Parallelize dataset transformations

This commit is contained in:
2021-06-24 19:30:46 +02:00
parent b2f20f2070
commit e9582d0883

View File

@@ -5,6 +5,7 @@ from Bio.SeqIO import parse
from numpy.random import random from numpy.random import random
from tensorflow import Tensor, int64 from tensorflow import Tensor, int64
from tensorflow.data import TFRecordDataset from tensorflow.data import TFRecordDataset
from tensorflow.data import AUTOTUNE, TFRecordDataset
from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example from tensorflow.io import TFRecordWriter, VarLenFeature, parse_single_example
from tensorflow.sparse import to_dense from tensorflow.sparse import to_dense
from tensorflow.train import Example, Feature, Features, Int64List from tensorflow.train import Example, Feature, Features, Int64List
@@ -104,7 +105,7 @@ def read_dataset(filepath) -> TFRecordDataset:
Read TFRecords files and generate a dataset Read TFRecords files and generate a dataset
""" """
data_input = TFRecordDataset(filenames=filepath) data_input = TFRecordDataset(filenames=filepath)
dataset = data_input.map(map_func=process_input) dataset = data_input.map(map_func=process_input, num_parallel_calls=AUTOTUNE)
shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42) shuffled_dataset = dataset.shuffle(buffer_size=10000, seed=42)
batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS) batched_dataset = shuffled_dataset.batch(batch_size=BATCH_SIZE).repeat(count=EPOCHS)
return batched_dataset return batched_dataset