diff options
author | Andrew Selle <aselle@google.com> | 2018-08-15 13:01:23 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-15 13:06:01 -0700 |
commit | ec5f4771e42972c31faaa39354d785891de9f91d (patch) | |
tree | 92f7d09a66c64f94114a5495cd95c984fb73d3d1 | |
parent | edf1f507f1874131ad453e03a5e0bffb9a189d73 (diff) |
Automated rollback of commit e045fd284075001e7f7585c581c4444e4e850534
PiperOrigin-RevId: 208868027
-rw-r--r-- | tensorflow/contrib/lite/python/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/convert_test.py | 93 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/op_hint.py | 894 |
3 files changed, 111 insertions, 878 deletions
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 2dd5522300..860aff9e7e 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -112,9 +112,7 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/contrib/framework:framework_py", - "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/core:protos_all_py", - "//tensorflow/python:framework", "//tensorflow/python:platform", ], ) diff --git a/tensorflow/contrib/lite/python/convert_test.py b/tensorflow/contrib/lite/python/convert_test.py index ae578378f9..dc21a9b669 100644 --- a/tensorflow/contrib/lite/python/convert_test.py +++ b/tensorflow/contrib/lite/python/convert_test.py @@ -113,13 +113,12 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase): # and 1 final output). self.assertEqual(self._countIdentities(sess.graph_def.node), 4) - stubbed_graphdef = op_hint.convert_op_hints_to_stubs( - graph_def=sess.graph_def) + stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess) self.assertCountEqual( self._getGraphOpTypes( stubbed_graphdef, - output_nodes=[op_hint._tensor_name_base(output.name)]), + output_nodes=[op_hint._tensor_name_base(output)]), ["cool_activation", "Const", "Identity"]) def testScaleAndBiasAndIdentity(self): @@ -140,13 +139,12 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase): # +1 for the final output self.assertEqual(self._countIdentities(sess.graph_def.node), 6) - stubbed_graphdef = op_hint.convert_op_hints_to_stubs( - graph_def=sess.graph_def) + stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess) self.assertCountEqual( self._getGraphOpTypes( stubbed_graphdef, - output_nodes=[op_hint._tensor_name_base(output.name)]), + output_nodes=[op_hint._tensor_name_base(output)]), ["scale_and_bias_and_identity", "Const", "Identity", "Pack"]) def testTwoFunctions(self): @@ -155,7 +153,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase): b = array_ops.constant([1.]) def _double_values(x): custom = op_hint.OpHint("add_test") - x, = custom.add_inputs(x) + x = custom.add_inputs(x) output = math_ops.multiply(x, x) output, = custom.add_outputs(output) return output @@ -166,90 +164,13 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase): # make sure one identity for each input (2) and output (2) => 2 + 2 # +1 for the final output self.assertEqual(self._countIdentities(sess.graph_def.node), 5) - stubbed_graphdef = op_hint.convert_op_hints_to_stubs( - graph_def=sess.graph_def) + stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess) self.assertCountEqual( self._getGraphOpTypes( stubbed_graphdef, - output_nodes=[op_hint._tensor_name_base(output.name)]), + output_nodes=[op_hint._tensor_name_base(output)]), ["add_test", "Const", "Identity", "Add"]) - def _get_input_index(self, x): - return x.op.node_def.attr[op_hint.OpHint.FUNCTION_INPUT_INDEX_ATTR].i - - def _get_output_index(self, x): - return x.op.node_def.attr[op_hint.OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i - - def _get_sort_index(self, x): - return x.op.node_def.attr[op_hint.OpHint.FUNCTION_SORT_INDEX_ATTR].i - - def testTags(self): - """Test if multiple args with the same tag are grouped.""" - a = array_ops.constant([1.]) - b = array_ops.constant([2.]) - c = array_ops.constant([3.]) - d = array_ops.constant([4.]) - custom = op_hint.OpHint("test_tag") - a = custom.add_input(a, tag="mytag", - aggregate=op_hint.OpHint.AGGREGATE_STACK) - b, = custom.add_inputs(b) - c = custom.add_input(c, tag="mytag", - aggregate=op_hint.OpHint.AGGREGATE_STACK) - d = custom.add_input(d, tag="mytag2", - aggregate=op_hint.OpHint.AGGREGATE_STACK) - res = math_ops.add(math_ops.mul(a, b), math_ops.mul(c, b)) - custom.add_outputs([res]) - with self.test_session(): - self.assertEqual(self._get_input_index(a), 0L) - self.assertEqual(self._get_sort_index(a), 0L) - self.assertEqual(self._get_input_index(b), 1L) - self.assertEqual(self._get_input_index(c), 0L) - self.assertEqual(self._get_sort_index(c), 1L) - - def testOverrideIndex(self): - a = array_ops.constant([1.]) - b = array_ops.constant([2.]) - c = array_ops.constant([3.]) - custom = op_hint.OpHint("test_override") - b = custom.add_input(b) # should auto assign 0 - a = custom.add_input(a, index_override=1) - c = custom.add_input(c) # should auto assign 2 - with self.test_session(): - self.assertEqual(self._get_input_index(a), 1L) - self.assertEqual(self._get_input_index(b), 0L) - self.assertEqual(self._get_input_index(c), 2L) - - def testAggregate(self): - a = array_ops.constant([3., 4.]) - b = array_ops.constant([5., 6.]) - hint = op_hint.OpHint("agg") - a0, a1 = array_ops.unstack(a) - b0, b1 = array_ops.unstack(b) - - a0 = hint.add_input(a0, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK) - b0 = hint.add_input(b0, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK) - a1 = hint.add_input(a1, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK) - b1 = hint.add_input(b1, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK) - - c0 = math_ops.add(a0, b0, name="addleft") - c1 = math_ops.add(a1, b1, name="addright") - c0 = hint.add_output( - c0, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK) - c1 = hint.add_output( - c1, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK) - - curr = array_ops.stack([c0, c1]) - output = array_ops.identity(curr, name="FINAL_OUTPUT") - with self.test_session() as sess: - stubbed_graphdef = op_hint.convert_op_hints_to_stubs( - graph_def=sess.graph_def) - print(stubbed_graphdef) - self.assertCountEqual( - self._getGraphOpTypes( - stubbed_graphdef, - output_nodes=[op_hint._tensor_name_base(output.name)]), - ["agg", "Const", "Identity"]) - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/lite/python/op_hint.py b/tensorflow/contrib/lite/python/op_hint.py index e439442f21..7908689ce4 100644 --- a/tensorflow/contrib/lite/python/op_hint.py +++ b/tensorflow/contrib/lite/python/op_hint.py @@ -25,9 +25,9 @@ Example: def tflite_cool_activation(input): # A cool activation function. custom = tf.contrib.lite.OpHint("cool_activation") - input, = custom.add_inputs(input) + input = custom.add_inputs(input) output = tf.sigmoid(input) * input - output, = custom.add_outputs(output) + custom.add_outputs(output) return output image = tf.placeholder(tf.float32, (1, 16, 16, 1)) @@ -64,24 +64,17 @@ ops don't actually exist in the normal TensorFlow runtime, but will be understood by toco later. """ -# TODO(aselle): Make this use generic graph transformations. -# TODO(aselle): _tensor_name_base should be called _tensor_name_to_op_name. - from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections as _collections -import copy as _copy +import itertools as _itertools import uuid as _uuid +from tensorflow.contrib import framework as _framework from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2 -from tensorflow.core.framework import graph_pb2 as _graph_pb2 -from tensorflow.core.framework import node_def_pb2 as _node_def_pb2 from tensorflow.python.framework import ops as _ops -# TODO(aselle): publicize these apis if we continue to use these. -from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes -from tensorflow.python.framework.graph_util_impl import _extract_graph_summary from tensorflow.python.ops import array_ops as _array_ops from tensorflow.python.util.all_util import remove_undocumented @@ -104,172 +97,11 @@ class OpHint(object): constructs, this mechanism can be retired and changed to use python defun's. """ - # Attr constants that are used for representation in the GraphDef. These - # will be used on every Identity op that is involved in a total OpHint. - - # Name of the OpHint function (cosmetic). + # Attr constants that are used for representation in the GraphDef FUNCTION_NAME_ATTR = "_tflite_function_name" - # UUID of the function (each OpHint gets a new uuid). FUNCTION_UUID_ATTR = "_tflite_function_uuid" - # The index index of the input (or nothing if it is an output). FUNCTION_INPUT_INDEX_ATTR = "_tflite_function_input_index" - # The output index of the output (or nothing if it is an input). FUNCTION_OUTPUT_INDEX_ATTR = "_tflite_function_output_index" - # An index that orders aggregate arguments. Aggregate arguments are ones - # that are separate but will be fused horizontally. For example a static LSTM - # has a lstm cell for each time step. Each one has a separate opHint, but a - # fused SequentialLSTM will treat this as a single tensor. - FUNCTION_SORT_INDEX_ATTR = "_tflite_function_sort_index" - # The way in which multiple parts of the aggregate argument will be joined - # into a fused operand. Valid options are OpHint.AGGREGATE_FIRST, - # OpHint.AGGREGATE_LAST, OpHint.AGGREGATE_STACK. - FUNCTION_AGGREGATE_ATTR = "_tflite_function_aggregate" - # On fused OpHint stub, the order of inputs that the final LSTM call will - # have. What this means is that the TensorFlow order might be - # "foo", "bar", "stuff" and you might want the TF lite op order to be - # "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this - # attribute to [2, 0, 1, -1]. - TFLITE_INPUT_INDICES = "_tflite_input_indices" - - # Types of aggregations - # stack: stacks all ophints with matching tags. i.e. for a static rnn. - # specifically, this is good for an input or output to a static rnn cell. - AGGREGATE_STACK = "stack" - # first: only takes the first output (one with lowest sort index) - # of matching tags. This is good for the input state to an RNN. - AGGREGATE_FIRST = "first" - # aggregation last takes only the last tag (one with highest sort index). - # This is good for an output value on the last stack item of a - # static rnn. - AGGREGATE_LAST = "last" - - class OpHintArgumentTracker(object): - """Conceptually tracks indices of arguments of "OpHint functions". - - The inputs and arguments of these functions both use an instance - of the class so they can have independent numbering.""" - - def __init__(self, function_name, unique_function_id, node_name_prefix, - attr_name): - """Initialize ophint argument. - - Args: - function_name: Name of the function that this tracks arguments for. - unique_function_id: UUID of function that this tracks arguments for. - node_name_prefix: How identities that are created are named. - attr_name: Name of attribute to use to store the index for this hint. - i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX - """ - - # The global index is the argument index of the op. This is in contrast - # to the sort index which is the sequence number of a particular instance - # of a given global index. For example, you may have called add hint - # twice with the tag "foo". Then the global index will be 0 for both - # and the sort index will be 0 for the first added and 1 for the second. - self._function_name = function_name - self._unique_function_id = unique_function_id - self._next_global_index = 0 # The absolute global index - self._used_global_indices = set() - self._tag_to_global_index = {} # The argument index a given tag maps to - self._tag_to_next_sort_index = {} # The current index for each tag - self._node_name_prefix = node_name_prefix - self._attr_name = attr_name - - def _get_new_global_index(self, index_override): - """Return the next unused argument index in order or use an override. - - Args: - index_override: An index to use instead of the next available or None - to use the next available. - - Returns: - A valid global_index to use for the next hint argument. - - Raises: - ValueError: If the index_override is already used by another hint. - """ - if index_override is None: - global_index = self._next_global_index - else: - if index_override in self._used_global_indices: - raise ValueError("Index %d was already used by another call to add") - global_index = index_override - # Make next_global_index valid - self._used_global_indices.add(global_index) - while self._next_global_index in self._used_global_indices: - self._next_global_index += 1 - return global_index - - def add(self, arg, tag=None, name=None, aggregate=None, - index_override=None): - """Return a wrapped tensor of an input tensor as an argument. - - Args: - arg: A TensorFlow tensor that should be considered an argument. - tag: String tag to identify arguments that should be packed. - name: Name of argument. This is included in the Identity hint op names. - aggregate: Strategy to aggregate. - Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, - and OpHint.AGGREGATE_STACK. - Note, aggregate is only valid if tag is specified. - index_override: Specify what input/output index should this be in the - final stub. i.e. add(arg0, index=1); add(arg1, index=0) wil make the - final stub be as stub_func(inputs[arg1, arg0], outputs=[]) rather than - the default call order based ordering. - - Returns: - A tensor representing the wrapped argument. - - Raises: - ValueError: When indices are not consistent. - """ - - # Find the appropriate index - if tag is None: - if aggregate is not None: - raise ValueError("You must specify `tag` if using aggregate.") - global_index = self._get_new_global_index(index_override) - sort_index = None - else: - if aggregate is None: - raise ValueError("You must specify `aggregate` if using tag.") - if tag not in self._tag_to_global_index: - self._tag_to_global_index[tag] = ( - self._get_new_global_index(index_override)) - self._tag_to_next_sort_index[tag] = 0 - elif (index_override and - index_override != self._tag_to_global_index[tag]): - raise ValueError( - "Tag %r was called with two indices %r and %r" % - (tag, index_override, self._tag_to_global_index[tag])) - global_index = self._tag_to_global_index[tag] - sort_index = self._tag_to_next_sort_index[tag] - self._tag_to_next_sort_index[tag] += 1 - - uuid = self._unique_function_id - name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name, - uuid, global_index, sort_index, name) - identity_op = _array_ops.identity(arg, name=name) - - # pylint: disable=protected-access - identity_op.op._set_attr( - OpHint.FUNCTION_NAME_ATTR, - _attr_value_pb2.AttrValue(s=self._function_name)) - identity_op.op._set_attr( - OpHint.FUNCTION_UUID_ATTR, - _attr_value_pb2.AttrValue(s=self._unique_function_id)) - identity_op.op._set_attr( - self._attr_name, _attr_value_pb2.AttrValue(i=global_index)) - if sort_index is not None: - identity_op.op._set_attr( - OpHint.FUNCTION_SORT_INDEX_ATTR, - _attr_value_pb2.AttrValue(i=sort_index)) - if aggregate is not None: - identity_op.op._set_attr( - OpHint.FUNCTION_AGGREGATE_ATTR, - _attr_value_pb2.AttrValue(s=str(aggregate))) - # pylint: enable=protected-access - return identity_op def __init__(self, function_name, **kwargs): """Create a OpHint. @@ -280,14 +112,10 @@ class OpHint(object): """ self._function_name = function_name self._unique_function_id = _uuid.uuid1().hex # TODO(aselle): Unique enough? + self._curr_input_index = 0 + self._curr_output_index = 0 self._attrs_to_store_later = kwargs self._stored_attrs = False - self._inputs = OpHint.OpHintArgumentTracker( - self._function_name, self._unique_function_id, "InputHint", - OpHint.FUNCTION_INPUT_INDEX_ATTR) - self._outputs = OpHint.OpHintArgumentTracker( - self._function_name, self._unique_function_id, "OutputHint", - OpHint.FUNCTION_OUTPUT_INDEX_ATTR) def _setattr(self, dest_op, name, value): tensor_value = _ops.convert_to_tensor(value) @@ -296,278 +124,68 @@ class OpHint(object): tensor=tensor_value.op.node_def.attr["value"].tensor)) # pylint: enable=protected-access - def add_input(self, *args, **kwargs): - """Add a wrapped input argument to the hint. - - Args: - *args: The input tensor. - **kwargs: - "name" label - "tag" a tag to group multiple arguments that will be aggregated. I.e. - a string like 'cool_input'. Basically multiple inputs can be added - to the same hint for parallel operations that will eventually be - combined. An example would be static_rnn which creates multiple copies - of state or inputs. - "aggregate" aggregation strategy that is valid only for tag non None. - Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, - and OpHint.AGGREGATE_STACK. - "index_override" The global index to use. This corresponds to the - argument order in the final stub that will be generated. - Returns: - The wrapped input tensor. - """ - return self._inputs.add(*args, **kwargs) - - def add_output(self, *args, **kwargs): - """Add a wrapped output argument to the hint. - - Args: - *args: The output tensor. - **kwargs: - "name" label - "tag" a tag to group multiple arguments that will be aggregated. I.e. - a string like 'cool_input'. Basically multiple inputs can be added - to the same hint for parallel operations that will eventually be - combined. An example would be static_rnn which creates multiple copies - of state or inputs. - "aggregate" aggregation strategy that is valid only for tag non None. - Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, - and OpHint.AGGREGATE_STACK. - "index_override" The global index to use. This corresponds to the - argument order in the final stub that will be generated. - Returns: - The wrapped output tensor. - """ - return self._outputs.add(*args, **kwargs) - - def add_inputs(self, *args, **kwargs): + def add_inputs(self, *args): """Add a sequence of inputs to the function invocation. Args: *args: List of inputs to be converted (should be Tf.Tensor). - **kwargs: This allows 'names' which should be a list of names. Returns: Wrapped inputs (identity standins that have additional metadata). These are also are also tf.Tensor's. """ - if "names" in kwargs: - return [ - self._inputs.add(arg, name=name) - for arg, name in zip(args, kwargs["names"]) - ] - else: - return [self._inputs.add(arg) for arg in args] - - def add_outputs(self, *args, **kwargs): + + def augmented_identity(arg): + identity_op = _array_ops.identity(arg) + # pylint: disable=protected-access + identity_op.op._set_attr( + OpHint.FUNCTION_NAME_ATTR, + _attr_value_pb2.AttrValue(s=self._function_name)) + identity_op.op._set_attr( + OpHint.FUNCTION_UUID_ATTR, + _attr_value_pb2.AttrValue(s=self._unique_function_id)) + identity_op.op._set_attr( + OpHint.FUNCTION_INPUT_INDEX_ATTR, + _attr_value_pb2.AttrValue(i=self._curr_input_index)) + # pylint: enable=protected-access + self._curr_input_index += 1 + return identity_op + + return [augmented_identity(arg) for arg in args] + + def add_outputs(self, *args): """Add a sequence of outputs to the function invocation. Args: *args: List of outputs to be converted (should be tf.Tensor). - **kwargs: See Returns: Wrapped outputs (identity standins that have additional metadata). These are also tf.Tensor's. """ - if "names" in kwargs: - return [ - self._outputs.add(arg, name=name) - for arg, name in zip(args, kwargs["names"]) - ] - else: - return [self._outputs.add(arg) for arg in args] - - -class _LiteOperand(object): - """Abstract operand for a tflite hint function. - - This is a base class that handles representing arguments to an OpHint. - It also is able to serialize operands to the stubbed graph_def. - Child classes are responsible for being able to - store information about the hint identity operators. They are also responsible - for knowing how to serialize to output graphdefs. - - Typically this will be implemented by holding one or more identity nodes - that were previously discovered as hints. - """ - - def aggregate_and_return_name_for_input(self, out_graphdef): - """This adds the node(s) to out_graphdef and returns the input node name. - - Args: - out_graphdef: A graphdef that is ready to have this input added. - - Returns: - The the output that the stub should use as an input for this operand. - - Raises: - RuntimeError: if the method is not implemented. - """ - del out_graphdef - raise RuntimeError("Unimplemented abstract method.") - def aggregate_and_return_name_for_output(self, fused_op_name, output_index, - out_graphdef): - """Add node(s) to graph representing output operands and returns type. - - Args: - fused_op_name: name of the fused op stub name. - output_index: Output index that we are currently processing from stub. - out_graphdef: The destination graphdef we are currently building up. - - Returns: - The datatype of this identity. - - Raises: - RuntimeError: if the method is not implemented. - """ - del fused_op_name, output_index, out_graphdef - raise RuntimeError("Unimplemented abstract method.") - - -class _LiteSingleOperand(_LiteOperand): - """A simple operand that is non-aggregated (i.e. most hints).""" - - def __init__(self, node): - _LiteOperand.__init__(self) - self.node = node - self.name = _tensor_name_base(node.name) - - def flatten(self): - return [self.name] - - def aggregate_and_return_name_for_input(self, out_graphdef): - return self.name - - def aggregate_and_return_name_for_output(self, fused_op_name, index, - out_graphdef): - output_node = _copy.deepcopy(self.node) - del output_node.input[:] - output_node.input.append(_tensorflow_output_name(fused_op_name, index)) - out_graphdef.node.extend([output_node]) - return self.node.attr["type"].i - - def __str__(self): - return str(self.name) - - -class _LiteAggregateOperand(_LiteOperand): - """An operand for a tflite hint function that is aggregated from many. - - For example, an LSTM is a grid of operators that are all related. Inputs - going into them may need to be fused, so they should all be tracked as - related arguments. - """ - - def __init__(self, aggregation): - _LiteOperand.__init__(self) - self.aggregation = aggregation - self.names = {} - self.nodes = {} - self.flattened = None - - def add(self, sort, node): - self.names[sort] = _tensor_name_base(node.name) - self.nodes[sort] = node - - def flatten_nodes(self): - """Return a list of all the node protos in aggregation sorted order.""" - if not self.flattened: - self.flattened = [None] * len(self.nodes) - for idx, node in self.nodes.iteritems(): - self.flattened[idx] = node - for n in self.nodes: - if n is None: - raise RuntimeError("Aggregate was missing argument.") - if self.aggregation == OpHint.AGGREGATE_FIRST: - self.flattened = self.flattened[:1] - elif self.aggregation == OpHint.AGGREGATE_LAST: - self.flattened = self.flattened[-1:] - elif self.aggregation == OpHint.AGGREGATE_STACK: - pass - else: - raise ValueError( - "Invalid aggregation type %r specified" % self.aggregation) - return self.flattened - - def flatten(self): - """Return a list of all node names in aggregation sorted sorter.""" - return [_tensor_name_base(x.name) for x in self.flatten_nodes()] - - def aggregate_and_return_name_for_input(self, out_graphdef): - """This adds the nodes to out_graphdef and returns an aggregated output. - - In particular, if you have 4 inputs to a hint stub, this will be the - node that you can use as an output. I.e. you have 4 timesteps from a - static rnn, then a fused UnidriecitonalLSTM will expect 1 input with - all 4 time steps. So here we make a pack and return the output name of - that pack. - - Args: - out_graphdef: A graphdef that is ready to have this input added. - - Returns: - The name of a pack that aggregates this node. - """ - flattened = self.flatten_nodes() - if len(flattened) == 1: - return _tensor_name_base(flattened[0].name) - else: - new_node = _node_def_pb2.NodeDef() - new_node.op = "Pack" - new_node.name = "OpHintStack-%s" % flattened[0].name - new_node.attr["N"].i = len(flattened) - new_node.attr["T"].type = flattened[0].attr["T"].type - for discrete in flattened: - new_node.input.append(_tensor_name_base(discrete.name)) - out_graphdef.node.extend([new_node]) - return new_node.name - - def aggregate_and_return_name_for_output(self, fused_op_name, output_index, - out_graphdef): - """This adds to `out_graphdef` all the unaggregated outputs. - - I.e. we are outputting from a fused stub, but we need to make it compatible - with the unfused original graph so we insert an unpack. Ideally in a later - stage the unpack -> pack sequences will be removed. + def augmented_identity(arg): + identity_op = _array_ops.identity(arg) + # pylint: disable=protected-access + identity_op.op._set_attr( + OpHint.FUNCTION_NAME_ATTR, + _attr_value_pb2.AttrValue(s=self._function_name)) + identity_op.op._set_attr( + OpHint.FUNCTION_UUID_ATTR, + _attr_value_pb2.AttrValue(s=self._unique_function_id)) + identity_op.op._set_attr( + OpHint.FUNCTION_OUTPUT_INDEX_ATTR, + _attr_value_pb2.AttrValue(i=self._curr_output_index)) + # pylint: enable=protected-access + self._curr_output_index += 1 + return identity_op - Args: - fused_op_name: The name of the stub we are in the process of fusing. - output_index: The output output_index this object represents. - out_graphdef: The graphdef we are in the process of buildings + wrapped_outputs = [augmented_identity(arg) for arg in args] - Returns: - The type of the aggregated output (so we can finish building the stub - op). - """ - flattened = self.flatten_nodes() - if len(flattened) == 1: - temp_op = _LiteSingleOperand(flattened[0]) - return temp_op.aggregate_and_return_name_for_output( - fused_op_name, output_index, out_graphdef) - else: - stack_node = _node_def_pb2.NodeDef() - stack_node.op = "Unpack" - stack_node.name = "OpHintUnstack-%s" % flattened[0].name - stack_node.attr["num"].i = len(flattened) - output_type = flattened[0].attr["T"].type - stack_node.attr["T"].type = output_type - stack_node.input.append(_tensorflow_output_name( - fused_op_name, output_index)) - out_graphdef.node.extend([stack_node]) - - for idx, discrete in enumerate(flattened): - output_node = _copy.deepcopy(discrete) - del output_node.input[:] - output_node.input.append(_tensorflow_output_name(stack_node.name, idx)) - out_graphdef.node.extend([output_node]) - - return output_type + if not self._stored_attrs: + for key, value in self._attrs_to_store_later.iteritems(): + self._setattr(wrapped_outputs[0], "_tflite_attr_" + key, value) + self._stored_attrs = True - def __str__(self): - s = "\t\t\tAGGREGATE %s\n" % self.aggregation - for sort, val in self.names.iteritems(): - s += "\t\t\t%d: %s\n" % (sort, val) - return s + return wrapped_outputs class _LiteFuncCall(object): @@ -594,87 +212,46 @@ class _LiteFuncCall(object): self.uuid = None self.params = {} - def flattened_inputs_and_outputs(self): - """Return a list of inputs and outputs in a flattened format. - - Returns: - Tuple of (inputs, outputs). where input and output i a list of names. - """ - def _flatten(input_or_output_dict): - flattened_items = [] - for item in input_or_output_dict.values(): - flattened_items.extend(item.flatten()) - return flattened_items - - return _flatten(self.inputs), _flatten(self.outputs) - def __str__(self): - def format_args(items): - s = "" - for idx, item in items.iteritems(): - s += ("\t\t%d:\n" % idx) + str(item) - return s + return "tflite function %s call %s\n\tinputs: %r\n\toutputs: %r" % ( + self.function_name, self.uuid, self.inputs, self.outputs) - inputs_str = "\tInputs\n" + format_args(self.inputs) - outputs_str = "\tOutputs\n" + format_args(self.outputs) - return ("tflite function %s call %s\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s" - % (self.function_name, self.uuid, inputs_str, outputs_str)) - - -def _find_all_hints_in_graph_def(graphdef): +def _find_all_hints_in_graph_def(session): """Look at the current default graph and return a list of LiteFuncCall objs. Args: - graphdef: A TensorFlow graph_def to look for LiteFuncCalls. + session: A TensorFlow session that contains the graph to convert. Returns: a list of `LifeFuncCall` objects in the form """ func_calls = _collections.defaultdict(_LiteFuncCall) - - for node in graphdef.node: - attr = node.attr - # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip - uuid = attr[OpHint.FUNCTION_UUID_ATTR].s - if (OpHint.FUNCTION_UUID_ATTR not in attr - or not attr[OpHint.FUNCTION_UUID_ATTR].s): - continue - - # Start building function - call_def = func_calls[uuid] - call_def.uuid = uuid - call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s - # Get sorting and aggregation information - - sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i - if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None) - if sort == -1: sort = None - aggregation = None - if OpHint.FUNCTION_AGGREGATE_ATTR in attr: - aggregation = attr[OpHint.FUNCTION_AGGREGATE_ATTR].s - - # Add the input or output - def put_operand(stuff, index, sort, operand, aggregation): - """Add a given index into the function structure.""" - if sort is None: - stuff[index] = _LiteSingleOperand(operand) - else: - if index not in stuff: - stuff[index] = _LiteAggregateOperand(aggregation) - stuff[index].add(sort, operand) - - if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr: - put_operand(call_def.inputs, attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i, - sort, node, aggregation) - if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr: - put_operand(call_def.outputs, attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i, - sort, node, aggregation) - - # Remember attributes - for a in attr: - if a.startswith("_tflite_attr_"): - call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor + seen_ops = set() + + for op in session.graph.get_operations(): + for operand in _itertools.chain(op.inputs, op.outputs): + if operand in seen_ops: + continue + seen_ops.add(operand) + attr = operand.op.node_def.attr + uuid = attr[OpHint.FUNCTION_UUID_ATTR].s + if OpHint.FUNCTION_UUID_ATTR not in attr: + continue + call_def = func_calls[uuid] + call_def.uuid = uuid + if OpHint.FUNCTION_UUID_ATTR in attr: + call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s + if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr: + call_def.inputs[attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i] = operand + if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr: + call_def.outputs[attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i] = operand + + for a in attr: + if a.startswith("_tflite_attr_"): + # TODO(aselle): Remember the attribute tensors so we can put them + # in collapse. + call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor return func_calls @@ -690,305 +267,42 @@ def _tensor_name_base(full_tensor_name): Returns: A name without any device assignment. """ - if full_tensor_name.startswith("^"): - return full_tensor_name[1:] - return full_tensor_name.split(":")[0] - - -def _tensorflow_output_name(tensor_name, output_index): - return tensor_name if output_index == 0 else "%s:%d" % (tensor_name, - output_index) - - -# TODO(aselle): This should be converted to grappler in the future. -def _check_subgraph_closed(n, reachable_by_input, input_nodes_set, - name_to_input_name): - """Checks to make sure node only connects to predecessor graph through inputs. - - Args: - n: Node to check - reachable_by_input: Nodes that are reachable by all inputs of subgraph - input_nodes_set: The set of nodes that are "inputs". - name_to_input_name: Maps from name to the list of inputs. - - Raises: - TypeError: If the given node uses items past inputs directly. - """ - next_to_visit = [n] - visited = set() - while next_to_visit: - current_node = next_to_visit.pop() - visited.add(current_node) - if (current_node in reachable_by_input - and current_node not in input_nodes_set): - raise TypeError( - "Node %s uses input %s not in input_nodes." % (n, current_node)) - if current_node not in input_nodes_set: - next_to_visit += [ - input_node for input_node in name_to_input_name[current_node] - if input_node not in visited - ] - - -# TODO(aselle): This should be converted to grappler in the future. -def _convert_single_op_hint_to_stub(call, graph_def): - """Given a graph_def, converts `call` into a stub and returns a new graph_def. - - Args: - call: A single function call to be converted. - graph_def: A graph_def to use as input (that hass call obviously). - Returns: - A new transformed graph-def that has call as a stub (single op). - - Note: after this process, the graph_def can no longer be loaded into - the tensorflow runtime, so all future manipulations are done in graph_def - level. - """ - name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( - graph_def) - input_names, output_names = call.flattened_inputs_and_outputs() - - reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) - reachable_by_output = _bfs_for_reachable_nodes(output_names, - name_to_input_name) - input_nodes_set = set(input_names) - output_nodes_set = set(output_names) - nodes_after_fuse = [] - nodes_deleted_by_fuse = set() - # Classify each node. We want to keep everything reachable by input, but - # we don't know if things that are not reachable by output or input (things - # after fusing). - for node in graph_def.node: - n = _tensor_name_base(node.name) - if n in reachable_by_output: - if n not in reachable_by_input and n not in output_nodes_set: - # n is an internal node. Check to make sure it is really internal. - # TODO(aselle): this could be done more efficiently by flooding - # the graph first. - _check_subgraph_closed(n, reachable_by_input, input_nodes_set, - name_to_input_name) - nodes_deleted_by_fuse.add(n) - elif n not in reachable_by_input: - # n is a node that after all the fusings, so keep it. - nodes_after_fuse.append(n) - else: - # n is a node that is randomly in the graph but not connected to - # the chain of dependencies. - pass - - # Make a new graphdef with all the pre-input and input nodes - out = _graph_pb2.GraphDef() - reachable_by_input_sorted = sorted( - list(reachable_by_input), key=lambda n: name_to_seq_num[n]) - for node in reachable_by_input_sorted: - out.node.extend([_copy.deepcopy(name_to_node[node])]) - - # Create any stacks to aggregate arguments into to a single input - # i.e. for static_rnn's. - # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1 - sorted_input_indices = call.inputs.keys() - sorted_input_indices.sort() - sorted_output_indices = call.outputs.keys() - sorted_output_indices.sort() - new_node = _node_def_pb2.NodeDef() - # Delegate to each operand to produce the proper new input for this stub node. - # In particular, an aggregate input will now be a Pack of some previously - # non-fused things. - for input_index in sorted_input_indices: - inputs = call.inputs[input_index] - new_node.input.append(inputs.aggregate_and_return_name_for_input(out)) - new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices) - - # Ceate the function - new_node.op = call.function_name - new_node.name = call.uuid - out.node.extend([new_node]) - - # Now call each output argument to give them a chance to make the proper - # output type and add it to our new_node. - output_dtypes = [] - for output_index in sorted_output_indices: - output = call.outputs[output_index] - output_dtype = ( - output.aggregate_and_return_name_for_output(new_node.name, output_index, - out)) - output_dtypes.append(output_dtype) - new_node.attr["_output_types"].list.type[:] = output_dtypes - # TODO(aselle): what is right here? - new_node.attr["_output_quantized"].b = False - - # Add post output nodes that do not depend on the outputs - for n in nodes_after_fuse: - should_keep = True - for input_name in name_to_input_name[n]: - if input_name in nodes_deleted_by_fuse: - should_keep = False - if should_keep: - out.node.extend([_copy.deepcopy(name_to_node[n])]) - - # Misc. graph_def data that needs copying. - out.library.CopyFrom(graph_def.library) - out.versions.CopyFrom(graph_def.versions) - - return out - - -# TODO(aselle): This should be converted to grappler in the future. -def _remove_one_redundant_stack_unstack(in_graph_def): - """Removes a stack->unstack pattern from in_graph_def in a returned graph. - - Args: - in_graph_def: Graph def to use as input. - Returns: - Simplified tuple (graph_def, changed_something) where changed_something - is true if anything was done. - """ - name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( - in_graph_def) - del name_to_seq_num - - # TODO(aselle): Make this not hardcoded. - do_generic_pack_unpack = True - - out = _graph_pb2.GraphDef() - out.library.CopyFrom(in_graph_def.library) - out.versions.CopyFrom(in_graph_def.versions) - for n in in_graph_def.node: - node_name = _tensor_name_base(n.name) - if not node_name.startswith("OpHintStack") and not n.op.startswith("Pack"): - continue - next_to_visit = [node_name] - visited = set() - - unpack_nodes = set() - pack_node = node_name - - # Find a pattern of unstack connected to a stack (with identities - # in between. - matches_pattern = True - is_hint_created_stack = False - while next_to_visit: - current_node_name = next_to_visit[0] - visited.add(current_node_name) - del next_to_visit[0] - node = name_to_node[current_node_name] - is_op_hint_stack = node.name.startswith("OpHintStack") - is_op_hint_unstack = node.name.startswith("OpHintUnstack") - if (node.op == "Identity" or is_op_hint_stack - or (do_generic_pack_unpack and node.op == "Pack")): - is_hint_created_stack |= is_op_hint_stack - next_to_visit += [ - input_node for input_node in name_to_input_name[current_node_name] - if input_node not in visited - ] - elif (is_op_hint_unstack - or (do_generic_pack_unpack and node.op == "Unpack")): - unpack_nodes.add(node.name) - is_hint_created_stack &= is_op_hint_unstack - else: - matches_pattern = False - break - visited.add(node.name) - - if matches_pattern and len(unpack_nodes) == 1: - pack_node = node_name - - # Check to see if anyone depends on the intermediate identity or the - # Unstacked form - no_external_dependency = True - for other_n in in_graph_def.node: - if other_n.name in visited: continue - for input_tensor in name_to_input_name[other_n.name]: - input_op = _tensor_name_base(input_tensor) - if input_op in visited and input_op != pack_node: - no_external_dependency = False - # Proceed with the substitution if the stack/unstack pair was created - # through hints, or that it was not, but nobody is consuming things - # between the stack and unstack. - if is_hint_created_stack or no_external_dependency: - end = unpack_nodes.pop() - end_input = name_to_node[end].input[0] - # All nodes that depend on the final stack need to be redone to use - for other_n in in_graph_def.node: - node_name = _tensor_name_base(other_n.name) - if node_name not in visited: - new_node = _copy.deepcopy(other_n) - new_node.input[:] = [ - (end_input if stripped == pack_node else - non_stripped) for stripped, non_stripped in zip( - name_to_input_name[node_name], new_node.input[:]) - ] - out.node.extend([new_node]) - return out, True - return in_graph_def, False - - -def _remove_redundant_stack_unstack(graph_def): - curr = graph_def - del graph_def - changed_stuff = True - while changed_stuff: - curr, changed_stuff = _remove_one_redundant_stack_unstack(curr) - return curr - - -def _convert_op_hints_to_stubs_helper( - graph_def, write_callback=lambda sess, graph_def: None): - """Converts a graph_def to a new graph_def where all op hints are stubbed. + return full_tensor_name.name.split(":")[0] - Args: - graph_def: A graph def that we should convert. - write_callback: A function pointer that can be used to write intermediate - steps of graph transformation (optional). - Returns: - A new stubbed graph_def. - """ - hints = _find_all_hints_in_graph_def(graph_def) - curr_graph_def = graph_def - del graph_def # prevent using graph_def again (common source of error) - for hint in hints.itervalues(): - curr_graph_def = _convert_single_op_hint_to_stub( - hint, curr_graph_def) - write_callback(curr_graph_def, "initial") - # The stubbing process can create stacks/unstacks in the case of LSTMs - # remove them. - curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def) - return curr_graph_def - - -def convert_op_hints_to_stubs(session=None, - graph_def=None, - write_callback=lambda graph_def, comments: None): +def convert_op_hints_to_stubs(session): """Converts a graphdef with LiteOp hints into stub operations. This is used to prepare for toco conversion of complex intrinsic usages. - Note: only one of session or graph_def should be used, not both. Args: session: A TensorFlow session that contains the graph to convert. - graph_def: A graph def that we should convert. - write_callback: A function pointer that can be used to write intermediate - steps of graph transformation (optional). Returns: A new graphdef with all ops contained in OpHints being replaced by a single op call with the right parameters. - Raises: - ValueError: If both session and graph_def are provided. """ - - if session is not None and graph_def is not None: - raise ValueError("Provide only one of session and graph_def.") - - if session is not None: - return _convert_op_hints_to_stubs_helper(session.graph_def, write_callback) - elif graph_def is not None: - return _convert_op_hints_to_stubs_helper(graph_def, write_callback) - else: - raise ValueError("Must specify session or graph_def as input.") - - -_allowed_symbols = [ - "OpHint", "convert_op_hints_to_stubs", "convert_op_hints_to_stubs_new" -] + hints = _find_all_hints_in_graph_def(session) + current_graph_def = session.graph_def + for call in hints.values(): + input_names = [None] * len(call.inputs) + output_names = [None] * len(call.outputs) + output_dtypes = [None] * len(call.outputs) + output_quantized = False + for input_index, tensor in call.inputs.items(): + input_names[input_index] = _tensor_name_base(tensor) + for output_index, tensor in call.outputs.items(): + output_names[output_index] = _tensor_name_base(tensor) + output_dtypes[output_index] = tensor.dtype.as_datatype_enum + # TODO(aselle): Support quantized flag properly + current_graph_def = _framework.fuse_op( + current_graph_def, input_names, output_names, output_dtypes, + output_quantized, call.uuid, call.function_name) + for node in current_graph_def.node: + if node.name == call.uuid: + for param, tensor in call.params.items(): + node.attr[param].tensor.CopyFrom(tensor) + return current_graph_def + + +_allowed_symbols = ["OpHint", "convert_op_hints_to_stubs"] remove_undocumented(__name__, _allowed_symbols) |