Skip to content

Commit

Permalink
Merge pull request #49124 from tensorflow/mm-cherrypick-tf-data-segfa…
Browse files Browse the repository at this point in the history
…ult-fix-to-r2.5

[tf.data][cherrypick] Fix snapshot segfault when using repeat and pre…
  • Loading branch information
mihaimaruseac committed May 12, 2021
2 parents 2107b1d + 16b8139 commit a4dfb8d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
16 changes: 6 additions & 10 deletions tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
Expand Up @@ -201,8 +201,6 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Reader

explicit Reader(const Params& params, int64 start_index);

~Reader() override;

Status Initialize(IteratorContext* ctx) override;

Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
Expand All @@ -222,7 +220,7 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Reader

std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);

DatasetBase* input_ TF_GUARDED_BY(mu_);
DatasetBase* input_ TF_GUARDED_BY(mu_) = nullptr;

std::unique_ptr<InstantiatedCapturedFunction> instantiated_reader_func_
TF_GUARDED_BY(mu_);
Expand Down Expand Up @@ -468,7 +466,11 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::GetNextInternal(
bool* end_of_sequence) {
mutex_lock l(mu_);
if (iterator_ == nullptr) {
TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr));
Status s = InitializeIterator(ctx, nullptr);
if (!s.ok()) {
iterator_.reset();
return s;
}
}
index_++;
return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
Expand Down Expand Up @@ -547,8 +549,6 @@ SnapshotDatasetV2Op::Dataset::Iterator::Reader::Reader(const Params& params,
int64 start_index)
: DatasetIterator<Dataset>(params), start_index_(start_index) {}

SnapshotDatasetV2Op::Dataset::Iterator::Reader::~Reader() { input_->Unref(); }

Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
IteratorContext* ctx) {
mutex_lock l(mu_);
Expand Down Expand Up @@ -597,10 +597,6 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
}
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(reader_output[0], &input_));

// We need to take a reference here as we will use the input_ and
// its iterator.
input_->Ref();

return input_->MakeIterator(ctx, this, prefix(), &input_impl_);
}

Expand Down
13 changes: 13 additions & 0 deletions tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
Expand Up @@ -413,6 +413,19 @@ def testReadOptimizableUsingFlatMap(self):
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())

@combinations.generate(test_base.default_test_combinations())
def testRepeatAndPrefetch(self):
"""This test reproduces github.com/tensorflow/tensorflow/issues/48903."""
dataset = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32))
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
dataset = dataset.shuffle(buffer_size=16)
dataset = dataset.batch(16)
dataset = dataset.repeat()
dataset = dataset.prefetch(1)
next_element = self.getNext(dataset)
for _ in range(30):
self.evaluate(next_element())


class LegacySnapshotTest(tf_record_test_base.TFRecordTestBase,
parameterized.TestCase):
Expand Down

0 comments on commit a4dfb8d

Please sign in to comment.