Replace the samples argument with cluster number
This commit is contained in:
		
							parent
							
								
									b4e90c1174
								
							
						
					
					
						commit
						e63406c0a8
					
				@ -3,18 +3,18 @@ from sys import argv
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from matplotlib.pyplot import *
 | 
					from matplotlib.pyplot import *
 | 
				
			||||||
from pandas import DataFrame
 | 
					from pandas import DataFrame
 | 
				
			||||||
from seaborn import heatmap, set_style, set_theme, pairplot
 | 
					from seaborn import clustermap, set_style, set_theme, pairplot
 | 
				
			||||||
from sklearn.metrics import silhouette_score, calinski_harabasz_score
 | 
					from sklearn.metrics import silhouette_score, calinski_harabasz_score
 | 
				
			||||||
from sklearn.cluster import KMeans, Birch, SpectralClustering, MeanShift, DBSCAN
 | 
					from sklearn.cluster import KMeans, Birch, SpectralClustering, MeanShift, DBSCAN
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from preprocessing import parse_data, filter_dataframe
 | 
					from preprocessing import parse_data, filter_dataframe
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def choose_model(model):
 | 
					def choose_model(model, cluster_number):
 | 
				
			||||||
    if model == "kmeans":
 | 
					    if model == "kmeans":
 | 
				
			||||||
        return KMeans(random_state=42)
 | 
					        return KMeans(n_clusters=cluster_number, random_state=42)
 | 
				
			||||||
    elif model == "birch":
 | 
					    elif model == "birch":
 | 
				
			||||||
        return Birch()
 | 
					        return Birch(n_clusters=cluster_number)
 | 
				
			||||||
    elif model == "spectral":
 | 
					    elif model == "spectral":
 | 
				
			||||||
        return SpectralClustering()
 | 
					        return SpectralClustering()
 | 
				
			||||||
    elif model == "meanshift":
 | 
					    elif model == "meanshift":
 | 
				
			||||||
@ -23,9 +23,9 @@ def choose_model(model):
 | 
				
			|||||||
        return DBSCAN()
 | 
					        return DBSCAN()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def predict_data(data, model, results, sample):
 | 
					def predict_data(data, model, cluster_number, results):
 | 
				
			||||||
    model_name = model
 | 
					    model_name = model
 | 
				
			||||||
    model = choose_model(model)
 | 
					    model = choose_model(model=model, cluster_number=cluster_number)
 | 
				
			||||||
    start_time = time.time()
 | 
					    start_time = time.time()
 | 
				
			||||||
    prediction = model.fit_predict(data)
 | 
					    prediction = model.fit_predict(data)
 | 
				
			||||||
    execution_time = time.time() - start_time
 | 
					    execution_time = time.time() - start_time
 | 
				
			||||||
@ -35,7 +35,6 @@ def predict_data(data, model, results, sample):
 | 
				
			|||||||
        X=data,
 | 
					        X=data,
 | 
				
			||||||
        labels=prediction,
 | 
					        labels=prediction,
 | 
				
			||||||
        metric="euclidean",
 | 
					        metric="euclidean",
 | 
				
			||||||
        sample_size=sample,
 | 
					 | 
				
			||||||
        random_state=42,
 | 
					        random_state=42,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    populated_results = populate_results(
 | 
					    populated_results = populate_results(
 | 
				
			||||||
@ -52,10 +51,13 @@ def predict_data(data, model, results, sample):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def plot_heatmap(results):
 | 
					def plot_heatmap(results):
 | 
				
			||||||
    fig = figure(figsize=(20, 10))
 | 
					    fig = figure(figsize=(20, 10))
 | 
				
			||||||
    heatmap(
 | 
					    results.reset_index()
 | 
				
			||||||
        data=results,
 | 
					    matrix = results["prediction"]
 | 
				
			||||||
        cmap="Blues",
 | 
					    print(matrix)
 | 
				
			||||||
        square=True,
 | 
					    clustermap(
 | 
				
			||||||
 | 
					        data=matrix,
 | 
				
			||||||
 | 
					        cmap="mako",
 | 
				
			||||||
 | 
					        metric="euclidean",
 | 
				
			||||||
        annot=True,
 | 
					        annot=True,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    fig_title = "Heatmap"
 | 
					    fig_title = "Heatmap"
 | 
				
			||||||
@ -66,10 +68,10 @@ def plot_heatmap(results):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def plot_scatter_plot(results):
 | 
					def plot_scatter_plot(results):
 | 
				
			||||||
    fig = figure(figsize=(20, 10))
 | 
					    fig = figure(figsize=(20, 10))
 | 
				
			||||||
    original_data = results.drop("prediction")
 | 
					    matrix = results.filter(items=["input", "prediction"])
 | 
				
			||||||
    pairplot(
 | 
					    pairplot(
 | 
				
			||||||
        data=results,
 | 
					        data=results,
 | 
				
			||||||
        vars=original_data,
 | 
					        vars=matrix,
 | 
				
			||||||
        hue="prediction",
 | 
					        hue="prediction",
 | 
				
			||||||
        palette="Paired",
 | 
					        palette="Paired",
 | 
				
			||||||
        diag_kind="hist",
 | 
					        diag_kind="hist",
 | 
				
			||||||
@ -138,12 +140,14 @@ def construct_case(df, choice):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def usage():
 | 
					def usage():
 | 
				
			||||||
    print("Usage: " + argv[0] + "<preprocessing action> <case> <sample size>")
 | 
					    print("Usage: " + argv[0] + "<preprocessing action> <case> <number of clusters>")
 | 
				
			||||||
    print("preprocessing actions:")
 | 
					    print("preprocessing actions:")
 | 
				
			||||||
    print("fill: fills the na values with the mean")
 | 
					    print("fill: fills the na values with the mean")
 | 
				
			||||||
    print("drop: drops the na values")
 | 
					    print("drop: drops the na values")
 | 
				
			||||||
    print("cases: choice of case study")
 | 
					    print("cases: choice of case study")
 | 
				
			||||||
    print("sample size: size of the sample when computing the Silhouette Coefficient")
 | 
					    print(
 | 
				
			||||||
 | 
					        "number of clusters: number of clusters for the algorithms that use a fixed number"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    exit()
 | 
					    exit()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -151,7 +155,7 @@ def main():
 | 
				
			|||||||
    models = ["kmeans", "birch", "spectral", "meanshift", "dbscan"]
 | 
					    models = ["kmeans", "birch", "spectral", "meanshift", "dbscan"]
 | 
				
			||||||
    if len(argv) != 4:
 | 
					    if len(argv) != 4:
 | 
				
			||||||
        usage()
 | 
					        usage()
 | 
				
			||||||
    case, sample = argv[2], int(argv[3])
 | 
					    case, cluster_number = argv[2], int(argv[3])
 | 
				
			||||||
    data = parse_data(source="data/accidentes_2013.csv", action=str(argv[1]))
 | 
					    data = parse_data(source="data/accidentes_2013.csv", action=str(argv[1]))
 | 
				
			||||||
    individual_result, complete_results = create_result_dataframes()
 | 
					    individual_result, complete_results = create_result_dataframes()
 | 
				
			||||||
    case_data = construct_case(df=data, choice=case)
 | 
					    case_data = construct_case(df=data, choice=case)
 | 
				
			||||||
@ -161,13 +165,13 @@ def main():
 | 
				
			|||||||
            data=filtered_data,
 | 
					            data=filtered_data,
 | 
				
			||||||
            model=model,
 | 
					            model=model,
 | 
				
			||||||
            results=individual_result,
 | 
					            results=individual_result,
 | 
				
			||||||
            sample=sample,
 | 
					            cluster_number=cluster_number,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        complete_results = complete_results.append(
 | 
					        complete_results = complete_results.append(
 | 
				
			||||||
            individual_result.append(model_results)
 | 
					            individual_result.append(model_results)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    complete_results.set_index("model")
 | 
					    complete_results.set_index("model")
 | 
				
			||||||
    print(complete_results)
 | 
					    show_results(results=complete_results)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user