diff options
-rw-r--r-- | tensorflow/core/framework/tensor_slice.h | 4 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor_slice_test.cc | 4 |
2 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/core/framework/tensor_slice.h b/tensorflow/core/framework/tensor_slice.h index 8c4a2adeb3..fca40e0894 100644 --- a/tensorflow/core/framework/tensor_slice.h +++ b/tensorflow/core/framework/tensor_slice.h @@ -94,7 +94,9 @@ class TensorSlice { } // If we have a full slice along dimension "d". - bool IsFullAt(int d) const { return lengths_[d] < 0; } + bool IsFullAt(int d) const { + return lengths_[d] == kFullExtent && starts_[d] == 0; + } // If this is a full slice, i.e. IsFullAt(d) for every d. bool IsFull() const; diff --git a/tensorflow/core/framework/tensor_slice_test.cc b/tensorflow/core/framework/tensor_slice_test.cc index e26c840998..bb32fa0724 100644 --- a/tensorflow/core/framework/tensor_slice_test.cc +++ b/tensorflow/core/framework/tensor_slice_test.cc @@ -273,8 +273,8 @@ TEST(TensorSliceTest, Deserialization) { TensorSlice ts3(proto3); // Both serializations should be interpreted the same. - EXPECT_EQ("0,5:0,10:14,1:-:-", ts2.DebugString()); - EXPECT_EQ("0,5:0,10:14,1:-:-", ts3.DebugString()); + EXPECT_EQ("0,5:0,10:14,1:1,-1:-", ts2.DebugString()); + EXPECT_EQ("0,5:0,10:14,1:1,-1:-", ts3.DebugString()); } TEST(TensorSliceTest, UpdateToCover) { |