aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-24 11:04:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 11:11:38 -0700
commit3bb3257a5f9675e6c094b9a6318d96d1bc27fc94 (patch)
tree730f2680227453abd863c1ed2661cab32635e167 /tensorflow/core/framework
parentf7017ef769bd603b61f25dfffc772e2153a9f076 (diff)
Add functionality to SubSlice a tensor.
PiperOrigin-RevId: 214295534
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/tensor.cc22
-rw-r--r--tensorflow/core/framework/tensor.h19
-rw-r--r--tensorflow/core/framework/tensor_test.cc36
3 files changed, 77 insertions, 0 deletions
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index eb9c79ff2d..3df677675e 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -812,6 +812,28 @@ Tensor Tensor::Slice(int64 start, int64 limit) const {
return ret;
}
+Tensor Tensor::SubSlice(int64 index) const {
+ CHECK_GE(dims(), 2); // Crash ok.
+ CHECK_LE(0, index); // Crash ok.
+ int64 dim0_size = shape_.dim_size(0);
+ CHECK_LE(index, dim0_size); // Crash ok.
+ Tensor ret;
+ ret.shape_ = shape_;
+ ret.shape_.RemoveDim(0);
+ ret.set_dtype(dtype());
+ ret.buf_ = nullptr;
+ if (dim0_size > 0) {
+ const int64 elems_per_dim0 = NumElements() / dim0_size;
+ const int64 delta = index * elems_per_dim0;
+ const int64 num_elems = elems_per_dim0;
+ if (buf_) {
+ DataType dt = dtype();
+ CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
+ }
+ }
+ return ret;
+}
+
bool Tensor::FromProto(const TensorProto& proto) {
return FromProto(cpu_allocator(), proto);
}
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index e412329498..8a0c70fef2 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -200,10 +200,29 @@ class Tensor {
/// must check the returned tensor's alignment before calling certain
/// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
///
+ /// NOTE: When fed with an N-dimensional tensor, this method returns a tensor
+ /// also with N dimensions. If you want to select a sub tensor, see SubSlice.
+ ///
/// REQUIRES: `dims()` >= 1
/// REQUIRES: `0 <= dim0_start <= dim0_limit <= dim_size(0)`
Tensor Slice(int64 dim0_start, int64 dim0_limit) const;
+ /// \brief Select a subslice from this tensor along the 1st dimension.
+ ///
+ /// When fed with an N-dimensional tensor, this method returns a tensor with
+ /// N-1 dimensions, where the returned tensor is a subslice of the input
+ /// tensor along the first dimension. The N-1 dimensions of the returned
+ /// tensor are the last N-1 dimensions of the input tensor.
+ ///
+ /// NOTE: The returned tensor may not satisfy the same alignment
+ /// requirement as this tensor depending on the shape. The caller
+ /// must check the returned tensor's alignment before calling certain
+ /// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
+ ///
+ /// REQUIRES: `dims()` >= 2
+ /// REQUIRES: `0 <= dim0_start < dim_size(0)`
+ Tensor SubSlice(int64 index) const;
+
/// \brief Parse `other` and construct the tensor.
/// Returns `true` iff the parsing succeeds. If the parsing fails,
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index fc05c86990..0bfa53e6c5 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -1228,6 +1228,42 @@ TEST(Tensor, Slice_Basic) {
}
}
+TEST(Tensor, SubSlice_Basic) {
+ { // General
+ Tensor x(DT_FLOAT, TensorShape({10, 4, 36}));
+ // Fills in known values.
+ for (int i = 0; i < 10; ++i) {
+ x.SubSlice(i).flat<float>().setConstant(i * 1.f);
+ }
+ // A simple sub-slice along dim0.
+ Tensor y = x.SubSlice(5);
+ EXPECT_TRUE(y.shape().IsSameSize(TensorShape({4, 36})));
+ auto tx = x.tensor<float, 3>();
+ auto ty = y.tensor<float, 2>();
+ for (int j = 0; j < 4; ++j) {
+ for (int k = 0; k < 36; ++k) {
+ EXPECT_EQ(ty(j, k), 5.0);
+ EXPECT_EQ(&tx(5, j, k), &ty(j, k));
+ }
+ }
+ }
+ {
+ // Test unaligned access via a SubSlice.
+ Tensor x(DT_FLOAT, TensorShape({30, 5}));
+ x.flat<float>().setConstant(0.0);
+
+ // Take an unaligned subslice.
+ Tensor y = x.SubSlice(1);
+#if EIGEN_MAX_ALIGN_BYTES > 0
+ EXPECT_FALSE(y.IsAligned());
+#endif
+ y.unaligned_flat<float>().setConstant(1.0);
+ for (int64 i = 0; i < y.NumElements(); ++i) {
+ EXPECT_EQ(1.0, y.unaligned_flat<float>()(i));
+ }
+ }
+}
+
template <typename T>
Tensor MkTensor(DataType dt, const TensorShape& shape,
std::vector<T> init_values) {