From d34e291085b42a3dbdc6be2c89654d12a040be72 Mon Sep 17 00:00:00 2001
From: coolneng <akasroua@gmail.com>
Date: Tue, 1 Jun 2021 23:06:25 +0200
Subject: [PATCH] Generate a dataset from the TFRecords files

---
 src/constants.py     |  2 ++
 src/preprocessing.py | 26 ++++++++++++++++++++++----
 2 files changed, 24 insertions(+), 4 deletions(-)

diff --git a/src/constants.py b/src/constants.py
index c5d64bc..a9be1d0 100644
--- a/src/constants.py
+++ b/src/constants.py
@@ -1,3 +1,5 @@
 BASES = "ACGT"
 TRAIN_DATASET = "data/train_data.tfrecords"
 TEST_DATASET = "data/test_data.tfrecords"
+EPOCHS = 1000
+BATCH_SIZE = 256
diff --git a/src/preprocessing.py b/src/preprocessing.py
index 68aad09..34cb1b6 100644
--- a/src/preprocessing.py
+++ b/src/preprocessing.py
@@ -1,12 +1,14 @@
+from typing import List
+
 from Bio.motifs import create
 from Bio.SeqIO import parse
 from numpy.random import random
-from tensorflow.io import TFRecordWriter
+from tensorflow import float32, string
 from tensorflow.data import TFRecordDataset
+from tensorflow.io import FixedLenFeature, TFRecordWriter, parse_single_example
 from tensorflow.train import BytesList, Example, Feature, Features, FloatList
-from typing import List
 
-from constants import TRAIN_DATASET, TEST_DATASET
+from constants import BATCH_SIZE, EPOCHS, TEST_DATASET, TRAIN_DATASET
 
 
 def generate_example(sequence, weight_matrix) -> bytes:
@@ -52,8 +54,24 @@ def create_dataset(filepath) -> None:
                 test.write(element)
 
 
+def process_input(byte_string):
+    schema = {
+        "sequence": FixedLenFeature(shape=[], dtype=string),
+        "A_counts": FixedLenFeature(shape=[], dtype=float32),
+        "C_counts": FixedLenFeature(shape=[], dtype=float32),
+        "G_counts": FixedLenFeature(shape=[], dtype=float32),
+        "T_counts": FixedLenFeature(shape=[], dtype=float32),
+    }
+    return parse_single_example(byte_string, features=schema)
+
+
 def read_dataset():
-    pass
+    data_input = TFRecordDataset(filenames=[TRAIN_DATASET, TEST_DATASET])
+    dataset = data_input.map(map_func=process_input)
+    shuffled_dataset = dataset.shuffle(buffer_size=10000, reshuffle_each_iteration=True)
+    batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS)
+    return batched_dataset
 
 
 create_dataset("data/curesim-HVR.fastq")
+dataset = read_dataset()