Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.data.experimental.snapshot segfault when using repeat and prefetch #48903

Closed
ashahab opened this issue May 4, 2021 · 14 comments · Fixed by #49121 or #49124
Closed

tf.data.experimental.snapshot segfault when using repeat and prefetch #48903

ashahab opened this issue May 4, 2021 · 14 comments · Fixed by #49121 or #49124
Assignees
Labels
comp:data tf.data related issues TF 2.4 for issues related to TF 2.4 type:bug Bug

Comments

@ashahab
Copy link
Contributor

ashahab commented May 4, 2021

Please make sure that this is a bug. As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Centos 7
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: N/A
  • TensorFlow installed from (source or binary): 2.4.0
  • TensorFlow version (use command below): 2.4.0
  • Python version: 3.7.7
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: N/A
  • GPU model and memory:

You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with:

  1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
  2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the current behavior
Using the following simple script, we can see a segmentation fault:

import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices(np.random.rand(16, 1024))
dataset = dataset.apply(
    tf.data.experimental.snapshot('snapshot'))
dataset = dataset.shuffle(buffer_size=16)
dataset = dataset.batch(16)
dataset = dataset.repeat()
dataset = dataset.prefetch(1)
def run(dataset):
    iterator = iter(dataset)
    for _ in range(30):
        next(iterator)
for _ in range(10):
    run(dataset) 

If we run it with Tensorflow 2.4.0 (or Tensorflow 2.4.1), the output is:

...
2021-05-04 11:04:17.989897: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2021-05-04 11:04:17.990504: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2596985000 Hz
Segmentation fault (core dumped)

If either of snapshot or repeat or prefetch is removed, this would not occur.

Describe the expected behavior
The expected behavior is that there would not be a segmentation fault
Contributing - Do you
want to contribute a PR? (yes/no): - yes
Briefly describe your candidate solution
(if contributing):

Standalone code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.

import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices(np.random.rand(16, 1024))
dataset = dataset.apply(
    tf.data.experimental.snapshot('snapshot'))
dataset = dataset.shuffle(buffer_size=16)
dataset = dataset.batch(16)
dataset = dataset.repeat()
dataset = dataset.prefetch(1)
def run(dataset):
    iterator = iter(dataset)
    for _ in range(30):
        next(iterator)
for _ in range(10):
    run(dataset) 

Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.
Analyzing the core dump, this is the truncated stack trace:

#0  0x00007fa2236c08af in tensorflow::data::experimental::SnapshotDatasetV2Op::Dataset::Iterator::Reader::~Reader() ()
   from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#1  0x00007fa2236c0971 in tensorflow::data::experimental::SnapshotDatasetV2Op::Dataset::Iterator::Reader::~Reader() ()
   from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#2  0x00007fa2236c04aa in tensorflow::data::experimental::SnapshotDatasetV2Op::Dataset::Iterator::~Iterator() ()
   from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#3  0x00007fa2222eefee in tensorflow::data::MapDatasetOp::Dataset::Iterator::~Iterator() ()
   from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#4  0x00007fa222335867 in tensorflow::data::ShuffleDatasetOpBase::ShuffleDatasetBase::Iterator::~Iterator() ()
   from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#5  0x00007fa2222c13a9 in tensorflow::data::BatchDatasetOp::Dataset::Iterator::~Iterator() ()
   from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#6  0x00007fa22232b529 in tensorflow::data::RepeatDatasetOp::Dataset::ForeverIterator::~ForeverIterator() ()
   from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#7  0x00007fa223e7e385 in tensorflow::data::PrefetchDatasetOp::Dataset::Iterator::~Iterator() ()
   from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#8  0x00007fa223771615 in tensorflow::data::experimental::(anonymous namespace)::MaxIntraOpParallelismDatasetOp::Dataset::Iterator::~Iterator() ()
   from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#9  0x00007fa2222fb665 in tensorflow::data::ModelDatasetOp::Dataset::Iterator::~Iterator() ()
   from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#10 0x00007fa223e441ab in std::_Sp_counted_ptr_inplace<tensorflow::data::IteratorResource::State, std::allocator<tensorflow::data::IteratorResource::State>, (__gnu_cxx::_Lock_policy)2>::_M_dispose() ()
   from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#11 0x00007fa21d44b1f6 in std::_Sp_counted_base<(__gnu_cxx::_Lock_policy)2>::_M_release() ()
   from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#12 0x00007fa223e4dc62 in tensorflow::data::IteratorResource::~IteratorResource() () from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#13 0x00007fa223e4dd51 in tensorflow::data::IteratorResource::~IteratorResource() () from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#14 0x00007fa2199ac086 in tensorflow::ResourceMgr::ResourceAndName::~ResourceAndName() () from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/../libtensorflow_framework.so.2
#15 0x00007fa2199ae73f in tensorflow::ResourceMgr::DoDelete(std::string const&, unsigned long long, std::string const&, std::string const&) ()
   from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/../libtensorflow_framework.so.2
#16 0x00007fa2199aeb89 in tensorflow::ResourceMgr::Delete(tensorflow::ResourceHandle const&) ()
   from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/../libtensorflow_framework.so.2
#17 0x00007fa223e4f684 in tensorflow::data::DeleteIteratorOp::DoCompute(tensorflow::OpKernelContext*) ()
   from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#18 0x00007fa223e444b1 in tensorflow::data::HybridAsyncOpKernel::Compute(tensorflow::OpKernelContext*) ()
   from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#19 0x00007fa22396409b in tensorflow::KernelAndDeviceOp::Run(tensorflow::ScopedStepContainer*, tensorflow::EagerKernelArgs const&, std::vector<absl::lts_2020_02_25::variant<tensorflow::Tensor, tensorflow::TensorShape>, std::allocator<absl::lts_2020_02_25::variant<tensorflow::Tensor, tensorflow::TensorShape> > >*, tensorflow::CancellationManager*, absl::lts_2020_02_25::optional<tensorflow::EagerRemoteFunctionParams> const&) ()
   from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#20 0x00007fa22391f359 in tensorflow::EagerKernelExecute(tensorflow::EagerContext*, absl::lts_2020_02_25::InlinedVector<tensorflow::TensorHandle*, 4ul, std::allocator<tensorflow::TensorHandle*> > const&, absl::lts_2020_02_25::optional<tensorflow::EagerRemoteFunctionParams> const&, std::unique_ptr<tensorflow::KernelAndDevice, tensorflow::core::RefCountDeleter> const&, tensorflow::GraphCollector*, tensorflow::CancellationManager*, absl::lts_2020_02_25:
:Span<tensorflow::TensorHandle*>) () from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#21 0x00007fa2239202c0 in tensorflow::ExecuteNode::Run() () from /home/ashahab/dev/tensorflow-build_trunk/tmp/tf-venv/lib/python3.7/site-packages/tensorflow/python/_pywrap_tensorflow_internal.so
#22 0x00007fa22395d14f in tensorflow::EagerExecutor::SyncExecute(tensorflow::EagerNode*) ()
@ashahab ashahab added the type:bug Bug label May 4, 2021
@UsharaniPagadala UsharaniPagadala added comp:data tf.data related issues TF 2.4 for issues related to TF 2.4 labels May 5, 2021
@UsharaniPagadala
Copy link

UsharaniPagadala commented May 5, 2021

@ashahab
I was able to reproduce the issue in TF v2.4.1, TF2.5rc1 and TF-nightly(2.6.0) with no errors Please find the gist here and let us know if it helps. Thanks!

@UsharaniPagadala UsharaniPagadala added the stat:awaiting response Status - Awaiting response from author label May 5, 2021
@ashahab
Copy link
Contributor Author

ashahab commented May 5, 2021

@UsharaniPagadala I think this is a race condition. I don't know about the execution environment of the notebook and whether it allows true multi-threading.
Here are the steps I followed on a 12-core Azure nv12 box:

$pip install tensorflow==2.4.1
$python segfault.py
...
2021-05-05 04:57:51.915818: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-05-05 04:57:51.915930: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1261] Device interconnect StreamExecutor with strength 1 edge matrix:
2021-05-05 04:57:51.915948: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1267]
2021-05-05 04:57:51.966764: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2021-05-05 04:57:51.967255: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2596985000 Hz
Segmentation fault (core dumped)

Let me know if you follow the above steps you don't see the segfault.

@UsharaniPagadala UsharaniPagadala removed the stat:awaiting response Status - Awaiting response from author label May 5, 2021
@UsharaniPagadala
Copy link

@jvishnuvardhan
I was able to run the code in tf2.4, tf2.5rc1 and tf-nightly .Please find the gist here.Thanks

@yangustc07
Copy link
Member

Thanks for the link. I can also reproduce the issue.

@jvishnuvardhan jvishnuvardhan removed their assignment May 5, 2021
@ashahab
Copy link
Contributor Author

ashahab commented May 5, 2021

@yangustc07 Were you able to see the segmentation fault?

@yangustc07
Copy link
Member

Yes, I can see the segmentation fault and I'm working on a fix. Inputs are welcome if you have more information.

@ashahab
Copy link
Contributor Author

ashahab commented May 5, 2021

@yangustc07 Thanks for reproducing the issue.
Yes I have more inputs.
This is caused only when we combine these five: snapshot, shuffle, batch, repeat, and prefetch. I have been unsuccessful in removing any of these to narrow the problem.

I added some logging where snapshot Reader was getting an input reference: input_->Ref() and where it was returning the reference: input_->UnRef().
All threads except one will first execute input_->Ref() and then execute input_->UnRef() (I logged the thread ids). However, the last thread invokes input_->UnRef() without doing the prior input_->Ref(). It's trying to invoke input_->UnRef() on a null(0) input_.

@yangustc07
Copy link
Member

Tried to debug more. The reason one thread does not call input_->Ref() is SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize returns a cancelled error somewhere. In that case, the destructor shouldn't call input_->UnRef(), and there shouldn't be any calls to Reader::GetNextInternal().

@ashahab
Copy link
Contributor Author

ashahab commented May 6, 2021 via email

@yangustc07
Copy link
Member

yangustc07 commented May 6, 2021

Yes, thanks for the note. I have updated the comment earlier. I have a better fix now.

@ashahab
Copy link
Contributor Author

ashahab commented May 6, 2021 via email

@ashahab
Copy link
Contributor Author

ashahab commented May 6, 2021

@yangustc07 do you have a fix?

@yangustc07
Copy link
Member

yangustc07 commented May 6, 2021

Yes, I just submitted 858a569. I'm trying to see why it changed my commit message to "internal change." The original description was:

[tf.data] Fix snapshot segfault when using repeat and prefetch.

Fixes: https://github.com/tensorflow/tensorflow/issues/48903.

`input_->MakeIterator` refs the dataset in
https://github.com/tensorflow/tensorflow/blob/a9cf3a0e4b419630f0183b0cc4e48e0641a62721/tensorflow/core/framework/dataset.cc#L679. So
we don't need to call `input_->Ref()`. Otherwise, if
`SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize` returns an error,
`input_->Ref()` isn't called, but the destructor still calls `input_->Unref()`.

If `InitializeIterator` returns an error, the iterator_ needs to be reset to
nullptr. Otherwise, if GetNextInternal is called a second time,
`iterator_->GetNext` may dereference a null `input_impl_`.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:data tf.data related issues TF 2.4 for issues related to TF 2.4 type:bug Bug
Projects
None yet
4 participants