aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler
diff options
context:
space:
mode:
authorGravatar Mingxing Tan <tanmingxing@google.com>2018-06-20 12:41:19 -0700
committerGravatar Mingxing Tan <tanmingxing@google.com>2018-06-20 12:41:19 -0700
commit8f1f0a8e4eaa5ae7593dc596b9b69a6cd88fa16a (patch)
tree7a341bfcec45046067e1753b2fe11e3f543ccfbe /tensorflow/python/grappler
parent39ea5a7044a16b868e38717b358c46d6e3191373 (diff)
parent4fdb7cc4f92e76a168810e9b420bf1b90eb544e9 (diff)
Merge commit for internal changes
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r--tensorflow/python/grappler/layout_optimizer_test.py7
-rw-r--r--tensorflow/python/grappler/memory_optimizer_test.py6
-rw-r--r--tensorflow/python/grappler/tf_optimizer_test.py3
3 files changed, 13 insertions, 3 deletions
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index af5d709f7e..7d07c77c79 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -158,6 +158,7 @@ def _get_config(layout_optimizer=True):
layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
# do not remove duplicated nodes
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF)
+ rewrite_options.min_graph_nodes = -1
graph_options = config_pb2.GraphOptions(
rewrite_options=rewrite_options, build_cost_model=1)
config = config_pb2.ConfigProto(graph_options=graph_options)
@@ -1443,7 +1444,8 @@ class LayoutOptimizerTest(test.TestCase):
def testGradient(self):
meta_graph = _simple_metagraph()
rewrite_options = rewriter_config_pb2.RewriterConfig(
- layout_optimizer=rewriter_config_pb2.RewriterConfig.ON)
+ layout_optimizer=rewriter_config_pb2.RewriterConfig.ON,
+ min_graph_nodes=-1)
optimized_graph = tf_optimizer.OptimizeGraph(
rewrite_options, meta_graph, cluster=_get_cluster())
@@ -1457,7 +1459,8 @@ class LayoutOptimizerTest(test.TestCase):
def testDepthwise(self):
meta_graph = _simple_metagraph(depthwise=True)
rewrite_options = rewriter_config_pb2.RewriterConfig(
- layout_optimizer=rewriter_config_pb2.RewriterConfig.ON)
+ layout_optimizer=rewriter_config_pb2.RewriterConfig.ON,
+ min_graph_nodes=-1)
optimized_graph = tf_optimizer.OptimizeGraph(
rewrite_options, meta_graph, cluster=_get_cluster())
diff --git a/tensorflow/python/grappler/memory_optimizer_test.py b/tensorflow/python/grappler/memory_optimizer_test.py
index 7ed4b128e4..b658edff2d 100644
--- a/tensorflow/python/grappler/memory_optimizer_test.py
+++ b/tensorflow/python/grappler/memory_optimizer_test.py
@@ -76,7 +76,8 @@ class MemoryOptimizerSwapTest(test.TestCase):
disable_model_pruning=True,
meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE,
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
- memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
+ memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL,
+ min_graph_nodes=-1)
graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
self.assertEqual(len(graph.node), graph_size + 2)
@@ -133,6 +134,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ min_graph_nodes=-1,
memory_optimization=rewriter_config_pb2.RewriterConfig.
RECOMPUTATION_HEURISTICS), original_metagraph)
self.assertGreater(
@@ -158,6 +160,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ min_graph_nodes=-1,
memory_optimization=rewriter_config_pb2.RewriterConfig.
RECOMPUTATION_HEURISTICS,
# Checks that name scope "gradients/" also match sub-scope.
@@ -297,6 +300,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
if 'Recomputed/' in node.name]))
rewritten_graph_def = tf_optimizer.OptimizeGraph(
rewriter_config_pb2.RewriterConfig(
+ min_graph_nodes=-1,
memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL),
metagraph)
self.assertEqual(
diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py
index 1c0f072dd3..5a9afe7257 100644
--- a/tensorflow/python/grappler/tf_optimizer_test.py
+++ b/tensorflow/python/grappler/tf_optimizer_test.py
@@ -47,6 +47,7 @@ class PyWrapOptimizeGraphTest(test.TestCase):
rewriter_config = rewriter_config_pb2.RewriterConfig()
rewriter_config.optimizers.append('constfold')
+ rewriter_config.min_graph_nodes = -1
graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
@@ -68,6 +69,7 @@ class PyWrapOptimizeGraphTest(test.TestCase):
# Optimize the graph.
mg = meta_graph.create_meta_graph_def(graph=g)
rewriter_config = rewriter_config_pb2.RewriterConfig()
+ rewriter_config.min_graph_nodes = -1
optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
# Check that the nodes referenced in various collections have been preserved
@@ -109,6 +111,7 @@ class PyWrapOptimizeGraphTest(test.TestCase):
# Optimize the graph.
mg = meta_graph.create_meta_graph_def(graph=g)
rewriter_config = rewriter_config_pb2.RewriterConfig()
+ rewriter_config.min_graph_nodes = -1
optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
mg.graph_def.CopyFrom(optimized_graph)