Replace AffinityPropagation with SpectralClustering
This commit is contained in:
		
							parent
							
								
									7e356fed37
								
							
						
					
					
						commit
						b4e90c1174
					
				@ -5,7 +5,7 @@ from matplotlib.pyplot import *
 | 
				
			|||||||
from pandas import DataFrame
 | 
					from pandas import DataFrame
 | 
				
			||||||
from seaborn import heatmap, set_style, set_theme, pairplot
 | 
					from seaborn import heatmap, set_style, set_theme, pairplot
 | 
				
			||||||
from sklearn.metrics import silhouette_score, calinski_harabasz_score
 | 
					from sklearn.metrics import silhouette_score, calinski_harabasz_score
 | 
				
			||||||
from sklearn.cluster import KMeans, Birch, AffinityPropagation, MeanShift, DBSCAN
 | 
					from sklearn.cluster import KMeans, Birch, SpectralClustering, MeanShift, DBSCAN
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from preprocessing import parse_data, filter_dataframe
 | 
					from preprocessing import parse_data, filter_dataframe
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -15,8 +15,8 @@ def choose_model(model):
 | 
				
			|||||||
        return KMeans(random_state=42)
 | 
					        return KMeans(random_state=42)
 | 
				
			||||||
    elif model == "birch":
 | 
					    elif model == "birch":
 | 
				
			||||||
        return Birch()
 | 
					        return Birch()
 | 
				
			||||||
    elif model == "affinity":
 | 
					    elif model == "spectral":
 | 
				
			||||||
        return AffinityPropagation(random_state=42)
 | 
					        return SpectralClustering()
 | 
				
			||||||
    elif model == "meanshift":
 | 
					    elif model == "meanshift":
 | 
				
			||||||
        return MeanShift()
 | 
					        return MeanShift()
 | 
				
			||||||
    elif model == "dbscan":
 | 
					    elif model == "dbscan":
 | 
				
			||||||
@ -29,6 +29,7 @@ def predict_data(data, model, results, sample):
 | 
				
			|||||||
    start_time = time.time()
 | 
					    start_time = time.time()
 | 
				
			||||||
    prediction = model.fit_predict(data)
 | 
					    prediction = model.fit_predict(data)
 | 
				
			||||||
    execution_time = time.time() - start_time
 | 
					    execution_time = time.time() - start_time
 | 
				
			||||||
 | 
					    cluster_number = len(set(prediction))
 | 
				
			||||||
    calinski = calinski_harabasz_score(X=data, labels=prediction)
 | 
					    calinski = calinski_harabasz_score(X=data, labels=prediction)
 | 
				
			||||||
    silhouette = silhouette_score(
 | 
					    silhouette = silhouette_score(
 | 
				
			||||||
        X=data,
 | 
					        X=data,
 | 
				
			||||||
@ -41,7 +42,7 @@ def predict_data(data, model, results, sample):
 | 
				
			|||||||
        df=results,
 | 
					        df=results,
 | 
				
			||||||
        model=model_name,
 | 
					        model=model_name,
 | 
				
			||||||
        prediction=prediction,
 | 
					        prediction=prediction,
 | 
				
			||||||
        clusters=len(set(prediction)),
 | 
					        clusters=cluster_number,
 | 
				
			||||||
        calinski=calinski,
 | 
					        calinski=calinski,
 | 
				
			||||||
        silhouette=silhouette,
 | 
					        silhouette=silhouette,
 | 
				
			||||||
        time=execution_time,
 | 
					        time=execution_time,
 | 
				
			||||||
@ -79,17 +80,12 @@ def plot_scatter_plot(results):
 | 
				
			|||||||
    fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
 | 
					    fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def print_dataframe(df):
 | 
					 | 
				
			||||||
    df.set_index("model")
 | 
					 | 
				
			||||||
    print(df)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def show_results(results):
 | 
					def show_results(results):
 | 
				
			||||||
    set_theme()
 | 
					    set_theme()
 | 
				
			||||||
    set_style("white")
 | 
					    set_style("white")
 | 
				
			||||||
    plot_heatmap(results=results)
 | 
					    plot_heatmap(results=results)
 | 
				
			||||||
    plot_scatter_plot(results=results)
 | 
					    plot_scatter_plot(results=results)
 | 
				
			||||||
    print_dataframe(df=results)
 | 
					    print(results)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def create_result_dataframes():
 | 
					def create_result_dataframes():
 | 
				
			||||||
@ -103,8 +99,7 @@ def create_result_dataframes():
 | 
				
			|||||||
            "time",
 | 
					            "time",
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    indexed_results = results.set_index("model")
 | 
					    return results, results
 | 
				
			||||||
    return indexed_results, indexed_results
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def populate_results(df, model, clusters, prediction, calinski, silhouette, time):
 | 
					def populate_results(df, model, clusters, prediction, calinski, silhouette, time):
 | 
				
			||||||
@ -153,7 +148,7 @@ def usage():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main():
 | 
					def main():
 | 
				
			||||||
    models = ["kmeans", "birch", "affinity", "meanshift", "dbscan"]
 | 
					    models = ["kmeans", "birch", "spectral", "meanshift", "dbscan"]
 | 
				
			||||||
    if len(argv) != 4:
 | 
					    if len(argv) != 4:
 | 
				
			||||||
        usage()
 | 
					        usage()
 | 
				
			||||||
    case, sample = argv[2], int(argv[3])
 | 
					    case, sample = argv[2], int(argv[3])
 | 
				
			||||||
@ -172,7 +167,7 @@ def main():
 | 
				
			|||||||
            individual_result.append(model_results)
 | 
					            individual_result.append(model_results)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    complete_results.set_index("model")
 | 
					    complete_results.set_index("model")
 | 
				
			||||||
    print_dataframe(df=complete_results)
 | 
					    print(complete_results)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user