aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-05-31 16:06:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-31 16:15:46 -0700
commit2c38e7c770c3b4a32a123452ced31e24a0297342 (patch)
tree6bb9081f2072c7e2d81f1a61cbe97bf10a8041c7
parent05c050218b676227fbc0fd24e053f76380ac218e (diff)
Add utility for converting FunctionDef to GraphDef and _FuncGraph.
PiperOrigin-RevId: 198795625
-rw-r--r--tensorflow/python/BUILD32
-rw-r--r--tensorflow/python/framework/function_def_to_graph.py189
-rw-r--r--tensorflow/python/framework/function_def_to_graph_test.py184
3 files changed, 405 insertions, 0 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index b15c5291f5..569403fa9a 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -718,6 +718,38 @@ py_library(
)
py_library(
+ name = "function_def_to_graph",
+ srcs = ["framework/function_def_to_graph.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":framework",
+ ":function",
+ ":op_def_registry",
+ ":tensor_shape",
+ ":versions",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+py_test(
+ name = "function_def_to_graph_test",
+ size = "small",
+ srcs = ["framework/function_def_to_graph_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":array_ops",
+ ":client_testlib",
+ ":dtypes",
+ ":framework_ops",
+ ":function_def_to_graph",
+ ":graph_to_function_def",
+ ":math_ops",
+ ":test_ops",
+ ],
+)
+
+py_library(
name = "graph_util",
srcs = [
"framework/graph_util.py",
diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py
new file mode 100644
index 0000000000..4fecc41343
--- /dev/null
+++ b/tensorflow/python/framework/function_def_to_graph.py
@@ -0,0 +1,189 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Utlity to convert FunctionDef to GraphDef and Graph."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.core.framework import versions_pb2
+from tensorflow.python.framework import function
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import op_def_registry
+from tensorflow.python.framework import versions
+
+
+def function_def_to_graph(fdef, input_shapes=None):
+ """Converts a FunctionDef to a function._FuncGraph (sub-class Graph).
+
+ The returned _FuncGraph's `name`, `inputs` and `outputs` fields will be set.
+ The input tensors are represented as placeholders.
+
+ Note: `_FuncGraph.inputs` and `_FuncGraph._captured` are not set and may be
+ set by the caller.
+
+ Args:
+ fdef: FunctionDef.
+ input_shapes: Optional. A list of TensorShape objects of the shapes of
+ function inputs. If specified, its length must match length of
+ `fdef.signature.input_arg`. If a shape is None, the corresponding input
+ placeholder will have unknown shape.
+
+ Returns:
+ A _FuncGraph.
+ """
+ func_graph = function._FuncGraph(fdef.signature.name, capture_by_value=False) # pylint: disable=protected-access
+ graph_def, nested_to_flat_tensor_name = function_def_to_graph_def(
+ fdef, input_shapes)
+
+ with func_graph.as_default():
+ # Add all function nodes to the graph.
+ importer.import_graph_def(graph_def, name="")
+
+ # Initialize fields specific to _FuncGraph.
+
+ # inputs
+ input_tensor_names = [
+ nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg
+ ]
+ func_graph.inputs = [
+ func_graph.get_tensor_by_name(name) for name in input_tensor_names
+ ]
+
+ # outputs
+ output_tensor_names = [
+ nested_to_flat_tensor_name[fdef.ret[arg.name]]
+ for arg in fdef.signature.output_arg
+ ]
+ func_graph.outputs = [
+ func_graph.get_tensor_by_name(name) for name in output_tensor_names
+ ]
+
+ return func_graph
+
+
+def function_def_to_graph_def(fdef, input_shapes=None):
+ """Convert a FunctionDef to a GraphDef.
+
+ Steps:
+ 1. Creates placeholder nodes corresponding to inputs in
+ `FunctionDef.signature.input_arg`.
+ 2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`.
+ 3. Renames inputs of all nodes to use the convention of GraphDef instead of
+ FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming
+ in FunctionDefs is different from GraphDefs.
+
+ Args:
+ fdef: FunctionDef.
+ input_shapes: Optional. A list of TensorShape objects of the shapes of
+ function inputs. If specified, its length must match length of
+ `fdef.signature.input_arg`. If a shape is None, the corresponding input
+ placeholder will have unknown shape.
+
+ Returns:
+ A tuple of (GraphDef, dict<string, string>). The dict contains a mapping
+ from nested tensor names (in FunctionDef) to flattened names (in GraphDef).
+
+ Raises:
+ ValueError: If the length of input_shapes does not match the number of
+ input_args or if the FunctionDef is invalid.
+ """
+ graph_def = graph_pb2.GraphDef()
+ graph_def.versions.CopyFrom(
+ versions_pb2.VersionDef(
+ producer=versions.GRAPH_DEF_VERSION,
+ min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER))
+
+ if input_shapes and len(input_shapes) != len(fdef.signature.input_arg):
+ raise ValueError("Length of input_shapes must match the number of " +
+ "input_args. len(input_shapes): {} len(input_arg): {}".
+ format(len(input_shapes), len(fdef.signature.input_arg)))
+
+ # 1. Create placeholders for input nodes.
+ for i, arg_def in enumerate(fdef.signature.input_arg):
+ node_def = graph_def.node.add()
+ node_def.name = arg_def.name
+ node_def.op = "Placeholder"
+ node_def.attr["dtype"].type = arg_def.type
+ if input_shapes and input_shapes[i] is not None:
+ node_def.attr["shape"].shape.CopyFrom(input_shapes[i].as_proto())
+
+ # 2. Copy all body NodeDefs to the GraphDef.
+ graph_def.node.extend(fdef.node_def)
+
+ # 3. Perform the renaming.
+
+ # Build the tensor name mapping then flatten the tensor names.
+ # See comment on `FunctionDef.node_def` on how the tensor naming in
+ # FunctionDefs is different from GraphDefs.
+ nested_to_flat_tensor_name = {}
+
+ for arg_def in fdef.signature.input_arg:
+ nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name)
+
+ for node_def in fdef.node_def:
+ op_def = op_def_registry.get_registered_ops().get(node_def.op)
+ if not op_def:
+ # TODO(b/80470245): Support functions which refer other functions.
+ raise NotImplementedError(
+ "No op registered for {},".format(node_def.op) +
+ " it may be a function. function_def_to_graph_def " +
+ "currently does not support converting functions with " +
+ "references to other graph functions.")
+
+ for attr in op_def.attr:
+ if attr.type in ("func", "list(func)"):
+ # TODO(b/80470245): Support functions which refer other functions.
+ raise NotImplementedError("Unsupported attr {} ".format(attr.name) +
+ " with type {}".format(attr.type) +
+ " in op {}. ".format(op_def.name) +
+ "function_def_to_graph_def currently does " +
+ "not support converting functions with " +
+ "references to other graph functions.")
+
+ # Iterate over output_args in op_def to build the map.
+ # Index of the output tensor in the flattened list of *all* output
+ # tensors of the op.
+ flattened_index = 0
+ for arg_def in op_def.output_arg:
+ num_args = _get_num_args(arg_def, node_def)
+ for i in range(num_args):
+ # Map tensor names from "node_name:output_arg_name:index" to
+ # "node_name:flattened_index".
+ nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i)
+ flat_name = "{}:{}".format(node_def.name, flattened_index)
+ nested_to_flat_tensor_name[nested_name] = flat_name
+ flattened_index += 1
+
+ # Update inputs of all nodes in graph.
+ for node_def in graph_def.node:
+ for i in range(len(node_def.input)):
+ node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]]
+
+ return graph_def, nested_to_flat_tensor_name
+
+
+# Based on implementation in core/framework/node_def_util.cc::ComputeArgRange.
+def _get_num_args(arg_def, node_def):
+ if arg_def.number_attr:
+ return node_def.attr[arg_def.number_attr].i
+ elif arg_def.type_list_attr:
+ return len(node_def.attr[arg_def.type_list_attr].list.type)
+ elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID:
+ return 1
+ else:
+ raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def)))
diff --git a/tensorflow/python/framework/function_def_to_graph_test.py b/tensorflow/python/framework/function_def_to_graph_test.py
new file mode 100644
index 0000000000..0f4e6ef54f
--- /dev/null
+++ b/tensorflow/python/framework/function_def_to_graph_test.py
@@ -0,0 +1,184 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.python.framework.function_def_to_graph."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function_def_to_graph
+from tensorflow.python.framework import graph_to_function_def
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class FunctionDefToGraphTest(test.TestCase):
+
+ def _build_function_def(self):
+ with ops.Graph().as_default() as g:
+ # Inputs
+ x = array_ops.placeholder(dtypes.float32, name="x")
+ y = array_ops.placeholder(dtypes.float32, name="y")
+
+ # Outputs
+ sum_squares = math_ops.add_n(
+ [math_ops.pow(x, 2), math_ops.pow(y, 2)], name="sum_squares")
+ sum_cubes = math_ops.add_n(
+ [math_ops.pow(x, 3), math_ops.pow(y, 3)], name="sum_cubes")
+ fdef = graph_to_function_def.graph_to_function_def(
+ g,
+ g.get_operations(),
+ [x, y], # Inputs
+ [sum_squares, sum_cubes]) # Outputs.
+ fdef.signature.name = "_whats_in_a_name"
+ return fdef
+
+ def testInputsAndOutputs(self):
+ fdef = self._build_function_def()
+ g = function_def_to_graph.function_def_to_graph(fdef)
+ self.assertEqual(g.name, "_whats_in_a_name")
+ with self.test_session(graph=g) as sess:
+ inputs = sess.run(g.inputs, feed_dict={"x:0": 2, "y:0": 3})
+ self.assertSequenceEqual(inputs, [2.0, 3.0])
+ outputs = sess.run(g.outputs, feed_dict={"x:0": 2, "y:0": 3})
+ self.assertSequenceEqual(outputs, [13.0, 35.0])
+
+ def testShapes(self):
+ fdef = self._build_function_def()
+
+ g = function_def_to_graph.function_def_to_graph(fdef)
+ self.assertIsNone(g.inputs[0].shape.dims) # Unknown dims.
+ self.assertIsNone(g.inputs[1].shape.dims) # Unknown dims.
+ self.assertIsNone(g.outputs[0].shape.dims) # Unknown dims.
+ self.assertIsNone(g.outputs[1].shape.dims) # Unknown dims.
+
+ g = function_def_to_graph.function_def_to_graph(
+ fdef, input_shapes=[tensor_shape.vector(5),
+ tensor_shape.vector(5)])
+ self.assertSequenceEqual(g.inputs[0].shape.dims, [5])
+ self.assertSequenceEqual(g.inputs[1].shape.dims, [5])
+ self.assertSequenceEqual(g.outputs[0].shape.dims, [5])
+ self.assertSequenceEqual(g.outputs[1].shape.dims, [5])
+
+ g = function_def_to_graph.function_def_to_graph(
+ fdef, input_shapes=[None, tensor_shape.matrix(5, 7)])
+ print(g.as_graph_def())
+ self.assertIsNone(g.inputs[0].shape.dims)
+ self.assertSequenceEqual(g.inputs[1].shape.dims, [5, 7])
+ self.assertSequenceEqual(g.outputs[0].shape.dims, [5, 7])
+ self.assertSequenceEqual(g.outputs[1].shape.dims, [5, 7])
+
+ # Should raise a ValueError if the length of input_shapes does not match
+ # the number of input args in FunctionDef.signature.input_arg.
+ with self.assertRaises(ValueError):
+ g = function_def_to_graph.function_def_to_graph(
+ fdef, input_shapes=[tensor_shape.matrix(5, 7)])
+
+
+class FunctionDefToGraphDefTest(test.TestCase):
+
+ def _build_function_def(self):
+ with ops.Graph().as_default() as g:
+ # Inputs: x y z
+ # |\ | /
+ # | \ | /
+ # | foo_1 list_output
+ # | / \ / \
+ # | d_1 e_1 a:1 a:0
+ # | \ | / |
+ # | \ | / |
+ # | foo_2 |
+ # | / \ |
+ # Outputs: x d_2 e_2 a:0
+
+ x = array_ops.placeholder(dtypes.float32, name="x")
+ y = array_ops.placeholder(dtypes.int32, name="y")
+ z = array_ops.placeholder(dtypes.int32, name="z")
+
+ d_1, e_1 = test_ops._op_def_lib.apply_op(
+ "Foo1", name="foo_1", a=x, b=y, c=z)
+
+ list_output0, list_output1 = test_ops.list_output(
+ T=[dtypes.int32, dtypes.int32], name="list_output")
+
+ d_2, e_2 = test_ops.foo1(a=d_1, b=e_1, c=list_output1, name="foo_2")
+
+ fdef = graph_to_function_def.graph_to_function_def(
+ g,
+ g.get_operations(),
+ [x, y, z], # Inputs
+ [x, d_2, e_2, list_output0]) # Outputs.
+
+ # Assert that the FunctionDef was correctly built.
+ assert len(fdef.node_def) == 3 # 2 Foo1 nodes and 1 ListOutput node.
+ assert fdef.node_def[0].op == "Foo1"
+ assert fdef.node_def[0].input == ["x", "y", "z"]
+ assert fdef.node_def[1].op == "ListOutput"
+ assert not fdef.node_def[1].input
+ assert fdef.node_def[2].op == "Foo1"
+ assert fdef.node_def[2].input == [
+ "foo_1:d:0", "foo_1:e:0", "list_output:a:1"
+ ]
+ return fdef
+
+ def testTensorNames(self):
+ fdef = self._build_function_def()
+ g, tensor_name_map = function_def_to_graph.function_def_to_graph_def(fdef)
+
+ # Verify that inputs of body nodes are correctly renamed.
+ # foo_1
+ self.assertSequenceEqual(g.node[3].input, ["x:0", "y:0", "z:0"])
+ # foo_2
+ self.assertSequenceEqual(g.node[5].input,
+ ["foo_1:0", "foo_1:1", "list_output:1"])
+
+ # Verify that the `tensor_name_map` has the correct mapping.
+ self.assertDictEqual(
+ tensor_name_map, {
+ "x": "x:0",
+ "y": "y:0",
+ "z": "z:0",
+ "foo_1:d:0": "foo_1:0",
+ "foo_1:e:0": "foo_1:1",
+ "list_output:a:0": "list_output:0",
+ "list_output:a:1": "list_output:1",
+ "foo_2:d:0": "foo_2:0",
+ "foo_2:e:0": "foo_2:1",
+ })
+
+ def testShapes(self):
+ fdef = self._build_function_def()
+ g, _ = function_def_to_graph.function_def_to_graph_def(
+ fdef,
+ input_shapes=[tensor_shape.scalar(),
+ tensor_shape.vector(5), None])
+ self.assertEqual("shape" in g.node[0].attr, True)
+ self.assertSequenceEqual(
+ tensor_shape.TensorShape(g.node[0].attr["shape"].shape).as_list(), [])
+ self.assertEqual(g.node[0].attr["shape"].shape.unknown_rank, False)
+ self.assertEqual("shape" in g.node[1].attr, True)
+ self.assertSequenceEqual(
+ tensor_shape.TensorShape(g.node[1].attr["shape"].shape).as_list(), [5])
+ self.assertEqual(g.node[0].attr["shape"].shape.unknown_rank, False)
+ self.assertFalse("shape" in g.node[2].attr)
+
+
+if __name__ == "__main__":
+ test.main()