diff options
author | Yu-Cheng Ling <ycling@google.com> | 2018-08-15 13:47:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-15 13:51:01 -0700 |
commit | 6b13a15e0b9906abbb66f87d83db291d0099cb43 (patch) | |
tree | d076dfaff5cc98a84a61719d5b5ff9ea610bed8d /tensorflow/contrib/lite/experimental | |
parent | d2875ea71373d05c645587a83dd870fa8a0ec070 (diff) |
Ops API to author TFLite ops as functions directly.
PiperOrigin-RevId: 208875580
Diffstat (limited to 'tensorflow/contrib/lite/experimental')
-rw-r--r-- | tensorflow/contrib/lite/experimental/ops/BUILD | 32 | ||||
-rw-r--r-- | tensorflow/contrib/lite/experimental/ops/README.md | 14 | ||||
-rw-r--r-- | tensorflow/contrib/lite/experimental/ops/ops.py | 87 | ||||
-rw-r--r-- | tensorflow/contrib/lite/experimental/ops/ops_test.py | 109 |
4 files changed, 242 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/experimental/ops/BUILD b/tensorflow/contrib/lite/experimental/ops/BUILD new file mode 100644 index 0000000000..6bdc0c1506 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/ops/BUILD @@ -0,0 +1,32 @@ +package(default_visibility = [ + "//visibility:private", +]) + +licenses(["notice"]) # Apache 2.0 + +py_library( + name = "ops", + srcs = [ + "ops.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "ops_test", + srcs = ["ops_test.py"], + tags = [ + # The test fails with CUDA on CPU. A lot of other tensorflow tests + # also use this flag. Unsure if this is a real problem or false + # alarm. + # TODO(ycling): Investigate if this is a real problem before open + # sourcing. + "no_cuda_on_cpu_tap", + ], + deps = [ + ":ops", + "//tensorflow/contrib/lite/experimental/pb2lite/python:converter_wrapper", + ], +) diff --git a/tensorflow/contrib/lite/experimental/ops/README.md b/tensorflow/contrib/lite/experimental/ops/README.md new file mode 100644 index 0000000000..723d2f3e49 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/ops/README.md @@ -0,0 +1,14 @@ +The TensorFlow Lite ops API provides functions to author TFLite ops in a +TensorFlow directily. + +Each python function creates a TensorFlow Function node in the graph. +The node has a `_tflite_function_name` attribute to annotate which TensorFlow +Lite op should be used. + +The graph can be used in TensorFlow for training and inference directly. +After training is done, user can freeze the graph, and use the converter under +the `experimental/pb2lite` directory to convert the graph to TensorFlow Lite +model format. + +Warning: Everything in this directory is experimental and highly subject to +changes. diff --git a/tensorflow/contrib/lite/experimental/ops/ops.py b/tensorflow/contrib/lite/experimental/ops/ops.py new file mode 100644 index 0000000000..23fc31d68d --- /dev/null +++ b/tensorflow/contrib/lite/experimental/ops/ops.py @@ -0,0 +1,87 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Generating TF functions which can be 1:1 mapped to TFLite ops.""" + +import tensorflow as tf +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.util.compat import as_bytes + +tfe = tf.contrib.eager + + +def _annotate_tflite_op(outputs, op_name, options=None): + """Annotates a TF Node with TFLite options. + + Args: + outputs: The output(s) of the op. Can be a `Tensor` or a list of `Tensor`s. + op_name: The TensorFlow Lite Op name. E.g. 'FULLY_CONNECTED'. + options: A `dict` which contains TensorFlow Lite options. + + Raises: + ValueError: If unsupported option types are used. + """ + + # The outputs may be a `Tensor` or a list of tensors. + if isinstance(outputs, tf.Tensor): + op = outputs.op + else: + op = outputs[0].op + + # pylint: disable=protected-access + # Note: `as_bytes` conversion is required for Python 3. + op._set_attr( + '_tflite_function_name', attr_value_pb2.AttrValue(s=as_bytes(op_name))) + if options: + for key, value in options.items(): + if isinstance(value, str): + op._set_attr(key, attr_value_pb2.AttrValue(s=as_bytes(value))) + elif isinstance(value, int): + op._set_attr(key, attr_value_pb2.AttrValue(i=value)) + else: + raise ValueError('Unsupported option value type %s' % value.__class__) + # pylint: enable=protected-access + + +# TODO(ycling): Generate this interface with FlatBuffer reflection +# functionality and extend the coverage to all TFLiteops. +# TODO(ycling): Support optional tensors (e.g. not using missing bias). +def fully_connected(x, weights, bias, fused_activation_function='NONE'): + """Create a TF function node equalivent to TFLite FULLY_CONNECTED op.""" + options = {'_fused_activation_function': fused_activation_function} + + @tfe.defun + def tf_lite_fully_connected(x, weights, bias): + """The TFLite FULLY_CONNECTED logic wrapped in a TF Function.""" + # TFLite FULLY_CONNECTED definition is different from TF matmul. + # The weights are transposed. Therefore we need to transpose the + # weights inside the TF function to simulate TFLite behavior. + transposed_weights = tf.transpose(weights) + + y = tf.matmul(x, transposed_weights) + bias + if fused_activation_function == 'RELU': + y = tf.nn.relu(y) + elif fused_activation_function == 'NONE': + # Do nothing. + pass + else: + # TODO(ycling): Support other activation functions. + raise Exception('Unsupported fused_activation_function "%s"' % + fused_activation_function) + return y + + output = tf_lite_fully_connected(x, weights, bias) + _annotate_tflite_op(output, 'FULLY_CONNECTED', options=options) + + return output diff --git a/tensorflow/contrib/lite/experimental/ops/ops_test.py b/tensorflow/contrib/lite/experimental/ops/ops_test.py new file mode 100644 index 0000000000..6c632d745d --- /dev/null +++ b/tensorflow/contrib/lite/experimental/ops/ops_test.py @@ -0,0 +1,109 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import tensorflow as tf +from tensorflow.contrib.lite.experimental.ops import ops +from tensorflow.contrib.lite.experimental.pb2lite.python import converter_wrapper as converter +from tensorflow.contrib.lite.python.interpreter import Interpreter +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test +from tensorflow.python.util.compat import as_bytes + +ERROR_THRESHOLD = 1e-6 + + +def _get_tf_operations(): + return tf.get_default_graph().get_operations() + + +class OpsTest(test_util.TensorFlowTestCase): + + def setUp(self): + np.random.seed(0) + + def test_fully_connected(self): + num_batches = 10 + num_input_channels = 20 + num_output_channels = 5 + + # No operation initially. + self.assertEqual(len(_get_tf_operations()), 0) + + # Create 1 operation (Placeholder for the input). + input_shape = (num_batches, num_input_channels) + x = tf.placeholder(tf.float32, shape=input_shape) + self.assertEqual(len(_get_tf_operations()), 1) + + # Defining weights and bias as constants. It should add 2 more + # nodes into the graph. + weights_shape = (num_output_channels, num_input_channels) + weights_value = np.random.rand(*weights_shape).astype(np.float32) + weights = tf.constant(weights_value, dtype=tf.float32) + bias_shape = (num_output_channels,) + bias_value = np.random.rand(*bias_shape).astype(np.float32) + bias = tf.constant(bias_value, dtype=tf.float32) + self.assertEqual(len(_get_tf_operations()), 3) + + # Call the function to construct a TF Function node which is equivalent + # to TFLite FULLY_CONNECTED node. + output = ops.fully_connected( + x, weights, bias, fused_activation_function='RELU') + + # Exactly one op should be added. It should be a function containing 2-3 ops + # (matmul, add, relu). + operations = _get_tf_operations() + self.assertEqual(len(operations), 4) + + op = operations[-1] + node_def = op.node_def + # Note: `as_bytes` conversion is required for Python 3. + self.assertEqual(node_def.attr['_tflite_function_name'].s, + as_bytes('FULLY_CONNECTED')) + self.assertEqual(node_def.attr['_fused_activation_function'].s, + as_bytes('RELU')) + + # Try to run the TF session to get the output value. + input_value = np.random.rand(*input_shape).astype(np.float32) + with tf.Session() as sess: + output_value = sess.run(output, feed_dict={x: input_value}) + graph_def = sess.graph_def + + # Convert the GraphDef to FlatBuffer. + flatbuffer_data = converter.Convert(graph_def.SerializeToString()) + + # Construct an interpreter with the FlatBuffer. + interpreter = Interpreter(model_content=flatbuffer_data) + + # Invoke the interpreter. + input_details = interpreter.get_input_details() + input_index = input_details[0]['index'] + interpreter.resize_tensor_input(input_index, input_shape) + interpreter.allocate_tensors() + interpreter.set_tensor(input_index, input_value) + interpreter.invoke() + + # Get the output from the interpreter, and compare it with the result from + # TensorFlow. + output_details = interpreter.get_output_details() + tflite_output_value = interpreter.get_tensor(output_details[0]['index']) + + max_error = np.max(np.abs(tflite_output_value - output_value)) + + self.assertTrue(max_error < ERROR_THRESHOLD) + + +if __name__ == '__main__': + test.main() |