aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-06-26 10:34:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-26 10:41:28 -0700
commit1fa8b4247a15079cb60da74a3c86ed5eef7f1b93 (patch)
treeb550bb00e828062fcfdfe11fded165325970a079
parentbdc928cdf2d5e9b6846ecfdf4b3767f275f4240d (diff)
Properly handle RefEnter, RefExit and RefNextIteration nodes.
PiperOrigin-RevId: 160162338
-rw-r--r--tensorflow/core/grappler/op_types.cc5
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc5
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc5
-rw-r--r--tensorflow/core/grappler/utils/BUILD1
-rw-r--r--tensorflow/core/grappler/utils/topological_sort.cc3
7 files changed, 16 insertions, 5 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 51146011b0..700a66d3a6 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -50,6 +50,11 @@ bool IsNoOp(const NodeDef& node) {
return op == "NoOp";
}
+bool IsNextIteration(const NodeDef& node) {
+ const auto& op = node.op();
+ return op == "NextIteration" || op == "RefNextIteration";
+}
+
bool IsPlaceholder(const NodeDef& node) {
const auto op = node.op();
return op == "Placeholder" || op == "PlaceholderV2" ||
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index b2102c688d..cba44e905e 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -26,6 +26,7 @@ bool IsConstant(const NodeDef& node);
bool IsDequeueOp(const NodeDef& node);
bool IsIdentity(const NodeDef& node);
bool IsMerge(const NodeDef& node);
+bool IsNextIteration(const NodeDef& node);
bool IsNoOp(const NodeDef& node);
bool IsPlaceholder(const NodeDef& node);
bool IsRecv(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 3a705b85f7..5a7f11e149 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -197,6 +197,7 @@ cc_library(
":static_schedule",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:topological_sort",
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 9f6e3aee71..2b6d9859b3 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -101,8 +101,9 @@ string AsControlDependency(const NodeDef& node) {
ConstantFolding::ConstantFolding() {
ops_to_preserve_ = std::regex(
- "Placeholder.*|Const|.*Save.*|.*Restore.*|.*Reader|Enter|Exit|"
- "NextIteration|.*Quantized.*");
+ "Placeholder.*|Const|.*Save.*|.*Restore.*|.*Reader|"
+ "Enter|RefEnter|Exit|RefExit|NextIteration|RefNextIteration|"
+ ".*Quantized.*");
}
string ConstantFolding::AddControlDependency(const string& input_name) {
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index eb3692f319..16a638a7d3 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/graph_rewriter.h"
#include "tensorflow/core/grappler/optimizers/static_schedule.h"
#include "tensorflow/core/grappler/utils.h"
@@ -556,8 +557,8 @@ static const NodeDef* FindSwapTrigger(
// Don't jump over frames, since adding a control dependency from one frame
// to the next isn't supported. Don't go through branches, since we don't
// know whether they'll be executed or not.
- if (input_node->op() == "NextIteration" || input_node->op() == "Switch" ||
- input_node->op() == "Merge") {
+ if (IsNextIteration(*input_node) || IsSwitch(*input_node) ||
+ IsMerge(*input_node)) {
continue;
}
auto it2 = execution_times.find(input_node);
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index 8839f07bc5..fd3894553b 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -48,6 +48,7 @@ cc_library(
deps = [
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
],
)
diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc
index fdf7fb0d3d..9c5d27f3c5 100644
--- a/tensorflow/core/grappler/utils/topological_sort.cc
+++ b/tensorflow/core/grappler/utils/topological_sort.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <deque>
#include <unordered_map>
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
@@ -35,7 +36,7 @@ void TopologicalSort(GraphDef* graph) {
if (node.op() == "Merge") {
ready_inputs[&node] = 0;
for (const auto& input : node.input()) {
- if (node_map.GetNode(input)->op() == "NextIteration") {
+ if (IsNextIteration(*node_map.GetNode(input))) {
ready_inputs[&node]++;
}
}