diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-24 11:04:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 11:11:38 -0700 |
commit | 3bb3257a5f9675e6c094b9a6318d96d1bc27fc94 (patch) | |
tree | 730f2680227453abd863c1ed2661cab32635e167 /tensorflow/core/framework | |
parent | f7017ef769bd603b61f25dfffc772e2153a9f076 (diff) |
Add functionality to SubSlice a tensor.
PiperOrigin-RevId: 214295534
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/tensor.cc | 22 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor.h | 19 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor_test.cc | 36 |
3 files changed, 77 insertions, 0 deletions
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index eb9c79ff2d..3df677675e 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -812,6 +812,28 @@ Tensor Tensor::Slice(int64 start, int64 limit) const { return ret; } +Tensor Tensor::SubSlice(int64 index) const { + CHECK_GE(dims(), 2); // Crash ok. + CHECK_LE(0, index); // Crash ok. + int64 dim0_size = shape_.dim_size(0); + CHECK_LE(index, dim0_size); // Crash ok. + Tensor ret; + ret.shape_ = shape_; + ret.shape_.RemoveDim(0); + ret.set_dtype(dtype()); + ret.buf_ = nullptr; + if (dim0_size > 0) { + const int64 elems_per_dim0 = NumElements() / dim0_size; + const int64 delta = index * elems_per_dim0; + const int64 num_elems = elems_per_dim0; + if (buf_) { + DataType dt = dtype(); + CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems)); + } + } + return ret; +} + bool Tensor::FromProto(const TensorProto& proto) { return FromProto(cpu_allocator(), proto); } diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index e412329498..8a0c70fef2 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -200,10 +200,29 @@ class Tensor { /// must check the returned tensor's alignment before calling certain /// methods that have alignment requirement (e.g., `flat()`, `tensor()`). /// + /// NOTE: When fed with an N-dimensional tensor, this method returns a tensor + /// also with N dimensions. If you want to select a sub tensor, see SubSlice. + /// /// REQUIRES: `dims()` >= 1 /// REQUIRES: `0 <= dim0_start <= dim0_limit <= dim_size(0)` Tensor Slice(int64 dim0_start, int64 dim0_limit) const; + /// \brief Select a subslice from this tensor along the 1st dimension. + /// + /// When fed with an N-dimensional tensor, this method returns a tensor with + /// N-1 dimensions, where the returned tensor is a subslice of the input + /// tensor along the first dimension. The N-1 dimensions of the returned + /// tensor are the last N-1 dimensions of the input tensor. + /// + /// NOTE: The returned tensor may not satisfy the same alignment + /// requirement as this tensor depending on the shape. The caller + /// must check the returned tensor's alignment before calling certain + /// methods that have alignment requirement (e.g., `flat()`, `tensor()`). + /// + /// REQUIRES: `dims()` >= 2 + /// REQUIRES: `0 <= dim0_start < dim_size(0)` + Tensor SubSlice(int64 index) const; + /// \brief Parse `other` and construct the tensor. /// Returns `true` iff the parsing succeeds. If the parsing fails, diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index fc05c86990..0bfa53e6c5 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -1228,6 +1228,42 @@ TEST(Tensor, Slice_Basic) { } } +TEST(Tensor, SubSlice_Basic) { + { // General + Tensor x(DT_FLOAT, TensorShape({10, 4, 36})); + // Fills in known values. + for (int i = 0; i < 10; ++i) { + x.SubSlice(i).flat<float>().setConstant(i * 1.f); + } + // A simple sub-slice along dim0. + Tensor y = x.SubSlice(5); + EXPECT_TRUE(y.shape().IsSameSize(TensorShape({4, 36}))); + auto tx = x.tensor<float, 3>(); + auto ty = y.tensor<float, 2>(); + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 36; ++k) { + EXPECT_EQ(ty(j, k), 5.0); + EXPECT_EQ(&tx(5, j, k), &ty(j, k)); + } + } + } + { + // Test unaligned access via a SubSlice. + Tensor x(DT_FLOAT, TensorShape({30, 5})); + x.flat<float>().setConstant(0.0); + + // Take an unaligned subslice. + Tensor y = x.SubSlice(1); +#if EIGEN_MAX_ALIGN_BYTES > 0 + EXPECT_FALSE(y.IsAligned()); +#endif + y.unaligned_flat<float>().setConstant(1.0); + for (int64 i = 0; i < y.NumElements(); ++i) { + EXPECT_EQ(1.0, y.unaligned_flat<float>()(i)); + } + } +} + template <typename T> Tensor MkTensor(DataType dt, const TensorShape& shape, std::vector<T> init_values) { |