aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/graph_util_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/graph_util_test.py')
-rw-r--r--tensorflow/python/framework/graph_util_test.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/python/framework/graph_util_test.py b/tensorflow/python/framework/graph_util_test.py
index 647ed1583a..0421837d49 100644
--- a/tensorflow/python/framework/graph_util_test.py
+++ b/tensorflow/python/framework/graph_util_test.py
@@ -188,6 +188,13 @@ class DeviceFunctionsTest(test.TestCase):
self.assertEqual("n3", sub_graph.node[2].name)
self.assertEqual("n5", sub_graph.node[3].name)
+ def testExtractSubGraphWithInvalidDestNodes(self):
+ graph_def = graph_pb2.GraphDef()
+ n1 = graph_def.node.add()
+ n1.name = "n1"
+ with self.assertRaisesRegexp(TypeError, "must be a list"):
+ graph_util.extract_sub_graph(graph_def, "n1")
+
def testConvertVariablesToConstsWithFunctions(self):
@function.Defun(dtypes.float32)
def plus_one(x):