diff options
-rw-r--r-- | tensorflow/compiler/jit/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/compiler/jit/partially_decluster_pass.cc | 175 | ||||
-rw-r--r-- | tensorflow/compiler/jit/partially_decluster_pass.h | 31 | ||||
-rw-r--r-- | tensorflow/compiler/jit/partially_decluster_pass_test.cc | 133 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_cluster_util.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_cluster_util.h | 3 | ||||
-rw-r--r-- | tensorflow/compiler/tests/dense_layer_test.py | 7 | ||||
-rw-r--r-- | tensorflow/compiler/tests/jit_test.py | 5 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/const_analysis.cc | 20 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/const_analysis.h | 8 | ||||
-rw-r--r-- | tensorflow/core/ops/compat/ops_history.v1.pbtxt | 127 | ||||
-rw-r--r-- | tensorflow/core/ops/ops.pbtxt | 51 |
12 files changed, 498 insertions, 68 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index de7cd26d1d..a989f15a1c 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -395,6 +395,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", ], ) @@ -480,6 +481,7 @@ tf_cc_test( ":common", ":compilation_passes", ":xla_cluster_util", + ":xla_gpu_device", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", @@ -496,6 +498,8 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/grappler/optimizers/data:graph_utils", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 584c963f71..10fc9e85d9 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -14,8 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/partially_decluster_pass.h" +#include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -130,30 +133,47 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { return Status::OK(); } -} // namespace -Status PartiallyDeclusterPass::Run( - const GraphOptimizationPassOptions& options) { - // NB! In this pass we assume the only XLA-auto-clusterable operations that - // may have side effects are resource variable operations so we don't cluster - // those. The pass will have to be updated if this assumption becomes - // invalid. - - Graph* graph = options.graph->get(); +bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } +// Clones nodes to outside their cluster to avoid device-to-host copies. For +// instance, converts this: +// +// ..... +// | +// v +// A_Clustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// to: +// +// ..... +// | | +// | +-------------+ +// | | +// v v +// A_Clustered A_Unclustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// where the ===> arrow has a hostmem source and destination and would entail a +// device to host copy if the source and destination were not in the same XLA +// cluster. +Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { // When deciding whether to decluster a particular node, we base our decision // on if we've decided that some of its consumers have to be declustered too. // Iterating the graph in post-order guarantees that consumers have been // visited before producers. std::vector<Node*> post_order; GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), - /*edge_filter=*/[](const Edge& edge) { - return !edge.src()->IsNextIteration(); - }); + /*edge_filter=*/NotBackedge); gtl::FlatSet<Node*> nodes_to_partially_decluster; - TF_RETURN_IF_ERROR(FindNodesToDecluster( - **options.graph, &nodes_to_partially_decluster, post_order)); + TF_RETURN_IF_ERROR( + FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); if (VLOG_IS_ON(3)) { for (Node* n : post_order) { @@ -170,10 +190,133 @@ Status PartiallyDeclusterPass::Run( } nodes_to_partially_decluster.clear(); - TF_RETURN_IF_ERROR(FindNodesToDecluster( - **options.graph, &nodes_to_partially_decluster, post_order)); + TF_RETURN_IF_ERROR( + FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); CHECK(nodes_to_partially_decluster.empty()); return Status::OK(); } + +bool IsIntraClusterEdge(const Edge& edge) { + absl::optional<absl::string_view> src_cluster_name = + GetXlaClusterForNode(*edge.src()); + absl::optional<absl::string_view> dst_cluster_name = + GetXlaClusterForNode(*edge.dst()); + return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name; +} + +Status MustCompileNode(const Node* n, bool* result) { + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n->assigned_device_name(), &device_type)); + + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + *result = false; + } else { + *result = registration->requires_compilation; + } + + return Status::OK(); +} + +// Declusters nodes to reduce the number of times we think we need to recompile +// a TensorFlow graph. +// +// Abstractly, if we have a cluster of this form: +// +// x0 = arg0 +// x1 = arg1 +// ... +// shape = f(x0, x1, ...) +// result = Reshape(input=<something>, new_shape=shape) +// +// then pulling `f` out of the cluster may reduce the number of compilations and +// will never increase the number of compilations. +// +// We may reduce the number of compilations if f is many to one. For instance +// if f(x,y) = x-y then x=3,y=1 and x=4,y=2 will generate two different +// compilations if f is in the cluster but only one compilation if f is outside +// the cluster. +// +// Declustering f will increase the number of compilations only if f is a +// one-to-many "function" i.e. isn't a function at all. RNG is one possible +// example, depending on how we look at it. But we never create clusters where +// such f's would be marked as must-be-constant. +// +// We assume here that the extra repeated (repeated compared to a clustered f +// where it will always be constant folded) host-side computation of f does not +// regress performance in any significant manner. We will have to revisit this +// algorith with a more complex cost model if this assumption turns out to be +// incorrect. +Status DeclusterNodesToReduceRecompilations(Graph* graph) { + std::vector<bool> compile_time_const_nodes(graph->num_node_ids()); + TF_RETURN_IF_ERROR(BackwardsConstAnalysis( + *graph, nullptr, &compile_time_const_nodes, IsIntraClusterEdge)); + + std::vector<Node*> rpo; + GetReversePostOrder(*graph, &rpo, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/NotBackedge); + for (Node* n : rpo) { + if (!compile_time_const_nodes[n->id()]) { + continue; + } + + absl::string_view cluster_name = *GetXlaClusterForNode(*n); + bool node_on_cluster_edge = + absl::c_all_of(n->in_edges(), [&](const Edge* e) { + absl::optional<absl::string_view> incoming_cluster = + GetXlaClusterForNode(*e->src()); + return !incoming_cluster || *incoming_cluster != cluster_name; + }); + + // We don't want to decluster F in a graph like + // + // Input -> OP -> Shape -> F -> Reshape + // + // Doing so will break up the cluster. Even if we were okay with breaking + // up the cluster we will at least have to relabel the two clusters to have + // different cluster names. + // + // We may want to revisit this in the future: we may have cases where OP is + // a small computation that does not benefit from XLA while XLA can optimize + // everything that follows the Reshape. In these cases it may be wise to + // remove Input, OP, Shape and F from the cluster, if F is a many-to-one + // function. + // + // Note that we do do the right thing for graphs like: + // + // Input -> F0 -> F1 -> Reshape + // + // Since we iterate in RPO, we'll first encounter F0, decluster it, then + // encounter F1, decluster it and so on. + if (node_on_cluster_edge) { + bool must_compile_node; + TF_RETURN_IF_ERROR(MustCompileNode(n, &must_compile_node)); + if (!must_compile_node) { + VLOG(3) << "Declustering must-be-constant node " << n->name(); + RemoveFromXlaCluster(n); + } + } + } + + return Status::OK(); +} + +} // namespace + +Status PartiallyDeclusterPass::Run( + const GraphOptimizationPassOptions& options) { + // NB! In this pass we assume the only XLA-auto-clusterable operations that + // may have side effects are resource variable operations so we don't cluster + // those. The pass will have to be updated if this assumption becomes + // invalid. + + Graph* graph = options.graph->get(); + + TF_RETURN_IF_ERROR(PartiallyDeclusterToRemoveDeviceToHostCopies(graph)); + TF_RETURN_IF_ERROR(DeclusterNodesToReduceRecompilations(graph)); + + return Status::OK(); +} } // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.h b/tensorflow/compiler/jit/partially_decluster_pass.h index 6949b5028e..cfc4ddb563 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.h +++ b/tensorflow/compiler/jit/partially_decluster_pass.h @@ -20,34 +20,11 @@ limitations under the License. namespace tensorflow { -// Clones nodes from within a cluster to outside the cluster if profitable. +// Clones or moves nodes from within a cluster to outside the cluster if +// profitable. There are two reasons why we do this: // -// Today this only clones to avoid device-to-host copies, but in the future we -// may consider other reasons to clone. For instance, we convert this: -// -// ..... -// | -// v -// A_Clustered ====> C_Unclustered -// | -// v -// B_Clustered -// -// to: -// -// ..... -// | | -// | +-------------+ -// | | -// v v -// A_Clustered A_Unclustered ====> C_Unclustered -// | -// v -// B_Clustered -// -// where the ===> arrow has a hostmem source and destination and would entail a -// device to host copy if the source and destination were not in the same XLA -// cluster. +// - Reducing device-to-host copies. +// - Reducing the number of XLA recompilations. class PartiallyDeclusterPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override; diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index f61a955c22..35872daa65 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/partially_decluster_pass.h" +#include "absl/memory/memory.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" @@ -31,6 +32,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -82,7 +84,9 @@ Status PartiallyDecluster(std::unique_ptr<Graph>* graph) { // Assign all nodes to the CPU device. static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; for (Node* n : (*graph)->nodes()) { - n->set_assigned_device_name(kCpuDevice); + if (n->assigned_device_name().empty()) { + n->set_assigned_device_name(kCpuDevice); + } } GraphOptimizationPassOptions opt_options; @@ -91,8 +95,8 @@ Status PartiallyDecluster(std::unique_ptr<Graph>* graph) { return pass.Run(opt_options); } -const Node* FindNodeByName(const Graph& graph, const string& name) { - for (const Node* node : graph.nodes()) { +Node* FindNodeByName(const Graph& graph, const string& name) { + for (Node* node : graph.nodes()) { if (node->name() == name) { return node; } @@ -279,5 +283,128 @@ TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) { "ClusteredProducer0/declustered"); EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input"); } + +void AddToCluster(absl::Span<Node* const> nodes, + absl::string_view cluster_name) { + for (Node* n : nodes) { + n->AddAttr(kXlaClusterAttr, string(cluster_name)); + } +} + +TEST(PartiallyDeclusterPassTest, DeclusterMustBeConstantNodes) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({shape.node(), reshape.node()}, "cluster_0"); + + auto graph = absl::make_unique<Graph>(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + const Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + + EXPECT_EQ(GetXlaClusterForNode(*n), absl::nullopt); +} + +TEST(PartiallyDeclusterPassTest, DeclusteringStopsAtMetadataOps) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input_a = ops::Placeholder(s.WithOpName("input_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output input_b = ops::Placeholder(s.WithOpName("shape_b"), DT_FLOAT, + ops::Placeholder::Attrs{}); + Output mul = ops::Mul(s.WithOpName("mul"), input_b, input_b); + Output shape_of_mul = ops::Shape(s.WithOpName("shape_of_mul"), mul); + + Output shape = ops::Add(s.WithOpName("shape"), shape_of_mul, input_a); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({mul.node(), shape_of_mul.node(), shape.node(), reshape.node()}, + "cluster_0"); + + std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + const Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + + EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); +} + +TEST(PartiallyDeclusterPassTest, EdgeAcrossDifferentClusters) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({reshape.node()}, "cluster_0"); + AddToCluster({shape.node()}, "cluster_1"); + + std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + const Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + + EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_1"); +} + +TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output shape_a = ops::Placeholder(s.WithOpName("shape_a"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape_b = ops::Placeholder(s.WithOpName("shape_b"), DT_INT32, + ops::Placeholder::Attrs{}); + Output shape = ops::Add(s.WithOpName("shape"), shape_a, shape_b); + + Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"), + DT_FLOAT, ops::Placeholder::Attrs{}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), reshape_input, shape); + + AddToCluster({shape.node(), reshape.node()}, "cluster_0"); + + std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global()); + TF_ASSERT_OK(s.ToGraph(graph.get())); + + // This is needed to register the XLA_GPU device. + std::vector<Device*> devices; + TF_ASSERT_OK(DeviceFactory::AddDevices( + SessionOptions(), "/job:localhost/replica:0/task:0", &devices)); + + // Scope::ToGraph loses the assigned device name since it goes through + // GraphDef/NodeDef which does not have a field for the assigned device name. + Node* n = FindNodeByName(*graph, "shape"); + ASSERT_NE(n, nullptr); + n->set_assigned_device_name( + "/job:localhost/replica:0/task:0/device:XLA_GPU:0"); + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + + EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0"); + + for (Device* d : devices) { + delete d; + } +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 03380e9406..f85121ca27 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -210,6 +210,8 @@ void RemoveFromXlaCluster(NodeDef* node_def) { node_def->mutable_attr()->erase(kXlaClusterAttr); } +void RemoveFromXlaCluster(Node* node) { node->ClearAttr(kXlaClusterAttr); } + Status AdjustCycleDetectionGraphForResourceOps( const Graph* graph, const FunctionLibraryDefinition* flib_def, const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore, diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index debd9038c7..94c96ac7c5 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -53,6 +53,9 @@ absl::optional<absl::string_view> GetXlaClusterForNode(const Node& node); // Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute). void RemoveFromXlaCluster(NodeDef* node_def); +// Removes `node` its XLA cluster (by clearing its _XlaCluster attribute). +void RemoveFromXlaCluster(Node* node); + // Returns true if `node` has a DT_RESOURCE typed input or output. bool HasResourceInputOrOutput(const Node& node); diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index 04f3b3ef49..0af74c2d8f 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -58,7 +58,8 @@ class DenseLayerTest(test.TestCase): Dense layer should be compiled into a single XlaLaunch op in auto-jit mode. """ - os.environ["TF_XLA_FLAGS"] = ("--tf_xla_cpu_global_jit") + os.environ["TF_XLA_FLAGS"] = ( + "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", "")) config = config_pb2.ConfigProto() config.graph_options.optimizer_options.global_jit_level = ( config_pb2.OptimizerOptions.ON_1) @@ -77,7 +78,7 @@ class DenseLayerTest(test.TestCase): labels = GetRunMetadataLabels(run_metadata) self.assertEqual(1, XlaLaunchOpCount(labels)) - self.assertFalse(InLabels(labels, "ListDiff")) + self.assertFalse(InLabels(labels, "MatMult")) def testDenseLayerJitScopeDefinedShape(self): """Tests that the dense layer node is properly compiled in jit scope. @@ -128,7 +129,7 @@ class DenseLayerTest(test.TestCase): labels = GetRunMetadataLabels(run_metadata) self.assertEqual(2, XlaLaunchOpCount(labels)) - self.assertFalse(InLabels(labels, "ListDiff")) + self.assertFalse(InLabels(labels, "MatMult")) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 6e0db54b7a..0839fb123e 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -489,8 +489,9 @@ class ElementWiseFusionTest(test.TestCase): def testElementWiseClustering(self): arg0 = np.random.rand(2, 2).astype(np.float32) arg1 = np.random.rand(2, 2).astype(np.float32) - os.environ["TF_XLA_FLAGS"] = ("--tf_xla_fusion_only=true " - "--tf_xla_cpu_global_jit") + os.environ["TF_XLA_FLAGS"] = ( + "--tf_xla_fusion_only=true " + "--tf_xla_cpu_global_jit " + os.environ.get("TF_XLA_FLAGS", "")) tf_op, tf_count = self.simpleTest(arg0, arg1, config_pb2.OptimizerOptions.OFF) self.assertEqual(0, tf_count) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index e8673d7790..922ae7c79a 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -26,8 +26,9 @@ namespace tensorflow { // Backwards dataflow analysis that finds arguments to a graph that must be // compile-time constants. Status BackwardsConstAnalysis(const Graph& g, - std::vector<bool>* compile_time_const_args, - std::vector<bool>* compile_time_const_nodes) { + std::vector<bool>* compile_time_const_arg_indices, + std::vector<bool>* compile_time_const_nodes, + std::function<bool(const Edge&)> edge_filter) { // Operators that don't look at the data of their inputs, just the shapes. const std::unordered_set<string> metadata_ops = { "Rank", @@ -45,8 +46,7 @@ Status BackwardsConstAnalysis(const Graph& g, } Status status; - auto visit = [&status, &metadata_ops, compile_time_const_nodes, - compile_time_const_args](Node* node) { + auto visit = [&](Node* node) { if (!status.ok()) return; // If this is a metadata-only op, don't propagate the const requirement. @@ -59,13 +59,13 @@ Status BackwardsConstAnalysis(const Graph& g, int index; status = GetNodeAttr(node->attrs(), "index", &index); if (!status.ok()) return; - if (compile_time_const_args) { - (*compile_time_const_args)[index] = true; + if (compile_time_const_arg_indices) { + (*compile_time_const_arg_indices)[index] = true; } return; } for (const Edge* pred : node->in_edges()) { - if (!pred->IsControlEdge()) { + if (!pred->IsControlEdge() && edge_filter(*pred)) { (*compile_time_const_nodes)[pred->src()->id()] = true; } } @@ -88,7 +88,8 @@ Status BackwardsConstAnalysis(const Graph& g, for (Edge const* edge : node->in_edges()) { if (edge->dst_input() >= name_range->second.first && - edge->dst_input() < name_range->second.second) { + edge->dst_input() < name_range->second.second && + edge_filter(*edge)) { (*compile_time_const_nodes)[edge->src()->id()] = true; } } @@ -97,7 +98,8 @@ Status BackwardsConstAnalysis(const Graph& g, // Post-order traversal visits nodes in reverse topological order for an // acyclic graph. - DFS(g, {}, visit); + DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{}, + [](const Edge& edge) { return !edge.src()->IsNextIteration(); }); return status; } diff --git a/tensorflow/compiler/tf2xla/const_analysis.h b/tensorflow/compiler/tf2xla/const_analysis.h index af57e5a403..49b3c6d413 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.h +++ b/tensorflow/compiler/tf2xla/const_analysis.h @@ -32,9 +32,13 @@ namespace tensorflow { // // The ids of the nodes in `graph` that must be constant are returned in // `compile_time_const_nodes`, if `compile_time_const_nodes` is not null. -Status BackwardsConstAnalysis(const Graph& graph, +// +// Only propagate const-ness along edges for which `edge_filter` returns true. +Status BackwardsConstAnalysis(const Graph& g, std::vector<bool>* compile_time_const_arg_indices, - std::vector<bool>* compile_time_const_nodes); + std::vector<bool>* compile_time_const_nodes, + std::function<bool(const Edge&)> edge_filter = + [](const Edge& e) { return true; }); } // namespace tensorflow diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index c32d6f84f5..34e6b5560b 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -35790,6 +35790,42 @@ op { } } op { + name: "NonMaxSuppressionV2" + input_arg { + name: "boxes" + type_attr: "T" + } + input_arg { + name: "scores" + type_attr: "T" + } + input_arg { + name: "max_output_size" + type: DT_INT32 + } + input_arg { + name: "iou_threshold" + type: DT_FLOAT + } + output_arg { + name: "selected_indices" + type: DT_INT32 + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } +} +op { name: "NonMaxSuppressionV3" input_arg { name: "boxes" @@ -35817,6 +35853,46 @@ op { } } op { + name: "NonMaxSuppressionV3" + input_arg { + name: "boxes" + type_attr: "T" + } + input_arg { + name: "scores" + type_attr: "T" + } + input_arg { + name: "max_output_size" + type: DT_INT32 + } + input_arg { + name: "iou_threshold" + type: DT_FLOAT + } + input_arg { + name: "score_threshold" + type: DT_FLOAT + } + output_arg { + name: "selected_indices" + type: DT_INT32 + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } +} +op { name: "NonMaxSuppressionV4" input_arg { name: "boxes" @@ -35855,6 +35931,57 @@ op { } } op { + name: "NonMaxSuppressionV4" + input_arg { + name: "boxes" + type_attr: "T" + } + input_arg { + name: "scores" + type_attr: "T" + } + input_arg { + name: "max_output_size" + type: DT_INT32 + } + input_arg { + name: "iou_threshold" + type: DT_FLOAT + } + input_arg { + name: "score_threshold" + type: DT_FLOAT + } + output_arg { + name: "selected_indices" + type: DT_INT32 + } + output_arg { + name: "valid_outputs" + type: DT_INT32 + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } + attr { + name: "pad_to_max_output_size" + type: "bool" + default_value { + b: false + } + } +} +op { name: "NonMaxSuppressionWithOverlaps" input_arg { name: "overlaps" diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index aeb03c5952..c00c0030e6 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -17098,11 +17098,11 @@ op { name: "NonMaxSuppressionV2" input_arg { name: "boxes" - type: DT_FLOAT + type_attr: "T" } input_arg { name: "scores" - type: DT_FLOAT + type_attr: "T" } input_arg { name: "max_output_size" @@ -17116,16 +17116,29 @@ op { name: "selected_indices" type: DT_INT32 } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } } op { name: "NonMaxSuppressionV3" input_arg { name: "boxes" - type: DT_FLOAT + type_attr: "T" } input_arg { name: "scores" - type: DT_FLOAT + type_attr: "T" } input_arg { name: "max_output_size" @@ -17143,16 +17156,29 @@ op { name: "selected_indices" type: DT_INT32 } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } } op { name: "NonMaxSuppressionV4" input_arg { name: "boxes" - type: DT_FLOAT + type_attr: "T" } input_arg { name: "scores" - type: DT_FLOAT + type_attr: "T" } input_arg { name: "max_output_size" @@ -17175,6 +17201,19 @@ op { type: DT_INT32 } attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } + attr { name: "pad_to_max_output_size" type: "bool" default_value { |