aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/functionalize_while.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/functionalize_while.cc')
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_while.cc17
1 files changed, 13 insertions, 4 deletions
diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc
index 7f45e3bffa..7c3ad448ef 100644
--- a/tensorflow/compiler/tf2xla/functionalize_while.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_while.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/functionalize_cond.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -34,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace {
@@ -473,12 +475,19 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
}
}
- // Builds the condition and body functions.
+ // Builds the condition and body functions. Notice that we call
+ // FunctionalizeCond() on cond_graph and body_graph because we might have
+ // unfunctionalized "if" in cond_graph and body_graph. Functionalize them
+ // before they are encapsulated in FunctionDef.
std::unique_ptr<Graph> cond_graph;
TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph));
+ FixupSourceAndSinkEdges(cond_graph.get());
+ TF_RETURN_IF_ERROR(FunctionalizeCond(cond_graph.get(), library));
DataTypeVector arg_types;
std::unique_ptr<Graph> body_graph;
TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
+ FixupSourceAndSinkEdges(body_graph.get());
+ TF_RETURN_IF_ERROR(FunctionalizeCond(body_graph.get(), library));
VLOG(2) << "Frame " << frame->name << " condition: "
<< dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library)
@@ -510,7 +519,7 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library,
// Builds a While operator.
NodeDef while_def;
- NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile");
+ NodeDefBuilder builder(frame->loop_cond->name(), "While", library);
builder.Attr("T", arg_types);
builder.Attr("cond", cond_name);
builder.Attr("body", body_name);
@@ -653,9 +662,9 @@ Status FunctionalizeWhileLoop(const FunctionLibraryDefinition* lookup_library,
// There should be no cycle at this point, since while loops have been removed
// from graph.
- // Check that the newly added XlaWhile nodes don't feed into themselves.
+ // Check that the newly added While nodes don't feed into themselves.
for (const Node* node : graph->op_nodes()) {
- if (node->def().op() == "XlaWhile") {
+ if (node->def().op() == "While") {
TF_RETURN_WITH_CONTEXT_IF_ERROR(
CheckNodeNotInCycle(node, graph->num_node_ids()),
"Functionalizing loop failed.");