diff options
author | Benoit Steiner <bsteiner@google.com> | 2017-06-26 10:34:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-26 10:41:28 -0700 |
commit | 1fa8b4247a15079cb60da74a3c86ed5eef7f1b93 (patch) | |
tree | b550bb00e828062fcfdfe11fded165325970a079 | |
parent | bdc928cdf2d5e9b6846ecfdf4b3767f275f4240d (diff) |
Properly handle RefEnter, RefExit and RefNextIteration nodes.
PiperOrigin-RevId: 160162338
-rw-r--r-- | tensorflow/core/grappler/op_types.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/grappler/op_types.h | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/memory_optimizer.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/topological_sort.cc | 3 |
7 files changed, 16 insertions, 5 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 51146011b0..700a66d3a6 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -50,6 +50,11 @@ bool IsNoOp(const NodeDef& node) { return op == "NoOp"; } +bool IsNextIteration(const NodeDef& node) { + const auto& op = node.op(); + return op == "NextIteration" || op == "RefNextIteration"; +} + bool IsPlaceholder(const NodeDef& node) { const auto op = node.op(); return op == "Placeholder" || op == "PlaceholderV2" || diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index b2102c688d..cba44e905e 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -26,6 +26,7 @@ bool IsConstant(const NodeDef& node); bool IsDequeueOp(const NodeDef& node); bool IsIdentity(const NodeDef& node); bool IsMerge(const NodeDef& node); +bool IsNextIteration(const NodeDef& node); bool IsNoOp(const NodeDef& node); bool IsPlaceholder(const NodeDef& node); bool IsRecv(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 3a705b85f7..5a7f11e149 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -197,6 +197,7 @@ cc_library( ":static_schedule", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/grappler/utils:topological_sort", diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 9f6e3aee71..2b6d9859b3 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -101,8 +101,9 @@ string AsControlDependency(const NodeDef& node) { ConstantFolding::ConstantFolding() { ops_to_preserve_ = std::regex( - "Placeholder.*|Const|.*Save.*|.*Restore.*|.*Reader|Enter|Exit|" - "NextIteration|.*Quantized.*"); + "Placeholder.*|Const|.*Save.*|.*Restore.*|.*Reader|" + "Enter|RefEnter|Exit|RefExit|NextIteration|RefNextIteration|" + ".*Quantized.*"); } string ConstantFolding::AddControlDependency(const string& input_name) { diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index eb3692f319..16a638a7d3 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/graph_rewriter.h" #include "tensorflow/core/grappler/optimizers/static_schedule.h" #include "tensorflow/core/grappler/utils.h" @@ -556,8 +557,8 @@ static const NodeDef* FindSwapTrigger( // Don't jump over frames, since adding a control dependency from one frame // to the next isn't supported. Don't go through branches, since we don't // know whether they'll be executed or not. - if (input_node->op() == "NextIteration" || input_node->op() == "Switch" || - input_node->op() == "Merge") { + if (IsNextIteration(*input_node) || IsSwitch(*input_node) || + IsMerge(*input_node)) { continue; } auto it2 = execution_times.find(input_node); diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index 8839f07bc5..fd3894553b 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -48,6 +48,7 @@ cc_library( deps = [ "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", ], ) diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc index fdf7fb0d3d..9c5d27f3c5 100644 --- a/tensorflow/core/grappler/utils/topological_sort.cc +++ b/tensorflow/core/grappler/utils/topological_sort.cc @@ -17,6 +17,7 @@ limitations under the License. #include <deque> #include <unordered_map> #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" namespace tensorflow { @@ -35,7 +36,7 @@ void TopologicalSort(GraphDef* graph) { if (node.op() == "Merge") { ready_inputs[&node] = 0; for (const auto& input : node.input()) { - if (node_map.GetNode(input)->op() == "NextIteration") { + if (IsNextIteration(*node_map.GetNode(input))) { ready_inputs[&node]++; } } |