diff options
author | 2018-05-31 16:06:15 -0700 | |
---|---|---|
committer | 2018-05-31 16:15:46 -0700 | |
commit | 2c38e7c770c3b4a32a123452ced31e24a0297342 (patch) | |
tree | 6bb9081f2072c7e2d81f1a61cbe97bf10a8041c7 | |
parent | 05c050218b676227fbc0fd24e053f76380ac218e (diff) |
Add utility for converting FunctionDef to GraphDef and _FuncGraph.
PiperOrigin-RevId: 198795625
-rw-r--r-- | tensorflow/python/BUILD | 32 | ||||
-rw-r--r-- | tensorflow/python/framework/function_def_to_graph.py | 189 | ||||
-rw-r--r-- | tensorflow/python/framework/function_def_to_graph_test.py | 184 |
3 files changed, 405 insertions, 0 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index b15c5291f5..569403fa9a 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -718,6 +718,38 @@ py_library( ) py_library( + name = "function_def_to_graph", + srcs = ["framework/function_def_to_graph.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework", + ":function", + ":op_def_registry", + ":tensor_shape", + ":versions", + "//tensorflow/core:protos_all_py", + ], +) + +py_test( + name = "function_def_to_graph_test", + size = "small", + srcs = ["framework/function_def_to_graph_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":array_ops", + ":client_testlib", + ":dtypes", + ":framework_ops", + ":function_def_to_graph", + ":graph_to_function_def", + ":math_ops", + ":test_ops", + ], +) + +py_library( name = "graph_util", srcs = [ "framework/graph_util.py", diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py new file mode 100644 index 0000000000..4fecc41343 --- /dev/null +++ b/tensorflow/python/framework/function_def_to_graph.py @@ -0,0 +1,189 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Utlity to convert FunctionDef to GraphDef and Graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.core.framework import versions_pb2 +from tensorflow.python.framework import function +from tensorflow.python.framework import importer +from tensorflow.python.framework import op_def_registry +from tensorflow.python.framework import versions + + +def function_def_to_graph(fdef, input_shapes=None): + """Converts a FunctionDef to a function._FuncGraph (sub-class Graph). + + The returned _FuncGraph's `name`, `inputs` and `outputs` fields will be set. + The input tensors are represented as placeholders. + + Note: `_FuncGraph.inputs` and `_FuncGraph._captured` are not set and may be + set by the caller. + + Args: + fdef: FunctionDef. + input_shapes: Optional. A list of TensorShape objects of the shapes of + function inputs. If specified, its length must match length of + `fdef.signature.input_arg`. If a shape is None, the corresponding input + placeholder will have unknown shape. + + Returns: + A _FuncGraph. + """ + func_graph = function._FuncGraph(fdef.signature.name, capture_by_value=False) # pylint: disable=protected-access + graph_def, nested_to_flat_tensor_name = function_def_to_graph_def( + fdef, input_shapes) + + with func_graph.as_default(): + # Add all function nodes to the graph. + importer.import_graph_def(graph_def, name="") + + # Initialize fields specific to _FuncGraph. + + # inputs + input_tensor_names = [ + nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg + ] + func_graph.inputs = [ + func_graph.get_tensor_by_name(name) for name in input_tensor_names + ] + + # outputs + output_tensor_names = [ + nested_to_flat_tensor_name[fdef.ret[arg.name]] + for arg in fdef.signature.output_arg + ] + func_graph.outputs = [ + func_graph.get_tensor_by_name(name) for name in output_tensor_names + ] + + return func_graph + + +def function_def_to_graph_def(fdef, input_shapes=None): + """Convert a FunctionDef to a GraphDef. + + Steps: + 1. Creates placeholder nodes corresponding to inputs in + `FunctionDef.signature.input_arg`. + 2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`. + 3. Renames inputs of all nodes to use the convention of GraphDef instead of + FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming + in FunctionDefs is different from GraphDefs. + + Args: + fdef: FunctionDef. + input_shapes: Optional. A list of TensorShape objects of the shapes of + function inputs. If specified, its length must match length of + `fdef.signature.input_arg`. If a shape is None, the corresponding input + placeholder will have unknown shape. + + Returns: + A tuple of (GraphDef, dict<string, string>). The dict contains a mapping + from nested tensor names (in FunctionDef) to flattened names (in GraphDef). + + Raises: + ValueError: If the length of input_shapes does not match the number of + input_args or if the FunctionDef is invalid. + """ + graph_def = graph_pb2.GraphDef() + graph_def.versions.CopyFrom( + versions_pb2.VersionDef( + producer=versions.GRAPH_DEF_VERSION, + min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)) + + if input_shapes and len(input_shapes) != len(fdef.signature.input_arg): + raise ValueError("Length of input_shapes must match the number of " + + "input_args. len(input_shapes): {} len(input_arg): {}". + format(len(input_shapes), len(fdef.signature.input_arg))) + + # 1. Create placeholders for input nodes. + for i, arg_def in enumerate(fdef.signature.input_arg): + node_def = graph_def.node.add() + node_def.name = arg_def.name + node_def.op = "Placeholder" + node_def.attr["dtype"].type = arg_def.type + if input_shapes and input_shapes[i] is not None: + node_def.attr["shape"].shape.CopyFrom(input_shapes[i].as_proto()) + + # 2. Copy all body NodeDefs to the GraphDef. + graph_def.node.extend(fdef.node_def) + + # 3. Perform the renaming. + + # Build the tensor name mapping then flatten the tensor names. + # See comment on `FunctionDef.node_def` on how the tensor naming in + # FunctionDefs is different from GraphDefs. + nested_to_flat_tensor_name = {} + + for arg_def in fdef.signature.input_arg: + nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name) + + for node_def in fdef.node_def: + op_def = op_def_registry.get_registered_ops().get(node_def.op) + if not op_def: + # TODO(b/80470245): Support functions which refer other functions. + raise NotImplementedError( + "No op registered for {},".format(node_def.op) + + " it may be a function. function_def_to_graph_def " + + "currently does not support converting functions with " + + "references to other graph functions.") + + for attr in op_def.attr: + if attr.type in ("func", "list(func)"): + # TODO(b/80470245): Support functions which refer other functions. + raise NotImplementedError("Unsupported attr {} ".format(attr.name) + + " with type {}".format(attr.type) + + " in op {}. ".format(op_def.name) + + "function_def_to_graph_def currently does " + + "not support converting functions with " + + "references to other graph functions.") + + # Iterate over output_args in op_def to build the map. + # Index of the output tensor in the flattened list of *all* output + # tensors of the op. + flattened_index = 0 + for arg_def in op_def.output_arg: + num_args = _get_num_args(arg_def, node_def) + for i in range(num_args): + # Map tensor names from "node_name:output_arg_name:index" to + # "node_name:flattened_index". + nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i) + flat_name = "{}:{}".format(node_def.name, flattened_index) + nested_to_flat_tensor_name[nested_name] = flat_name + flattened_index += 1 + + # Update inputs of all nodes in graph. + for node_def in graph_def.node: + for i in range(len(node_def.input)): + node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]] + + return graph_def, nested_to_flat_tensor_name + + +# Based on implementation in core/framework/node_def_util.cc::ComputeArgRange. +def _get_num_args(arg_def, node_def): + if arg_def.number_attr: + return node_def.attr[arg_def.number_attr].i + elif arg_def.type_list_attr: + return len(node_def.attr[arg_def.type_list_attr].list.type) + elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID: + return 1 + else: + raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def))) diff --git a/tensorflow/python/framework/function_def_to_graph_test.py b/tensorflow/python/framework/function_def_to_graph_test.py new file mode 100644 index 0000000000..0f4e6ef54f --- /dev/null +++ b/tensorflow/python/framework/function_def_to_graph_test.py @@ -0,0 +1,184 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.python.framework.function_def_to_graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function_def_to_graph +from tensorflow.python.framework import graph_to_function_def +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class FunctionDefToGraphTest(test.TestCase): + + def _build_function_def(self): + with ops.Graph().as_default() as g: + # Inputs + x = array_ops.placeholder(dtypes.float32, name="x") + y = array_ops.placeholder(dtypes.float32, name="y") + + # Outputs + sum_squares = math_ops.add_n( + [math_ops.pow(x, 2), math_ops.pow(y, 2)], name="sum_squares") + sum_cubes = math_ops.add_n( + [math_ops.pow(x, 3), math_ops.pow(y, 3)], name="sum_cubes") + fdef = graph_to_function_def.graph_to_function_def( + g, + g.get_operations(), + [x, y], # Inputs + [sum_squares, sum_cubes]) # Outputs. + fdef.signature.name = "_whats_in_a_name" + return fdef + + def testInputsAndOutputs(self): + fdef = self._build_function_def() + g = function_def_to_graph.function_def_to_graph(fdef) + self.assertEqual(g.name, "_whats_in_a_name") + with self.test_session(graph=g) as sess: + inputs = sess.run(g.inputs, feed_dict={"x:0": 2, "y:0": 3}) + self.assertSequenceEqual(inputs, [2.0, 3.0]) + outputs = sess.run(g.outputs, feed_dict={"x:0": 2, "y:0": 3}) + self.assertSequenceEqual(outputs, [13.0, 35.0]) + + def testShapes(self): + fdef = self._build_function_def() + + g = function_def_to_graph.function_def_to_graph(fdef) + self.assertIsNone(g.inputs[0].shape.dims) # Unknown dims. + self.assertIsNone(g.inputs[1].shape.dims) # Unknown dims. + self.assertIsNone(g.outputs[0].shape.dims) # Unknown dims. + self.assertIsNone(g.outputs[1].shape.dims) # Unknown dims. + + g = function_def_to_graph.function_def_to_graph( + fdef, input_shapes=[tensor_shape.vector(5), + tensor_shape.vector(5)]) + self.assertSequenceEqual(g.inputs[0].shape.dims, [5]) + self.assertSequenceEqual(g.inputs[1].shape.dims, [5]) + self.assertSequenceEqual(g.outputs[0].shape.dims, [5]) + self.assertSequenceEqual(g.outputs[1].shape.dims, [5]) + + g = function_def_to_graph.function_def_to_graph( + fdef, input_shapes=[None, tensor_shape.matrix(5, 7)]) + print(g.as_graph_def()) + self.assertIsNone(g.inputs[0].shape.dims) + self.assertSequenceEqual(g.inputs[1].shape.dims, [5, 7]) + self.assertSequenceEqual(g.outputs[0].shape.dims, [5, 7]) + self.assertSequenceEqual(g.outputs[1].shape.dims, [5, 7]) + + # Should raise a ValueError if the length of input_shapes does not match + # the number of input args in FunctionDef.signature.input_arg. + with self.assertRaises(ValueError): + g = function_def_to_graph.function_def_to_graph( + fdef, input_shapes=[tensor_shape.matrix(5, 7)]) + + +class FunctionDefToGraphDefTest(test.TestCase): + + def _build_function_def(self): + with ops.Graph().as_default() as g: + # Inputs: x y z + # |\ | / + # | \ | / + # | foo_1 list_output + # | / \ / \ + # | d_1 e_1 a:1 a:0 + # | \ | / | + # | \ | / | + # | foo_2 | + # | / \ | + # Outputs: x d_2 e_2 a:0 + + x = array_ops.placeholder(dtypes.float32, name="x") + y = array_ops.placeholder(dtypes.int32, name="y") + z = array_ops.placeholder(dtypes.int32, name="z") + + d_1, e_1 = test_ops._op_def_lib.apply_op( + "Foo1", name="foo_1", a=x, b=y, c=z) + + list_output0, list_output1 = test_ops.list_output( + T=[dtypes.int32, dtypes.int32], name="list_output") + + d_2, e_2 = test_ops.foo1(a=d_1, b=e_1, c=list_output1, name="foo_2") + + fdef = graph_to_function_def.graph_to_function_def( + g, + g.get_operations(), + [x, y, z], # Inputs + [x, d_2, e_2, list_output0]) # Outputs. + + # Assert that the FunctionDef was correctly built. + assert len(fdef.node_def) == 3 # 2 Foo1 nodes and 1 ListOutput node. + assert fdef.node_def[0].op == "Foo1" + assert fdef.node_def[0].input == ["x", "y", "z"] + assert fdef.node_def[1].op == "ListOutput" + assert not fdef.node_def[1].input + assert fdef.node_def[2].op == "Foo1" + assert fdef.node_def[2].input == [ + "foo_1:d:0", "foo_1:e:0", "list_output:a:1" + ] + return fdef + + def testTensorNames(self): + fdef = self._build_function_def() + g, tensor_name_map = function_def_to_graph.function_def_to_graph_def(fdef) + + # Verify that inputs of body nodes are correctly renamed. + # foo_1 + self.assertSequenceEqual(g.node[3].input, ["x:0", "y:0", "z:0"]) + # foo_2 + self.assertSequenceEqual(g.node[5].input, + ["foo_1:0", "foo_1:1", "list_output:1"]) + + # Verify that the `tensor_name_map` has the correct mapping. + self.assertDictEqual( + tensor_name_map, { + "x": "x:0", + "y": "y:0", + "z": "z:0", + "foo_1:d:0": "foo_1:0", + "foo_1:e:0": "foo_1:1", + "list_output:a:0": "list_output:0", + "list_output:a:1": "list_output:1", + "foo_2:d:0": "foo_2:0", + "foo_2:e:0": "foo_2:1", + }) + + def testShapes(self): + fdef = self._build_function_def() + g, _ = function_def_to_graph.function_def_to_graph_def( + fdef, + input_shapes=[tensor_shape.scalar(), + tensor_shape.vector(5), None]) + self.assertEqual("shape" in g.node[0].attr, True) + self.assertSequenceEqual( + tensor_shape.TensorShape(g.node[0].attr["shape"].shape).as_list(), []) + self.assertEqual(g.node[0].attr["shape"].shape.unknown_rank, False) + self.assertEqual("shape" in g.node[1].attr, True) + self.assertSequenceEqual( + tensor_shape.TensorShape(g.node[1].attr["shape"].shape).as_list(), [5]) + self.assertEqual(g.node[0].attr["shape"].shape.unknown_rank, False) + self.assertFalse("shape" in g.node[2].attr) + + +if __name__ == "__main__": + test.main() |