diff options
author | 2017-08-19 16:33:35 -0700 | |
---|---|---|
committer | 2017-08-19 16:36:37 -0700 | |
commit | 8903e5fc72aba856c5567d09b41340a5e32d4f8f (patch) | |
tree | ea018a74d0422c073f895f57b1c9ffbba52b355b /tensorflow/compiler/xla/shape_util_test.cc | |
parent | c572ca84f40bce6a6114fea2ae4c5da174ffd73b (diff) |
[XLA] Make ShapeUtil::ParseShapeString more complete.
Handle tuples, nested tuples, more element types.
PiperOrigin-RevId: 165826211
Diffstat (limited to 'tensorflow/compiler/xla/shape_util_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/shape_util_test.cc | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 69ef6175cc..9635e5ad2e 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -78,6 +78,30 @@ TEST(ShapeUtilTest, ParseShapeStringR2F32) { << "actual: " << ShapeUtil::HumanString(actual); } +TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { + string shape_string = "(f32[1572864],s8[5120,1024])"; + Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie(); + Shape expected = + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}), + ShapeUtil::MakeShape(S8, {5120, 1024})}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { + string shape_string = "(f32[1],(f32[2]), f32[3])"; + Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie(); + Shape expected = ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {1}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}), + ShapeUtil::MakeShape(F32, {3}), + }); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + TEST(ShapeUtilTest, CompatibleIdenticalShapes) { Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2}); Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); |