aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/framework/tensor.cc2
-rw-r--r--tensorflow/core/framework/tensor.h2
-rw-r--r--tensorflow/core/framework/tensor_test.cc3
3 files changed, 5 insertions, 2 deletions
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 3df677675e..1dea6da911 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -813,7 +813,7 @@ Tensor Tensor::Slice(int64 start, int64 limit) const {
}
Tensor Tensor::SubSlice(int64 index) const {
- CHECK_GE(dims(), 2); // Crash ok.
+ CHECK_GE(dims(), 1); // Crash ok.
CHECK_LE(0, index); // Crash ok.
int64 dim0_size = shape_.dim_size(0);
CHECK_LE(index, dim0_size); // Crash ok.
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 8a0c70fef2..d0f9eb56e2 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -219,7 +219,7 @@ class Tensor {
/// must check the returned tensor's alignment before calling certain
/// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
///
- /// REQUIRES: `dims()` >= 2
+ /// REQUIRES: `dims()` >= 1
/// REQUIRES: `0 <= dim0_start < dim_size(0)`
Tensor SubSlice(int64 index) const;
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index 0bfa53e6c5..c596604143 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -1246,6 +1246,9 @@ TEST(Tensor, SubSlice_Basic) {
EXPECT_EQ(&tx(5, j, k), &ty(j, k));
}
}
+ Tensor z = y.SubSlice(3).SubSlice(31);
+ auto tz = z.unaligned_flat<float>();
+ EXPECT_EQ(*tz.data(), 5.0);
}
{
// Test unaligned access via a SubSlice.