diff options
Diffstat (limited to 'tensorflow/contrib/graph_editor/tests/transform_test.py')
-rw-r--r-- | tensorflow/contrib/graph_editor/tests/transform_test.py | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py index 6b26869236..9a06431320 100644 --- a/tensorflow/contrib/graph_editor/tests/transform_test.py +++ b/tensorflow/contrib/graph_editor/tests/transform_test.py @@ -58,6 +58,20 @@ class TransformTest(tf.test.TestCase): self.assertEqual(t.name, t_.name) self.assertEqual(info.original(t_), t) + def test_copy_assert(self): + tf.reset_default_graph() + a = tf.constant(1) + b = tf.constant(1) + eq = tf.equal(a, b) + assert_op = tf.Assert(eq, [a, b]) + with tf.control_dependencies([assert_op]): + _ = tf.add(a, b) + sgv = ge.make_view([assert_op, eq.op, a.op, b.op]) + copier = ge.Transformer() + copied_sgv, info = copier(sgv, sgv.graph, "", "") + new_assert_op = info.transformed(assert_op) + self.assertIsNotNone(new_assert_op) + def test_transform(self): transformer = ge.Transformer() def my_transform_op_handler(info, op): |