aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-02 12:58:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-02 13:02:11 -0800
commite0fac18b63e80963d42cb1e39243d84ae86ae01a (patch)
tree049a596b5fb4a3dab53a7bfb11a8cbb248b93bbc /tensorflow/python/grappler
parent85daa2e4553e49ca6ab2fbb412b18c23b5399524 (diff)
Automated g4 rollback of changelist 187582263
PiperOrigin-RevId: 187657654
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r--tensorflow/python/grappler/memory_optimizer_test.py32
1 files changed, 31 insertions, 1 deletions
diff --git a/tensorflow/python/grappler/memory_optimizer_test.py b/tensorflow/python/grappler/memory_optimizer_test.py
index 948911f099..4df959ce04 100644
--- a/tensorflow/python/grappler/memory_optimizer_test.py
+++ b/tensorflow/python/grappler/memory_optimizer_test.py
@@ -162,7 +162,8 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
memory_optimization=rewriter_config_pb2.RewriterConfig.
RECOMPUTATION_HEURISTICS,
- memory_optimizer_target_node_name_prefix='optimizer/gradients/'),
+ # Checks that name scope "gradients/" also match sub-scope.
+ memory_optimizer_target_node_name_scope='gradients/'),
original_metagraph)
self.assertGreater(
len(rewritten_graph_def.node),
@@ -176,6 +177,35 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
len([node for node in rewritten_graph_def.node
if 'Recomputed/' in node.name]))
+ def testRewritingNameScopedGradientNamesScope(self):
+ """Tests that rewriting occurs with non-standard gradient names."""
+ (original_metagraph, _, _,
+ _) = self._GetMetaGraph(optimizer_scope_name='foo/bar')
+ rewritten_graph_def = tf_optimizer.OptimizeGraph(
+ rewriter_config_pb2.RewriterConfig(
+ disable_model_pruning=True,
+ constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
+ dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF,
+ arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
+ memory_optimization=rewriter_config_pb2.RewriterConfig.
+ RECOMPUTATION_HEURISTICS,
+ # This should not match anything.
+ memory_optimizer_target_node_name_scope='r/gradients/'),
+ original_metagraph)
+ self.assertEqual(
+ len(rewritten_graph_def.node), len(original_metagraph.graph_def.node))
+ self.assertEqual(0,
+ len([
+ node for node in original_metagraph.graph_def.node
+ if 'Recomputed/' in node.name
+ ]))
+ self.assertEqual(0,
+ len([
+ node for node in rewritten_graph_def.node
+ if 'Recomputed/' in node.name
+ ]))
+
def _GetMemoryOptimizerSessionConfig(self):
rewrite_options = rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True,