diff options
Diffstat (limited to 'tensorflow/core/kernels/data/optimize_dataset_op.cc')
-rw-r--r-- | tensorflow/core/kernels/data/optimize_dataset_op.cc | 92 |
1 files changed, 66 insertions, 26 deletions
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index 8965858e8d..276f5f89c8 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -54,8 +54,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { ctx, ParseVectorArgument<string>(ctx, "optimizations", &optimizations)); Dataset* dataset = new Dataset(ctx, input, optimizations, output_types_, output_shapes_); - core::ScopedUnref unref(dataset); - OP_REQUIRES_OK(ctx, dataset->Optimize(ctx, output)); + OP_REQUIRES_OK(ctx, dataset->Optimize(ctx)); + *output = dataset; } private: @@ -73,7 +73,10 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { input_->Ref(); } - ~Dataset() override { input_->Unref(); } + ~Dataset() override { + input_->Unref(); + optimized_input_->Unref(); + } std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { @@ -81,7 +84,7 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { new Iterator({this, strings::StrCat(prefix, "::Optimize")})); } - Status Optimize(OpKernelContext* ctx, DatasetBase** output) { + Status Optimize(OpKernelContext* ctx) { GraphDefBuilder b; DatasetGraphDefBuilder db(&b); Node* input_node = nullptr; @@ -89,18 +92,20 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { string output_node = input_node->name(); GraphDef graph_def; TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def)); + VLOG(3) << "Before optimization: " << graph_def.DebugString(); TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node)); - + VLOG(3) << "After optimization: " << graph_def.DebugString(); + flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), + graph_def.library())); Graph graph(OpRegistry::Global()); TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr)); std::vector<Tensor> outputs; - GraphRunner graph_runner(ctx->env()); - // Once rewrites that add/modify functions are introduced, we will need - // persist the results in a function library runtime. + GraphRunner graph_runner(ctx->function_library()->device()); TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {}, {output_node}, &outputs)); - TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], output)); - (*output)->Ref(); + TF_RETURN_IF_ERROR( + GetDatasetFromVariantTensor(outputs[0], &optimized_input_)); + optimized_input_->Ref(); return Status::OK(); } @@ -113,6 +118,18 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { string DebugString() const override { return "OptimizeDatasetOp::Dataset"; } + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); + Node* optimizations_node = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(optimizations_, &optimizations_node)); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {input_graph_node, optimizations_node}, output)); + return Status::OK(); + } + private: class Iterator : public DatasetIterator<Dataset> { public: @@ -120,15 +137,38 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator<Dataset>(params) {} Status Initialize(IteratorContext* ctx) override { - return errors::Unimplemented(strings::StrCat(prefix(), "::Initialize")); + return dataset()->optimized_input_->MakeIterator(ctx, prefix(), + &input_impl_); } Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors, bool* end_of_sequence) override { - return errors::Unimplemented( - strings::StrCat(prefix(), "::GetNextInternal")); + IteratorContext::Params params; + params.env = ctx->env(); + params.runner = *(ctx->runner()); + params.stats_aggregator_getter = ctx->stats_aggregator_getter(); + params.lib = ctx->lib(); + params.function_library = dataset()->flib_def_; + params.allocator_getter = ctx->allocator_getter(); + IteratorContext iter_ctx(params); + return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + return Status::OK(); } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + return Status::OK(); + } + + private: + std::unique_ptr<IteratorBase> input_impl_; }; Status ApplyOptimizations(OpKernelContext* ctx, GraphDef* graph_def, @@ -136,16 +176,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { // Add a fake sink node to allow rewriting the actual sink node. NodeDef* node = graph_def->mutable_node()->Add(); node->set_name("FakeSink"); - node->set_op("IdentityDataset"); + node->set_op("SinkDataset"); node->add_input(*output_node); - { - grappler::GraphView graph(graph_def); - NodeDef* sink = graph.GetNode(*output_node); - (*node->mutable_attr())["output_shapes"] = - sink->attr().at("output_shapes"); - (*node->mutable_attr())["output_types"] = - sink->attr().at("output_types"); - } // Create metagraph. MetaGraphDef meta_graph_def; @@ -162,10 +194,10 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { for (const string& optimization : optimizations_) { rewriter_config.add_optimizers(optimization); } - // If no optimizations were specified, supply a non-existent optimization - // to prevent Grappler from applying the default set of optimizations as - // some of them do not work out of the box at the moment (e.g. because we - // have no cost model for dataset ops). + // If no optimizations were specified, supply a non-existent + // optimization to prevent Grappler from applying the default set of + // optimizations as some of them do not work out of the box at the + // moment (e.g. because we have no cost model for dataset ops). if (optimizations_.empty()) { rewriter_config.add_optimizers("non-existent"); } @@ -178,6 +210,12 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { tensorflow::grappler::VirtualCluster cluster(device_map); // Run optimizer. + if (VLOG_IS_ON(2)) { + LOG(INFO) << "Performing the following optimizations:"; + for (const string& optimization : optimizations_) { + LOG(INFO) << " " << optimization; + } + } TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer( *grappler_item, rewriter_config, ctx->device(), &cluster, graph_def)); @@ -192,6 +230,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } + DatasetBase* optimized_input_; + std::shared_ptr<FunctionLibraryDefinition> flib_def_; const DatasetBase* input_; const std::vector<string> optimizations_; const DataTypeVector output_types_; |