diff options
author | 2018-06-14 09:33:17 -0700 | |
---|---|---|
committer | 2018-06-14 09:37:31 -0700 | |
commit | 3d5fa1f7f85e8cbd39227e921960fa36539ba3cd (patch) | |
tree | 59cd6fbfcf1a9da15fcff4e042ba4dc2d7c28e80 /tensorflow | |
parent | b22cfe55abc6700d9d9492be4316da4e74e3549d (diff) |
Disable removing pairs of transposes across chains, while debugging breakage in bayesflow.
PiperOrigin-RevId: 200568541
Diffstat (limited to 'tensorflow')
5 files changed, 23 insertions, 11 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 20887bc218..1b18087cdf 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -210,8 +210,7 @@ cc_library( hdrs = ["graph_optimizer_stage.h"], visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/costs:graph_properties", @@ -225,6 +224,7 @@ tf_cuda_cc_test( deps = [ ":graph_optimizer_stage", "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/grappler:grappler_item", diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 51110b4bda..c41b152d21 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1084,8 +1084,11 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage { Status TrySimplify(NodeDef* node, string* simplified_node_name) override { TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node)); NodeDef* tail = node; - tail = GetTailOfIdempotentChain(*tail, *ctx().node_map, - *ctx().nodes_to_preserve); + // TODO(rmlarsen): Enable after debugging breakage in Bayesflow. + if (ctx().opt_level == RewriterConfig::AGGRESSIVE) { + tail = GetTailOfIdempotentChain(*tail, *ctx().node_map, + *ctx().nodes_to_preserve); + } NodeDef* first_transpose; TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose)); @@ -2713,7 +2716,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { } const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_, - graph_properties_.get(), node_map_.get()); + graph_properties_.get(), node_map_.get(), + opt_level_); const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify); // Stop pipeline after first stage returning non-empty simplified tensor name. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index ff96cb6480..fe70c7db5c 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -1510,7 +1510,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesThroughChain) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphDef output; - ArithmeticOptimizer optimizer; + ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE); EnableOnlyRemoveIdentityTranspose(&optimizer); OptimizeAndPrune(&optimizer, &item, &output); diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h index 2fbdd76a77..2afb5df431 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h +++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { namespace grappler { @@ -44,16 +45,19 @@ const NodeScopeAndName ParseNodeScopeAndName(const string& node_name); struct GraphOptimizerContext { GraphOptimizerContext(const std::unordered_set<string>* nodes_to_preserve, GraphDef* optimized_graph, - GraphProperties* graph_properties, NodeMap* node_map) + GraphProperties* graph_properties, NodeMap* node_map, + RewriterConfig::Toggle opt_level) : nodes_to_preserve(nodes_to_preserve), optimized_graph(optimized_graph), graph_properties(graph_properties), - node_map(node_map) {} + node_map(node_map), + opt_level(opt_level) {} const std::unordered_set<string>* nodes_to_preserve; GraphDef* optimized_graph; GraphProperties* graph_properties; NodeMap* node_map; + RewriterConfig::Toggle opt_level; }; Status GetInputNode(const GraphOptimizerContext& ctx, const string& input, diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc index 3f5ab87a5a..34f28c7c27 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc +++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { namespace grappler { @@ -59,7 +60,8 @@ TEST_F(GraphOptimizerStageTest, OptimizedNodeName) { GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr, /*optimized_graph*/ nullptr, /*graph_properties*/ nullptr, - /*node_name*/ nullptr); + /*node_name*/ nullptr, + /*opt_level*/ RewriterConfig::ON); FakeOptimizerStage stage("my_opt", "my_stg", ctx); const auto node = ParseNodeScopeAndName("a/b/c/Add"); @@ -94,7 +96,8 @@ TEST_F(GraphOptimizerStageTest, GetInputNodeAndProperties) { GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr, /*optimized_graph*/ &item.graph, /*graph_properties*/ &properties, - /*node_name*/ &node_map); + /*node_name*/ &node_map, + /*opt_level*/ RewriterConfig::ON); FakeOptimizerStage stage("my_opt", "my_stg", ctx); NodeDef* add_node; @@ -133,7 +136,8 @@ TEST_F(GraphOptimizerStageTest, AddNodes) { GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr, /*optimized_graph*/ &item.graph, /*graph_properties*/ &properties, - /*node_name*/ &node_map); + /*node_name*/ &node_map, + /*opt_level*/ RewriterConfig::ON); FakeOptimizerStage stage("my_opt", "my_stg", ctx); NodeDef* add_node; |