diff options
author | Tong Shen <endlessroad@google.com> | 2018-09-12 16:33:55 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-12 16:37:13 -0700 |
commit | acc32e741935545d8e600a67361c388d14556538 (patch) | |
tree | e26d8490ef47fe41eac5d21c556a62a1ff7cbe5f /tensorflow/compiler/tf2xla | |
parent | 99c35081f054f8d111c1512a0acb4b76686c102a (diff) |
Generate "While" node instead of "XlaWhile" node.
PiperOrigin-RevId: 212725134
Diffstat (limited to 'tensorflow/compiler/tf2xla')
5 files changed, 27 insertions, 46 deletions
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index e29a4c0603..d549e7bb59 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -560,6 +560,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/compiler/tf2xla/cc:xla_ops", diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index ca64f3f226..db256e577a 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -1285,13 +1285,6 @@ Status FunctionalizeCond::FunctionalizeInternal() { std::vector<int> switch_ids; std::vector<Node*> merge_order; DFS(*graph_, nullptr, [&](Node* n) { - // Nodes marked with _xla_outside_compilation are skipped, because they need - // to be executed on host with regular TF executor, which does not support - // XlaIf/XlaWhile. - if (HasNodeAttr(n->def(), kXlaOutsideCompilationAttrName)) { - return; - } - if (IsSwitch(n)) { switch_ids.push_back(n->id()); } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index f1cbcdf617..ba99205640 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -35,11 +35,7 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, // This pass looks at the graph and all associated FunctionDefs, and turns // traditional control flow structure (Switch/Merge/etc.) into functional -// control flow structure (XlaIf/XlaWhile). -// -// Notice that control flow structure marked with _xla_outside_compilation are -// skipped, because they need to be executed on host with regular TF executor, -// which does not support XlaIf/XlaWhile. +// control flow structure (If/While). class FunctionalizeControlFlowPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override; diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index c068a4110c..c3841f996f 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" @@ -112,16 +113,12 @@ TEST(FunctionalizeControlFlow, Conditional) { auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = ops::XlaIf(scope.WithOpName(op_name), less, - std::initializer_list<Input>{less, y, x}, then_fn, - else_fn, {DT_INT32}); + auto if_op = ops::If(scope.WithOpName(op_name), less, + std::initializer_list<Input>{less, y, x}, {DT_INT32}, + then_fn, else_fn); auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); - // TODO(jpienaar): Create wrapper for IfOp. - for (NodeDef& n : *expected.mutable_node()) { - if (n.op() == "XlaIf") n.set_op("If"); - } TF_EXPECT_GRAPH_EQ(expected, graph_def); } @@ -177,7 +174,7 @@ TEST(FunctionalizeControlFlow, Conditional) { Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, NameAttrList* body) { for (const NodeDef& node : graph.node()) { - if (node.op() == "XlaWhile") { + if (node.op() == "While") { const NameAttrList* result; TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result)); *cond = *result; @@ -186,7 +183,7 @@ Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, return Status::OK(); } } - return errors::NotFound("No XlaWhile node found in graph"); + return errors::NotFound("No While node found in graph"); } // Graph: @@ -255,8 +252,8 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { Scope scope = Scope::NewRootScope().ExitOnError(); auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list<Input>{source}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list<Input>{source}, cond_fn, body_fn); auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); @@ -392,8 +389,8 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) { Scope scope = Scope::NewRootScope().ExitOnError(); auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list<Input>{source}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list<Input>{source}, cond_fn, body_fn); GraphDef expected; TF_ASSERT_OK(scope.ToGraphDef(&expected)); TF_EXPECT_GRAPH_EQ(expected, graph_def); @@ -483,8 +480,8 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { Scope scope = Scope::NewRootScope().ExitOnError(); auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list<Input>{source}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list<Input>{source}, cond_fn, body_fn); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); TF_EXPECT_GRAPH_EQ(expected, graph_def); @@ -625,8 +622,8 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) { auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); auto while_op = - ops::XlaWhile(scope.WithOpName("while/LoopCond"), - std::initializer_list<Input>{x, y}, cond_fn, body_fn); + ops::While(scope.WithOpName("while/LoopCond"), + std::initializer_list<Input>{x, y}, cond_fn, body_fn); auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]); auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]); GraphDef expected; @@ -864,9 +861,9 @@ TEST(FunctionalizeControlFlow, Complex) { auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0); - auto while_op = ops::XlaWhile(scope.WithOpName("outer/LoopCond"), - std::initializer_list<Input>{zero, y, x, var}, - outer_cond_fn, outer_body_fn); + auto while_op = ops::While(scope.WithOpName("outer/LoopCond"), + std::initializer_list<Input>{zero, y, x, var}, + outer_cond_fn, outer_body_fn); auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); GraphDef expected; TF_EXPECT_OK(scope.ToGraphDef(&expected)); @@ -921,9 +918,9 @@ TEST(FunctionalizeControlFlow, Complex) { auto one_j = ops::Const<int32>( scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); auto while_op = - ops::XlaWhile(scope.WithOpName("outer/LoopCond_1"), - std::initializer_list<Input>{one_j, arg1, arg2, arg3}, - inner_cond_fn, inner_body_fn); + ops::While(scope.WithOpName("outer/LoopCond_1"), + std::initializer_list<Input>{one_j, arg1, arg2, arg3}, + inner_cond_fn, inner_body_fn); auto one_outer = ops::Const<int32>( scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index 2173e15e03..7c3ad448ef 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -519,7 +519,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, // Builds a While operator. NodeDef while_def; - NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile", library); + NodeDefBuilder builder(frame->loop_cond->name(), "While", library); builder.Attr("T", arg_types); builder.Attr("cond", cond_name); builder.Attr("body", body_name); @@ -650,14 +650,8 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, continue; } - // Nodes marked with _xla_outside_compilation are skipped, because they need - // to be executed on host with regular TF executor, which does not support - // XlaIf/XlaWhile. - string name; - if (!HasNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName)) { - TF_RETURN_IF_ERROR( - FunctionalizeLoop(lookup_library, graph, frame, library)); - } + TF_RETURN_IF_ERROR( + FunctionalizeLoop(lookup_library, graph, frame, library)); // If the parent has no remaining children, add it to the worklist. --frame->parent->num_children; @@ -668,9 +662,9 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library, // There should be no cycle at this point, since while loops have been removed // from graph. - // Check that the newly added XlaWhile nodes don't feed into themselves. + // Check that the newly added While nodes don't feed into themselves. for (const Node* node : graph->op_nodes()) { - if (node->def().op() == "XlaWhile") { + if (node->def().op() == "While") { TF_RETURN_WITH_CONTEXT_IF_ERROR( CheckNodeNotInCycle(node, graph->num_node_ids()), "Functionalizing loop failed."); |