aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-09 20:05:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 20:08:54 -0700
commit854ae599743a1e92a31ad49cfe42c6454cefd3b9 (patch)
tree1ff75695f61c5eb3353e739295e81f76bbe28a64
parent58fcfc98cd59ae3952399fc55380b8733df08df9 (diff)
Use Ophints to support TfLite UnidirectionaSequenceLstm and add an e2e test.
Support peephole and num_proj as well. PiperOrigin-RevId: 216467578
-rw-r--r--tensorflow/contrib/lite/experimental/examples/lstm/BUILD40
-rw-r--r--tensorflow/contrib/lite/experimental/examples/lstm/tflite_lstm.py396
-rw-r--r--tensorflow/contrib/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py226
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc47
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc44
-rw-r--r--tensorflow/contrib/lite/toco/model.h6
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc39
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc5
-rw-r--r--tensorflow/tools/pip_package/pip_smoke_test.py4
10 files changed, 811 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/experimental/examples/lstm/BUILD b/tensorflow/contrib/lite/experimental/examples/lstm/BUILD
new file mode 100644
index 0000000000..2125f218ca
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/lstm/BUILD
@@ -0,0 +1,40 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_library(
+ name = "tflite_lstm",
+ srcs = ["tflite_lstm.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/lite/python:lite",
+ "//tensorflow/python:framework",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "unidirectional_sequence_lstm_test",
+ size = "large",
+ srcs = ["unidirectional_sequence_lstm_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ ],
+ deps = [
+ ":tflite_lstm",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/lite/python:lite",
+ "//tensorflow/examples/tutorials/mnist:input_data",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform",
+ "//tensorflow/python/tools:optimize_for_inference",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
diff --git a/tensorflow/contrib/lite/experimental/examples/lstm/tflite_lstm.py b/tensorflow/contrib/lite/experimental/examples/lstm/tflite_lstm.py
new file mode 100644
index 0000000000..2357743266
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/lstm/tflite_lstm.py
@@ -0,0 +1,396 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TfLite LSTMCell wrapper.
+
+TODO(renjieliu): Find a better home for this one.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import tensorflow as tf
+
+from tensorflow.contrib.lite.python import lite
+from tensorflow.python.keras import activations
+from tensorflow.python.keras import initializers
+from tensorflow.python.layers import base as base_layer
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import rnn_cell_impl
+from tensorflow.python.platform import tf_logging as logging
+
+
+class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
+ """Long short-term memory unit (LSTM) recurrent network cell.
+
+ This is used only for TfLite, it provides hints and it also makes the
+ variables in the desired for the tflite ops (transposed and seaparated).
+
+ The default non-peephole implementation is based on:
+
+ https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
+
+ Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+ "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
+
+ The peephole implementation is based on:
+
+ https://research.google.com/pubs/archive/43905.pdf
+
+ Hasim Sak, Andrew Senior, and Francoise Beaufays.
+ "Long short-term memory recurrent neural network architectures for
+ large scale acoustic modeling." INTERSPEECH, 2014.
+
+ The class uses optional peep-hole connections, optional cell clipping, and
+ an optional projection layer.
+
+ Note that this cell is not optimized for performance. Please use
+ `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or
+ `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for
+ better performance on CPU.
+ """
+
+ def __init__(self,
+ num_units,
+ use_peepholes=False,
+ cell_clip=None,
+ initializer=None,
+ num_proj=None,
+ proj_clip=None,
+ num_unit_shards=None,
+ num_proj_shards=None,
+ forget_bias=1.0,
+ state_is_tuple=True,
+ activation=None,
+ reuse=None,
+ name=None,
+ dtype=None):
+ """Initialize the parameters for an LSTM cell.
+
+ Args:
+ num_units: int, The number of units in the LSTM cell.
+ use_peepholes: bool, set True to enable diagonal/peephole connections.
+ cell_clip: (optional) A float value, if provided the cell state is clipped
+ by this value prior to the cell output activation.
+ initializer: (optional) The initializer to use for the weight and
+ projection matrices.
+ num_proj: (optional) int, The output dimensionality for the projection
+ matrices. If None, no projection is performed.
+ proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
+ provided, then the projected values are clipped elementwise to within
+ `[-proj_clip, proj_clip]`.
+ num_unit_shards: Deprecated, will be removed by Jan. 2017. Use a
+ variable_scope partitioner instead.
+ num_proj_shards: Deprecated, will be removed by Jan. 2017. Use a
+ variable_scope partitioner instead.
+ forget_bias: Biases of the forget gate are initialized by default to 1 in
+ order to reduce the scale of forgetting at the beginning of the
+ training. Must set it manually to `0.0` when restoring from CudnnLSTM
+ trained checkpoints.
+ state_is_tuple: If True, accepted and returned states are 2-tuples of the
+ `c_state` and `m_state`. If False, they are concatenated along the
+ column axis. This latter behavior will soon be deprecated.
+ activation: Activation function of the inner states. Default: `tanh`.
+ reuse: (optional) Python boolean describing whether to reuse variables in
+ an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+ name: String, the name of the layer. Layers with the same name will share
+ weights, but to avoid mistakes we require reuse=True in such cases.
+ dtype: Default dtype of the layer (default of `None` means use the type of
+ the first input). Required when `build` is called before `call`. When
+ restoring from CudnnLSTM-trained checkpoints, use
+ `CudnnCompatibleLSTMCell` instead.
+ """
+ super(TFLiteLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
+ # TODO(raziel): decide if we want to just support tuples (yes please!).
+ if not state_is_tuple:
+ logging.warn(
+ "%s: Using a concatenated state is slower and will soon be "
+ "deprecated. Use state_is_tuple=True.", self)
+ if num_unit_shards is not None or num_proj_shards is not None:
+ logging.warn(
+ "%s: The num_unit_shards and proj_unit_shards parameters are "
+ "deprecated and will be removed in Jan 2017. "
+ "Use a variable scope with a partitioner instead.", self)
+
+ # Inputs must be 2-dimensional.
+ # TODO(raziel): layers stuff -- chop if un-layerizing Op.
+ self.input_spec = base_layer.InputSpec(ndim=2)
+
+ self._tflite_wrapper = lite.OpHint("UnidirectionalSequenceLstm")
+
+ self._num_units = num_units
+ self._use_peepholes = use_peepholes
+ self._cell_clip = cell_clip
+ self._initializer = initializer
+ self._num_proj = num_proj
+ self._proj_clip = proj_clip
+ self._num_unit_shards = num_unit_shards
+ self._num_proj_shards = num_proj_shards
+ self._forget_bias = forget_bias
+ self._state_is_tuple = state_is_tuple
+ self._activation = activation or math_ops.tanh
+
+ self._output_size = num_proj if num_proj else num_units
+ self._state_size = (
+ tf.nn.rnn_cell.LSTMStateTuple(num_units, self._output_size)
+ if state_is_tuple else num_units + self._output_size)
+
+ @property
+ def state_size(self):
+ return self._state_size
+
+ @property
+ def output_size(self):
+ return self._output_size
+
+ def build(self, inputs_shape):
+ """Build TfLite LSTM cell graph.
+
+ Args:
+ inputs_shape: The inputs_shape must be known, and is [batch_size,
+ input_size] shape.
+
+ Raises:
+ ValueError: if the inputs_shape is invalid.
+ """
+ if len(inputs_shape) != 2 or inputs_shape[1].value is None:
+ raise ValueError("Invalid inputs_shape, saw shape: %s" % inputs_shape)
+
+ input_depth = inputs_shape[1].value
+ maybe_partitioner = (
+ partitioned_variables.fixed_size_partitioner(self._num_unit_shards)
+ if self._num_unit_shards is not None else None)
+ input_weight_shape = [self._num_units, input_depth]
+ cell_weight_shape = [self._num_units, self._output_size]
+ bias_shape = [self._num_units]
+
+ def add_variable_wrapped(name, shape, initializer, index, partitioner):
+ var = self.add_variable(
+ name, shape=shape, initializer=initializer, partitioner=partitioner)
+ return self._tflite_wrapper.add_input(
+ var, name="name", index_override=index)
+
+ weight_initializer = self._initializer
+ if self.dtype is None:
+ bias_initializer = init_ops.zeros_initializer
+ else:
+ bias_initializer = init_ops.zeros_initializer(dtype=self.dtype)
+
+ self.input_to_input_w = add_variable_wrapped(
+ "input_to_input_w", input_weight_shape, weight_initializer, 1,
+ maybe_partitioner)
+ self.input_to_forget_w = add_variable_wrapped(
+ "input_to_forget_w", input_weight_shape, weight_initializer, 2,
+ maybe_partitioner)
+ self.input_to_cell_w = add_variable_wrapped(
+ "input_to_cell_w", input_weight_shape, weight_initializer, 3,
+ maybe_partitioner)
+ self.input_to_output_w = add_variable_wrapped(
+ "input_to_output_w", input_weight_shape, weight_initializer, 4,
+ maybe_partitioner)
+ self.cell_to_input_w = add_variable_wrapped(
+ "cell_to_input_w", cell_weight_shape, weight_initializer, 5,
+ maybe_partitioner)
+ self.cell_to_forget_w = add_variable_wrapped(
+ "cell_to_forget_w", cell_weight_shape, weight_initializer, 6,
+ maybe_partitioner)
+ self.cell_to_cell_w = add_variable_wrapped(
+ "cell_to_cell_w", cell_weight_shape, weight_initializer, 7,
+ maybe_partitioner)
+ self.cell_to_output_w = add_variable_wrapped(
+ "cell_to_output_w", cell_weight_shape, weight_initializer, 8,
+ maybe_partitioner)
+
+ self.input_bias = add_variable_wrapped(
+ "input_bias", bias_shape, bias_initializer, 12, maybe_partitioner)
+ self.forget_bias = add_variable_wrapped(
+ "forget_bias", bias_shape, bias_initializer, 13, maybe_partitioner)
+ self.cell_bias = add_variable_wrapped(
+ "cell_bias", bias_shape, bias_initializer, 14, maybe_partitioner)
+ self.output_bias = add_variable_wrapped(
+ "output_bias", bias_shape, bias_initializer, 15, maybe_partitioner)
+
+ # index 9, 10, 11.
+ # f stands for forget, i stands for input and o stands for output.
+ if self._use_peepholes:
+ self._w_f_diag = add_variable_wrapped("w_f_diag", [self._num_units],
+ self._initializer, 9,
+ maybe_partitioner)
+ self._w_i_diag = add_variable_wrapped("w_i_diag", [self._num_units],
+ self._initializer, 10,
+ maybe_partitioner)
+ self._w_o_diag = add_variable_wrapped("w_o_diag", [self._num_units],
+ self._initializer, 11,
+ maybe_partitioner)
+
+ # index 16 for proj kernel.
+ if self._num_proj is not None:
+ maybe_proj_partitioner = (
+ partitioned_variables.fixed_size_partitioner(self._num_proj_shards)
+ if self._num_proj_shards is not None else None)
+ self._proj_kernel = add_variable_wrapped(
+ "projection/kernel", [self._num_proj, self._num_units],
+ self._initializer,
+ 16,
+ partitioner=maybe_proj_partitioner)
+
+ self.built = True
+
+ def call(self, inputs, state):
+ """Run one step of LSTM.
+
+ Args:
+ inputs: input Tensor, 2D, `[batch, num_units]`.
+ state: if `state_is_tuple` is False, this must be a state Tensor, `2-D,
+ [batch, state_size]`. If `state_is_tuple` is True, this must be a tuple
+ of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`.
+
+ Returns:
+ A tuple containing:
+
+ - A `2-D, [batch, output_dim]`, Tensor representing the output of the
+ LSTM after reading `inputs` when previous state was `state`.
+ Here output_dim is:
+ num_proj if num_proj was set,
+ num_units otherwise.
+ - Tensor(s) representing the new state of LSTM after reading `inputs` when
+ the previous state was `state`. Same type and shape(s) as `state`.
+
+ Raises:
+ ValueError: If input size cannot be inferred from inputs via
+ static shape inference.
+ """
+ inputs = self._tflite_wrapper.add_input(
+ inputs, tag="input", name="input", aggregate="stack", index_override=0)
+
+ # Make sure inputs and bias_initializer has the same type.
+ assert inputs.dtype == self.input_to_input_w.dtype
+
+ num_proj = self._num_units if self._num_proj is None else self._num_proj
+ sigmoid = math_ops.sigmoid
+
+ if self._state_is_tuple:
+ (c_prev, m_prev) = state
+ else:
+ c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
+ m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
+
+ # Note: For TfLite, cell_state is at index 19 while activation state at
+ # index 18.
+ c_prev = self._tflite_wrapper.add_input(
+ c_prev,
+ tag="c_prev",
+ name="c_prev",
+ aggregate="first",
+ index_override=19)
+ m_prev = self._tflite_wrapper.add_input(
+ m_prev,
+ tag="m_prev",
+ name="m_prev",
+ aggregate="first",
+ index_override=18)
+
+ input_size = inputs.get_shape().with_rank(2)[1]
+ if input_size.value is None:
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
+
+ inputs_and_m_prev = array_ops.concat([inputs, m_prev], axis=1)
+
+ # i stands for input gate.
+ # f stands for forget gate activation.
+ # o outputs.
+ # j output of LSTM unit.
+ # c is the final state.
+ # m is the output.
+ i = nn_ops.bias_add(
+ tf.matmul(
+ inputs_and_m_prev,
+ tf.concat([self.input_to_input_w, self.cell_to_input_w], axis=1),
+ transpose_b=True), self.input_bias)
+ f = nn_ops.bias_add(
+ tf.matmul(
+ inputs_and_m_prev,
+ tf.concat([self.input_to_forget_w, self.cell_to_forget_w], axis=1),
+ transpose_b=True), self.forget_bias)
+ o = nn_ops.bias_add(
+ tf.matmul(
+ inputs_and_m_prev,
+ tf.concat([self.input_to_output_w, self.cell_to_output_w], axis=1),
+ transpose_b=True), self.output_bias)
+ j = nn_ops.bias_add(
+ tf.matmul(
+ inputs_and_m_prev,
+ tf.concat([self.input_to_cell_w, self.cell_to_cell_w], axis=1),
+ transpose_b=True), self.cell_bias)
+
+ # Diagonal connections
+ if self._use_peepholes:
+ c = (
+ sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +
+ sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))
+ else:
+ c = (
+ sigmoid(f + self._forget_bias) * c_prev +
+ sigmoid(i) * self._activation(j))
+
+ if self._cell_clip is not None:
+ # pylint: disable=invalid-unary-operand-type
+ c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
+ # pylint: enable=invalid-unary-operand-type
+ if self._use_peepholes:
+ m = sigmoid(o + self._w_o_diag * c) * self._activation(c)
+ else:
+ m = sigmoid(o) * self._activation(c)
+
+ if self._num_proj is not None:
+ transposed_proj_kernel = tf.transpose(self._proj_kernel)
+ m = math_ops.matmul(m, transposed_proj_kernel)
+
+ if self._proj_clip is not None:
+ # pylint: disable=invalid-unary-operand-type
+ m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
+ # pylint: enable=invalid-unary-operand-type
+
+ c = self._tflite_wrapper.add_output(
+ c, tag="c", name="c", aggregate="last", index_override=1)
+ m = self._tflite_wrapper.add_output(
+ m, tag="m", name="m", index_override=2, aggregate="stack")
+
+ new_state = (
+ tf.nn.rnn_cell.LSTMStateTuple(c, m)
+ if self._state_is_tuple else array_ops.concat([c, m], 1))
+ return m, new_state
+
+ def get_config(self):
+ config = {
+ "num_units": self._num_units,
+ "use_peepholes": self._use_peepholes,
+ "cell_clip": self._cell_clip,
+ "initializer": initializers.serialize(self._initializer),
+ "num_proj": self._num_proj,
+ "proj_clip": self._proj_clip,
+ "num_unit_shards": self._num_unit_shards,
+ "num_proj_shards": self._num_proj_shards,
+ "forget_bias": self._forget_bias,
+ "state_is_tuple": self._state_is_tuple,
+ "activation": activations.serialize(self._activation),
+ "reuse": self._reuse,
+ }
+ base_config = super(TFLiteLSTMCell, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
diff --git a/tensorflow/contrib/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py b/tensorflow/contrib/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
new file mode 100644
index 0000000000..2ca977518c
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
@@ -0,0 +1,226 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import tempfile
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.contrib.lite.experimental.examples.lstm.tflite_lstm import TFLiteLSTMCell
+from tensorflow.examples.tutorials.mnist import input_data
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+from tensorflow.python.tools import optimize_for_inference_lib
+
+# Number of steps to train model.
+TRAIN_STEPS = 1
+
+CONFIG = tf.ConfigProto(device_count={"GPU": 0})
+
+
+class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ tf.reset_default_graph()
+ # Import MNIST dataset
+ self.mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
+
+ # Define constants
+ # Unrolled through 28 time steps
+ self.time_steps = 28
+ # Rows of 28 pixels
+ self.n_input = 28
+ # Learning rate for Adam optimizer
+ self.learning_rate = 0.001
+ # MNIST is meant to be classified in 10 classes(0-9).
+ self.n_classes = 10
+ # Batch size
+ self.batch_size = 16
+ # Lstm Units.
+ self.num_units = 64
+
+ def buildLstmLayer(self):
+ return tf.nn.rnn_cell.MultiRNNCell([
+ TFLiteLSTMCell(
+ self.num_units, use_peepholes=True, forget_bias=0, name="rnn1"),
+ TFLiteLSTMCell(self.num_units, num_proj=64, forget_bias=0, name="rnn2"),
+ TFLiteLSTMCell(
+ self.num_units // 2,
+ use_peepholes=True,
+ num_proj=64,
+ forget_bias=0,
+ name="rnn3"),
+ TFLiteLSTMCell(self.num_units, forget_bias=0, name="rnn4")
+ ])
+
+ def buildModel(self, lstm_layer, is_dynamic_rnn, is_train):
+ # Weights and biases for output softmax layer.
+ out_weights = tf.Variable(
+ tf.random_normal([self.num_units, self.n_classes]))
+ out_bias = tf.Variable(tf.random_normal([self.n_classes]))
+
+ # input image placeholder
+ x = tf.placeholder(
+ "float", [None, self.time_steps, self.n_input], name="INPUT_IMAGE")
+
+ # For dynamic_rnn, train with dynamic_rnn and inference with static_rnn.
+ # x is shaped [batch_size,time_steps,num_inputs]
+ if is_dynamic_rnn:
+ if is_train:
+ lstm_input = x
+ outputs, _ = tf.nn.dynamic_rnn(lstm_layer, lstm_input, dtype="float32")
+ outputs = tf.unstack(outputs, axis=1)
+ else:
+ lstm_input = tf.unstack(x, self.time_steps, 1)
+ outputs, _ = tf.nn.static_rnn(lstm_layer, lstm_input, dtype="float32")
+ else:
+ lstm_input = tf.unstack(x, self.time_steps, 1)
+ outputs, _ = tf.nn.static_rnn(lstm_layer, lstm_input, dtype="float32")
+
+ # Compute logits by multiplying outputs[-1] of shape [batch_size,num_units]
+ # by the softmax layer's out_weight of shape [num_units,n_classes]
+ # plus out_bias
+ prediction = tf.matmul(outputs[-1], out_weights) + out_bias
+ output_class = tf.nn.softmax(prediction, name="OUTPUT_CLASS")
+
+ return x, prediction, output_class
+
+ def trainModel(self, x, prediction, output_class, sess):
+ # input label placeholder
+ y = tf.placeholder("float", [None, self.n_classes])
+ # Loss function
+ loss = tf.reduce_mean(
+ tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=y))
+ # Optimization
+ opt = tf.train.AdamOptimizer(
+ learning_rate=self.learning_rate).minimize(loss)
+
+ # Initialize variables
+ init = tf.global_variables_initializer()
+ sess.run(init)
+ for _ in range(TRAIN_STEPS):
+ batch_x, batch_y = self.mnist.train.next_batch(
+ batch_size=self.batch_size, shuffle=False)
+
+ batch_x = batch_x.reshape((self.batch_size, self.time_steps,
+ self.n_input))
+ sess.run(opt, feed_dict={x: batch_x, y: batch_y})
+
+ def saveAndRestoreModel(self, lstm_layer, sess, saver, is_dynamic_rnn):
+ model_dir = tempfile.mkdtemp()
+ saver.save(sess, model_dir)
+
+ # Reset the graph.
+ tf.reset_default_graph()
+ x, prediction, output_class = self.buildModel(
+ lstm_layer, is_dynamic_rnn, is_train=False)
+
+ new_sess = tf.Session(config=CONFIG)
+ saver = tf.train.Saver()
+ saver.restore(new_sess, model_dir)
+ return x, prediction, output_class, new_sess
+
+ def getInferenceResult(self, x, output_class, sess):
+ b1, _ = self.mnist.train.next_batch(batch_size=1)
+ sample_input = np.reshape(b1, (1, self.time_steps, self.n_input))
+
+ expected_output = sess.run(output_class, feed_dict={x: sample_input})
+ frozen_graph = tf.graph_util.convert_variables_to_constants(
+ sess, sess.graph_def, [output_class.op.name])
+ return sample_input, expected_output, frozen_graph
+
+ def tfliteInvoke(self, graph, test_inputs, outputs):
+ tf.reset_default_graph()
+ # Turn the input into placeholder of shape 1
+ tflite_input = tf.placeholder(
+ "float", [1, self.time_steps, self.n_input], name="INPUT_IMAGE_LITE")
+ tf.import_graph_def(graph, name="", input_map={"INPUT_IMAGE": tflite_input})
+ with tf.Session() as sess:
+ curr = sess.graph_def
+ curr = tf.contrib.lite.convert_op_hints_to_stubs(graph_def=curr)
+
+ curr = optimize_for_inference_lib.optimize_for_inference(
+ curr, ["INPUT_IMAGE_LITE"], ["OUTPUT_CLASS"],
+ [tf.float32.as_datatype_enum])
+
+ tflite = tf.contrib.lite.toco_convert(
+ curr, [tflite_input], [outputs], allow_custom_ops=False)
+ interpreter = tf.contrib.lite.Interpreter(model_content=tflite)
+
+ try:
+ interpreter.allocate_tensors()
+ except ValueError:
+ assert False
+
+ input_index = (interpreter.get_input_details()[0]["index"])
+ interpreter.set_tensor(input_index, test_inputs)
+ interpreter.invoke()
+ output_index = (interpreter.get_output_details()[0]["index"])
+ result = interpreter.get_tensor(output_index)
+ # Reset all variables so it will not pollute other inferences.
+ interpreter.reset_all_variables()
+ return result
+
+ def testStaticRnnMultiRnnCell(self):
+ sess = tf.Session(config=CONFIG)
+
+ x, prediction, output_class = self.buildModel(
+ self.buildLstmLayer(), is_dynamic_rnn=False, is_train=True)
+ self.trainModel(x, prediction, output_class, sess)
+
+ saver = tf.train.Saver()
+ x, prediction, output_class, new_sess = self.saveAndRestoreModel(
+ self.buildLstmLayer(), sess, saver, is_dynamic_rnn=False)
+
+ test_inputs, expected_output, frozen_graph = self.getInferenceResult(
+ x, output_class, new_sess)
+
+ result = self.tfliteInvoke(frozen_graph, test_inputs, output_class)
+ self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-3))
+
+ def testDynamicRnnMultiRnnCell(self):
+ sess = tf.Session(config=CONFIG)
+
+ x, prediction, output_class = self.buildModel(
+ self.buildLstmLayer(), is_dynamic_rnn=True, is_train=True)
+ self.trainModel(x, prediction, output_class, sess)
+
+ # Since we don't yet support OpHints for dynamic, we will load the model
+ # back in as a static model. This requires the variables to have the same
+ # names as if they were trained as a static. Thus, we get rid of while/rnn
+ # names.
+ variables_to_save = {}
+ for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
+ op_name = i.name
+ if op_name.startswith("while/rnn/"):
+ op_name = op_name.split("while/rnn/")[1]
+ if op_name.endswith(":0"):
+ op_name = op_name.split(":0")[0]
+ variables_to_save[op_name] = i
+ saver = tf.train.Saver(variables_to_save)
+
+ x, prediction, output_class, new_sess = self.saveAndRestoreModel(
+ self.buildLstmLayer(), sess, saver, is_dynamic_rnn=True)
+
+ test_inputs, expected_output, frozen_graph = self.getInferenceResult(
+ x, output_class, new_sess)
+
+ result = self.tfliteInvoke(frozen_graph, test_inputs, output_class)
+ self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-3))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 40cd6dea82..47faa20a29 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -239,6 +239,12 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op,
}
break;
}
+ case OperatorType::kUnidirectionalSequenceLstm: {
+ const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
+ if (data_type != ArrayDataType::kFloat) return ::tensorflow::Status::OK();
+ SetDataTypeForAllOutputs(model, op, data_type);
+ break;
+ }
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 5496e2093e..e861df2b3d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -946,6 +946,49 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
.copy_shape(activ_temp_shape);
}
+void ProcessUnidirectionalSequenceLstmOperator(
+ Model* model, UnidirectionalSequenceLstmOperator* op) {
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.has_shape()) {
+ // Shape already propagated
+ return;
+ }
+
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been set by PropagateArrayDataTypes
+ return;
+ }
+
+ // TODO(renjieliu): check the inputs, as well as all kinds of weights.
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ const int batch_size = input_shape.dims(1);
+ const int timestamp = input_shape.dims(0);
+
+ const auto& recurrent_to_output_weights_array =
+ model->GetArray(op->inputs[8]);
+ // Yield until input dims have been resolved.
+ if (!recurrent_to_output_weights_array.has_shape()) {
+ return;
+ }
+
+ constexpr int kInputActivationStateTensor = 18;
+ constexpr int kInputCellStateTensor = 19;
+ // b(115961645): This is a hack to work around.
+ model->GetArray(op->inputs[kInputActivationStateTensor]).buffer.reset();
+ model->GetArray(op->inputs[kInputCellStateTensor]).buffer.reset();
+
+ const auto& output_weights_shape = recurrent_to_output_weights_array.shape();
+ const int output_size = output_weights_shape.dims(1);
+
+ Shape* output_shape = output_array.mutable_shape();
+ output_shape->ReplaceDims({timestamp, batch_size, output_size});
+}
+
void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
@@ -1800,6 +1843,10 @@ void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
ProcessResizeBilinearOperator(model,
static_cast<ResizeBilinearOperator*>(op));
break;
+ case OperatorType::kUnidirectionalSequenceLstm:
+ ProcessUnidirectionalSequenceLstmOperator(
+ model, static_cast<UnidirectionalSequenceLstmOperator*>(op));
+ break;
case OperatorType::kLstmCell:
ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
break;
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 32f22e1ea0..6b195cc992 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -43,6 +43,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/session_options.h"
@@ -2002,6 +2003,48 @@ tensorflow::Status ConvertCTCBeamSearchDecoderOperator(
return tensorflow::Status::OK();
}
+// This isn't a TensorFlow builtin op. Currently this node can only be generated
+// with TfLite OpHint API.
+tensorflow::Status ConvertUnidirectionalSequenceLstm(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm");
+
+ auto* op = new UnidirectionalSequenceLstmOperator();
+ const auto& indices = GetListAttr(node, "_tflite_input_indices");
+ if (indices.i_size() != node.input().size()) {
+ return tensorflow::errors::InvalidArgument("Input size does not match.");
+ }
+
+ // The input size needs to be the same as the TfLite UniDirectionalSequence
+ // Lstm implementation.
+ const int kInputsSize = 20;
+
+ op->inputs.resize(kInputsSize);
+ std::vector<bool> done(kInputsSize);
+ int idx = 0;
+ for (const string& input : node.input()) {
+ int real_index = indices.i(idx);
+ op->inputs[real_index] = (input);
+ done[real_index] = true;
+ idx++;
+ }
+
+ for (int idx = 0; idx < done.size(); idx++) {
+ if (!done[idx]) {
+ string optional_name = node.name() + "_" + std::to_string(idx);
+ model->CreateOptionalArray(optional_name);
+ op->inputs[idx] = optional_name;
+ }
+ }
+
+ // There're three outputs, only the last one is required.
+ op->outputs.push_back(node.name() + ":2");
+ model->operators.emplace_back(op);
+
+ return tensorflow::Status::OK();
+}
+
} // namespace
namespace internal {
@@ -2121,6 +2164,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Transpose", ConvertSimpleOperator<TransposeOperator, 2>},
{"Unpack", ConvertUnpackOperator},
{"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1>},
+ {"UnidirectionalSequenceLstm", ConvertUnidirectionalSequenceLstm},
});
}
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 61f1f095e9..f3b84430db 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -58,6 +58,7 @@ enum class OperatorType : uint8 {
kL2Normalization,
kL2Pool,
kLstmCell,
+ kUnidirectionalSequenceLstm,
kLocalResponseNormalization,
kLog,
kLogistic,
@@ -635,6 +636,11 @@ struct LstmCellOperator : Operator {
KernelType kernel_type;
};
+struct UnidirectionalSequenceLstmOperator : Operator {
+ UnidirectionalSequenceLstmOperator()
+ : Operator(OperatorType::kUnidirectionalSequenceLstm) {}
+};
+
// Element-wise multiplication operator.
//
// Inputs:
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index ed37535fe0..e08a61d357 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -741,6 +741,42 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
}
};
+class UnidirectionalSequenceLstm
+ : public BuiltinOperator<
+ UnidirectionalSequenceLstmOperator,
+ ::tflite::UnidirectionalSequenceLSTMOptions,
+ ::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ // Current toco converter only supports tanh, no clip.
+ return ::tflite::CreateUnidirectionalSequenceLSTMOptions(
+ *builder, /*fused_activation_function=*/
+ ::tflite::ActivationFunctionType_TANH,
+ /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ // Only support tanh activation, so check that tflite type is tanh.
+ DCHECK(options.fused_activation_function() ==
+ ::tflite::ActivationFunctionType_TANH);
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+
+ std::vector<bool> GetMutatingInputVariables(
+ const Operator& op) const override {
+ std::vector<bool> mutating_input_variables(op.inputs.size(), false);
+ mutating_input_variables[kInputActivationStateTensor] = true;
+ mutating_input_variables[kInputCellStateTensor] = true;
+ return mutating_input_variables;
+ }
+};
+
class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
@@ -1435,6 +1471,9 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
OperatorType::kFakeQuant));
ops.push_back(
MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
+ ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>(
+ ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
+ OperatorType::kUnidirectionalSequenceLstm));
ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT,
OperatorType::kOneHot));
ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 083a96ad9d..61aa311212 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -407,6 +407,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder)
HANDLE_OPERATORTYPENAME_CASE(Unpack)
HANDLE_OPERATORTYPENAME_CASE(ZerosLike)
+ HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceLstm)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE
@@ -898,12 +899,12 @@ void CheckNoMissingArray(const Model& model) {
void FixNoMissingArray(Model* model) {
for (const auto& op : model->operators) {
for (const auto& input : op->inputs) {
- if (!model->HasArray(input)) {
+ if (!model->HasArray(input) && !model->IsOptionalArray(input)) {
model->GetOrCreateArray(input);
}
}
for (const auto& output : op->outputs) {
- if (!model->HasArray(output)) {
+ if (!model->HasArray(output) && !model->IsOptionalArray(output)) {
model->GetOrCreateArray(output);
}
}
diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py
index c6ef82ccdc..45106b35fc 100644
--- a/tensorflow/tools/pip_package/pip_smoke_test.py
+++ b/tensorflow/tools/pip_package/pip_smoke_test.py
@@ -85,6 +85,10 @@ BLACKLIST = [
# contrib
"//tensorflow/contrib/session_bundle:session_bundle_half_plus_two",
"//tensorflow/contrib/keras:testing_utils",
+ "//tensorflow/contrib/lite/experimental/examples/lstm:tflite_lstm",
+ "//tensorflow/contrib/lite/experimental/examples/lstm:tflite_lstm.py",
+ "//tensorflow/contrib/lite/experimental/examples/lstm:unidirectional_sequence_lstm_test", # pylint:disable=line-too-long
+ "//tensorflow/contrib/lite/experimental/examples/lstm:unidirectional_sequence_lstm_test.py", # pylint:disable=line-too-long
"//tensorflow/contrib/lite/python:interpreter",
"//tensorflow/contrib/lite/python:interpreter_test",
"//tensorflow/contrib/lite/python:interpreter.py",