locimend/src/model.py

85 lines
2.9 KiB
Python

from random import seed
from numpy import argmax
from tensorflow import one_hot
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import Dense, Dropout, Input, Masking
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.random import set_seed
from hyperparameters import Hyperparameters
from preprocessing import BASES, dataset_creation, decode_sequence, encode_sequence
def build_model(hyperparams) -> Model:
"""
Build the CNN model
"""
model = Sequential(
[
Input(shape=(hyperparams.batch_size, hyperparams.max_length, len(BASES))),
Masking(mask_value=-1),
Dense(
units=256, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
),
Dropout(rate=0.3),
Dense(
units=128, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
),
Dropout(rate=0.3),
Dense(
units=64, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
),
Dropout(rate=0.3),
Dense(units=len(BASES), activation="softmax"),
]
)
model.compile(
optimizer=Adam(hyperparams.learning_rate),
loss=categorical_crossentropy,
metrics=["accuracy", "AUC"],
)
return model
def show_metrics(model, eval_dataset, test_dataset) -> None:
"""
Show the model metrics
"""
eval_metrics = model.evaluate(eval_dataset, verbose=0)
test_metrics = model.evaluate(test_dataset, verbose=0)
print(f"Eval metrics {eval_metrics}")
print(f"Test metrics {test_metrics}")
def train_model(data_file, label_file, seed_value=42) -> None:
"""
Create a dataset, a model and runs training and evaluation on it
"""
seed(seed_value)
set_seed(seed_value)
hyperparams = Hyperparameters(data_file=data_file, label_file=label_file)
train_data, eval_data, test_data = dataset_creation(hyperparams)
model = build_model(hyperparams)
print("Training the model")
model.fit(train_data, epochs=hyperparams.epochs, validation_data=eval_data)
print("Training complete. Obtaining the model's metrics...")
show_metrics(model, eval_data, test_data)
model.save("trained_model")
async def infer_sequence(sequence) -> str:
"""
Predict the correct sequence, using the trained model
"""
model = load_model("trained_model")
encoded_sequence = encode_sequence(sequence)
one_hot_encoded_sequence = one_hot(encoded_sequence, depth=len(BASES))
prediction = model.predict(one_hot_encoded_sequence)
encoded_prediction = argmax(prediction, axis=1)
final_prediction = decode_sequence(encoded_prediction)
return final_prediction