diff --git a/src/preprocessing.py b/src/preprocessing.py index a2673f8..ca72b67 100644 --- a/src/preprocessing.py +++ b/src/preprocessing.py @@ -81,14 +81,14 @@ def transform_features(parsed_features) -> List[Tensor]: """ Cast and transform the parsed features of an Example into a list of Tensors """ + sparse_features = ["sequence", "label"] + for feature in sparse_features: + parsed_features[feature] = cast(parsed_features[feature], int32) + parsed_features[feature] = to_dense(parsed_features[feature]) for base in BASES: parsed_features[f"{base}_counts"] = cast( parsed_features[f"{base}_counts"], int32 ) - parsed_features["sequence"] = cast(parsed_features["sequence"], int32) - parsed_features["label"] = cast(parsed_features["label"], int32) - parsed_features["sequence"] = to_dense(parsed_features["sequence"]) - parsed_features["label"] = to_dense(parsed_features["label"]) features = list(parsed_features.values())[:-1] return features