Plot confusion matrix and simplify plot selection
This commit is contained in:
		
							parent
							
								
									bed9f1d0e9
								
							
						
					
					
						commit
						4af2c35c2e
					
				@ -7,7 +7,7 @@ from sklearn.neighbors import KNeighborsClassifier
 | 
				
			|||||||
from sklearn.preprocessing import scale
 | 
					from sklearn.preprocessing import scale
 | 
				
			||||||
from sklearn.svm import LinearSVC
 | 
					from sklearn.svm import LinearSVC
 | 
				
			||||||
from sklearn.tree import DecisionTreeClassifier
 | 
					from sklearn.tree import DecisionTreeClassifier
 | 
				
			||||||
from seaborn import set_theme
 | 
					from seaborn import set_theme, set_style, heatmap, FacetGrid
 | 
				
			||||||
from matplotlib.pyplot import *
 | 
					from matplotlib.pyplot import *
 | 
				
			||||||
from pandas import DataFrame
 | 
					from pandas import DataFrame
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -53,41 +53,41 @@ def predict_data(data, target, model, results):
 | 
				
			|||||||
    return populated_results
 | 
					    return populated_results
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def plot_roc_auc_curve(model, results):
 | 
					def plot_roc_auc_curve(results):
 | 
				
			||||||
    rounded_auc = round(results.loc[model]["auc"], 3)
 | 
					    fig = figure(figsize=(8, 6))
 | 
				
			||||||
    plot(
 | 
					    for model in results.index:
 | 
				
			||||||
        results.loc[model]["fpr"],
 | 
					        rounded_auc = round(results.loc[model]["auc"], 3)
 | 
				
			||||||
        results.loc[model]["tpr"],
 | 
					        plot(
 | 
				
			||||||
        label=f"{model} , AUC={rounded_auc}",
 | 
					            results.loc[model]["fpr"],
 | 
				
			||||||
    )
 | 
					            results.loc[model]["tpr"],
 | 
				
			||||||
 | 
					            label=f"{model} , AUC={rounded_auc}",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
    xticks(arange(0.0, 1.0, step=0.1))
 | 
					    xticks(arange(0.0, 1.0, step=0.1))
 | 
				
			||||||
    yticks(arange(0.0, 1.0, step=0.1))
 | 
					    yticks(arange(0.0, 1.0, step=0.1))
 | 
				
			||||||
    legend(loc="lower right")
 | 
					    legend(loc="lower right")
 | 
				
			||||||
 | 
					    fig_title = "ROC AUC curve"
 | 
				
			||||||
 | 
					 | 
				
			||||||
def plot_confusion_matrix(model, results):
 | 
					 | 
				
			||||||
    matrix = results.loc[model]["confusion_matrix"]
 | 
					 | 
				
			||||||
    classes = ["Negative", "Positive"]
 | 
					 | 
				
			||||||
    for item in matrix:
 | 
					 | 
				
			||||||
        text(x=0.5, y=0.5, s=item)
 | 
					 | 
				
			||||||
    xticks(ticks=arange(len(classes)), labels=classes)
 | 
					 | 
				
			||||||
    yticks(ticks=arange(len(classes)), labels=classes)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def choose_plot_type(type, model, results):
 | 
					 | 
				
			||||||
    if type == "roc":
 | 
					 | 
				
			||||||
        plot_roc_auc_curve(model, results)
 | 
					 | 
				
			||||||
    elif type == "confusion_matrix":
 | 
					 | 
				
			||||||
        plot_confusion_matrix(model, results)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def plot_individual_figure(results, type, x_axis, y_axis, fig_title):
 | 
					 | 
				
			||||||
    fig = figure(figsize=(8, 6))
 | 
					 | 
				
			||||||
    for model in results.index:
 | 
					 | 
				
			||||||
        choose_plot_type(type, model, results)
 | 
					 | 
				
			||||||
    xlabel(x_axis)
 | 
					 | 
				
			||||||
    ylabel(y_axis)
 | 
					 | 
				
			||||||
    title(fig_title)
 | 
					    title(fig_title)
 | 
				
			||||||
 | 
					    xlabel("False positive rate")
 | 
				
			||||||
 | 
					    ylabel("True positive rate")
 | 
				
			||||||
 | 
					    fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def plot_confusion_matrix(results):
 | 
				
			||||||
 | 
					    set_style("white")
 | 
				
			||||||
 | 
					    matrix = results.filter(items=["model", "confusion_matrix"])
 | 
				
			||||||
 | 
					    fig, axes = subplots(nrows=1, ncols=5, figsize=(8, 6))
 | 
				
			||||||
 | 
					    for i in range(len(axes)):
 | 
				
			||||||
 | 
					        heatmap(
 | 
				
			||||||
 | 
					            ax=axes[i],
 | 
				
			||||||
 | 
					            data=matrix.iloc[i]["confusion_matrix"],
 | 
				
			||||||
 | 
					            cmap="Blues",
 | 
				
			||||||
 | 
					            square=True,
 | 
				
			||||||
 | 
					            annot=True,
 | 
				
			||||||
 | 
					            cbar=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        axes[i].set_title(matrix.index[i])
 | 
				
			||||||
 | 
					    fig_title = "Confusion Matrix"
 | 
				
			||||||
 | 
					    suptitle(fig_title)
 | 
				
			||||||
    show()
 | 
					    show()
 | 
				
			||||||
    fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
 | 
					    fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -95,20 +95,8 @@ def plot_individual_figure(results, type, x_axis, y_axis, fig_title):
 | 
				
			|||||||
# TODO Add cross_val_score
 | 
					# TODO Add cross_val_score
 | 
				
			||||||
def plot_all_figures(results):
 | 
					def plot_all_figures(results):
 | 
				
			||||||
    set_theme()
 | 
					    set_theme()
 | 
				
			||||||
    plot_individual_figure(
 | 
					    plot_roc_auc_curve(results=results)
 | 
				
			||||||
        results,
 | 
					    plot_confusion_matrix(results=results)
 | 
				
			||||||
        type="roc",
 | 
					 | 
				
			||||||
        x_axis="False positive rate",
 | 
					 | 
				
			||||||
        y_axis="True positive rate",
 | 
					 | 
				
			||||||
        fig_title="ROC AUC curve",
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    plot_individual_figure(
 | 
					 | 
				
			||||||
        results,
 | 
					 | 
				
			||||||
        type="confusion_matrix",
 | 
					 | 
				
			||||||
        x_axis="Predicted values",
 | 
					 | 
				
			||||||
        y_axis="Real values",
 | 
					 | 
				
			||||||
        fig_title="Confusion Matrix",
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def create_result_dataframes():
 | 
					def create_result_dataframes():
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user