aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-14 09:33:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-14 09:37:31 -0700
commit3d5fa1f7f85e8cbd39227e921960fa36539ba3cd (patch)
tree59cd6fbfcf1a9da15fcff4e042ba4dc2d7c28e80 /tensorflow
parentb22cfe55abc6700d9d9492be4316da4e74e3549d (diff)
Disable removing pairs of transposes across chains, while debugging breakage in bayesflow.
PiperOrigin-RevId: 200568541
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD4
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage.h8
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc10
5 files changed, 23 insertions, 11 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 20887bc218..1b18087cdf 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -210,8 +210,7 @@ cc_library(
hdrs = ["graph_optimizer_stage.h"],
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
@@ -225,6 +224,7 @@ tf_cuda_cc_test(
deps = [
":graph_optimizer_stage",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler:grappler_item",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 51110b4bda..c41b152d21 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -1084,8 +1084,11 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
NodeDef* tail = node;
- tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
- *ctx().nodes_to_preserve);
+ // TODO(rmlarsen): Enable after debugging breakage in Bayesflow.
+ if (ctx().opt_level == RewriterConfig::AGGRESSIVE) {
+ tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
+ *ctx().nodes_to_preserve);
+ }
NodeDef* first_transpose;
TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
@@ -2713,7 +2716,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
}
const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
- graph_properties_.get(), node_map_.get());
+ graph_properties_.get(), node_map_.get(),
+ opt_level_);
const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
// Stop pipeline after first stage returning non-empty simplified tensor name.
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index ff96cb6480..fe70c7db5c 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -1510,7 +1510,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdentityTransposesThroughChain) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphDef output;
- ArithmeticOptimizer optimizer;
+ ArithmeticOptimizer optimizer(RewriterConfig::AGGRESSIVE);
EnableOnlyRemoveIdentityTranspose(&optimizer);
OptimizeAndPrune(&optimizer, &item, &output);
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
index 2fbdd76a77..2afb5df431 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
namespace grappler {
@@ -44,16 +45,19 @@ const NodeScopeAndName ParseNodeScopeAndName(const string& node_name);
struct GraphOptimizerContext {
GraphOptimizerContext(const std::unordered_set<string>* nodes_to_preserve,
GraphDef* optimized_graph,
- GraphProperties* graph_properties, NodeMap* node_map)
+ GraphProperties* graph_properties, NodeMap* node_map,
+ RewriterConfig::Toggle opt_level)
: nodes_to_preserve(nodes_to_preserve),
optimized_graph(optimized_graph),
graph_properties(graph_properties),
- node_map(node_map) {}
+ node_map(node_map),
+ opt_level(opt_level) {}
const std::unordered_set<string>* nodes_to_preserve;
GraphDef* optimized_graph;
GraphProperties* graph_properties;
NodeMap* node_map;
+ RewriterConfig::Toggle opt_level;
};
Status GetInputNode(const GraphOptimizerContext& ctx, const string& input,
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc
index 3f5ab87a5a..34f28c7c27 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
namespace grappler {
@@ -59,7 +60,8 @@ TEST_F(GraphOptimizerStageTest, OptimizedNodeName) {
GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
/*optimized_graph*/ nullptr,
/*graph_properties*/ nullptr,
- /*node_name*/ nullptr);
+ /*node_name*/ nullptr,
+ /*opt_level*/ RewriterConfig::ON);
FakeOptimizerStage stage("my_opt", "my_stg", ctx);
const auto node = ParseNodeScopeAndName("a/b/c/Add");
@@ -94,7 +96,8 @@ TEST_F(GraphOptimizerStageTest, GetInputNodeAndProperties) {
GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
/*optimized_graph*/ &item.graph,
/*graph_properties*/ &properties,
- /*node_name*/ &node_map);
+ /*node_name*/ &node_map,
+ /*opt_level*/ RewriterConfig::ON);
FakeOptimizerStage stage("my_opt", "my_stg", ctx);
NodeDef* add_node;
@@ -133,7 +136,8 @@ TEST_F(GraphOptimizerStageTest, AddNodes) {
GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
/*optimized_graph*/ &item.graph,
/*graph_properties*/ &properties,
- /*node_name*/ &node_map);
+ /*node_name*/ &node_map,
+ /*opt_level*/ RewriterConfig::ON);
FakeOptimizerStage stage("my_opt", "my_stg", ctx);
NodeDef* add_node;