aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-02-02 16:51:43 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-02-03 09:34:58 -0800
commite830638148e203a2d9cf491e4693d35661a360d1 (patch)
tree11368b2e6b8a9d125f37050484da640983dbecfc
parent08b09699ad5f8b57add1f83461568b647c42129a (diff)
Refactor the logic to apply optimization into a common module.
Change: 113692577
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc45
-rw-r--r--tensorflow/core/common_runtime/function.cc143
-rw-r--r--tensorflow/core/common_runtime/function.h4
-rw-r--r--tensorflow/core/common_runtime/function_test.cc3
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.cc91
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.h43
-rw-r--r--tensorflow/core/framework/config.proto56
-rw-r--r--tensorflow/core/graph/algorithm.cc6
-rw-r--r--tensorflow/core/graph/algorithm.h4
-rw-r--r--tensorflow/core/graph/graph_constructor.cc21
-rw-r--r--tensorflow/core/graph/graph_constructor.h10
-rw-r--r--tensorflow/core/graph/optimizer_cse.cc11
-rw-r--r--tensorflow/core/graph/optimizer_cse.h4
-rw-r--r--tensorflow/core/ops/array_grad.cc23
-rw-r--r--tensorflow/python/framework/function_test.py50
-rw-r--r--tensorflow/python/framework/gen_docs_combined.py2
16 files changed, 334 insertions, 182 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 17eb94f495..3e775c5c48 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/common_runtime/session_factory.h"
#include "tensorflow/core/common_runtime/simple_placer.h"
#include "tensorflow/core/framework/function.h"
@@ -612,6 +613,9 @@ Status DirectSession::GetOrCreateExecutors(
ek->func_defs = fdefs;
ek->items.reserve(graphs.size());
auto runner = [this](Executor::Args::Closure c) { SchedClosure(c); };
+ const auto& optimizer_opts =
+ options_.config.graph_options().optimizer_options();
+ GraphOptimizer optimizer(optimizer_opts);
for (const auto& graph : graphs) {
const string& partition_name = graph.first;
Graph* partition_graph = graph.second;
@@ -623,8 +627,8 @@ Status DirectSession::GetOrCreateExecutors(
ek->items.resize(ek->items.size() + 1);
auto* item = &(ek->items.back());
- item->flib =
- NewFunctionLibraryRuntime(device, runner, graph_def_version, fdefs);
+ item->flib = NewFunctionLibraryRuntime(device, runner, graph_def_version,
+ fdefs, optimizer_opts);
LocalExecutorParams params;
params.device = device;
@@ -646,6 +650,7 @@ Status DirectSession::GetOrCreateExecutors(
// Do nothing because 'kernel' is owned by opseg above.
};
+ optimizer.Optimize(lib, &partition_graph);
s = NewLocalExecutor(params, partition_graph, &item->executor);
if (!s.ok()) {
return s;
@@ -715,38 +720,11 @@ Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds,
RunStateArgs* run_state_args) {
std::unique_ptr<FunctionLibraryDefinition> fdefs;
std::unique_ptr<Graph> graph;
- GraphConstructorOptions opts{
- options_.config.graph_options().optimizer_options()};
-
- std::unordered_set<StringPiece, StringPiece::Hasher> keep_nodes;
- for (const string& feed : feeds) {
- keep_nodes.insert(ParseTensorName(feed).first);
- }
- for (const string& fetch : fetches) {
- keep_nodes.insert(ParseTensorName(fetch).first);
- }
- for (const string& target_node : target_nodes) {
- keep_nodes.insert(target_node);
- }
-
- if (opts.optimizer_do_cse) {
- // Prevent CSE from eliminating nodes that will be required during
- // RewriteGraphForExecution, below.
- opts.cse_consider_function = [&keep_nodes](const Node* n) {
- return n->IsConstant() && !keep_nodes.count(n->name());
- };
- }
-
- if (opts.optimizer_do_constant_folding) {
- opts.constant_folding_opts.consider = [&keep_nodes](const Node* n) {
- return keep_nodes.count(n->name()) > 0;
- };
- }
-
{
mutex_lock l(graph_def_lock_);
fdefs.reset(new FunctionLibraryDefinition(graph_def_.library()));
graph.reset(new Graph(fdefs.get()));
+ GraphConstructorOptions opts;
TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def_, graph.get()));
}
@@ -760,13 +738,6 @@ Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds,
graph.get(), feeds, fetches, target_nodes,
device_set_.client_device()->attributes()));
- if (opts.optimizer_do_constant_folding) {
- bool constant_folded =
- DoConstantFolding(opts.constant_folding_opts, graph.get());
- VLOG(2) << (constant_folded ? "Folded some constants"
- : "Found no constant folding opportunity");
- }
-
GraphDef graph_def;
graph->ToGraphDef(&graph_def);
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 85052ae132..d83b293c2a 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
@@ -238,13 +239,32 @@ class RetvalOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp);
REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_GPU), RetvalOp);
+class PassOn : public OpKernel {
+ public:
+ explicit PassOn(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(),
+ errors::Internal("#inputs != #outputs : ", ctx->num_inputs(),
+ " vs. ", ctx->num_outputs()));
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ ctx->set_output(i, ctx->input(i));
+ }
+ }
+};
+REGISTER_KERNEL_BUILDER(Name("_ListToArray").Device(DEVICE_CPU), PassOn);
+REGISTER_KERNEL_BUILDER(Name("_ListToArray").Device(DEVICE_GPU), PassOn);
+REGISTER_KERNEL_BUILDER(Name("_ArrayToList").Device(DEVICE_CPU), PassOn);
+REGISTER_KERNEL_BUILDER(Name("_ArrayToList").Device(DEVICE_GPU), PassOn);
+
static const FunctionLibraryRuntime::Handle kInvalidHandle = -1;
class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
public:
FunctionLibraryRuntimeImpl(Device* device, Runner runner,
int graph_def_version,
- const FunctionLibraryDefinition* lib_def);
+ const FunctionLibraryDefinition* lib_def,
+ const OptimizerOptions& optimizer_options);
~FunctionLibraryRuntimeImpl() override;
@@ -268,6 +288,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
Runner runner_ = nullptr;
const int graph_def_version_;
const FunctionLibraryDefinition* const lib_def_;
+ GraphOptimizer optimizer_;
std::function<Status(const string&, const OpDef**)> get_func_sig_;
std::function<Status(const NodeDef&, OpKernel**)> create_kernel_;
@@ -303,11 +324,13 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
Device* device, Runner runner, int graph_def_version,
- const FunctionLibraryDefinition* lib_def)
+ const FunctionLibraryDefinition* lib_def,
+ const OptimizerOptions& optimizer_options)
: device_(device),
runner_(runner),
graph_def_version_(graph_def_version),
- lib_def_(lib_def) {
+ lib_def_(lib_def),
+ optimizer_(optimizer_options) {
get_func_sig_ = [this](const string& op, const OpDef** sig) {
Status s;
*sig = lib_def_->LookUp(op, &s);
@@ -366,19 +389,21 @@ class CallOp : public AsyncOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(CallOp);
};
-class SymbolicGradientOp : public OpKernel {
+class SymbolicGradientOp : public AsyncOpKernel {
public:
SymbolicGradientOp(OpKernelConstruction* ctx)
- : OpKernel(ctx), handle_(kInvalidHandle) {}
+ : AsyncOpKernel(ctx), handle_(kInvalidHandle) {}
~SymbolicGradientOp() override {}
- void Compute(OpKernelContext* ctx) override {
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
FunctionLibraryRuntime* lib = ctx->function_library();
- OP_REQUIRES(ctx, lib != nullptr,
- errors::Internal("No function library is provided."));
+ OP_REQUIRES_ASYNC(ctx, lib != nullptr,
+ errors::Internal("No function library is provided."),
+ done);
- OP_REQUIRES_OK(ctx, lib->Instantiate(kGradientOp, def().attr(), &handle_));
+ OP_REQUIRES_OK_ASYNC(
+ ctx, lib->Instantiate(kGradientOp, def().attr(), &handle_), done);
FunctionLibraryRuntime::Options opts;
std::vector<Tensor> args;
@@ -387,20 +412,19 @@ class SymbolicGradientOp : public OpKernel {
args.push_back(ctx->input(i));
}
std::vector<Tensor>* rets = new std::vector<Tensor>;
- Notification n;
- lib->Run(opts, handle_, args, rets, [ctx, rets, &n](const Status& status) {
- if (!status.ok()) {
- ctx->SetStatus(status);
- } else {
- CHECK_EQ(rets->size(), ctx->num_outputs());
- for (size_t i = 0; i < rets->size(); ++i) {
- ctx->set_output(i, (*rets)[i]);
- }
- }
- delete rets;
- n.Notify();
- });
- n.WaitForNotification();
+ lib->Run(opts, handle_, args, rets,
+ [ctx, done, rets](const Status& status) {
+ if (!status.ok()) {
+ ctx->SetStatus(status);
+ } else {
+ CHECK_EQ(rets->size(), ctx->num_outputs());
+ for (size_t i = 0; i < rets->size(); ++i) {
+ ctx->set_output(i, (*rets)[i]);
+ }
+ }
+ delete rets;
+ done();
+ });
}
private:
@@ -541,7 +565,8 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
void DumpGraph(StringPiece label, const Graph* g) {
// TODO(zhifengc): Change Graph to record #nodes.
- VLOG(1) << "Graph " << label << " #edges " << g->edges().size();
+ VLOG(1) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges "
+ << g->edges().size();
if (VLOG_IS_ON(2)) {
for (const auto& line : str_util::Split(DebugString(g), '\n')) {
VLOG(2) << "|| " << line;
@@ -549,54 +574,13 @@ void DumpGraph(StringPiece label, const Graph* g) {
}
}
-static void SimplifyGraph(Graph* g) {
- if (RemoveListArrayConverter(g)) {
- DumpGraph("RemoveListArrayConverter", g);
- }
- bool changed;
- do {
- changed = false;
- if (RemoveDeadNodes(g)) {
- changed = true;
- DumpGraph("RemoveDeadNodes", g);
- }
- if (RemoveIdentityNodes(g)) {
- changed = true;
- DumpGraph("RemoveIdentityNodes", g);
- }
- FixupSourceAndSinkEdges(g);
- OptimizeCSE(g, nullptr);
- DumpGraph("OptimizeCSE", g);
- } while (changed);
-}
-
void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g) {
- for (const Node* n : (*g)->nodes()) {
- if (n->IsControlFlow()) {
- VLOG(2) << "Skip OptimizeGraph due to control flow ops.";
- return;
- }
- }
-
- DumpGraph("Initial", *g);
-
- // Run SimplifyGraph at least once to rewrite away ops such as
- // _ListToArray, _ArrayToList, etc.
- SimplifyGraph(*g);
-
- const int kNumInlineRounds = 10;
- for (int i = 0; i < kNumInlineRounds; ++i) {
- if (!ExpandInlineFunctions(lib, *g)) break;
- DumpGraph("ExpandInlineFunctions", *g);
- SimplifyGraph(*g);
- }
-
- // Makes a copy so that we densify node ids.
- Graph* copy = new Graph((*g)->op_registry());
- CopyGraph(**g, copy);
- delete *g;
- *g = copy;
- DumpGraph("ReCopy", *g);
+ OptimizerOptions opts;
+ opts.set_do_common_subexpression_elimination(true);
+ opts.set_do_function_inlining(true);
+ opts.set_do_constant_folding(true);
+ GraphOptimizer optimizer(opts);
+ optimizer.Optimize(lib, g);
}
Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
@@ -604,7 +588,8 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
CHECK_NOTNULL(fbody);
Graph* g = new Graph(lib_def_);
CopyGraph(*fbody->graph, g);
- OptimizeGraph(this, &g);
+
+ optimizer_.Optimize(this, &g);
// Creates an executor based on the g. This must be done without
// holding mu_ because create_kernel_ calls back into the library.
@@ -697,12 +682,14 @@ bool FunctionLibraryRuntimeImpl::IsDefined(const string& function_name) {
FunctionLibraryRuntime* NewFunctionLibraryRuntime(
Device* device, Runner runner, int graph_def_version,
- const FunctionLibraryDefinition* lib_def) {
+ const FunctionLibraryDefinition* lib_def,
+ const OptimizerOptions& optimizer_options) {
return new FunctionLibraryRuntimeImpl(device, runner, graph_def_version,
- lib_def);
+ lib_def, optimizer_options);
}
bool RemoveDeadNodes(Graph* g) {
+ VLOG(2) << "Removing dead nodes";
std::vector<bool> visited(g->num_node_ids(), false);
std::deque<Node*> q;
for (auto n : g->nodes()) {
@@ -750,6 +737,7 @@ const Edge* GetTheOnlyDataEdge(const EdgeSet& edges) {
} // end namespace
bool RemoveIdentityNodes(Graph* g) {
+ VLOG(2) << "Removing identity nodes";
bool removed_any = false;
gtl::InlinedVector<Node*, 8> matches;
for (Node* n : g->nodes()) {
@@ -775,6 +763,7 @@ bool RemoveIdentityNodes(Graph* g) {
}
bool RemoveListArrayConverter(Graph* g) {
+ VLOG(2) << "Removing list array converter";
gtl::InlinedVector<Node*, 8> matches;
for (Node* n : g->nodes()) {
if ((n->type_string() == "_ListToArray") ||
@@ -904,10 +893,14 @@ static void InlineFunctionBody(Graph* g, Node* caller,
// If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we
// remember 'y' in node_map[x->id()].
std::vector<Node*> node_map(fbody->graph->num_node_ids());
+ Status s;
for (Node* n : fbody->graph->nodes()) {
if (n->IsSource() || n->IsSink()) continue;
CHECK(n->IsOp());
- node_map[n->id()] = g->CopyNode(n);
+ NodeDef ndef = n->def();
+ ndef.set_name(strings::StrCat(caller->name(), "/", ndef.name()));
+ node_map[n->id()] = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
}
for (const Edge* e : fbody->graph->edges()) {
if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() ||
diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h
index d8c93fa313..ae19da1b34 100644
--- a/tensorflow/core/common_runtime/function.h
+++ b/tensorflow/core/common_runtime/function.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <functional>
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/config.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
@@ -34,7 +35,8 @@ typedef std::function<void()> Closure;
typedef std::function<void(Closure)> Runner;
FunctionLibraryRuntime* NewFunctionLibraryRuntime(
Device* device, Runner runner, int graph_def_version,
- const FunctionLibraryDefinition* lib_def);
+ const FunctionLibraryDefinition* lib_def,
+ const OptimizerOptions& optimizer_options);
// FunctionLibraryRuntime::GetFunctionBody returns a description of an
// instantiated function that is represented as a Graph with arg/ret
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 1d27864a3b..e8e77256ac 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -148,8 +148,9 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
delete lib_def_;
lib_def_ = new FunctionLibraryDefinition(proto);
delete lib_;
+ OptimizerOptions opts;
lib_ = NewFunctionLibraryRuntime(device_, FunctionTestSchedClosure,
- TF_GRAPH_DEF_VERSION, lib_def_);
+ TF_GRAPH_DEF_VERSION, lib_def_, opts);
}
Status Run(const string& name, InstantiateAttrValueSlice attrs,
diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc
new file mode 100644
index 0000000000..d5a8fd604f
--- /dev/null
+++ b/tensorflow/core/common_runtime/graph_optimizer.cc
@@ -0,0 +1,91 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/graph_optimizer.h"
+
+#include "tensorflow/core/common_runtime/constant_folding.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/optimizer_cse.h"
+
+namespace tensorflow {
+
+GraphOptimizer::GraphOptimizer(const OptimizerOptions& opts) : opts_(opts) {
+ if (opts_.opt_level() >= OptimizerOptions::L1) {
+ opts_.set_do_common_subexpression_elimination(true);
+ }
+ if (opts_.opt_level() >= OptimizerOptions::L2) {
+ opts_.set_do_constant_folding(true);
+ }
+}
+
+GraphOptimizer::~GraphOptimizer() {}
+
+void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Graph** graph) {
+ Graph* g = *graph;
+ for (const Node* n : g->nodes()) {
+ if (n->IsControlFlow()) {
+ VLOG(2) << "Skip optimization if there is any control flow ops";
+ }
+ }
+
+ DumpGraph("Initial", g);
+ bool changed = true;
+ const int kMaxRounds = 10;
+ for (int rounds = 0; rounds < kMaxRounds; ++rounds) {
+ changed = false;
+ if (opts_.do_function_inlining() && RemoveListArrayConverter(g)) {
+ DumpGraph("RemoveListArrayConverter", g);
+ changed = true;
+ }
+ if (opts_.do_function_inlining() && RemoveDeadNodes(g)) {
+ DumpGraph("RemoveDeadNodes", g);
+ changed = true;
+ }
+ if (opts_.do_function_inlining() && RemoveIdentityNodes(g)) {
+ DumpGraph("RemoveIdentityNodes", g);
+ changed = true;
+ }
+ if (opts_.do_constant_folding()) {
+ ConstantFoldingOptions cf_opts;
+ if (DoConstantFolding(cf_opts, g)) {
+ DumpGraph("ConstFolding", g);
+ changed = true;
+ }
+ }
+ if (opts_.do_function_inlining() && FixupSourceAndSinkEdges(g)) {
+ DumpGraph("FixupSourceAndSinkEdges", g);
+ changed = true;
+ }
+ if (opts_.do_common_subexpression_elimination() &&
+ OptimizeCSE(g, nullptr)) {
+ DumpGraph("OptimizeCSE", g);
+ changed = true;
+ }
+ if (opts_.do_function_inlining() && ExpandInlineFunctions(runtime, g)) {
+ DumpGraph("ExpandInlineFunctions", g);
+ changed = true;
+ }
+ if (!changed) break;
+ }
+
+ Graph* copy = new Graph(g->op_registry());
+ CopyGraph(*g, copy);
+ delete g;
+ *graph = copy;
+ DumpGraph("ReCopy", *graph);
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h
new file mode 100644
index 0000000000..bbe643a43e
--- /dev/null
+++ b/tensorflow/core/common_runtime/graph_optimizer.h
@@ -0,0 +1,43 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_
+
+#include "tensorflow/core/framework/config.pb.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class GraphOptimizer {
+ public:
+ GraphOptimizer(const OptimizerOptions& opts);
+ ~GraphOptimizer();
+
+ // Applies optimization passes specified in 'opts' to 'graph'.
+ // Maybe replace *graph with a new graph object.
+ void Optimize(FunctionLibraryRuntime* runtime, Graph** graph);
+
+ private:
+ OptimizerOptions opts_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GraphOptimizer);
+};
+
+} // end namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_OPTIMIZER_H_
diff --git a/tensorflow/core/framework/config.proto b/tensorflow/core/framework/config.proto
index 5218126676..6e0188226b 100644
--- a/tensorflow/core/framework/config.proto
+++ b/tensorflow/core/framework/config.proto
@@ -26,6 +26,36 @@ message GPUOptions {
int64 deferred_deletion_bytes = 3;
};
+// Options passed to the graph optimizer
+message OptimizerOptions {
+ // If true, optimize the graph using common subexpression elimination.
+ bool do_common_subexpression_elimination = 1;
+
+ // If true, perform constant folding optimization on the graph.
+ bool do_constant_folding = 2;
+
+ // If true, perform function inlining on the graph.
+ bool do_function_inlining = 4;
+
+ // Optimization level
+ enum Level {
+ // L1 is the default level.
+ // Optimization performed at L1 :
+ // 1. Common subexpression elimination
+ L1 = 0;
+
+ // Optimization performed at L2 :
+ // 1. Common subexpression elimination
+ // 2. Constant folding
+ L2 = 2;
+
+ // No optimizations
+ L0 = -1;
+ }
+
+ Level opt_level = 3;
+}
+
message GraphOptions {
// Removed, use optimizer_options below.
reserved "skip_common_subexpression_elimination";
@@ -35,31 +65,7 @@ message GraphOptions {
// (Currently ignored.)
bool enable_recv_scheduling = 2;
- // Options passed to the graph optimizer
- message OptimizerOptions {
- // If true, optimize the graph using common subexpression elimination.
- bool do_common_subexpression_elimination = 1;
-
- // If true, perform constant folding optimization on the graph.
- bool do_constant_folding = 2;
-
- // Optimization level
- enum Level {
- // L1 is the default level.
- // Optimization performed at L1 :
- // 1. Common subexpression elimination
- L1 = 0;
- // Optimization performed at L2 :
- // 1. Common subexpression elimination
- // 2. Constant folding
- L2 = 2;
- // No optimizations
- L0 = -1;
- }
-
- Level opt_level = 3;
- }
-
+ // Options controlling how graph is optimized.
OptimizerOptions optimizer_options = 3;
};
diff --git a/tensorflow/core/graph/algorithm.cc b/tensorflow/core/graph/algorithm.cc
index 0df3f1c3db..f47e7df961 100644
--- a/tensorflow/core/graph/algorithm.cc
+++ b/tensorflow/core/graph/algorithm.cc
@@ -150,17 +150,21 @@ void PruneForReverseReachability(Graph* g,
FixupSourceAndSinkEdges(g);
}
-void FixupSourceAndSinkEdges(Graph* g) {
+bool FixupSourceAndSinkEdges(Graph* g) {
// Connect all nodes with no incoming edges to source.
// Connect all nodes with no outgoing edges to sink.
+ bool changed = false;
for (Node* n : g->nodes()) {
if (!n->IsSource() && n->in_edges().empty()) {
g->AddControlEdge(g->source_node(), n);
+ changed = true;
}
if (!n->IsSink() && n->out_edges().empty()) {
g->AddControlEdge(n, g->sink_node());
+ changed = true;
}
}
+ return changed;
}
} // namespace tensorflow
diff --git a/tensorflow/core/graph/algorithm.h b/tensorflow/core/graph/algorithm.h
index 5851e7b7ab..f57e83e686 100644
--- a/tensorflow/core/graph/algorithm.h
+++ b/tensorflow/core/graph/algorithm.h
@@ -55,7 +55,9 @@ void PruneForReverseReachability(Graph* g,
// Connect all nodes with no incoming edges to source.
// Connect all nodes with no outgoing edges to sink.
-void FixupSourceAndSinkEdges(Graph* g);
+//
+// Returns true if and only if 'g' is mutated.
+bool FixupSourceAndSinkEdges(Graph* g);
} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index cc702eeda5..e1d80782b6 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -382,15 +382,15 @@ bool GraphConstructor::TypeValidateEdge(const Edge* edge) {
return true;
}
-static void SetDoCSE(const GraphOptions::OptimizerOptions& optimizer_opt,
- bool force, GraphConstructorOptions* graph_opt) {
+static void SetDoCSE(const OptimizerOptions& optimizer_opt, bool force,
+ GraphConstructorOptions* graph_opt) {
graph_opt->optimizer_do_cse =
force || optimizer_opt.do_common_subexpression_elimination();
}
-static void SetDoConstantFolding(
- const GraphOptions::OptimizerOptions& optimizer_opt, bool force,
- GraphConstructorOptions* graph_opt) {
+static void SetDoConstantFolding(const OptimizerOptions& optimizer_opt,
+ bool force,
+ GraphConstructorOptions* graph_opt) {
graph_opt->optimizer_do_constant_folding =
force || optimizer_opt.do_constant_folding();
}
@@ -401,18 +401,19 @@ static void SetDoConstantFolding(
// GraphConstructorOptions functions
// ----------------------------------------------------------------------------
-GraphConstructorOptions::GraphConstructorOptions(
- const GraphOptions::OptimizerOptions& opts) {
+GraphConstructorOptions::GraphConstructorOptions() {}
+
+GraphConstructorOptions::GraphConstructorOptions(const OptimizerOptions& opts) {
// Set the individually specified options first.
SetDoCSE(opts, false, this);
SetDoConstantFolding(opts, false, this);
// Set options that the level signifies
- if (opts.opt_level() == GraphOptions::OptimizerOptions::L0) {
+ if (opts.opt_level() == OptimizerOptions::L0) {
// No optimizations performed.
- } else if (opts.opt_level() == GraphOptions::OptimizerOptions::L1) {
+ } else if (opts.opt_level() == OptimizerOptions::L1) {
SetDoCSE(opts, true, this);
- } else if (opts.opt_level() == GraphOptions::OptimizerOptions::L2) {
+ } else if (opts.opt_level() == OptimizerOptions::L2) {
SetDoCSE(opts, true, this);
SetDoConstantFolding(opts, true, this);
}
diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h
index e779e5631a..bf3f818e1e 100644
--- a/tensorflow/core/graph/graph_constructor.h
+++ b/tensorflow/core/graph/graph_constructor.h
@@ -33,8 +33,8 @@ struct ConstantFoldingOptions {
// Construct a graph *g out of a GraphDef gdef. Returns non-OK on
// error, in which case *g is left in an incomplete state.
struct GraphConstructorOptions {
- explicit GraphConstructorOptions(
- const GraphOptions::OptimizerOptions& opts = Level0());
+ GraphConstructorOptions();
+ explicit GraphConstructorOptions(const OptimizerOptions& opts);
// If true, allows internal ops in the GraphDef.
bool allow_internal_ops = false;
@@ -59,12 +59,6 @@ struct GraphConstructorOptions {
bool optimizer_do_constant_folding = false;
ConstantFoldingOptions constant_folding_opts;
-
- static GraphOptions::OptimizerOptions Level0() {
- GraphOptions::OptimizerOptions ret;
- ret.set_opt_level(GraphOptions::OptimizerOptions::L0);
- return ret;
- }
};
extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
const GraphDef& gdef, Graph* g);
diff --git a/tensorflow/core/graph/optimizer_cse.cc b/tensorflow/core/graph/optimizer_cse.cc
index 26ed552677..564b47542f 100644
--- a/tensorflow/core/graph/optimizer_cse.cc
+++ b/tensorflow/core/graph/optimizer_cse.cc
@@ -52,7 +52,7 @@ class OptimizerCSE {
public:
explicit OptimizerCSE(Graph* g) : g_(g) {}
- void Optimize(std::function<bool(const Node*)> consider_fn);
+ bool Optimize(std::function<bool(const Node*)> consider_fn);
private:
struct Scratch;
@@ -180,7 +180,7 @@ bool OptimizerCSE::Equivalent(const Node* a, const Node* b, Scratch* scratch) {
return true;
}
-void OptimizerCSE::Optimize(std::function<bool(const Node*)> consider_fn) {
+bool OptimizerCSE::Optimize(std::function<bool(const Node*)> consider_fn) {
// This very simple implementation works if the whole graph is one
// giant basic block (because we just traverse nodes in a
// topological order). We'll need to do something more
@@ -202,6 +202,7 @@ void OptimizerCSE::Optimize(std::function<bool(const Node*)> consider_fn) {
// Scratch space for Equivalent calls. Allocated here and passed in to
// Equivalent to avoid allocation inside the loop below.
+ bool changed = false;
Scratch scratch;
for (Node* n : order) {
if (!n->IsOp()) continue;
@@ -224,13 +225,15 @@ void OptimizerCSE::Optimize(std::function<bool(const Node*)> consider_fn) {
g_->AddEdge(*candidate, e->src_output(), e->dst(), e->dst_input());
}
g_->RemoveNode(n);
+ changed = true;
}
}
+ return changed;
}
-void OptimizeCSE(Graph* g, std::function<bool(const Node*)> consider_fn) {
+bool OptimizeCSE(Graph* g, std::function<bool(const Node*)> consider_fn) {
OptimizerCSE opt(g);
- opt.Optimize(consider_fn);
+ return opt.Optimize(consider_fn);
}
} // namespace tensorflow
diff --git a/tensorflow/core/graph/optimizer_cse.h b/tensorflow/core/graph/optimizer_cse.h
index d62ea48b5c..310c906982 100644
--- a/tensorflow/core/graph/optimizer_cse.h
+++ b/tensorflow/core/graph/optimizer_cse.h
@@ -27,7 +27,9 @@ namespace tensorflow {
// "consider_fn" is not nullptr, then only nodes for which
// consider_fn(node) returns true will be considered for combining
// during the common subexpression elimination.
-extern void OptimizeCSE(Graph* g, std::function<bool(const Node*)> consider_fn);
+//
+// Returns true if and only if 'g' is mutated.
+extern bool OptimizeCSE(Graph* g, std::function<bool(const Node*)> consider_fn);
} // namespace tensorflow
diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc
index 66118489dd..45c1152a33 100644
--- a/tensorflow/core/ops/array_grad.cc
+++ b/tensorflow/core/ops/array_grad.cc
@@ -237,4 +237,27 @@ Status ListToArrayGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("_ListToArray", ListToArrayGrad);
+Status FillGrad(const AttrSlice& attrs, FunctionDef* g) {
+ *g = FDH::Define(
+ // Arg defs
+ {"dims: int32", "x: T", "dy: T"},
+ // Ret val defs
+ {"d_dims: int32", "dx: T"},
+ // Attr defs
+ {"T: {float, double}"},
+ // Nodes
+ {
+ {{"d_dims"}, "ZerosLike", {"dims"}, {{"T", DT_INT32}}},
+ FDH::Const("zero", 0),
+ {{"rank"}, "Rank", {"dy"}, {{"T", "$T"}}},
+ FDH::Const("one", 1),
+ {{"r"}, "Range", {"zero", "rank", "one"}, {}},
+ // dx = sum(dy)
+ {{"dx"}, "Sum", {"dy", "r"}, {{"T", "$T"}}},
+ });
+ VLOG(1) << "FillGrad " << DebugString(*g);
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("Fill", FillGrad);
+
} // end namespace tensorflow
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 865d30a516..5b278d1d5f 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -29,7 +29,6 @@ import tensorflow as tf
from tensorflow.python.framework import function
# pylint: disable=unused-import
from tensorflow.python.ops import functional_ops
-
# pylint: enable=unused-import
@@ -384,9 +383,22 @@ class UnrollLSTMTest(tf.test.TestCase):
return LSTMLoop10(weights, inp)
+ def _OptimizerOptions(self):
+ ret = []
+ for cse in [False, True]:
+ for inline in [False, True]:
+ for cfold in [False, True]:
+ ret.append(tf.ConfigProto(graph_options=tf.GraphOptions(
+ optimizer_options=tf.OptimizerOptions(
+ opt_level=tf.OptimizerOptions.L0,
+ do_common_subexpression_elimination=cse,
+ do_function_inlining=inline,
+ do_constant_folding=cfold))))
+ return ret
+
def testUnrollLSTM(self):
# Run one step of the unrolled lstm graph.
- def RunForward(mode):
+ def RunForward(mode, cfg=None):
g = tf.Graph()
start = time.time()
with g.as_default():
@@ -397,20 +409,22 @@ class UnrollLSTMTest(tf.test.TestCase):
finish = time.time()
print("time: ", finish - start, " txt size: ", len(str(gdef)),
"gdef bin size: ", len(gdef.SerializeToString()))
- with g.as_default(), tf.Session() as sess:
+ with g.as_default(), tf.Session(config=cfg) as sess:
return sess.run(m)
mv0 = RunForward("complete")
- mv1 = RunForward("cell")
- mv2 = RunForward("loop")
- mv3 = RunForward("loop10")
- self.assertAllClose(mv0, mv1, rtol=1e-4)
- self.assertAllClose(mv0, mv2, rtol=1e-4)
- self.assertAllClose(mv0, mv3, rtol=1e-4)
+ for cfg in self._OptimizerOptions():
+ print("cfg = ", cfg)
+ mv1 = RunForward("cell", cfg)
+ mv2 = RunForward("loop", cfg)
+ mv3 = RunForward("loop10", cfg)
+ self.assertAllClose(mv0, mv1, rtol=1e-4)
+ self.assertAllClose(mv0, mv2, rtol=1e-4)
+ self.assertAllClose(mv0, mv3, rtol=1e-4)
def testUnrollLSTMGrad(self):
# Run one step of the unrolled lstm graph.
- def RunForwardBackward(mode):
+ def RunForwardBackward(mode, cfg=None):
g = tf.Graph()
start = time.time()
with g.as_default():
@@ -423,16 +437,18 @@ class UnrollLSTMTest(tf.test.TestCase):
finish = time.time()
print("time: ", finish - start, " txt size: ", len(str(gdef)),
"gdef bin size: ", len(gdef.SerializeToString()))
- with g.as_default(), tf.Session() as sess:
+ with g.as_default(), tf.Session(config=cfg) as sess:
return sess.run(dw)
d0 = RunForwardBackward("complete")
- d1 = RunForwardBackward("cell")
- d2 = RunForwardBackward("loop")
- d3 = RunForwardBackward("loop10")
- self.assertAllClose(d0, d1, rtol=1e-4)
- self.assertAllClose(d0, d2, rtol=1e-4)
- self.assertAllClose(d0, d3, rtol=1e-4)
+ for cfg in self._OptimizerOptions():
+ print("cfg = ", cfg)
+ d1 = RunForwardBackward("cell", cfg)
+ d2 = RunForwardBackward("loop", cfg)
+ d3 = RunForwardBackward("loop10", cfg)
+ self.assertAllClose(d0, d1, rtol=1e-4)
+ self.assertAllClose(d0, d2, rtol=1e-4)
+ self.assertAllClose(d0, d3, rtol=1e-4)
if __name__ == "__main__":
diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py
index 939beb053d..6694ae59ba 100644
--- a/tensorflow/python/framework/gen_docs_combined.py
+++ b/tensorflow/python/framework/gen_docs_combined.py
@@ -121,7 +121,7 @@ _hidden_symbols = ["Event", "LogMessage", "Summary", "SessionLog", "xrange",
"HistogramProto", "ConfigProto", "NodeDef", "GraphDef",
"GPUOptions", "GraphOptions", "SessionInterface",
"BaseSession", "NameAttrList", "AttrValue",
- "TensorArray"]
+ "TensorArray", "OptimizerOptions"]
def main(unused_argv):
if not FLAGS.out_dir: