Skip to content

Commit

Permalink
Move TFRecord examples to the exported class.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 365149622
Change-Id: Ib7dfba397b5c2a9757cfd14d8a2537f38dadf35d
  • Loading branch information
MarkDaoust authored and tensorflower-gardener committed Mar 25, 2021
1 parent 054d641 commit fc75295
Showing 1 changed file with 43 additions and 43 deletions.
86 changes: 43 additions & 43 deletions tensorflow/python/data/ops/readers.py
Expand Up @@ -245,48 +245,7 @@ def _filenames(self, value):


class _TFRecordDataset(dataset_ops.DatasetSource):
"""A `Dataset` comprising records from one or more TFRecord files.
This dataset loads TFRecords from file as bytes, exactly as they were written.
`TFRecordDataset` does not do any parsing or decoding on its own. Parsing and
decoding can be done by applying `Dataset.map` transformations after the
`TFRecordDataset`.
A minimal example is given below:
>>> import tempfile
>>> example_path = os.path.join(tempfile.gettempdir(), "example.tfrecords")
>>> np.random.seed(0)
>>> # Write the records to a file.
... with tf.io.TFRecordWriter(example_path) as file_writer:
... for _ in range(4):
... x, y = np.random.random(), np.random.random()
...
... record_bytes = tf.train.Example(features=tf.train.Features(feature={
... "x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),
... "y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),
... })).SerializeToString()
... file_writer.write(record_bytes)
>>> # Read the data back out.
>>> def decode_fn(record_bytes):
... return tf.io.parse_single_example(
... # Data
... record_bytes,
...
... # Schema
... {"x": tf.io.FixedLenFeature([], dtype=tf.float32),
... "y": tf.io.FixedLenFeature([], dtype=tf.float32)}
... )
>>> for batch in tf.data.TFRecordDataset([example_path]).map(decode_fn):
... print("x = {x:.4f}, y = {y:.4f}".format(**batch))
x = 0.5488, y = 0.7152
x = 0.6028, y = 0.5449
x = 0.4237, y = 0.6459
x = 0.4376, y = 0.8918
"""
"""A `Dataset` comprising records from one or more TFRecord files."""

def __init__(self, filenames, compression_type=None, buffer_size=None):
"""Creates a `TFRecordDataset`.
Expand Down Expand Up @@ -374,7 +333,48 @@ def _transformation_name(self):

@tf_export("data.TFRecordDataset", v1=[])
class TFRecordDatasetV2(dataset_ops.DatasetV2):
"""A `Dataset` comprising records from one or more TFRecord files."""
"""A `Dataset` comprising records from one or more TFRecord files.
This dataset loads TFRecords from the files as bytes, exactly as they were
written.`TFRecordDataset` does not do any parsing or decoding on its own.
Parsing and decoding can be done by applying `Dataset.map` transformations
after the `TFRecordDataset`.
A minimal example is given below:
>>> import tempfile
>>> example_path = os.path.join(tempfile.gettempdir(), "example.tfrecords")
>>> np.random.seed(0)
>>> # Write the records to a file.
... with tf.io.TFRecordWriter(example_path) as file_writer:
... for _ in range(4):
... x, y = np.random.random(), np.random.random()
...
... record_bytes = tf.train.Example(features=tf.train.Features(feature={
... "x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),
... "y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),
... })).SerializeToString()
... file_writer.write(record_bytes)
>>> # Read the data back out.
>>> def decode_fn(record_bytes):
... return tf.io.parse_single_example(
... # Data
... record_bytes,
...
... # Schema
... {"x": tf.io.FixedLenFeature([], dtype=tf.float32),
... "y": tf.io.FixedLenFeature([], dtype=tf.float32)}
... )
>>> for batch in tf.data.TFRecordDataset([example_path]).map(decode_fn):
... print("x = {x:.4f}, y = {y:.4f}".format(**batch))
x = 0.5488, y = 0.7152
x = 0.6028, y = 0.5449
x = 0.4237, y = 0.6459
x = 0.4376, y = 0.8918
"""

def __init__(self,
filenames,
Expand Down

0 comments on commit fc75295

Please sign in to comment.