aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar Tong Shen <endlessroad@google.com>2018-09-12 16:33:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 16:37:13 -0700
commitacc32e741935545d8e600a67361c388d14556538 (patch)
treee26d8490ef47fe41eac5d21c556a62a1ff7cbe5f /tensorflow/compiler/tf2xla
parent99c35081f054f8d111c1512a0acb4b76686c102a (diff)
Generate "While" node instead of "XlaWhile" node.
PiperOrigin-RevId: 212725134
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_cond.cc7
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.h6
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc43
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_while.cc16
5 files changed, 27 insertions, 46 deletions
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index e29a4c0603..d549e7bb59 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -560,6 +560,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
+ "//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:resource_variable_ops",
"//tensorflow/compiler/tf2xla/cc:xla_ops",
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc
index ca64f3f226..db256e577a 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc
@@ -1285,13 +1285,6 @@ Status FunctionalizeCond::FunctionalizeInternal() {
std::vector<int> switch_ids;
std::vector<Node*> merge_order;
DFS(*graph_, nullptr, [&](Node* n) {
- // Nodes marked with _xla_outside_compilation are skipped, because they need
- // to be executed on host with regular TF executor, which does not support
- // XlaIf/XlaWhile.
- if (HasNodeAttr(n->def(), kXlaOutsideCompilationAttrName)) {
- return;
- }
-
if (IsSwitch(n)) {
switch_ids.push_back(n->id());
}
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
index f1cbcdf617..ba99205640 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h
@@ -35,11 +35,7 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library,
// This pass looks at the graph and all associated FunctionDefs, and turns
// traditional control flow structure (Switch/Merge/etc.) into functional
-// control flow structure (XlaIf/XlaWhile).
-//
-// Notice that control flow structure marked with _xla_outside_compilation are
-// skipped, because they need to be executed on host with regular TF executor,
-// which does not support XlaIf/XlaWhile.
+// control flow structure (If/While).
class FunctionalizeControlFlowPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override;
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
index c068a4110c..c3841f996f 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h"
@@ -112,16 +113,12 @@ TEST(FunctionalizeControlFlow, Conditional) {
auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
- auto if_op = ops::XlaIf(scope.WithOpName(op_name), less,
- std::initializer_list<Input>{less, y, x}, then_fn,
- else_fn, {DT_INT32});
+ auto if_op = ops::If(scope.WithOpName(op_name), less,
+ std::initializer_list<Input>{less, y, x}, {DT_INT32},
+ then_fn, else_fn);
auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
- // TODO(jpienaar): Create wrapper for IfOp.
- for (NodeDef& n : *expected.mutable_node()) {
- if (n.op() == "XlaIf") n.set_op("If");
- }
TF_EXPECT_GRAPH_EQ(expected, graph_def);
}
@@ -177,7 +174,7 @@ TEST(FunctionalizeControlFlow, Conditional) {
Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
NameAttrList* body) {
for (const NodeDef& node : graph.node()) {
- if (node.op() == "XlaWhile") {
+ if (node.op() == "While") {
const NameAttrList* result;
TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result));
*cond = *result;
@@ -186,7 +183,7 @@ Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond,
return Status::OK();
}
}
- return errors::NotFound("No XlaWhile node found in graph");
+ return errors::NotFound("No While node found in graph");
}
// Graph:
@@ -255,8 +252,8 @@ TEST(FunctionalizeControlFlow, OneLoopVar) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
auto while_op =
- ops::XlaWhile(scope.WithOpName("while/LoopCond"),
- std::initializer_list<Input>{source}, cond_fn, body_fn);
+ ops::While(scope.WithOpName("while/LoopCond"),
+ std::initializer_list<Input>{source}, cond_fn, body_fn);
auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
@@ -392,8 +389,8 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
auto while_op =
- ops::XlaWhile(scope.WithOpName("while/LoopCond"),
- std::initializer_list<Input>{source}, cond_fn, body_fn);
+ ops::While(scope.WithOpName("while/LoopCond"),
+ std::initializer_list<Input>{source}, cond_fn, body_fn);
GraphDef expected;
TF_ASSERT_OK(scope.ToGraphDef(&expected));
TF_EXPECT_GRAPH_EQ(expected, graph_def);
@@ -483,8 +480,8 @@ TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32);
auto while_op =
- ops::XlaWhile(scope.WithOpName("while/LoopCond"),
- std::initializer_list<Input>{source}, cond_fn, body_fn);
+ ops::While(scope.WithOpName("while/LoopCond"),
+ std::initializer_list<Input>{source}, cond_fn, body_fn);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
TF_EXPECT_GRAPH_EQ(expected, graph_def);
@@ -625,8 +622,8 @@ TEST(FunctionalizeControlFlow, TwoLoopVars) {
auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32);
auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32);
auto while_op =
- ops::XlaWhile(scope.WithOpName("while/LoopCond"),
- std::initializer_list<Input>{x, y}, cond_fn, body_fn);
+ ops::While(scope.WithOpName("while/LoopCond"),
+ std::initializer_list<Input>{x, y}, cond_fn, body_fn);
auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]);
auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]);
GraphDef expected;
@@ -864,9 +861,9 @@ TEST(FunctionalizeControlFlow, Complex) {
auto zero = ops::Const<int32>(scope.WithOpName("outer/Const"), 0);
- auto while_op = ops::XlaWhile(scope.WithOpName("outer/LoopCond"),
- std::initializer_list<Input>{zero, y, x, var},
- outer_cond_fn, outer_body_fn);
+ auto while_op = ops::While(scope.WithOpName("outer/LoopCond"),
+ std::initializer_list<Input>{zero, y, x, var},
+ outer_cond_fn, outer_body_fn);
auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]);
GraphDef expected;
TF_EXPECT_OK(scope.ToGraphDef(&expected));
@@ -921,9 +918,9 @@ TEST(FunctionalizeControlFlow, Complex) {
auto one_j = ops::Const<int32>(
scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1);
auto while_op =
- ops::XlaWhile(scope.WithOpName("outer/LoopCond_1"),
- std::initializer_list<Input>{one_j, arg1, arg2, arg3},
- inner_cond_fn, inner_body_fn);
+ ops::While(scope.WithOpName("outer/LoopCond_1"),
+ std::initializer_list<Input>{one_j, arg1, arg2, arg3},
+ inner_cond_fn, inner_body_fn);
auto one_outer = ops::Const<int32>(
scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1);
diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc
index 2173e15e03..7c3ad448ef 100644
--- a/tensorflow/compiler/tf2xla/functionalize_while.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_while.cc
@@ -519,7 +519,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
// Builds a While operator.
NodeDef while_def;
- NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile", library);
+ NodeDefBuilder builder(frame->loop_cond->name(), "While", library);
builder.Attr("T", arg_types);
builder.Attr("cond", cond_name);
builder.Attr("body", body_name);
@@ -650,14 +650,8 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
continue;
}
- // Nodes marked with _xla_outside_compilation are skipped, because they need
- // to be executed on host with regular TF executor, which does not support
- // XlaIf/XlaWhile.
- string name;
- if (!HasNodeAttr(frame->loop_cond->def(), kXlaOutsideCompilationAttrName)) {
- TF_RETURN_IF_ERROR(
- FunctionalizeLoop(lookup_library, graph, frame, library));
- }
+ TF_RETURN_IF_ERROR(
+ FunctionalizeLoop(lookup_library, graph, frame, library));
// If the parent has no remaining children, add it to the worklist.
--frame->parent->num_children;
@@ -668,9 +662,9 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
// There should be no cycle at this point, since while loops have been removed
// from graph.
- // Check that the newly added XlaWhile nodes don't feed into themselves.
+ // Check that the newly added While nodes don't feed into themselves.
for (const Node* node : graph->op_nodes()) {
- if (node->def().op() == "XlaWhile") {
+ if (node->def().op() == "While") {
TF_RETURN_WITH_CONTEXT_IF_ERROR(
CheckNodeNotInCycle(node, graph->num_node_ids()),
"Functionalizing loop failed.");