aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/optimize_dataset_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/optimize_dataset_op.cc')
-rw-r--r--tensorflow/core/kernels/data/optimize_dataset_op.cc92
1 files changed, 66 insertions, 26 deletions
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index 8965858e8d..276f5f89c8 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -54,8 +54,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
ctx, ParseVectorArgument<string>(ctx, "optimizations", &optimizations));
Dataset* dataset =
new Dataset(ctx, input, optimizations, output_types_, output_shapes_);
- core::ScopedUnref unref(dataset);
- OP_REQUIRES_OK(ctx, dataset->Optimize(ctx, output));
+ OP_REQUIRES_OK(ctx, dataset->Optimize(ctx));
+ *output = dataset;
}
private:
@@ -73,7 +73,10 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
input_->Ref();
}
- ~Dataset() override { input_->Unref(); }
+ ~Dataset() override {
+ input_->Unref();
+ optimized_input_->Unref();
+ }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
@@ -81,7 +84,7 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
new Iterator({this, strings::StrCat(prefix, "::Optimize")}));
}
- Status Optimize(OpKernelContext* ctx, DatasetBase** output) {
+ Status Optimize(OpKernelContext* ctx) {
GraphDefBuilder b;
DatasetGraphDefBuilder db(&b);
Node* input_node = nullptr;
@@ -89,18 +92,20 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
string output_node = input_node->name();
GraphDef graph_def;
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
+ VLOG(3) << "Before optimization: " << graph_def.DebugString();
TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node));
-
+ VLOG(3) << "After optimization: " << graph_def.DebugString();
+ flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
+ graph_def.library()));
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
std::vector<Tensor> outputs;
- GraphRunner graph_runner(ctx->env());
- // Once rewrites that add/modify functions are introduced, we will need
- // persist the results in a function library runtime.
+ GraphRunner graph_runner(ctx->function_library()->device());
TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {},
{output_node}, &outputs));
- TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], output));
- (*output)->Ref();
+ TF_RETURN_IF_ERROR(
+ GetDatasetFromVariantTensor(outputs[0], &optimized_input_));
+ optimized_input_->Ref();
return Status::OK();
}
@@ -113,6 +118,18 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "OptimizeDatasetOp::Dataset"; }
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
+ Node* optimizations_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddVector(optimizations_, &optimizations_node));
+ TF_RETURN_IF_ERROR(
+ b->AddDataset(this, {input_graph_node, optimizations_node}, output));
+ return Status::OK();
+ }
+
private:
class Iterator : public DatasetIterator<Dataset> {
public:
@@ -120,15 +137,38 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return errors::Unimplemented(strings::StrCat(prefix(), "::Initialize"));
+ return dataset()->optimized_input_->MakeIterator(ctx, prefix(),
+ &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
- return errors::Unimplemented(
- strings::StrCat(prefix(), "::GetNextInternal"));
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = *(ctx->runner());
+ params.stats_aggregator_getter = ctx->stats_aggregator_getter();
+ params.lib = ctx->lib();
+ params.function_library = dataset()->flib_def_;
+ params.allocator_getter = ctx->allocator_getter();
+ IteratorContext iter_ctx(params);
+ return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ return Status::OK();
}
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ private:
+ std::unique_ptr<IteratorBase> input_impl_;
};
Status ApplyOptimizations(OpKernelContext* ctx, GraphDef* graph_def,
@@ -136,16 +176,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
// Add a fake sink node to allow rewriting the actual sink node.
NodeDef* node = graph_def->mutable_node()->Add();
node->set_name("FakeSink");
- node->set_op("IdentityDataset");
+ node->set_op("SinkDataset");
node->add_input(*output_node);
- {
- grappler::GraphView graph(graph_def);
- NodeDef* sink = graph.GetNode(*output_node);
- (*node->mutable_attr())["output_shapes"] =
- sink->attr().at("output_shapes");
- (*node->mutable_attr())["output_types"] =
- sink->attr().at("output_types");
- }
// Create metagraph.
MetaGraphDef meta_graph_def;
@@ -162,10 +194,10 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
for (const string& optimization : optimizations_) {
rewriter_config.add_optimizers(optimization);
}
- // If no optimizations were specified, supply a non-existent optimization
- // to prevent Grappler from applying the default set of optimizations as
- // some of them do not work out of the box at the moment (e.g. because we
- // have no cost model for dataset ops).
+ // If no optimizations were specified, supply a non-existent
+ // optimization to prevent Grappler from applying the default set of
+ // optimizations as some of them do not work out of the box at the
+ // moment (e.g. because we have no cost model for dataset ops).
if (optimizations_.empty()) {
rewriter_config.add_optimizers("non-existent");
}
@@ -178,6 +210,12 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
tensorflow::grappler::VirtualCluster cluster(device_map);
// Run optimizer.
+ if (VLOG_IS_ON(2)) {
+ LOG(INFO) << "Performing the following optimizations:";
+ for (const string& optimization : optimizations_) {
+ LOG(INFO) << " " << optimization;
+ }
+ }
TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
*grappler_item, rewriter_config, ctx->device(), &cluster, graph_def));
@@ -192,6 +230,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
+ DatasetBase* optimized_input_;
+ std::shared_ptr<FunctionLibraryDefinition> flib_def_;
const DatasetBase* input_;
const std::vector<string> optimizations_;
const DataTypeVector output_types_;