Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 372390520
Change-Id: I1f0caa5bbda11862310a7c85e77f5df9e8fc3709
  • Loading branch information
yangustc07 authored and tensorflower-gardener committed May 6, 2021
1 parent 080bc01 commit 858a569
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
15 changes: 6 additions & 9 deletions tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
Expand Up @@ -251,8 +251,6 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase {
Reader(const Params& params, int64 start_index)
: DatasetIterator<Dataset>(params), start_index_(start_index) {}

~Reader() override { input_->Unref(); }

Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);

Expand Down Expand Up @@ -301,11 +299,6 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase {
}
TF_RETURN_IF_ERROR(
GetDatasetFromVariantTensor(reader_output[0], &input_));

// We need to take a reference here as we will use the input_ and
// its iterator.
input_->Ref();

This comment has been minimized.

Copy link
@ashahab

ashahab May 6, 2021

Contributor

Is this safe? This reference incrementing and decrementing is being done in all other dataset ops.
Also, given that this is a private member variable and a shared resource, how would multiple threads know that this is being referenced?

This comment has been minimized.

Copy link
@yangustc07

yangustc07 May 6, 2021

Author Member

Yes, this is safe. The Ref and Unref are done in the other datasets, but not in the iterators. When an iterator is constructed, the dataset is Refed here:

params_.dataset->Ref();

In this case, input_->MakeIterator will Ref the input_ dataset.

Hope this helps.

This comment has been minimized.

Copy link
@ashahab

ashahab May 6, 2021

Contributor

Do you know why this error only shows up for prefetch + repeat?

This comment has been minimized.

Copy link
@ashahab

ashahab May 6, 2021

Contributor

Also thanks for pointing out where params_.dataset->Ref() and Unref() are being done. I was seeing that the iterator refers to its creator dataset but didn't see the construction and destruction.

This comment has been minimized.

Copy link
@ashahab

ashahab May 6, 2021

Contributor

Do you have any plans to backport this to 2.4.*?

This comment has been minimized.

Copy link
@yangustc07

yangustc07 May 6, 2021

Author Member

I tried to print the error message returned by the TF_RETURN_IF_ERROR. Those are "Cancelled" errors.

thread 140435431581440: SnapshotDatasetV2Op::Dataset::Iterator::GetNextInternal
*** SIGSEGV (@(nil)), see go/stacktraces#s15 received by PID 8264 (TID 9451) on cpu 2; stack trace: ***
thread 140435725035264: ShuffleDatasetBase input_impl_->GetNext = Cancelled: 
thread 140435725035264: BatchDatasetOp::input_impl_->GetNext = Cancelled: 
thread 140435725035264: InfiniteRepeatOp::input_impl_->GetNext = Cancelled: 
thread 140435725035264: PrefetchThread input_impl_->GetNext = Cancelled: 
thread 140435725035264: PrefetchThread Wait for a slot in the buffer
PC: @     0x55a826e841a4  (unknown)  tensorflow::data::experimental::SnapshotDatasetV2Op::Dataset::Iterator::Reader::GetNextInternal()

My interpretation is when PrefetchOp cancels its threads here

The snapshot op is running in another thread (140435431581440) and may still try to GetNext or destruct itself. But the initialization wasn't successful due to cancellation. So the GetNext or destructor dereferences a null pointer.

This comment has been minimized.

Copy link
@yangustc07

yangustc07 May 7, 2021

Author Member

I have forwarded your backport request to the managers. I'll let you know once they decide if it's ok to patch or cherrypick.

This comment has been minimized.

Copy link
@yangustc07

yangustc07 May 7, 2021

Author Member

We're going to backport it to 2.4 and cherrypick into 2.5. Hope this helps.

This comment has been minimized.

Copy link
@ashahab

ashahab May 7, 2021

Contributor

@yangustc07 Thanks a lot!
BTW, How did you get so much information out of the TF_RETURN_IF_ERROR macro? That seems like a great debugging tool. I redefined it but only get "Cancelled" and not the stack trace.

This comment has been minimized.

Copy link
@yangustc07

yangustc07 May 7, 2021

Author Member

I tried adding printing statements to each op. Please let me know if you find a better way :)

This comment has been minimized.

Copy link
@ashahab

ashahab May 7, 2021

Contributor

Great! If you can point me to the backport commit/PR, that'd be great!

This comment has been minimized.

Copy link
@yangustc07

yangustc07 May 12, 2021

Author Member

I created #49121.

This comment has been minimized.

Copy link
@yangustc07

yangustc07 May 12, 2021

Author Member

The cherrypick for 2.5 is #49124.


return input_->MakeIterator(ctx, this, prefix(), &input_impl_);
}

Expand Down Expand Up @@ -337,7 +330,7 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase {

std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);

DatasetBase* input_ TF_GUARDED_BY(mu_);
DatasetBase* input_ TF_GUARDED_BY(mu_) = nullptr;

std::unique_ptr<InstantiatedCapturedFunction> instantiated_reader_func_
TF_GUARDED_BY(mu_);
Expand Down Expand Up @@ -614,7 +607,11 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase {
bool* end_of_sequence) override {
mutex_lock l(mu_);
if (iterator_ == nullptr) {
TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr));
Status s = InitializeIterator(ctx, /*reader=*/nullptr);
if (!s.ok()) {
iterator_.reset();
return s;
}
}
index_++;
return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
Expand Up @@ -413,6 +413,19 @@ def testReadOptimizableUsingFlatMap(self):
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())

@combinations.generate(test_base.default_test_combinations())
def testRepeatAndPrefetch(self):
"""This test reproduces github.com/tensorflow/tensorflow/issues/48903."""
dataset = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32))
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
dataset = dataset.shuffle(buffer_size=16)
dataset = dataset.batch(16)
dataset = dataset.repeat()
dataset = dataset.prefetch(1)
next_element = self.getNext(dataset)
for _ in range(30):
self.evaluate(next_element())


class LegacySnapshotTest(tf_record_test_base.TFRecordTestBase,
parameterized.TestCase):
Expand Down

1 comment on commit 858a569

@yangustc07
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original commit message was lost due to an internal link. It was:

[tf.data] Fix snapshot segfault when using repeat and prefetch.

Fixes: https://github.com/tensorflow/tensorflow/issues/48903.

`input_->MakeIterator` refs the dataset in
https://github.com/tensorflow/tensorflow/blob/a9cf3a0e4b419630f0183b0cc4e48e0641a62721/tensorflow/core/framework/dataset.cc#L679. So
we don't need to call `input_->Ref()`. Otherwise, if
`SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize` returns an error,
`input_->Ref()` isn't called, but the destructor still calls `input_->Unref()`.

If `InitializeIterator` returns an error, the iterator_ needs to be reset to
nullptr. Otherwise, if GetNextInternal is called a second time,
`iterator_->GetNext` may dereference a null `input_impl_`.

Please sign in to comment.