aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework
diff options
context:
space:
mode:
authorGravatar Sourabh Bajaj <sourabhbajaj@google.com>2017-11-30 16:37:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-30 16:41:01 -0800
commitb2db981a6731e978453862a73dab892bc674db68 (patch)
treec11a7c4038e2595268113c2859c1d0d3072ede4f /tensorflow/contrib/framework
parent0438ac79bdb503ed267bec2146e7136ac8e99ff9 (diff)
Merge changes from github.
PiperOrigin-RevId: 177526301
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r--tensorflow/contrib/framework/python/framework/graph_util.py28
-rw-r--r--tensorflow/contrib/framework/python/framework/graph_util_test.py14
2 files changed, 41 insertions, 1 deletions
diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py
index 9ba9c77b92..a18ff2320d 100644
--- a/tensorflow/contrib/framework/python/framework/graph_util.py
+++ b/tensorflow/contrib/framework/python/framework/graph_util.py
@@ -24,12 +24,14 @@ import six
# pylint: disable=unused-import
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
+from tensorflow.python.framework import ops
from tensorflow.python.framework.graph_util_impl import _assert_nodes_are_present
from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
from tensorflow.python.framework.graph_util_impl import _node_name
-__all__ = ["fuse_op"]
+
+__all__ = ["fuse_op", "get_placeholders"]
def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes,
@@ -126,3 +128,27 @@ def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes,
out.library.CopyFrom(graph_def.library)
out.versions.CopyFrom(graph_def.versions)
return out
+
+
+def get_placeholders(graph):
+ """Get placeholders of a graph.
+
+ Args:
+ graph: A tf.Graph.
+ Returns:
+ A list contains all placeholders of given graph.
+
+ Raises:
+ TypeError: If `graph` is not a tensorflow graph.
+ """
+
+ if not isinstance(graph, ops.Graph):
+ raise TypeError("Input graph needs to be a Graph: %s" % graph)
+
+ # For each placeholder() call, there is a corresponding
+ # operation of type 'Placeholder' registered to the graph.
+ # The return value (a Tensor) of placeholder() is the
+ # first output of this operation in fact.
+ operations = graph.get_operations()
+ result = [i.outputs[0] for i in operations if i.type == "Placeholder"]
+ return result
diff --git a/tensorflow/contrib/framework/python/framework/graph_util_test.py b/tensorflow/contrib/framework/python/framework/graph_util_test.py
index 0c531fb290..b8a6d109e1 100644
--- a/tensorflow/contrib/framework/python/framework/graph_util_test.py
+++ b/tensorflow/contrib/framework/python/framework/graph_util_test.py
@@ -21,6 +21,9 @@ from tensorflow.contrib.framework.python.framework import graph_util
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.core.framework import types_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -81,5 +84,16 @@ class GraphUtilTest(test.TestCase):
self.assertEqual(fused_graph_def.node[4].name, 'E')
+class GetPlaceholdersTest(test.TestCase):
+
+ def test_get_placeholders(self):
+ with ops.Graph().as_default() as g:
+ placeholders = [array_ops.placeholder(dtypes.float32) for _ in range(5)]
+ results = graph_util.get_placeholders(g)
+ self.assertEqual(
+ sorted(placeholders, key=lambda x: x._id), # pylint: disable=protected-access
+ sorted(results, key=lambda x: x._id)) # pylint: disable=protected-access
+
+
if __name__ == '__main__':
test.main()