aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar avijit-nervana <avijit.chakraborty@intel.com>2018-09-07 22:08:31 -0700
committerGravatar avijit-nervana <avijit.chakraborty@intel.com>2018-09-07 22:08:31 -0700
commitc26c5e1217944448f1f4c2b97626fc4d7d6406d3 (patch)
treeeef1276e70740301981c6b1c5992697251c6c3c9
parent2032512ba1de376baadfa9f3983e3edbc67a6731 (diff)
parente970a022ef6a3602dd5c9ea15afa96a2291880b1 (diff)
Merge branch 'master' into avijit/add-cpu-backend
-rw-r--r--tensorflow/compiler/jit/BUILD4
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass.cc175
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass.h31
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass_test.cc133
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.cc2
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.h3
-rw-r--r--tensorflow/compiler/tests/dense_layer_test.py7
-rw-r--r--tensorflow/compiler/tests/jit_test.py5
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.cc20
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.h8
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt127
-rw-r--r--tensorflow/core/ops/ops.pbtxt51
12 files changed, 498 insertions, 68 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index de7cd26d1d..a989f15a1c 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -395,6 +395,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
+ "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
],
)
@@ -480,6 +481,7 @@ tf_cc_test(
":common",
":compilation_passes",
":xla_cluster_util",
+ ":xla_gpu_device",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
@@ -496,6 +498,8 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "//tensorflow/core/grappler/optimizers/data:graph_utils",
+ "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc
index 584c963f71..10fc9e85d9 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass.cc
@@ -14,8 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
+#include "absl/algorithm/container.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
+#include "tensorflow/compiler/tf2xla/const_analysis.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/gtl/flatset.h"
@@ -130,30 +133,47 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) {
return Status::OK();
}
-} // namespace
-Status PartiallyDeclusterPass::Run(
- const GraphOptimizationPassOptions& options) {
- // NB! In this pass we assume the only XLA-auto-clusterable operations that
- // may have side effects are resource variable operations so we don't cluster
- // those. The pass will have to be updated if this assumption becomes
- // invalid.
-
- Graph* graph = options.graph->get();
+bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
+// Clones nodes to outside their cluster to avoid device-to-host copies. For
+// instance, converts this:
+//
+// .....
+// |
+// v
+// A_Clustered ====> C_Unclustered
+// |
+// v
+// B_Clustered
+//
+// to:
+//
+// .....
+// | |
+// | +-------------+
+// | |
+// v v
+// A_Clustered A_Unclustered ====> C_Unclustered
+// |
+// v
+// B_Clustered
+//
+// where the ===> arrow has a hostmem source and destination and would entail a
+// device to host copy if the source and destination were not in the same XLA
+// cluster.
+Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
// When deciding whether to decluster a particular node, we base our decision
// on if we've decided that some of its consumers have to be declustered too.
// Iterating the graph in post-order guarantees that consumers have been
// visited before producers.
std::vector<Node*> post_order;
GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
- /*edge_filter=*/[](const Edge& edge) {
- return !edge.src()->IsNextIteration();
- });
+ /*edge_filter=*/NotBackedge);
gtl::FlatSet<Node*> nodes_to_partially_decluster;
- TF_RETURN_IF_ERROR(FindNodesToDecluster(
- **options.graph, &nodes_to_partially_decluster, post_order));
+ TF_RETURN_IF_ERROR(
+ FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
if (VLOG_IS_ON(3)) {
for (Node* n : post_order) {
@@ -170,10 +190,133 @@ Status PartiallyDeclusterPass::Run(
}
nodes_to_partially_decluster.clear();
- TF_RETURN_IF_ERROR(FindNodesToDecluster(
- **options.graph, &nodes_to_partially_decluster, post_order));
+ TF_RETURN_IF_ERROR(
+ FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
CHECK(nodes_to_partially_decluster.empty());
return Status::OK();
}
+
+bool IsIntraClusterEdge(const Edge& edge) {
+ absl::optional<absl::string_view> src_cluster_name =
+ GetXlaClusterForNode(*edge.src());
+ absl::optional<absl::string_view> dst_cluster_name =
+ GetXlaClusterForNode(*edge.dst());
+ return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name;
+}
+
+Status MustCompileNode(const Node* n, bool* result) {
+ DeviceType device_type("");
+ TF_RETURN_IF_ERROR(
+ DeviceToDeviceType(n->assigned_device_name(), &device_type));
+
+ const XlaOpRegistry::DeviceRegistration* registration;
+ if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
+ *result = false;
+ } else {
+ *result = registration->requires_compilation;
+ }
+
+ return Status::OK();
+}
+
+// Declusters nodes to reduce the number of times we think we need to recompile
+// a TensorFlow graph.
+//
+// Abstractly, if we have a cluster of this form:
+//
+// x0 = arg0
+// x1 = arg1
+// ...
+// shape = f(x0, x1, ...)
+// result = Reshape(input=<something>, new_shape=shape)
+//
+// then pulling `f` out of the cluster may reduce the number of compilations and
+// will never increase the number of compilations.
+//
+// We may reduce the number of compilations if f is many to one. For instance
+// if f(x,y) = x-y then x=3,y=1 and x=4,y=2 will generate two different
+// compilations if f is in the cluster but only one compilation if f is outside
+// the cluster.
+//
+// Declustering f will increase the number of compilations only if f is a
+// one-to-many "function" i.e. isn't a function at all. RNG is one possible
+// example, depending on how we look at it. But we never create clusters where
+// such f's would be marked as must-be-constant.
+//
+// We assume here that the extra repeated (repeated compared to a clustered f
+// where it will always be constant folded) host-side computation of f does not
+// regress performance in any significant manner. We will have to revisit this
+// algorith with a more complex cost model if this assumption turns out to be
+// incorrect.
+Status DeclusterNodesToReduceRecompilations(Graph* graph) {
+ std::vector<bool> compile_time_const_nodes(graph->num_node_ids());
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
+ *graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge));
+
+ std::vector<Node*> rpo;
+ GetReversePostOrder(*graph, &rpo, /*stable_comparator=*/NodeComparatorName(),
+ /*edge_filter=*/NotBackedge);
+ for (Node* n : rpo) {
+ if (!compile_time_const_nodes[n->id()]) {
+ continue;
+ }
+
+ absl::string_view cluster_name = *GetXlaClusterForNode(*n);
+ bool node_on_cluster_edge =
+ absl::c_all_of(n->in_edges(), [&](const Edge* e) {
+ absl::optional<absl::string_view> incoming_cluster =
+ GetXlaClusterForNode(*e->src());
+ return !incoming_cluster || *incoming_cluster != cluster_name;
+ });
+
+ // We don't want to decluster F in a graph like
+ //
+ // Input -> OP -> Shape -> F -> Reshape
+ //
+ // Doing so will break up the cluster. Even if we were okay with breaking
+ // up the cluster we will at least have to relabel the two clusters to have
+ // different cluster names.
+ //
+ // We may want to revisit this in the future: we may have cases where OP is
+ // a small computation that does not benefit from XLA while XLA can optimize
+ // everything that follows the Reshape. In these cases it may be wise to
+ // remove Input, OP, Shape and F from the cluster, if F is a many-to-one
+ // function.
+ //
+ // Note that we do do the right thing for graphs like:
+ //
+ // Input -> F0 -> F1 -> Reshape
+ //
+ // Since we iterate in RPO, we'll first encounter F0, decluster it, then
+ // encounter F1, decluster it and so on.
+ if (node_on_cluster_edge) {
+ bool must_compile_node;
+ TF_RETURN_IF_ERROR(MustCompileNode(n, &must_compile_node));
+ if (!must_compile_node) {
+ VLOG(3) << "Declustering must-be-constant node " << n->name();
+ RemoveFromXlaCluster(n);
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+} // namespace
+
+Status PartiallyDeclusterPass::Run(
+ const GraphOptimizationPassOptions& options) {
+ // NB! In this pass we assume the only XLA-auto-clusterable operations that
+ // may have side effects are resource variable operations so we don't cluster
+ // those. The pass will have to be updated if this assumption becomes
+ // invalid.
+
+ Graph* graph = options.graph->get();
+
+ TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph));
+ TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph));
+
+ return Status::OK();
+}
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.h b/tensorflow/compiler/jit/partially_decluster_pass.h
index 6949b5028e..cfc4ddb563 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass.h
+++ b/tensorflow/compiler/jit/partially_decluster_pass.h
@@ -20,34 +20,11 @@ limitations under the License.
namespace tensorflow {
-// Clones nodes from within a cluster to outside the cluster if profitable.
+// Clones or moves nodes from within a cluster to outside the cluster if
+// profitable. There are two reasons why we do this:
//
-// Today this only clones to avoid device-to-host copies, but in the future we
-// may consider other reasons to clone. For instance, we convert this:
-//
-// .....
-// |
-// v
-// A_Clustered ====> C_Unclustered
-// |
-// v
-// B_Clustered
-//
-// to:
-//
-// .....
-// | |
-// | +-------------+
-// | |
-// v v
-// A_Clustered A_Unclustered ====> C_Unclustered
-// |
-// v
-// B_Clustered
-//
-// where the ===> arrow has a hostmem source and destination and would entail a
-// device to host copy if the source and destination were not in the same XLA
-// cluster.
+// - Reducing device-to-host copies.
+// - Reducing the number of XLA recompilations.
class PartiallyDeclusterPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
index f61a955c22..35872daa65 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
+#include "absl/memory/memory.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
@@ -31,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -82,7 +84,9 @@ Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
// Assign all nodes to the CPU device.
static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
for (Node* n : (*graph)->nodes()) {
- n->set_assigned_device_name(kCpuDevice);
+ if (n->assigned_device_name().empty()) {
+ n->set_assigned_device_name(kCpuDevice);
+ }
}
GraphOptimizationPassOptions opt_options;
@@ -91,8 +95,8 @@ Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
return pass.Run(opt_options);
}
-const Node* FindNodeByName(const Graph& graph, const string& name) {
- for (const Node* node : graph.nodes()) {
+Node* FindNodeByName(const Graph& graph, const string& name) {
+ for (Node* node : graph.nodes()) {
if (node->name() == name) {
return node;
}
@@ -279,5 +283,128 @@ TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) {
"ClusteredProducer0/declustered");
EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input");
}
+
+void AddToCluster(absl::Span<Node* const> nodes,
+ absl::string_view cluster_name) {
+ for (Node* n : nodes) {
+ n->AddAttr(kXlaClusterAttr, string(cluster_name));
+ }
+}
+
+TEST(PartiallyDeclusterPassTest, DeclusterMustBeConstantNodes) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
+
+ Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
+ DT_FLOAT, ops::Placeholder::Attrs{});
+ Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
+
+ AddToCluster({shape.node(), reshape.node()}, "cluster_0");
+
+ auto graph = absl::make_unique<Graph>(OpRegistry::Global());
+ TF_ASSERT_OK(s.ToGraph(graph.get()));
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+
+ const Node* n = FindNodeByName(*graph, "shape");
+ ASSERT_NE(n, nullptr);
+
+ EXPECT_EQ(GetXlaClusterForNode(*n), absl::nullopt);
+}
+
+TEST(PartiallyDeclusterPassTest, DeclusteringStopsAtMetadataOps) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input_a = ops::Placeholder(s.WithOpName("input_a"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output input_b = ops::Placeholder(s.WithOpName("shape_b"), DT_FLOAT,
+ ops::Placeholder::Attrs{});
+ Output mul = ops::Mul(s.WithOpName("mul"), input_b, input_b);
+ Output shape_of_mul = ops::Shape(s.WithOpName("shape_of_mul"), mul);
+
+ Output shape = ops::Add(s.WithOpName("shape"), shape_of_mul, input_a);
+
+ Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
+ DT_FLOAT, ops::Placeholder::Attrs{});
+ Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
+
+ AddToCluster({mul.node(), shape_of_mul.node(), shape.node(), reshape.node()},
+ "cluster_0");
+
+ std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
+ TF_ASSERT_OK(s.ToGraph(graph.get()));
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+
+ const Node* n = FindNodeByName(*graph, "shape");
+ ASSERT_NE(n, nullptr);
+
+ EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
+}
+
+TEST(PartiallyDeclusterPassTest, EdgeAcrossDifferentClusters) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
+
+ Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
+ DT_FLOAT, ops::Placeholder::Attrs{});
+ Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
+
+ AddToCluster({reshape.node()}, "cluster_0");
+ AddToCluster({shape.node()}, "cluster_1");
+
+ std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
+ TF_ASSERT_OK(s.ToGraph(graph.get()));
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+
+ const Node* n = FindNodeByName(*graph, "shape");
+ ASSERT_NE(n, nullptr);
+
+ EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_1");
+}
+
+TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32,
+ ops::Placeholder::Attrs{});
+ Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b);
+
+ Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
+ DT_FLOAT, ops::Placeholder::Attrs{});
+ Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape);
+
+ AddToCluster({shape.node(), reshape.node()}, "cluster_0");
+
+ std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
+ TF_ASSERT_OK(s.ToGraph(graph.get()));
+
+ // This is needed to register the XLA_GPU device.
+ std::vector<Device*> devices;
+ TF_ASSERT_OK(DeviceFactory::AddDevices(
+ SessionOptions(), "/job:localhost/replica:0/task:0", &devices));
+
+ // Scope::ToGraph loses the assigned device name since it goes through
+ // GraphDef/NodeDef which does not have a field for the assigned device name.
+ Node* n = FindNodeByName(*graph, "shape");
+ ASSERT_NE(n, nullptr);
+ n->set_assigned_device_name(
+ "/job:localhost/replica:0/task:0/device:XLA_GPU:0");
+
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+
+ EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
+
+ for (Device* d : devices) {
+ delete d;
+ }
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc
index 03380e9406..f85121ca27 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.cc
+++ b/tensorflow/compiler/jit/xla_cluster_util.cc
@@ -210,6 +210,8 @@ void RemoveFromXlaCluster(NodeDef* node_def) {
node_def->mutable_attr()->erase(kXlaClusterAttr);
}
+void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); }
+
Status AdjustCycleDetectionGraphForResourceOps(
const Graph* graph, const FunctionLibraryDefinition* flib_def,
const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h
index debd9038c7..94c96ac7c5 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.h
+++ b/tensorflow/compiler/jit/xla_cluster_util.h
@@ -53,6 +53,9 @@ absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node);
// Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute).
void RemoveFromXlaCluster(NodeDef* node_def);
+// Removes `node` its XLA cluster (by clearing its _XlaCluster attribute).
+void RemoveFromXlaCluster(Node* node);
+
// Returns true if `node` has a DT_RESOURCE typed input or output.
bool HasResourceInputOrOutput(const Node& node);
diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py
index 04f3b3ef49..0af74c2d8f 100644
--- a/tensorflow/compiler/tests/dense_layer_test.py
+++ b/tensorflow/compiler/tests/dense_layer_test.py
@@ -58,7 +58,8 @@ class DenseLayerTest(test.TestCase):
Dense layer should be compiled into a single XlaLaunch op in auto-jit mode.
"""
- os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit")
+ os.environ["TF_XLA_FLAGS"] = (
+ "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", ""))
config = config_pb2.ConfigProto()
config.graph_options.optimizer_options.global_jit_level = (
config_pb2.OptimizerOptions.ON_1)
@@ -77,7 +78,7 @@ class DenseLayerTest(test.TestCase):
labels = GetRunMetadataLabels(run_metadata)
self.assertEqual(1, XlaLaunchOpCount(labels))
- self.assertFalse(InLabels(labels, "ListDiff"))
+ self.assertFalse(InLabels(labels, "MatMult"))
def testDenseLayerJitScopeDefinedShape(self):
"""Tests that the dense layer node is properly compiled in jit scope.
@@ -128,7 +129,7 @@ class DenseLayerTest(test.TestCase):
labels = GetRunMetadataLabels(run_metadata)
self.assertEqual(2, XlaLaunchOpCount(labels))
- self.assertFalse(InLabels(labels, "ListDiff"))
+ self.assertFalse(InLabels(labels, "MatMult"))
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py
index 6e0db54b7a..0839fb123e 100644
--- a/tensorflow/compiler/tests/jit_test.py
+++ b/tensorflow/compiler/tests/jit_test.py
@@ -489,8 +489,9 @@ class ElementWiseFusionTest(test.TestCase):
def testElementWiseClustering(self):
arg0 = np.random.rand(2, 2).astype(np.float32)
arg1 = np.random.rand(2, 2).astype(np.float32)
- os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true "
- "--tf_xla_cpu_global_jit")
+ os.environ["TF_XLA_FLAGS"] = (
+ "--tf_xla_fusion_only=true "
+ "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", ""))
tf_op, tf_count = self.simpleTest(arg0, arg1,
config_pb2.OptimizerOptions.OFF)
self.assertEqual(0, tf_count)
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index e8673d7790..922ae7c79a 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -26,8 +26,9 @@ namespace tensorflow {
// Backwards dataflow analysis that finds arguments to a graph that must be
// compile-time constants.
Status BackwardsConstAnalysis(const Graph& g,
- std::vector<bool>* compile_time_const_args,
- std::vector<bool>* compile_time_const_nodes) {
+ std::vector<bool>* compile_time_const_arg_indices,
+ std::vector<bool>* compile_time_const_nodes,
+ std::function<bool(const Edge&)> edge_filter) {
// Operators that don't look at the data of their inputs, just the shapes.
const std::unordered_set<string> metadata_ops = {
"Rank",
@@ -45,8 +46,7 @@ Status BackwardsConstAnalysis(const Graph& g,
}
Status status;
- auto visit = [&status, &metadata_ops, compile_time_const_nodes,
- compile_time_const_args](Node* node) {
+ auto visit = [&](Node* node) {
if (!status.ok()) return;
// If this is a metadata-only op, don't propagate the const requirement.
@@ -59,13 +59,13 @@ Status BackwardsConstAnalysis(const Graph& g,
int index;
status = GetNodeAttr(node->attrs(), "index", &index);
if (!status.ok()) return;
- if (compile_time_const_args) {
- (*compile_time_const_args)[index] = true;
+ if (compile_time_const_arg_indices) {
+ (*compile_time_const_arg_indices)[index] = true;
}
return;
}
for (const Edge* pred : node->in_edges()) {
- if (!pred->IsControlEdge()) {
+ if (!pred->IsControlEdge() && edge_filter(*pred)) {
(*compile_time_const_nodes)[pred->src()->id()] = true;
}
}
@@ -88,7 +88,8 @@ Status BackwardsConstAnalysis(const Graph& g,
for (Edge const* edge : node->in_edges()) {
if (edge->dst_input() >= name_range->second.first &&
- edge->dst_input() < name_range->second.second) {
+ edge->dst_input() < name_range->second.second &&
+ edge_filter(*edge)) {
(*compile_time_const_nodes)[edge->src()->id()] = true;
}
}
@@ -97,7 +98,8 @@ Status BackwardsConstAnalysis(const Graph& g,
// Post-order traversal visits nodes in reverse topological order for an
// acyclic graph.
- DFS(g, {}, visit);
+ DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{},
+ [](const Edge& edge) { return !edge.src()->IsNextIteration(); });
return status;
}
diff --git a/tensorflow/compiler/tf2xla/const_analysis.h b/tensorflow/compiler/tf2xla/const_analysis.h
index af57e5a403..49b3c6d413 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.h
+++ b/tensorflow/compiler/tf2xla/const_analysis.h
@@ -32,9 +32,13 @@ namespace tensorflow {
//
// The ids of the nodes in `graph` that must be constant are returned in
// `compile_time_const_nodes`, if `compile_time_const_nodes` is not null.
-Status BackwardsConstAnalysis(const Graph& graph,
+//
+// Only propagate const-ness along edges for which `edge_filter` returns true.
+Status BackwardsConstAnalysis(const Graph& g,
std::vector<bool>* compile_time_const_arg_indices,
- std::vector<bool>* compile_time_const_nodes);
+ std::vector<bool>* compile_time_const_nodes,
+ std::function<bool(const Edge&)> edge_filter =
+ [](const Edge& e) { return true; });
} // namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index c32d6f84f5..34e6b5560b 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -35790,6 +35790,42 @@ op {
}
}
op {
+ name: "NonMaxSuppressionV2"
+ input_arg {
+ name: "boxes"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scores"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "iou_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+}
+op {
name: "NonMaxSuppressionV3"
input_arg {
name: "boxes"
@@ -35817,6 +35853,46 @@ op {
}
}
op {
+ name: "NonMaxSuppressionV3"
+ input_arg {
+ name: "boxes"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scores"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "iou_threshold"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "score_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+}
+op {
name: "NonMaxSuppressionV4"
input_arg {
name: "boxes"
@@ -35855,6 +35931,57 @@ op {
}
}
op {
+ name: "NonMaxSuppressionV4"
+ input_arg {
+ name: "boxes"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scores"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "iou_threshold"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "score_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "valid_outputs"
+ type: DT_INT32
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "pad_to_max_output_size"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+}
+op {
name: "NonMaxSuppressionWithOverlaps"
input_arg {
name: "overlaps"
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index aeb03c5952..c00c0030e6 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -17098,11 +17098,11 @@ op {
name: "NonMaxSuppressionV2"
input_arg {
name: "boxes"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "scores"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "max_output_size"
@@ -17116,16 +17116,29 @@ op {
name: "selected_indices"
type: DT_INT32
}
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
}
op {
name: "NonMaxSuppressionV3"
input_arg {
name: "boxes"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "scores"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "max_output_size"
@@ -17143,16 +17156,29 @@ op {
name: "selected_indices"
type: DT_INT32
}
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
}
op {
name: "NonMaxSuppressionV4"
input_arg {
name: "boxes"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "scores"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "max_output_size"
@@ -17175,6 +17201,19 @@ op {
type: DT_INT32
}
attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
name: "pad_to_max_output_size"
type: "bool"
default_value {