locimend/src/model.py

82 lines
2.6 KiB
Python

from random import seed
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import *
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.random import set_seed
from hyperparameters import Hyperparameters
from preprocessing import BASES, dataset_creation
def build_model(hyperparams) -> Model:
"""
Build the CNN model
"""
model = Sequential(
[
Input(shape=(None, len(BASES))),
Conv1D(
filters=16,
kernel_size=5,
activation="relu",
kernel_regularizer=l2(hyperparams.l2_rate),
),
MaxPool1D(pool_size=3, strides=1),
Conv1D(
filters=16,
kernel_size=3,
activation="relu",
kernel_regularizer=l2(hyperparams.l2_rate),
),
MaxPool1D(pool_size=3, strides=1),
GlobalAveragePooling1D(),
Dense(
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
),
Dropout(rate=0.3),
Dense(
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
),
Dropout(rate=0.3),
Dense(units=len(BASES), activation="softmax"),
]
)
model.compile(
optimizer=Adam(hyperparams.learning_rate),
loss=categorical_crossentropy,
metrics=["accuracy"],
)
return model
def show_metrics(model, eval_dataset, test_dataset) -> None:
"""
Show the model metrics
"""
eval_metrics = model.evaluate(eval_dataset, verbose=0)
test_metrics = model.evaluate(test_dataset, verbose=0)
print(f"Final eval metrics - loss: {eval_metrics[0]} - accuracy: {eval_metrics[1]}")
print(f"Final test metrics - loss: {test_metrics[0]} - accuracy: {test_metrics[1]}")
def run(data_file, label_file, seed_value=42) -> None:
"""
Create a dataset, a model and runs training and evaluation on it
"""
seed(seed_value)
set_seed(seed_value)
hyperparams = Hyperparameters(data_file=data_file, label_file=label_file)
train_data, eval_data, test_data = dataset_creation(hyperparams)
model = build_model(hyperparams)
print("Training the model")
model.fit(train_data, epochs=hyperparams.epochs, validation_data=eval_data)
print("Training complete. Obtaining the model's metrics...")
show_metrics(model, eval_data, test_data)
if __name__ == "__main__":
run(data_file="data/curesim-HVR.fastq", label_file="data/HVR.fastq")