aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2018-08-16 11:23:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 11:51:03 -0700
commit020ce87723b43f96372fff79c1b8d9f989286409 (patch)
treefdd704609c6c92027679b86d59fe3b035eff60d2
parent09e272a6a5c3359b671a068f4dac2bca8b312358 (diff)
Automated rollback of commit ec5f4771e42972c31faaa39354d785891de9f91d
PiperOrigin-RevId: 209016586
-rw-r--r--tensorflow/contrib/lite/python/BUILD3
-rw-r--r--tensorflow/contrib/lite/python/convert_test.py93
-rw-r--r--tensorflow/contrib/lite/python/op_hint.py898
3 files changed, 883 insertions, 111 deletions
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 860aff9e7e..47f0c8e9a2 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -112,8 +112,11 @@ py_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python:framework",
"//tensorflow/python:platform",
+ "//tensorflow/python:util",
],
)
diff --git a/tensorflow/contrib/lite/python/convert_test.py b/tensorflow/contrib/lite/python/convert_test.py
index dc21a9b669..bc05514cec 100644
--- a/tensorflow/contrib/lite/python/convert_test.py
+++ b/tensorflow/contrib/lite/python/convert_test.py
@@ -113,12 +113,13 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
# and 1 final output).
self.assertEqual(self._countIdentities(sess.graph_def.node), 4)
- stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess)
+ stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
+ graph_def=sess.graph_def)
self.assertCountEqual(
self._getGraphOpTypes(
stubbed_graphdef,
- output_nodes=[op_hint._tensor_name_base(output)]),
+ output_nodes=[op_hint._tensor_name_base(output.name)]),
["cool_activation", "Const", "Identity"])
def testScaleAndBiasAndIdentity(self):
@@ -139,12 +140,13 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
# +1 for the final output
self.assertEqual(self._countIdentities(sess.graph_def.node), 6)
- stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess)
+ stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
+ graph_def=sess.graph_def)
self.assertCountEqual(
self._getGraphOpTypes(
stubbed_graphdef,
- output_nodes=[op_hint._tensor_name_base(output)]),
+ output_nodes=[op_hint._tensor_name_base(output.name)]),
["scale_and_bias_and_identity", "Const", "Identity", "Pack"])
def testTwoFunctions(self):
@@ -153,7 +155,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
b = array_ops.constant([1.])
def _double_values(x):
custom = op_hint.OpHint("add_test")
- x = custom.add_inputs(x)
+ x, = custom.add_inputs(x)
output = math_ops.multiply(x, x)
output, = custom.add_outputs(output)
return output
@@ -164,13 +166,90 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
# make sure one identity for each input (2) and output (2) => 2 + 2
# +1 for the final output
self.assertEqual(self._countIdentities(sess.graph_def.node), 5)
- stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess)
+ stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
+ graph_def=sess.graph_def)
self.assertCountEqual(
self._getGraphOpTypes(
stubbed_graphdef,
- output_nodes=[op_hint._tensor_name_base(output)]),
+ output_nodes=[op_hint._tensor_name_base(output.name)]),
["add_test", "Const", "Identity", "Add"])
+ def _get_input_index(self, x):
+ return x.op.node_def.attr[op_hint.OpHint.FUNCTION_INPUT_INDEX_ATTR].i
+
+ def _get_output_index(self, x):
+ return x.op.node_def.attr[op_hint.OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i
+
+ def _get_sort_index(self, x):
+ return x.op.node_def.attr[op_hint.OpHint.FUNCTION_SORT_INDEX_ATTR].i
+
+ def testTags(self):
+ """Test if multiple args with the same tag are grouped."""
+ a = array_ops.constant([1.])
+ b = array_ops.constant([2.])
+ c = array_ops.constant([3.])
+ d = array_ops.constant([4.])
+ custom = op_hint.OpHint("test_tag")
+ a = custom.add_input(a, tag="mytag",
+ aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ b, = custom.add_inputs(b)
+ c = custom.add_input(c, tag="mytag",
+ aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ d = custom.add_input(d, tag="mytag2",
+ aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ res = math_ops.add(math_ops.mul(a, b), math_ops.mul(c, b))
+ custom.add_outputs([res])
+ with self.test_session():
+ self.assertEqual(self._get_input_index(a), 0)
+ self.assertEqual(self._get_sort_index(a), 0)
+ self.assertEqual(self._get_input_index(b), 1)
+ self.assertEqual(self._get_input_index(c), 0)
+ self.assertEqual(self._get_sort_index(c), 1)
+
+ def testOverrideIndex(self):
+ a = array_ops.constant([1.])
+ b = array_ops.constant([2.])
+ c = array_ops.constant([3.])
+ custom = op_hint.OpHint("test_override")
+ b = custom.add_input(b) # should auto assign 0
+ a = custom.add_input(a, index_override=1)
+ c = custom.add_input(c) # should auto assign 2
+ with self.test_session():
+ self.assertEqual(self._get_input_index(a), 1)
+ self.assertEqual(self._get_input_index(b), 0)
+ self.assertEqual(self._get_input_index(c), 2)
+
+ def testAggregate(self):
+ a = array_ops.constant([3., 4.])
+ b = array_ops.constant([5., 6.])
+ hint = op_hint.OpHint("agg")
+ a0, a1 = array_ops.unstack(a)
+ b0, b1 = array_ops.unstack(b)
+
+ a0 = hint.add_input(a0, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ b0 = hint.add_input(b0, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ a1 = hint.add_input(a1, tag="c", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ b1 = hint.add_input(b1, tag="n", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+
+ c0 = math_ops.add(a0, b0, name="addleft")
+ c1 = math_ops.add(a1, b1, name="addright")
+ c0 = hint.add_output(
+ c0, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+ c1 = hint.add_output(
+ c1, tag="out", aggregate=op_hint.OpHint.AGGREGATE_STACK)
+
+ curr = array_ops.stack([c0, c1])
+ output = array_ops.identity(curr, name="FINAL_OUTPUT")
+ with self.test_session() as sess:
+ stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
+ graph_def=sess.graph_def)
+ print(stubbed_graphdef)
+ self.assertCountEqual(
+ self._getGraphOpTypes(
+ stubbed_graphdef,
+ output_nodes=[op_hint._tensor_name_base(output.name)]),
+ ["agg", "Const", "Identity"])
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/lite/python/op_hint.py b/tensorflow/contrib/lite/python/op_hint.py
index 7908689ce4..8c920132e5 100644
--- a/tensorflow/contrib/lite/python/op_hint.py
+++ b/tensorflow/contrib/lite/python/op_hint.py
@@ -25,9 +25,9 @@ Example:
def tflite_cool_activation(input):
# A cool activation function.
custom = tf.contrib.lite.OpHint("cool_activation")
- input = custom.add_inputs(input)
+ input, = custom.add_inputs(input)
output = tf.sigmoid(input) * input
- custom.add_outputs(output)
+ output, = custom.add_outputs(output)
return output
image = tf.placeholder(tf.float32, (1, 16, 16, 1))
@@ -64,18 +64,27 @@ ops don't actually exist in the normal TensorFlow runtime, but will be
understood by toco later.
"""
+# TODO(aselle): Make this use generic graph transformations.
+# TODO(aselle): _tensor_name_base should be called _tensor_name_to_op_name.
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections as _collections
-import itertools as _itertools
+import copy as _copy
import uuid as _uuid
+import six as _six
-from tensorflow.contrib import framework as _framework
from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2
+from tensorflow.core.framework import graph_pb2 as _graph_pb2
+from tensorflow.core.framework import node_def_pb2 as _node_def_pb2
from tensorflow.python.framework import ops as _ops
+# TODO(aselle): publicize these apis if we continue to use these.
+from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
+from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
from tensorflow.python.ops import array_ops as _array_ops
+from tensorflow.python.util import compat as _compat
from tensorflow.python.util.all_util import remove_undocumented
@@ -97,11 +106,174 @@ class OpHint(object):
constructs, this mechanism can be retired and changed to use python defun's.
"""
- # Attr constants that are used for representation in the GraphDef
+ # Attr constants that are used for representation in the GraphDef. These
+ # will be used on every Identity op that is involved in a total OpHint.
+
+ # Name of the OpHint function (cosmetic).
FUNCTION_NAME_ATTR = "_tflite_function_name"
+ # UUID of the function (each OpHint gets a new uuid).
FUNCTION_UUID_ATTR = "_tflite_function_uuid"
+ # The index index of the input (or nothing if it is an output).
FUNCTION_INPUT_INDEX_ATTR = "_tflite_function_input_index"
+ # The output index of the output (or nothing if it is an input).
FUNCTION_OUTPUT_INDEX_ATTR = "_tflite_function_output_index"
+ # An index that orders aggregate arguments. Aggregate arguments are ones
+ # that are separate but will be fused horizontally. For example a static LSTM
+ # has a lstm cell for each time step. Each one has a separate opHint, but a
+ # fused SequentialLSTM will treat this as a single tensor.
+ FUNCTION_SORT_INDEX_ATTR = "_tflite_function_sort_index"
+ # The way in which multiple parts of the aggregate argument will be joined
+ # into a fused operand. Valid options are OpHint.AGGREGATE_FIRST,
+ # OpHint.AGGREGATE_LAST, OpHint.AGGREGATE_STACK.
+ FUNCTION_AGGREGATE_ATTR = "_tflite_function_aggregate"
+ # On fused OpHint stub, the order of inputs that the final LSTM call will
+ # have. What this means is that the TensorFlow order might be
+ # "foo", "bar", "stuff" and you might want the TF lite op order to be
+ # "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this
+ # attribute to [2, 0, 1, -1].
+ TFLITE_INPUT_INDICES = "_tflite_input_indices"
+
+ # Types of aggregations
+ # stack: stacks all ophints with matching tags. i.e. for a static rnn.
+ # specifically, this is good for an input or output to a static rnn cell.
+ AGGREGATE_STACK = _compat.as_bytes("stack")
+ # first: only takes the first output (one with lowest sort index)
+ # of matching tags. This is good for the input state to an RNN.
+ AGGREGATE_FIRST = _compat.as_bytes("first")
+ # aggregation last takes only the last tag (one with highest sort index).
+ # This is good for an output value on the last stack item of a
+ # static rnn.
+ AGGREGATE_LAST = _compat.as_bytes("last")
+
+ class OpHintArgumentTracker(object):
+ """Conceptually tracks indices of arguments of "OpHint functions".
+
+ The inputs and arguments of these functions both use an instance
+ of the class so they can have independent numbering."""
+
+ def __init__(self, function_name, unique_function_id, node_name_prefix,
+ attr_name):
+ """Initialize ophint argument.
+
+ Args:
+ function_name: Name of the function that this tracks arguments for.
+ unique_function_id: UUID of function that this tracks arguments for.
+ node_name_prefix: How identities that are created are named.
+ attr_name: Name of attribute to use to store the index for this hint.
+ i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX
+ """
+
+ # The global index is the argument index of the op. This is in contrast
+ # to the sort index which is the sequence number of a particular instance
+ # of a given global index. For example, you may have called add hint
+ # twice with the tag "foo". Then the global index will be 0 for both
+ # and the sort index will be 0 for the first added and 1 for the second.
+ self._function_name = function_name
+ self._unique_function_id = unique_function_id
+ self._next_global_index = 0 # The absolute global index
+ self._used_global_indices = set()
+ self._tag_to_global_index = {} # The argument index a given tag maps to
+ self._tag_to_next_sort_index = {} # The current index for each tag
+ self._node_name_prefix = node_name_prefix
+ self._attr_name = attr_name
+
+ def _get_new_global_index(self, index_override):
+ """Return the next unused argument index in order or use an override.
+
+ Args:
+ index_override: An index to use instead of the next available or None
+ to use the next available.
+
+ Returns:
+ A valid global_index to use for the next hint argument.
+
+ Raises:
+ ValueError: If the index_override is already used by another hint.
+ """
+ if index_override is None:
+ global_index = self._next_global_index
+ else:
+ if index_override in self._used_global_indices:
+ raise ValueError("Index %d was already used by another call to add")
+ global_index = index_override
+ # Make next_global_index valid
+ self._used_global_indices.add(global_index)
+ while self._next_global_index in self._used_global_indices:
+ self._next_global_index += 1
+ return global_index
+
+ def add(self, arg, tag=None, name=None, aggregate=None,
+ index_override=None):
+ """Return a wrapped tensor of an input tensor as an argument.
+
+ Args:
+ arg: A TensorFlow tensor that should be considered an argument.
+ tag: String tag to identify arguments that should be packed.
+ name: Name of argument. This is included in the Identity hint op names.
+ aggregate: Strategy to aggregate.
+ Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
+ and OpHint.AGGREGATE_STACK.
+ Note, aggregate is only valid if tag is specified.
+ index_override: Specify what input/output index should this be in the
+ final stub. i.e. add(arg0, index=1); add(arg1, index=0) wil make the
+ final stub be as stub_func(inputs[arg1, arg0], outputs=[]) rather than
+ the default call order based ordering.
+
+ Returns:
+ A tensor representing the wrapped argument.
+
+ Raises:
+ ValueError: When indices are not consistent.
+ """
+
+ # Find the appropriate index
+ if tag is None:
+ if aggregate is not None:
+ raise ValueError("You must specify `tag` if using aggregate.")
+ global_index = self._get_new_global_index(index_override)
+ sort_index = None
+ else:
+ if aggregate is None:
+ raise ValueError("You must specify `aggregate` if using tag.")
+ if tag not in self._tag_to_global_index:
+ self._tag_to_global_index[tag] = (
+ self._get_new_global_index(index_override))
+ self._tag_to_next_sort_index[tag] = 0
+ elif (index_override and
+ index_override != self._tag_to_global_index[tag]):
+ raise ValueError(
+ "Tag %r was called with two indices %r and %r" %
+ (tag, index_override, self._tag_to_global_index[tag]))
+ global_index = self._tag_to_global_index[tag]
+ sort_index = self._tag_to_next_sort_index[tag]
+ self._tag_to_next_sort_index[tag] += 1
+
+ uuid = self._unique_function_id
+ name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name,
+ uuid, global_index, sort_index, name)
+ identity_op = _array_ops.identity(arg, name=name)
+
+ # pylint: disable=protected-access
+ identity_op.op._set_attr(
+ OpHint.FUNCTION_NAME_ATTR,
+ _attr_value_pb2.AttrValue(
+ s=_compat.as_bytes(self._function_name)))
+ identity_op.op._set_attr(
+ OpHint.FUNCTION_UUID_ATTR,
+ _attr_value_pb2.AttrValue(
+ s=_compat.as_bytes(self._unique_function_id)))
+ identity_op.op._set_attr(
+ self._attr_name, _attr_value_pb2.AttrValue(i=global_index))
+ if sort_index is not None:
+ identity_op.op._set_attr(
+ OpHint.FUNCTION_SORT_INDEX_ATTR,
+ _attr_value_pb2.AttrValue(i=sort_index))
+ if aggregate is not None:
+ identity_op.op._set_attr(
+ OpHint.FUNCTION_AGGREGATE_ATTR,
+ _attr_value_pb2.AttrValue(s=_compat.as_bytes((aggregate))))
+ # pylint: enable=protected-access
+ return identity_op
def __init__(self, function_name, **kwargs):
"""Create a OpHint.
@@ -112,10 +284,14 @@ class OpHint(object):
"""
self._function_name = function_name
self._unique_function_id = _uuid.uuid1().hex # TODO(aselle): Unique enough?
- self._curr_input_index = 0
- self._curr_output_index = 0
self._attrs_to_store_later = kwargs
self._stored_attrs = False
+ self._inputs = OpHint.OpHintArgumentTracker(
+ self._function_name, self._unique_function_id, "InputHint",
+ OpHint.FUNCTION_INPUT_INDEX_ATTR)
+ self._outputs = OpHint.OpHintArgumentTracker(
+ self._function_name, self._unique_function_id, "OutputHint",
+ OpHint.FUNCTION_OUTPUT_INDEX_ATTR)
def _setattr(self, dest_op, name, value):
tensor_value = _ops.convert_to_tensor(value)
@@ -124,68 +300,278 @@ class OpHint(object):
tensor=tensor_value.op.node_def.attr["value"].tensor))
# pylint: enable=protected-access
- def add_inputs(self, *args):
+ def add_input(self, *args, **kwargs):
+ """Add a wrapped input argument to the hint.
+
+ Args:
+ *args: The input tensor.
+ **kwargs:
+ "name" label
+ "tag" a tag to group multiple arguments that will be aggregated. I.e.
+ a string like 'cool_input'. Basically multiple inputs can be added
+ to the same hint for parallel operations that will eventually be
+ combined. An example would be static_rnn which creates multiple copies
+ of state or inputs.
+ "aggregate" aggregation strategy that is valid only for tag non None.
+ Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
+ and OpHint.AGGREGATE_STACK.
+ "index_override" The global index to use. This corresponds to the
+ argument order in the final stub that will be generated.
+ Returns:
+ The wrapped input tensor.
+ """
+ return self._inputs.add(*args, **kwargs)
+
+ def add_output(self, *args, **kwargs):
+ """Add a wrapped output argument to the hint.
+
+ Args:
+ *args: The output tensor.
+ **kwargs:
+ "name" label
+ "tag" a tag to group multiple arguments that will be aggregated. I.e.
+ a string like 'cool_input'. Basically multiple inputs can be added
+ to the same hint for parallel operations that will eventually be
+ combined. An example would be static_rnn which creates multiple copies
+ of state or inputs.
+ "aggregate" aggregation strategy that is valid only for tag non None.
+ Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
+ and OpHint.AGGREGATE_STACK.
+ "index_override" The global index to use. This corresponds to the
+ argument order in the final stub that will be generated.
+ Returns:
+ The wrapped output tensor.
+ """
+ return self._outputs.add(*args, **kwargs)
+
+ def add_inputs(self, *args, **kwargs):
"""Add a sequence of inputs to the function invocation.
Args:
*args: List of inputs to be converted (should be Tf.Tensor).
+ **kwargs: This allows 'names' which should be a list of names.
Returns:
Wrapped inputs (identity standins that have additional metadata). These
are also are also tf.Tensor's.
"""
-
- def augmented_identity(arg):
- identity_op = _array_ops.identity(arg)
- # pylint: disable=protected-access
- identity_op.op._set_attr(
- OpHint.FUNCTION_NAME_ATTR,
- _attr_value_pb2.AttrValue(s=self._function_name))
- identity_op.op._set_attr(
- OpHint.FUNCTION_UUID_ATTR,
- _attr_value_pb2.AttrValue(s=self._unique_function_id))
- identity_op.op._set_attr(
- OpHint.FUNCTION_INPUT_INDEX_ATTR,
- _attr_value_pb2.AttrValue(i=self._curr_input_index))
- # pylint: enable=protected-access
- self._curr_input_index += 1
- return identity_op
-
- return [augmented_identity(arg) for arg in args]
-
- def add_outputs(self, *args):
+ if "names" in kwargs:
+ return [
+ self._inputs.add(arg, name=name)
+ for arg, name in zip(args, kwargs["names"])
+ ]
+ else:
+ return [self._inputs.add(arg) for arg in args]
+
+ def add_outputs(self, *args, **kwargs):
"""Add a sequence of outputs to the function invocation.
Args:
*args: List of outputs to be converted (should be tf.Tensor).
+ **kwargs: See
Returns:
Wrapped outputs (identity standins that have additional metadata). These
are also tf.Tensor's.
"""
+ if "names" in kwargs:
+ return [
+ self._outputs.add(arg, name=name)
+ for arg, name in zip(args, kwargs["names"])
+ ]
+ else:
+ return [self._outputs.add(arg) for arg in args]
+
+
+class _LiteOperand(object):
+ """Abstract operand for a tflite hint function.
+
+ This is a base class that handles representing arguments to an OpHint.
+ It also is able to serialize operands to the stubbed graph_def.
+ Child classes are responsible for being able to
+ store information about the hint identity operators. They are also responsible
+ for knowing how to serialize to output graphdefs.
+
+ Typically this will be implemented by holding one or more identity nodes
+ that were previously discovered as hints.
+ """
+
+ def aggregate_and_return_name_for_input(self, out_graphdef):
+ """This adds the node(s) to out_graphdef and returns the input node name.
+
+ Args:
+ out_graphdef: A graphdef that is ready to have this input added.
+
+ Returns:
+ The the output that the stub should use as an input for this operand.
+
+ Raises:
+ RuntimeError: if the method is not implemented.
+ """
+ del out_graphdef
+ raise RuntimeError("Unimplemented abstract method.")
+
+ def aggregate_and_return_name_for_output(self, fused_op_name, output_index,
+ out_graphdef):
+ """Add node(s) to graph representing output operands and returns type.
+
+ Args:
+ fused_op_name: name of the fused op stub name.
+ output_index: Output index that we are currently processing from stub.
+ out_graphdef: The destination graphdef we are currently building up.
+
+ Returns:
+ The datatype of this identity.
+
+ Raises:
+ RuntimeError: if the method is not implemented.
+ """
+ del fused_op_name, output_index, out_graphdef
+ raise RuntimeError("Unimplemented abstract method.")
- def augmented_identity(arg):
- identity_op = _array_ops.identity(arg)
- # pylint: disable=protected-access
- identity_op.op._set_attr(
- OpHint.FUNCTION_NAME_ATTR,
- _attr_value_pb2.AttrValue(s=self._function_name))
- identity_op.op._set_attr(
- OpHint.FUNCTION_UUID_ATTR,
- _attr_value_pb2.AttrValue(s=self._unique_function_id))
- identity_op.op._set_attr(
- OpHint.FUNCTION_OUTPUT_INDEX_ATTR,
- _attr_value_pb2.AttrValue(i=self._curr_output_index))
- # pylint: enable=protected-access
- self._curr_output_index += 1
- return identity_op
- wrapped_outputs = [augmented_identity(arg) for arg in args]
+class _LiteSingleOperand(_LiteOperand):
+ """A simple operand that is non-aggregated (i.e. most hints)."""
- if not self._stored_attrs:
- for key, value in self._attrs_to_store_later.iteritems():
- self._setattr(wrapped_outputs[0], "_tflite_attr_" + key, value)
- self._stored_attrs = True
+ def __init__(self, node):
+ _LiteOperand.__init__(self)
+ self.node = node
+ self.name = _tensor_name_base(node.name)
- return wrapped_outputs
+ def flatten(self):
+ return [self.name]
+
+ def aggregate_and_return_name_for_input(self, out_graphdef):
+ return self.name
+
+ def aggregate_and_return_name_for_output(self, fused_op_name, index,
+ out_graphdef):
+ output_node = _copy.deepcopy(self.node)
+ del output_node.input[:]
+ output_node.input.append(_tensorflow_output_name(fused_op_name, index))
+ out_graphdef.node.extend([output_node])
+ return self.node.attr["type"].i
+
+ def __str__(self):
+ return str(self.name)
+
+
+class _LiteAggregateOperand(_LiteOperand):
+ """An operand for a tflite hint function that is aggregated from many.
+
+ For example, an LSTM is a grid of operators that are all related. Inputs
+ going into them may need to be fused, so they should all be tracked as
+ related arguments.
+ """
+
+ def __init__(self, aggregation):
+ _LiteOperand.__init__(self)
+ self.aggregation = aggregation
+ self.names = {}
+ self.nodes = {}
+ self.flattened = None
+
+ def add(self, sort, node):
+ self.names[sort] = _tensor_name_base(node.name)
+ self.nodes[sort] = node
+
+ def flatten_nodes(self):
+ """Return a list of all the node protos in aggregation sorted order."""
+ if not self.flattened:
+ self.flattened = [None] * len(self.nodes)
+ for idx, node in _six.iteritems(self.nodes):
+ self.flattened[idx] = node
+ for n in self.nodes:
+ if n is None:
+ raise RuntimeError("Aggregate was missing argument.")
+ if self.aggregation == OpHint.AGGREGATE_FIRST:
+ self.flattened = self.flattened[:1]
+ elif self.aggregation == OpHint.AGGREGATE_LAST:
+ self.flattened = self.flattened[-1:]
+ elif self.aggregation == OpHint.AGGREGATE_STACK:
+ pass
+ else:
+ raise ValueError(
+ "Invalid aggregation type %r specified" % self.aggregation)
+ return self.flattened
+
+ def flatten(self):
+ """Return a list of all node names in aggregation sorted sorter."""
+ return [_tensor_name_base(x.name) for x in self.flatten_nodes()]
+
+ def aggregate_and_return_name_for_input(self, out_graphdef):
+ """This adds the nodes to out_graphdef and returns an aggregated output.
+
+ In particular, if you have 4 inputs to a hint stub, this will be the
+ node that you can use as an output. I.e. you have 4 timesteps from a
+ static rnn, then a fused UnidriecitonalLSTM will expect 1 input with
+ all 4 time steps. So here we make a pack and return the output name of
+ that pack.
+
+ Args:
+ out_graphdef: A graphdef that is ready to have this input added.
+
+ Returns:
+ The name of a pack that aggregates this node.
+ """
+ flattened = self.flatten_nodes()
+ if len(flattened) == 1:
+ return _tensor_name_base(flattened[0].name)
+ else:
+ new_node = _node_def_pb2.NodeDef()
+ new_node.op = "Pack"
+ new_node.name = "OpHintStack-%s" % flattened[0].name
+ new_node.attr["N"].i = len(flattened)
+ new_node.attr["T"].type = flattened[0].attr["T"].type
+ for discrete in flattened:
+ new_node.input.append(_tensor_name_base(discrete.name))
+ out_graphdef.node.extend([new_node])
+ return new_node.name
+
+ def aggregate_and_return_name_for_output(self, fused_op_name, output_index,
+ out_graphdef):
+ """This adds to `out_graphdef` all the unaggregated outputs.
+
+ I.e. we are outputting from a fused stub, but we need to make it compatible
+ with the unfused original graph so we insert an unpack. Ideally in a later
+ stage the unpack -> pack sequences will be removed.
+
+ Args:
+ fused_op_name: The name of the stub we are in the process of fusing.
+ output_index: The output output_index this object represents.
+ out_graphdef: The graphdef we are in the process of buildings
+
+ Returns:
+ The type of the aggregated output (so we can finish building the stub
+ op).
+ """
+ flattened = self.flatten_nodes()
+ if len(flattened) == 1:
+ temp_op = _LiteSingleOperand(flattened[0])
+ return temp_op.aggregate_and_return_name_for_output(
+ fused_op_name, output_index, out_graphdef)
+ else:
+ stack_node = _node_def_pb2.NodeDef()
+ stack_node.op = "Unpack"
+ stack_node.name = "OpHintUnstack-%s" % flattened[0].name
+ stack_node.attr["num"].i = len(flattened)
+ output_type = flattened[0].attr["T"].type
+ stack_node.attr["T"].type = output_type
+ stack_node.input.append(_tensorflow_output_name(
+ fused_op_name, output_index))
+ out_graphdef.node.extend([stack_node])
+
+ for idx, discrete in enumerate(flattened):
+ output_node = _copy.deepcopy(discrete)
+ del output_node.input[:]
+ output_node.input.append(_tensorflow_output_name(stack_node.name, idx))
+ out_graphdef.node.extend([output_node])
+
+ return output_type
+
+ def __str__(self):
+ s = "\t\t\tAGGREGATE %s\n" % self.aggregation
+ for sort, val in self.names.iteritems():
+ s += "\t\t\t%d: %s\n" % (sort, val)
+ return s
class _LiteFuncCall(object):
@@ -212,46 +598,87 @@ class _LiteFuncCall(object):
self.uuid = None
self.params = {}
+ def flattened_inputs_and_outputs(self):
+ """Return a list of inputs and outputs in a flattened format.
+
+ Returns:
+ Tuple of (inputs, outputs). where input and output i a list of names.
+ """
+ def _flatten(input_or_output_dict):
+ flattened_items = []
+ for item in input_or_output_dict.values():
+ flattened_items.extend(item.flatten())
+ return flattened_items
+
+ return _flatten(self.inputs), _flatten(self.outputs)
+
def __str__(self):
- return "tflite function %s call %s\n\tinputs: %r\n\toutputs: %r" % (
- self.function_name, self.uuid, self.inputs, self.outputs)
+ def format_args(items):
+ s = ""
+ for idx, item in items.iteritems():
+ s += ("\t\t%d:\n" % idx) + str(item)
+ return s
+
+ inputs_str = "\tInputs\n" + format_args(self.inputs)
+ outputs_str = "\tOutputs\n" + format_args(self.outputs)
+ return ("tflite function %s call %s\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s"
+ % (self.function_name, self.uuid, inputs_str, outputs_str))
-def _find_all_hints_in_graph_def(session):
+
+def _find_all_hints_in_graph_def(graphdef):
"""Look at the current default graph and return a list of LiteFuncCall objs.
Args:
- session: A TensorFlow session that contains the graph to convert.
+ graphdef: A TensorFlow graph_def to look for LiteFuncCalls.
Returns:
a list of `LifeFuncCall` objects in the form
"""
func_calls = _collections.defaultdict(_LiteFuncCall)
- seen_ops = set()
-
- for op in session.graph.get_operations():
- for operand in _itertools.chain(op.inputs, op.outputs):
- if operand in seen_ops:
- continue
- seen_ops.add(operand)
- attr = operand.op.node_def.attr
- uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
- if OpHint.FUNCTION_UUID_ATTR not in attr:
- continue
- call_def = func_calls[uuid]
- call_def.uuid = uuid
- if OpHint.FUNCTION_UUID_ATTR in attr:
- call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s
- if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr:
- call_def.inputs[attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i] = operand
- if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr:
- call_def.outputs[attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i] = operand
-
- for a in attr:
- if a.startswith("_tflite_attr_"):
- # TODO(aselle): Remember the attribute tensors so we can put them
- # in collapse.
- call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor
+
+ for node in graphdef.node:
+ attr = node.attr
+ # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip
+ uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
+ if (OpHint.FUNCTION_UUID_ATTR not in attr
+ or not attr[OpHint.FUNCTION_UUID_ATTR].s):
+ continue
+
+ # Start building function
+ call_def = func_calls[uuid]
+ call_def.uuid = uuid
+ call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s
+ # Get sorting and aggregation information
+
+ sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
+ if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None)
+ if sort == -1: sort = None
+ aggregation = None
+ if OpHint.FUNCTION_AGGREGATE_ATTR in attr:
+ aggregation = attr[OpHint.FUNCTION_AGGREGATE_ATTR].s
+
+ # Add the input or output
+ def put_operand(stuff, index, sort, operand, aggregation):
+ """Add a given index into the function structure."""
+ if sort is None:
+ stuff[index] = _LiteSingleOperand(operand)
+ else:
+ if index not in stuff:
+ stuff[index] = _LiteAggregateOperand(aggregation)
+ stuff[index].add(sort, operand)
+
+ if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr:
+ put_operand(call_def.inputs, attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i,
+ sort, node, aggregation)
+ if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr:
+ put_operand(call_def.outputs, attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i,
+ sort, node, aggregation)
+
+ # Remember attributes
+ for a in attr:
+ if a.startswith("_tflite_attr_"):
+ call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor
return func_calls
@@ -267,42 +694,305 @@ def _tensor_name_base(full_tensor_name):
Returns:
A name without any device assignment.
"""
- return full_tensor_name.name.split(":")[0]
+ if full_tensor_name.startswith("^"):
+ return full_tensor_name[1:]
+ return full_tensor_name.split(":")[0]
+
+
+def _tensorflow_output_name(tensor_name, output_index):
+ return tensor_name if output_index == 0 else "%s:%d" % (tensor_name,
+ output_index)
+
+
+# TODO(aselle): This should be converted to grappler in the future.
+def _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
+ name_to_input_name):
+ """Checks to make sure node only connects to predecessor graph through inputs.
+
+ Args:
+ n: Node to check
+ reachable_by_input: Nodes that are reachable by all inputs of subgraph
+ input_nodes_set: The set of nodes that are "inputs".
+ name_to_input_name: Maps from name to the list of inputs.
+
+ Raises:
+ TypeError: If the given node uses items past inputs directly.
+ """
+ next_to_visit = [n]
+ visited = set()
+ while next_to_visit:
+ current_node = next_to_visit.pop()
+ visited.add(current_node)
+ if (current_node in reachable_by_input
+ and current_node not in input_nodes_set):
+ raise TypeError(
+ "Node %s uses input %s not in input_nodes." % (n, current_node))
+ if current_node not in input_nodes_set:
+ next_to_visit += [
+ input_node for input_node in name_to_input_name[current_node]
+ if input_node not in visited
+ ]
+
+
+# TODO(aselle): This should be converted to grappler in the future.
+def _convert_single_op_hint_to_stub(call, graph_def):
+ """Given a graph_def, converts `call` into a stub and returns a new graph_def.
+ Args:
+ call: A single function call to be converted.
+ graph_def: A graph_def to use as input (that hass call obviously).
+ Returns:
+ A new transformed graph-def that has call as a stub (single op).
-def convert_op_hints_to_stubs(session):
+ Note: after this process, the graph_def can no longer be loaded into
+ the tensorflow runtime, so all future manipulations are done in graph_def
+ level.
+ """
+ name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
+ graph_def)
+ input_names, output_names = call.flattened_inputs_and_outputs()
+
+ reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
+ reachable_by_output = _bfs_for_reachable_nodes(output_names,
+ name_to_input_name)
+ input_nodes_set = set(input_names)
+ output_nodes_set = set(output_names)
+ nodes_after_fuse = []
+ nodes_deleted_by_fuse = set()
+ # Classify each node. We want to keep everything reachable by input, but
+ # we don't know if things that are not reachable by output or input (things
+ # after fusing).
+ for node in graph_def.node:
+ n = _tensor_name_base(node.name)
+ if n in reachable_by_output:
+ if n not in reachable_by_input and n not in output_nodes_set:
+ # n is an internal node. Check to make sure it is really internal.
+ # TODO(aselle): this could be done more efficiently by flooding
+ # the graph first.
+ _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
+ name_to_input_name)
+ nodes_deleted_by_fuse.add(n)
+ elif n not in reachable_by_input:
+ # n is a node that after all the fusings, so keep it.
+ nodes_after_fuse.append(n)
+ else:
+ # n is a node that is randomly in the graph but not connected to
+ # the chain of dependencies.
+ pass
+
+ # Make a new graphdef with all the pre-input and input nodes
+ out = _graph_pb2.GraphDef()
+ reachable_by_input_sorted = sorted(
+ list(reachable_by_input), key=lambda n: name_to_seq_num[n])
+ for node in reachable_by_input_sorted:
+ out.node.extend([_copy.deepcopy(name_to_node[node])])
+
+ # Create any stacks to aggregate arguments into to a single input
+ # i.e. for static_rnn's.
+ # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1
+ sorted_input_indices = list(call.inputs.keys())
+ sorted_input_indices.sort()
+ sorted_output_indices = list(call.outputs.keys())
+ sorted_output_indices.sort()
+ new_node = _node_def_pb2.NodeDef()
+ # Delegate to each operand to produce the proper new input for this stub node.
+ # In particular, an aggregate input will now be a Pack of some previously
+ # non-fused things.
+ for input_index in sorted_input_indices:
+ inputs = call.inputs[input_index]
+ new_node.input.append(inputs.aggregate_and_return_name_for_input(out))
+ new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices)
+
+ # Ceate the function
+ new_node.op = call.function_name
+ new_node.name = call.uuid
+ out.node.extend([new_node])
+
+ # Now call each output argument to give them a chance to make the proper
+ # output type and add it to our new_node.
+ output_dtypes = []
+ for output_index in sorted_output_indices:
+ output = call.outputs[output_index]
+ output_dtype = (
+ output.aggregate_and_return_name_for_output(new_node.name, output_index,
+ out))
+ output_dtypes.append(output_dtype)
+ new_node.attr["_output_types"].list.type[:] = output_dtypes
+ # TODO(aselle): what is right here?
+ new_node.attr["_output_quantized"].b = False
+
+ # Add post output nodes that do not depend on the outputs
+ for n in nodes_after_fuse:
+ should_keep = True
+ for input_name in name_to_input_name[n]:
+ if input_name in nodes_deleted_by_fuse:
+ should_keep = False
+ if should_keep:
+ out.node.extend([_copy.deepcopy(name_to_node[n])])
+
+ # Misc. graph_def data that needs copying.
+ out.library.CopyFrom(graph_def.library)
+ out.versions.CopyFrom(graph_def.versions)
+
+ return out
+
+
+# TODO(aselle): This should be converted to grappler in the future.
+def _remove_one_redundant_stack_unstack(in_graph_def):
+ """Removes a stack->unstack pattern from in_graph_def in a returned graph.
+
+ Args:
+ in_graph_def: Graph def to use as input.
+ Returns:
+ Simplified tuple (graph_def, changed_something) where changed_something
+ is true if anything was done.
+ """
+ name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
+ in_graph_def)
+ del name_to_seq_num
+
+ # TODO(aselle): Make this not hardcoded.
+ do_generic_pack_unpack = True
+
+ out = _graph_pb2.GraphDef()
+ out.library.CopyFrom(in_graph_def.library)
+ out.versions.CopyFrom(in_graph_def.versions)
+ for n in in_graph_def.node:
+ node_name = _tensor_name_base(n.name)
+ if not node_name.startswith("OpHintStack") and not n.op.startswith("Pack"):
+ continue
+ next_to_visit = [node_name]
+ visited = set()
+
+ unpack_nodes = set()
+ pack_node = node_name
+
+ # Find a pattern of unstack connected to a stack (with identities
+ # in between.
+ matches_pattern = True
+ is_hint_created_stack = False
+ while next_to_visit:
+ current_node_name = next_to_visit[0]
+ visited.add(current_node_name)
+ del next_to_visit[0]
+ node = name_to_node[current_node_name]
+ is_op_hint_stack = node.name.startswith("OpHintStack")
+ is_op_hint_unstack = node.name.startswith("OpHintUnstack")
+ if (node.op == "Identity" or is_op_hint_stack
+ or (do_generic_pack_unpack and node.op == "Pack")):
+ is_hint_created_stack |= is_op_hint_stack
+ next_to_visit += [
+ input_node for input_node in name_to_input_name[current_node_name]
+ if input_node not in visited
+ ]
+ elif (is_op_hint_unstack
+ or (do_generic_pack_unpack and node.op == "Unpack")):
+ unpack_nodes.add(node.name)
+ is_hint_created_stack &= is_op_hint_unstack
+ else:
+ matches_pattern = False
+ break
+ visited.add(node.name)
+
+ if matches_pattern and len(unpack_nodes) == 1:
+ pack_node = node_name
+
+ # Check to see if anyone depends on the intermediate identity or the
+ # Unstacked form
+ no_external_dependency = True
+ for other_n in in_graph_def.node:
+ if other_n.name in visited: continue
+ for input_tensor in name_to_input_name[other_n.name]:
+ input_op = _tensor_name_base(input_tensor)
+ if input_op in visited and input_op != pack_node:
+ no_external_dependency = False
+ # Proceed with the substitution if the stack/unstack pair was created
+ # through hints, or that it was not, but nobody is consuming things
+ # between the stack and unstack.
+ if is_hint_created_stack or no_external_dependency:
+ end = unpack_nodes.pop()
+ end_input = name_to_node[end].input[0]
+ # All nodes that depend on the final stack need to be redone to use
+ for other_n in in_graph_def.node:
+ node_name = _tensor_name_base(other_n.name)
+ if node_name not in visited:
+ new_node = _copy.deepcopy(other_n)
+ new_node.input[:] = [
+ (end_input if stripped == pack_node else
+ non_stripped) for stripped, non_stripped in zip(
+ name_to_input_name[node_name], new_node.input[:])
+ ]
+ out.node.extend([new_node])
+ return out, True
+ return in_graph_def, False
+
+
+def _remove_redundant_stack_unstack(graph_def):
+ curr = graph_def
+ del graph_def
+ changed_stuff = True
+ while changed_stuff:
+ curr, changed_stuff = _remove_one_redundant_stack_unstack(curr)
+ return curr
+
+
+def _convert_op_hints_to_stubs_helper(
+ graph_def, write_callback=lambda sess, graph_def: None):
+ """Converts a graph_def to a new graph_def where all op hints are stubbed.
+
+ Args:
+ graph_def: A graph def that we should convert.
+ write_callback: A function pointer that can be used to write intermediate
+ steps of graph transformation (optional).
+ Returns:
+ A new stubbed graph_def.
+ """
+
+ hints = _find_all_hints_in_graph_def(graph_def)
+ curr_graph_def = graph_def
+ del graph_def # prevent using graph_def again (common source of error)
+ for hint in _six.itervalues(hints):
+ curr_graph_def = _convert_single_op_hint_to_stub(
+ hint, curr_graph_def)
+ write_callback(curr_graph_def, "initial")
+ # The stubbing process can create stacks/unstacks in the case of LSTMs
+ # remove them.
+ curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def)
+ return curr_graph_def
+
+
+def convert_op_hints_to_stubs(session=None,
+ graph_def=None,
+ write_callback=lambda graph_def, comments: None):
"""Converts a graphdef with LiteOp hints into stub operations.
This is used to prepare for toco conversion of complex intrinsic usages.
+ Note: only one of session or graph_def should be used, not both.
Args:
session: A TensorFlow session that contains the graph to convert.
+ graph_def: A graph def that we should convert.
+ write_callback: A function pointer that can be used to write intermediate
+ steps of graph transformation (optional).
Returns:
A new graphdef with all ops contained in OpHints being replaced by
a single op call with the right parameters.
+ Raises:
+ ValueError: If both session and graph_def are provided.
"""
- hints = _find_all_hints_in_graph_def(session)
- current_graph_def = session.graph_def
- for call in hints.values():
- input_names = [None] * len(call.inputs)
- output_names = [None] * len(call.outputs)
- output_dtypes = [None] * len(call.outputs)
- output_quantized = False
- for input_index, tensor in call.inputs.items():
- input_names[input_index] = _tensor_name_base(tensor)
- for output_index, tensor in call.outputs.items():
- output_names[output_index] = _tensor_name_base(tensor)
- output_dtypes[output_index] = tensor.dtype.as_datatype_enum
- # TODO(aselle): Support quantized flag properly
- current_graph_def = _framework.fuse_op(
- current_graph_def, input_names, output_names, output_dtypes,
- output_quantized, call.uuid, call.function_name)
- for node in current_graph_def.node:
- if node.name == call.uuid:
- for param, tensor in call.params.items():
- node.attr[param].tensor.CopyFrom(tensor)
- return current_graph_def
-
-
-_allowed_symbols = ["OpHint", "convert_op_hints_to_stubs"]
+
+ if session is not None and graph_def is not None:
+ raise ValueError("Provide only one of session and graph_def.")
+
+ if session is not None:
+ return _convert_op_hints_to_stubs_helper(session.graph_def, write_callback)
+ elif graph_def is not None:
+ return _convert_op_hints_to_stubs_helper(graph_def, write_callback)
+ else:
+ raise ValueError("Must specify session or graph_def as input.")
+
+
+_allowed_symbols = [
+ "OpHint", "convert_op_hints_to_stubs", "convert_op_hints_to_stubs_new"
+]
remove_undocumented(__name__, _allowed_symbols)