aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client/graph_util_test.py
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2015-12-08 14:55:13 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-12-08 14:55:13 -0800
commit2c3738db9c4df83adc1aff29f5cb0e9735dd5eac (patch)
tree0d187dafeddfabaf9cd5ac4b491001dd0d639ee4 /tensorflow/python/client/graph_util_test.py
parentddd4aaf5286de24ba70402ee0ec8b836d3aed8c7 (diff)
TensorFlow: Upstream changes to git.
Change 109730179 Add support for selecting partition strategy in tf.nn.embedding_lookup and related ops, and allow unequally-sized shards to be used as input. Change 109729548 TensorFlow: add RELEASE.md notes for 0.6.0. Change 109728185 Make seq2seq_test non-flaky by setting python and numpy random seed. Change 109725913 Refactor slot creation in optimizers and moving averages to separate file Change 109718024 TensorFlow: reduce runtime of seq2seq_test from ~30s to ~18s. Change 109712251 More performance improvement for convnet on GPU. + Switch forward convolution format to NCHW. + Allocate scratch space for forward- and backward- convolutions. + Users can use "TF_CUDNN_WORKSPACE_LIMIT_IN_MB" to configure the scratch space limit. The default limit in 1GB. Change 109710898 Added extract_sub_graph utility function Base CL: 109731609
Diffstat (limited to 'tensorflow/python/client/graph_util_test.py')
-rw-r--r--tensorflow/python/client/graph_util_test.py30
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/python/client/graph_util_test.py b/tensorflow/python/client/graph_util_test.py
index 6b7dba60bc..73265361cd 100644
--- a/tensorflow/python/client/graph_util_test.py
+++ b/tensorflow/python/client/graph_util_test.py
@@ -20,6 +20,8 @@ from __future__ import print_function
import tensorflow.python.platform
+import tensorflow as tf
+
from tensorflow.python.client import graph_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -140,6 +142,34 @@ class DeviceFunctionsTest(googletest.TestCase):
self.assertEqual(const_4.device, "/device:CPU:1")
self.assertEqual(const_5.device, "/replica:0")
+ def testExtractSubGraph(self):
+ graph_def = tf.GraphDef()
+ n1 = graph_def.node.add()
+ n1.name = "n1"
+ n1.input.extend(["n5"])
+ n2 = graph_def.node.add()
+ n2.name = "n2"
+ # Take the first output of the n1 node as the input.
+ n2.input.extend(["n1:0"])
+ n3 = graph_def.node.add()
+ n3.name = "n3"
+ # Add a control input (which isn't really needed by the kernel, but
+ # rather to enforce execution order between nodes).
+ n3.input.extend(["^n2"])
+ n4 = graph_def.node.add()
+ n4.name = "n4"
+
+ # It is fine to have a loops in the graph as well.
+ n5 = graph_def.node.add()
+ n5.name = "n5"
+ n5.input.extend(["n1"])
+
+ sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"])
+ self.assertEqual("n1", sub_graph.node[0].name)
+ self.assertEqual("n2", sub_graph.node[1].name)
+ self.assertEqual("n3", sub_graph.node[2].name)
+ self.assertEqual("n5", sub_graph.node[3].name)
+
if __name__ == "__main__":
googletest.main()