Remove imputation of values from part 2
This commit is contained in:
		
							parent
							
								
									8b2ce6b5c9
								
							
						
					
					
						commit
						59895f4b8a
					
				@ -89,7 +89,6 @@ def plot_confusion_matrix(results):
 | 
				
			|||||||
        axes[i].set_title(matrix.index[i])
 | 
					        axes[i].set_title(matrix.index[i])
 | 
				
			||||||
    fig_title = "Confusion Matrix"
 | 
					    fig_title = "Confusion Matrix"
 | 
				
			||||||
    suptitle(fig_title)
 | 
					    suptitle(fig_title)
 | 
				
			||||||
    show()
 | 
					 | 
				
			||||||
    fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
 | 
					    fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -106,7 +105,6 @@ def plot_attributes_correlation(data, target):
 | 
				
			|||||||
        axes[i].set_title(transformed_data.columns[i])
 | 
					        axes[i].set_title(transformed_data.columns[i])
 | 
				
			||||||
    fig_title = "Attribute's correlation"
 | 
					    fig_title = "Attribute's correlation"
 | 
				
			||||||
    suptitle(fig_title)
 | 
					    suptitle(fig_title)
 | 
				
			||||||
    show()
 | 
					 | 
				
			||||||
    fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
 | 
					    fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -8,18 +8,6 @@ def replace_values(df):
 | 
				
			|||||||
    return df
 | 
					    return df
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def process_na(df, action):
 | 
					 | 
				
			||||||
    if action == "drop":
 | 
					 | 
				
			||||||
        return df.dropna()
 | 
					 | 
				
			||||||
    elif action == "fill":
 | 
					 | 
				
			||||||
        return replace_values(df)
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        print("Unknown action selected. The choices are: ")
 | 
					 | 
				
			||||||
        print("fill: fills the na values with the mean")
 | 
					 | 
				
			||||||
        print("drop: drops the na values")
 | 
					 | 
				
			||||||
        exit()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def filter_dataframe(df):
 | 
					def filter_dataframe(df):
 | 
				
			||||||
    relevant_columns = [
 | 
					    relevant_columns = [
 | 
				
			||||||
        "TOT_HERIDOS_LEVES",
 | 
					        "TOT_HERIDOS_LEVES",
 | 
				
			||||||
@ -39,8 +27,8 @@ def normalize_numerical_values(df):
 | 
				
			|||||||
    return df
 | 
					    return df
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def parse_data(source, action):
 | 
					def parse_data(source):
 | 
				
			||||||
    df = read_csv(filepath_or_buffer=source, na_values="?")
 | 
					    df = read_csv(filepath_or_buffer=source, na_values="?")
 | 
				
			||||||
    processed_df = process_na(df=df, action=action)
 | 
					    processed_df = df.dropna()
 | 
				
			||||||
    normalized_df = normalize_numerical_values(df=processed_df)
 | 
					    normalized_df = normalize_numerical_values(df=processed_df)
 | 
				
			||||||
    return normalized_df
 | 
					    return normalized_df
 | 
				
			||||||
 | 
				
			|||||||
@ -3,7 +3,6 @@ from sys import argv
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from matplotlib.pyplot import *
 | 
					from matplotlib.pyplot import *
 | 
				
			||||||
from pandas import DataFrame
 | 
					from pandas import DataFrame
 | 
				
			||||||
from seaborn import clustermap, set_style, set_theme, pairplot
 | 
					 | 
				
			||||||
from sklearn.metrics import silhouette_score, calinski_harabasz_score
 | 
					from sklearn.metrics import silhouette_score, calinski_harabasz_score
 | 
				
			||||||
from sklearn.cluster import KMeans, Birch, SpectralClustering, MeanShift, DBSCAN
 | 
					from sklearn.cluster import KMeans, Birch, SpectralClustering, MeanShift, DBSCAN
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -49,47 +48,6 @@ def predict_data(data, model, cluster_number, results):
 | 
				
			|||||||
    return populated_results
 | 
					    return populated_results
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def plot_heatmap(results):
 | 
					 | 
				
			||||||
    fig = figure(figsize=(20, 10))
 | 
					 | 
				
			||||||
    results.reset_index()
 | 
					 | 
				
			||||||
    matrix = results["prediction"]
 | 
					 | 
				
			||||||
    print(matrix)
 | 
					 | 
				
			||||||
    clustermap(
 | 
					 | 
				
			||||||
        data=matrix,
 | 
					 | 
				
			||||||
        cmap="mako",
 | 
					 | 
				
			||||||
        metric="euclidean",
 | 
					 | 
				
			||||||
        annot=True,
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    fig_title = "Heatmap"
 | 
					 | 
				
			||||||
    title(fig_title)
 | 
					 | 
				
			||||||
    show()
 | 
					 | 
				
			||||||
    fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def plot_scatter_plot(results):
 | 
					 | 
				
			||||||
    fig = figure(figsize=(20, 10))
 | 
					 | 
				
			||||||
    matrix = results.filter(items=["input", "prediction"])
 | 
					 | 
				
			||||||
    pairplot(
 | 
					 | 
				
			||||||
        data=results,
 | 
					 | 
				
			||||||
        vars=matrix,
 | 
					 | 
				
			||||||
        hue="prediction",
 | 
					 | 
				
			||||||
        palette="Paired",
 | 
					 | 
				
			||||||
        diag_kind="hist",
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    fig_title = "Scatter plot"
 | 
					 | 
				
			||||||
    title(fig_title)
 | 
					 | 
				
			||||||
    show()
 | 
					 | 
				
			||||||
    fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def show_results(results):
 | 
					 | 
				
			||||||
    set_theme()
 | 
					 | 
				
			||||||
    set_style("white")
 | 
					 | 
				
			||||||
    plot_heatmap(results=results)
 | 
					 | 
				
			||||||
    plot_scatter_plot(results=results)
 | 
					 | 
				
			||||||
    print(results)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def create_result_dataframes():
 | 
					def create_result_dataframes():
 | 
				
			||||||
    results = DataFrame(
 | 
					    results = DataFrame(
 | 
				
			||||||
        columns=[
 | 
					        columns=[
 | 
				
			||||||
@ -153,10 +111,10 @@ def usage():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def main():
 | 
					def main():
 | 
				
			||||||
    models = ["kmeans", "birch", "spectral", "meanshift", "dbscan"]
 | 
					    models = ["kmeans", "birch", "spectral", "meanshift", "dbscan"]
 | 
				
			||||||
    if len(argv) != 4:
 | 
					    if len(argv) != 3:
 | 
				
			||||||
        usage()
 | 
					        usage()
 | 
				
			||||||
    case, cluster_number = argv[2], int(argv[3])
 | 
					    case, cluster_number = argv[1], int(argv[2])
 | 
				
			||||||
    data = parse_data(source="data/accidentes_2013.csv", action=str(argv[1]))
 | 
					    data = parse_data(source="data/accidentes_2013.csv")
 | 
				
			||||||
    individual_result, complete_results = create_result_dataframes()
 | 
					    individual_result, complete_results = create_result_dataframes()
 | 
				
			||||||
    case_data = construct_case(df=data, choice=case)
 | 
					    case_data = construct_case(df=data, choice=case)
 | 
				
			||||||
    filtered_data = filter_dataframe(df=case_data)
 | 
					    filtered_data = filter_dataframe(df=case_data)
 | 
				
			||||||
@ -171,7 +129,7 @@ def main():
 | 
				
			|||||||
            individual_result.append(model_results)
 | 
					            individual_result.append(model_results)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    complete_results.set_index("model")
 | 
					    complete_results.set_index("model")
 | 
				
			||||||
    show_results(results=complete_results)
 | 
					    print(complete_results)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user