from typing import Dict, List, Tuple from Bio.pairwise2 import align from Bio.SeqIO import parse from numpy.random import random from tensorflow import Tensor, int64, one_hot from tensorflow.data import AUTOTUNE, TFRecordDataset from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example from tensorflow.sparse import to_dense from tensorflow.train import Example, Feature, Features, Int64List BASES = "ACGT-" def align_sequences(sequence, label) -> Tuple[str, str]: """ Align the altered sequence with the reference sequence to obtain a same length output """ alignments = align.globalxx(label, sequence) best_alignment = alignments[0] aligned_seq, aligned_label, _, _, _ = best_alignment return aligned_seq, aligned_label def generate_example(sequence, label) -> bytes: """ Create a binary-string for each sequence containing the sequence and the bases' counts """ aligned_seq, aligned_label = align_sequences(sequence, label) schema = { "sequence": Feature(int64_list=Int64List(value=encode_sequence(aligned_seq))), "label": Feature(int64_list=Int64List(value=encode_sequence(aligned_label))), } example = Example(features=Features(feature=schema)) return example.SerializeToString() def encode_sequence(sequence) -> List[int]: """ Encode the DNA sequence using the indices of the BASES constant """ encoded_sequence = [BASES.index(element) for element in sequence] return encoded_sequence def read_fastq(hyperparams) -> List[bytes]: """ Parses a data and a label FASTQ files and generates a List of serialized Examples """ examples = [] with open(hyperparams.data_file) as data, open(hyperparams.label_file) as labels: for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")): example = generate_example(sequence=str(element.seq), label=str(label.seq)) examples.append(example) return examples def create_dataset(hyperparams, dataset_split=[0.8, 0.1, 0.1]) -> None: """ Create a training, evaluation and test dataset with a 80/10/10 split respectively """ data = read_fastq(hyperparams) with TFRecordWriter(hyperparams.train_dataset) as training, TFRecordWriter( hyperparams.test_dataset ) as test, TFRecordWriter(hyperparams.eval_dataset) as evaluation: for element in data: if random() < dataset_split[0]: training.write(element) elif random() < dataset_split[0] + dataset_split[1]: evaluation.write(element) else: test.write(element) def transform_features(parsed_features) -> Dict[str, Tensor]: """ Transform the parsed features of an Example into a list of dense one hot encoded Tensors """ features = {} sparse_features = ["sequence", "label"] for element in sparse_features: features[element] = to_dense(parsed_features[element]) features[element] = one_hot(features[element], depth=len(BASES)) return features def process_input(byte_string) -> Tuple[Tensor, Tensor]: """ Parse a byte-string into an Example object """ schema = { "sequence": VarLenFeature(dtype=int64), "label": VarLenFeature(dtype=int64), } parsed_features = parse_single_example(byte_string, features=schema) features = transform_features(parsed_features) return features["sequence"], features["label"] def read_dataset(filepath, hyperparams) -> TFRecordDataset: """ Read TFRecords files and generate a dataset """ data_input = TFRecordDataset(filenames=filepath) dataset = data_input.map(map_func=process_input, num_parallel_calls=AUTOTUNE) shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42) batched_dataset = shuffled_dataset.batch(batch_size=hyperparams.batch_size).repeat( count=hyperparams.epochs ) return batched_dataset def dataset_creation( hyperparams, ) -> Tuple[TFRecordDataset, TFRecordDataset, TFRecordDataset]: """ Generate the TFRecord files and split them into training, validation and test data """ create_dataset(hyperparams) train_data = read_dataset(hyperparams.train_dataset, hyperparams) eval_data = read_dataset(hyperparams.eval_dataset, hyperparams) test_data = read_dataset(hyperparams.test_dataset, hyperparams) return train_data, eval_data, test_data