From 75ca952f5b8436fa8262b629918374de8dca56f3 Mon Sep 17 00:00:00 2001
From: coolneng <akasroua@gmail.com>
Date: Wed, 23 Jun 2021 18:27:19 +0200
Subject: [PATCH] Remove dense Tensor transformation

---
 src/preprocessing.py | 31 ++++++-------------------------
 1 file changed, 6 insertions(+), 25 deletions(-)

diff --git a/src/preprocessing.py b/src/preprocessing.py
index df705a6..04cabbe 100644
--- a/src/preprocessing.py
+++ b/src/preprocessing.py
@@ -1,11 +1,9 @@
-from typing import Dict, List, Tuple
 
 from Bio.SeqIO import parse
 from numpy.random import random
 from tensorflow import Tensor, int64
 from tensorflow.data import TFRecordDataset
 from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example
-from tensorflow.sparse import to_dense
 from tensorflow.train import Example, Feature, Features, Int64List
 
 from constants import *
@@ -38,44 +36,28 @@ def read_fastq(data_file, label_file) -> List[bytes]:
     examples = []
     with open(data_file) as data, open(label_file) as labels:
         for element, label in zip(parse(data, "fastq"), parse(labels, "fastq")):
-            example = generate_example(
-                sequence=str(element.seq),
-                label=str(label.seq),
-            )
+            example = generate_example(sequence=str(element.seq), label=str(label.seq))
             examples.append(example)
     return examples
 
 
-def create_dataset(
-    data_file, label_file, train_eval_test_split=[0.8, 0.1, 0.1]
-) -> None:
+def create_dataset(data_file, label_file, dataset_split=[0.8, 0.1, 0.1]) -> None:
     """
-    Create a training, evaluation and test dataset with a 80/10/30 split respectively
+    Create a training, evaluation and test dataset with a 80/10/10 split respectively
     """
     data = read_fastq(data_file, label_file)
     with TFRecordWriter(TRAIN_DATASET) as training, TFRecordWriter(
         TEST_DATASET
     ) as test, TFRecordWriter(EVAL_DATASET) as evaluation:
         for element in data:
-            if random() < train_eval_test_split[0]:
+            if random() < dataset_split[0]:
                 training.write(element)
-            elif random() < train_eval_test_split[0] + train_eval_test_split[1]:
+            elif random() < dataset_split[0] + dataset_split[1]:
                 evaluation.write(element)
             else:
                 test.write(element)
 
 
-def transform_features(parsed_features) -> Dict[str, Tensor]:
-    """
-    Transform the parsed features of an Example into a list of dense Tensors
-    """
-    features = {}
-    sparse_features = ["sequence", "label"]
-    for element in sparse_features:
-        features[element] = to_dense(parsed_features[element])
-    return features
-
-
 def process_input(byte_string) -> Tuple[Tensor, Tensor]:
     """
     Parse a byte-string into an Example object
@@ -84,8 +66,7 @@ def process_input(byte_string) -> Tuple[Tensor, Tensor]:
         "sequence": VarLenFeature(dtype=int64),
         "label": VarLenFeature(dtype=int64),
     }
-    parsed_features = parse_single_example(byte_string, features=schema)
-    features = transform_features(parsed_features)
+    features = parse_single_example(byte_string, features=schema)
     return features["sequence"], features["label"]