diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-02-02 16:51:43 -0800 |
---|---|---|
committer | Vijay Vasudevan <vrv@google.com> | 2016-02-03 09:34:58 -0800 |
commit | e830638148e203a2d9cf491e4693d35661a360d1 (patch) | |
tree | 11368b2e6b8a9d125f37050484da640983dbecfc | |
parent | 08b09699ad5f8b57add1f83461568b647c42129a (diff) |
Refactor the logic to apply optimization into a common module.
Change: 113692577
-rw-r--r-- | tensorflow/core/common_runtime/direct_session.cc | 45 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/function.cc | 143 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/function.h | 4 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/function_test.cc | 3 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/graph_optimizer.cc | 91 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/graph_optimizer.h | 43 | ||||
-rw-r--r-- | tensorflow/core/framework/config.proto | 56 | ||||
-rw-r--r-- | tensorflow/core/graph/algorithm.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/graph/algorithm.h | 4 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor.cc | 21 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor.h | 10 | ||||
-rw-r--r-- | tensorflow/core/graph/optimizer_cse.cc | 11 | ||||
-rw-r--r-- | tensorflow/core/graph/optimizer_cse.h | 4 | ||||
-rw-r--r-- | tensorflow/core/ops/array_grad.cc | 23 | ||||
-rw-r--r-- | tensorflow/python/framework/function_test.py | 50 | ||||
-rw-r--r-- | tensorflow/python/framework/gen_docs_combined.py | 2 |
16 files changed, 334 insertions, 182 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 17eb94f495..3e775c5c48 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/session_factory.h" #include "tensorflow/core/common_runtime/simple_placer.h" #include "tensorflow/core/framework/function.h" @@ -612,6 +613,9 @@ Status DirectSession::GetOrCreateExecutors( ek->func_defs = fdefs; ek->items.reserve(graphs.size()); auto runner = [this](Executor::Args::Closure c) { SchedClosure(c); }; + const auto& optimizer_opts = + options_.config.graph_options().optimizer_options(); + GraphOptimizer optimizer(optimizer_opts); for (const auto& graph : graphs) { const string& partition_name = graph.first; Graph* partition_graph = graph.second; @@ -623,8 +627,8 @@ Status DirectSession::GetOrCreateExecutors( ek->items.resize(ek->items.size() + 1); auto* item = &(ek->items.back()); - item->flib = - NewFunctionLibraryRuntime(device, runner, graph_def_version, fdefs); + item->flib = NewFunctionLibraryRuntime(device, runner, graph_def_version, + fdefs, optimizer_opts); LocalExecutorParams params; params.device = device; @@ -646,6 +650,7 @@ Status DirectSession::GetOrCreateExecutors( // Do nothing because 'kernel' is owned by opseg above. }; + optimizer.Optimize(lib, &partition_graph); s = NewLocalExecutor(params, partition_graph, &item->executor); if (!s.ok()) { return s; @@ -715,38 +720,11 @@ Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds, RunStateArgs* run_state_args) { std::unique_ptr<FunctionLibraryDefinition> fdefs; std::unique_ptr<Graph> graph; - GraphConstructorOptions opts{ - options_.config.graph_options().optimizer_options()}; - - std::unordered_set<StringPiece, StringPiece::Hasher> keep_nodes; - for (const string& feed : feeds) { - keep_nodes.insert(ParseTensorName(feed).first); - } - for (const string& fetch : fetches) { - keep_nodes.insert(ParseTensorName(fetch).first); - } - for (const string& target_node : target_nodes) { - keep_nodes.insert(target_node); - } - - if (opts.optimizer_do_cse) { - // Prevent CSE from eliminating nodes that will be required during - // RewriteGraphForExecution, below. - opts.cse_consider_function = [&keep_nodes](const Node* n) { - return n->IsConstant() && !keep_nodes.count(n->name()); - }; - } - - if (opts.optimizer_do_constant_folding) { - opts.constant_folding_opts.consider = [&keep_nodes](const Node* n) { - return keep_nodes.count(n->name()) > 0; - }; - } - { mutex_lock l(graph_def_lock_); fdefs.reset(new FunctionLibraryDefinition(graph_def_.library())); graph.reset(new Graph(fdefs.get())); + GraphConstructorOptions opts; TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def_, graph.get())); } @@ -760,13 +738,6 @@ Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds, graph.get(), feeds, fetches, target_nodes, device_set_.client_device()->attributes())); - if (opts.optimizer_do_constant_folding) { - bool constant_folded = - DoConstantFolding(opts.constant_folding_opts, graph.get()); - VLOG(2) << (constant_folded ? "Folded some constants" - : "Found no constant folding opportunity"); - } - GraphDef graph_def; graph->ToGraphDef(&graph_def); diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 85052ae132..d83b293c2a 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" @@ -238,13 +239,32 @@ class RetvalOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp); REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_GPU), RetvalOp); +class PassOn : public OpKernel { + public: + explicit PassOn(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(), + errors::Internal("#inputs != #outputs : ", ctx->num_inputs(), + " vs. ", ctx->num_outputs())); + for (int i = 0; i < ctx->num_inputs(); ++i) { + ctx->set_output(i, ctx->input(i)); + } + } +}; +REGISTER_KERNEL_BUILDER(Name("_ListToArray").Device(DEVICE_CPU), PassOn); +REGISTER_KERNEL_BUILDER(Name("_ListToArray").Device(DEVICE_GPU), PassOn); +REGISTER_KERNEL_BUILDER(Name("_ArrayToList").Device(DEVICE_CPU), PassOn); +REGISTER_KERNEL_BUILDER(Name("_ArrayToList").Device(DEVICE_GPU), PassOn); + static const FunctionLibraryRuntime::Handle kInvalidHandle = -1; class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { public: FunctionLibraryRuntimeImpl(Device* device, Runner runner, int graph_def_version, - const FunctionLibraryDefinition* lib_def); + const FunctionLibraryDefinition* lib_def, + const OptimizerOptions& optimizer_options); ~FunctionLibraryRuntimeImpl() override; @@ -268,6 +288,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { Runner runner_ = nullptr; const int graph_def_version_; const FunctionLibraryDefinition* const lib_def_; + GraphOptimizer optimizer_; std::function<Status(const string&, const OpDef**)> get_func_sig_; std::function<Status(const NodeDef&, OpKernel**)> create_kernel_; @@ -303,11 +324,13 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl( Device* device, Runner runner, int graph_def_version, - const FunctionLibraryDefinition* lib_def) + const FunctionLibraryDefinition* lib_def, + const OptimizerOptions& optimizer_options) : device_(device), runner_(runner), graph_def_version_(graph_def_version), - lib_def_(lib_def) { + lib_def_(lib_def), + optimizer_(optimizer_options) { get_func_sig_ = [this](const string& op, const OpDef** sig) { Status s; *sig = lib_def_->LookUp(op, &s); @@ -366,19 +389,21 @@ class CallOp : public AsyncOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(CallOp); }; -class SymbolicGradientOp : public OpKernel { +class SymbolicGradientOp : public AsyncOpKernel { public: SymbolicGradientOp(OpKernelConstruction* ctx) - : OpKernel(ctx), handle_(kInvalidHandle) {} + : AsyncOpKernel(ctx), handle_(kInvalidHandle) {} ~SymbolicGradientOp() override {} - void Compute(OpKernelContext* ctx) override { + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { FunctionLibraryRuntime* lib = ctx->function_library(); - OP_REQUIRES(ctx, lib != nullptr, - errors::Internal("No function library is provided.")); + OP_REQUIRES_ASYNC(ctx, lib != nullptr, + errors::Internal("No function library is provided."), + done); - OP_REQUIRES_OK(ctx, lib->Instantiate(kGradientOp, def().attr(), &handle_)); + OP_REQUIRES_OK_ASYNC( + ctx, lib->Instantiate(kGradientOp, def().attr(), &handle_), done); FunctionLibraryRuntime::Options opts; std::vector<Tensor> args; @@ -387,20 +412,19 @@ class SymbolicGradientOp : public OpKernel { args.push_back(ctx->input(i)); } std::vector<Tensor>* rets = new std::vector<Tensor>; - Notification n; - lib->Run(opts, handle_, args, rets, [ctx, rets, &n](const Status& status) { - if (!status.ok()) { - ctx->SetStatus(status); - } else { - CHECK_EQ(rets->size(), ctx->num_outputs()); - for (size_t i = 0; i < rets->size(); ++i) { - ctx->set_output(i, (*rets)[i]); - } - } - delete rets; - n.Notify(); - }); - n.WaitForNotification(); + lib->Run(opts, handle_, args, rets, + [ctx, done, rets](const Status& status) { + if (!status.ok()) { + ctx->SetStatus(status); + } else { + CHECK_EQ(rets->size(), ctx->num_outputs()); + for (size_t i = 0; i < rets->size(); ++i) { + ctx->set_output(i, (*rets)[i]); + } + } + delete rets; + done(); + }); } private: @@ -541,7 +565,8 @@ Status FunctionLibraryRuntimeImpl::Instantiate( void DumpGraph(StringPiece label, const Graph* g) { // TODO(zhifengc): Change Graph to record #nodes. - VLOG(1) << "Graph " << label << " #edges " << g->edges().size(); + VLOG(1) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges " + << g->edges().size(); if (VLOG_IS_ON(2)) { for (const auto& line : str_util::Split(DebugString(g), '\n')) { VLOG(2) << "|| " << line; @@ -549,54 +574,13 @@ void DumpGraph(StringPiece label, const Graph* g) { } } -static void SimplifyGraph(Graph* g) { - if (RemoveListArrayConverter(g)) { - DumpGraph("RemoveListArrayConverter", g); - } - bool changed; - do { - changed = false; - if (RemoveDeadNodes(g)) { - changed = true; - DumpGraph("RemoveDeadNodes", g); - } - if (RemoveIdentityNodes(g)) { - changed = true; - DumpGraph("RemoveIdentityNodes", g); - } - FixupSourceAndSinkEdges(g); - OptimizeCSE(g, nullptr); - DumpGraph("OptimizeCSE", g); - } while (changed); -} - void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g) { - for (const Node* n : (*g)->nodes()) { - if (n->IsControlFlow()) { - VLOG(2) << "Skip OptimizeGraph due to control flow ops."; - return; - } - } - - DumpGraph("Initial", *g); - - // Run SimplifyGraph at least once to rewrite away ops such as - // _ListToArray, _ArrayToList, etc. - SimplifyGraph(*g); - - const int kNumInlineRounds = 10; - for (int i = 0; i < kNumInlineRounds; ++i) { - if (!ExpandInlineFunctions(lib, *g)) break; - DumpGraph("ExpandInlineFunctions", *g); - SimplifyGraph(*g); - } - - // Makes a copy so that we densify node ids. - Graph* copy = new Graph((*g)->op_registry()); - CopyGraph(**g, copy); - delete *g; - *g = copy; - DumpGraph("ReCopy", *g); + OptimizerOptions opts; + opts.set_do_common_subexpression_elimination(true); + opts.set_do_function_inlining(true); + opts.set_do_constant_folding(true); + GraphOptimizer optimizer(opts); + optimizer.Optimize(lib, g); } Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { @@ -604,7 +588,8 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { CHECK_NOTNULL(fbody); Graph* g = new Graph(lib_def_); CopyGraph(*fbody->graph, g); - OptimizeGraph(this, &g); + + optimizer_.Optimize(this, &g); // Creates an executor based on the g. This must be done without // holding mu_ because create_kernel_ calls back into the library. @@ -697,12 +682,14 @@ bool FunctionLibraryRuntimeImpl::IsDefined(const string& function_name) { FunctionLibraryRuntime* NewFunctionLibraryRuntime( Device* device, Runner runner, int graph_def_version, - const FunctionLibraryDefinition* lib_def) { + const FunctionLibraryDefinition* lib_def, + const OptimizerOptions& optimizer_options) { return new FunctionLibraryRuntimeImpl(device, runner, graph_def_version, - lib_def); + lib_def, optimizer_options); } bool RemoveDeadNodes(Graph* g) { + VLOG(2) << "Removing dead nodes"; std::vector<bool> visited(g->num_node_ids(), false); std::deque<Node*> q; for (auto n : g->nodes()) { @@ -750,6 +737,7 @@ const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) { } // end namespace bool RemoveIdentityNodes(Graph* g) { + VLOG(2) << "Removing identity nodes"; bool removed_any = false; gtl::InlinedVector<Node*, 8> matches; for (Node* n : g->nodes()) { @@ -775,6 +763,7 @@ bool RemoveIdentityNodes(Graph* g) { } bool RemoveListArrayConverter(Graph* g) { + VLOG(2) << "Removing list array converter"; gtl::InlinedVector<Node*, 8> matches; for (Node* n : g->nodes()) { if ((n->type_string() == "_ListToArray") || @@ -904,10 +893,14 @@ static void InlineFunctionBody(Graph* g, Node* caller, // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we // remember 'y' in node_map[x->id()]. std::vector<Node*> node_map(fbody->graph->num_node_ids()); + Status s; for (Node* n : fbody->graph->nodes()) { if (n->IsSource() || n->IsSink()) continue; CHECK(n->IsOp()); - node_map[n->id()] = g->CopyNode(n); + NodeDef ndef = n->def(); + ndef.set_name(strings::StrCat(caller->name(), "/", ndef.name())); + node_map[n->id()] = g->AddNode(ndef, &s); + TF_CHECK_OK(s); } for (const Edge* e : fbody->graph->edges()) { if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() || diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h index d8c93fa313..ae19da1b34 100644 --- a/tensorflow/core/common_runtime/function.h +++ b/tensorflow/core/common_runtime/function.h @@ -19,6 +19,7 @@ limitations under the License. #include <functional> #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/config.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/graph/graph.h" @@ -34,7 +35,8 @@ typedef std::function<void()> Closure; typedef std::function<void(Closure)> Runner; FunctionLibraryRuntime* NewFunctionLibraryRuntime( Device* device, Runner runner, int graph_def_version, - const FunctionLibraryDefinition* lib_def); + const FunctionLibraryDefinition* lib_def, + const OptimizerOptions& optimizer_options); // FunctionLibraryRuntime::GetFunctionBody returns a description of an // instantiated function that is represented as a Graph with arg/ret diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 1d27864a3b..e8e77256ac 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -148,8 +148,9 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { delete lib_def_; lib_def_ = new FunctionLibraryDefinition(proto); delete lib_; + OptimizerOptions opts; lib_ = NewFunctionLibraryRuntime(device_, FunctionTestSchedClosure, - TF_GRAPH_DEF_VERSION, lib_def_); + TF_GRAPH_DEF_VERSION, lib_def_, opts); } Status Run(const string& name, InstantiateAttrValueSlice attrs, diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc new file mode 100644 index 0000000000..d5a8fd604f --- /dev/null +++ b/tensorflow/core/common_runtime/graph_optimizer.cc @@ -0,0 +1,91 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/graph_optimizer.h" + +#include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/optimizer_cse.h" + +namespace tensorflow { + +GraphOptimizer::GraphOptimizer(const OptimizerOptions& opts) : opts_(opts) { + if (opts_.opt_level() >= OptimizerOptions::L1) { + opts_.set_do_common_subexpression_elimination(true); + } + if (opts_.opt_level() >= OptimizerOptions::L2) { + opts_.set_do_constant_folding(true); + } +} + +GraphOptimizer::~GraphOptimizer() {} + +void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Graph** graph) { + Graph* g = *graph; + for (const Node* n : g->nodes()) { + if (n->IsControlFlow()) { + VLOG(2) << "Skip optimization if there is any control flow ops"; + } + } + + DumpGraph("Initial", g); + bool changed = true; + const int kMaxRounds = 10; + for (int rounds = 0; rounds < kMaxRounds; ++rounds) { + changed = false; + if (opts_.do_function_inlining() && RemoveListArrayConverter(g)) { + DumpGraph("RemoveListArrayConverter", g); + changed = true; + } + if (opts_.do_function_inlining() && RemoveDeadNodes(g)) { + DumpGraph("RemoveDeadNodes", g); + changed = true; + } + if (opts_.do_function_inlining() && RemoveIdentityNodes(g)) { + DumpGraph("RemoveIdentityNodes", g); + changed = true; + } + if (opts_.do_constant_folding()) { + ConstantFoldingOptions cf_opts; + if (DoConstantFolding(cf_opts, g)) { + DumpGraph("ConstFolding", g); + changed = true; + } + } + if (opts_.do_function_inlining() && FixupSourceAndSinkEdges(g)) { + DumpGraph("FixupSourceAndSinkEdges", g); + changed = true; + } + if (opts_.do_common_subexpression_elimination() && + OptimizeCSE(g, nullptr)) { + DumpGraph("OptimizeCSE", g); + changed = true; + } + if (opts_.do_function_inlining() && ExpandInlineFunctions(runtime, g)) { + DumpGraph("ExpandInlineFunctions", g); + changed = true; + } + if (!changed) break; + } + + Graph* copy = new Graph(g->op_registry()); + CopyGraph(*g, copy); + delete g; + *graph = copy; + DumpGraph("ReCopy", *graph); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h new file mode 100644 index 0000000000..bbe643a43e --- /dev/null +++ b/tensorflow/core/common_runtime/graph_optimizer.h @@ -0,0 +1,43 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_ + +#include "tensorflow/core/framework/config.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +class GraphOptimizer { + public: + GraphOptimizer(const OptimizerOptions& opts); + ~GraphOptimizer(); + + // Applies optimization passes specified in 'opts' to 'graph'. + // Maybe replace *graph with a new graph object. + void Optimize(FunctionLibraryRuntime* runtime, Graph** graph); + + private: + OptimizerOptions opts_; + + TF_DISALLOW_COPY_AND_ASSIGN(GraphOptimizer); +}; + +} // end namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_ diff --git a/tensorflow/core/framework/config.proto b/tensorflow/core/framework/config.proto index 5218126676..6e0188226b 100644 --- a/tensorflow/core/framework/config.proto +++ b/tensorflow/core/framework/config.proto @@ -26,6 +26,36 @@ message GPUOptions { int64 deferred_deletion_bytes = 3; }; +// Options passed to the graph optimizer +message OptimizerOptions { + // If true, optimize the graph using common subexpression elimination. + bool do_common_subexpression_elimination = 1; + + // If true, perform constant folding optimization on the graph. + bool do_constant_folding = 2; + + // If true, perform function inlining on the graph. + bool do_function_inlining = 4; + + // Optimization level + enum Level { + // L1 is the default level. + // Optimization performed at L1 : + // 1. Common subexpression elimination + L1 = 0; + + // Optimization performed at L2 : + // 1. Common subexpression elimination + // 2. Constant folding + L2 = 2; + + // No optimizations + L0 = -1; + } + + Level opt_level = 3; +} + message GraphOptions { // Removed, use optimizer_options below. reserved "skip_common_subexpression_elimination"; @@ -35,31 +65,7 @@ message GraphOptions { // (Currently ignored.) bool enable_recv_scheduling = 2; - // Options passed to the graph optimizer - message OptimizerOptions { - // If true, optimize the graph using common subexpression elimination. - bool do_common_subexpression_elimination = 1; - - // If true, perform constant folding optimization on the graph. - bool do_constant_folding = 2; - - // Optimization level - enum Level { - // L1 is the default level. - // Optimization performed at L1 : - // 1. Common subexpression elimination - L1 = 0; - // Optimization performed at L2 : - // 1. Common subexpression elimination - // 2. Constant folding - L2 = 2; - // No optimizations - L0 = -1; - } - - Level opt_level = 3; - } - + // Options controlling how graph is optimized. OptimizerOptions optimizer_options = 3; }; diff --git a/tensorflow/core/graph/algorithm.cc b/tensorflow/core/graph/algorithm.cc index 0df3f1c3db..f47e7df961 100644 --- a/tensorflow/core/graph/algorithm.cc +++ b/tensorflow/core/graph/algorithm.cc @@ -150,17 +150,21 @@ void PruneForReverseReachability(Graph* g, FixupSourceAndSinkEdges(g); } -void FixupSourceAndSinkEdges(Graph* g) { +bool FixupSourceAndSinkEdges(Graph* g) { // Connect all nodes with no incoming edges to source. // Connect all nodes with no outgoing edges to sink. + bool changed = false; for (Node* n : g->nodes()) { if (!n->IsSource() && n->in_edges().empty()) { g->AddControlEdge(g->source_node(), n); + changed = true; } if (!n->IsSink() && n->out_edges().empty()) { g->AddControlEdge(n, g->sink_node()); + changed = true; } } + return changed; } } // namespace tensorflow diff --git a/tensorflow/core/graph/algorithm.h b/tensorflow/core/graph/algorithm.h index 5851e7b7ab..f57e83e686 100644 --- a/tensorflow/core/graph/algorithm.h +++ b/tensorflow/core/graph/algorithm.h @@ -55,7 +55,9 @@ void PruneForReverseReachability(Graph* g, // Connect all nodes with no incoming edges to source. // Connect all nodes with no outgoing edges to sink. -void FixupSourceAndSinkEdges(Graph* g); +// +// Returns true if and only if 'g' is mutated. +bool FixupSourceAndSinkEdges(Graph* g); } // namespace tensorflow diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index cc702eeda5..e1d80782b6 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -382,15 +382,15 @@ bool GraphConstructor::TypeValidateEdge(const Edge* edge) { return true; } -static void SetDoCSE(const GraphOptions::OptimizerOptions& optimizer_opt, - bool force, GraphConstructorOptions* graph_opt) { +static void SetDoCSE(const OptimizerOptions& optimizer_opt, bool force, + GraphConstructorOptions* graph_opt) { graph_opt->optimizer_do_cse = force || optimizer_opt.do_common_subexpression_elimination(); } -static void SetDoConstantFolding( - const GraphOptions::OptimizerOptions& optimizer_opt, bool force, - GraphConstructorOptions* graph_opt) { +static void SetDoConstantFolding(const OptimizerOptions& optimizer_opt, + bool force, + GraphConstructorOptions* graph_opt) { graph_opt->optimizer_do_constant_folding = force || optimizer_opt.do_constant_folding(); } @@ -401,18 +401,19 @@ static void SetDoConstantFolding( // GraphConstructorOptions functions // ---------------------------------------------------------------------------- -GraphConstructorOptions::GraphConstructorOptions( - const GraphOptions::OptimizerOptions& opts) { +GraphConstructorOptions::GraphConstructorOptions() {} + +GraphConstructorOptions::GraphConstructorOptions(const OptimizerOptions& opts) { // Set the individually specified options first. SetDoCSE(opts, false, this); SetDoConstantFolding(opts, false, this); // Set options that the level signifies - if (opts.opt_level() == GraphOptions::OptimizerOptions::L0) { + if (opts.opt_level() == OptimizerOptions::L0) { // No optimizations performed. - } else if (opts.opt_level() == GraphOptions::OptimizerOptions::L1) { + } else if (opts.opt_level() == OptimizerOptions::L1) { SetDoCSE(opts, true, this); - } else if (opts.opt_level() == GraphOptions::OptimizerOptions::L2) { + } else if (opts.opt_level() == OptimizerOptions::L2) { SetDoCSE(opts, true, this); SetDoConstantFolding(opts, true, this); } diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h index e779e5631a..bf3f818e1e 100644 --- a/tensorflow/core/graph/graph_constructor.h +++ b/tensorflow/core/graph/graph_constructor.h @@ -33,8 +33,8 @@ struct ConstantFoldingOptions { // Construct a graph *g out of a GraphDef gdef. Returns non-OK on // error, in which case *g is left in an incomplete state. struct GraphConstructorOptions { - explicit GraphConstructorOptions( - const GraphOptions::OptimizerOptions& opts = Level0()); + GraphConstructorOptions(); + explicit GraphConstructorOptions(const OptimizerOptions& opts); // If true, allows internal ops in the GraphDef. bool allow_internal_ops = false; @@ -59,12 +59,6 @@ struct GraphConstructorOptions { bool optimizer_do_constant_folding = false; ConstantFoldingOptions constant_folding_opts; - - static GraphOptions::OptimizerOptions Level0() { - GraphOptions::OptimizerOptions ret; - ret.set_opt_level(GraphOptions::OptimizerOptions::L0); - return ret; - } }; extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, const GraphDef& gdef, Graph* g); diff --git a/tensorflow/core/graph/optimizer_cse.cc b/tensorflow/core/graph/optimizer_cse.cc index 26ed552677..564b47542f 100644 --- a/tensorflow/core/graph/optimizer_cse.cc +++ b/tensorflow/core/graph/optimizer_cse.cc @@ -52,7 +52,7 @@ class OptimizerCSE { public: explicit OptimizerCSE(Graph* g) : g_(g) {} - void Optimize(std::function<bool(const Node*)> consider_fn); + bool Optimize(std::function<bool(const Node*)> consider_fn); private: struct Scratch; @@ -180,7 +180,7 @@ bool OptimizerCSE::Equivalent(const Node* a, const Node* b, Scratch* scratch) { return true; } -void OptimizerCSE::Optimize(std::function<bool(const Node*)> consider_fn) { +bool OptimizerCSE::Optimize(std::function<bool(const Node*)> consider_fn) { // This very simple implementation works if the whole graph is one // giant basic block (because we just traverse nodes in a // topological order). We'll need to do something more @@ -202,6 +202,7 @@ void OptimizerCSE::Optimize(std::function<bool(const Node*)> consider_fn) { // Scratch space for Equivalent calls. Allocated here and passed in to // Equivalent to avoid allocation inside the loop below. + bool changed = false; Scratch scratch; for (Node* n : order) { if (!n->IsOp()) continue; @@ -224,13 +225,15 @@ void OptimizerCSE::Optimize(std::function<bool(const Node*)> consider_fn) { g_->AddEdge(*candidate, e->src_output(), e->dst(), e->dst_input()); } g_->RemoveNode(n); + changed = true; } } + return changed; } -void OptimizeCSE(Graph* g, std::function<bool(const Node*)> consider_fn) { +bool OptimizeCSE(Graph* g, std::function<bool(const Node*)> consider_fn) { OptimizerCSE opt(g); - opt.Optimize(consider_fn); + return opt.Optimize(consider_fn); } } // namespace tensorflow diff --git a/tensorflow/core/graph/optimizer_cse.h b/tensorflow/core/graph/optimizer_cse.h index d62ea48b5c..310c906982 100644 --- a/tensorflow/core/graph/optimizer_cse.h +++ b/tensorflow/core/graph/optimizer_cse.h @@ -27,7 +27,9 @@ namespace tensorflow { // "consider_fn" is not nullptr, then only nodes for which // consider_fn(node) returns true will be considered for combining // during the common subexpression elimination. -extern void OptimizeCSE(Graph* g, std::function<bool(const Node*)> consider_fn); +// +// Returns true if and only if 'g' is mutated. +extern bool OptimizeCSE(Graph* g, std::function<bool(const Node*)> consider_fn); } // namespace tensorflow diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc index 66118489dd..45c1152a33 100644 --- a/tensorflow/core/ops/array_grad.cc +++ b/tensorflow/core/ops/array_grad.cc @@ -237,4 +237,27 @@ Status ListToArrayGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("_ListToArray", ListToArrayGrad); +Status FillGrad(const AttrSlice& attrs, FunctionDef* g) { + *g = FDH::Define( + // Arg defs + {"dims: int32", "x: T", "dy: T"}, + // Ret val defs + {"d_dims: int32", "dx: T"}, + // Attr defs + {"T: {float, double}"}, + // Nodes + { + {{"d_dims"}, "ZerosLike", {"dims"}, {{"T", DT_INT32}}}, + FDH::Const("zero", 0), + {{"rank"}, "Rank", {"dy"}, {{"T", "$T"}}}, + FDH::Const("one", 1), + {{"r"}, "Range", {"zero", "rank", "one"}, {}}, + // dx = sum(dy) + {{"dx"}, "Sum", {"dy", "r"}, {{"T", "$T"}}}, + }); + VLOG(1) << "FillGrad " << DebugString(*g); + return Status::OK(); +} +REGISTER_OP_GRADIENT("Fill", FillGrad); + } // end namespace tensorflow diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 865d30a516..5b278d1d5f 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -29,7 +29,6 @@ import tensorflow as tf from tensorflow.python.framework import function # pylint: disable=unused-import from tensorflow.python.ops import functional_ops - # pylint: enable=unused-import @@ -384,9 +383,22 @@ class UnrollLSTMTest(tf.test.TestCase): return LSTMLoop10(weights, inp) + def _OptimizerOptions(self): + ret = [] + for cse in [False, True]: + for inline in [False, True]: + for cfold in [False, True]: + ret.append(tf.ConfigProto(graph_options=tf.GraphOptions( + optimizer_options=tf.OptimizerOptions( + opt_level=tf.OptimizerOptions.L0, + do_common_subexpression_elimination=cse, + do_function_inlining=inline, + do_constant_folding=cfold)))) + return ret + def testUnrollLSTM(self): # Run one step of the unrolled lstm graph. - def RunForward(mode): + def RunForward(mode, cfg=None): g = tf.Graph() start = time.time() with g.as_default(): @@ -397,20 +409,22 @@ class UnrollLSTMTest(tf.test.TestCase): finish = time.time() print("time: ", finish - start, " txt size: ", len(str(gdef)), "gdef bin size: ", len(gdef.SerializeToString())) - with g.as_default(), tf.Session() as sess: + with g.as_default(), tf.Session(config=cfg) as sess: return sess.run(m) mv0 = RunForward("complete") - mv1 = RunForward("cell") - mv2 = RunForward("loop") - mv3 = RunForward("loop10") - self.assertAllClose(mv0, mv1, rtol=1e-4) - self.assertAllClose(mv0, mv2, rtol=1e-4) - self.assertAllClose(mv0, mv3, rtol=1e-4) + for cfg in self._OptimizerOptions(): + print("cfg = ", cfg) + mv1 = RunForward("cell", cfg) + mv2 = RunForward("loop", cfg) + mv3 = RunForward("loop10", cfg) + self.assertAllClose(mv0, mv1, rtol=1e-4) + self.assertAllClose(mv0, mv2, rtol=1e-4) + self.assertAllClose(mv0, mv3, rtol=1e-4) def testUnrollLSTMGrad(self): # Run one step of the unrolled lstm graph. - def RunForwardBackward(mode): + def RunForwardBackward(mode, cfg=None): g = tf.Graph() start = time.time() with g.as_default(): @@ -423,16 +437,18 @@ class UnrollLSTMTest(tf.test.TestCase): finish = time.time() print("time: ", finish - start, " txt size: ", len(str(gdef)), "gdef bin size: ", len(gdef.SerializeToString())) - with g.as_default(), tf.Session() as sess: + with g.as_default(), tf.Session(config=cfg) as sess: return sess.run(dw) d0 = RunForwardBackward("complete") - d1 = RunForwardBackward("cell") - d2 = RunForwardBackward("loop") - d3 = RunForwardBackward("loop10") - self.assertAllClose(d0, d1, rtol=1e-4) - self.assertAllClose(d0, d2, rtol=1e-4) - self.assertAllClose(d0, d3, rtol=1e-4) + for cfg in self._OptimizerOptions(): + print("cfg = ", cfg) + d1 = RunForwardBackward("cell", cfg) + d2 = RunForwardBackward("loop", cfg) + d3 = RunForwardBackward("loop10", cfg) + self.assertAllClose(d0, d1, rtol=1e-4) + self.assertAllClose(d0, d2, rtol=1e-4) + self.assertAllClose(d0, d3, rtol=1e-4) if __name__ == "__main__": diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py index 939beb053d..6694ae59ba 100644 --- a/tensorflow/python/framework/gen_docs_combined.py +++ b/tensorflow/python/framework/gen_docs_combined.py @@ -121,7 +121,7 @@ _hidden_symbols = ["Event", "LogMessage", "Summary", "SessionLog", "xrange", "HistogramProto", "ConfigProto", "NodeDef", "GraphDef", "GPUOptions", "GraphOptions", "SessionInterface", "BaseSession", "NameAttrList", "AttrValue", - "TensorArray"] + "TensorArray", "OptimizerOptions"] def main(unused_argv): if not FLAGS.out_dir: |