82 lines
2.6 KiB
Python
82 lines
2.6 KiB
Python
from random import seed
|
|
|
|
from tensorflow.keras import Model, Sequential
|
|
from tensorflow.keras.layers import *
|
|
from tensorflow.keras.losses import categorical_crossentropy
|
|
from tensorflow.keras.optimizers import Adam
|
|
from tensorflow.keras.regularizers import l2
|
|
from tensorflow.random import set_seed
|
|
|
|
from hyperparameters import Hyperparameters
|
|
from preprocessing import BASES, dataset_creation
|
|
|
|
|
|
def build_model(hyperparams) -> Model:
|
|
"""
|
|
Build the CNN model
|
|
"""
|
|
model = Sequential(
|
|
[
|
|
Input(shape=(None, len(BASES))),
|
|
Conv1D(
|
|
filters=16,
|
|
kernel_size=5,
|
|
activation="relu",
|
|
kernel_regularizer=l2(hyperparams.l2_rate),
|
|
),
|
|
MaxPool1D(pool_size=3, strides=1),
|
|
Conv1D(
|
|
filters=16,
|
|
kernel_size=3,
|
|
activation="relu",
|
|
kernel_regularizer=l2(hyperparams.l2_rate),
|
|
),
|
|
MaxPool1D(pool_size=3, strides=1),
|
|
GlobalAveragePooling1D(),
|
|
Dense(
|
|
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
|
),
|
|
Dropout(rate=0.3),
|
|
Dense(
|
|
units=16, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
|
),
|
|
Dropout(rate=0.3),
|
|
Dense(units=len(BASES), activation="softmax"),
|
|
]
|
|
)
|
|
model.compile(
|
|
optimizer=Adam(hyperparams.learning_rate),
|
|
loss=categorical_crossentropy,
|
|
metrics=["accuracy"],
|
|
)
|
|
return model
|
|
|
|
|
|
def show_metrics(model, eval_dataset, test_dataset) -> None:
|
|
"""
|
|
Show the model metrics
|
|
"""
|
|
eval_metrics = model.evaluate(eval_dataset, verbose=0)
|
|
test_metrics = model.evaluate(test_dataset, verbose=0)
|
|
print(f"Final eval metrics - loss: {eval_metrics[0]} - accuracy: {eval_metrics[1]}")
|
|
print(f"Final test metrics - loss: {test_metrics[0]} - accuracy: {test_metrics[1]}")
|
|
|
|
|
|
def run(data_file, label_file, seed_value=42) -> None:
|
|
"""
|
|
Create a dataset, a model and runs training and evaluation on it
|
|
"""
|
|
seed(seed_value)
|
|
set_seed(seed_value)
|
|
hyperparams = Hyperparameters(data_file=data_file, label_file=label_file)
|
|
train_data, eval_data, test_data = dataset_creation(hyperparams)
|
|
model = build_model(hyperparams)
|
|
print("Training the model")
|
|
model.fit(train_data, epochs=hyperparams.epochs, validation_data=eval_data)
|
|
print("Training complete. Obtaining the model's metrics...")
|
|
show_metrics(model, eval_data, test_data)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run(data_file="data/curesim-HVR.fastq", label_file="data/HVR.fastq")
|