From 858a5698a7b4ee95befa2e9c3d7aaa0a8170ec54 Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Thu, 6 May 2021 11:46:26 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 372390520 Change-Id: I1f0caa5bbda11862310a7c85e77f5df9e8fc3709 --- .../data/experimental/snapshot_dataset_op.cc | 15 ++++++--------- .../experimental/kernel_tests/snapshot_test.py | 13 +++++++++++++ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index 1918fae7e0d5ea..0db8bbd0f93dd6 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -251,8 +251,6 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { Reader(const Params& params, int64 start_index) : DatasetIterator(params), start_index_(start_index) {} - ~Reader() override { input_->Unref(); } - Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); @@ -301,11 +299,6 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { } TF_RETURN_IF_ERROR( GetDatasetFromVariantTensor(reader_output[0], &input_)); - - // We need to take a reference here as we will use the input_ and - // its iterator. - input_->Ref(); - return input_->MakeIterator(ctx, this, prefix(), &input_impl_); } @@ -337,7 +330,7 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { std::unique_ptr input_impl_ TF_GUARDED_BY(mu_); - DatasetBase* input_ TF_GUARDED_BY(mu_); + DatasetBase* input_ TF_GUARDED_BY(mu_) = nullptr; std::unique_ptr instantiated_reader_func_ TF_GUARDED_BY(mu_); @@ -614,7 +607,11 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { bool* end_of_sequence) override { mutex_lock l(mu_); if (iterator_ == nullptr) { - TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr)); + Status s = InitializeIterator(ctx, /*reader=*/nullptr); + if (!s.ok()) { + iterator_.reset(); + return s; + } } index_++; return iterator_->GetNext(ctx, out_tensors, end_of_sequence); diff --git a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py index fe6db1eb860444..f720966df99b4f 100644 --- a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py @@ -413,6 +413,19 @@ def testReadOptimizableUsingFlatMap(self): num_runs_per_fingerprint=1, num_snapshot_shards_per_run=multiprocessing.cpu_count()) + @combinations.generate(test_base.default_test_combinations()) + def testRepeatAndPrefetch(self): + """This test reproduces github.com/tensorflow/tensorflow/issues/48903.""" + dataset = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32)) + dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir)) + dataset = dataset.shuffle(buffer_size=16) + dataset = dataset.batch(16) + dataset = dataset.repeat() + dataset = dataset.prefetch(1) + next_element = self.getNext(dataset) + for _ in range(30): + self.evaluate(next_element()) + class LegacySnapshotTest(tf_record_test_base.TFRecordTestBase, parameterized.TestCase):