diff options
author | Alan Chiao <alanchiao@google.com> | 2018-10-04 08:30:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 08:34:41 -0700 |
commit | dcd7dd2d2e1ed7d8c26dd22dbbd2bac269c42e1e (patch) | |
tree | 06798fad9258383b59ed80e1c30a751495ceb229 /tensorflow/contrib | |
parent | 7b56d4ff7679ed59e3ea799054c5dcefd0600ab0 (diff) |
Sparse output fully connected custom op.
PiperOrigin-RevId: 215741296
Diffstat (limited to 'tensorflow/contrib')
3 files changed, 411 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index daaf6714cc..95e387814d 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -210,6 +210,7 @@ cc_library( "slice.cc", "space_to_batch_nd.cc", "space_to_depth.cc", + "sparse_output_fully_connected.cc", "sparse_to_dense.cc", "split.cc", "squeeze.cc", @@ -334,6 +335,23 @@ tf_cc_test( ) tf_cc_test( + name = "sparse_output_fully_connected_test", + size = "small", + srcs = ["sparse_output_fully_connected_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) + +tf_cc_test( name = "activations_test", size = "small", srcs = ["activations_test.cc"], diff --git a/tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc new file mode 100644 index 0000000000..843ed0768c --- /dev/null +++ b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc @@ -0,0 +1,235 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// SparseOutputFullyConnected is a fully connected layer that uses a single +// row in the weights and bias via a lookup. +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace sparse_output_fully_connected { + +// Input tensors of size {n_batch, n_input} +constexpr int kInputTensor = 0; +// Auxiliary input tensor of size { 1 } +constexpr int kInputLookupTensor = 1; + +// Weights tensor of size { n_embeddings , n_input } +constexpr int kWeightsTensor = 2; +// Bias tensor of size { n_embeddings } +constexpr int kBiasTensor = 3; + +// Output tensor. +constexpr int kOutputTensor = 0; + +// Temporary tensors. +enum TemporaryTensor { + kInputQuantized = 0, + kScalingFactors = 1, + kNumTemporaryTensors = 2 +}; + +// Struct to hold op data. +struct OpData { + int scratch_tensor_index; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + context->AddTensors(context, /*tensors_to_add=*/kNumTemporaryTensors, + &data->scratch_tensor_index); + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast<OpData*>(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + OpData* op_data = reinterpret_cast<OpData*>(node->user_data); + + TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2); + const int n_batch = SizeOfDimension(input, 0); + const int n_input = SizeOfDimension(input, 1); + + const TfLiteTensor* lookup = GetInput(context, node, kInputLookupTensor); + TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); + // Only support single lookup. + TF_LITE_ENSURE_EQ(context, SizeOfDimension(lookup, 0), 1); + + const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 2); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(weights, 1), n_input); + + const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); + TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(weights, 0)); + + const bool is_hybrid_op = + (weights->type == kTfLiteUInt8 && input->type == kTfLiteFloat32); + + if (is_hybrid_op) { + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); + + // Allocate temporary tensors to store quantized values of input. + node->temporaries->data[kInputQuantized] = op_data->scratch_tensor_index; + TfLiteTensor* input_quantized = + GetTemporary(context, node, /*index=*/kInputQuantized); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + + // Tell interpreter to allocate temporary tensors to store scaling factors. + node->temporaries->data[kScalingFactors] = + op_data->scratch_tensor_index + kScalingFactors; + TfLiteTensor* scaling_factors = + GetTemporary(context, node, /*index=*/kScalingFactors); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } + } + return kTfLiteOk; +} + +TfLiteStatus EvalFloat(const TfLiteTensor* input, const TfLiteTensor* lookup, + const TfLiteTensor* weights, const TfLiteTensor* bias, + TfLiteTensor* output) { + const int n_batch = SizeOfDimension(input, 0); + const int n_input = SizeOfDimension(input, 1); + + const float* input_ptr_batch = input->data.f; + + // Initialize pointer to right row according to lookup value. + int32 lookup_index = lookup->data.i32[0]; + const float* weights_ptr = weights->data.f + lookup_index * n_input; + + // Initialize output to bias. + if (bias) { + float* bias_ptr = bias->data.f + lookup_index; + tensor_utils::VectorBatchVectorAssign(bias_ptr, 1, n_batch, output->data.f); + } else { + tensor_utils::ZeroVector(output->data.f, n_batch * 1); + } + + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + weights_ptr, /*m_rows=*/1, n_input, input_ptr_batch, n_batch, + output->data.f, /*result_stride=*/1); + + return kTfLiteOk; +} + +TfLiteStatus EvalHybrid(const TfLiteTensor* input, const TfLiteTensor* lookup, + const TfLiteTensor* weights, const TfLiteTensor* bias, + TfLiteTensor* scaling_factors, + TfLiteTensor* input_quantized, TfLiteTensor* output) { + const int n_batch = SizeOfDimension(input, 0); + const int n_input = SizeOfDimension(input, 1); + + const float* input_ptr_batch = input->data.f; + // Initialize the pointer to storage for quantized values and + // scaling factors. + int8_t* quantized_input_ptr_batch = + reinterpret_cast<int8_t*>(input_quantized->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; + + // Initialize pointer to right row according to lookup value. + int32 lookup_index = lookup->data.i32[0]; + int8_t* weights_ptr = + reinterpret_cast<int8_t*>(weights->data.uint8) + lookup_index * n_input; + + // Initialize output to bias. + if (bias) { + float* bias_ptr = bias->data.f + lookup_index; + tensor_utils::VectorBatchVectorAssign(bias_ptr, 1, n_batch, output->data.f); + } else { + tensor_utils::ZeroVector(output->data.f, n_batch * 1); + } + + if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) { + // Quantize input from float to int8. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_input; + tensor_utils::SymmetricQuantizeFloats( + input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset, + &unused_min, &unused_max, &scaling_factors_ptr[b]); + scaling_factors_ptr[b] *= weights->params.scale; + } + + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + weights_ptr, /*m_rows=*/1, n_input, quantized_input_ptr_batch, + scaling_factors_ptr, n_batch, output->data.f, /*result_stride=*/1); + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + const TfLiteTensor* lookup = GetInput(context, node, kInputLookupTensor); + const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor); + const TfLiteTensor* bias = GetInput(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (weights->type) { + case kTfLiteFloat32: { + return EvalFloat(input, lookup, weights, bias, output); + } + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = + GetTemporary(context, node, /*index=*/kInputQuantized); + TfLiteTensor* scaling_factors = + GetTemporary(context, node, /*index=*/kScalingFactors); + return EvalHybrid(input, lookup, weights, bias, scaling_factors, + input_quantized, output); + } + default: + context->ReportError(context, "Type %d is not currently supported.", + weights->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace sparse_output_fully_connected + +TfLiteRegistration* Register_SPARSE_OUTPUT_FULLY_CONNECTED() { + static TfLiteRegistration r = {sparse_output_fully_connected::Init, + sparse_output_fully_connected::Free, + sparse_output_fully_connected::Prepare, + sparse_output_fully_connected::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc new file mode 100644 index 0000000000..365986a5c1 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite sparse output fully connected op. +#include <iomanip> +#include <random> +#include <vector> + +#include <gtest/gtest.h> +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" + +namespace tflite { + +namespace ops { +namespace custom { + +TfLiteRegistration* Register_SPARSE_OUTPUT_FULLY_CONNECTED(); + +namespace { + +using ::testing::ElementsAreArray; + +class BaseSparseOutputFullyConnectedOpModel : public SingleOpModel { + public: + BaseSparseOutputFullyConnectedOpModel(const TensorData& input, + const TensorData& weights, + const TensorData& output = { + TensorType_FLOAT32}) { + input_ = AddInput(input); + lookup_ = AddInput({TensorType_INT32, {1}}); + weights_ = AddInput(weights); + int bias_size = GetShape(weights_)[0]; + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + output_ = AddOutput(output); + + // Create empty (required) options map. + flexbuffers::Builder fbb; + fbb.Map([&]() {}); + fbb.Finish(); + + SetCustomOp("SPARSE_OUTPUT_FULLY_CONNECTED", fbb.GetBuffer(), + Register_SPARSE_OUTPUT_FULLY_CONNECTED); + BuildInterpreter({GetShape(input_), GetShape(lookup_), GetShape(weights_), + GetShape(bias_)}); + } + + void SetInput(const std::vector<float>& data) { + PopulateTensor(input_, data); + } + + void SetLookup(const std::vector<int32>& f) { PopulateTensor(lookup_, f); } + + void SetBias(const std::vector<float>& f) { PopulateTensor(bias_, f); } + + std::vector<float> GetOutput() { return ExtractVector<float>(output_); } + + protected: + int input_; + int lookup_; + int weights_; + int bias_; + int output_; +}; + +class FloatSparseOutputFullyConnectedOpModel + : public BaseSparseOutputFullyConnectedOpModel { + public: + using BaseSparseOutputFullyConnectedOpModel:: + BaseSparseOutputFullyConnectedOpModel; + + void SetWeights(const std::vector<float>& f) { PopulateTensor(weights_, f); } +}; + +class HybridSparseOutputFullyConnectedOpModel + : public BaseSparseOutputFullyConnectedOpModel { + public: + using BaseSparseOutputFullyConnectedOpModel:: + BaseSparseOutputFullyConnectedOpModel; + + void SetWeights(const std::vector<float>& f) { + SymmetricQuantizeAndPopulate(weights_, f); + } +}; + +TEST(SparseOutputFullyConnectedOpTest, SimpleTestFloat) { + FloatSparseOutputFullyConnectedOpModel m({TensorType_FLOAT32, {1, 5}}, + {TensorType_FLOAT32, {3, 5}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({-1.0, 0.0, 1.0, 2.0, 3.0}); + + m.SetLookup({2}); + + m.SetWeights({ + -1.0, 0.0, 1.0, 2.0, 3.0, // + 0.0, 1.0, 2.0, 3.0, 4.0, // + 1.0, 2.0, 3.0, 4.0, 5.0, // + }); + + m.SetBias({1.0, 2.0, 3.0}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({28})); +} + +TEST(SparseOutputFullyConnectedOpTest, SimpleTestHybrid) { + HybridSparseOutputFullyConnectedOpModel m({TensorType_FLOAT32, {1, 5}}, + {TensorType_UINT8, {3, 5}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({-1.0, 0.0, 1.0, 2.0, 3.0}); + + m.SetLookup({2}); + + m.SetWeights({ + -1.0, 0.0, 1.0, 2.0, 3.0, // + 0.0, 1.0, 2.0, 3.0, 4.0, // + 1.0, 2.0, 3.0, 4.0, 5.0, // + }); + + m.SetBias({1.0, 2.0, 3.0}); + + m.Invoke(); + + // We get 28.0552 instead of 28. + // + // Input -> -42, 0, 42, 85, 127 with scale factor of 127/3. + // Looked up weights -> 25, 51, 76, 102, 127 with scale factor of 127/5. + // + // (-42 * 25 + 0 * 51 + 42 * 76 + 85 * 102 + 127 * 127) * (3*5/127^2) + 3.0 + // gives us the expected result. + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({28}, 0.0553))); +} + +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} |