aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/experimental
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-08-15 13:47:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-15 13:51:01 -0700
commit6b13a15e0b9906abbb66f87d83db291d0099cb43 (patch)
treed076dfaff5cc98a84a61719d5b5ff9ea610bed8d /tensorflow/contrib/lite/experimental
parentd2875ea71373d05c645587a83dd870fa8a0ec070 (diff)
Ops API to author TFLite ops as functions directly.
PiperOrigin-RevId: 208875580
Diffstat (limited to 'tensorflow/contrib/lite/experimental')
-rw-r--r--tensorflow/contrib/lite/experimental/ops/BUILD32
-rw-r--r--tensorflow/contrib/lite/experimental/ops/README.md14
-rw-r--r--tensorflow/contrib/lite/experimental/ops/ops.py87
-rw-r--r--tensorflow/contrib/lite/experimental/ops/ops_test.py109
4 files changed, 242 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/experimental/ops/BUILD b/tensorflow/contrib/lite/experimental/ops/BUILD
new file mode 100644
index 0000000000..6bdc0c1506
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/ops/BUILD
@@ -0,0 +1,32 @@
+package(default_visibility = [
+ "//visibility:private",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+py_library(
+ name = "ops",
+ srcs = [
+ "ops.py",
+ ],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
+ name = "ops_test",
+ srcs = ["ops_test.py"],
+ tags = [
+ # The test fails with CUDA on CPU. A lot of other tensorflow tests
+ # also use this flag. Unsure if this is a real problem or false
+ # alarm.
+ # TODO(ycling): Investigate if this is a real problem before open
+ # sourcing.
+ "no_cuda_on_cpu_tap",
+ ],
+ deps = [
+ ":ops",
+ "//tensorflow/contrib/lite/experimental/pb2lite/python:converter_wrapper",
+ ],
+)
diff --git a/tensorflow/contrib/lite/experimental/ops/README.md b/tensorflow/contrib/lite/experimental/ops/README.md
new file mode 100644
index 0000000000..723d2f3e49
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/ops/README.md
@@ -0,0 +1,14 @@
+The TensorFlow Lite ops API provides functions to author TFLite ops in a
+TensorFlow directily.
+
+Each python function creates a TensorFlow Function node in the graph.
+The node has a `_tflite_function_name` attribute to annotate which TensorFlow
+Lite op should be used.
+
+The graph can be used in TensorFlow for training and inference directly.
+After training is done, user can freeze the graph, and use the converter under
+the `experimental/pb2lite` directory to convert the graph to TensorFlow Lite
+model format.
+
+Warning: Everything in this directory is experimental and highly subject to
+changes.
diff --git a/tensorflow/contrib/lite/experimental/ops/ops.py b/tensorflow/contrib/lite/experimental/ops/ops.py
new file mode 100644
index 0000000000..23fc31d68d
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/ops/ops.py
@@ -0,0 +1,87 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Generating TF functions which can be 1:1 mapped to TFLite ops."""
+
+import tensorflow as tf
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.util.compat import as_bytes
+
+tfe = tf.contrib.eager
+
+
+def _annotate_tflite_op(outputs, op_name, options=None):
+ """Annotates a TF Node with TFLite options.
+
+ Args:
+ outputs: The output(s) of the op. Can be a `Tensor` or a list of `Tensor`s.
+ op_name: The TensorFlow Lite Op name. E.g. 'FULLY_CONNECTED'.
+ options: A `dict` which contains TensorFlow Lite options.
+
+ Raises:
+ ValueError: If unsupported option types are used.
+ """
+
+ # The outputs may be a `Tensor` or a list of tensors.
+ if isinstance(outputs, tf.Tensor):
+ op = outputs.op
+ else:
+ op = outputs[0].op
+
+ # pylint: disable=protected-access
+ # Note: `as_bytes` conversion is required for Python 3.
+ op._set_attr(
+ '_tflite_function_name', attr_value_pb2.AttrValue(s=as_bytes(op_name)))
+ if options:
+ for key, value in options.items():
+ if isinstance(value, str):
+ op._set_attr(key, attr_value_pb2.AttrValue(s=as_bytes(value)))
+ elif isinstance(value, int):
+ op._set_attr(key, attr_value_pb2.AttrValue(i=value))
+ else:
+ raise ValueError('Unsupported option value type %s' % value.__class__)
+ # pylint: enable=protected-access
+
+
+# TODO(ycling): Generate this interface with FlatBuffer reflection
+# functionality and extend the coverage to all TFLiteops.
+# TODO(ycling): Support optional tensors (e.g. not using missing bias).
+def fully_connected(x, weights, bias, fused_activation_function='NONE'):
+ """Create a TF function node equalivent to TFLite FULLY_CONNECTED op."""
+ options = {'_fused_activation_function': fused_activation_function}
+
+ @tfe.defun
+ def tf_lite_fully_connected(x, weights, bias):
+ """The TFLite FULLY_CONNECTED logic wrapped in a TF Function."""
+ # TFLite FULLY_CONNECTED definition is different from TF matmul.
+ # The weights are transposed. Therefore we need to transpose the
+ # weights inside the TF function to simulate TFLite behavior.
+ transposed_weights = tf.transpose(weights)
+
+ y = tf.matmul(x, transposed_weights) + bias
+ if fused_activation_function == 'RELU':
+ y = tf.nn.relu(y)
+ elif fused_activation_function == 'NONE':
+ # Do nothing.
+ pass
+ else:
+ # TODO(ycling): Support other activation functions.
+ raise Exception('Unsupported fused_activation_function "%s"' %
+ fused_activation_function)
+ return y
+
+ output = tf_lite_fully_connected(x, weights, bias)
+ _annotate_tflite_op(output, 'FULLY_CONNECTED', options=options)
+
+ return output
diff --git a/tensorflow/contrib/lite/experimental/ops/ops_test.py b/tensorflow/contrib/lite/experimental/ops/ops_test.py
new file mode 100644
index 0000000000..6c632d745d
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/ops/ops_test.py
@@ -0,0 +1,109 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.contrib.lite.experimental.ops import ops
+from tensorflow.contrib.lite.experimental.pb2lite.python import converter_wrapper as converter
+from tensorflow.contrib.lite.python.interpreter import Interpreter
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+from tensorflow.python.util.compat import as_bytes
+
+ERROR_THRESHOLD = 1e-6
+
+
+def _get_tf_operations():
+ return tf.get_default_graph().get_operations()
+
+
+class OpsTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ np.random.seed(0)
+
+ def test_fully_connected(self):
+ num_batches = 10
+ num_input_channels = 20
+ num_output_channels = 5
+
+ # No operation initially.
+ self.assertEqual(len(_get_tf_operations()), 0)
+
+ # Create 1 operation (Placeholder for the input).
+ input_shape = (num_batches, num_input_channels)
+ x = tf.placeholder(tf.float32, shape=input_shape)
+ self.assertEqual(len(_get_tf_operations()), 1)
+
+ # Defining weights and bias as constants. It should add 2 more
+ # nodes into the graph.
+ weights_shape = (num_output_channels, num_input_channels)
+ weights_value = np.random.rand(*weights_shape).astype(np.float32)
+ weights = tf.constant(weights_value, dtype=tf.float32)
+ bias_shape = (num_output_channels,)
+ bias_value = np.random.rand(*bias_shape).astype(np.float32)
+ bias = tf.constant(bias_value, dtype=tf.float32)
+ self.assertEqual(len(_get_tf_operations()), 3)
+
+ # Call the function to construct a TF Function node which is equivalent
+ # to TFLite FULLY_CONNECTED node.
+ output = ops.fully_connected(
+ x, weights, bias, fused_activation_function='RELU')
+
+ # Exactly one op should be added. It should be a function containing 2-3 ops
+ # (matmul, add, relu).
+ operations = _get_tf_operations()
+ self.assertEqual(len(operations), 4)
+
+ op = operations[-1]
+ node_def = op.node_def
+ # Note: `as_bytes` conversion is required for Python 3.
+ self.assertEqual(node_def.attr['_tflite_function_name'].s,
+ as_bytes('FULLY_CONNECTED'))
+ self.assertEqual(node_def.attr['_fused_activation_function'].s,
+ as_bytes('RELU'))
+
+ # Try to run the TF session to get the output value.
+ input_value = np.random.rand(*input_shape).astype(np.float32)
+ with tf.Session() as sess:
+ output_value = sess.run(output, feed_dict={x: input_value})
+ graph_def = sess.graph_def
+
+ # Convert the GraphDef to FlatBuffer.
+ flatbuffer_data = converter.Convert(graph_def.SerializeToString())
+
+ # Construct an interpreter with the FlatBuffer.
+ interpreter = Interpreter(model_content=flatbuffer_data)
+
+ # Invoke the interpreter.
+ input_details = interpreter.get_input_details()
+ input_index = input_details[0]['index']
+ interpreter.resize_tensor_input(input_index, input_shape)
+ interpreter.allocate_tensors()
+ interpreter.set_tensor(input_index, input_value)
+ interpreter.invoke()
+
+ # Get the output from the interpreter, and compare it with the result from
+ # TensorFlow.
+ output_details = interpreter.get_output_details()
+ tflite_output_value = interpreter.get_tensor(output_details[0]['index'])
+
+ max_error = np.max(np.abs(tflite_output_value - output_value))
+
+ self.assertTrue(max_error < ERROR_THRESHOLD)
+
+
+if __name__ == '__main__':
+ test.main()