diff options
author | 2015-12-08 14:55:13 -0800 | |
---|---|---|
committer | 2015-12-08 14:55:13 -0800 | |
commit | 2c3738db9c4df83adc1aff29f5cb0e9735dd5eac (patch) | |
tree | 0d187dafeddfabaf9cd5ac4b491001dd0d639ee4 /tensorflow/python/client/graph_util_test.py | |
parent | ddd4aaf5286de24ba70402ee0ec8b836d3aed8c7 (diff) |
TensorFlow: Upstream changes to git.
Change 109730179
Add support for selecting partition strategy in tf.nn.embedding_lookup and related ops, and allow unequally-sized shards to be used as input.
Change 109729548
TensorFlow: add RELEASE.md notes for 0.6.0.
Change 109728185
Make seq2seq_test non-flaky by setting python and numpy random seed.
Change 109725913
Refactor slot creation in optimizers and moving averages to separate file
Change 109718024
TensorFlow: reduce runtime of seq2seq_test from ~30s to ~18s.
Change 109712251
More performance improvement for convnet on GPU.
+ Switch forward convolution format to NCHW.
+ Allocate scratch space for forward- and backward- convolutions.
+ Users can use "TF_CUDNN_WORKSPACE_LIMIT_IN_MB" to configure the scratch space
limit. The default limit in 1GB.
Change 109710898
Added extract_sub_graph utility function
Base CL: 109731609
Diffstat (limited to 'tensorflow/python/client/graph_util_test.py')
-rw-r--r-- | tensorflow/python/client/graph_util_test.py | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/python/client/graph_util_test.py b/tensorflow/python/client/graph_util_test.py index 6b7dba60bc..73265361cd 100644 --- a/tensorflow/python/client/graph_util_test.py +++ b/tensorflow/python/client/graph_util_test.py @@ -20,6 +20,8 @@ from __future__ import print_function import tensorflow.python.platform +import tensorflow as tf + from tensorflow.python.client import graph_util from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -140,6 +142,34 @@ class DeviceFunctionsTest(googletest.TestCase): self.assertEqual(const_4.device, "/device:CPU:1") self.assertEqual(const_5.device, "/replica:0") + def testExtractSubGraph(self): + graph_def = tf.GraphDef() + n1 = graph_def.node.add() + n1.name = "n1" + n1.input.extend(["n5"]) + n2 = graph_def.node.add() + n2.name = "n2" + # Take the first output of the n1 node as the input. + n2.input.extend(["n1:0"]) + n3 = graph_def.node.add() + n3.name = "n3" + # Add a control input (which isn't really needed by the kernel, but + # rather to enforce execution order between nodes). + n3.input.extend(["^n2"]) + n4 = graph_def.node.add() + n4.name = "n4" + + # It is fine to have a loops in the graph as well. + n5 = graph_def.node.add() + n5.name = "n5" + n5.input.extend(["n1"]) + + sub_graph = graph_util.extract_sub_graph(graph_def, ["n3"]) + self.assertEqual("n1", sub_graph.node[0].name) + self.assertEqual("n2", sub_graph.node[1].name) + self.assertEqual("n3", sub_graph.node[2].name) + self.assertEqual("n5", sub_graph.node[3].name) + if __name__ == "__main__": googletest.main() |