Plot confusion matrix and simplify plot selection
This commit is contained in:
		
							parent
							
								
									bed9f1d0e9
								
							
						
					
					
						commit
						4af2c35c2e
					
				@ -7,7 +7,7 @@ from sklearn.neighbors import KNeighborsClassifier
 | 
			
		||||
from sklearn.preprocessing import scale
 | 
			
		||||
from sklearn.svm import LinearSVC
 | 
			
		||||
from sklearn.tree import DecisionTreeClassifier
 | 
			
		||||
from seaborn import set_theme
 | 
			
		||||
from seaborn import set_theme, set_style, heatmap, FacetGrid
 | 
			
		||||
from matplotlib.pyplot import *
 | 
			
		||||
from pandas import DataFrame
 | 
			
		||||
 | 
			
		||||
@ -53,41 +53,41 @@ def predict_data(data, target, model, results):
 | 
			
		||||
    return populated_results
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def plot_roc_auc_curve(model, results):
 | 
			
		||||
    rounded_auc = round(results.loc[model]["auc"], 3)
 | 
			
		||||
    plot(
 | 
			
		||||
        results.loc[model]["fpr"],
 | 
			
		||||
        results.loc[model]["tpr"],
 | 
			
		||||
        label=f"{model} , AUC={rounded_auc}",
 | 
			
		||||
    )
 | 
			
		||||
def plot_roc_auc_curve(results):
 | 
			
		||||
    fig = figure(figsize=(8, 6))
 | 
			
		||||
    for model in results.index:
 | 
			
		||||
        rounded_auc = round(results.loc[model]["auc"], 3)
 | 
			
		||||
        plot(
 | 
			
		||||
            results.loc[model]["fpr"],
 | 
			
		||||
            results.loc[model]["tpr"],
 | 
			
		||||
            label=f"{model} , AUC={rounded_auc}",
 | 
			
		||||
        )
 | 
			
		||||
    xticks(arange(0.0, 1.0, step=0.1))
 | 
			
		||||
    yticks(arange(0.0, 1.0, step=0.1))
 | 
			
		||||
    legend(loc="lower right")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def plot_confusion_matrix(model, results):
 | 
			
		||||
    matrix = results.loc[model]["confusion_matrix"]
 | 
			
		||||
    classes = ["Negative", "Positive"]
 | 
			
		||||
    for item in matrix:
 | 
			
		||||
        text(x=0.5, y=0.5, s=item)
 | 
			
		||||
    xticks(ticks=arange(len(classes)), labels=classes)
 | 
			
		||||
    yticks(ticks=arange(len(classes)), labels=classes)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def choose_plot_type(type, model, results):
 | 
			
		||||
    if type == "roc":
 | 
			
		||||
        plot_roc_auc_curve(model, results)
 | 
			
		||||
    elif type == "confusion_matrix":
 | 
			
		||||
        plot_confusion_matrix(model, results)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def plot_individual_figure(results, type, x_axis, y_axis, fig_title):
 | 
			
		||||
    fig = figure(figsize=(8, 6))
 | 
			
		||||
    for model in results.index:
 | 
			
		||||
        choose_plot_type(type, model, results)
 | 
			
		||||
    xlabel(x_axis)
 | 
			
		||||
    ylabel(y_axis)
 | 
			
		||||
    fig_title = "ROC AUC curve"
 | 
			
		||||
    title(fig_title)
 | 
			
		||||
    xlabel("False positive rate")
 | 
			
		||||
    ylabel("True positive rate")
 | 
			
		||||
    fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def plot_confusion_matrix(results):
 | 
			
		||||
    set_style("white")
 | 
			
		||||
    matrix = results.filter(items=["model", "confusion_matrix"])
 | 
			
		||||
    fig, axes = subplots(nrows=1, ncols=5, figsize=(8, 6))
 | 
			
		||||
    for i in range(len(axes)):
 | 
			
		||||
        heatmap(
 | 
			
		||||
            ax=axes[i],
 | 
			
		||||
            data=matrix.iloc[i]["confusion_matrix"],
 | 
			
		||||
            cmap="Blues",
 | 
			
		||||
            square=True,
 | 
			
		||||
            annot=True,
 | 
			
		||||
            cbar=False,
 | 
			
		||||
        )
 | 
			
		||||
        axes[i].set_title(matrix.index[i])
 | 
			
		||||
    fig_title = "Confusion Matrix"
 | 
			
		||||
    suptitle(fig_title)
 | 
			
		||||
    show()
 | 
			
		||||
    fig.savefig(f"docs/assets/{fig_title.replace(' ', '_').lower()}.png")
 | 
			
		||||
 | 
			
		||||
@ -95,20 +95,8 @@ def plot_individual_figure(results, type, x_axis, y_axis, fig_title):
 | 
			
		||||
# TODO Add cross_val_score
 | 
			
		||||
def plot_all_figures(results):
 | 
			
		||||
    set_theme()
 | 
			
		||||
    plot_individual_figure(
 | 
			
		||||
        results,
 | 
			
		||||
        type="roc",
 | 
			
		||||
        x_axis="False positive rate",
 | 
			
		||||
        y_axis="True positive rate",
 | 
			
		||||
        fig_title="ROC AUC curve",
 | 
			
		||||
    )
 | 
			
		||||
    plot_individual_figure(
 | 
			
		||||
        results,
 | 
			
		||||
        type="confusion_matrix",
 | 
			
		||||
        x_axis="Predicted values",
 | 
			
		||||
        y_axis="Real values",
 | 
			
		||||
        fig_title="Confusion Matrix",
 | 
			
		||||
    )
 | 
			
		||||
    plot_roc_auc_curve(results=results)
 | 
			
		||||
    plot_confusion_matrix(results=results)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_result_dataframes():
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user