diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-27 21:11:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 21:15:59 -0700 |
commit | 7fd14feb9cbc690b362633639b27393576472c79 (patch) | |
tree | 7cba9491bb8ed3a476a017fab1e8e14e99d37673 /tensorflow/contrib/lite/kernels | |
parent | 97cba0b88cb3ce6a3f3cc66a8c4fd414bd3ac1a8 (diff) |
Kernel signature reworking, remove Dims from tensor functions.
PiperOrigin-RevId: 214883775
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/tensor.h | 4 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h | 29 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/tensor_test.cc | 36 |
3 files changed, 20 insertions, 49 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h index 765c3a03ef..689cea03e7 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -37,10 +37,6 @@ inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) { : nullptr; } -inline Dims<4> GetTensorDims(std::vector<int32_t> data) { - return GetTensorDims(data.data(), data.size()); -} - inline RuntimeShape GetTensorShape(std::vector<int32_t> data) { return RuntimeShape(data.size(), data.data()); } diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h index 5e688ce452..9f5b33d217 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h @@ -86,35 +86,6 @@ inline const bool* GetTensorData(const TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.b : nullptr; } -// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object -// even if the original tensors were not 4D. We should consider rewriting them -// to take a more generic 'shape' object. -inline Dims<4> GetTensorDims(const int data[], const int size) { - Dims<4> d; - for (int i = 0; i < 4; ++i) { - int src = size - i - 1; - if (src >= 0) { - d.sizes[i] = data[src]; - } else { - d.sizes[i] = 1; - } - } - d.strides[0] = 1; - for (int i = 1; i < 4; i++) { - d.strides[i] = d.strides[i - 1] * d.sizes[i - 1]; - } - return d; -} - -inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) { - if (tensor == nullptr) { - return Dims<4>(); - } - - auto* dims = tensor->dims; - return GetTensorDims(dims->data, dims->size); -} - inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) { if (tensor == nullptr) { return RuntimeShape(); diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc index bf2068d320..2ed73ba82d 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc @@ -21,28 +21,32 @@ namespace { using ::testing::ElementsAre; -TEST(TensorTest, GetTensorDims4D) { - Dims<4> d = GetTensorDims({2, 3, 4, 5}); - EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 2)); - EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60)); +TEST(TensorTest, GetTensorShape4D) { + RuntimeShape d = GetTensorShape({2, 3, 4, 5}); + EXPECT_THAT( + std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()), + ElementsAre(2, 3, 4, 5)); } -TEST(TensorTest, GetTensorDims3D) { - Dims<4> d = GetTensorDims({3, 4, 5}); - EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 1)); - EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60)); +TEST(TensorTest, GetTensorShape3D) { + RuntimeShape d = GetTensorShape({3, 4, 5}); + EXPECT_THAT( + std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()), + ElementsAre(3, 4, 5)); } -TEST(TensorTest, GetTensorDims2D) { - Dims<4> d = GetTensorDims({4, 5}); - EXPECT_THAT(d.sizes, ElementsAre(5, 4, 1, 1)); - EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 20)); +TEST(TensorTest, GetTensorShape2D) { + RuntimeShape d = GetTensorShape({4, 5}); + EXPECT_THAT( + std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()), + ElementsAre(4, 5)); } -TEST(TensorTest, GetTensorDims1D) { - Dims<4> d = GetTensorDims({5}); - EXPECT_THAT(d.sizes, ElementsAre(5, 1, 1, 1)); - EXPECT_THAT(d.strides, ElementsAre(1, 5, 5, 5)); +TEST(TensorTest, GetTensorShape1D) { + RuntimeShape d = GetTensorShape({5}); + EXPECT_THAT( + std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()), + ElementsAre(5)); } } // namespace |