diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-09 20:05:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 20:08:54 -0700 |
commit | 854ae599743a1e92a31ad49cfe42c6454cefd3b9 (patch) | |
tree | 1ff75695f61c5eb3353e739295e81f76bbe28a64 | |
parent | 58fcfc98cd59ae3952399fc55380b8733df08df9 (diff) |
Use Ophints to support TfLite UnidirectionaSequenceLstm and add an e2e test.
Support peephole and num_proj as well.
PiperOrigin-RevId: 216467578
10 files changed, 811 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/experimental/examples/lstm/BUILD b/tensorflow/contrib/lite/experimental/examples/lstm/BUILD new file mode 100644 index 0000000000..2125f218ca --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/lstm/BUILD @@ -0,0 +1,40 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "tflite_lstm", + srcs = ["tflite_lstm.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/lite/python:lite", + "//tensorflow/python:framework", + "@six_archive//:six", + ], +) + +py_test( + name = "unidirectional_sequence_lstm_test", + size = "large", + srcs = ["unidirectional_sequence_lstm_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_oss", + "no_pip", + ], + deps = [ + ":tflite_lstm", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/lite/python:lite", + "//tensorflow/examples/tutorials/mnist:input_data", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform", + "//tensorflow/python/tools:optimize_for_inference", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) diff --git a/tensorflow/contrib/lite/experimental/examples/lstm/tflite_lstm.py b/tensorflow/contrib/lite/experimental/examples/lstm/tflite_lstm.py new file mode 100644 index 0000000000..2357743266 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/lstm/tflite_lstm.py @@ -0,0 +1,396 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TfLite LSTMCell wrapper. + +TODO(renjieliu): Find a better home for this one. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tensorflow as tf + +from tensorflow.contrib.lite.python import lite +from tensorflow.python.keras import activations +from tensorflow.python.keras import initializers +from tensorflow.python.layers import base as base_layer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import rnn_cell_impl +from tensorflow.python.platform import tf_logging as logging + + +class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell): + """Long short-term memory unit (LSTM) recurrent network cell. + + This is used only for TfLite, it provides hints and it also makes the + variables in the desired for the tflite ops (transposed and seaparated). + + The default non-peephole implementation is based on: + + https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf + + Felix Gers, Jurgen Schmidhuber, and Fred Cummins. + "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999. + + The peephole implementation is based on: + + https://research.google.com/pubs/archive/43905.pdf + + Hasim Sak, Andrew Senior, and Francoise Beaufays. + "Long short-term memory recurrent neural network architectures for + large scale acoustic modeling." INTERSPEECH, 2014. + + The class uses optional peep-hole connections, optional cell clipping, and + an optional projection layer. + + Note that this cell is not optimized for performance. Please use + `tf.contrib.cudnn_rnn.CudnnLSTM` for better performance on GPU, or + `tf.contrib.rnn.LSTMBlockCell` and `tf.contrib.rnn.LSTMBlockFusedCell` for + better performance on CPU. + """ + + def __init__(self, + num_units, + use_peepholes=False, + cell_clip=None, + initializer=None, + num_proj=None, + proj_clip=None, + num_unit_shards=None, + num_proj_shards=None, + forget_bias=1.0, + state_is_tuple=True, + activation=None, + reuse=None, + name=None, + dtype=None): + """Initialize the parameters for an LSTM cell. + + Args: + num_units: int, The number of units in the LSTM cell. + use_peepholes: bool, set True to enable diagonal/peephole connections. + cell_clip: (optional) A float value, if provided the cell state is clipped + by this value prior to the cell output activation. + initializer: (optional) The initializer to use for the weight and + projection matrices. + num_proj: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. + proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is + provided, then the projected values are clipped elementwise to within + `[-proj_clip, proj_clip]`. + num_unit_shards: Deprecated, will be removed by Jan. 2017. Use a + variable_scope partitioner instead. + num_proj_shards: Deprecated, will be removed by Jan. 2017. Use a + variable_scope partitioner instead. + forget_bias: Biases of the forget gate are initialized by default to 1 in + order to reduce the scale of forgetting at the beginning of the + training. Must set it manually to `0.0` when restoring from CudnnLSTM + trained checkpoints. + state_is_tuple: If True, accepted and returned states are 2-tuples of the + `c_state` and `m_state`. If False, they are concatenated along the + column axis. This latter behavior will soon be deprecated. + activation: Activation function of the inner states. Default: `tanh`. + reuse: (optional) Python boolean describing whether to reuse variables in + an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + name: String, the name of the layer. Layers with the same name will share + weights, but to avoid mistakes we require reuse=True in such cases. + dtype: Default dtype of the layer (default of `None` means use the type of + the first input). Required when `build` is called before `call`. When + restoring from CudnnLSTM-trained checkpoints, use + `CudnnCompatibleLSTMCell` instead. + """ + super(TFLiteLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) + # TODO(raziel): decide if we want to just support tuples (yes please!). + if not state_is_tuple: + logging.warn( + "%s: Using a concatenated state is slower and will soon be " + "deprecated. Use state_is_tuple=True.", self) + if num_unit_shards is not None or num_proj_shards is not None: + logging.warn( + "%s: The num_unit_shards and proj_unit_shards parameters are " + "deprecated and will be removed in Jan 2017. " + "Use a variable scope with a partitioner instead.", self) + + # Inputs must be 2-dimensional. + # TODO(raziel): layers stuff -- chop if un-layerizing Op. + self.input_spec = base_layer.InputSpec(ndim=2) + + self._tflite_wrapper = lite.OpHint("UnidirectionalSequenceLstm") + + self._num_units = num_units + self._use_peepholes = use_peepholes + self._cell_clip = cell_clip + self._initializer = initializer + self._num_proj = num_proj + self._proj_clip = proj_clip + self._num_unit_shards = num_unit_shards + self._num_proj_shards = num_proj_shards + self._forget_bias = forget_bias + self._state_is_tuple = state_is_tuple + self._activation = activation or math_ops.tanh + + self._output_size = num_proj if num_proj else num_units + self._state_size = ( + tf.nn.rnn_cell.LSTMStateTuple(num_units, self._output_size) + if state_is_tuple else num_units + self._output_size) + + @property + def state_size(self): + return self._state_size + + @property + def output_size(self): + return self._output_size + + def build(self, inputs_shape): + """Build TfLite LSTM cell graph. + + Args: + inputs_shape: The inputs_shape must be known, and is [batch_size, + input_size] shape. + + Raises: + ValueError: if the inputs_shape is invalid. + """ + if len(inputs_shape) != 2 or inputs_shape[1].value is None: + raise ValueError("Invalid inputs_shape, saw shape: %s" % inputs_shape) + + input_depth = inputs_shape[1].value + maybe_partitioner = ( + partitioned_variables.fixed_size_partitioner(self._num_unit_shards) + if self._num_unit_shards is not None else None) + input_weight_shape = [self._num_units, input_depth] + cell_weight_shape = [self._num_units, self._output_size] + bias_shape = [self._num_units] + + def add_variable_wrapped(name, shape, initializer, index, partitioner): + var = self.add_variable( + name, shape=shape, initializer=initializer, partitioner=partitioner) + return self._tflite_wrapper.add_input( + var, name="name", index_override=index) + + weight_initializer = self._initializer + if self.dtype is None: + bias_initializer = init_ops.zeros_initializer + else: + bias_initializer = init_ops.zeros_initializer(dtype=self.dtype) + + self.input_to_input_w = add_variable_wrapped( + "input_to_input_w", input_weight_shape, weight_initializer, 1, + maybe_partitioner) + self.input_to_forget_w = add_variable_wrapped( + "input_to_forget_w", input_weight_shape, weight_initializer, 2, + maybe_partitioner) + self.input_to_cell_w = add_variable_wrapped( + "input_to_cell_w", input_weight_shape, weight_initializer, 3, + maybe_partitioner) + self.input_to_output_w = add_variable_wrapped( + "input_to_output_w", input_weight_shape, weight_initializer, 4, + maybe_partitioner) + self.cell_to_input_w = add_variable_wrapped( + "cell_to_input_w", cell_weight_shape, weight_initializer, 5, + maybe_partitioner) + self.cell_to_forget_w = add_variable_wrapped( + "cell_to_forget_w", cell_weight_shape, weight_initializer, 6, + maybe_partitioner) + self.cell_to_cell_w = add_variable_wrapped( + "cell_to_cell_w", cell_weight_shape, weight_initializer, 7, + maybe_partitioner) + self.cell_to_output_w = add_variable_wrapped( + "cell_to_output_w", cell_weight_shape, weight_initializer, 8, + maybe_partitioner) + + self.input_bias = add_variable_wrapped( + "input_bias", bias_shape, bias_initializer, 12, maybe_partitioner) + self.forget_bias = add_variable_wrapped( + "forget_bias", bias_shape, bias_initializer, 13, maybe_partitioner) + self.cell_bias = add_variable_wrapped( + "cell_bias", bias_shape, bias_initializer, 14, maybe_partitioner) + self.output_bias = add_variable_wrapped( + "output_bias", bias_shape, bias_initializer, 15, maybe_partitioner) + + # index 9, 10, 11. + # f stands for forget, i stands for input and o stands for output. + if self._use_peepholes: + self._w_f_diag = add_variable_wrapped("w_f_diag", [self._num_units], + self._initializer, 9, + maybe_partitioner) + self._w_i_diag = add_variable_wrapped("w_i_diag", [self._num_units], + self._initializer, 10, + maybe_partitioner) + self._w_o_diag = add_variable_wrapped("w_o_diag", [self._num_units], + self._initializer, 11, + maybe_partitioner) + + # index 16 for proj kernel. + if self._num_proj is not None: + maybe_proj_partitioner = ( + partitioned_variables.fixed_size_partitioner(self._num_proj_shards) + if self._num_proj_shards is not None else None) + self._proj_kernel = add_variable_wrapped( + "projection/kernel", [self._num_proj, self._num_units], + self._initializer, + 16, + partitioner=maybe_proj_partitioner) + + self.built = True + + def call(self, inputs, state): + """Run one step of LSTM. + + Args: + inputs: input Tensor, 2D, `[batch, num_units]`. + state: if `state_is_tuple` is False, this must be a state Tensor, `2-D, + [batch, state_size]`. If `state_is_tuple` is True, this must be a tuple + of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. + + Returns: + A tuple containing: + + - A `2-D, [batch, output_dim]`, Tensor representing the output of the + LSTM after reading `inputs` when previous state was `state`. + Here output_dim is: + num_proj if num_proj was set, + num_units otherwise. + - Tensor(s) representing the new state of LSTM after reading `inputs` when + the previous state was `state`. Same type and shape(s) as `state`. + + Raises: + ValueError: If input size cannot be inferred from inputs via + static shape inference. + """ + inputs = self._tflite_wrapper.add_input( + inputs, tag="input", name="input", aggregate="stack", index_override=0) + + # Make sure inputs and bias_initializer has the same type. + assert inputs.dtype == self.input_to_input_w.dtype + + num_proj = self._num_units if self._num_proj is None else self._num_proj + sigmoid = math_ops.sigmoid + + if self._state_is_tuple: + (c_prev, m_prev) = state + else: + c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) + m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) + + # Note: For TfLite, cell_state is at index 19 while activation state at + # index 18. + c_prev = self._tflite_wrapper.add_input( + c_prev, + tag="c_prev", + name="c_prev", + aggregate="first", + index_override=19) + m_prev = self._tflite_wrapper.add_input( + m_prev, + tag="m_prev", + name="m_prev", + aggregate="first", + index_override=18) + + input_size = inputs.get_shape().with_rank(2)[1] + if input_size.value is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + + inputs_and_m_prev = array_ops.concat([inputs, m_prev], axis=1) + + # i stands for input gate. + # f stands for forget gate activation. + # o outputs. + # j output of LSTM unit. + # c is the final state. + # m is the output. + i = nn_ops.bias_add( + tf.matmul( + inputs_and_m_prev, + tf.concat([self.input_to_input_w, self.cell_to_input_w], axis=1), + transpose_b=True), self.input_bias) + f = nn_ops.bias_add( + tf.matmul( + inputs_and_m_prev, + tf.concat([self.input_to_forget_w, self.cell_to_forget_w], axis=1), + transpose_b=True), self.forget_bias) + o = nn_ops.bias_add( + tf.matmul( + inputs_and_m_prev, + tf.concat([self.input_to_output_w, self.cell_to_output_w], axis=1), + transpose_b=True), self.output_bias) + j = nn_ops.bias_add( + tf.matmul( + inputs_and_m_prev, + tf.concat([self.input_to_cell_w, self.cell_to_cell_w], axis=1), + transpose_b=True), self.cell_bias) + + # Diagonal connections + if self._use_peepholes: + c = ( + sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev + + sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) + else: + c = ( + sigmoid(f + self._forget_bias) * c_prev + + sigmoid(i) * self._activation(j)) + + if self._cell_clip is not None: + # pylint: disable=invalid-unary-operand-type + c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) + # pylint: enable=invalid-unary-operand-type + if self._use_peepholes: + m = sigmoid(o + self._w_o_diag * c) * self._activation(c) + else: + m = sigmoid(o) * self._activation(c) + + if self._num_proj is not None: + transposed_proj_kernel = tf.transpose(self._proj_kernel) + m = math_ops.matmul(m, transposed_proj_kernel) + + if self._proj_clip is not None: + # pylint: disable=invalid-unary-operand-type + m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) + # pylint: enable=invalid-unary-operand-type + + c = self._tflite_wrapper.add_output( + c, tag="c", name="c", aggregate="last", index_override=1) + m = self._tflite_wrapper.add_output( + m, tag="m", name="m", index_override=2, aggregate="stack") + + new_state = ( + tf.nn.rnn_cell.LSTMStateTuple(c, m) + if self._state_is_tuple else array_ops.concat([c, m], 1)) + return m, new_state + + def get_config(self): + config = { + "num_units": self._num_units, + "use_peepholes": self._use_peepholes, + "cell_clip": self._cell_clip, + "initializer": initializers.serialize(self._initializer), + "num_proj": self._num_proj, + "proj_clip": self._proj_clip, + "num_unit_shards": self._num_unit_shards, + "num_proj_shards": self._num_proj_shards, + "forget_bias": self._forget_bias, + "state_is_tuple": self._state_is_tuple, + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(TFLiteLSTMCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/tensorflow/contrib/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py b/tensorflow/contrib/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py new file mode 100644 index 0000000000..2ca977518c --- /dev/null +++ b/tensorflow/contrib/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py @@ -0,0 +1,226 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tempfile +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.lite.experimental.examples.lstm.tflite_lstm import TFLiteLSTMCell +from tensorflow.examples.tutorials.mnist import input_data +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test +from tensorflow.python.tools import optimize_for_inference_lib + +# Number of steps to train model. +TRAIN_STEPS = 1 + +CONFIG = tf.ConfigProto(device_count={"GPU": 0}) + + +class UnidirectionalSequenceLstmTest(test_util.TensorFlowTestCase): + + def setUp(self): + tf.reset_default_graph() + # Import MNIST dataset + self.mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) + + # Define constants + # Unrolled through 28 time steps + self.time_steps = 28 + # Rows of 28 pixels + self.n_input = 28 + # Learning rate for Adam optimizer + self.learning_rate = 0.001 + # MNIST is meant to be classified in 10 classes(0-9). + self.n_classes = 10 + # Batch size + self.batch_size = 16 + # Lstm Units. + self.num_units = 64 + + def buildLstmLayer(self): + return tf.nn.rnn_cell.MultiRNNCell([ + TFLiteLSTMCell( + self.num_units, use_peepholes=True, forget_bias=0, name="rnn1"), + TFLiteLSTMCell(self.num_units, num_proj=64, forget_bias=0, name="rnn2"), + TFLiteLSTMCell( + self.num_units // 2, + use_peepholes=True, + num_proj=64, + forget_bias=0, + name="rnn3"), + TFLiteLSTMCell(self.num_units, forget_bias=0, name="rnn4") + ]) + + def buildModel(self, lstm_layer, is_dynamic_rnn, is_train): + # Weights and biases for output softmax layer. + out_weights = tf.Variable( + tf.random_normal([self.num_units, self.n_classes])) + out_bias = tf.Variable(tf.random_normal([self.n_classes])) + + # input image placeholder + x = tf.placeholder( + "float", [None, self.time_steps, self.n_input], name="INPUT_IMAGE") + + # For dynamic_rnn, train with dynamic_rnn and inference with static_rnn. + # x is shaped [batch_size,time_steps,num_inputs] + if is_dynamic_rnn: + if is_train: + lstm_input = x + outputs, _ = tf.nn.dynamic_rnn(lstm_layer, lstm_input, dtype="float32") + outputs = tf.unstack(outputs, axis=1) + else: + lstm_input = tf.unstack(x, self.time_steps, 1) + outputs, _ = tf.nn.static_rnn(lstm_layer, lstm_input, dtype="float32") + else: + lstm_input = tf.unstack(x, self.time_steps, 1) + outputs, _ = tf.nn.static_rnn(lstm_layer, lstm_input, dtype="float32") + + # Compute logits by multiplying outputs[-1] of shape [batch_size,num_units] + # by the softmax layer's out_weight of shape [num_units,n_classes] + # plus out_bias + prediction = tf.matmul(outputs[-1], out_weights) + out_bias + output_class = tf.nn.softmax(prediction, name="OUTPUT_CLASS") + + return x, prediction, output_class + + def trainModel(self, x, prediction, output_class, sess): + # input label placeholder + y = tf.placeholder("float", [None, self.n_classes]) + # Loss function + loss = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=y)) + # Optimization + opt = tf.train.AdamOptimizer( + learning_rate=self.learning_rate).minimize(loss) + + # Initialize variables + init = tf.global_variables_initializer() + sess.run(init) + for _ in range(TRAIN_STEPS): + batch_x, batch_y = self.mnist.train.next_batch( + batch_size=self.batch_size, shuffle=False) + + batch_x = batch_x.reshape((self.batch_size, self.time_steps, + self.n_input)) + sess.run(opt, feed_dict={x: batch_x, y: batch_y}) + + def saveAndRestoreModel(self, lstm_layer, sess, saver, is_dynamic_rnn): + model_dir = tempfile.mkdtemp() + saver.save(sess, model_dir) + + # Reset the graph. + tf.reset_default_graph() + x, prediction, output_class = self.buildModel( + lstm_layer, is_dynamic_rnn, is_train=False) + + new_sess = tf.Session(config=CONFIG) + saver = tf.train.Saver() + saver.restore(new_sess, model_dir) + return x, prediction, output_class, new_sess + + def getInferenceResult(self, x, output_class, sess): + b1, _ = self.mnist.train.next_batch(batch_size=1) + sample_input = np.reshape(b1, (1, self.time_steps, self.n_input)) + + expected_output = sess.run(output_class, feed_dict={x: sample_input}) + frozen_graph = tf.graph_util.convert_variables_to_constants( + sess, sess.graph_def, [output_class.op.name]) + return sample_input, expected_output, frozen_graph + + def tfliteInvoke(self, graph, test_inputs, outputs): + tf.reset_default_graph() + # Turn the input into placeholder of shape 1 + tflite_input = tf.placeholder( + "float", [1, self.time_steps, self.n_input], name="INPUT_IMAGE_LITE") + tf.import_graph_def(graph, name="", input_map={"INPUT_IMAGE": tflite_input}) + with tf.Session() as sess: + curr = sess.graph_def + curr = tf.contrib.lite.convert_op_hints_to_stubs(graph_def=curr) + + curr = optimize_for_inference_lib.optimize_for_inference( + curr, ["INPUT_IMAGE_LITE"], ["OUTPUT_CLASS"], + [tf.float32.as_datatype_enum]) + + tflite = tf.contrib.lite.toco_convert( + curr, [tflite_input], [outputs], allow_custom_ops=False) + interpreter = tf.contrib.lite.Interpreter(model_content=tflite) + + try: + interpreter.allocate_tensors() + except ValueError: + assert False + + input_index = (interpreter.get_input_details()[0]["index"]) + interpreter.set_tensor(input_index, test_inputs) + interpreter.invoke() + output_index = (interpreter.get_output_details()[0]["index"]) + result = interpreter.get_tensor(output_index) + # Reset all variables so it will not pollute other inferences. + interpreter.reset_all_variables() + return result + + def testStaticRnnMultiRnnCell(self): + sess = tf.Session(config=CONFIG) + + x, prediction, output_class = self.buildModel( + self.buildLstmLayer(), is_dynamic_rnn=False, is_train=True) + self.trainModel(x, prediction, output_class, sess) + + saver = tf.train.Saver() + x, prediction, output_class, new_sess = self.saveAndRestoreModel( + self.buildLstmLayer(), sess, saver, is_dynamic_rnn=False) + + test_inputs, expected_output, frozen_graph = self.getInferenceResult( + x, output_class, new_sess) + + result = self.tfliteInvoke(frozen_graph, test_inputs, output_class) + self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-3)) + + def testDynamicRnnMultiRnnCell(self): + sess = tf.Session(config=CONFIG) + + x, prediction, output_class = self.buildModel( + self.buildLstmLayer(), is_dynamic_rnn=True, is_train=True) + self.trainModel(x, prediction, output_class, sess) + + # Since we don't yet support OpHints for dynamic, we will load the model + # back in as a static model. This requires the variables to have the same + # names as if they were trained as a static. Thus, we get rid of while/rnn + # names. + variables_to_save = {} + for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): + op_name = i.name + if op_name.startswith("while/rnn/"): + op_name = op_name.split("while/rnn/")[1] + if op_name.endswith(":0"): + op_name = op_name.split(":0")[0] + variables_to_save[op_name] = i + saver = tf.train.Saver(variables_to_save) + + x, prediction, output_class, new_sess = self.saveAndRestoreModel( + self.buildLstmLayer(), sess, saver, is_dynamic_rnn=True) + + test_inputs, expected_output, frozen_graph = self.getInferenceResult( + x, output_class, new_sess) + + result = self.tfliteInvoke(frozen_graph, test_inputs, output_class) + self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-3)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 40cd6dea82..47faa20a29 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -239,6 +239,12 @@ void SetDataTypeForAllOutputs(Model* model, Operator* op, } break; } + case OperatorType::kUnidirectionalSequenceLstm: { + const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; + if (data_type != ArrayDataType::kFloat) return ::tensorflow::Status::OK(); + SetDataTypeForAllOutputs(model, op, data_type); + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 5496e2093e..e861df2b3d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -946,6 +946,49 @@ void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) { .copy_shape(activ_temp_shape); } +void ProcessUnidirectionalSequenceLstmOperator( + Model* model, UnidirectionalSequenceLstmOperator* op) { + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // Shape already propagated + return; + } + + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes + return; + } + + // TODO(renjieliu): check the inputs, as well as all kinds of weights. + const auto& input_array = model->GetArray(op->inputs[0]); + // Yield until input dims have been resolved. + if (!input_array.has_shape()) { + return; + } + const auto& input_shape = input_array.shape(); + const int batch_size = input_shape.dims(1); + const int timestamp = input_shape.dims(0); + + const auto& recurrent_to_output_weights_array = + model->GetArray(op->inputs[8]); + // Yield until input dims have been resolved. + if (!recurrent_to_output_weights_array.has_shape()) { + return; + } + + constexpr int kInputActivationStateTensor = 18; + constexpr int kInputCellStateTensor = 19; + // b(115961645): This is a hack to work around. + model->GetArray(op->inputs[kInputActivationStateTensor]).buffer.reset(); + model->GetArray(op->inputs[kInputCellStateTensor]).buffer.reset(); + + const auto& output_weights_shape = recurrent_to_output_weights_array.shape(); + const int output_size = output_weights_shape.dims(1); + + Shape* output_shape = output_array.mutable_shape(); + output_shape->ReplaceDims({timestamp, batch_size, output_size}); +} + void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) { const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. @@ -1800,6 +1843,10 @@ void ProcessUnpackOperator(Model* model, UnpackOperator* op) { ProcessResizeBilinearOperator(model, static_cast<ResizeBilinearOperator*>(op)); break; + case OperatorType::kUnidirectionalSequenceLstm: + ProcessUnidirectionalSequenceLstmOperator( + model, static_cast<UnidirectionalSequenceLstmOperator*>(op)); + break; case OperatorType::kLstmCell: ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op)); break; diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 32f22e1ea0..6b195cc992 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/session_options.h" @@ -2002,6 +2003,48 @@ tensorflow::Status ConvertCTCBeamSearchDecoderOperator( return tensorflow::Status::OK(); } +// This isn't a TensorFlow builtin op. Currently this node can only be generated +// with TfLite OpHint API. +tensorflow::Status ConvertUnidirectionalSequenceLstm( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + DCHECK_EQ(node.op(), "UnidirectionalSequenceLstm"); + + auto* op = new UnidirectionalSequenceLstmOperator(); + const auto& indices = GetListAttr(node, "_tflite_input_indices"); + if (indices.i_size() != node.input().size()) { + return tensorflow::errors::InvalidArgument("Input size does not match."); + } + + // The input size needs to be the same as the TfLite UniDirectionalSequence + // Lstm implementation. + const int kInputsSize = 20; + + op->inputs.resize(kInputsSize); + std::vector<bool> done(kInputsSize); + int idx = 0; + for (const string& input : node.input()) { + int real_index = indices.i(idx); + op->inputs[real_index] = (input); + done[real_index] = true; + idx++; + } + + for (int idx = 0; idx < done.size(); idx++) { + if (!done[idx]) { + string optional_name = node.name() + "_" + std::to_string(idx); + model->CreateOptionalArray(optional_name); + op->inputs[idx] = optional_name; + } + } + + // There're three outputs, only the last one is required. + op->outputs.push_back(node.name() + ":2"); + model->operators.emplace_back(op); + + return tensorflow::Status::OK(); +} + } // namespace namespace internal { @@ -2121,6 +2164,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Transpose", ConvertSimpleOperator<TransposeOperator, 2>}, {"Unpack", ConvertUnpackOperator}, {"ZerosLike", ConvertSimpleOperator<TensorFlowZerosLikeOperator, 1>}, + {"UnidirectionalSequenceLstm", ConvertUnidirectionalSequenceLstm}, }); } diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 61f1f095e9..f3b84430db 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -58,6 +58,7 @@ enum class OperatorType : uint8 { kL2Normalization, kL2Pool, kLstmCell, + kUnidirectionalSequenceLstm, kLocalResponseNormalization, kLog, kLogistic, @@ -635,6 +636,11 @@ struct LstmCellOperator : Operator { KernelType kernel_type; }; +struct UnidirectionalSequenceLstmOperator : Operator { + UnidirectionalSequenceLstmOperator() + : Operator(OperatorType::kUnidirectionalSequenceLstm) {} +}; + // Element-wise multiplication operator. // // Inputs: diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index ed37535fe0..e08a61d357 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -741,6 +741,42 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions, } }; +class UnidirectionalSequenceLstm + : public BuiltinOperator< + UnidirectionalSequenceLstmOperator, + ::tflite::UnidirectionalSequenceLSTMOptions, + ::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + // Current toco converter only supports tanh, no clip. + return ::tflite::CreateUnidirectionalSequenceLSTMOptions( + *builder, /*fused_activation_function=*/ + ::tflite::ActivationFunctionType_TANH, + /*cell_clip=*/0.0, + /*proj_clip=*/0.0); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + // Only support tanh activation, so check that tflite type is tanh. + DCHECK(options.fused_activation_function() == + ::tflite::ActivationFunctionType_TANH); + } + + int GetVersion(const Operator& op) const override { return 1; } + + std::vector<bool> GetMutatingInputVariables( + const Operator& op) const override { + std::vector<bool> mutating_input_variables(op.inputs.size(), false); + mutating_input_variables[kInputActivationStateTensor] = true; + mutating_input_variables[kInputCellStateTensor] = true; + return mutating_input_variables; + } +}; + class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions, ::tflite::BuiltinOptions_ReducerOptions> { public: @@ -1435,6 +1471,9 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList( OperatorType::kFakeQuant)); ops.push_back( MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack)); + ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>( + ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + OperatorType::kUnidirectionalSequenceLstm)); ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT, OperatorType::kOneHot)); ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK, diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 083a96ad9d..61aa311212 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -407,6 +407,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder) HANDLE_OPERATORTYPENAME_CASE(Unpack) HANDLE_OPERATORTYPENAME_CASE(ZerosLike) + HANDLE_OPERATORTYPENAME_CASE(UnidirectionalSequenceLstm) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE @@ -898,12 +899,12 @@ void CheckNoMissingArray(const Model& model) { void FixNoMissingArray(Model* model) { for (const auto& op : model->operators) { for (const auto& input : op->inputs) { - if (!model->HasArray(input)) { + if (!model->HasArray(input) && !model->IsOptionalArray(input)) { model->GetOrCreateArray(input); } } for (const auto& output : op->outputs) { - if (!model->HasArray(output)) { + if (!model->HasArray(output) && !model->IsOptionalArray(output)) { model->GetOrCreateArray(output); } } diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index c6ef82ccdc..45106b35fc 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -85,6 +85,10 @@ BLACKLIST = [ # contrib "//tensorflow/contrib/session_bundle:session_bundle_half_plus_two", "//tensorflow/contrib/keras:testing_utils", + "//tensorflow/contrib/lite/experimental/examples/lstm:tflite_lstm", + "//tensorflow/contrib/lite/experimental/examples/lstm:tflite_lstm.py", + "//tensorflow/contrib/lite/experimental/examples/lstm:unidirectional_sequence_lstm_test", # pylint:disable=line-too-long + "//tensorflow/contrib/lite/experimental/examples/lstm:unidirectional_sequence_lstm_test.py", # pylint:disable=line-too-long "//tensorflow/contrib/lite/python:interpreter", "//tensorflow/contrib/lite/python:interpreter_test", "//tensorflow/contrib/lite/python:interpreter.py", |