diff options
author | Benoit Steiner <bsteiner@google.com> | 2018-04-03 17:52:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-03 17:55:29 -0700 |
commit | bde05fdf247ea6311414677f55260f7e8085718f (patch) | |
tree | abfcb745d793fa4bd34605bcbd7ac55a3e0539cd /tensorflow/python/grappler | |
parent | 5b652f57709d30d883570a82ac500051d8bfe1e6 (diff) |
Fix a shape inference bug.
PiperOrigin-RevId: 191528009
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r-- | tensorflow/python/grappler/tf_optimizer_test.py | 47 |
1 files changed, 46 insertions, 1 deletions
diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py index 3ee4d7807e..1c0f072dd3 100644 --- a/tensorflow/python/grappler/tf_optimizer_test.py +++ b/tensorflow/python/grappler/tf_optimizer_test.py @@ -17,12 +17,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.grappler import item as gitem from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -74,6 +78,47 @@ class PyWrapOptimizeGraphTest(test.TestCase): self.assertEqual(a2.op.name, optimized_graph.node[3].name) self.assertEqual('Variable/Assign', optimized_graph.node[4].name) + def testLoops(self): + g = ops.Graph() + with g.as_default(): + + def _Cond(_, counter): + return counter < end + + def _Body(buf, counter): + buf = array_ops.concat([buf, [counter]], 0) + counter += 1 + return [buf, counter] + + start = array_ops.placeholder(shape=[], dtype=dtypes.int32) + end = array_ops.placeholder(shape=[], dtype=dtypes.int32) + init_buf = array_ops.zeros(shape=[0], dtype=dtypes.int32) + loop_vars = [init_buf, start] + shape_inv = [ + tensor_shape.TensorShape([None]), + tensor_shape.TensorShape([]) + ] + buf, _ = control_flow_ops.while_loop(_Cond, _Body, loop_vars, shape_inv) + + f = -array_ops.ones_like(buf, optimize=False) + buf_shape = array_ops.shape(buf) + f_shape = array_ops.shape(f) + ops.add_to_collection('train_op', buf_shape) + ops.add_to_collection('train_op', f_shape) + + # Optimize the graph. + mg = meta_graph.create_meta_graph_def(graph=g) + rewriter_config = rewriter_config_pb2.RewriterConfig() + optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) + mg.graph_def.CopyFrom(optimized_graph) + + # Check that the nodes referenced in various collections have been preserved + item = gitem.Item(mg) + props = item.GetOpProperties() + buf_prop = props[buf.op.name] + f_prop = props[f.op.name] + self.assertEqual(buf_prop, f_prop) + if __name__ == '__main__': test.main() |