diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-02 12:58:08 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-02 13:02:11 -0800 |
commit | e0fac18b63e80963d42cb1e39243d84ae86ae01a (patch) | |
tree | 049a596b5fb4a3dab53a7bfb11a8cbb248b93bbc /tensorflow/python/grappler | |
parent | 85daa2e4553e49ca6ab2fbb412b18c23b5399524 (diff) |
Automated g4 rollback of changelist 187582263
PiperOrigin-RevId: 187657654
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r-- | tensorflow/python/grappler/memory_optimizer_test.py | 32 |
1 files changed, 31 insertions, 1 deletions
diff --git a/tensorflow/python/grappler/memory_optimizer_test.py b/tensorflow/python/grappler/memory_optimizer_test.py index 948911f099..4df959ce04 100644 --- a/tensorflow/python/grappler/memory_optimizer_test.py +++ b/tensorflow/python/grappler/memory_optimizer_test.py @@ -162,7 +162,8 @@ class MemoryOptimizerRecomputeTest(test.TestCase): arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, memory_optimization=rewriter_config_pb2.RewriterConfig. RECOMPUTATION_HEURISTICS, - memory_optimizer_target_node_name_prefix='optimizer/gradients/'), + # Checks that name scope "gradients/" also match sub-scope. + memory_optimizer_target_node_name_scope='gradients/'), original_metagraph) self.assertGreater( len(rewritten_graph_def.node), @@ -176,6 +177,35 @@ class MemoryOptimizerRecomputeTest(test.TestCase): len([node for node in rewritten_graph_def.node if 'Recomputed/' in node.name])) + def testRewritingNameScopedGradientNamesScope(self): + """Tests that rewriting occurs with non-standard gradient names.""" + (original_metagraph, _, _, + _) = self._GetMetaGraph(optimizer_scope_name='foo/bar') + rewritten_graph_def = tf_optimizer.OptimizeGraph( + rewriter_config_pb2.RewriterConfig( + disable_model_pruning=True, + constant_folding=rewriter_config_pb2.RewriterConfig.OFF, + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, + layout_optimizer=rewriter_config_pb2.RewriterConfig.OFF, + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + memory_optimization=rewriter_config_pb2.RewriterConfig. + RECOMPUTATION_HEURISTICS, + # This should not match anything. + memory_optimizer_target_node_name_scope='r/gradients/'), + original_metagraph) + self.assertEqual( + len(rewritten_graph_def.node), len(original_metagraph.graph_def.node)) + self.assertEqual(0, + len([ + node for node in original_metagraph.graph_def.node + if 'Recomputed/' in node.name + ])) + self.assertEqual(0, + len([ + node for node in rewritten_graph_def.node + if 'Recomputed/' in node.name + ])) + def _GetMemoryOptimizerSessionConfig(self): rewrite_options = rewriter_config_pb2.RewriterConfig( disable_model_pruning=True, |