aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-07-12 13:10:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-12 13:23:02 -0700
commit0ef634190dc2e49e4002a841185fc850b80cc1b9 (patch)
tree4bf6758b9681210cfb2b8316561af692be22cb5b /tensorflow
parent513af8ee971237267c00b3e0f9f0dea05503f70a (diff)
[tf.data] Handling checkpointing of optimized input pipelines correctly.
PiperOrigin-RevId: 204350306
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py6
-rw-r--r--tensorflow/core/kernels/data/optimize_dataset_op.cc40
2 files changed, 29 insertions, 17 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
index 3bb9723bbc..21eebccd11 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
@@ -35,8 +35,6 @@ class OptimizeDatasetTest(test.TestCase):
with self.test_session() as sess:
graph = graph_pb2.GraphDef().FromString(
sess.run(dataset._as_serialized_graph()))
- self.assertTrue(
- all([node.op != "MapAndBatchDatasetV2" for node in graph.node]))
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -50,8 +48,6 @@ class OptimizeDatasetTest(test.TestCase):
with self.test_session() as sess:
graph = graph_pb2.GraphDef().FromString(
sess.run(dataset._as_serialized_graph()))
- self.assertTrue(
- all([node.op != "MapAndBatchDatasetV2" for node in graph.node]))
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -65,8 +61,6 @@ class OptimizeDatasetTest(test.TestCase):
with self.test_session() as sess:
graph = graph_pb2.GraphDef().FromString(
sess.run(dataset._as_serialized_graph()))
- self.assertTrue(
- any([node.op == "MapAndBatchDatasetV2" for node in graph.node]))
self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index 81be69105e..276f5f89c8 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -53,23 +53,30 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(
ctx, ParseVectorArgument<string>(ctx, "optimizations", &optimizations));
Dataset* dataset =
- new Dataset(ctx, optimizations, output_types_, output_shapes_);
- OP_REQUIRES_OK(ctx, dataset->Optimize(ctx, input));
+ new Dataset(ctx, input, optimizations, output_types_, output_shapes_);
+ OP_REQUIRES_OK(ctx, dataset->Optimize(ctx));
*output = dataset;
}
private:
class Dataset : public GraphDatasetBase {
public:
- Dataset(OpKernelContext* ctx, const std::vector<string>& optimizations,
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const std::vector<string>& optimizations,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
: GraphDatasetBase(ctx),
+ input_(input),
optimizations_(optimizations),
output_types_(output_types),
- output_shapes_(output_shapes) {}
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
- ~Dataset() override { input_->Unref(); }
+ ~Dataset() override {
+ input_->Unref();
+ optimized_input_->Unref();
+ }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
@@ -77,15 +84,17 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
new Iterator({this, strings::StrCat(prefix, "::Optimize")}));
}
- Status Optimize(OpKernelContext* ctx, const DatasetBase* input) {
+ Status Optimize(OpKernelContext* ctx) {
GraphDefBuilder b;
DatasetGraphDefBuilder db(&b);
Node* input_node = nullptr;
- TF_RETURN_IF_ERROR(db.AddParentDataset(ctx, input, &input_node));
+ TF_RETURN_IF_ERROR(db.AddParentDataset(ctx, input_, &input_node));
string output_node = input_node->name();
GraphDef graph_def;
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
+ VLOG(3) << "Before optimization: " << graph_def.DebugString();
TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node));
+ VLOG(3) << "After optimization: " << graph_def.DebugString();
flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
graph_def.library()));
Graph graph(OpRegistry::Global());
@@ -94,8 +103,9 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
GraphRunner graph_runner(ctx->function_library()->device());
TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {},
{output_node}, &outputs));
- TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &input_));
- input_->Ref();
+ TF_RETURN_IF_ERROR(
+ GetDatasetFromVariantTensor(outputs[0], &optimized_input_));
+ optimized_input_->Ref();
return Status::OK();
}
@@ -127,7 +137,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ return dataset()->optimized_input_->MakeIterator(ctx, prefix(),
+ &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
@@ -199,6 +210,12 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
tensorflow::grappler::VirtualCluster cluster(device_map);
// Run optimizer.
+ if (VLOG_IS_ON(2)) {
+ LOG(INFO) << "Performing the following optimizations:";
+ for (const string& optimization : optimizations_) {
+ LOG(INFO) << " " << optimization;
+ }
+ }
TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
*grappler_item, rewriter_config, ctx->device(), &cluster, graph_def));
@@ -213,8 +230,9 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- DatasetBase* input_;
+ DatasetBase* optimized_input_;
std::shared_ptr<FunctionLibraryDefinition> flib_def_;
+ const DatasetBase* input_;
const std::vector<string> optimizations_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;