aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/graph_editor/tests/transform_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/graph_editor/tests/transform_test.py')
-rw-r--r--tensorflow/contrib/graph_editor/tests/transform_test.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py
index 6b26869236..9a06431320 100644
--- a/tensorflow/contrib/graph_editor/tests/transform_test.py
+++ b/tensorflow/contrib/graph_editor/tests/transform_test.py
@@ -58,6 +58,20 @@ class TransformTest(tf.test.TestCase):
self.assertEqual(t.name, t_.name)
self.assertEqual(info.original(t_), t)
+ def test_copy_assert(self):
+ tf.reset_default_graph()
+ a = tf.constant(1)
+ b = tf.constant(1)
+ eq = tf.equal(a, b)
+ assert_op = tf.Assert(eq, [a, b])
+ with tf.control_dependencies([assert_op]):
+ _ = tf.add(a, b)
+ sgv = ge.make_view([assert_op, eq.op, a.op, b.op])
+ copier = ge.Transformer()
+ copied_sgv, info = copier(sgv, sgv.graph, "", "")
+ new_assert_op = info.transformed(assert_op)
+ self.assertIsNotNone(new_assert_op)
+
def test_transform(self):
transformer = ge.Transformer()
def my_transform_op_handler(info, op):