85 lines
2.9 KiB
Python
85 lines
2.9 KiB
Python
from random import seed
|
|
|
|
from numpy import argmax
|
|
from tensorflow import one_hot
|
|
from tensorflow.keras import Model, Sequential
|
|
from tensorflow.keras.layers import Dense, Dropout, Input, Masking
|
|
from tensorflow.keras.losses import categorical_crossentropy
|
|
from tensorflow.keras.models import load_model
|
|
from tensorflow.keras.optimizers import Adam
|
|
from tensorflow.keras.regularizers import l2
|
|
from tensorflow.random import set_seed
|
|
|
|
from hyperparameters import Hyperparameters
|
|
from preprocessing import BASES, dataset_creation, decode_sequence, encode_sequence
|
|
|
|
|
|
def build_model(hyperparams) -> Model:
|
|
"""
|
|
Build the CNN model
|
|
"""
|
|
model = Sequential(
|
|
[
|
|
Input(shape=(hyperparams.batch_size, hyperparams.max_length, len(BASES))),
|
|
Masking(mask_value=-1),
|
|
Dense(
|
|
units=256, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
|
),
|
|
Dropout(rate=0.3),
|
|
Dense(
|
|
units=128, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
|
),
|
|
Dropout(rate=0.3),
|
|
Dense(
|
|
units=64, activation="relu", kernel_regularizer=l2(hyperparams.l2_rate)
|
|
),
|
|
Dropout(rate=0.3),
|
|
Dense(units=len(BASES), activation="softmax"),
|
|
]
|
|
)
|
|
model.compile(
|
|
optimizer=Adam(hyperparams.learning_rate),
|
|
loss=categorical_crossentropy,
|
|
metrics=["accuracy", "AUC"],
|
|
)
|
|
return model
|
|
|
|
|
|
def show_metrics(model, eval_dataset, test_dataset) -> None:
|
|
"""
|
|
Show the model metrics
|
|
"""
|
|
eval_metrics = model.evaluate(eval_dataset, verbose=0)
|
|
test_metrics = model.evaluate(test_dataset, verbose=0)
|
|
print(f"Eval metrics {eval_metrics}")
|
|
print(f"Test metrics {test_metrics}")
|
|
|
|
|
|
def train_model(data_file, label_file, seed_value=42) -> None:
|
|
"""
|
|
Create a dataset, a model and runs training and evaluation on it
|
|
"""
|
|
seed(seed_value)
|
|
set_seed(seed_value)
|
|
hyperparams = Hyperparameters(data_file=data_file, label_file=label_file)
|
|
train_data, eval_data, test_data = dataset_creation(hyperparams)
|
|
model = build_model(hyperparams)
|
|
print("Training the model")
|
|
model.fit(train_data, epochs=hyperparams.epochs, validation_data=eval_data)
|
|
print("Training complete. Obtaining the model's metrics...")
|
|
show_metrics(model, eval_data, test_data)
|
|
model.save("trained_model")
|
|
|
|
|
|
async def infer_sequence(sequence) -> str:
|
|
"""
|
|
Predict the correct sequence, using the trained model
|
|
"""
|
|
model = load_model("trained_model")
|
|
encoded_sequence = encode_sequence(sequence)
|
|
one_hot_encoded_sequence = one_hot(encoded_sequence, depth=len(BASES))
|
|
prediction = model.predict(one_hot_encoded_sequence)
|
|
encoded_prediction = argmax(prediction, axis=1)
|
|
final_prediction = decode_sequence(encoded_prediction)
|
|
return final_prediction
|