Replace AffinityPropagation with SpectralClustering
This commit is contained in:
		
							parent
							
								
									7e356fed37
								
							
						
					
					
						commit
						b4e90c1174
					
				@ -5,7 +5,7 @@ from matplotlib.pyplot import *
 | 
			
		||||
from pandas import DataFrame
 | 
			
		||||
from seaborn import heatmap, set_style, set_theme, pairplot
 | 
			
		||||
from sklearn.metrics import silhouette_score, calinski_harabasz_score
 | 
			
		||||
from sklearn.cluster import KMeans, Birch, AffinityPropagation, MeanShift, DBSCAN
 | 
			
		||||
from sklearn.cluster import KMeans, Birch, SpectralClustering, MeanShift, DBSCAN
 | 
			
		||||
 | 
			
		||||
from preprocessing import parse_data, filter_dataframe
 | 
			
		||||
 | 
			
		||||
@ -15,8 +15,8 @@ def choose_model(model):
 | 
			
		||||
        return KMeans(random_state=42)
 | 
			
		||||
    elif model == "birch":
 | 
			
		||||
        return Birch()
 | 
			
		||||
    elif model == "affinity":
 | 
			
		||||
        return AffinityPropagation(random_state=42)
 | 
			
		||||
    elif model == "spectral":
 | 
			
		||||
        return SpectralClustering()
 | 
			
		||||
    elif model == "meanshift":
 | 
			
		||||
        return MeanShift()
 | 
			
		||||
    elif model == "dbscan":
 | 
			
		||||
@ -29,6 +29,7 @@ def predict_data(data, model, results, sample):
 | 
			
		||||
    start_time = time.time()
 | 
			
		||||
    prediction = model.fit_predict(data)
 | 
			
		||||
    execution_time = time.time() - start_time
 | 
			
		||||
    cluster_number = len(set(prediction))
 | 
			
		||||
    calinski = calinski_harabasz_score(X=data, labels=prediction)
 | 
			
		||||
    silhouette = silhouette_score(
 | 
			
		||||
        X=data,
 | 
			
		||||
@ -41,7 +42,7 @@ def predict_data(data, model, results, sample):
 | 
			
		||||
        df=results,
 | 
			
		||||
        model=model_name,
 | 
			
		||||
        prediction=prediction,
 | 
			
		||||
        clusters=len(set(prediction)),
 | 
			
		||||
        clusters=cluster_number,
 | 
			
		||||
        calinski=calinski,
 | 
			
		||||
        silhouette=silhouette,
 | 
			
		||||
        time=execution_time,
 | 
			
		||||
@ -79,17 +80,12 @@ def plot_scatter_plot(results):
 | 
			
		||||
    fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def print_dataframe(df):
 | 
			
		||||
    df.set_index("model")
 | 
			
		||||
    print(df)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def show_results(results):
 | 
			
		||||
    set_theme()
 | 
			
		||||
    set_style("white")
 | 
			
		||||
    plot_heatmap(results=results)
 | 
			
		||||
    plot_scatter_plot(results=results)
 | 
			
		||||
    print_dataframe(df=results)
 | 
			
		||||
    print(results)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_result_dataframes():
 | 
			
		||||
@ -103,8 +99,7 @@ def create_result_dataframes():
 | 
			
		||||
            "time",
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
    indexed_results = results.set_index("model")
 | 
			
		||||
    return indexed_results, indexed_results
 | 
			
		||||
    return results, results
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def populate_results(df, model, clusters, prediction, calinski, silhouette, time):
 | 
			
		||||
@ -153,7 +148,7 @@ def usage():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    models = ["kmeans", "birch", "affinity", "meanshift", "dbscan"]
 | 
			
		||||
    models = ["kmeans", "birch", "spectral", "meanshift", "dbscan"]
 | 
			
		||||
    if len(argv) != 4:
 | 
			
		||||
        usage()
 | 
			
		||||
    case, sample = argv[2], int(argv[3])
 | 
			
		||||
@ -172,7 +167,7 @@ def main():
 | 
			
		||||
            individual_result.append(model_results)
 | 
			
		||||
        )
 | 
			
		||||
    complete_results.set_index("model")
 | 
			
		||||
    print_dataframe(df=complete_results)
 | 
			
		||||
    print(complete_results)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user