Plot attribute correlation
This commit is contained in:
		
							parent
							
								
									4af2c35c2e
								
							
						
					
					
						commit
						ef043650b5
					
				@ -1,15 +1,14 @@
 | 
				
			|||||||
from numpy import mean, arange
 | 
					from numpy import mean, arange
 | 
				
			||||||
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve
 | 
					from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve
 | 
				
			||||||
from sklearn.model_selection import cross_val_predict
 | 
					 | 
				
			||||||
from sklearn.naive_bayes import GaussianNB
 | 
					from sklearn.naive_bayes import GaussianNB
 | 
				
			||||||
from sklearn.neural_network import MLPClassifier
 | 
					from sklearn.neural_network import MLPClassifier
 | 
				
			||||||
from sklearn.neighbors import KNeighborsClassifier
 | 
					from sklearn.neighbors import KNeighborsClassifier
 | 
				
			||||||
from sklearn.preprocessing import scale
 | 
					from sklearn.preprocessing import scale
 | 
				
			||||||
from sklearn.svm import LinearSVC
 | 
					from sklearn.svm import LinearSVC
 | 
				
			||||||
from sklearn.tree import DecisionTreeClassifier
 | 
					from sklearn.tree import DecisionTreeClassifier
 | 
				
			||||||
from seaborn import set_theme, set_style, heatmap, FacetGrid
 | 
					from seaborn import set_theme, set_style, heatmap, countplot
 | 
				
			||||||
from matplotlib.pyplot import *
 | 
					from matplotlib.pyplot import *
 | 
				
			||||||
from pandas import DataFrame
 | 
					from pandas import DataFrame, cut
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from sys import argv
 | 
					from sys import argv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -92,11 +91,28 @@ def plot_confusion_matrix(results):
 | 
				
			|||||||
    fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
 | 
					    fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO Add cross_val_score
 | 
					def plot_attributes_correlation(data, target):
 | 
				
			||||||
def plot_all_figures(results):
 | 
					    transformed_data = transform_dataframe(data, target)
 | 
				
			||||||
 | 
					    fig, axes = subplots(nrows=5, ncols=1, figsize=(8, 6))
 | 
				
			||||||
 | 
					    for i in range(len(axes)):
 | 
				
			||||||
 | 
					        countplot(
 | 
				
			||||||
 | 
					            ax=axes[i],
 | 
				
			||||||
 | 
					            x=transformed_data.columns[i],
 | 
				
			||||||
 | 
					            data=transformed_data,
 | 
				
			||||||
 | 
					            hue="Severity",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        axes[i].set_title(transformed_data.columns[i])
 | 
				
			||||||
 | 
					    fig_title = "Attribute's correlation"
 | 
				
			||||||
 | 
					    suptitle(fig_title)
 | 
				
			||||||
 | 
					    show()
 | 
				
			||||||
 | 
					    fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def plot_all_figures(results, data, target):
 | 
				
			||||||
    set_theme()
 | 
					    set_theme()
 | 
				
			||||||
    plot_roc_auc_curve(results=results)
 | 
					    # plot_roc_auc_curve(results=results)
 | 
				
			||||||
    plot_confusion_matrix(results=results)
 | 
					    # plot_confusion_matrix(results=results)
 | 
				
			||||||
 | 
					    plot_attributes_correlation(data=data, target=target)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def create_result_dataframes():
 | 
					def create_result_dataframes():
 | 
				
			||||||
@ -127,6 +143,13 @@ def rename_model(model):
 | 
				
			|||||||
    return mapping[model]
 | 
					    return mapping[model]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def transform_dataframe(data, target):
 | 
				
			||||||
 | 
					    joined_df = data.join(target)
 | 
				
			||||||
 | 
					    binned_df = joined_df.copy()
 | 
				
			||||||
 | 
					    binned_df["Age"] = cut(x=joined_df["Age"], bins=[15, 30, 45, 60, 75])
 | 
				
			||||||
 | 
					    return binned_df
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def usage():
 | 
					def usage():
 | 
				
			||||||
    print("Usage: " + argv[0] + "<preprocessing action>")
 | 
					    print("Usage: " + argv[0] + "<preprocessing action>")
 | 
				
			||||||
    print("preprocessing actions:")
 | 
					    print("preprocessing actions:")
 | 
				
			||||||
@ -149,7 +172,7 @@ def main():
 | 
				
			|||||||
            individual_result.append(model_results)
 | 
					            individual_result.append(model_results)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    indexed_results = complete_results.set_index("model")
 | 
					    indexed_results = complete_results.set_index("model")
 | 
				
			||||||
    plot_all_figures(results=indexed_results)
 | 
					    plot_all_figures(results=indexed_results, data=data, target=target)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user