from random import seed from tensorflow.keras import Model, Sequential from tensorflow.keras.layers import * from tensorflow.keras.losses import categorical_crossentropy from tensorflow.keras.optimizers import Adam from tensorflow.keras.regularizers import l2 from tensorflow.random import set_seed from hyperparameters import Hyperparameters from preprocessing import BASES, dataset_creation def build_model(hyperparams) -> Model: """ Build the CNN model """ model = Sequential( [ Input(shape=(None, len(BASES))), Conv1D( filters=16, kernel_size=5, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate), ), MaxPool1D(pool_size=3, strides=1), Conv1D( filters=16, kernel_size=3, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate), ), MaxPool1D(pool_size=3, strides=1), GlobalAveragePooling1D(), Dense( units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate) ), Dropout(rate=0.3), Dense( units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate) ), Dropout(rate=0.3), Dense(units=len(BASES), activation="softmax"), ] ) model.compile( optimizer=Adam(hyperparams.learning_rate), loss=categorical_crossentropy, metrics=["accuracy"], ) return model def show_metrics(model, eval_dataset, test_dataset) -> None: """ Show the model metrics """ eval_metrics = model.evaluate(eval_dataset, verbose=0) test_metrics = model.evaluate(test_dataset, verbose=0) print(f"Final eval metrics - loss: {eval_metrics[0]} - accuracy: {eval_metrics[1]}") print(f"Final test metrics - loss: {test_metrics[0]} - accuracy: {test_metrics[1]}") def run(data_file, label_file, seed_value=42) -> None: """ Create a dataset, a model and runs training and evaluation on it """ seed(seed_value) set_seed(seed_value) hyperparams = Hyperparameters(data_file=data_file, label_file=label_file) train_data, eval_data, test_data = dataset_creation(hyperparams) model = build_model(hyperparams) print("Training the model") model.fit(train_data, epochs=hyperparams.epochs, validation_data=eval_data) print("Training complete. Obtaining the model's metrics...") show_metrics(model, eval_data, test_data) if __name__ == "__main__": run(data_file="data/curesim-HVR.fastq", label_file="data/HVR.fastq")