aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-04 18:57:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-04 19:00:21 -0700
commit76801dda9b4766d729ab88267ee47f48d05eafb7 (patch)
treecc7d07e0e6457509c1bd8b3a7a206fbb27451ad2 /tensorflow
parent35c8574e49aadcf16d009717e1d31fcce148db02 (diff)
Enable XLA fusions as a Grappler optimization.
PiperOrigin-RevId: 199230907
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/jit/BUILD46
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc161
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.cc161
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.h46
-rw-r--r--tensorflow/compiler/jit/xla_fusion_optimizer.cc321
-rw-r--r--tensorflow/compiler/jit/xla_fusion_optimizer.h49
-rw-r--r--tensorflow/compiler/jit/xla_fusion_optimizer_test.cc183
-rw-r--r--tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h2
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc100
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.h4
10 files changed, 889 insertions, 184 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 6d6c030a26..ab8cd8f4bc 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -25,6 +25,7 @@ load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
+load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
# Target that bundles up the XLA CPU and GPU JIT devices.
cc_library(
@@ -312,6 +313,7 @@ cc_library(
":common",
":shape_inference_helpers",
":union_find",
+ ":xla_cluster_util",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/kernels:parallel_check_op",
"//tensorflow/compiler/jit/legacy_flags:encapsulate_subgraphs_pass_flags",
@@ -333,6 +335,18 @@ cc_library(
)
cc_library(
+ name = "xla_cluster_util",
+ srcs = ["xla_cluster_util.cc"],
+ hdrs = ["xla_cluster_util.h"],
+ deps = [
+ "//tensorflow/compiler/jit/graphcycles",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:graph",
+ "//tensorflow/core/kernels:bounds_check",
+ ],
+)
+
+cc_library(
name = "union_find",
hdrs = ["union_find.h"],
)
@@ -408,6 +422,38 @@ tf_cc_test(
],
)
+cc_library(
+ name = "xla_fusion_optimizer",
+ srcs = ["xla_fusion_optimizer.cc"],
+ hdrs = ["xla_fusion_optimizer.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":common",
+ ":union_find",
+ ":xla_cluster_util",
+ "//tensorflow/compiler/jit/graphcycles",
+ "//tensorflow/core:core_cpu_base",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ ],
+)
+
+tf_cuda_cc_test(
+ name = "xla_fusion_optimizer_test",
+ srcs = ["xla_fusion_optimizer_test.cc"],
+ deps = [
+ ":common",
+ ":xla_cluster_util",
+ ":xla_fusion_optimizer",
+ "//tensorflow/core:graph",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/grappler/utils:grappler_test",
+ ],
+)
+
# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library.
cc_header_only_library(
name = "xla_jit_headers_lib",
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 07ee93d79e..74468266b9 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
#include "tensorflow/compiler/jit/union_find.h"
+#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -41,9 +42,6 @@ limitations under the License.
namespace tensorflow {
-const char* const kXlaClusterAttr = "_XlaCluster";
-const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation";
-
namespace {
// Returns true if, when executed in TensorFlow, `node` is guaranteed to forward
@@ -191,16 +189,6 @@ bool IsCompilableCall(const NodeDef& call_def,
return true;
}
-// Returns the DeviceType corresponding to 'device'.
-Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) {
- DeviceNameUtils::ParsedName parsed;
- if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
- return errors::Internal("Malformed assigned device '", device, "'");
- }
- *device_type = DeviceType(parsed.type);
- return Status::OK();
-}
-
// Tests whether `node` has a DT_RESOURCE typed input or output.
bool HasResourceInputOrOutput(const Node& node) {
return std::find(node.input_types().begin(), node.input_types().end(),
@@ -209,18 +197,11 @@ bool HasResourceInputOrOutput(const Node& node) {
DT_RESOURCE) != node.output_types().end();
}
-struct NodeCompare {
- bool operator()(const Node* a, const Node* b) const {
- return a->id() < b->id();
- }
-};
-using OrderedNodeSet = std::set<Node*, NodeCompare>;
-
// Returns true if the op can be decomposed into XLA ops for which
// there are fusable elemental implementations.
//
-// TODO(hpucha): Consider a black list instead of a white list as
-// implemented below.
+// TODO(hpucha): Remove this code since this functionality is subsumed by
+// Grappler XlaFusionOptimizer.
bool IsXlaFusable(const NodeDef& node) {
static const std::unordered_set<std::string>* elementwise_ops =
new std::unordered_set<std::string>(
@@ -390,7 +371,7 @@ Status FindCompilationCandidates(
for (Node* node : graph.op_nodes()) {
sorted_nodes.push_back(node);
}
- std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeCompare());
+ std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID());
for (Node* node : sorted_nodes) {
VLOG(2) << "Fuel: " << fuel;
@@ -405,9 +386,13 @@ Status FindCompilationCandidates(
DeviceType device_type("");
TF_RETURN_IF_ERROR(
- DeviceTypeOfDevice(node->assigned_device_name(), &device_type));
+ DeviceToDeviceType(node->assigned_device_name(), &device_type));
- if (is_compilable_fn && !is_compilable_fn(node, device_type)) continue;
+ if (is_compilable_fn && !is_compilable_fn(node, device_type)) {
+ VLOG(2) << "Compilation rejected node: not compilable " << node->name()
+ << ": " << node->type_string();
+ continue;
+ }
const XlaOpRegistry::DeviceRegistration* registration;
CHECK(
@@ -456,46 +441,6 @@ struct Cluster {
int representative = -1;
};
-// Returns a string describing how an edge from src to dst would
-// create a cycle.
-string DescribeCycle(const GraphCycles& cycles, const Graph& graph, int src,
- int dst) {
- int32 max_path_size = graph.num_node_ids() + 1;
- std::vector<int32> path(max_path_size);
- int32 path_size = cycles.FindPath(dst, src, max_path_size, path.data());
- if (path_size == 0) {
- return "";
- }
-
- auto node_name = [&cycles, &graph](int node_id) {
- if (!FastBoundsCheck(node_id, graph.num_node_ids())) {
- return string("(null)");
- }
- auto* node = graph.FindNodeId(node_id);
- if (node == nullptr) {
- return string("(null)");
- }
- return node->name();
- };
-
- string description;
- strings::StrAppend(&description, "Edge from ", node_name(src), " to ",
- node_name(dst), " would create a cycle.\n");
- path.resize(path_size);
- for (int32 node_id : path) {
- string ascii_art;
- if (node_id == dst) {
- ascii_art = "+-> ";
- } else if (node_id != src) {
- ascii_art = "| ";
- } else {
- ascii_art = "+-- ";
- }
- strings::StrAppend(&description, ascii_art, node_name(node_id), "\n");
- }
- return description;
-}
-
} // anonymous namespace
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
@@ -601,84 +546,13 @@ Status MarkForCompilationPass::RunImpl(
: Env::Default(),
is_compilable_fn, &compilation_candidates));
- GraphCycles cycles;
- for (int i = 0; i < graph->num_node_ids(); ++i) {
- // We rely on the node IDs in the cycle detection graph being consecutive
- // integers starting from 0.
- CHECK_EQ(i, cycles.NewNode());
+ if (compilation_candidates.empty()) {
+ VLOG(2) << "No compilable candidates";
+ return Status::OK();
}
- // Compute the loop structure of the graph.
- std::vector<ControlFlowInfo> control_flow_info;
- TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info));
-
- // The clustering code must avoid adding cycles to the graph to prevent
- // deadlock. However, the graph may contain loops, which would trigger the
- // cycle detection code. To handle loops, we alter the structure of the cycle
- // detection graph, disconnecting each loop from the enclosing graph.
- // Specifically, we:
- // * add a new "frame" node for each loop.
- // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges
- // to/from the corresponding frame node. In essence, we collapse the loop
- // into a single node for the purpose of cycle detection in the enclosing
- // graph.
- // * the body of the loop should now be disconnected from the rest of the
- // graph; we make it acyclic by breaking loop backedges (edges outgoing from
- // "NextIteration" nodes.
-
- // Map from frame name strings to node IDs in the cycle detection graph.
- std::unordered_map<string, int> frame_nodes;
-
- // Get the cycle graph node ID for frame 'frame_name', or add one if none
- // exists.
- auto GetOrAddFrameNodeId = [&frame_nodes, &cycles](const string& frame_name) {
- int& frame_id = frame_nodes.emplace(frame_name, -1).first->second;
- if (frame_id < 0) {
- // The emplace succeeded; we have not allocated a frame node yet.
- frame_id = cycles.NewNode();
- }
- return frame_id;
- };
-
- for (Edge const* edge : graph->edges()) {
- if (edge->dst()->IsEnter()) {
- // Lift edges to an "Enter" node to the corresponding frame node.
- const string& frame_name =
- control_flow_info[edge->dst()->id()].frame_name;
- int dst = GetOrAddFrameNodeId(frame_name);
- if (!cycles.InsertEdge(edge->src()->id(), dst)) {
- return errors::Internal(
- "Cycle detected when adding enter->frame edge: ",
- DescribeCycle(cycles, *graph, edge->src()->id(), dst));
- }
- continue;
- }
- if (edge->src()->IsExit()) {
- // Lift edges from an "Exit" node to the corresponding frame node.
- const string& frame_name =
- control_flow_info[edge->src()->id()].frame_name;
- int src = GetOrAddFrameNodeId(frame_name);
- if (!cycles.InsertEdge(src, edge->dst()->id())) {
- return errors::Internal(
- "Cycle detected when adding frame->exit edge: ",
- DescribeCycle(cycles, *graph, src, edge->dst()->id()));
- }
- // Drop the original edge.
- continue;
- }
- if (edge->src()->IsNextIteration()) {
- // Break loop back-edges.
- continue;
- }
- if (!cycles.InsertEdge(edge->src()->id(), edge->dst()->id())) {
- // This should never happen. All cycles in the graph should contain
- // a control flow operator.
- return errors::Internal(
- "Found cycle in graph without control flow operator during XLA "
- "compilation: ",
- DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id()));
- }
- }
+ GraphCycles cycles;
+ TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles));
// Each compilation candidate belongs to a cluster. The cluster's
// representative
@@ -696,6 +570,9 @@ Status MarkForCompilationPass::RunImpl(
// Repeatedly contract edges between clusters that are on the same device,
// provided the contraction would not create a cycle.
+ //
+ // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for
+ // example, from the Grappler fusion pass).
while (!worklist.empty()) {
int from = worklist.front()->Get().representative;
worklist.pop_front();
@@ -804,7 +681,7 @@ Status MarkForCompilationPass::RunImpl(
// compilation.
DeviceType device_type("");
TF_RETURN_IF_ERROR(
- DeviceTypeOfDevice(n->assigned_device_name(), &device_type));
+ DeviceToDeviceType(n->assigned_device_name(), &device_type));
const XlaOpRegistry::DeviceRegistration* registration;
XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration);
diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc
new file mode 100644
index 0000000000..70bd10336b
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_cluster_util.cc
@@ -0,0 +1,161 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/xla_cluster_util.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/graph/control_flow.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+const char* const kXlaClusterAttr = "_XlaCluster";
+const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation";
+
+namespace {
+// Returns a string describing how an edge from src to dst would
+// create a cycle.
+string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src,
+ int dst) {
+ int32 max_path_size = graph.num_node_ids() + 1;
+ std::vector<int32> path(max_path_size);
+ int32 path_size = cycles->FindPath(dst, src, max_path_size, path.data());
+ if (path_size == 0) {
+ return "";
+ }
+
+ auto node_name = [cycles, &graph](int node_id) {
+ if (!FastBoundsCheck(node_id, graph.num_node_ids())) {
+ return string("(null)");
+ }
+ auto* node = graph.FindNodeId(node_id);
+ if (node == nullptr) {
+ return string("(null)");
+ }
+ return node->name();
+ };
+
+ string description;
+ strings::StrAppend(&description, "Edge from ", node_name(src), " to ",
+ node_name(dst), " would create a cycle.\n");
+ path.resize(path_size);
+ for (int32 node_id : path) {
+ string ascii_art;
+ if (node_id == dst) {
+ ascii_art = "+-> ";
+ } else if (node_id != src) {
+ ascii_art = "| ";
+ } else {
+ ascii_art = "+-- ";
+ }
+ strings::StrAppend(&description, ascii_art, node_name(node_id), "\n");
+ }
+ return description;
+}
+} // namespace
+
+Status DeviceToDeviceType(const string& device, DeviceType* device_type) {
+ DeviceNameUtils::ParsedName parsed;
+ if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
+ return errors::Internal("Malformed assigned device '", device, "'");
+ }
+ *device_type = DeviceType(parsed.type);
+ return Status::OK();
+}
+
+Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) {
+ for (int i = 0; i < graph->num_node_ids(); ++i) {
+ // We rely on the node IDs in the cycle detection graph being consecutive
+ // integers starting from 0.
+ CHECK_EQ(i, cycles->NewNode());
+ }
+
+ // Compute the loop structure of the graph.
+ std::vector<ControlFlowInfo> control_flow_info;
+ TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info));
+
+ // The clustering code must avoid adding cycles to the graph to prevent
+ // deadlock. However, the graph may contain loops, which would trigger the
+ // cycle detection code. To handle loops, we alter the structure of the cycle
+ // detection graph, disconnecting each loop from the enclosing graph.
+ // Specifically, we:
+ // * add a new "frame" node for each loop.
+ // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges
+ // to/from the corresponding frame node. In essence, we collapse the loop
+ // into a single node for the purpose of cycle detection in the enclosing
+ // graph.
+ // * the body of the loop should now be disconnected from the rest of the
+ // graph; we make it acyclic by breaking loop backedges (edges outgoing from
+ // "NextIteration" nodes.
+
+ // Map from frame name strings to node IDs in the cycle detection graph.
+ std::unordered_map<string, int> frame_nodes;
+
+ // Get the cycle graph node ID for frame 'frame_name', or add one if none
+ // exists.
+ auto GetOrAddFrameNodeId = [&frame_nodes, cycles](const string& frame_name) {
+ int& frame_id = frame_nodes.emplace(frame_name, -1).first->second;
+ if (frame_id < 0) {
+ // The emplace succeeded; we have not allocated a frame node yet.
+ frame_id = cycles->NewNode();
+ }
+ return frame_id;
+ };
+
+ for (Edge const* edge : graph->edges()) {
+ if (edge->dst()->IsEnter()) {
+ // Lift edges to an "Enter" node to the corresponding frame node.
+ const string& frame_name =
+ control_flow_info[edge->dst()->id()].frame_name;
+ int dst = GetOrAddFrameNodeId(frame_name);
+ if (!cycles->InsertEdge(edge->src()->id(), dst)) {
+ return errors::Internal(
+ "Cycle detected when adding enter->frame edge: ",
+ DescribeCycle(cycles, *graph, edge->src()->id(), dst));
+ }
+ continue;
+ }
+ if (edge->src()->IsExit()) {
+ // Lift edges from an "Exit" node to the corresponding frame node.
+ const string& frame_name =
+ control_flow_info[edge->src()->id()].frame_name;
+ int src = GetOrAddFrameNodeId(frame_name);
+ if (!cycles->InsertEdge(src, edge->dst()->id())) {
+ return errors::Internal(
+ "Cycle detected when adding frame->exit edge: ",
+ DescribeCycle(cycles, *graph, src, edge->dst()->id()));
+ }
+ // Drop the original edge.
+ continue;
+ }
+ if (edge->src()->IsNextIteration()) {
+ // Break loop back-edges.
+ continue;
+ }
+ if (!cycles->InsertEdge(edge->src()->id(), edge->dst()->id())) {
+ // This should never happen. All cycles in the graph should contain
+ // a control flow operator.
+ return errors::Internal(
+ "Found cycle in graph without control flow operator during XLA "
+ "compilation: ",
+ DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id()));
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h
new file mode 100644
index 0000000000..5b673bdc27
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_cluster_util.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Contains utilities for clustering compilable graph nodes via XLA.
+
+#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
+#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
+
+#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
+#include "tensorflow/core/graph/algorithm.h"
+
+namespace tensorflow {
+
+// The attribute that marks nodes to be grouped into functions by the
+// encapsulate subgraphs pass.
+extern const char* const kXlaClusterAttr;
+
+// The attribute that marks nodes in a cluster to be placed outside the xla
+// compilation by the encapsulate subgraphs pass.
+extern const char* const kXlaOutsideCompilationAttr;
+
+using OrderedNodeSet = std::set<Node*, NodeComparatorID>;
+
+// Returns the DeviceType corresponding to 'device'.
+Status DeviceToDeviceType(const string& device, DeviceType* device_type);
+
+// Creates a graph representation to enable cycle detection when clustering.
+// This representation handles loops in graph by disconnecting each loop from
+// the enclosing graph.
+Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
new file mode 100644
index 0000000000..96016521ea
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc
@@ -0,0 +1,321 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
+
+#include <atomic>
+#include <deque>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
+#include "tensorflow/compiler/jit/union_find.h"
+#include "tensorflow/compiler/jit/xla_cluster_util.h"
+#include "tensorflow/core/common_runtime/shape_refiner.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+
+namespace tensorflow {
+
+// Is 'node' an operator that consumes only the shape of its input, not the
+// data itself?
+static bool IsShapeConsumerOp(const Node& node) {
+ return node.type_string() == "Shape" || node.type_string() == "ShapeN" ||
+ node.type_string() == "Rank" || node.type_string() == "Size";
+}
+
+// Returns true if the op can be decomposed into XLA ops for which
+// there are fusable elemental implementations.
+bool IsXlaFusable(const NodeDef& node) {
+ static const std::unordered_set<std::string>* elementwise_ops =
+ new std::unordered_set<std::string>(
+ {// tf2xla/kernels/aggregate_ops.cc
+ "AddN",
+ // tf2xla/kernels/binary_ops.cc
+ "Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv",
+ "FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift",
+ "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv",
+ "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference",
+ "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater",
+ "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad",
+ "SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual",
+ // tf2xla/kernels/unary_ops.cc
+ "ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin",
+ "Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp",
+ "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal",
+ "Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round",
+ "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt",
+ "Square", "Tan", "Tanh", "Real", "Imag",
+ // tf2xla/kernels/bcast_ops.cc
+ "BroadcastArgs", "BroadcastGradientArgs",
+ // tf2xla/kernels/bias_ops.cc
+ "BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/,
+ // tf2xla/kernels/cast_op.cc
+ "Cast",
+ // tf2xla/kernels/concat_op.cc
+ "Concat", "ConcatV2", "ConcatOffset",
+ // tf2xla/kernels/const_op.cc
+ "Const",
+ // tf2xla/kernels/elu_op.cc
+ "Elu", "EluGrad", "Selu", "SeluGrad",
+ // tf2xla/kernels/fill_op.cc
+ "Fill",
+ // tf2xla/kernels/identity_op.cc
+ "Identity", "IdentityN", "PreventGradient",
+ "StopGradient", /*"Snapshot",*/
+ // tf2xla/kernels/index_ops.cc
+ "ArgMax", "ArgMin",
+ // tf2xla/kernels/mirror_pad_op.cc
+ "MirrorPad",
+ // tf2xla/kernels/one_hot_op.cc
+ "OneHot",
+ // tf2xla/kernels/pack_op.cc
+ "Pack",
+ // tf2xla/kernels/pad_op.cc
+ "Pad", "PadV2",
+ // tf2xla/kernels/relu_op.cc
+ "Relu", "Relu6", "ReluGrad", "Relu6Grad",
+ // tf2xla/kernels/reshape_op.cc
+ "Reshape",
+ // tf2xla/kernels/reverse_op.cc
+ "Reverse", "ReverseV2",
+ // tf2xla/kernels/reverse_sequence_op.cc
+ "ReverseSequence",
+ // tf2xla/kernels/shape_op.cc
+ "Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze",
+ "ZerosLike", "OnesLike",
+ // tf2xla/kernels/slice_op.cc
+ "Slice",
+ // tf2xla/kernels/split_op.cc
+ "Split", "SplitV",
+ // tf2xla/kernels/strided_slice_op.cc
+ "StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign",
+ // tf2xla/kernels/tile_ops.cc
+ "Tile",
+ // tf2xla/kernels/transpose_op.cc
+ "Transpose", "InvertPermutation",
+ // tf2xla/kernels/unpack_op.cc
+ "Unpack"});
+
+ return elementwise_ops->count(node.op()) > 0;
+}
+
+Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster,
+ const grappler::GrapplerItem& item,
+ GraphDef* output) {
+ VLOG(2) << "Here at fusion optimizer";
+
+ // TODO(hpucha): Implement encapsulation and replacing with XlaLaunch op.
+ // Once that happens, the expected interaction between this optimizer and when
+ // the global_jit_level is set is as follows: Fusion optimizer will replace
+ // appropriate fusion clusters with XlaLaunch nodes. The remaining graph can
+ // be further compiled where possible via mark_for_compilation_pass. Note that
+ // this might lead to inefficient clustering, and it is best to use either the
+ // fusion optimizer or the global_jit flag, and not combine the two.
+
+ // Create a Graph out of GraphDef. This is required currently because the
+ // helpers around clustering, encapsulation etc work on graphs.
+ FunctionLibraryDefinition function_library(OpRegistry::Global(),
+ item.graph.library());
+ Graph graph(function_library);
+ ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
+ shape_refiner.set_require_shape_inference_fns(false);
+ shape_refiner.set_disable_constant_propagation(true);
+ ImportGraphDefOptions options;
+ // Graph optimization happens at the late stage of graph execution, when
+ // colocation constraints are already validated previously and the device
+ // placement of nodes has also completed, so there is no need to validate
+ // colocation constraints again.
+ options.validate_colocation_constraints = false;
+ options.validate_shape = false;
+ TF_RETURN_IF_ERROR(
+ ImportGraphDef(options, item.graph, &graph, &shape_refiner));
+
+ // Collect nodes that can be fused via XLA, while ignoring those that
+ // explicitly ask for XLA: (*) nodes that are marked to be compiled
+ // explicitly. (*) nodes assigned to XLA device.
+ OrderedNodeSet compilation_candidates;
+ for (Node* node : graph.op_nodes()) {
+ // If there is a _XlaCompile annotation, ignore the node if it is
+ // true. Nodes are marked with this attr via experimental_jit_scope, and
+ // will be handled by the mark_for_compilation pass.
+ bool compile = false;
+ Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
+ if (status.ok() && compile) {
+ continue;
+ }
+ // If there is already a _XlaCluster annotation, ignore the node. Nodes are
+ // marked with this attr to indicate they are already part of a cluster and
+ // hence ignored.
+ status = GetNodeAttr(node->attrs(), kXlaClusterAttr, &compile);
+ if (status.ok()) {
+ continue;
+ }
+
+ // If there is an explicit XLA device placement, ignore the node.
+ DeviceType device_type("");
+ TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type));
+ if (device_type.type_string().find("XLA") != string::npos) continue;
+
+ // Assume all fusable ops are registered.
+ // TODO(hpucha): Check for registration if possible.
+ if (!IsXlaFusable(node->def())) {
+ continue;
+ }
+
+ compilation_candidates.insert(node);
+ }
+
+ if (compilation_candidates.empty()) {
+ VLOG(2) << "No compilable candidates";
+ *output = item.graph;
+ return Status::OK();
+ }
+
+ GraphCycles cycles;
+ TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles));
+
+ // TODO(hpucha): Make clustering more robust. There are two known issues that
+ // we need to mitigate: (a) Non-resource variables can cause deadlocks
+ // when clustering changes order of execution. See b/77263461 for a specific
+ // example. (b) Queue operations can also cause deadlocks. See b/77261498 for
+ // example.
+
+ struct Cluster {
+ // Identifies the node that represents this cluster in the cycle detection
+ // graph.
+ int representative = -1;
+ };
+
+ // Each compilation candidate belongs to a cluster. The cluster's
+ // representative names the node in the 'cycles' graph that represents the
+ // cluster.
+ std::vector<UnionFind<Cluster>> clusters(graph.num_node_ids());
+ std::deque<UnionFind<Cluster>*> worklist;
+ for (Node* node : compilation_candidates) {
+ Cluster& cluster = clusters[node->id()].Get();
+ cluster.representative = node->id();
+ worklist.push_back(&clusters[node->id()]);
+ }
+
+ // Repeatedly contract edges between clusters that are on the same device,
+ // provided the contraction would not create a cycle. This is a simplified
+ // version of the clustering in mark_for_compilation_pass that also deals with
+ // nodes that are explicitly tagged to be compiled/clustered.
+ while (!worklist.empty()) {
+ int from = worklist.front()->Get().representative;
+ worklist.pop_front();
+
+ Node* node_from = graph.FindNodeId(from);
+ if (node_from->IsControlFlow()) {
+ // Control flow nodes aren't compilation candidates and should never
+ // appear.
+ return errors::Internal(
+ "Found control flow node in clustering worklist: ",
+ node_from->type_string());
+ }
+ for (int to : cycles.Successors(from)) {
+ if (to >= graph.num_node_ids()) {
+ // Node is a "frame" node that is present only in the cycle detection
+ // graph. No clustering is possible.
+ continue;
+ }
+ Node* node_to = graph.FindNodeId(to);
+ if (compilation_candidates.find(node_to) ==
+ compilation_candidates.cend()) {
+ continue;
+ }
+
+ // Do not cluster across devices.
+ if (node_from->def().device() != node_to->def().device()) {
+ VLOG(2) << "Devices " << node_from->def().device() << " "
+ << node_to->def().device();
+ VLOG(2) << "Device names " << node_from->assigned_device_name() << " "
+ << node_to->assigned_device_name();
+ continue;
+ }
+
+ // Ops that consume shapes cannot be the root of a cluster. This is an
+ // optimization.
+ if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) {
+ continue;
+ }
+
+ // If contracting the edge would create a cycle, bail out.
+ // However, just because we can't merge the clusters now does not mean
+ // we won't be able to merge them in the future.
+ // e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge
+ // 1->3. But if we first contract 1->2 then we can later contract 1->3.
+ if (!cycles.ContractEdge(from, to)) continue;
+
+ // Merge the clusters. ContractEdge uses 'from' as the number of the
+ // merged node, so make sure 'from' is the chosen representative.
+ clusters[from].Merge(&clusters[to]);
+
+ worklist.push_back(&clusters[from]);
+ break;
+ }
+ }
+
+ // Count the number of non-trivial elements in each cluster.
+ std::vector<int> effective_cluster_sizes(graph.num_node_ids());
+ for (const Node* n : compilation_candidates) {
+ int cluster = clusters[n->id()].Get().representative;
+ // Identity nodes will be removed if the node gets marked for compilation.
+ // Therefore we don't want to count them towards the effective cluster size.
+ if (n->def().op() != "Identity") {
+ effective_cluster_sizes[cluster]++;
+ }
+ }
+
+ const int min_cluster_size = 2;
+ int num_clusters = 0;
+ for (auto size : effective_cluster_sizes) {
+ if (size >= min_cluster_size) {
+ VLOG(3) << "Cluster " << num_clusters << " " << size;
+ num_clusters++;
+ }
+ }
+
+ // Names for each cluster.
+ std::unordered_map<int, string> cluster_names;
+ // Sequence number generator to ensure clusters have unique names.
+ static std::atomic<int64> cluster_sequence_num;
+
+ for (Node* n : compilation_candidates) {
+ int cluster = clusters[n->id()].Get().representative;
+
+ // Compile if this is a cluster of >= min_cluster_size compilable operators.
+ if (effective_cluster_sizes[cluster] >= min_cluster_size) {
+ string& name = cluster_names[cluster];
+
+ if (name.empty()) {
+ name = strings::StrCat("cluster_", cluster_sequence_num++);
+ }
+ n->AddAttr(kXlaClusterAttr, name);
+ VLOG(3) << "Assigning node " << n->name() << " to cluster " << name;
+ }
+ }
+
+ graph.ToGraphDef(output);
+ return Status::OK();
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(XlaFusionOptimizer, "xla-fusion");
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.h b/tensorflow/compiler/jit/xla_fusion_optimizer.h
new file mode 100644
index 0000000000..3d2309e782
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_fusion_optimizer.h
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
+#define TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+
+// Optimizes graphs by fusing ops where possible, resulting in more efficient
+// execution.
+class XlaFusionOptimizer : public grappler::CustomGraphOptimizer {
+ public:
+ XlaFusionOptimizer() {}
+ ~XlaFusionOptimizer() override {}
+
+ Status Init(
+ const RewriterConfig_CustomGraphOptimizer* config = nullptr) override {
+ return Status::OK();
+ }
+
+ string name() const override { return "xla-fusion"; };
+
+ Status Optimize(grappler::Cluster* cluster,
+ const grappler::GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(grappler::Cluster* cluster, const grappler::GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override {
+ // Nothing to do for XlaFusionOptimizer.
+ }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_
diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc
new file mode 100644
index 0000000000..5736760a87
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc
@@ -0,0 +1,183 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/xla_fusion_optimizer.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/xla_cluster_util.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace tensorflow {
+namespace {
+
+REGISTER_OP("UncompilableNullary").Output("o: float");
+REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
+
+class XlaFusionOptimizerTest : public grappler::GrapplerTest {
+ protected:
+ std::unordered_map<string, string> GetClusters(const GraphDef& graph) {
+ std::unordered_map<string, string> ids;
+ for (const NodeDef& node : graph.node()) {
+ string cluster;
+ if (GetNodeAttr(AttrSlice(node), kXlaClusterAttr, &cluster).ok()) {
+ CHECK(!cluster.empty());
+ ids[node.name()] = cluster;
+ }
+ }
+ return ids;
+ }
+};
+
+TEST_F(XlaFusionOptimizerTest, Chains) {
+ GraphDef graph;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a =
+ ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
+ Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
+ Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
+ Node* d =
+ ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
+ Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
+ ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
+ TF_ASSERT_OK(builder.ToGraphDef(&graph));
+ }
+ grappler::GrapplerItem item;
+ item.graph = graph;
+
+ XlaFusionOptimizer optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ auto clusters = GetClusters(output);
+ EXPECT_EQ(4, clusters.size());
+ EXPECT_EQ(clusters["B"], clusters["C"]);
+ EXPECT_EQ(clusters["E"], clusters["F"]);
+ EXPECT_NE(clusters["B"], clusters["E"]);
+ EXPECT_TRUE(clusters.find("A") == clusters.cend());
+ EXPECT_TRUE(clusters.find("D") == clusters.cend());
+}
+
+TEST_F(XlaFusionOptimizerTest, FusableOps) {
+ GraphDef graph;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a = ops::SourceOp(
+ "Placeholder",
+ builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT));
+ Node* b = ops::SourceOp(
+ "Placeholder",
+ builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT));
+
+ Node* c = ops::BinaryOp("Add", a, b, builder.opts().WithName("C"));
+ ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D"));
+ ops::UnaryOp("Abs", c, builder.opts().WithName("E"));
+
+ TF_ASSERT_OK(builder.ToGraphDef(&graph));
+ }
+ grappler::GrapplerItem item;
+ item.graph = graph;
+
+ XlaFusionOptimizer optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ auto clusters = GetClusters(output);
+ EXPECT_EQ(2, clusters.size());
+ EXPECT_EQ(clusters["C"], clusters["E"]);
+ EXPECT_TRUE(clusters.find("D") == clusters.cend());
+}
+
+TEST_F(XlaFusionOptimizerTest, IgnoreExplicitXLAAttrs) {
+ GraphDef graph;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a = ops::SourceOp(
+ "Placeholder",
+ builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT));
+ Node* b = ops::SourceOp(
+ "Placeholder",
+ builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT));
+
+ Node* c = ops::BinaryOp(
+ "Add", a, b,
+ builder.opts().WithName("C").WithDevice("/device:XLA_CPU"));
+ ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D"));
+ Node* e = ops::UnaryOp("Abs", c, builder.opts().WithName("E"));
+ ops::UnaryOp("Cos", e,
+ builder.opts().WithName("F").WithAttr(kXlaCompileAttr, true));
+
+ TF_ASSERT_OK(builder.ToGraphDef(&graph));
+ }
+ grappler::GrapplerItem item;
+ item.graph = graph;
+
+ XlaFusionOptimizer optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ auto clusters = GetClusters(output);
+ EXPECT_TRUE(clusters.empty());
+}
+
+TEST_F(XlaFusionOptimizerTest, UncompilableCycles) {
+ GraphDef graph;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a = ops::SourceOp("Const", builder.opts()
+ .WithName("A")
+ .WithAttr("dtype", DT_FLOAT)
+ .WithAttr("value", Tensor()));
+ Node* b =
+ ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
+ ops::BinaryOp("Mul", a, b, builder.opts().WithName("C"));
+
+ TF_ASSERT_OK(builder.ToGraphDef(&graph));
+ }
+ grappler::GrapplerItem item;
+ item.graph = graph;
+
+ XlaFusionOptimizer optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ auto clusters = GetClusters(output);
+ EXPECT_TRUE(clusters.empty());
+}
+
+TEST_F(XlaFusionOptimizerTest, CompilableCycles) {
+ GraphDef graph;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a = ops::SourceOp("Const", builder.opts()
+ .WithName("A")
+ .WithAttr("dtype", DT_FLOAT)
+ .WithAttr("value", Tensor()));
+ Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
+ ops::BinaryOp("Mul", a, b, builder.opts().WithName("C"));
+ TF_ASSERT_OK(builder.ToGraphDef(&graph));
+ }
+ grappler::GrapplerItem item;
+ item.graph = graph;
+
+ XlaFusionOptimizer optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ auto clusters = GetClusters(output);
+ EXPECT_EQ(3, clusters.size());
+ EXPECT_EQ(clusters["A"], clusters["B"]);
+ EXPECT_EQ(clusters["A"], clusters["C"]);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h
index 3148a5f809..0b8e0b692a 100644
--- a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h
+++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h
@@ -50,7 +50,7 @@ class CustomGraphOptimizerRegistrar {
#define REGISTER_GRAPH_OPTIMIZER_AS(MyCustomGraphOptimizerClass, name) \
namespace { \
- static CustomGraphOptimizerRegistrar \
+ static ::tensorflow::grappler::CustomGraphOptimizerRegistrar \
MyCustomGraphOptimizerClass##_registrar( \
[]() { return new MyCustomGraphOptimizerClass; }, (name)); \
} // namespace
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index e6622486eb..143d9dc1c6 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -217,23 +217,9 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
bool is_optimized = false;
GraphOptimizationResult optimization_result(item.id);
+ GraphOptimizer* fusion_optimizer = nullptr;
+ GraphOptimizer* sa_optimizer = nullptr;
- // ScopedAllocatorOptimizer must run last, so move it to the
- // end of optimizers and run only on the last iteration.
- {
- int sa_index = 0;
- for (; sa_index < optimizers.size(); ++sa_index) {
- if (optimizers[sa_index]->name() == "scoped_allocator_optimizer") {
- break;
- }
- }
- const int last_index = optimizers.size() - 1;
- if (sa_index < last_index) {
- optimizers[last_index].swap(optimizers[sa_index]);
- }
- }
-
- const int last_iteration = NumIterations(cfg_) - 1;
for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) {
VLOG(4) << "Starting optimization iteration " << iteration + 1;
@@ -241,37 +227,40 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
// Some optimizers can run only once.
if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
// Some must run only on the last iteration.
- if (optimizer->name() == "scoped_allocator_optimizer" &&
- iteration != last_iteration)
+ if (optimizer->name() == "scoped_allocator_optimizer") {
+ if (sa_optimizer == nullptr) sa_optimizer = optimizer.get();
+ continue;
+ }
+ if (optimizer->name() == "xla-fusion") {
+ if (fusion_optimizer == nullptr) fusion_optimizer = optimizer.get();
continue;
-
- uint64 start_us = Env::Default()->NowMicros();
- // This swaps the current optimized_graph into optimized item and
- // resets optimized_graph to an empty graph.
- optimized_graph->Swap(&optimized_item.graph);
- *optimized_graph = GraphDef();
- Status status =
- optimizer->Optimize(cluster, optimized_item, optimized_graph);
- uint64 end_us = Env::Default()->NowMicros();
-
- string result;
- if (!status.ok()) {
- optimized_graph->Swap(&optimized_item.graph);
- result = status.ToString();
- } else {
- is_optimized = true;
- float duration_ms = (end_us - start_us) / 1000.0f;
- result = strings::StrCat(
- PrintSizesBeforeAfter(optimized_item.graph, *optimized_graph),
- ", time = ", duration_ms, "ms.");
}
- VLOG(4) << optimizer->name() << ": " << result;
- OptimizerResult optimizer_result{optimizer->name(), result};
- optimization_result.results.push_back(optimizer_result);
+ Status status = RunOptimizer(optimizer.get(), cluster, &optimized_item,
+ optimized_graph, &optimization_result);
+ if (status.ok()) is_optimized = true;
}
}
+ // Run fusion optimizer if requested after all other optimizers since: 1) it
+ // doesn't need to be called more than once. 2) we don't want subsequent
+ // optimization passes to break the fusion clusters. We could potentially
+ // encapsulate the fusion clusters right away, but that will prevent a lot of
+ // optimizations from taking place since we don't have shape inference for
+ // functions, and we can't optimize across function boundaries.
+ if (fusion_optimizer != nullptr) {
+ Status status = RunOptimizer(fusion_optimizer, cluster, &optimized_item,
+ optimized_graph, &optimization_result);
+ if (status.ok()) is_optimized = true;
+ }
+
+ // ScopedAllocatorOptimizer must run last.
+ if (sa_optimizer != nullptr) {
+ Status status = RunOptimizer(sa_optimizer, cluster, &optimized_item,
+ optimized_graph, &optimization_result);
+ if (status.ok()) is_optimized = true;
+ }
+
// Record graph optimization result.
optimization_results_.push_back(optimization_result);
@@ -286,6 +275,35 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
return Status::OK();
}
+Status MetaOptimizer::RunOptimizer(
+ GraphOptimizer* optimizer, Cluster* cluster, GrapplerItem* optimized_item,
+ GraphDef* optimized_graph, GraphOptimizationResult* optimization_result) {
+ uint64 start_us = Env::Default()->NowMicros();
+ // This swaps the current optimized_graph into optimized item and
+ // resets optimized_graph to an empty graph.
+ optimized_graph->Swap(&optimized_item->graph);
+ *optimized_graph = GraphDef();
+ Status status =
+ optimizer->Optimize(cluster, *optimized_item, optimized_graph);
+ uint64 end_us = Env::Default()->NowMicros();
+
+ string result;
+ if (!status.ok()) {
+ optimized_graph->Swap(&optimized_item->graph);
+ result = status.ToString();
+ } else {
+ float duration_ms = (end_us - start_us) / 1000.0f;
+ result = strings::StrCat(
+ PrintSizesBeforeAfter(optimized_item->graph, *optimized_graph),
+ ", time = ", duration_ms, "ms.");
+ }
+ VLOG(4) << optimizer->name() << ": " << result;
+
+ OptimizerResult optimizer_result{optimizer->name(), result};
+ optimization_result->results.push_back(optimizer_result);
+ return status;
+}
+
Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
optimization_results_.clear();
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index e736dd174e..151a54cbdf 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -72,6 +72,10 @@ class MetaOptimizer : public GraphOptimizer {
std::vector<OptimizerResult> results;
};
+ Status RunOptimizer(GraphOptimizer* optimizer, Cluster* cluster,
+ GrapplerItem* optimized_item, GraphDef* optimized_graph,
+ GraphOptimizationResult* optimization_result);
+
std::vector<GraphOptimizationResult> optimization_results_;
};