diff options
Diffstat (limited to 'tensorflow/core/framework/tensor_slice_test.cc')
-rw-r--r-- | tensorflow/core/framework/tensor_slice_test.cc | 246 |
1 files changed, 246 insertions, 0 deletions
diff --git a/tensorflow/core/framework/tensor_slice_test.cc b/tensorflow/core/framework/tensor_slice_test.cc new file mode 100644 index 0000000000..5f718a56b6 --- /dev/null +++ b/tensorflow/core/framework/tensor_slice_test.cc @@ -0,0 +1,246 @@ +#include "tensorflow/core/framework/tensor_slice.h" + +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/logging.h" +#include <gtest/gtest.h> +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +// Basic tests +TEST(TensorSliceTest, Basic) { + { + // Repeatedly setting FullSlice should work. + TensorSlice s(3); + EXPECT_EQ("-:-:-", s.DebugString()); + + s.SetFullSlice(4); + EXPECT_EQ("-:-:-:-", s.DebugString()); + } +} + +// Testing for serialization and parsing for the string format of slices. +TEST(TensorSliceTest, Serialization) { + // Serialization + { + TensorSlice s({{0, -1}, {0, 10}, {14, 1}, {0, -1}}); + EXPECT_EQ("-:0,10:14,1:-", s.DebugString()); + } + + { + TensorSliceProto proto; + // Define ptxt outside ASSERT_TRUE call to work around bug in some + // versions of gcc that breaks when you use raw string literals + // inside macro expansions. + // See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971 + const char* ptxt = R"PROTO( + extent { } + extent { start: 0 length: 10 } + extent { start: 14 length: 1 } + extent { } + )PROTO"; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString(ptxt, &proto)); + TensorSlice s(proto); + EXPECT_EQ("-:0,10:14,1:-", s.DebugString()); + } + + // Parsing + { + TensorSlice s = TensorSlice::ParseOrDie("-:-:1,3:4,5"); + TensorSliceProto proto; + s.AsProto(&proto); + EXPECT_EQ( + "extent { } " + "extent { } " + "extent { start: 1 length: 3 } " + "extent { start: 4 length: 5 }", + proto.ShortDebugString()); + } + + // Failed parsing + { + TensorSlice slice; + Status s = TensorSlice::Parse("-:-:1,3:4:5", &slice); + EXPECT_EQ( + "Invalid argument: " + "Expected a pair of numbers or '-' but got '4': " + "string = -:-:1,3:4:5", + s.ToString()); + } + { + TensorSlice slice; + Status s = TensorSlice::Parse("-:-1,3", &slice); + EXPECT_EQ( + "Invalid argument: " + "Expected non-negative start and positive length but got " + "start = -1, length = 3: string = -:-1,3", + s.ToString()); + } +} + +// Testing the slice intersection +TEST(TensorSliceTest, Intersection) { + // "EVERYTHING" intersects with everything + { + TensorSlice a = TensorSlice::ParseOrDie("-:-"); + TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4"); + TensorSlice c; + EXPECT_TRUE(a.Intersect(b, &c)); + EXPECT_EQ("1,2:3,4", c.DebugString()); + } + + { + TensorSlice a = TensorSlice::ParseOrDie("-:-"); + TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4"); + TensorSlice c; + EXPECT_TRUE(b.Intersect(a, &c)); + EXPECT_EQ("1,2:3,4", c.DebugString()); + } + + // Overlap at all dimensions + { + TensorSlice a = TensorSlice::ParseOrDie("1,5:2,6:3,7:5,10"); + TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4:9,10:12,1"); + TensorSlice c; + EXPECT_TRUE(a.Intersect(b, &c)); + EXPECT_EQ("1,2:3,4:9,1:12,1", c.DebugString()); + } + + // A mixture of everything and non-trivial slices + { + TensorSlice a = TensorSlice::ParseOrDie("-:1,1"); + TensorSlice b = TensorSlice::ParseOrDie("-:0,2"); + TensorSlice c; + EXPECT_TRUE(a.Intersect(b, &c)); + EXPECT_EQ("-:1,1", c.DebugString()); + } + + // No overlap on dimension 3: "3,1" and "4,5" don't intersect + { + TensorSlice a = TensorSlice::ParseOrDie("1,2:3,1:5,6"); + TensorSlice b = TensorSlice::ParseOrDie("1,3:4,5:1,6"); + TensorSlice c; + EXPECT_FALSE(a.Intersect(b, &c)); + EXPECT_EQ("", c.DebugString()); + } + // No intersection when there are different dimensions + { + TensorSlice a = TensorSlice::ParseOrDie("1,2:3,1:-"); + TensorSlice b = TensorSlice::ParseOrDie("-:-"); + TensorSlice c; + EXPECT_FALSE(a.Intersect(b, &c)); + EXPECT_EQ("", c.DebugString()); + } +} + +// Testing applying a slice to a tensor shape +TEST(TensorSliceTest, SliceTensorShape) { + // A proper application + { + TensorSlice a = TensorSlice::ParseOrDie("1,1:-:4,1:2,6"); + TensorShape x({2, 4, 5, 8}); + TensorShape y; + EXPECT_OK(a.SliceTensorShape(x, &y)); + EXPECT_EQ( + "dim { size: 1 } " + "dim { size: 4 } " + "dim { size: 1 } " + "dim { size: 6 }", + y.DebugString()); + } + + // An invalid application -- dimension 2 is out of bound + { + TensorSlice a = TensorSlice::ParseOrDie("1,1:1,4:-:-"); + TensorShape x({2, 4, 5, 8}); + TensorShape y; + EXPECT_EQ( + "Internal: " + "Extent in dimension 1 out of bounds: " + "shape = dim { size: 2 } " + "dim { size: 4 } " + "dim { size: 5 } " + "dim { size: 8 }, " + "slice = 1,1:1,4:-:-", + a.SliceTensorShape(x, &y).ToString()); + EXPECT_EQ("", y.DebugString()); + } +} + +// Testing the computation of relative slices. +TEST(TensorSliceTest, ComputeRelative) { + // Easy case: base is "everything" + { + TensorSlice base = TensorSlice::ParseOrDie("-:-:-:-"); + TensorSlice sub = TensorSlice::ParseOrDie("-:1,2:-:3,4"); + TensorSlice relative; + base.ComputeRelative(sub, &relative); + EXPECT_EQ("-:1,2:-:3,4", relative.DebugString()); + } + + // A slightly more complicated case + { + TensorSlice base = TensorSlice::ParseOrDie("1,2:3,4:-:5,1"); + TensorSlice sub = TensorSlice::ParseOrDie("1,1:4,2:3,3:5,1"); + TensorSlice relative; + base.ComputeRelative(sub, &relative); + EXPECT_EQ("0,1:1,2:3,3:0,1", relative.DebugString()); + } +} + +TEST(TensorSliceTest, ExtentLength) { + TensorSliceProto proto; + // Define ptxt outside ASSERT_TRUE call to work around bug in some + // versions of gcc that breaks when you use raw string literals + // inside macro expansions. + // See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971 + const char* ptxt = R"PROTO( + extent { } + extent { start: 0 length: 10 } + extent { start: 14 length: 1 } + extent { } + )PROTO"; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString(ptxt, &proto)); + EXPECT_FALSE(TensorSlice::HasExtentLength(proto.extent(0))); + EXPECT_TRUE(TensorSlice::HasExtentLength(proto.extent(1))); + EXPECT_TRUE(TensorSlice::HasExtentLength(proto.extent(2))); + EXPECT_FALSE(TensorSlice::HasExtentLength(proto.extent(3))); + EXPECT_EQ(-1, TensorSlice::GetExtentLength(proto.extent(0))); + EXPECT_EQ(10, TensorSlice::GetExtentLength(proto.extent(1))); + EXPECT_EQ(1, TensorSlice::GetExtentLength(proto.extent(2))); + EXPECT_EQ(-1, TensorSlice::GetExtentLength(proto.extent(3))); +} + +TEST(TensorSliceTest, Deserialization) { + // Serialization of + // extent { length: 5 } + // extent { start: 0 length: 10 } + // extent { start: 14 length: 1 } + // extent { start: 1 } + // extent { } + // in proto2 and proto3: + const char pb2[] = + "\x0A\x02\x10\x05\x0A\x04\x08\x00" + "\x10\x0A\x0A\x04\x08\x0E\x10\x01\x0A\x02\x08\x01\x0A\x00"; + const char pb3[] = + "\x0A\x02\x10\x05\x0A\x02" + "\x10\x0A\x0A\x04\x08\x0E\x10\x01\x0A\x02\x08\x01\x0A\x00"; + // (The difference is that in the proto3 version, "start: 0" isn't included + // since 0 is start's default value.) + + TensorSliceProto proto2; + ASSERT_TRUE(proto2.ParseFromArray(pb2, sizeof(pb2) - 1)); + TensorSlice ts2(proto2); + + TensorSliceProto proto3; + ASSERT_TRUE(proto3.ParseFromArray(pb3, sizeof(pb3) - 1)); + TensorSlice ts3(proto3); + + // Both serializations should be interpreted the same. + EXPECT_EQ("0,5:0,10:14,1:-:-", ts2.DebugString()); + EXPECT_EQ("0,5:0,10:14,1:-:-", ts3.DebugString()); +} + +} // namespace +} // namespace tensorflow |