aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-04-03 17:52:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-03 17:55:29 -0700
commitbde05fdf247ea6311414677f55260f7e8085718f (patch)
treeabfcb745d793fa4bd34605bcbd7ac55a3e0539cd /tensorflow/python/grappler
parent5b652f57709d30d883570a82ac500051d8bfe1e6 (diff)
Fix a shape inference bug.
PiperOrigin-RevId: 191528009
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r--tensorflow/python/grappler/tf_optimizer_test.py47
1 files changed, 46 insertions, 1 deletions
diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py
index 3ee4d7807e..1c0f072dd3 100644
--- a/tensorflow/python/grappler/tf_optimizer_test.py
+++ b/tensorflow/python/grappler/tf_optimizer_test.py
@@ -17,12 +17,16 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.grappler import item as gitem
from tensorflow.python.grappler import tf_optimizer
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -74,6 +78,47 @@ class PyWrapOptimizeGraphTest(test.TestCase):
self.assertEqual(a2.op.name, optimized_graph.node[3].name)
self.assertEqual('Variable/Assign', optimized_graph.node[4].name)
+ def testLoops(self):
+ g = ops.Graph()
+ with g.as_default():
+
+ def _Cond(_, counter):
+ return counter < end
+
+ def _Body(buf, counter):
+ buf = array_ops.concat([buf, [counter]], 0)
+ counter += 1
+ return [buf, counter]
+
+ start = array_ops.placeholder(shape=[], dtype=dtypes.int32)
+ end = array_ops.placeholder(shape=[], dtype=dtypes.int32)
+ init_buf = array_ops.zeros(shape=[0], dtype=dtypes.int32)
+ loop_vars = [init_buf, start]
+ shape_inv = [
+ tensor_shape.TensorShape([None]),
+ tensor_shape.TensorShape([])
+ ]
+ buf, _ = control_flow_ops.while_loop(_Cond, _Body, loop_vars, shape_inv)
+
+ f = -array_ops.ones_like(buf, optimize=False)
+ buf_shape = array_ops.shape(buf)
+ f_shape = array_ops.shape(f)
+ ops.add_to_collection('train_op', buf_shape)
+ ops.add_to_collection('train_op', f_shape)
+
+ # Optimize the graph.
+ mg = meta_graph.create_meta_graph_def(graph=g)
+ rewriter_config = rewriter_config_pb2.RewriterConfig()
+ optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
+ mg.graph_def.CopyFrom(optimized_graph)
+
+ # Check that the nodes referenced in various collections have been preserved
+ item = gitem.Item(mg)
+ props = item.GetOpProperties()
+ buf_prop = props[buf.op.name]
+ f_prop = props[f.op.name]
+ self.assertEqual(buf_prop, f_prop)
+
if __name__ == '__main__':
test.main()