Replace the samples argument with cluster number
This commit is contained in:
		
							parent
							
								
									b4e90c1174
								
							
						
					
					
						commit
						e63406c0a8
					
				@ -3,18 +3,18 @@ from sys import argv
 | 
			
		||||
 | 
			
		||||
from matplotlib.pyplot import *
 | 
			
		||||
from pandas import DataFrame
 | 
			
		||||
from seaborn import heatmap, set_style, set_theme, pairplot
 | 
			
		||||
from seaborn import clustermap, set_style, set_theme, pairplot
 | 
			
		||||
from sklearn.metrics import silhouette_score, calinski_harabasz_score
 | 
			
		||||
from sklearn.cluster import KMeans, Birch, SpectralClustering, MeanShift, DBSCAN
 | 
			
		||||
 | 
			
		||||
from preprocessing import parse_data, filter_dataframe
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def choose_model(model):
 | 
			
		||||
def choose_model(model, cluster_number):
 | 
			
		||||
    if model == "kmeans":
 | 
			
		||||
        return KMeans(random_state=42)
 | 
			
		||||
        return KMeans(n_clusters=cluster_number, random_state=42)
 | 
			
		||||
    elif model == "birch":
 | 
			
		||||
        return Birch()
 | 
			
		||||
        return Birch(n_clusters=cluster_number)
 | 
			
		||||
    elif model == "spectral":
 | 
			
		||||
        return SpectralClustering()
 | 
			
		||||
    elif model == "meanshift":
 | 
			
		||||
@ -23,9 +23,9 @@ def choose_model(model):
 | 
			
		||||
        return DBSCAN()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def predict_data(data, model, results, sample):
 | 
			
		||||
def predict_data(data, model, cluster_number, results):
 | 
			
		||||
    model_name = model
 | 
			
		||||
    model = choose_model(model)
 | 
			
		||||
    model = choose_model(model=model, cluster_number=cluster_number)
 | 
			
		||||
    start_time = time.time()
 | 
			
		||||
    prediction = model.fit_predict(data)
 | 
			
		||||
    execution_time = time.time() - start_time
 | 
			
		||||
@ -35,7 +35,6 @@ def predict_data(data, model, results, sample):
 | 
			
		||||
        X=data,
 | 
			
		||||
        labels=prediction,
 | 
			
		||||
        metric="euclidean",
 | 
			
		||||
        sample_size=sample,
 | 
			
		||||
        random_state=42,
 | 
			
		||||
    )
 | 
			
		||||
    populated_results = populate_results(
 | 
			
		||||
@ -52,10 +51,13 @@ def predict_data(data, model, results, sample):
 | 
			
		||||
 | 
			
		||||
def plot_heatmap(results):
 | 
			
		||||
    fig = figure(figsize=(20, 10))
 | 
			
		||||
    heatmap(
 | 
			
		||||
        data=results,
 | 
			
		||||
        cmap="Blues",
 | 
			
		||||
        square=True,
 | 
			
		||||
    results.reset_index()
 | 
			
		||||
    matrix = results["prediction"]
 | 
			
		||||
    print(matrix)
 | 
			
		||||
    clustermap(
 | 
			
		||||
        data=matrix,
 | 
			
		||||
        cmap="mako",
 | 
			
		||||
        metric="euclidean",
 | 
			
		||||
        annot=True,
 | 
			
		||||
    )
 | 
			
		||||
    fig_title = "Heatmap"
 | 
			
		||||
@ -66,10 +68,10 @@ def plot_heatmap(results):
 | 
			
		||||
 | 
			
		||||
def plot_scatter_plot(results):
 | 
			
		||||
    fig = figure(figsize=(20, 10))
 | 
			
		||||
    original_data = results.drop("prediction")
 | 
			
		||||
    matrix = results.filter(items=["input", "prediction"])
 | 
			
		||||
    pairplot(
 | 
			
		||||
        data=results,
 | 
			
		||||
        vars=original_data,
 | 
			
		||||
        vars=matrix,
 | 
			
		||||
        hue="prediction",
 | 
			
		||||
        palette="Paired",
 | 
			
		||||
        diag_kind="hist",
 | 
			
		||||
@ -138,12 +140,14 @@ def construct_case(df, choice):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def usage():
 | 
			
		||||
    print("Usage: " + argv[0] + "<preprocessing action> <case> <sample size>")
 | 
			
		||||
    print("Usage: " + argv[0] + "<preprocessing action> <case> <number of clusters>")
 | 
			
		||||
    print("preprocessing actions:")
 | 
			
		||||
    print("fill: fills the na values with the mean")
 | 
			
		||||
    print("drop: drops the na values")
 | 
			
		||||
    print("cases: choice of case study")
 | 
			
		||||
    print("sample size: size of the sample when computing the Silhouette Coefficient")
 | 
			
		||||
    print(
 | 
			
		||||
        "number of clusters: number of clusters for the algorithms that use a fixed number"
 | 
			
		||||
    )
 | 
			
		||||
    exit()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -151,7 +155,7 @@ def main():
 | 
			
		||||
    models = ["kmeans", "birch", "spectral", "meanshift", "dbscan"]
 | 
			
		||||
    if len(argv) != 4:
 | 
			
		||||
        usage()
 | 
			
		||||
    case, sample = argv[2], int(argv[3])
 | 
			
		||||
    case, cluster_number = argv[2], int(argv[3])
 | 
			
		||||
    data = parse_data(source="data/accidentes_2013.csv", action=str(argv[1]))
 | 
			
		||||
    individual_result, complete_results = create_result_dataframes()
 | 
			
		||||
    case_data = construct_case(df=data, choice=case)
 | 
			
		||||
@ -161,13 +165,13 @@ def main():
 | 
			
		||||
            data=filtered_data,
 | 
			
		||||
            model=model,
 | 
			
		||||
            results=individual_result,
 | 
			
		||||
            sample=sample,
 | 
			
		||||
            cluster_number=cluster_number,
 | 
			
		||||
        )
 | 
			
		||||
        complete_results = complete_results.append(
 | 
			
		||||
            individual_result.append(model_results)
 | 
			
		||||
        )
 | 
			
		||||
    complete_results.set_index("model")
 | 
			
		||||
    print(complete_results)
 | 
			
		||||
    show_results(results=complete_results)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user