diff options
author | 2018-07-12 13:10:15 -0700 | |
---|---|---|
committer | 2018-07-12 13:23:02 -0700 | |
commit | 0ef634190dc2e49e4002a841185fc850b80cc1b9 (patch) | |
tree | 4bf6758b9681210cfb2b8316561af692be22cb5b /tensorflow | |
parent | 513af8ee971237267c00b3e0f9f0dea05503f70a (diff) |
[tf.data] Handling checkpointing of optimized input pipelines correctly.
PiperOrigin-RevId: 204350306
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py | 6 | ||||
-rw-r--r-- | tensorflow/core/kernels/data/optimize_dataset_op.cc | 40 |
2 files changed, 29 insertions, 17 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py index 3bb9723bbc..21eebccd11 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -35,8 +35,6 @@ class OptimizeDatasetTest(test.TestCase): with self.test_session() as sess: graph = graph_pb2.GraphDef().FromString( sess.run(dataset._as_serialized_graph())) - self.assertTrue( - all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -50,8 +48,6 @@ class OptimizeDatasetTest(test.TestCase): with self.test_session() as sess: graph = graph_pb2.GraphDef().FromString( sess.run(dataset._as_serialized_graph())) - self.assertTrue( - all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -65,8 +61,6 @@ class OptimizeDatasetTest(test.TestCase): with self.test_session() as sess: graph = graph_pb2.GraphDef().FromString( sess.run(dataset._as_serialized_graph())) - self.assertTrue( - any([node.op == "MapAndBatchDatasetV2" for node in graph.node])) self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index 81be69105e..276f5f89c8 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -53,23 +53,30 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES_OK( ctx, ParseVectorArgument<string>(ctx, "optimizations", &optimizations)); Dataset* dataset = - new Dataset(ctx, optimizations, output_types_, output_shapes_); - OP_REQUIRES_OK(ctx, dataset->Optimize(ctx, input)); + new Dataset(ctx, input, optimizations, output_types_, output_shapes_); + OP_REQUIRES_OK(ctx, dataset->Optimize(ctx)); *output = dataset; } private: class Dataset : public GraphDatasetBase { public: - Dataset(OpKernelContext* ctx, const std::vector<string>& optimizations, + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const std::vector<string>& optimizations, const DataTypeVector& output_types, const std::vector<PartialTensorShape>& output_shapes) : GraphDatasetBase(ctx), + input_(input), optimizations_(optimizations), output_types_(output_types), - output_shapes_(output_shapes) {} + output_shapes_(output_shapes) { + input_->Ref(); + } - ~Dataset() override { input_->Unref(); } + ~Dataset() override { + input_->Unref(); + optimized_input_->Unref(); + } std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { @@ -77,15 +84,17 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { new Iterator({this, strings::StrCat(prefix, "::Optimize")})); } - Status Optimize(OpKernelContext* ctx, const DatasetBase* input) { + Status Optimize(OpKernelContext* ctx) { GraphDefBuilder b; DatasetGraphDefBuilder db(&b); Node* input_node = nullptr; - TF_RETURN_IF_ERROR(db.AddParentDataset(ctx, input, &input_node)); + TF_RETURN_IF_ERROR(db.AddParentDataset(ctx, input_, &input_node)); string output_node = input_node->name(); GraphDef graph_def; TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def)); + VLOG(3) << "Before optimization: " << graph_def.DebugString(); TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node)); + VLOG(3) << "After optimization: " << graph_def.DebugString(); flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), graph_def.library())); Graph graph(OpRegistry::Global()); @@ -94,8 +103,9 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { GraphRunner graph_runner(ctx->function_library()->device()); TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {}, {output_node}, &outputs)); - TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &input_)); - input_->Ref(); + TF_RETURN_IF_ERROR( + GetDatasetFromVariantTensor(outputs[0], &optimized_input_)); + optimized_input_->Ref(); return Status::OK(); } @@ -127,7 +137,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator<Dataset>(params) {} Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + return dataset()->optimized_input_->MakeIterator(ctx, prefix(), + &input_impl_); } Status GetNextInternal(IteratorContext* ctx, @@ -199,6 +210,12 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { tensorflow::grappler::VirtualCluster cluster(device_map); // Run optimizer. + if (VLOG_IS_ON(2)) { + LOG(INFO) << "Performing the following optimizations:"; + for (const string& optimization : optimizations_) { + LOG(INFO) << " " << optimization; + } + } TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer( *grappler_item, rewriter_config, ctx->device(), &cluster, graph_def)); @@ -213,8 +230,9 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - DatasetBase* input_; + DatasetBase* optimized_input_; std::shared_ptr<FunctionLibraryDefinition> flib_def_; + const DatasetBase* input_; const std::vector<string> optimizations_; const DataTypeVector output_types_; const std::vector<PartialTensorShape> output_shapes_; |