aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/generator_dataset_op.cc
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2018-09-26 09:13:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 09:22:32 -0700
commitc3203eb8bf0d7ae9dce133f982884622f666c681 (patch)
treeae3279ca47a974e2d4d5ef7a06ed4efc4246a3dc /tensorflow/core/kernels/data/generator_dataset_op.cc
parent01512356e10ab87887e3c7b69f9ed3e5a8397f76 (diff)
The GeneratorDataset init function was being run during Initialization which is a blocking Op. Moving it to the GetNext call which is a non blocking async op.
This means we can now revert back to the default of 2 threads per GPU. PiperOrigin-RevId: 214614839
Diffstat (limited to 'tensorflow/core/kernels/data/generator_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc9
1 files changed, 7 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index 71a36314a0..b4367d5a11 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -86,8 +86,6 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
TF_RETURN_IF_ERROR(dataset()->init_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->next_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
- TF_RETURN_IF_ERROR(
- dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
return Status::OK();
}
@@ -96,6 +94,12 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
bool* end_of_sequence) override {
mutex_lock l(mu_);
+ if (!initialized_) {
+ TF_RETURN_IF_ERROR(
+ dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
+ initialized_ = true;
+ }
+
if (finalized_) {
*end_of_sequence = true;
return Status::OK();
@@ -123,6 +127,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
private:
mutex mu_;
+ bool initialized_ GUARDED_BY(mu_) = false;
bool finalized_ GUARDED_BY(mu_) = false;
std::vector<Tensor> state_ GUARDED_BY(mu_);
};