aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alan Chiao <alanchiao@google.com>2018-10-04 08:30:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 08:34:41 -0700
commitdcd7dd2d2e1ed7d8c26dd22dbbd2bac269c42e1e (patch)
tree06798fad9258383b59ed80e1c30a751495ceb229
parent7b56d4ff7679ed59e3ea799054c5dcefd0600ab0 (diff)
Sparse output fully connected custom op.
PiperOrigin-RevId: 215741296
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD18
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc235
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc158
3 files changed, 411 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index daaf6714cc..95e387814d 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -210,6 +210,7 @@ cc_library(
"slice.cc",
"space_to_batch_nd.cc",
"space_to_depth.cc",
+ "sparse_output_fully_connected.cc",
"sparse_to_dense.cc",
"split.cc",
"squeeze.cc",
@@ -334,6 +335,23 @@ tf_cc_test(
)
tf_cc_test(
+ name = "sparse_output_fully_connected_test",
+ size = "small",
+ srcs = ["sparse_output_fully_connected_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
+tf_cc_test(
name = "activations_test",
size = "small",
srcs = ["activations_test.cc"],
diff --git a/tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc
new file mode 100644
index 0000000000..843ed0768c
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc
@@ -0,0 +1,235 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// SparseOutputFullyConnected is a fully connected layer that uses a single
+// row in the weights and bias via a lookup.
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sparse_output_fully_connected {
+
+// Input tensors of size {n_batch, n_input}
+constexpr int kInputTensor = 0;
+// Auxiliary input tensor of size { 1 }
+constexpr int kInputLookupTensor = 1;
+
+// Weights tensor of size { n_embeddings , n_input }
+constexpr int kWeightsTensor = 2;
+// Bias tensor of size { n_embeddings }
+constexpr int kBiasTensor = 3;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
+
+// Temporary tensors.
+enum TemporaryTensor {
+ kInputQuantized = 0,
+ kScalingFactors = 1,
+ kNumTemporaryTensors = 2
+};
+
+// Struct to hold op data.
+struct OpData {
+ int scratch_tensor_index;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ context->AddTensors(context, /*tensors_to_add=*/kNumTemporaryTensors,
+ &data->scratch_tensor_index);
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
+ const int n_batch = SizeOfDimension(input, 0);
+ const int n_input = SizeOfDimension(input, 1);
+
+ const TfLiteTensor* lookup = GetInput(context, node, kInputLookupTensor);
+ TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
+ // Only support single lookup.
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(lookup, 0), 1);
+
+ const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 2);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(weights, 1), n_input);
+
+ const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
+ TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(weights, 0));
+
+ const bool is_hybrid_op =
+ (weights->type == kTfLiteUInt8 && input->type == kTfLiteFloat32);
+
+ if (is_hybrid_op) {
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
+
+ // Allocate temporary tensors to store quantized values of input.
+ node->temporaries->data[kInputQuantized] = op_data->scratch_tensor_index;
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, /*index=*/kInputQuantized);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+
+ // Tell interpreter to allocate temporary tensors to store scaling factors.
+ node->temporaries->data[kScalingFactors] =
+ op_data->scratch_tensor_index + kScalingFactors;
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, /*index=*/kScalingFactors);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalFloat(const TfLiteTensor* input, const TfLiteTensor* lookup,
+ const TfLiteTensor* weights, const TfLiteTensor* bias,
+ TfLiteTensor* output) {
+ const int n_batch = SizeOfDimension(input, 0);
+ const int n_input = SizeOfDimension(input, 1);
+
+ const float* input_ptr_batch = input->data.f;
+
+ // Initialize pointer to right row according to lookup value.
+ int32 lookup_index = lookup->data.i32[0];
+ const float* weights_ptr = weights->data.f + lookup_index * n_input;
+
+ // Initialize output to bias.
+ if (bias) {
+ float* bias_ptr = bias->data.f + lookup_index;
+ tensor_utils::VectorBatchVectorAssign(bias_ptr, 1, n_batch, output->data.f);
+ } else {
+ tensor_utils::ZeroVector(output->data.f, n_batch * 1);
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ weights_ptr, /*m_rows=*/1, n_input, input_ptr_batch, n_batch,
+ output->data.f, /*result_stride=*/1);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(const TfLiteTensor* input, const TfLiteTensor* lookup,
+ const TfLiteTensor* weights, const TfLiteTensor* bias,
+ TfLiteTensor* scaling_factors,
+ TfLiteTensor* input_quantized, TfLiteTensor* output) {
+ const int n_batch = SizeOfDimension(input, 0);
+ const int n_input = SizeOfDimension(input, 1);
+
+ const float* input_ptr_batch = input->data.f;
+ // Initialize the pointer to storage for quantized values and
+ // scaling factors.
+ int8_t* quantized_input_ptr_batch =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+
+ // Initialize pointer to right row according to lookup value.
+ int32 lookup_index = lookup->data.i32[0];
+ int8_t* weights_ptr =
+ reinterpret_cast<int8_t*>(weights->data.uint8) + lookup_index * n_input;
+
+ // Initialize output to bias.
+ if (bias) {
+ float* bias_ptr = bias->data.f + lookup_index;
+ tensor_utils::VectorBatchVectorAssign(bias_ptr, 1, n_batch, output->data.f);
+ } else {
+ tensor_utils::ZeroVector(output->data.f, n_batch * 1);
+ }
+
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
+ // Quantize input from float to int8.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
+ &unused_min, &unused_max, &scaling_factors_ptr[b]);
+ scaling_factors_ptr[b] *= weights->params.scale;
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ weights_ptr, /*m_rows=*/1, n_input, quantized_input_ptr_batch,
+ scaling_factors_ptr, n_batch, output->data.f, /*result_stride=*/1);
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* lookup = GetInput(context, node, kInputLookupTensor);
+ const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
+ const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (weights->type) {
+ case kTfLiteFloat32: {
+ return EvalFloat(input, lookup, weights, bias, output);
+ }
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, /*index=*/kInputQuantized);
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, /*index=*/kScalingFactors);
+ return EvalHybrid(input, lookup, weights, bias, scaling_factors,
+ input_quantized, output);
+ }
+ default:
+ context->ReportError(context, "Type %d is not currently supported.",
+ weights->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace sparse_output_fully_connected
+
+TfLiteRegistration* Register_SPARSE_OUTPUT_FULLY_CONNECTED() {
+ static TfLiteRegistration r = {sparse_output_fully_connected::Init,
+ sparse_output_fully_connected::Free,
+ sparse_output_fully_connected::Prepare,
+ sparse_output_fully_connected::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc
new file mode 100644
index 0000000000..365986a5c1
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc
@@ -0,0 +1,158 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite sparse output fully connected op.
+#include <iomanip>
+#include <random>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+
+namespace tflite {
+
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_SPARSE_OUTPUT_FULLY_CONNECTED();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseSparseOutputFullyConnectedOpModel : public SingleOpModel {
+ public:
+ BaseSparseOutputFullyConnectedOpModel(const TensorData& input,
+ const TensorData& weights,
+ const TensorData& output = {
+ TensorType_FLOAT32}) {
+ input_ = AddInput(input);
+ lookup_ = AddInput({TensorType_INT32, {1}});
+ weights_ = AddInput(weights);
+ int bias_size = GetShape(weights_)[0];
+ bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
+ output_ = AddOutput(output);
+
+ // Create empty (required) options map.
+ flexbuffers::Builder fbb;
+ fbb.Map([&]() {});
+ fbb.Finish();
+
+ SetCustomOp("SPARSE_OUTPUT_FULLY_CONNECTED", fbb.GetBuffer(),
+ Register_SPARSE_OUTPUT_FULLY_CONNECTED);
+ BuildInterpreter({GetShape(input_), GetShape(lookup_), GetShape(weights_),
+ GetShape(bias_)});
+ }
+
+ void SetInput(const std::vector<float>& data) {
+ PopulateTensor(input_, data);
+ }
+
+ void SetLookup(const std::vector<int32>& f) { PopulateTensor(lookup_, f); }
+
+ void SetBias(const std::vector<float>& f) { PopulateTensor(bias_, f); }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+ int input_;
+ int lookup_;
+ int weights_;
+ int bias_;
+ int output_;
+};
+
+class FloatSparseOutputFullyConnectedOpModel
+ : public BaseSparseOutputFullyConnectedOpModel {
+ public:
+ using BaseSparseOutputFullyConnectedOpModel::
+ BaseSparseOutputFullyConnectedOpModel;
+
+ void SetWeights(const std::vector<float>& f) { PopulateTensor(weights_, f); }
+};
+
+class HybridSparseOutputFullyConnectedOpModel
+ : public BaseSparseOutputFullyConnectedOpModel {
+ public:
+ using BaseSparseOutputFullyConnectedOpModel::
+ BaseSparseOutputFullyConnectedOpModel;
+
+ void SetWeights(const std::vector<float>& f) {
+ SymmetricQuantizeAndPopulate(weights_, f);
+ }
+};
+
+TEST(SparseOutputFullyConnectedOpTest, SimpleTestFloat) {
+ FloatSparseOutputFullyConnectedOpModel m({TensorType_FLOAT32, {1, 5}},
+ {TensorType_FLOAT32, {3, 5}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({-1.0, 0.0, 1.0, 2.0, 3.0});
+
+ m.SetLookup({2});
+
+ m.SetWeights({
+ -1.0, 0.0, 1.0, 2.0, 3.0, //
+ 0.0, 1.0, 2.0, 3.0, 4.0, //
+ 1.0, 2.0, 3.0, 4.0, 5.0, //
+ });
+
+ m.SetBias({1.0, 2.0, 3.0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({28}));
+}
+
+TEST(SparseOutputFullyConnectedOpTest, SimpleTestHybrid) {
+ HybridSparseOutputFullyConnectedOpModel m({TensorType_FLOAT32, {1, 5}},
+ {TensorType_UINT8, {3, 5}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({-1.0, 0.0, 1.0, 2.0, 3.0});
+
+ m.SetLookup({2});
+
+ m.SetWeights({
+ -1.0, 0.0, 1.0, 2.0, 3.0, //
+ 0.0, 1.0, 2.0, 3.0, 4.0, //
+ 1.0, 2.0, 3.0, 4.0, 5.0, //
+ });
+
+ m.SetBias({1.0, 2.0, 3.0});
+
+ m.Invoke();
+
+ // We get 28.0552 instead of 28.
+ //
+ // Input -> -42, 0, 42, 85, 127 with scale factor of 127/3.
+ // Looked up weights -> 25, 51, 76, 102, 127 with scale factor of 127/5.
+ //
+ // (-42 * 25 + 0 * 51 + 42 * 76 + 85 * 102 + 127 * 127) * (3*5/127^2) + 3.0
+ // gives us the expected result.
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({28}, 0.0553)));
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}