124 lines
4.4 KiB
Python
124 lines
4.4 KiB
Python
from typing import Dict, List, Tuple
|
|
|
|
from Bio.pairwise2 import align
|
|
from Bio.SeqIO import parse
|
|
from numpy.random import random
|
|
from tensorflow import Tensor, int64, one_hot
|
|
from tensorflow.data import AUTOTUNE, TFRecordDataset
|
|
from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example
|
|
from tensorflow.sparse import to_dense
|
|
from tensorflow.train import Example, Feature, Features, Int64List
|
|
|
|
BASES = "ACGT-"
|
|
|
|
|
|
def align_sequences(sequence, label) -> Tuple[str, str]:
|
|
"""
|
|
Align the altered sequence with the reference sequence to obtain a same length output
|
|
"""
|
|
alignments = align.globalxx(label, sequence)
|
|
best_alignment = alignments[0]
|
|
aligned_seq, aligned_label, _, _, _ = best_alignment
|
|
return aligned_seq, aligned_label
|
|
|
|
|
|
def generate_example(sequence, label) -> bytes:
|
|
"""
|
|
Create a binary-string for each sequence containing the sequence and the bases' counts
|
|
"""
|
|
aligned_seq, aligned_label = align_sequences(sequence, label)
|
|
schema = {
|
|
"sequence": Feature(int64_list=Int64List(value=encode_sequence(aligned_seq))),
|
|
"label": Feature(int64_list=Int64List(value=encode_sequence(aligned_label))),
|
|
}
|
|
example = Example(features=Features(feature=schema))
|
|
return example.SerializeToString()
|
|
|
|
|
|
def encode_sequence(sequence) -> List[int]:
|
|
"""
|
|
Encode the DNA sequence using the indices of the BASES constant
|
|
"""
|
|
encoded_sequence = [BASES.index(element) for element in sequence]
|
|
return encoded_sequence
|
|
|
|
|
|
def read_fastq(hyperparams) -> List[bytes]:
|
|
"""
|
|
Parses a data and a label FASTQ files and generates a List of serialized Examples
|
|
"""
|
|
examples = []
|
|
with open(hyperparams.data_file) as data, open(hyperparams.label_file) as labels:
|
|
for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")):
|
|
example = generate_example(sequence=str(element.seq), label=str(label.seq))
|
|
examples.append(example)
|
|
return examples
|
|
|
|
|
|
def create_dataset(hyperparams, dataset_split=[0.8, 0.1, 0.1]) -> None:
|
|
"""
|
|
Create a training, evaluation and test dataset with a 80/10/10 split respectively
|
|
"""
|
|
data = read_fastq(hyperparams)
|
|
with TFRecordWriter(hyperparams.train_dataset) as training, TFRecordWriter(
|
|
hyperparams.test_dataset
|
|
) as test, TFRecordWriter(hyperparams.eval_dataset) as evaluation:
|
|
for element in data:
|
|
if random() < dataset_split[0]:
|
|
training.write(element)
|
|
elif random() < dataset_split[0] + dataset_split[1]:
|
|
evaluation.write(element)
|
|
else:
|
|
test.write(element)
|
|
|
|
|
|
def transform_features(parsed_features) -> Dict[str, Tensor]:
|
|
"""
|
|
Transform the parsed features of an Example into a list of dense one hot encoded Tensors
|
|
"""
|
|
features = {}
|
|
sparse_features = ["sequence", "label"]
|
|
for element in sparse_features:
|
|
features[element] = to_dense(parsed_features[element])
|
|
features[element] = one_hot(features[element], depth=len(BASES))
|
|
return features
|
|
|
|
|
|
def process_input(byte_string) -> Tuple[Tensor, Tensor]:
|
|
"""
|
|
Parse a byte-string into an Example object
|
|
"""
|
|
schema = {
|
|
"sequence": VarLenFeature(dtype=int64),
|
|
"label": VarLenFeature(dtype=int64),
|
|
}
|
|
parsed_features = parse_single_example(byte_string, features=schema)
|
|
features = transform_features(parsed_features)
|
|
return features["sequence"], features["label"]
|
|
|
|
|
|
def read_dataset(filepath, hyperparams) -> TFRecordDataset:
|
|
"""
|
|
Read TFRecords files and generate a dataset
|
|
"""
|
|
data_input = TFRecordDataset(filenames=filepath)
|
|
dataset = data_input.map(map_func=process_input, num_parallel_calls=AUTOTUNE)
|
|
shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42)
|
|
batched_dataset = shuffled_dataset.batch(batch_size=hyperparams.batch_size).repeat(
|
|
count=hyperparams.epochs
|
|
)
|
|
return batched_dataset
|
|
|
|
|
|
def dataset_creation(
|
|
hyperparams,
|
|
) -> Tuple[TFRecordDataset, TFRecordDataset, TFRecordDataset]:
|
|
"""
|
|
Generate the TFRecord files and split them into training, validation and test data
|
|
"""
|
|
create_dataset(hyperparams)
|
|
train_data = read_dataset(hyperparams.train_dataset, hyperparams)
|
|
eval_data = read_dataset(hyperparams.eval_dataset, hyperparams)
|
|
test_data = read_dataset(hyperparams.test_dataset, hyperparams)
|
|
return train_data, eval_data, test_data
|