Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolves coredump caused by tf.data.experimental.save with prefetch #49383

Merged
merged 1 commit into from Aug 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions tensorflow/core/kernels/data/experimental/io_ops.cc
Expand Up @@ -253,7 +253,11 @@ class LoadDatasetOp::Dataset : public DatasetBase {
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}

~Iterator() override { input_->Unref(); }
~Iterator() override {
if (input_) {
input_->Unref();
}
}

Status Initialize(IteratorContext* ctx) override {
mutex_lock l(mu_);
Expand Down Expand Up @@ -331,7 +335,7 @@ class LoadDatasetOp::Dataset : public DatasetBase {
}

mutex mu_;
DatasetBase* input_ TF_GUARDED_BY(mu_);
DatasetBase* input_ TF_GUARDED_BY(mu_) = nullptr;
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
};
Expand Down
15 changes: 15 additions & 0 deletions tensorflow/python/data/experimental/kernel_tests/io_test.py
Expand Up @@ -17,6 +17,7 @@
from __future__ import division
from __future__ import print_function

import numpy as np
import os
import shutil

Expand Down Expand Up @@ -111,6 +112,20 @@ def testOptionalElementSpec(self):
dataset_loaded = io.load(self._test_dir)
self.assertDatasetsEqual(dataset, dataset_loaded)

@combinations.generate(test_base.eager_only_combinations())
def testRepeatAndPrefetch(self):
"""This test reproduces github.com/tensorflow/tensorflow/issues/49165"""
dataset1 = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32))
io.save(dataset1, self._test_dir)
dataset = io.load(self._test_dir)
dataset = dataset.shuffle(buffer_size=16)
dataset = dataset.batch(16)
dataset = dataset.repeat()
dataset = dataset.prefetch(1)
next_element = self.getNext(dataset)
for _ in range(30):
self.evaluate(next_element())


if __name__ == "__main__":
test.main()