diff options
author | 2017-11-30 16:37:11 -0800 | |
---|---|---|
committer | 2017-11-30 16:41:01 -0800 | |
commit | b2db981a6731e978453862a73dab892bc674db68 (patch) | |
tree | c11a7c4038e2595268113c2859c1d0d3072ede4f /tensorflow/contrib/framework | |
parent | 0438ac79bdb503ed267bec2146e7136ac8e99ff9 (diff) |
Merge changes from github.
PiperOrigin-RevId: 177526301
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r-- | tensorflow/contrib/framework/python/framework/graph_util.py | 28 | ||||
-rw-r--r-- | tensorflow/contrib/framework/python/framework/graph_util_test.py | 14 |
2 files changed, 41 insertions, 1 deletions
diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py index 9ba9c77b92..a18ff2320d 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util.py +++ b/tensorflow/contrib/framework/python/framework/graph_util.py @@ -24,12 +24,14 @@ import six # pylint: disable=unused-import from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.framework import ops from tensorflow.python.framework.graph_util_impl import _assert_nodes_are_present from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes from tensorflow.python.framework.graph_util_impl import _extract_graph_summary from tensorflow.python.framework.graph_util_impl import _node_name -__all__ = ["fuse_op"] + +__all__ = ["fuse_op", "get_placeholders"] def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, @@ -126,3 +128,27 @@ def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, out.library.CopyFrom(graph_def.library) out.versions.CopyFrom(graph_def.versions) return out + + +def get_placeholders(graph): + """Get placeholders of a graph. + + Args: + graph: A tf.Graph. + Returns: + A list contains all placeholders of given graph. + + Raises: + TypeError: If `graph` is not a tensorflow graph. + """ + + if not isinstance(graph, ops.Graph): + raise TypeError("Input graph needs to be a Graph: %s" % graph) + + # For each placeholder() call, there is a corresponding + # operation of type 'Placeholder' registered to the graph. + # The return value (a Tensor) of placeholder() is the + # first output of this operation in fact. + operations = graph.get_operations() + result = [i.outputs[0] for i in operations if i.type == "Placeholder"] + return result diff --git a/tensorflow/contrib/framework/python/framework/graph_util_test.py b/tensorflow/contrib/framework/python/framework/graph_util_test.py index 0c531fb290..b8a6d109e1 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util_test.py +++ b/tensorflow/contrib/framework/python/framework/graph_util_test.py @@ -21,6 +21,9 @@ from tensorflow.contrib.framework.python.framework import graph_util from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import types_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -81,5 +84,16 @@ class GraphUtilTest(test.TestCase): self.assertEqual(fused_graph_def.node[4].name, 'E') +class GetPlaceholdersTest(test.TestCase): + + def test_get_placeholders(self): + with ops.Graph().as_default() as g: + placeholders = [array_ops.placeholder(dtypes.float32) for _ in range(5)] + results = graph_util.get_placeholders(g) + self.assertEqual( + sorted(placeholders, key=lambda x: x._id), # pylint: disable=protected-access + sorted(results, key=lambda x: x._id)) # pylint: disable=protected-access + + if __name__ == '__main__': test.main() |