aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-10-01 13:40:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 13:44:35 -0700
commitc86f5941359526b91d85daf844e94ff5d39b2d6c (patch)
treeaf0e32582187d30a58e1da7c6e18f01ccb701c36 /tensorflow/core
parent1630584951975479dee852cf6f7603fe6819fde1 (diff)
Make cond_v2 If op lowering work in a defun + eager.
Prior to this change, the lowering pass assumed that the If op functions would be available in the If op's graph. If the If op is defined in a defun and then called via eager execution, the functions will be in the eager context, but not in the defun's graph. This change makes the lowering pass correctly use the function library passed in by the caller via GraphOptimizationPassOptions. PiperOrigin-RevId: 215271990
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.cc43
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.h5
-rw-r--r--tensorflow/core/common_runtime/lower_if_op_test.cc4
3 files changed, 32 insertions, 20 deletions
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc
index dfce7c23e7..a02084f223 100644
--- a/tensorflow/core/common_runtime/lower_if_op.cc
+++ b/tensorflow/core/common_runtime/lower_if_op.cc
@@ -38,11 +38,12 @@ class CondBuilder {
public:
enum Branch { kElseBranch = 0, kThenBranch = 1 };
- // Create a CondBuilder to create the lowering of If op. that has then and
+ // Create a CondBuilder to create the lowered form of `if_op` with then and
// else functions named `then_fn_name` and `else_fn_name` respectively in the
- // given graph.
+ // `graph`. The functions should be available in `flib`.
CondBuilder(Node* if_op, const string& then_fn_name,
- const string& else_fn_name, Graph* graph);
+ const string& else_fn_name, const FunctionLibraryDefinition& flib,
+ Graph* graph);
// Constructs the basic conditional control flow using switch and merge nodes.
Status CreatePivotNodes();
@@ -89,6 +90,7 @@ class CondBuilder {
Node* then_call_node_;
Node* else_call_node_;
Graph* graph_;
+ const FunctionLibraryDefinition& flib_;
string name_;
NodeBuilder then_call_builder_;
@@ -96,9 +98,11 @@ class CondBuilder {
};
CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name,
- const string& else_fn_name, Graph* graph)
+ const string& else_fn_name,
+ const FunctionLibraryDefinition& flib, Graph* graph)
: if_op_(if_op),
graph_(graph),
+ flib_(flib),
name_(if_op->name()),
then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()),
else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) {
@@ -193,15 +197,15 @@ Status CondBuilder::AddOutputs() {
return Status::OK();
}
-Status InlineCallInGraph(Node* n, Graph* g) {
- const auto& lib = g->flib_def();
- const FunctionDef* fdef = lib.Find(n->type_string());
+Status InlineCallInGraph(Node* n, const FunctionLibraryDefinition& flib,
+ Graph* g) {
+ const FunctionDef* fdef = flib.Find(n->type_string());
CHECK(fdef != nullptr);
FunctionBody* fbody;
TF_RETURN_IF_ERROR(
- FunctionDefToBodyHelper(*fdef, n->attrs(), &lib,
- [&lib](const string& op, const OpDef** sig) {
- return lib.LookUpOpDef(op, sig);
+ FunctionDefToBodyHelper(*fdef, n->attrs(), &flib,
+ [&flib](const string& op, const OpDef** sig) {
+ return flib.LookUpOpDef(op, sig);
},
&fbody));
// TODO(jpienaar): Improve this interface to make the need to delete it
@@ -219,8 +223,8 @@ Status CondBuilder::BuildLoweredIfOutput() {
}
Status CondBuilder::InlineCallNodes() {
- TF_RETURN_IF_ERROR(InlineCallInGraph(then_call_node_, graph_));
- TF_RETURN_IF_ERROR(InlineCallInGraph(else_call_node_, graph_));
+ TF_RETURN_IF_ERROR(InlineCallInGraph(then_call_node_, flib_, graph_));
+ TF_RETURN_IF_ERROR(InlineCallInGraph(else_call_node_, flib_, graph_));
return Status::OK();
}
@@ -240,6 +244,12 @@ Status LowerIfOpPass::Run(const GraphOptimizationPassOptions& options) {
return errors::Internal("Lowering If op requires a graph to be available.");
}
+ FunctionLibraryDefinition* flib = options.flib_def;
+ if (flib == nullptr) {
+ return errors::Internal(
+ "Lowering If op requires a FunctionLibraryDefinition to be available.");
+ }
+
// Match all the nodes that need to be rewritten.
gtl::InlinedVector<Node*, 2> matches;
for (Node* n : g->op_nodes()) {
@@ -251,12 +261,14 @@ Status LowerIfOpPass::Run(const GraphOptimizationPassOptions& options) {
}
}
for (Node* n : matches) {
- TF_RETURN_IF_ERROR(RewriteNode(n, g));
+ TF_RETURN_IF_ERROR(RewriteNode(n, *flib, g));
}
return Status::OK();
}
-Status LowerIfOpPass::RewriteNode(Node* n, Graph* g) {
+Status LowerIfOpPass::RewriteNode(Node* n,
+ const FunctionLibraryDefinition& flib,
+ Graph* g) {
const AttrValue* then_attr = n->attrs().Find("then_branch");
if (then_attr == nullptr) {
return errors::InvalidArgument("Then branch function missing");
@@ -266,7 +278,8 @@ Status LowerIfOpPass::RewriteNode(Node* n, Graph* g) {
return errors::InvalidArgument("Else branch function missing");
}
- CondBuilder cb(n, then_attr->func().name(), else_attr->func().name(), g);
+ CondBuilder cb(n, then_attr->func().name(), else_attr->func().name(), flib,
+ g);
TF_RETURN_IF_ERROR(cb.CreatePivotNodes());
TF_RETURN_IF_ERROR(cb.AddInputs());
TF_RETURN_IF_ERROR(cb.AddOutputs());
diff --git a/tensorflow/core/common_runtime/lower_if_op.h b/tensorflow/core/common_runtime/lower_if_op.h
index a9ef39ae5c..5ab1123e3f 100644
--- a/tensorflow/core/common_runtime/lower_if_op.h
+++ b/tensorflow/core/common_runtime/lower_if_op.h
@@ -29,8 +29,9 @@ class LowerIfOpPass : public GraphOptimizationPass {
Status Run(const GraphOptimizationPassOptions& options) override;
private:
- // Rewrite the given If node `n` in graph `g` to use the switch-merge form.
- Status RewriteNode(Node* n, Graph* g);
+ // Rewrite the given If node `n` in graph `g` to use the switch-merge
+ // form. `flib` should contain the branch functions referenced by `n`.
+ Status RewriteNode(Node* n, const FunctionLibraryDefinition& flib, Graph* g);
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/lower_if_op_test.cc b/tensorflow/core/common_runtime/lower_if_op_test.cc
index 319a617b32..044a355d06 100644
--- a/tensorflow/core/common_runtime/lower_if_op_test.cc
+++ b/tensorflow/core/common_runtime/lower_if_op_test.cc
@@ -36,9 +36,7 @@ namespace tensorflow {
namespace {
Status Rewrite(std::unique_ptr<Graph>* graph) {
- FunctionDefLibrary flib;
- FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
-
+ FunctionLibraryDefinition flib_def((*graph)->flib_def());
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
opt_options.flib_def = &flib_def;