diff options
author | 2018-09-27 14:25:18 -0700 | |
---|---|---|
committer | 2018-09-27 14:30:11 -0700 | |
commit | d0397c3314600da0c9cdc300ae87483331d54298 (patch) | |
tree | 8ec45f483facaa2219324a19a7f74e38e2e0a5e2 /tensorflow/contrib/lite/delegates/flex | |
parent | cc83067469bc30bba55932c587f31ef68f15792f (diff) |
Rename TFLite Eager delegate -> Flex delegate
PiperOrigin-RevId: 214835588
Diffstat (limited to 'tensorflow/contrib/lite/delegates/flex')
18 files changed, 2239 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/delegates/flex/BUILD b/tensorflow/contrib/lite/delegates/flex/BUILD new file mode 100644 index 0000000000..bf5d91899c --- /dev/null +++ b/tensorflow/contrib/lite/delegates/flex/BUILD @@ -0,0 +1,200 @@ +# +# This is a TF Lite delegate that is powered by TensorFlow's Eager. +# +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "buffer_map", + srcs = ["buffer_map.cc"], + hdrs = ["buffer_map.h"], + deps = [ + ":util", + "//tensorflow/c:c_api_internal", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite:kernel_api", + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite_no_runtime", + ], + "//conditions:default": [ + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + ], + }), +) + +tf_cc_test( + name = "buffer_map_test", + size = "small", + srcs = ["buffer_map_test.cc"], + deps = [ + ":buffer_map", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:util", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "delegate", + srcs = [ + "delegate.cc", + ], + hdrs = [ + "delegate.h", + ], + deps = [ + ":buffer_map", + ":delegate_data", + ":kernel", + ":util", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/contrib/lite:util", + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite_no_runtime", + ], + "//conditions:default": [ + "//tensorflow/core:lib", + ], + }), +) + +tf_cc_test( + name = "delegate_test", + size = "small", + srcs = ["delegate_test.cc"], + deps = [ + ":delegate", + ":test_util", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "delegate_data", + srcs = ["delegate_data.cc"], + hdrs = ["delegate_data.h"], + deps = [ + ":buffer_map", + "//tensorflow/core/common_runtime/eager:context", + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite", + ], + "//conditions:default": [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + ], + }), +) + +tf_cc_test( + name = "delegate_data_test", + size = "small", + srcs = ["delegate_data_test.cc"], + deps = [ + ":delegate_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:util", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "kernel", + srcs = ["kernel.cc"], + hdrs = ["kernel.h"], + deps = [ + ":delegate_data", + ":util", + "@flatbuffers", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/contrib/lite:string", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:execute", + "//tensorflow/core/common_runtime/eager:tensor_handle", + ] + select({ + # TODO(b/111881878): The android_tensorflow_lib target pulls in the full + # set of core TensorFlow kernels. We may want to revisit this dependency + # to allow selective registration via build targets. + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + "//conditions:default": [ + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:framework", + "//tensorflow/core:tensorflow", + ], + }), +) + +tf_cc_test( + name = "kernel_test", + size = "small", + srcs = ["kernel_test.cc"], + deps = [ + ":delegate_data", + ":kernel", + ":test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "test_util", + testonly = True, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + "//tensorflow/c:c_api_internal", + "//tensorflow/contrib/lite:string", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_absl//absl/memory", + "@flatbuffers", + ], +) + +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + deps = [ + "//tensorflow/c:c_api_internal", + "//tensorflow/contrib/lite/c:c_api_internal", + "//tensorflow/contrib/lite:kernel_api", + ] + select({ + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib_lite_no_runtime", + ], + "//conditions:default": [ + "//tensorflow/core:lib", + "//tensorflow/core:framework", + ], + }), +) + +tf_cc_test( + name = "util_test", + size = "small", + srcs = ["util_test.cc"], + deps = [ + ":util", + "//tensorflow/contrib/lite:string", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/delegates/flex/buffer_map.cc b/tensorflow/contrib/lite/delegates/flex/buffer_map.cc new file mode 100644 index 0000000000..63e39196d9 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/flex/buffer_map.cc @@ -0,0 +1,111 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h" + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/contrib/lite/delegates/flex/util.h" +#include "tensorflow/core/framework/allocation_description.pb.h" +#include "tensorflow/core/framework/log_memory.h" + +namespace tflite { +namespace flex { +namespace { +// A tensor buffer that is allocated, deallocated and populated by TF Lite. +class TfLiteTensorBuffer : public tensorflow::TensorBuffer { + public: + explicit TfLiteTensorBuffer(const TfLiteTensor* tensor) { + len_ = tensor->bytes; + // TODO(ahentz): if we can guarantee that TF Lite allocated tensors with + // the same alignment as TensorFlow (EIGEN_MAX_ALIGN_BYTES), then we can + // potentially eliminate the copy below. + data_ = + tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len_); + if (data_ != nullptr) { + if (tensorflow::LogMemory::IsEnabled()) { + tensorflow::LogMemory::RecordRawAllocation( + "TfLiteTensorBuffer_New", + tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, len_, + data_, tensorflow::cpu_allocator()); + } + std::memcpy(data_, tensor->data.raw, tensor->bytes); + } + } + + ~TfLiteTensorBuffer() override { + if (tensorflow::LogMemory::IsEnabled() && data_ != nullptr) { + tensorflow::LogMemory::RecordRawDeallocation( + "TfLiteTensorBuffer_Delete", + tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data_, + tensorflow::cpu_allocator(), false); + } + tensorflow::cpu_allocator()->DeallocateRaw(data_); + } + + void* data() const override { return data_; } + size_t size() const override { return len_; } + + TensorBuffer* root_buffer() override { return this; } + void FillAllocationDescription( + tensorflow::AllocationDescription* proto) const override { + tensorflow::int64 rb = size(); + proto->set_requested_bytes(rb); + proto->set_allocator_name(tensorflow::cpu_allocator()->Name()); + } + + // Prevents input forwarding from mutating this buffer. + bool OwnsMemory() const override { return false; } + + private: + void* data_; + size_t len_; +}; +} // namespace + +BufferMap::BufferMap() {} + +BufferMap::~BufferMap() {} + +bool BufferMap::HasTensor(int tensor_index) const { + return id_to_tensor_.count(tensor_index) != 0; +} + +tensorflow::Tensor BufferMap::GetTensor(int tensor_index) const { + return id_to_tensor_.at(tensor_index); +} + +void BufferMap::SetFromTfLite(int tensor_index, const TfLiteTensor* tensor) { + tensorflow::TensorShape shape; + int num_dims = tensor->dims->size; + for (int i = 0; i < num_dims; ++i) { + shape.AddDim(tensor->dims->data[i]); + } + // TODO(ahentz): we assume this is a new tensor and allocate a new buffer + // for it. This is not always the best approach. For example, this might + // be a reallocation after resizing tensors. In that case we would be + // preferable to somehow reuse the buffer. + auto* buf = new TfLiteTensorBuffer(tensor); + tensorflow::Tensor t = tensorflow::TensorCApi::MakeTensor( + GetTensorFlowDataType(tensor->type), shape, buf); + buf->Unref(); + + SetFromTensorFlow(tensor_index, std::move(t)); +} + +void BufferMap::SetFromTensorFlow(int tensor_index, tensorflow::Tensor tensor) { + id_to_tensor_[tensor_index] = std::move(tensor); +} + +} // namespace flex +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/flex/buffer_map.h b/tensorflow/contrib/lite/delegates/flex/buffer_map.h new file mode 100644 index 0000000000..4ce886568a --- /dev/null +++ b/tensorflow/contrib/lite/delegates/flex/buffer_map.h @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_BUFFER_MAP_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_BUFFER_MAP_H_ + +#include <map> + +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tflite { +namespace flex { + +// Maps a TF Lite tensor index into a TensorFlow tensor. +// +// The TF Lite interpreter assigns integer indices to each of its tensors, but +// the Flex delegate deals in terms of TensorFlow tensors. This class maps +// from indices to tensors and allows the creation of new tensors to be +// associated with a given index. +class BufferMap { + public: + BufferMap(); + ~BufferMap(); + + // Returns true if the given 'tensor_index' has a corresponding + // tensorflow::Tensor. + bool HasTensor(int tensor_index) const; + + // Returns the tensorflow::Tensor associated with the given 'tensor_index'. + // Precondition: HasTensor() is true. + tensorflow::Tensor GetTensor(int tensor_index) const; + + // Associates the given tensorflow::Tensor with the given 'tensor_index'. + // Note that tensorflow Tensors share data buffers, so this method is only a + // shallow copy. + void SetFromTensorFlow(int tensor_index, tensorflow::Tensor tensor); + + // Same as above but creates a new tensorflow::Tensor with a copy of the + // given TfLiteTensor's data. + void SetFromTfLite(int tensor_index, const TfLiteTensor* tensor); + + private: + std::map<int, tensorflow::Tensor> id_to_tensor_; +}; + +} // namespace flex +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_BUFFER_MAP_H_ diff --git a/tensorflow/contrib/lite/delegates/flex/buffer_map_test.cc b/tensorflow/contrib/lite/delegates/flex/buffer_map_test.cc new file mode 100644 index 0000000000..bb80e25e80 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/flex/buffer_map_test.cc @@ -0,0 +1,174 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/contrib/lite/util.h" + +namespace tflite { +namespace flex { +namespace { + +using ::testing::ElementsAre; + +// A bit of RAII to simplify handling of TfLiteTensors in the tests. +using UniqueTfLiteTensor = + std::unique_ptr<TfLiteTensor, std::function<void(TfLiteTensor*)>>; + +template <typename T> +UniqueTfLiteTensor MakeLiteTensor(const std::vector<int>& shape, + const std::vector<T>& data) { + auto tensor = UniqueTfLiteTensor(new TfLiteTensor, [](TfLiteTensor* t) { + TfLiteTensorDataFree(t); + TfLiteIntArrayFree(t->dims); + delete t; + }); + tensor->allocation_type = kTfLiteDynamic; + tensor->type = typeToTfLiteType<T>(); + tensor->dims = ConvertVectorToTfLiteIntArray(shape); + tensor->data.raw = nullptr; + TfLiteTensorRealloc(data.size() * sizeof(T), tensor.get()); + memcpy(tensor->data.raw, data.data(), data.size() * sizeof(T)); + return tensor; +} + +template <typename T> +tensorflow::Tensor MakeTensor(const std::vector<int>& shape, + const std::vector<T>& data) { + BufferMap buffer_map; // BufferMap is the easiest way to build the tensor. + UniqueTfLiteTensor t1 = MakeLiteTensor<T>(shape, data); + buffer_map.SetFromTfLite(0, t1.get()); + return buffer_map.GetTensor(0); +} + +std::vector<tensorflow::int64> GetTensorShape(const tensorflow::Tensor& t) { + std::vector<tensorflow::int64> shape(t.dims()); + for (int i = 0; i < t.dims(); ++i) { + shape[i] = t.dim_size(i); + } + return shape; +} + +template <typename T> +std::vector<T> GetTensorData(const tensorflow::Tensor& t) { + const T* data = t.flat<T>().data(); + return std::vector<T>(data, data + t.NumElements()); +} + +TEST(BufferMapTest, EmptyBuffer) { + BufferMap buffer_map; + EXPECT_FALSE(buffer_map.HasTensor(0)); +} + +TEST(BufferMapTest, SetFromTfLite) { + BufferMap buffer_map; + + UniqueTfLiteTensor t = + MakeLiteTensor<float>({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + buffer_map.SetFromTfLite(0, t.get()); + ASSERT_TRUE(buffer_map.HasTensor(0)); + + EXPECT_THAT(GetTensorData<float>(buffer_map.GetTensor(0)), + ElementsAre(0, 0, 0, 0.123f, 0, 0)); + + // Also check details of the tensor. + tensorflow::Tensor out_tensor = buffer_map.GetTensor(0); + ASSERT_EQ(out_tensor.dtype(), tensorflow::DT_FLOAT); + ASSERT_EQ(out_tensor.NumElements(), 6); + ASSERT_THAT(GetTensorShape(out_tensor), ElementsAre(1, 2, 1, 3)); +} + +TEST(BufferMapTest, SetFromTfLiteTwice) { + UniqueTfLiteTensor t1 = + MakeLiteTensor<float>({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + UniqueTfLiteTensor t2 = + MakeLiteTensor<int>({1, 2, 4}, {0, 0, 0, 3, 0, 0, 1, 2}); + + BufferMap buffer_map; + buffer_map.SetFromTfLite(0, t1.get()); + buffer_map.SetFromTfLite(0, t2.get()); + + EXPECT_THAT(GetTensorData<int>(buffer_map.GetTensor(0)), + ElementsAre(0, 0, 0, 3, 0, 0, 1, 2)); +} + +TEST(BufferMapTest, SetFromTensorFlow) { + tensorflow::Tensor t1 = + MakeTensor<float>({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + + BufferMap buffer_map; + buffer_map.SetFromTensorFlow(0, t1); + + EXPECT_THAT(GetTensorData<float>(buffer_map.GetTensor(0)), + ElementsAre(0, 0, 0, 0.123f, 0, 0)); + + // Also check details of the tensor. + tensorflow::Tensor out_tensor = buffer_map.GetTensor(0); + ASSERT_EQ(out_tensor.dtype(), tensorflow::DT_FLOAT); + ASSERT_EQ(out_tensor.NumElements(), 6); + ASSERT_THAT(GetTensorShape(out_tensor), ElementsAre(1, 2, 1, 3)); +} + +TEST(BufferMapTest, SetFromTensorFlowTwice) { + tensorflow::Tensor t1 = + MakeTensor<float>({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + tensorflow::Tensor t2 = MakeTensor<int>({1, 2, 4}, {0, 0, 0, 3, 0, 0, 1, 2}); + BufferMap buffer_map; + buffer_map.SetFromTensorFlow(0, t1); + buffer_map.SetFromTensorFlow(0, t2); + + EXPECT_THAT(GetTensorData<int>(buffer_map.GetTensor(0)), + ElementsAre(0, 0, 0, 3, 0, 0, 1, 2)); +} + +TEST(BufferMapTest, TfLiteOverwritesTensorFlow) { + tensorflow::Tensor t1 = + MakeTensor<float>({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + UniqueTfLiteTensor t2 = + MakeLiteTensor<int>({1, 2, 4}, {0, 0, 0, 3, 0, 0, 1, 2}); + + BufferMap buffer_map; + buffer_map.SetFromTensorFlow(0, t1); + buffer_map.SetFromTfLite(0, t2.get()); + + EXPECT_THAT(GetTensorData<int>(buffer_map.GetTensor(0)), + ElementsAre(0, 0, 0, 3, 0, 0, 1, 2)); +} + +TEST(BufferMapTest, TensorFlowOverwritesTfLite) { + tensorflow::Tensor t1 = + MakeTensor<float>({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0}); + UniqueTfLiteTensor t2 = + MakeLiteTensor<int>({1, 2, 4}, {0, 0, 0, 3, 0, 0, 1, 2}); + BufferMap buffer_map; + buffer_map.SetFromTfLite(0, t2.get()); + buffer_map.SetFromTensorFlow(0, t1); + + EXPECT_THAT(GetTensorData<float>(buffer_map.GetTensor(0)), + ElementsAre(0, 0, 0, 0.123f, 0, 0)); +} + +} // namespace +} // namespace flex +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/delegates/flex/delegate.cc b/tensorflow/contrib/lite/delegates/flex/delegate.cc new file mode 100644 index 0000000000..ba065a8ff5 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/flex/delegate.cc @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/flex/delegate.h" + +#include <vector> + +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h" +#include "tensorflow/contrib/lite/delegates/flex/kernel.h" +#include "tensorflow/contrib/lite/delegates/flex/util.h" +#include "tensorflow/contrib/lite/util.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tflite { +namespace flex { +namespace delegate { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteDelegate* delegate) { + // Get the nodes in the current execution plan. Interpreter owns this array. + TfLiteIntArray* plan; + TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); + + // Add all custom ops starting with "Flex" to list of supported nodes. + std::vector<int> supported_nodes; + for (int node_index : TfLiteIntArrayView(plan)) { + TfLiteNode* node; + TfLiteRegistration* registration; + TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( + context, node_index, &node, ®istration)); + + if (IsFlexOp(registration->custom_name)) { + supported_nodes.push_back(node_index); + } + } + + // Request TFLite to partition the graph and make kernels for each independent + // subgraph. + TfLiteIntArray* size_and_nodes = + ConvertVectorToTfLiteIntArray(supported_nodes); + context->ReplaceSubgraphsWithDelegateKernels(context, GetKernel(), + size_and_nodes, delegate); + TfLiteIntArrayFree(size_and_nodes); + return kTfLiteOk; +} + +TfLiteStatus CopyFromBufferHandle(TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, void* data, + size_t size) { + BufferMap* buffer_map = + reinterpret_cast<DelegateData*>(delegate->data_)->GetBufferMap(context); + + if (!buffer_map->HasTensor(buffer_handle)) { + context->ReportError(context, "Invalid tensor index %d.", buffer_handle); + return kTfLiteError; + } + + tensorflow::Tensor t = buffer_map->GetTensor(buffer_handle); + tensorflow::StringPiece t_data = t.tensor_data(); + + if (size != t_data.size()) { + context->ReportError( + context, "Not enough space to store TensorFlow's aligned buffer."); + return kTfLiteError; + } + + memcpy(data, t_data.data(), t_data.size()); + return kTfLiteOk; +} + +} // namespace delegate +} // namespace flex + +std::unique_ptr<FlexDelegate> FlexDelegate::Create() { + std::unique_ptr<flex::DelegateData> delegate_data; + if (!flex::DelegateData::Create(&delegate_data).ok()) { + fprintf(stderr, "Unable to initialize TensorFlow context.\n"); + return nullptr; + } + + return std::unique_ptr<FlexDelegate>( + new FlexDelegate(std::move(delegate_data))); +} + +FlexDelegate::FlexDelegate(std::unique_ptr<flex::DelegateData> delegate_data) + : TfLiteDelegate{ + /*data_=*/delegate_data.get(), + /*nullptr,*/ &flex::delegate::Prepare, + /*CopyFromBufferHandle=*/&flex::delegate::CopyFromBufferHandle, + /*CopyToBufferHandle=*/nullptr, + /*FreeBufferHandle=*/nullptr}, + delegate_data_(std::move(delegate_data)) {} + +FlexDelegate::~FlexDelegate() {} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/flex/delegate.h b/tensorflow/contrib/lite/delegates/flex/delegate.h new file mode 100644 index 0000000000..1017780dc7 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/flex/delegate.h @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_H_ + +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h" + +namespace tflite { + +// WARNING: This is an experimental interface that is subject to change. +// Delegate that can be used to extract parts of a graph that are designed to be +// executed by TensorFlow's runtime via Eager. +// +// The interpreter must be constructed after the FlexDelegate and destructed +// before the FlexDelegate. This delegate may be used with multiple +// interpreters, but it is *not* thread-safe. +// +// Usage: +// auto delegate = FlexDelegate::Create(); +// ... build interpreter ... +// +// if (delegate) { +// interpreter->ModifyGraphWithDelegate( +// delegate.get(), /*allow_dynamic_tensors=*/true); +// } +// ... run inference ... +// ... destroy interpreter ... +// ... destroy delegate ... +class FlexDelegate : public TfLiteDelegate { + public: + // Creates a delegate that supports TF ops. + // + // If the underyling TF Flex context creation fails, returns null. + static std::unique_ptr<FlexDelegate> Create(); + + ~FlexDelegate(); + + private: + explicit FlexDelegate(std::unique_ptr<flex::DelegateData> delegate_data); + + std::unique_ptr<flex::DelegateData> delegate_data_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_H_ diff --git a/tensorflow/contrib/lite/delegates/flex/delegate_data.cc b/tensorflow/contrib/lite/delegates/flex/delegate_data.cc new file mode 100644 index 0000000000..8f985f770c --- /dev/null +++ b/tensorflow/contrib/lite/delegates/flex/delegate_data.cc @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h" + +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tflite { +namespace flex { +tensorflow::Status DelegateData::Create(std::unique_ptr<DelegateData>* data) { + std::vector<tensorflow::Device*> devices; + + TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices( + tensorflow::SessionOptions(), "/job:localhost/replica:0/task:0", + &devices)); + + std::unique_ptr<tensorflow::DeviceMgr> device_mgr( + new tensorflow::DeviceMgr(devices)); + // Note that Rendezvous is ref-counted so it will be automatically deleted. + tensorflow::Rendezvous* rendezvous = + new tensorflow::IntraProcessRendezvous(device_mgr.get()); + data->reset(new DelegateData(new tensorflow::EagerContext( + tensorflow::SessionOptions(), + tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, + /*async=*/false, std::move(device_mgr), rendezvous))); + return tensorflow::Status(); +} + +DelegateData::DelegateData(tensorflow::EagerContext* eager_context) + : eager_context_(eager_context) {} + +DelegateData::~DelegateData() {} + +} // namespace flex +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/flex/delegate_data.h b/tensorflow/contrib/lite/delegates/flex/delegate_data.h new file mode 100644 index 0000000000..8d75f0b0ef --- /dev/null +++ b/tensorflow/contrib/lite/delegates/flex/delegate_data.h @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_ + +#include "tensorflow/contrib/lite/delegates/flex/buffer_map.h" +#include "tensorflow/core/common_runtime/eager/context.h" + +namespace tflite { +namespace flex { + +// Data kept by the Flex delegate for the lifetime of an Interpreter. +class DelegateData { + public: + // Create a new DelegateData, initialized with a newly-created EagerContext. + static tensorflow::Status Create(std::unique_ptr<DelegateData>* data); + + ~DelegateData(); + + // The EagerContext that is required for execution of Flex Ops. + tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); } + + // Map from TF Lite tensor index to TensorFlow tensor for a given context. + BufferMap* GetBufferMap(const TfLiteContext* context) { + return &buffer_map_[context]; + } + + private: + explicit DelegateData(tensorflow::EagerContext* eager_context); + + std::unique_ptr<tensorflow::EagerContext> eager_context_; + // TODO(b/112439500): Clean up stale BufferMap instances after adding the + // necessary cleanup hook from a TfLiteContext to a TfLiteDelegate. + std::unordered_map<const TfLiteContext*, BufferMap> buffer_map_; +}; + +} // namespace flex +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_DELEGATE_DATA_H_ diff --git a/tensorflow/contrib/lite/delegates/flex/delegate_data_test.cc b/tensorflow/contrib/lite/delegates/flex/delegate_data_test.cc new file mode 100644 index 0000000000..30b10f435a --- /dev/null +++ b/tensorflow/contrib/lite/delegates/flex/delegate_data_test.cc @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +namespace flex { +namespace { + +TEST(DelegateDataTest, Basic) { + std::unique_ptr<DelegateData> data; + // We only check for success because it is hard to make initialization fail. + // It only happens if we manage to not link the CPU device factory into the + // binary. + EXPECT_TRUE(DelegateData::Create(&data).ok()); + + TfLiteContext dummy_context1 = {}; + TfLiteContext dummy_context2 = {}; + EXPECT_NE(data->GetEagerContext(), nullptr); + EXPECT_NE(data->GetBufferMap(&dummy_context1), nullptr); + EXPECT_NE(data->GetBufferMap(&dummy_context1), + data->GetBufferMap(&dummy_context2)); +} + +} // namespace +} // namespace flex +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/delegates/flex/delegate_test.cc b/tensorflow/contrib/lite/delegates/flex/delegate_test.cc new file mode 100644 index 0000000000..1813952cef --- /dev/null +++ b/tensorflow/contrib/lite/delegates/flex/delegate_test.cc @@ -0,0 +1,246 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/flex/delegate.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/delegates/flex/test_util.h" + +namespace tflite { +namespace flex { +namespace { + +using ::testing::ContainsRegex; +using ::testing::ElementsAre; + +class DelegateTest : public testing::FlexModelTest { + public: + DelegateTest() { + delegate_ = FlexDelegate::Create(); + interpreter_.reset(new Interpreter(&error_reporter_)); + } + + ~DelegateTest() override { + // The delegate needs to be destructed after the interpreter because the + // interpreter references data contained in the delegate. + interpreter_.reset(); + delegate_.reset(); + } + + void ConfigureDelegate() { + ASSERT_EQ(interpreter_->ModifyGraphWithDelegate( + delegate_.get(), /*allow_dynamic_tensors=*/true), + kTfLiteOk); + } + + private: + std::unique_ptr<FlexDelegate> delegate_; +}; + +TEST_F(DelegateTest, FullGraph) { + // Define the graph. + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfOp(testing::kMul, {6, 7}, {8}); + + // Apply the delegate. + ConfigureDelegate(); + + // Define inputs. + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); + ASSERT_EQ(GetType(8), kTfLiteFloat32); +} + +TEST_F(DelegateTest, NonFloatTypeInference) { + AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2}); + + AddTfOp(testing::kAdd, {0, 1}, {2}); + + ConfigureDelegate(); + + SetShape(0, {2, 2}); + SetTypedValues<int>(0, {1, 2, 3, 4}); + SetShape(1, {2, 2}); + SetTypedValues<int>(1, {4, 3, 2, 1}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(2), ElementsAre(2, 2)); + ASSERT_THAT(GetTypedValues<int>(2), ElementsAre(5, 5, 5, 5)); + ASSERT_EQ(GetType(2), kTfLiteInt32); +} + +TEST_F(DelegateTest, MixedGraph) { + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfLiteMulOp({6, 7}, {8}); + + ConfigureDelegate(); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); +} + +TEST_F(DelegateTest, SplitGraph) { + AddTensors(10, {0}, {9}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kAdd, {1, 2}, {3}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + + AddTfLiteMulOp({4, 5}, {6}); + + AddTfOp(testing::kUnpack, {6}, {7, 8}); + AddTfOp(testing::kAdd, {7, 8}, {9}); + + ConfigureDelegate(); + + SetShape(0, {2, 2, 2, 1}); + SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(9), ElementsAre(1)); + ASSERT_THAT(GetValues(9), ElementsAre(10.0f)); +} + +TEST_F(DelegateTest, OnlyTFLite) { + // Only TFLite single op model. + AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3}); + AddTfLiteMulOp({0, 1}, {2}); + + ConfigureDelegate(); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(1, {2, 2, 1}); + SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1)); + ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f)); +} + +TEST_F(DelegateTest, MultipleInvokeCalls) { + // Call Invoke() multiple times on the same model. + AddTensors(10, {0, 1}, {2}, kTfLiteFloat32, {3}); + AddTfLiteMulOp({0, 1}, {2}); + + ConfigureDelegate(); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(1, {2, 2, 1}); + SetValues(1, {1.0f, 2.0f, 3.0f, 4.0f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1)); + ASSERT_THAT(GetValues(2), ElementsAre(1.1f, 4.4f, 9.9f, 17.6f)); + + SetShape(0, {2, 2, 1}); + SetValues(1, {4.0f, 3.0f, 2.0f, 1.0f}); + SetShape(1, {2, 2, 1}); + SetValues(0, {4.4f, 3.3f, 2.2f, 1.1f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(2), ElementsAre(2, 2, 1)); + ASSERT_THAT(GetValues(2), ElementsAre(17.6f, 9.9f, 4.4f, 1.1f)); +} + +TEST_F(DelegateTest, MultipleInterpretersSameDelegate) { + // Build a graph, configure the delegate and set inputs. + { + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfOp(testing::kMul, {6, 7}, {8}); + ConfigureDelegate(); + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + } + + // Create a new interpreter, inject into the test framework and build + // a different graph using the *same* delegate. + std::unique_ptr<Interpreter> interpreter(new Interpreter(&error_reporter_)); + interpreter_.swap(interpreter); + { + AddTensors(10, {0}, {9}, kTfLiteFloat32, {3}); + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kAdd, {1, 2}, {3}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfLiteMulOp({4, 5}, {6}); + AddTfOp(testing::kUnpack, {6}, {7, 8}); + AddTfOp(testing::kAdd, {7, 8}, {9}); + ConfigureDelegate(); + SetShape(0, {2, 2, 2, 1}); + SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f}); + } + + // Swap back in the first interpreter and validate inference. + interpreter_.swap(interpreter); + { + ASSERT_TRUE(Invoke()); + EXPECT_THAT(GetShape(8), ElementsAre(2, 1)); + EXPECT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); + } + + // Swap in the second interpreter and validate inference. + interpreter_.swap(interpreter); + { + ASSERT_TRUE(Invoke()); + EXPECT_THAT(GetShape(9), ElementsAre(1)); + EXPECT_THAT(GetValues(9), ElementsAre(10.0f)); + } +} + +} // namespace +} // namespace flex +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/delegates/flex/kernel.cc b/tensorflow/contrib/lite/delegates/flex/kernel.cc new file mode 100644 index 0000000000..e4f1aea990 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/flex/kernel.cc @@ -0,0 +1,299 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/flex/kernel.h" + +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "tensorflow/contrib/lite/builtin_ops.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h" +#include "tensorflow/contrib/lite/delegates/flex/util.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/execute.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" + +// Note: this is part of TF Lite's Flex delegation code which is to be +// completed soon. + +// This is the TF Lite op that is created by the flex delegate to handle +// execution of a supported subgraph. The usual flow is that the delegate +// informs the interpreter of supported nodes in a graph, and each supported +// subgraph is replaced with one instance of this kernel. +// +// The kernel is initialized with TfLiteDelegateParams from which we retrieve +// the global EagerContext and BufferMap, as well as a list of inputs and +// outputs to the subgraph. Those are used to build the OpData, with a list of +// TensorFlow Ops that should be executed in order (which we call an OpNode). +// +// For each node included in the subgraph, we query the interpreter and +// retrieve the associated NodeDef, which is then used to configure the +// corresponding TensorFlow/Eager Op. + +namespace tflite { +namespace flex { +namespace kernel { + +// Controls the lifetime of tensor handles in a vector. +class VectorOfHandles { + public: + explicit VectorOfHandles(int num_elements) : vector_(num_elements, nullptr) {} + + ~VectorOfHandles() { + for (auto* handle : vector_) { + if (handle) handle->Unref(); + } + } + + tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2>* GetVector() { + return &vector_; + } + + tensorflow::TensorHandle* GetHandle(int index) { return vector_[index]; } + + private: + tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> vector_; +}; + +// Executes the TensorFlow op given by 'op_name', with the attributes specified +// in 'nodedef'. Inputs and outputs are given as indices into the 'buffer_map'. +tensorflow::Status ExecuteFlexOp(tensorflow::EagerContext* eager_context, + BufferMap* buffer_map, const string& op_name, + const tensorflow::NodeDef& nodedef, + const std::vector<int>& inputs, + const std::vector<int>& outputs) { + const tensorflow::AttrTypeMap* attr_types; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + tensorflow::AttrTypeMapForOp(op_name.c_str(), &attr_types), + " (while processing attributes of '", op_name, "')"); + + tensorflow::EagerOperation op(eager_context, op_name.c_str(), attr_types); + for (const auto& attr : nodedef.attr()) { + op.MutableAttrs()->Set(attr.first, attr.second); + } + + for (int input_index : inputs) { + if (!buffer_map->HasTensor(input_index)) { + return tensorflow::errors::Internal( + "Cannot read from invalid tensor index ", input_index); + } + auto* handle = new tensorflow::TensorHandle( + buffer_map->GetTensor(input_index), nullptr, nullptr, nullptr); + op.AddInput(handle); + handle->Unref(); + } + + int num_retvals = outputs.size(); + VectorOfHandles retvals(num_retvals); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + EagerExecute(&op, retvals.GetVector(), &num_retvals), + " (while executing '", op_name, "' via Eager)"); + + if (num_retvals != outputs.size()) { + return tensorflow::errors::Internal( + "Unexpected number of outputs from EagerExecute"); + } + + for (int i = 0; i < num_retvals; ++i) { + const tensorflow::Tensor* tensor = nullptr; + TF_RETURN_IF_ERROR(retvals.GetHandle(i)->Tensor(&tensor)); + buffer_map->SetFromTensorFlow(outputs[i], *tensor); + } + + return tensorflow::Status::OK(); +} + +// A single node within the larger 'op'. Note that this kernel executes many +// TensorFlow ops within a single TF Lite op. +struct OpNode { + // The name of the TensorFlow op to execute. + string name; + // The corresponding NodeDef, containing the attributes for the op. + tensorflow::NodeDef nodedef; + // List of inputs, as TF Lite tensor indices. + std::vector<int> inputs; + // List of outputs, as TF Lite tensor indices. + std::vector<int> outputs; +}; + +// The Larger 'op', which contains all the nodes in a supported subgraph. +struct OpData { + tensorflow::EagerContext* eager_context; + BufferMap* buffer_map; + std::vector<OpNode> nodes; + std::vector<int> subgraph_inputs; + std::vector<int> subgraph_outputs; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* op_data = new OpData; + + const TfLiteDelegateParams* params = + reinterpret_cast<const TfLiteDelegateParams*>(buffer); + CHECK(params); + CHECK(params->delegate); + CHECK(params->delegate->data_); + op_data->eager_context = + reinterpret_cast<DelegateData*>(params->delegate->data_) + ->GetEagerContext(); + op_data->buffer_map = reinterpret_cast<DelegateData*>(params->delegate->data_) + ->GetBufferMap(context); + + CHECK(params->output_tensors); + for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) { + op_data->subgraph_outputs.push_back(tensor_index); + } + + CHECK(params->input_tensors); + for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) { + op_data->subgraph_inputs.push_back(tensor_index); + } + + CHECK(params->nodes_to_replace); + for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) { + TfLiteNode* node; + TfLiteRegistration* reg; + context->GetNodeAndRegistration(context, node_index, &node, ®); + + op_data->nodes.push_back(OpNode()); + OpNode& node_data = op_data->nodes.back(); + + node_data.name = ""; + if (node->custom_initial_data) { + // The flexbuffer contains a vector where the first elements is the + // op name and the second is a serialized NodeDef. + const flexbuffers::Vector& v = + flexbuffers::GetRoot( + reinterpret_cast<const uint8_t*>(node->custom_initial_data), + node->custom_initial_data_size) + .AsVector(); + + node_data.name = v[0].AsString().str(); + if (!node_data.nodedef.ParseFromString(v[1].AsString().str())) { + // We will just leave the nodedef empty and error out in Eval(). + node_data.nodedef.Clear(); + } + } + + // Fill NodeDef with defaults if it's a valid op. + const tensorflow::OpRegistrationData* op_reg_data; + auto tf_status = tensorflow::OpRegistry::Global()->LookUp( + node_data.nodedef.op(), &op_reg_data); + if (tf_status.ok()) { + AddDefaultsToNodeDef(op_reg_data->op_def, &node_data.nodedef); + } + + for (auto input_index : TfLiteIntArrayView(node->inputs)) { + node_data.inputs.push_back(input_index); + } + for (auto output_index : TfLiteIntArrayView(node->outputs)) { + node_data.outputs.push_back(output_index); + } + } + + return op_data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast<OpData*>(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const auto* op_data = reinterpret_cast<OpData*>(node->user_data); + TF_LITE_ENSURE_MSG( + context, op_data->eager_context != nullptr, + "Failed to initialize eager context. This often happens when a CPU " + "device has not been registered, presumably because some symbols from " + "tensorflow/core:core_cpu_impl were not linked into the binary."); + + // Whenever we find a constant tensor, insert it in the buffer map. + BufferMap* buffer_map = op_data->buffer_map; + for (auto tensor_index : op_data->subgraph_inputs) { + TfLiteTensor* tensor = &context->tensors[tensor_index]; + if (IsConstantTensor(tensor)) { + if (!buffer_map->HasTensor(tensor_index)) { + buffer_map->SetFromTfLite(tensor_index, tensor); + } + } + } + + // All output tensors are allocated by TensorFlow/Eager, so we + // mark them as kTfLiteDynamic. + for (auto tensor_index : op_data->subgraph_outputs) { + SetTensorToDynamic(&context->tensors[tensor_index]); + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const auto* op_data = reinterpret_cast<OpData*>(node->user_data); + BufferMap* buffer_map = op_data->buffer_map; + tensorflow::EagerContext* eager_context = op_data->eager_context; + + // Insert a tensor in the buffer map for all inputs that are not constant. + // Constants were handled in Prepare() already. + for (auto tensor_index : op_data->subgraph_inputs) { + TfLiteTensor* tensor = &context->tensors[tensor_index]; + if (!IsConstantTensor(tensor)) { + buffer_map->SetFromTfLite(tensor_index, tensor); + } + } + + // Execute the TensorFlow Ops sequentially. + for (const auto& node_data : op_data->nodes) { + if (node_data.nodedef.op().empty()) { + context->ReportError(context, "Invalid NodeDef in Flex op '%s'", + node_data.name.c_str()); + return kTfLiteError; + } + auto status = + ExecuteFlexOp(eager_context, buffer_map, node_data.name, + node_data.nodedef, node_data.inputs, node_data.outputs); + TF_LITE_ENSURE_OK(context, ConvertStatus(context, status)); + } + + for (auto tensor_index : op_data->subgraph_outputs) { + if (!buffer_map->HasTensor(tensor_index)) { + context->ReportError(context, "Cannot write to invalid tensor index %d", + tensor_index); + return kTfLiteError; + } + + TfLiteTensor* tensor = &context->tensors[tensor_index]; + TF_LITE_ENSURE_OK( + context, + CopyShapeAndType(context, buffer_map->GetTensor(tensor_index), tensor)); + tensor->buffer_handle = tensor_index; + tensor->data_is_stale = true; + } + + return kTfLiteOk; +} + +} // namespace kernel + +TfLiteRegistration GetKernel() { + TfLiteRegistration registration{&kernel::Init, &kernel::Free, + &kernel::Prepare, &kernel::Eval, + nullptr, kTfLiteBuiltinDelegate}; + return registration; +} + +} // namespace flex +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/flex/kernel.h b/tensorflow/contrib/lite/delegates/flex/kernel.h new file mode 100644 index 0000000000..ac9313a37b --- /dev/null +++ b/tensorflow/contrib/lite/delegates/flex/kernel.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_KERNEL_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_KERNEL_H_ + +#include "tensorflow/contrib/lite/c/c_api_internal.h" + +namespace tflite { +namespace flex { + +// Return the registration object used to initialize and execute ops that will +// be delegated to TensorFlow's Eager runtime. This TF Lite op is created by +// the flex delegate to handle execution of a supported subgraph. The usual +// flow is that the delegate informs the interpreter of supported nodes in a +// graph, and each supported subgraph is replaced with one instance of this +// kernel. +TfLiteRegistration GetKernel(); + +} // namespace flex +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_KERNEL_H_ diff --git a/tensorflow/contrib/lite/delegates/flex/kernel_test.cc b/tensorflow/contrib/lite/delegates/flex/kernel_test.cc new file mode 100644 index 0000000000..94a6f8b61a --- /dev/null +++ b/tensorflow/contrib/lite/delegates/flex/kernel_test.cc @@ -0,0 +1,230 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/flex/kernel.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/delegates/flex/delegate_data.h" +#include "tensorflow/contrib/lite/delegates/flex/test_util.h" + +namespace tflite { +namespace flex { +namespace { + +using ::testing::ContainsRegex; +using ::testing::ElementsAre; + +TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteDelegate* delegate, + const std::vector<int>& supported_nodes) { + TfLiteIntArray* size_and_nodes = + ConvertVectorToTfLiteIntArray(supported_nodes); + TF_LITE_ENSURE_STATUS(context->ReplaceSubgraphsWithDelegateKernels( + context, flex::GetKernel(), size_and_nodes, delegate)); + TfLiteIntArrayFree(size_and_nodes); + return kTfLiteOk; +} + +class KernelTest : public testing::FlexModelTest { + public: + KernelTest() { + CHECK(DelegateData::Create(&delegate_data_).ok()); + interpreter_.reset(new Interpreter(&error_reporter_)); + } + + ~KernelTest() override { + // The data needs to be released before the interpreter because the + // interpreter references the data. + delegate_data_.reset(); + interpreter_.reset(); + } + + template <typename T> + void ConfigureDelegate(T prepare_function) { + delegate_.data_ = delegate_data_.get(); + delegate_.FreeBufferHandle = nullptr; + delegate_.Prepare = prepare_function; + delegate_.CopyFromBufferHandle = [](TfLiteContext* context, + TfLiteDelegate* delegate, + TfLiteBufferHandle buffer_handle, + void* data, size_t size) { + auto* delegate_data = reinterpret_cast<DelegateData*>(delegate->data_); + tensorflow::StringPiece values = delegate_data->GetBufferMap(context) + ->GetTensor(buffer_handle) + .tensor_data(); + memcpy(data, values.data(), values.size()); + return kTfLiteOk; + }; + CHECK(interpreter_->ModifyGraphWithDelegate( + &delegate_, /*allow_dynamic_tensors=*/true) == kTfLiteOk); + } + + private: + std::unique_ptr<DelegateData> delegate_data_; + TfLiteDelegate delegate_; +}; + +TEST_F(KernelTest, FullGraph) { + // Define the graph. + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfOp(testing::kMul, {6, 7}, {8}); + + // Apply Delegate. + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0, 1, 2, 3, 4}); + }); + + // Define inputs. + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); +} + +TEST_F(KernelTest, BadTensorFlowOp) { + AddTensors(2, {0}, {1}, kTfLiteFloat32, {3}); + AddTfOp(testing::kNonExistent, {0}, {1}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_FALSE(Invoke()); + ASSERT_THAT(error_reporter().error_messages(), + ContainsRegex("while processing attributes of 'NonExistentOp'")); +} + +TEST_F(KernelTest, BadNumberOfOutputs) { + AddTensors(3, {0}, {1, 2}, kTfLiteFloat32, {3}); + AddTfOp(testing::kIdentity, {0}, {1, 2}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_FALSE(Invoke()); + ASSERT_THAT(error_reporter().error_messages(), + ContainsRegex("Unexpected number of outputs")); +} + +TEST_F(KernelTest, IncompatibleNodeDef) { + AddTensors(2, {0}, {1}, kTfLiteFloat32, {3}); + + // Cast is a TF op, but we don't add the proper nodedef to it in AddTfOp. + AddTfOp(testing::kIncompatibleNodeDef, {0}, {1}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_FALSE(Invoke()); + ASSERT_THAT(error_reporter().error_messages(), + ContainsRegex("while executing 'Cast' via Eager")); +} + +TEST_F(KernelTest, WrongSetOfNodes) { + AddTensors(4, {0}, {3}, kTfLiteFloat32, {3}); + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfLiteMulOp({1, 2}, {3}); + + // Specify that testing::kMul (#1) is supported when it actually isn't. + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0, 1}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_FALSE(Invoke()); + ASSERT_THAT(error_reporter().error_messages(), + ContainsRegex("Invalid NodeDef in Flex op")); +} + +TEST_F(KernelTest, MixedGraph) { + AddTensors(9, {0, 3}, {8}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + AddTfOp(testing::kAdd, {1, 4}, {6}); + AddTfOp(testing::kAdd, {2, 5}, {7}); + AddTfLiteMulOp({6, 7}, {8}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0, 1, 2, 3}); + }); + + SetShape(0, {2, 2, 1}); + SetValues(0, {1.1f, 2.2f, 3.3f, 4.4f}); + SetShape(3, {2, 2, 1}); + SetValues(3, {1.1f, 2.2f, 3.3f, 4.4f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); + ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); +} + +TEST_F(KernelTest, SplitGraph) { + AddTensors(10, {0}, {9}, kTfLiteFloat32, {3}); + + AddTfOp(testing::kUnpack, {0}, {1, 2}); + AddTfOp(testing::kAdd, {1, 2}, {3}); + AddTfOp(testing::kUnpack, {3}, {4, 5}); + + AddTfLiteMulOp({4, 5}, {6}); + + AddTfOp(testing::kUnpack, {6}, {7, 8}); + AddTfOp(testing::kAdd, {7, 8}, {9}); + + ConfigureDelegate([](TfLiteContext* context, TfLiteDelegate* delegate) { + return GenericPrepare(context, delegate, {0, 1, 2, 4, 5}); + }); + + SetShape(0, {2, 2, 2, 1}); + SetValues(0, {3.0f, 1.0f, 0.5f, -1.0f, 0.0f, 1.0f, 1.5f, 3.0f}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(9), ElementsAre(1)); + ASSERT_THAT(GetValues(9), ElementsAre(10.0f)); +} + +} // namespace +} // namespace flex +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/delegates/flex/test_util.cc b/tensorflow/contrib/lite/delegates/flex/test_util.cc new file mode 100644 index 0000000000..69c336a01a --- /dev/null +++ b/tensorflow/contrib/lite/delegates/flex/test_util.cc @@ -0,0 +1,157 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/delegates/flex/test_util.h" + +#include "absl/memory/memory.h" +#include "flatbuffers/flexbuffers.h" // TF:flatbuffers +#include "tensorflow/contrib/lite/string.h" + +namespace tflite { +namespace flex { +namespace testing { + +bool FlexModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; } + +void FlexModelTest::SetShape(int tensor_index, const std::vector<int>& values) { + ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk); + ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk); +} + +std::vector<int> FlexModelTest::GetShape(int tensor_index) { + std::vector<int> result; + auto* dims = interpreter_->tensor(tensor_index)->dims; + result.reserve(dims->size); + for (int i = 0; i < dims->size; ++i) { + result.push_back(dims->data[i]); + } + return result; +} + +TfLiteType FlexModelTest::GetType(int tensor_index) { + return interpreter_->tensor(tensor_index)->type; +} + +void FlexModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs, + const std::vector<int>& outputs, TfLiteType type, + const std::vector<int>& dims) { + interpreter_->AddTensors(num_tensors); + for (int i = 0; i < num_tensors; ++i) { + TfLiteQuantizationParams quant; + // Suppress explicit output type specification to ensure type inference + // works properly. + if (std::find(outputs.begin(), outputs.end(), i) != outputs.end()) { + type = kTfLiteFloat32; + } + CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type, + /*name=*/"", + /*dims=*/dims, quant), + kTfLiteOk); + } + + CHECK_EQ(interpreter_->SetInputs(inputs), kTfLiteOk); + CHECK_EQ(interpreter_->SetOutputs(outputs), kTfLiteOk); +} + +void FlexModelTest::AddTfLiteMulOp(const std::vector<int>& inputs, + const std::vector<int>& outputs) { + static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + reg.builtin_code = BuiltinOperator_MUL; + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + auto* i0 = &context->tensors[node->inputs->data[0]]; + auto* o = &context->tensors[node->outputs->data[0]]; + return context->ResizeTensor(context, o, TfLiteIntArrayCopy(i0->dims)); + }; + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + auto* i0 = &context->tensors[node->inputs->data[0]]; + auto* i1 = &context->tensors[node->inputs->data[1]]; + auto* o = &context->tensors[node->outputs->data[0]]; + for (int i = 0; i < o->bytes / sizeof(float); ++i) { + o->data.f[i] = i0->data.f[i] * i1->data.f[i]; + } + return kTfLiteOk; + }; + + CHECK_EQ(interpreter_->AddNodeWithParameters(inputs, outputs, nullptr, 0, + nullptr, ®), + kTfLiteOk); +} + +void FlexModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs, + const std::vector<int>& outputs) { + auto attr = [](const string& key, const string& value) { + return " attr{ key: '" + key + "' value {" + value + "}}"; + }; + + // Crude type attribution, will need fleshing out as more tests are added. + // TODO(b/113613439): Use nodedef string utilities to properly handle + // all types. + string type_attribute = attr("T", "type: DT_FLOAT"); + if (interpreter_->tensor(inputs[0])->type == kTfLiteInt32) { + type_attribute = attr("T", "type: DT_INT32"); + } + + if (op == kUnpack) { + string attributes = + type_attribute + attr("num", "i: 2") + attr("axis", "i: 0"); + AddTfOp("FlexUnpack", "Unpack", attributes, inputs, outputs); + } else if (op == kIdentity) { + string attributes = type_attribute; + AddTfOp("FlexIdentity", "Identity", attributes, inputs, outputs); + } else if (op == kAdd) { + string attributes = type_attribute; + AddTfOp("FlexAdd", "Add", attributes, inputs, outputs); + } else if (op == kMul) { + string attributes = type_attribute; + AddTfOp("FlexMul", "Mul", attributes, inputs, outputs); + } else if (op == kNonExistent) { + AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs); + } else if (op == kIncompatibleNodeDef) { + // "Cast" op is created without attributes - making it incompatible. + AddTfOp("FlexCast", "Cast", "", inputs, outputs); + } +} + +void FlexModelTest::AddTfOp(const char* tflite_name, const string& tf_name, + const string& nodedef_str, + const std::vector<int>& inputs, + const std::vector<int>& outputs) { + static TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + reg.builtin_code = BuiltinOperator_CUSTOM; + reg.custom_name = tflite_name; + + tensorflow::NodeDef nodedef; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString( + nodedef_str + " op: '" + tf_name + "'", &nodedef)); + string serialized_nodedef; + CHECK(nodedef.SerializeToString(&serialized_nodedef)); + flexbuffers::Builder fbb; + fbb.Vector([&]() { + fbb.String(nodedef.op()); + fbb.String(serialized_nodedef); + }); + fbb.Finish(); + + flexbuffers_.push_back(fbb.GetBuffer()); + auto& buffer = flexbuffers_.back(); + CHECK_EQ(interpreter_->AddNodeWithParameters( + inputs, outputs, reinterpret_cast<const char*>(buffer.data()), + buffer.size(), nullptr, ®), + kTfLiteOk); +} + +} // namespace testing +} // namespace flex +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/flex/test_util.h b/tensorflow/contrib/lite/delegates/flex/test_util.h new file mode 100644 index 0000000000..a8c81b90a3 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/flex/test_util.h @@ -0,0 +1,119 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_TEST_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_TEST_UTIL_H_ + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" + +namespace tflite { +namespace flex { +namespace testing { + +enum TfOpType { + kUnpack, + kIdentity, + kAdd, + kMul, + // Represents an op that does not exist in TensorFlow. + kNonExistent, + // Represents an valid TensorFlow op where the NodeDef is incompatible. + kIncompatibleNodeDef, +}; + +// This class creates models with TF and TFLite ops. In order to use this class +// to test the Flex delegate, implement a function that calls +// interpreter->ModifyGraphWithDelegate. +class FlexModelTest : public ::testing::Test { + public: + FlexModelTest() {} + ~FlexModelTest() {} + + bool Invoke(); + + // Sets the (typed) tensor's values at the given index. + template <typename T> + void SetTypedValues(int tensor_index, const std::vector<T>& values) { + memcpy(interpreter_->typed_tensor<T>(tensor_index), values.data(), + values.size() * sizeof(T)); + } + + // Returns the (typed) tensor's values at the given index. + template <typename T> + std::vector<T> GetTypedValues(int tensor_index) { + const TfLiteTensor* t = interpreter_->tensor(tensor_index); + const T* tdata = interpreter_->typed_tensor<T>(tensor_index); + return std::vector<T>(tdata, tdata + t->bytes / sizeof(T)); + } + + // Sets the tensor's values at the given index. + void SetValues(int tensor_index, const std::vector<float>& values) { + SetTypedValues<float>(tensor_index, values); + } + + // Returns the tensor's values at the given index. + std::vector<float> GetValues(int tensor_index) { + return GetTypedValues<float>(tensor_index); + } + + // Sets the tensor's shape at the given index. + void SetShape(int tensor_index, const std::vector<int>& values); + + // Returns the tensor's shape at the given index. + std::vector<int> GetShape(int tensor_index); + + // Returns the tensor's type at the given index. + TfLiteType GetType(int tensor_index); + + const TestErrorReporter& error_reporter() const { return error_reporter_; } + + // Adds `num_tensor` tensors to the model. `inputs` contains the indices of + // the input tensors and `outputs` contains the indices of the output + // tensors. All tensors are set to have `type` and `dims`. + void AddTensors(int num_tensors, const std::vector<int>& inputs, + const std::vector<int>& outputs, TfLiteType type, + const std::vector<int>& dims); + + // Adds a TFLite Mul op. `inputs` contains the indices of the input tensors + // and `outputs` contains the indices of the output tensors. + void AddTfLiteMulOp(const std::vector<int>& inputs, + const std::vector<int>& outputs); + + // Adds a TensorFlow op. `inputs` contains the indices of the + // input tensors and `outputs` contains the indices of the output tensors. + // This function is limited to the set of ops defined in TfOpType. + void AddTfOp(TfOpType op, const std::vector<int>& inputs, + const std::vector<int>& outputs); + + protected: + std::unique_ptr<Interpreter> interpreter_; + TestErrorReporter error_reporter_; + + private: + // Helper method to add a TensorFlow op. tflite_names needs to start with + // "Flex" in order to work with the Flex delegate. + void AddTfOp(const char* tflite_name, const string& tf_name, + const string& nodedef_str, const std::vector<int>& inputs, + const std::vector<int>& outputs); + + std::vector<std::vector<uint8_t>> flexbuffers_; +}; + +} // namespace testing +} // namespace flex +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/delegates/flex/util.cc b/tensorflow/contrib/lite/delegates/flex/util.cc new file mode 100644 index 0000000000..829bc388bf --- /dev/null +++ b/tensorflow/contrib/lite/delegates/flex/util.cc @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/flex/util.h" + +namespace tflite { +namespace flex { + +TfLiteStatus ConvertStatus(TfLiteContext* context, + const tensorflow::Status& status) { + if (!status.ok()) { + context->ReportError(context, "%s", status.error_message().c_str()); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus CopyShapeAndType(TfLiteContext* context, + const tensorflow::Tensor& src, + TfLiteTensor* tensor) { + tensor->type = GetTensorFlowLiteType(static_cast<TF_DataType>(src.dtype())); + if (tensor->type == kTfLiteNoType) { + context->ReportError(context, + "TF Lite does not support TensorFlow data type: %s", + DataTypeString(src.dtype()).c_str()); + return kTfLiteError; + } + + int num_dims = src.dims(); + TfLiteIntArray* shape = TfLiteIntArrayCreate(num_dims); + for (int j = 0; j < num_dims; ++j) { + // We need to cast from TensorFlow's int64 to TF Lite's int32. Let's + // make sure there's no overflow. + if (src.dim_size(j) >= std::numeric_limits<int>::max()) { + context->ReportError(context, + "Dimension value in TensorFlow shape is larger than " + "supported by TF Lite"); + TfLiteIntArrayFree(shape); + return kTfLiteError; + } + shape->data[j] = static_cast<int>(src.dim_size(j)); + } + return context->ResizeTensor(context, tensor, shape); +} + +TF_DataType GetTensorFlowDataType(TfLiteType type) { + switch (type) { + case kTfLiteNoType: + return TF_FLOAT; + case kTfLiteFloat32: + return TF_FLOAT; + case kTfLiteInt16: + return TF_INT16; + case kTfLiteInt32: + return TF_INT32; + case kTfLiteUInt8: + return TF_UINT8; + case kTfLiteInt64: + return TF_INT64; + case kTfLiteComplex64: + return TF_COMPLEX64; + case kTfLiteString: + return TF_STRING; + case kTfLiteBool: + return TF_BOOL; + } +} + +TfLiteType GetTensorFlowLiteType(TF_DataType type) { + switch (type) { + case TF_FLOAT: + return kTfLiteFloat32; + case TF_INT16: + return kTfLiteInt16; + case TF_INT32: + return kTfLiteInt32; + case TF_UINT8: + return kTfLiteUInt8; + case TF_INT64: + return kTfLiteInt64; + case TF_COMPLEX64: + return kTfLiteComplex64; + case TF_STRING: + return kTfLiteString; + case TF_BOOL: + return kTfLiteBool; + default: + return kTfLiteNoType; + } +} + +} // namespace flex +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/flex/util.h b/tensorflow/contrib/lite/delegates/flex/util.h new file mode 100644 index 0000000000..7f910e7316 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/flex/util.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_UTIL_H_ + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/contrib/lite/c/c_api_internal.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tflite { +namespace flex { + +// Converts a tensorflow:Status into a TfLiteStatus. If the original status +// represented an error, reports it using the given 'context'. +TfLiteStatus ConvertStatus(TfLiteContext* context, + const tensorflow::Status& status); + +// Copies the given shape and type of the TensorFlow 'src' tensor into a TF Lite +// 'tensor'. Logs an error and returns kTfLiteError if the shape or type can't +// be converted. +TfLiteStatus CopyShapeAndType(TfLiteContext* context, + const tensorflow::Tensor& src, + TfLiteTensor* tensor); + +// Returns the TF C API Data type that corresponds to the given TfLiteType. +TF_DataType GetTensorFlowDataType(TfLiteType type); + +// Returns the TfLiteType that corresponds to the given TF C API Data type. +TfLiteType GetTensorFlowLiteType(TF_DataType); + +} // namespace flex +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_FLEX_UTIL_H_ diff --git a/tensorflow/contrib/lite/delegates/flex/util_test.cc b/tensorflow/contrib/lite/delegates/flex/util_test.cc new file mode 100644 index 0000000000..5f049e7b0a --- /dev/null +++ b/tensorflow/contrib/lite/delegates/flex/util_test.cc @@ -0,0 +1,142 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/flex/util.h" + +#include <cstdarg> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/string.h" +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +namespace flex { +namespace { + +using tensorflow::DT_FLOAT; +using tensorflow::DT_INT32; +using tensorflow::Tensor; +using ::testing::ElementsAre; + +struct TestContext : public TfLiteContext { + string error; + std::vector<int> new_size; +}; + +void ReportError(TfLiteContext* context, const char* format, ...) { + TestContext* c = static_cast<TestContext*>(context); + const size_t kBufferSize = 1024; + char temp_buffer[kBufferSize]; + + va_list args; + va_start(args, format); + vsnprintf(temp_buffer, kBufferSize, format, args); + va_end(args); + + c->error = temp_buffer; +} + +TfLiteStatus ResizeTensor(TfLiteContext* context, TfLiteTensor* tensor, + TfLiteIntArray* new_size) { + TestContext* c = static_cast<TestContext*>(context); + c->new_size.clear(); + for (int i = 0; i < new_size->size; ++i) { + c->new_size.push_back(new_size->data[i]); + } + TfLiteIntArrayFree(new_size); + return kTfLiteOk; +} + +TEST(UtilTest, ConvertStatus) { + TestContext context; + context.ReportError = ReportError; + + EXPECT_EQ(ConvertStatus(&context, tensorflow::errors::Internal("Some Error")), + kTfLiteError); + EXPECT_EQ(context.error, "Some Error"); + + context.error.clear(); + EXPECT_EQ(ConvertStatus(&context, tensorflow::Status()), kTfLiteOk); + EXPECT_TRUE(context.error.empty()); +} + +TEST(UtilTest, CopyShapeAndType) { + TestContext context; + context.ReportError = ReportError; + context.ResizeTensor = ResizeTensor; + + TfLiteTensor dst; + + EXPECT_EQ(CopyShapeAndType(&context, Tensor(), &dst), kTfLiteOk); + EXPECT_THAT(context.new_size, ElementsAre(0)); + EXPECT_EQ(dst.type, kTfLiteFloat32); + + EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_FLOAT, {1, 2}), &dst), + kTfLiteOk); + EXPECT_THAT(context.new_size, ElementsAre(1, 2)); + EXPECT_EQ(dst.type, kTfLiteFloat32); + + EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_INT32, {1, 2}), &dst), + kTfLiteOk); + EXPECT_THAT(context.new_size, ElementsAre(1, 2)); + EXPECT_EQ(dst.type, kTfLiteInt32); + + EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_FLOAT, {1LL << 44, 2}), &dst), + kTfLiteError); + EXPECT_EQ(context.error, + "Dimension value in TensorFlow shape is larger than supported by " + "TF Lite"); + + EXPECT_EQ( + CopyShapeAndType(&context, Tensor(tensorflow::DT_HALF, {1, 2}), &dst), + kTfLiteError); + EXPECT_EQ(context.error, + "TF Lite does not support TensorFlow data type: half"); +} + +TEST(UtilTest, TypeConversionsFromTFLite) { + EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteNoType)); + EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteFloat32)); + EXPECT_EQ(TF_INT16, GetTensorFlowDataType(kTfLiteInt16)); + EXPECT_EQ(TF_INT32, GetTensorFlowDataType(kTfLiteInt32)); + EXPECT_EQ(TF_UINT8, GetTensorFlowDataType(kTfLiteUInt8)); + EXPECT_EQ(TF_INT64, GetTensorFlowDataType(kTfLiteInt64)); + EXPECT_EQ(TF_COMPLEX64, GetTensorFlowDataType(kTfLiteComplex64)); + EXPECT_EQ(TF_STRING, GetTensorFlowDataType(kTfLiteString)); + EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool)); +} + +TEST(UtilTest, TypeConversionsFromTensorFlow) { + EXPECT_EQ(kTfLiteFloat32, GetTensorFlowLiteType(TF_FLOAT)); + EXPECT_EQ(kTfLiteInt16, GetTensorFlowLiteType(TF_INT16)); + EXPECT_EQ(kTfLiteInt32, GetTensorFlowLiteType(TF_INT32)); + EXPECT_EQ(kTfLiteUInt8, GetTensorFlowLiteType(TF_UINT8)); + EXPECT_EQ(kTfLiteInt64, GetTensorFlowLiteType(TF_INT64)); + EXPECT_EQ(kTfLiteComplex64, GetTensorFlowLiteType(TF_COMPLEX64)); + EXPECT_EQ(kTfLiteString, GetTensorFlowLiteType(TF_STRING)); + EXPECT_EQ(kTfLiteBool, GetTensorFlowLiteType(TF_BOOL)); + EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_RESOURCE)); + EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_VARIANT)); +} + +} // namespace +} // namespace flex +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} |