aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-17 08:08:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-17 08:12:12 -0700
commit17b9b3f957654178b33c2b13a83b5f6b0690f16d (patch)
tree020a91069f3cd39820fcc6af812643703ec41ce2 /tensorflow/contrib
parent09d35773d849b7f6aed72adbb3bf7eb79648a49a (diff)
A map from TF Lite tensor indices to TensorFlow tensors.
PiperOrigin-RevId: 204912661
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/lite/delegates/eager/BUILD31
-rw-r--r--tensorflow/contrib/lite/delegates/eager/buffer_map.cc105
-rw-r--r--tensorflow/contrib/lite/delegates/eager/buffer_map.h59
-rw-r--r--tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc172
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.cc23
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.h5
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util_test.cc17
-rw-r--r--tensorflow/contrib/lite/util.h7
8 files changed, 415 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD
index 066b106215..270d83d188 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/eager/BUILD
@@ -8,10 +8,41 @@ package(default_visibility = [
licenses(["notice"]) # Apache 2.0
cc_library(
+ name = "buffer_map",
+ srcs = ["buffer_map.cc"],
+ hdrs = ["buffer_map.h"],
+ deps = [
+ ":util",
+ "//tensorflow/c:c_api_internal",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:kernel_api",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_test(
+ name = "buffer_map_test",
+ size = "small",
+ srcs = ["buffer_map_test.cc"],
+ tags = [
+ "tflite_not_portable",
+ ],
+ deps = [
+ ":buffer_map",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:util",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
name = "util",
srcs = ["util.cc"],
hdrs = ["util.h"],
deps = [
+ "//tensorflow/c:c_api_internal",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/core:framework",
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.cc b/tensorflow/contrib/lite/delegates/eager/buffer_map.cc
new file mode 100644
index 0000000000..e4a780b735
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.cc
@@ -0,0 +1,105 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h"
+
+#include "tensorflow/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/delegates/eager/util.h"
+#include "tensorflow/core/framework/allocation_description.pb.h"
+#include "tensorflow/core/framework/log_memory.h"
+
+namespace tflite {
+namespace {
+// A tensor buffer that is allocated, deallocated and populated by TF Lite.
+class TfLiteTensorBuffer : public tensorflow::TensorBuffer {
+ public:
+ explicit TfLiteTensorBuffer(const TfLiteTensor* tensor) {
+ len_ = tensor->bytes;
+ // TODO(ahentz): if we can guarantee that TF Lite allocated tensors with
+ // the same alignment as TensorFlow (EIGEN_MAX_ALIGN_BYTES), then we can
+ // potentially eliminate the copy below.
+ data_ =
+ tensorflow::cpu_allocator()->AllocateRaw(EIGEN_MAX_ALIGN_BYTES, len_);
+ if (data_ != nullptr) {
+ if (tensorflow::LogMemory::IsEnabled()) {
+ tensorflow::LogMemory::RecordRawAllocation(
+ "TfLiteTensorBuffer_New",
+ tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, len_,
+ data_, tensorflow::cpu_allocator());
+ }
+ std::memcpy(data_, tensor->data.raw, tensor->bytes);
+ }
+ }
+
+ ~TfLiteTensorBuffer() override {
+ if (tensorflow::LogMemory::IsEnabled() && data_ != nullptr) {
+ tensorflow::LogMemory::RecordRawDeallocation(
+ "TfLiteTensorBuffer_Delete",
+ tensorflow::LogMemory::EXTERNAL_TENSOR_ALLOCATION_STEP_ID, data_,
+ tensorflow::cpu_allocator(), false);
+ }
+ tensorflow::cpu_allocator()->DeallocateRaw(data_);
+ }
+
+ void* data() const override { return data_; }
+ size_t size() const override { return len_; }
+
+ TensorBuffer* root_buffer() override { return this; }
+ void FillAllocationDescription(
+ tensorflow::AllocationDescription* proto) const override {
+ tensorflow::int64 rb = size();
+ proto->set_requested_bytes(rb);
+ proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
+ }
+
+ // Prevents input forwarding from mutating this buffer.
+ bool OwnsMemory() const override { return false; }
+
+ private:
+ void* data_;
+ size_t len_;
+};
+} // namespace
+
+BufferMap::BufferMap() {}
+
+BufferMap::~BufferMap() {}
+
+bool BufferMap::HasTensor(int tensor_index) const {
+ return id_to_tensor_.count(tensor_index) != 0;
+}
+
+tensorflow::Tensor BufferMap::GetTensor(int tensor_index) const {
+ return id_to_tensor_.at(tensor_index);
+}
+
+void BufferMap::SetFromTfLite(int tensor_index, const TfLiteTensor* tensor) {
+ tensorflow::TensorShape shape;
+ int num_dims = tensor->dims->size;
+ for (int i = 0; i < num_dims; ++i) {
+ shape.AddDim(tensor->dims->data[i]);
+ }
+ auto* buf = new TfLiteTensorBuffer(tensor);
+ tensorflow::Tensor t = tensorflow::TensorCApi::MakeTensor(
+ GetTensorFlowDataType(tensor->type), shape, buf);
+ buf->Unref();
+
+ SetFromTensorFlow(tensor_index, std::move(t));
+}
+
+void BufferMap::SetFromTensorFlow(int tensor_index, tensorflow::Tensor tensor) {
+ id_to_tensor_[tensor_index] = std::move(tensor);
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.h b/tensorflow/contrib/lite/delegates/eager/buffer_map.h
new file mode 100644
index 0000000000..922f67f574
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.h
@@ -0,0 +1,59 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_
+
+#include <map>
+
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tflite {
+
+// Maps a TF Lite tensor index into a TensorFlow tensor.
+//
+// The TF Lite interpreter assigns integer indices to each of its tensors, but
+// the Eager delegate deals in terms of TensorFlow tensors. This class maps
+// from indices to tensors and allows the creation of new tensors to be
+// associated with a given index.
+class BufferMap {
+ public:
+ BufferMap();
+ ~BufferMap();
+
+ // Returns true if the given 'tensor_index' has a corresponding
+ // tensorflow::Tensor.
+ bool HasTensor(int tensor_index) const;
+
+ // Returns the tensorflow::Tensor associated with the given 'tensor_index'.
+ // Precondition: HasTensor() is true.
+ tensorflow::Tensor GetTensor(int tensor_index) const;
+
+ // Associates the given tensorflow::Tensor with the given 'tensor_index'.
+ // Note that tensorflow Tensors share data buffers, so this method is only a
+ // shallow copy.
+ void SetFromTensorFlow(int tensor_index, tensorflow::Tensor tensor);
+
+ // Same as above but creates a new tensorflow::Tensor with a copy of the
+ // given TfLiteTensor's data.
+ void SetFromTfLite(int tensor_index, const TfLiteTensor* tensor);
+
+ private:
+ std::map<int, tensorflow::Tensor> id_to_tensor_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_BUFFER_MAP_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc b/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc
new file mode 100644
index 0000000000..c447eeaa05
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/buffer_map_test.cc
@@ -0,0 +1,172 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/testing/util.h"
+#include "tensorflow/contrib/lite/util.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+// A bit of RAII to simplify handling of TfLiteTensors in the tests.
+using UniqueTfLiteTensor =
+ std::unique_ptr<TfLiteTensor, std::function<void(TfLiteTensor*)>>;
+
+template <typename T>
+UniqueTfLiteTensor MakeLiteTensor(const std::vector<int>& shape,
+ const std::vector<T>& data) {
+ auto tensor = UniqueTfLiteTensor(new TfLiteTensor, [](TfLiteTensor* t) {
+ TfLiteTensorDataFree(t);
+ TfLiteIntArrayFree(t->dims);
+ delete t;
+ });
+ tensor->allocation_type = kTfLiteDynamic;
+ tensor->type = typeToTfLiteType<T>();
+ tensor->dims = ConvertVectorToTfLiteIntArray(shape);
+ tensor->data.raw = nullptr;
+ TfLiteTensorRealloc(data.size() * sizeof(T), tensor.get());
+ memcpy(tensor->data.raw, data.data(), data.size() * sizeof(T));
+ return tensor;
+}
+
+template <typename T>
+tensorflow::Tensor MakeTensor(const std::vector<int>& shape,
+ const std::vector<T>& data) {
+ BufferMap buffer_map; // BufferMap is the easiest way to build the tensor.
+ UniqueTfLiteTensor t1 = MakeLiteTensor<T>(shape, data);
+ buffer_map.SetFromTfLite(0, t1.get());
+ return buffer_map.GetTensor(0);
+}
+
+std::vector<int64> GetTensorShape(const tensorflow::Tensor& t) {
+ std::vector<int64> shape(t.dims());
+ for (int i = 0; i < t.dims(); ++i) {
+ shape[i] = t.dim_size(i);
+ }
+ return shape;
+}
+
+template <typename T>
+std::vector<T> GetTensorData(const tensorflow::Tensor& t) {
+ const T* data = t.flat<T>().data();
+ return std::vector<T>(data, data + t.NumElements());
+}
+
+TEST(BufferMapTest, EmptyBuffer) {
+ BufferMap buffer_map;
+ EXPECT_FALSE(buffer_map.HasTensor(0));
+}
+
+TEST(BufferMapTest, SetFromTfLite) {
+ BufferMap buffer_map;
+
+ UniqueTfLiteTensor t =
+ MakeLiteTensor<float>({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0});
+ buffer_map.SetFromTfLite(0, t.get());
+ ASSERT_TRUE(buffer_map.HasTensor(0));
+
+ EXPECT_THAT(GetTensorData<float>(buffer_map.GetTensor(0)),
+ ElementsAre(0, 0, 0, 0.123f, 0, 0));
+
+ // Also check details of the tensor.
+ tensorflow::Tensor out_tensor = buffer_map.GetTensor(0);
+ ASSERT_EQ(out_tensor.dtype(), tensorflow::DT_FLOAT);
+ ASSERT_EQ(out_tensor.NumElements(), 6);
+ ASSERT_THAT(GetTensorShape(out_tensor), ElementsAre(1, 2, 1, 3));
+}
+
+TEST(BufferMapTest, SetFromTfLiteTwice) {
+ UniqueTfLiteTensor t1 =
+ MakeLiteTensor<float>({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0});
+ UniqueTfLiteTensor t2 =
+ MakeLiteTensor<int>({1, 2, 4}, {0, 0, 0, 3, 0, 0, 1, 2});
+
+ BufferMap buffer_map;
+ buffer_map.SetFromTfLite(0, t1.get());
+ buffer_map.SetFromTfLite(0, t2.get());
+
+ EXPECT_THAT(GetTensorData<int>(buffer_map.GetTensor(0)),
+ ElementsAre(0, 0, 0, 3, 0, 0, 1, 2));
+}
+
+TEST(BufferMapTest, SetFromTensorFlow) {
+ tensorflow::Tensor t1 =
+ MakeTensor<float>({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0});
+
+ BufferMap buffer_map;
+ buffer_map.SetFromTensorFlow(0, t1);
+
+ EXPECT_THAT(GetTensorData<float>(buffer_map.GetTensor(0)),
+ ElementsAre(0, 0, 0, 0.123f, 0, 0));
+
+ // Also check details of the tensor.
+ tensorflow::Tensor out_tensor = buffer_map.GetTensor(0);
+ ASSERT_EQ(out_tensor.dtype(), tensorflow::DT_FLOAT);
+ ASSERT_EQ(out_tensor.NumElements(), 6);
+ ASSERT_THAT(GetTensorShape(out_tensor), ElementsAre(1, 2, 1, 3));
+}
+
+TEST(BufferMapTest, SetFromTensorFlowTwice) {
+ tensorflow::Tensor t1 =
+ MakeTensor<float>({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0});
+ tensorflow::Tensor t2 = MakeTensor<int>({1, 2, 4}, {0, 0, 0, 3, 0, 0, 1, 2});
+ BufferMap buffer_map;
+ buffer_map.SetFromTensorFlow(0, t1);
+ buffer_map.SetFromTensorFlow(0, t2);
+
+ EXPECT_THAT(GetTensorData<int>(buffer_map.GetTensor(0)),
+ ElementsAre(0, 0, 0, 3, 0, 0, 1, 2));
+}
+
+TEST(BufferMapTest, TfLiteOverwritesTensorFlow) {
+ tensorflow::Tensor t1 =
+ MakeTensor<float>({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0});
+ UniqueTfLiteTensor t2 =
+ MakeLiteTensor<int>({1, 2, 4}, {0, 0, 0, 3, 0, 0, 1, 2});
+
+ BufferMap buffer_map;
+ buffer_map.SetFromTensorFlow(0, t1);
+ buffer_map.SetFromTfLite(0, t2.get());
+
+ EXPECT_THAT(GetTensorData<int>(buffer_map.GetTensor(0)),
+ ElementsAre(0, 0, 0, 3, 0, 0, 1, 2));
+}
+
+TEST(BufferMapTest, TensorFlowOverwritesTfLite) {
+ tensorflow::Tensor t1 =
+ MakeTensor<float>({1, 2, 1, 3}, {0, 0, 0, 0.123f, 0, 0});
+ UniqueTfLiteTensor t2 =
+ MakeLiteTensor<int>({1, 2, 4}, {0, 0, 0, 3, 0, 0, 1, 2});
+ BufferMap buffer_map;
+ buffer_map.SetFromTfLite(0, t2.get());
+ buffer_map.SetFromTensorFlow(0, t1);
+
+ EXPECT_THAT(GetTensorData<float>(buffer_map.GetTensor(0)),
+ ElementsAre(0, 0, 0, 0.123f, 0, 0));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/eager/util.cc
index 04a852e515..e1879bdaff 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.cc
+++ b/tensorflow/contrib/lite/delegates/eager/util.cc
@@ -44,4 +44,27 @@ TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src,
return context->ResizeTensor(context, tensor, shape);
}
+TF_DataType GetTensorFlowDataType(TfLiteType type) {
+ switch (type) {
+ case kTfLiteNoType:
+ return TF_FLOAT;
+ case kTfLiteFloat32:
+ return TF_FLOAT;
+ case kTfLiteInt16:
+ return TF_INT16;
+ case kTfLiteInt32:
+ return TF_INT32;
+ case kTfLiteUInt8:
+ return TF_UINT8;
+ case kTfLiteInt64:
+ return TF_INT64;
+ case kTfLiteComplex64:
+ return TF_COMPLEX64;
+ case kTfLiteString:
+ return TF_STRING;
+ case kTfLiteBool:
+ return TF_BOOL;
+ }
+}
+
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h
index 2696ca8d0d..12b33b9b49 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.h
+++ b/tensorflow/contrib/lite/delegates/eager/util.h
@@ -15,6 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_
#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_
+#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
@@ -30,6 +31,10 @@ TfLiteStatus ConvertStatus(TfLiteContext* context,
// error and returns kTfLiteError if the shape can't be converted.
TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src,
TfLiteTensor* tensor);
+
+// Returns the TF C API Data type that corresponds to the given TfLiteType.
+TF_DataType GetTensorFlowDataType(TfLiteType type);
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/eager/util_test.cc
index 563f82dec3..53ed4db972 100644
--- a/tensorflow/contrib/lite/delegates/eager/util_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/util_test.cc
@@ -23,6 +23,8 @@ limitations under the License.
namespace tflite {
namespace {
+using tensorflow::DT_FLOAT;
+using tensorflow::Tensor;
using ::testing::ElementsAre;
struct TestContext : public TfLiteContext {
@@ -72,9 +74,6 @@ TEST(UtilTest, CopyShape) {
context.ReportError = ReportError;
context.ResizeTensor = ResizeTensor;
- using tensorflow::DT_FLOAT;
- using tensorflow::Tensor;
-
TfLiteTensor dst;
EXPECT_EQ(CopyShape(&context, Tensor(), &dst), kTfLiteOk);
@@ -90,6 +89,18 @@ TEST(UtilTest, CopyShape) {
"TF Lite");
}
+TEST(UtilTest, TypeConversions) {
+ EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteNoType));
+ EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteFloat32));
+ EXPECT_EQ(TF_INT16, GetTensorFlowDataType(kTfLiteInt16));
+ EXPECT_EQ(TF_INT32, GetTensorFlowDataType(kTfLiteInt32));
+ EXPECT_EQ(TF_UINT8, GetTensorFlowDataType(kTfLiteUInt8));
+ EXPECT_EQ(TF_INT64, GetTensorFlowDataType(kTfLiteInt64));
+ EXPECT_EQ(TF_COMPLEX64, GetTensorFlowDataType(kTfLiteComplex64));
+ EXPECT_EQ(TF_STRING, GetTensorFlowDataType(kTfLiteString));
+ EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h
index 89d9b4f5cf..3c4801183b 100644
--- a/tensorflow/contrib/lite/util.h
+++ b/tensorflow/contrib/lite/util.h
@@ -26,12 +26,17 @@ limitations under the License.
namespace tflite {
-// Converts a `std::vector` to a `TfLiteIntArray`.
+// Converts a `std::vector` to a `TfLiteIntArray`. The caller takes ownership
+// of the returned pointer.
TfLiteIntArray* ConvertVectorToTfLiteIntArray(const std::vector<int>& input);
+// Converts an array (of the given size) to a `TfLiteIntArray`. The caller
+// takes ownership of the returned pointer, and must make sure 'dims' has at
+// least 'rank' elemnts.
TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims);
// Checks whether a `TfLiteIntArray` and an int array have matching elements.
+// The caller must guarantee that 'b' has at least 'b_size' elements.
bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size,
const int* b);