diff options
author | Andrew Selle <aselle@google.com> | 2018-08-16 11:23:41 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 11:51:03 -0700 |
commit | 020ce87723b43f96372fff79c1b8d9f989286409 (patch) | |
tree | fdd704609c6c92027679b86d59fe3b035eff60d2 | |
parent | 09e272a6a5c3359b671a068f4dac2bca8b312358 (diff) |
Automated rollback of commit ec5f4771e42972c31faaa39354d785891de9f91d
PiperOrigin-RevId: 209016586
-rw-r--r-- | tensorflow/contrib/lite/python/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/convert_test.py | 93 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/op_hint.py | 898 |
3 files changed, 883 insertions, 111 deletions
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 860aff9e7e..47f0c8e9a2 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -112,8 +112,11 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/core:protos_all_py", + "//tensorflow/python:framework", "//tensorflow/python:platform", + "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/lite/python/convert_test.py b/tensorflow/contrib/lite/python/convert_test.py index dc21a9b669..bc05514cec 100644 --- a/tensorflow/contrib/lite/python/convert_test.py +++ b/tensorflow/contrib/lite/python/convert_test.py @@ -113,12 +113,13 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase): # and 1 final output). self.assertEqual(self._countIdentities(sess.graph_def.node), 4) - stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess) + stubbed_graphdef = op_hint.convert_op_hints_to_stubs( + graph_def=sess.graph_def) self.assertCountEqual( self._getGraphOpTypes( stubbed_graphdef, - output_nodes=[op_hint._tensor_name_base(output)]), + output_nodes=[op_hint._tensor_name_base(output.name)]), ["cool_activation", "Const", "Identity"]) def testScaleAndBiasAndIdentity(self): @@ -139,12 +140,13 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase): # +1 for the final output self.assertEqual(self._countIdentities(sess.graph_def.node), 6) - stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess) + stubbed_graphdef = op_hint.convert_op_hints_to_stubs( + graph_def=sess.graph_def) self.assertCountEqual( self._getGraphOpTypes( stubbed_graphdef, - output_nodes=[op_hint._tensor_name_base(output)]), + output_nodes=[op_hint._tensor_name_base(output.name)]), ["scale_and_bias_and_identity", "Const", "Identity", "Pack"]) def testTwoFunctions(self): @@ -153,7 +155,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase): b = array_ops.constant([1.]) def _double_values(x): custom = op_hint.OpHint("add_test") - x = custom.add_inputs(x) + x, = custom.add_inputs(x) output = math_ops.multiply(x, x) output, = custom.add_outputs(output) return output @@ -164,13 +166,90 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase): # make sure one identity for each input (2) and output (2) => 2 + 2 # +1 for the final output self.assertEqual(self._countIdentities(sess.graph_def.node), 5) - stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess) + stubbed_graphdef = op_hint.convert_op_hints_to_stubs( + graph_def=sess.graph_def) self.assertCountEqual( self._getGraphOpTypes( stubbed_graphdef, - output_nodes=[op_hint._tensor_name_base(output)]), + output_nodes=[op_hint._tensor_name_base(output.name)]), ["add_test", "Const", "Identity", "Add"]) + def _get_input_index(self, x): + return x.op.node_def.attr[op_hint.OpHint.FUNCTION_INPUT_INDEX_ATTR].i + + def _get_output_index(self, x): + return x.op.node_def.attr[op_hint.OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i + + def _get_sort_index(self, x): + return x.op.node_def.attr[op_hint.OpHint.FUNCTION_SORT_INDEX_ATTR].i + + def testTags(self): + """Test if multiple args with the same tag are grouped.""" + a = array_ops.constant([1.]) + b = array_ops.constant([2.]) + c = array_ops.constant([3.]) + d = array_ops.constant([4.]) + custom = op_hint.OpHint("test_tag") + a = custom.add_input(a, tag="mytag", + aggregate=op_hint.OpHint.AGGREGATE_STACK) + b, = custom.add_inputs(b) + c = custom.add_input(c, tag="mytag", + aggregate=op_hint.OpHint.AGGREGATE_STACK) + d = custom.add_input(d, tag="mytag2", + aggregate=op_hint.OpHint.AGGREGATE_STACK) + res = math_ops.add(math_ops.mul(a, b), math_ops.mul(c, b)) + custom.add_outputs([res]) + with self.test_session(): + self.assertEqual(self._get_input_index(a), 0) + self.assertEqual(self._get_sort_index(a), 0) + self.assertEqual(self._get_input_index(b), 1) + self.assertEqual(self._get_input_index(c), 0) + self.assertEqual(self._get_sort_index(c), 1) + + def testOverrideIndex(self): + a = array_ops.constant([1.]) + b = array_ops.constant([2.]) + c = array_ops.constant([3.]) + custom = op_hint.OpHint("test_override") + b = custom.add_input(b) # should auto assign 0 + a = custom.add_input(a, index_override=1) + c = custom.add_input(c) # should auto assign 2 + with self.test_session(): + self.assertEqual(self._get_input_index(a), 1) + self.assertEqual(self._get_input_index(b), 0) + self.assertEqual(self._get_input_index(c), 2) + + def testAggregate(self): + a = array_ops.constant([3., 4.]) + b = array_ops.constant([5., 6.]) + hint = op_hint.OpHint("agg") + a0, a1 = array_ops.unstack(a) + b0, b1 = array_ops.unstack(b) + + a0 = hint.add_input(a0, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK) + b0 = hint.add_input(b0, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK) + a1 = hint.add_input(a1, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK) + b1 = hint.add_input(b1, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK) + + c0 = math_ops.add(a0, b0, name="addleft") + c1 = math_ops.add(a1, b1, name="addright") + c0 = hint.add_output( + c0, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK) + c1 = hint.add_output( + c1, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK) + + curr = array_ops.stack([c0, c1]) + output = array_ops.identity(curr, name="FINAL_OUTPUT") + with self.test_session() as sess: + stubbed_graphdef = op_hint.convert_op_hints_to_stubs( + graph_def=sess.graph_def) + print(stubbed_graphdef) + self.assertCountEqual( + self._getGraphOpTypes( + stubbed_graphdef, + output_nodes=[op_hint._tensor_name_base(output.name)]), + ["agg", "Const", "Identity"]) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/lite/python/op_hint.py b/tensorflow/contrib/lite/python/op_hint.py index 7908689ce4..8c920132e5 100644 --- a/tensorflow/contrib/lite/python/op_hint.py +++ b/tensorflow/contrib/lite/python/op_hint.py @@ -25,9 +25,9 @@ Example: def tflite_cool_activation(input): # A cool activation function. custom = tf.contrib.lite.OpHint("cool_activation") - input = custom.add_inputs(input) + input, = custom.add_inputs(input) output = tf.sigmoid(input) * input - custom.add_outputs(output) + output, = custom.add_outputs(output) return output image = tf.placeholder(tf.float32, (1, 16, 16, 1)) @@ -64,18 +64,27 @@ ops don't actually exist in the normal TensorFlow runtime, but will be understood by toco later. """ +# TODO(aselle): Make this use generic graph transformations. +# TODO(aselle): _tensor_name_base should be called _tensor_name_to_op_name. + from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections as _collections -import itertools as _itertools +import copy as _copy import uuid as _uuid +import six as _six -from tensorflow.contrib import framework as _framework from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2 +from tensorflow.core.framework import graph_pb2 as _graph_pb2 +from tensorflow.core.framework import node_def_pb2 as _node_def_pb2 from tensorflow.python.framework import ops as _ops +# TODO(aselle): publicize these apis if we continue to use these. +from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes +from tensorflow.python.framework.graph_util_impl import _extract_graph_summary from tensorflow.python.ops import array_ops as _array_ops +from tensorflow.python.util import compat as _compat from tensorflow.python.util.all_util import remove_undocumented @@ -97,11 +106,174 @@ class OpHint(object): constructs, this mechanism can be retired and changed to use python defun's. """ - # Attr constants that are used for representation in the GraphDef + # Attr constants that are used for representation in the GraphDef. These + # will be used on every Identity op that is involved in a total OpHint. + + # Name of the OpHint function (cosmetic). FUNCTION_NAME_ATTR = "_tflite_function_name" + # UUID of the function (each OpHint gets a new uuid). FUNCTION_UUID_ATTR = "_tflite_function_uuid" + # The index index of the input (or nothing if it is an output). FUNCTION_INPUT_INDEX_ATTR = "_tflite_function_input_index" + # The output index of the output (or nothing if it is an input). FUNCTION_OUTPUT_INDEX_ATTR = "_tflite_function_output_index" + # An index that orders aggregate arguments. Aggregate arguments are ones + # that are separate but will be fused horizontally. For example a static LSTM + # has a lstm cell for each time step. Each one has a separate opHint, but a + # fused SequentialLSTM will treat this as a single tensor. + FUNCTION_SORT_INDEX_ATTR = "_tflite_function_sort_index" + # The way in which multiple parts of the aggregate argument will be joined + # into a fused operand. Valid options are OpHint.AGGREGATE_FIRST, + # OpHint.AGGREGATE_LAST, OpHint.AGGREGATE_STACK. + FUNCTION_AGGREGATE_ATTR = "_tflite_function_aggregate" + # On fused OpHint stub, the order of inputs that the final LSTM call will + # have. What this means is that the TensorFlow order might be + # "foo", "bar", "stuff" and you might want the TF lite op order to be + # "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this + # attribute to [2, 0, 1, -1]. + TFLITE_INPUT_INDICES = "_tflite_input_indices" + + # Types of aggregations + # stack: stacks all ophints with matching tags. i.e. for a static rnn. + # specifically, this is good for an input or output to a static rnn cell. + AGGREGATE_STACK = _compat.as_bytes("stack") + # first: only takes the first output (one with lowest sort index) + # of matching tags. This is good for the input state to an RNN. + AGGREGATE_FIRST = _compat.as_bytes("first") + # aggregation last takes only the last tag (one with highest sort index). + # This is good for an output value on the last stack item of a + # static rnn. + AGGREGATE_LAST = _compat.as_bytes("last") + + class OpHintArgumentTracker(object): + """Conceptually tracks indices of arguments of "OpHint functions". + + The inputs and arguments of these functions both use an instance + of the class so they can have independent numbering.""" + + def __init__(self, function_name, unique_function_id, node_name_prefix, + attr_name): + """Initialize ophint argument. + + Args: + function_name: Name of the function that this tracks arguments for. + unique_function_id: UUID of function that this tracks arguments for. + node_name_prefix: How identities that are created are named. + attr_name: Name of attribute to use to store the index for this hint. + i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX + """ + + # The global index is the argument index of the op. This is in contrast + # to the sort index which is the sequence number of a particular instance + # of a given global index. For example, you may have called add hint + # twice with the tag "foo". Then the global index will be 0 for both + # and the sort index will be 0 for the first added and 1 for the second. + self._function_name = function_name + self._unique_function_id = unique_function_id + self._next_global_index = 0 # The absolute global index + self._used_global_indices = set() + self._tag_to_global_index = {} # The argument index a given tag maps to + self._tag_to_next_sort_index = {} # The current index for each tag + self._node_name_prefix = node_name_prefix + self._attr_name = attr_name + + def _get_new_global_index(self, index_override): + """Return the next unused argument index in order or use an override. + + Args: + index_override: An index to use instead of the next available or None + to use the next available. + + Returns: + A valid global_index to use for the next hint argument. + + Raises: + ValueError: If the index_override is already used by another hint. + """ + if index_override is None: + global_index = self._next_global_index + else: + if index_override in self._used_global_indices: + raise ValueError("Index %d was already used by another call to add") + global_index = index_override + # Make next_global_index valid + self._used_global_indices.add(global_index) + while self._next_global_index in self._used_global_indices: + self._next_global_index += 1 + return global_index + + def add(self, arg, tag=None, name=None, aggregate=None, + index_override=None): + """Return a wrapped tensor of an input tensor as an argument. + + Args: + arg: A TensorFlow tensor that should be considered an argument. + tag: String tag to identify arguments that should be packed. + name: Name of argument. This is included in the Identity hint op names. + aggregate: Strategy to aggregate. + Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, + and OpHint.AGGREGATE_STACK. + Note, aggregate is only valid if tag is specified. + index_override: Specify what input/output index should this be in the + final stub. i.e. add(arg0, index=1); add(arg1, index=0) wil make the + final stub be as stub_func(inputs[arg1, arg0], outputs=[]) rather than + the default call order based ordering. + + Returns: + A tensor representing the wrapped argument. + + Raises: + ValueError: When indices are not consistent. + """ + + # Find the appropriate index + if tag is None: + if aggregate is not None: + raise ValueError("You must specify `tag` if using aggregate.") + global_index = self._get_new_global_index(index_override) + sort_index = None + else: + if aggregate is None: + raise ValueError("You must specify `aggregate` if using tag.") + if tag not in self._tag_to_global_index: + self._tag_to_global_index[tag] = ( + self._get_new_global_index(index_override)) + self._tag_to_next_sort_index[tag] = 0 + elif (index_override and + index_override != self._tag_to_global_index[tag]): + raise ValueError( + "Tag %r was called with two indices %r and %r" % + (tag, index_override, self._tag_to_global_index[tag])) + global_index = self._tag_to_global_index[tag] + sort_index = self._tag_to_next_sort_index[tag] + self._tag_to_next_sort_index[tag] += 1 + + uuid = self._unique_function_id + name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name, + uuid, global_index, sort_index, name) + identity_op = _array_ops.identity(arg, name=name) + + # pylint: disable=protected-access + identity_op.op._set_attr( + OpHint.FUNCTION_NAME_ATTR, + _attr_value_pb2.AttrValue( + s=_compat.as_bytes(self._function_name))) + identity_op.op._set_attr( + OpHint.FUNCTION_UUID_ATTR, + _attr_value_pb2.AttrValue( + s=_compat.as_bytes(self._unique_function_id))) + identity_op.op._set_attr( + self._attr_name, _attr_value_pb2.AttrValue(i=global_index)) + if sort_index is not None: + identity_op.op._set_attr( + OpHint.FUNCTION_SORT_INDEX_ATTR, + _attr_value_pb2.AttrValue(i=sort_index)) + if aggregate is not None: + identity_op.op._set_attr( + OpHint.FUNCTION_AGGREGATE_ATTR, + _attr_value_pb2.AttrValue(s=_compat.as_bytes((aggregate)))) + # pylint: enable=protected-access + return identity_op def __init__(self, function_name, **kwargs): """Create a OpHint. @@ -112,10 +284,14 @@ class OpHint(object): """ self._function_name = function_name self._unique_function_id = _uuid.uuid1().hex # TODO(aselle): Unique enough? - self._curr_input_index = 0 - self._curr_output_index = 0 self._attrs_to_store_later = kwargs self._stored_attrs = False + self._inputs = OpHint.OpHintArgumentTracker( + self._function_name, self._unique_function_id, "InputHint", + OpHint.FUNCTION_INPUT_INDEX_ATTR) + self._outputs = OpHint.OpHintArgumentTracker( + self._function_name, self._unique_function_id, "OutputHint", + OpHint.FUNCTION_OUTPUT_INDEX_ATTR) def _setattr(self, dest_op, name, value): tensor_value = _ops.convert_to_tensor(value) @@ -124,68 +300,278 @@ class OpHint(object): tensor=tensor_value.op.node_def.attr["value"].tensor)) # pylint: enable=protected-access - def add_inputs(self, *args): + def add_input(self, *args, **kwargs): + """Add a wrapped input argument to the hint. + + Args: + *args: The input tensor. + **kwargs: + "name" label + "tag" a tag to group multiple arguments that will be aggregated. I.e. + a string like 'cool_input'. Basically multiple inputs can be added + to the same hint for parallel operations that will eventually be + combined. An example would be static_rnn which creates multiple copies + of state or inputs. + "aggregate" aggregation strategy that is valid only for tag non None. + Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, + and OpHint.AGGREGATE_STACK. + "index_override" The global index to use. This corresponds to the + argument order in the final stub that will be generated. + Returns: + The wrapped input tensor. + """ + return self._inputs.add(*args, **kwargs) + + def add_output(self, *args, **kwargs): + """Add a wrapped output argument to the hint. + + Args: + *args: The output tensor. + **kwargs: + "name" label + "tag" a tag to group multiple arguments that will be aggregated. I.e. + a string like 'cool_input'. Basically multiple inputs can be added + to the same hint for parallel operations that will eventually be + combined. An example would be static_rnn which creates multiple copies + of state or inputs. + "aggregate" aggregation strategy that is valid only for tag non None. + Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, + and OpHint.AGGREGATE_STACK. + "index_override" The global index to use. This corresponds to the + argument order in the final stub that will be generated. + Returns: + The wrapped output tensor. + """ + return self._outputs.add(*args, **kwargs) + + def add_inputs(self, *args, **kwargs): """Add a sequence of inputs to the function invocation. Args: *args: List of inputs to be converted (should be Tf.Tensor). + **kwargs: This allows 'names' which should be a list of names. Returns: Wrapped inputs (identity standins that have additional metadata). These are also are also tf.Tensor's. """ - - def augmented_identity(arg): - identity_op = _array_ops.identity(arg) - # pylint: disable=protected-access - identity_op.op._set_attr( - OpHint.FUNCTION_NAME_ATTR, - _attr_value_pb2.AttrValue(s=self._function_name)) - identity_op.op._set_attr( - OpHint.FUNCTION_UUID_ATTR, - _attr_value_pb2.AttrValue(s=self._unique_function_id)) - identity_op.op._set_attr( - OpHint.FUNCTION_INPUT_INDEX_ATTR, - _attr_value_pb2.AttrValue(i=self._curr_input_index)) - # pylint: enable=protected-access - self._curr_input_index += 1 - return identity_op - - return [augmented_identity(arg) for arg in args] - - def add_outputs(self, *args): + if "names" in kwargs: + return [ + self._inputs.add(arg, name=name) + for arg, name in zip(args, kwargs["names"]) + ] + else: + return [self._inputs.add(arg) for arg in args] + + def add_outputs(self, *args, **kwargs): """Add a sequence of outputs to the function invocation. Args: *args: List of outputs to be converted (should be tf.Tensor). + **kwargs: See Returns: Wrapped outputs (identity standins that have additional metadata). These are also tf.Tensor's. """ + if "names" in kwargs: + return [ + self._outputs.add(arg, name=name) + for arg, name in zip(args, kwargs["names"]) + ] + else: + return [self._outputs.add(arg) for arg in args] + + +class _LiteOperand(object): + """Abstract operand for a tflite hint function. + + This is a base class that handles representing arguments to an OpHint. + It also is able to serialize operands to the stubbed graph_def. + Child classes are responsible for being able to + store information about the hint identity operators. They are also responsible + for knowing how to serialize to output graphdefs. + + Typically this will be implemented by holding one or more identity nodes + that were previously discovered as hints. + """ + + def aggregate_and_return_name_for_input(self, out_graphdef): + """This adds the node(s) to out_graphdef and returns the input node name. + + Args: + out_graphdef: A graphdef that is ready to have this input added. + + Returns: + The the output that the stub should use as an input for this operand. + + Raises: + RuntimeError: if the method is not implemented. + """ + del out_graphdef + raise RuntimeError("Unimplemented abstract method.") + + def aggregate_and_return_name_for_output(self, fused_op_name, output_index, + out_graphdef): + """Add node(s) to graph representing output operands and returns type. + + Args: + fused_op_name: name of the fused op stub name. + output_index: Output index that we are currently processing from stub. + out_graphdef: The destination graphdef we are currently building up. + + Returns: + The datatype of this identity. + + Raises: + RuntimeError: if the method is not implemented. + """ + del fused_op_name, output_index, out_graphdef + raise RuntimeError("Unimplemented abstract method.") - def augmented_identity(arg): - identity_op = _array_ops.identity(arg) - # pylint: disable=protected-access - identity_op.op._set_attr( - OpHint.FUNCTION_NAME_ATTR, - _attr_value_pb2.AttrValue(s=self._function_name)) - identity_op.op._set_attr( - OpHint.FUNCTION_UUID_ATTR, - _attr_value_pb2.AttrValue(s=self._unique_function_id)) - identity_op.op._set_attr( - OpHint.FUNCTION_OUTPUT_INDEX_ATTR, - _attr_value_pb2.AttrValue(i=self._curr_output_index)) - # pylint: enable=protected-access - self._curr_output_index += 1 - return identity_op - wrapped_outputs = [augmented_identity(arg) for arg in args] +class _LiteSingleOperand(_LiteOperand): + """A simple operand that is non-aggregated (i.e. most hints).""" - if not self._stored_attrs: - for key, value in self._attrs_to_store_later.iteritems(): - self._setattr(wrapped_outputs[0], "_tflite_attr_" + key, value) - self._stored_attrs = True + def __init__(self, node): + _LiteOperand.__init__(self) + self.node = node + self.name = _tensor_name_base(node.name) - return wrapped_outputs + def flatten(self): + return [self.name] + + def aggregate_and_return_name_for_input(self, out_graphdef): + return self.name + + def aggregate_and_return_name_for_output(self, fused_op_name, index, + out_graphdef): + output_node = _copy.deepcopy(self.node) + del output_node.input[:] + output_node.input.append(_tensorflow_output_name(fused_op_name, index)) + out_graphdef.node.extend([output_node]) + return self.node.attr["type"].i + + def __str__(self): + return str(self.name) + + +class _LiteAggregateOperand(_LiteOperand): + """An operand for a tflite hint function that is aggregated from many. + + For example, an LSTM is a grid of operators that are all related. Inputs + going into them may need to be fused, so they should all be tracked as + related arguments. + """ + + def __init__(self, aggregation): + _LiteOperand.__init__(self) + self.aggregation = aggregation + self.names = {} + self.nodes = {} + self.flattened = None + + def add(self, sort, node): + self.names[sort] = _tensor_name_base(node.name) + self.nodes[sort] = node + + def flatten_nodes(self): + """Return a list of all the node protos in aggregation sorted order.""" + if not self.flattened: + self.flattened = [None] * len(self.nodes) + for idx, node in _six.iteritems(self.nodes): + self.flattened[idx] = node + for n in self.nodes: + if n is None: + raise RuntimeError("Aggregate was missing argument.") + if self.aggregation == OpHint.AGGREGATE_FIRST: + self.flattened = self.flattened[:1] + elif self.aggregation == OpHint.AGGREGATE_LAST: + self.flattened = self.flattened[-1:] + elif self.aggregation == OpHint.AGGREGATE_STACK: + pass + else: + raise ValueError( + "Invalid aggregation type %r specified" % self.aggregation) + return self.flattened + + def flatten(self): + """Return a list of all node names in aggregation sorted sorter.""" + return [_tensor_name_base(x.name) for x in self.flatten_nodes()] + + def aggregate_and_return_name_for_input(self, out_graphdef): + """This adds the nodes to out_graphdef and returns an aggregated output. + + In particular, if you have 4 inputs to a hint stub, this will be the + node that you can use as an output. I.e. you have 4 timesteps from a + static rnn, then a fused UnidriecitonalLSTM will expect 1 input with + all 4 time steps. So here we make a pack and return the output name of + that pack. + + Args: + out_graphdef: A graphdef that is ready to have this input added. + + Returns: + The name of a pack that aggregates this node. + """ + flattened = self.flatten_nodes() + if len(flattened) == 1: + return _tensor_name_base(flattened[0].name) + else: + new_node = _node_def_pb2.NodeDef() + new_node.op = "Pack" + new_node.name = "OpHintStack-%s" % flattened[0].name + new_node.attr["N"].i = len(flattened) + new_node.attr["T"].type = flattened[0].attr["T"].type + for discrete in flattened: + new_node.input.append(_tensor_name_base(discrete.name)) + out_graphdef.node.extend([new_node]) + return new_node.name + + def aggregate_and_return_name_for_output(self, fused_op_name, output_index, + out_graphdef): + """This adds to `out_graphdef` all the unaggregated outputs. + + I.e. we are outputting from a fused stub, but we need to make it compatible + with the unfused original graph so we insert an unpack. Ideally in a later + stage the unpack -> pack sequences will be removed. + + Args: + fused_op_name: The name of the stub we are in the process of fusing. + output_index: The output output_index this object represents. + out_graphdef: The graphdef we are in the process of buildings + + Returns: + The type of the aggregated output (so we can finish building the stub + op). + """ + flattened = self.flatten_nodes() + if len(flattened) == 1: + temp_op = _LiteSingleOperand(flattened[0]) + return temp_op.aggregate_and_return_name_for_output( + fused_op_name, output_index, out_graphdef) + else: + stack_node = _node_def_pb2.NodeDef() + stack_node.op = "Unpack" + stack_node.name = "OpHintUnstack-%s" % flattened[0].name + stack_node.attr["num"].i = len(flattened) + output_type = flattened[0].attr["T"].type + stack_node.attr["T"].type = output_type + stack_node.input.append(_tensorflow_output_name( + fused_op_name, output_index)) + out_graphdef.node.extend([stack_node]) + + for idx, discrete in enumerate(flattened): + output_node = _copy.deepcopy(discrete) + del output_node.input[:] + output_node.input.append(_tensorflow_output_name(stack_node.name, idx)) + out_graphdef.node.extend([output_node]) + + return output_type + + def __str__(self): + s = "\t\t\tAGGREGATE %s\n" % self.aggregation + for sort, val in self.names.iteritems(): + s += "\t\t\t%d: %s\n" % (sort, val) + return s class _LiteFuncCall(object): @@ -212,46 +598,87 @@ class _LiteFuncCall(object): self.uuid = None self.params = {} + def flattened_inputs_and_outputs(self): + """Return a list of inputs and outputs in a flattened format. + + Returns: + Tuple of (inputs, outputs). where input and output i a list of names. + """ + def _flatten(input_or_output_dict): + flattened_items = [] + for item in input_or_output_dict.values(): + flattened_items.extend(item.flatten()) + return flattened_items + + return _flatten(self.inputs), _flatten(self.outputs) + def __str__(self): - return "tflite function %s call %s\n\tinputs: %r\n\toutputs: %r" % ( - self.function_name, self.uuid, self.inputs, self.outputs) + def format_args(items): + s = "" + for idx, item in items.iteritems(): + s += ("\t\t%d:\n" % idx) + str(item) + return s + + inputs_str = "\tInputs\n" + format_args(self.inputs) + outputs_str = "\tOutputs\n" + format_args(self.outputs) + return ("tflite function %s call %s\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s" + % (self.function_name, self.uuid, inputs_str, outputs_str)) -def _find_all_hints_in_graph_def(session): + +def _find_all_hints_in_graph_def(graphdef): """Look at the current default graph and return a list of LiteFuncCall objs. Args: - session: A TensorFlow session that contains the graph to convert. + graphdef: A TensorFlow graph_def to look for LiteFuncCalls. Returns: a list of `LifeFuncCall` objects in the form """ func_calls = _collections.defaultdict(_LiteFuncCall) - seen_ops = set() - - for op in session.graph.get_operations(): - for operand in _itertools.chain(op.inputs, op.outputs): - if operand in seen_ops: - continue - seen_ops.add(operand) - attr = operand.op.node_def.attr - uuid = attr[OpHint.FUNCTION_UUID_ATTR].s - if OpHint.FUNCTION_UUID_ATTR not in attr: - continue - call_def = func_calls[uuid] - call_def.uuid = uuid - if OpHint.FUNCTION_UUID_ATTR in attr: - call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s - if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr: - call_def.inputs[attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i] = operand - if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr: - call_def.outputs[attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i] = operand - - for a in attr: - if a.startswith("_tflite_attr_"): - # TODO(aselle): Remember the attribute tensors so we can put them - # in collapse. - call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor + + for node in graphdef.node: + attr = node.attr + # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip + uuid = attr[OpHint.FUNCTION_UUID_ATTR].s + if (OpHint.FUNCTION_UUID_ATTR not in attr + or not attr[OpHint.FUNCTION_UUID_ATTR].s): + continue + + # Start building function + call_def = func_calls[uuid] + call_def.uuid = uuid + call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s + # Get sorting and aggregation information + + sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i + if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None) + if sort == -1: sort = None + aggregation = None + if OpHint.FUNCTION_AGGREGATE_ATTR in attr: + aggregation = attr[OpHint.FUNCTION_AGGREGATE_ATTR].s + + # Add the input or output + def put_operand(stuff, index, sort, operand, aggregation): + """Add a given index into the function structure.""" + if sort is None: + stuff[index] = _LiteSingleOperand(operand) + else: + if index not in stuff: + stuff[index] = _LiteAggregateOperand(aggregation) + stuff[index].add(sort, operand) + + if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr: + put_operand(call_def.inputs, attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i, + sort, node, aggregation) + if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr: + put_operand(call_def.outputs, attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i, + sort, node, aggregation) + + # Remember attributes + for a in attr: + if a.startswith("_tflite_attr_"): + call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor return func_calls @@ -267,42 +694,305 @@ def _tensor_name_base(full_tensor_name): Returns: A name without any device assignment. """ - return full_tensor_name.name.split(":")[0] + if full_tensor_name.startswith("^"): + return full_tensor_name[1:] + return full_tensor_name.split(":")[0] + + +def _tensorflow_output_name(tensor_name, output_index): + return tensor_name if output_index == 0 else "%s:%d" % (tensor_name, + output_index) + + +# TODO(aselle): This should be converted to grappler in the future. +def _check_subgraph_closed(n, reachable_by_input, input_nodes_set, + name_to_input_name): + """Checks to make sure node only connects to predecessor graph through inputs. + + Args: + n: Node to check + reachable_by_input: Nodes that are reachable by all inputs of subgraph + input_nodes_set: The set of nodes that are "inputs". + name_to_input_name: Maps from name to the list of inputs. + + Raises: + TypeError: If the given node uses items past inputs directly. + """ + next_to_visit = [n] + visited = set() + while next_to_visit: + current_node = next_to_visit.pop() + visited.add(current_node) + if (current_node in reachable_by_input + and current_node not in input_nodes_set): + raise TypeError( + "Node %s uses input %s not in input_nodes." % (n, current_node)) + if current_node not in input_nodes_set: + next_to_visit += [ + input_node for input_node in name_to_input_name[current_node] + if input_node not in visited + ] + + +# TODO(aselle): This should be converted to grappler in the future. +def _convert_single_op_hint_to_stub(call, graph_def): + """Given a graph_def, converts `call` into a stub and returns a new graph_def. + Args: + call: A single function call to be converted. + graph_def: A graph_def to use as input (that hass call obviously). + Returns: + A new transformed graph-def that has call as a stub (single op). -def convert_op_hints_to_stubs(session): + Note: after this process, the graph_def can no longer be loaded into + the tensorflow runtime, so all future manipulations are done in graph_def + level. + """ + name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( + graph_def) + input_names, output_names = call.flattened_inputs_and_outputs() + + reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) + reachable_by_output = _bfs_for_reachable_nodes(output_names, + name_to_input_name) + input_nodes_set = set(input_names) + output_nodes_set = set(output_names) + nodes_after_fuse = [] + nodes_deleted_by_fuse = set() + # Classify each node. We want to keep everything reachable by input, but + # we don't know if things that are not reachable by output or input (things + # after fusing). + for node in graph_def.node: + n = _tensor_name_base(node.name) + if n in reachable_by_output: + if n not in reachable_by_input and n not in output_nodes_set: + # n is an internal node. Check to make sure it is really internal. + # TODO(aselle): this could be done more efficiently by flooding + # the graph first. + _check_subgraph_closed(n, reachable_by_input, input_nodes_set, + name_to_input_name) + nodes_deleted_by_fuse.add(n) + elif n not in reachable_by_input: + # n is a node that after all the fusings, so keep it. + nodes_after_fuse.append(n) + else: + # n is a node that is randomly in the graph but not connected to + # the chain of dependencies. + pass + + # Make a new graphdef with all the pre-input and input nodes + out = _graph_pb2.GraphDef() + reachable_by_input_sorted = sorted( + list(reachable_by_input), key=lambda n: name_to_seq_num[n]) + for node in reachable_by_input_sorted: + out.node.extend([_copy.deepcopy(name_to_node[node])]) + + # Create any stacks to aggregate arguments into to a single input + # i.e. for static_rnn's. + # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1 + sorted_input_indices = list(call.inputs.keys()) + sorted_input_indices.sort() + sorted_output_indices = list(call.outputs.keys()) + sorted_output_indices.sort() + new_node = _node_def_pb2.NodeDef() + # Delegate to each operand to produce the proper new input for this stub node. + # In particular, an aggregate input will now be a Pack of some previously + # non-fused things. + for input_index in sorted_input_indices: + inputs = call.inputs[input_index] + new_node.input.append(inputs.aggregate_and_return_name_for_input(out)) + new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices) + + # Ceate the function + new_node.op = call.function_name + new_node.name = call.uuid + out.node.extend([new_node]) + + # Now call each output argument to give them a chance to make the proper + # output type and add it to our new_node. + output_dtypes = [] + for output_index in sorted_output_indices: + output = call.outputs[output_index] + output_dtype = ( + output.aggregate_and_return_name_for_output(new_node.name, output_index, + out)) + output_dtypes.append(output_dtype) + new_node.attr["_output_types"].list.type[:] = output_dtypes + # TODO(aselle): what is right here? + new_node.attr["_output_quantized"].b = False + + # Add post output nodes that do not depend on the outputs + for n in nodes_after_fuse: + should_keep = True + for input_name in name_to_input_name[n]: + if input_name in nodes_deleted_by_fuse: + should_keep = False + if should_keep: + out.node.extend([_copy.deepcopy(name_to_node[n])]) + + # Misc. graph_def data that needs copying. + out.library.CopyFrom(graph_def.library) + out.versions.CopyFrom(graph_def.versions) + + return out + + +# TODO(aselle): This should be converted to grappler in the future. +def _remove_one_redundant_stack_unstack(in_graph_def): + """Removes a stack->unstack pattern from in_graph_def in a returned graph. + + Args: + in_graph_def: Graph def to use as input. + Returns: + Simplified tuple (graph_def, changed_something) where changed_something + is true if anything was done. + """ + name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( + in_graph_def) + del name_to_seq_num + + # TODO(aselle): Make this not hardcoded. + do_generic_pack_unpack = True + + out = _graph_pb2.GraphDef() + out.library.CopyFrom(in_graph_def.library) + out.versions.CopyFrom(in_graph_def.versions) + for n in in_graph_def.node: + node_name = _tensor_name_base(n.name) + if not node_name.startswith("OpHintStack") and not n.op.startswith("Pack"): + continue + next_to_visit = [node_name] + visited = set() + + unpack_nodes = set() + pack_node = node_name + + # Find a pattern of unstack connected to a stack (with identities + # in between. + matches_pattern = True + is_hint_created_stack = False + while next_to_visit: + current_node_name = next_to_visit[0] + visited.add(current_node_name) + del next_to_visit[0] + node = name_to_node[current_node_name] + is_op_hint_stack = node.name.startswith("OpHintStack") + is_op_hint_unstack = node.name.startswith("OpHintUnstack") + if (node.op == "Identity" or is_op_hint_stack + or (do_generic_pack_unpack and node.op == "Pack")): + is_hint_created_stack |= is_op_hint_stack + next_to_visit += [ + input_node for input_node in name_to_input_name[current_node_name] + if input_node not in visited + ] + elif (is_op_hint_unstack + or (do_generic_pack_unpack and node.op == "Unpack")): + unpack_nodes.add(node.name) + is_hint_created_stack &= is_op_hint_unstack + else: + matches_pattern = False + break + visited.add(node.name) + + if matches_pattern and len(unpack_nodes) == 1: + pack_node = node_name + + # Check to see if anyone depends on the intermediate identity or the + # Unstacked form + no_external_dependency = True + for other_n in in_graph_def.node: + if other_n.name in visited: continue + for input_tensor in name_to_input_name[other_n.name]: + input_op = _tensor_name_base(input_tensor) + if input_op in visited and input_op != pack_node: + no_external_dependency = False + # Proceed with the substitution if the stack/unstack pair was created + # through hints, or that it was not, but nobody is consuming things + # between the stack and unstack. + if is_hint_created_stack or no_external_dependency: + end = unpack_nodes.pop() + end_input = name_to_node[end].input[0] + # All nodes that depend on the final stack need to be redone to use + for other_n in in_graph_def.node: + node_name = _tensor_name_base(other_n.name) + if node_name not in visited: + new_node = _copy.deepcopy(other_n) + new_node.input[:] = [ + (end_input if stripped == pack_node else + non_stripped) for stripped, non_stripped in zip( + name_to_input_name[node_name], new_node.input[:]) + ] + out.node.extend([new_node]) + return out, True + return in_graph_def, False + + +def _remove_redundant_stack_unstack(graph_def): + curr = graph_def + del graph_def + changed_stuff = True + while changed_stuff: + curr, changed_stuff = _remove_one_redundant_stack_unstack(curr) + return curr + + +def _convert_op_hints_to_stubs_helper( + graph_def, write_callback=lambda sess, graph_def: None): + """Converts a graph_def to a new graph_def where all op hints are stubbed. + + Args: + graph_def: A graph def that we should convert. + write_callback: A function pointer that can be used to write intermediate + steps of graph transformation (optional). + Returns: + A new stubbed graph_def. + """ + + hints = _find_all_hints_in_graph_def(graph_def) + curr_graph_def = graph_def + del graph_def # prevent using graph_def again (common source of error) + for hint in _six.itervalues(hints): + curr_graph_def = _convert_single_op_hint_to_stub( + hint, curr_graph_def) + write_callback(curr_graph_def, "initial") + # The stubbing process can create stacks/unstacks in the case of LSTMs + # remove them. + curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def) + return curr_graph_def + + +def convert_op_hints_to_stubs(session=None, + graph_def=None, + write_callback=lambda graph_def, comments: None): """Converts a graphdef with LiteOp hints into stub operations. This is used to prepare for toco conversion of complex intrinsic usages. + Note: only one of session or graph_def should be used, not both. Args: session: A TensorFlow session that contains the graph to convert. + graph_def: A graph def that we should convert. + write_callback: A function pointer that can be used to write intermediate + steps of graph transformation (optional). Returns: A new graphdef with all ops contained in OpHints being replaced by a single op call with the right parameters. + Raises: + ValueError: If both session and graph_def are provided. """ - hints = _find_all_hints_in_graph_def(session) - current_graph_def = session.graph_def - for call in hints.values(): - input_names = [None] * len(call.inputs) - output_names = [None] * len(call.outputs) - output_dtypes = [None] * len(call.outputs) - output_quantized = False - for input_index, tensor in call.inputs.items(): - input_names[input_index] = _tensor_name_base(tensor) - for output_index, tensor in call.outputs.items(): - output_names[output_index] = _tensor_name_base(tensor) - output_dtypes[output_index] = tensor.dtype.as_datatype_enum - # TODO(aselle): Support quantized flag properly - current_graph_def = _framework.fuse_op( - current_graph_def, input_names, output_names, output_dtypes, - output_quantized, call.uuid, call.function_name) - for node in current_graph_def.node: - if node.name == call.uuid: - for param, tensor in call.params.items(): - node.attr[param].tensor.CopyFrom(tensor) - return current_graph_def - - -_allowed_symbols = ["OpHint", "convert_op_hints_to_stubs"] + + if session is not None and graph_def is not None: + raise ValueError("Provide only one of session and graph_def.") + + if session is not None: + return _convert_op_hints_to_stubs_helper(session.graph_def, write_callback) + elif graph_def is not None: + return _convert_op_hints_to_stubs_helper(graph_def, write_callback) + else: + raise ValueError("Must specify session or graph_def as input.") + + +_allowed_symbols = [ + "OpHint", "convert_op_hints_to_stubs", "convert_op_hints_to_stubs_new" +] remove_undocumented(__name__, _allowed_symbols) |