aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/shape_util_test.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-01-09 12:04:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-09 12:26:35 -0800
commit1e67c90e2caceeff82d09793d1ef5fa0300d219b (patch)
tree6567ea8b0fa01fcfcd608b7e4c636865d33c7032 /tensorflow/compiler/xla/shape_util_test.cc
parent7ad7e4dfae4344d6b955b5eb61dc4b6bb792f1b3 (diff)
Initial open-source release of XLA: Accelerated Linear Algebra.
XLA is a compiler-based linear algebra execution engine that targets CPUs, GPUs and custom accelerators. XLA is still experimental; we are releasing it early to get the community involved. Change: 143990941
Diffstat (limited to 'tensorflow/compiler/xla/shape_util_test.cc')
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc506
1 files changed, 506 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
new file mode 100644
index 0000000000..4e8a496e7e
--- /dev/null
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -0,0 +1,506 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/shape_util.h"
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+TEST(ShapeUtilTest, GetDimensionHelperCanNegativeIndex) {
+ Shape matrix = ShapeUtil::MakeShape(F32, {2, 3});
+ EXPECT_EQ(3, ShapeUtil::GetDimension(matrix, -1));
+ EXPECT_EQ(2, ShapeUtil::GetDimension(matrix, -2));
+}
+
+TEST(ShapeUtilTest, GetDimensionHelperExampleInDocumentationTest) {
+ auto shape = ShapeUtil::MakeShape(F32, {1, 2, 3, 4});
+ ASSERT_EQ(4, ShapeUtil::GetDimension(shape, -1));
+}
+
+TEST(ShapeUtilTest, NegativeIndexOobFails) {
+ Shape matrix = ShapeUtil::MakeShape(F32, {2, 3});
+ ASSERT_DEATH(ShapeUtil::GetDimension(matrix, -3), "dimension_number >= 0");
+}
+
+TEST(ShapeUtilTest, Rank1DimensionIndexing) {
+ Shape shape = ShapeUtil::MakeShape(F32, {3});
+ ASSERT_EQ(3, shape.dimensions(0));
+}
+
+TEST(ShapeUtilTest, Rank2DimensionIndexing) {
+ Shape shape = ShapeUtil::MakeShape(F32, {3, 2});
+ ASSERT_EQ(2, shape.dimensions(1));
+ ASSERT_EQ(3, shape.dimensions(0));
+}
+
+TEST(ShapeUtilTest, Rank3DimensionIndexing) {
+ Shape shape = ShapeUtil::MakeShape(F32, {3, 2, 7});
+ ASSERT_EQ(7, shape.dimensions(2));
+ ASSERT_EQ(2, shape.dimensions(1));
+ ASSERT_EQ(3, shape.dimensions(0));
+}
+
+TEST(ShapeUtilTest, Rank4DimensionIndexing) {
+ Shape shape = ShapeUtil::MakeShape(F32, {3, 2, 7, 8});
+ ASSERT_EQ(8, shape.dimensions(3));
+ ASSERT_EQ(7, shape.dimensions(2));
+ ASSERT_EQ(2, shape.dimensions(1));
+ ASSERT_EQ(3, shape.dimensions(0));
+}
+
+TEST(ShapeUtilTest, ParseShapeStringR2F32) {
+ string shape_string = "f32[123,456]";
+ Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie();
+ Shape expected = ShapeUtil::MakeShape(F32, {123, 456});
+ ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
+ << "expected: " << ShapeUtil::HumanString(expected)
+ << "actual: " << ShapeUtil::HumanString(actual);
+}
+
+TEST(ShapeUtilTest, CompatibleIdenticalShapes) {
+ Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2});
+ Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
+ ASSERT_TRUE(ShapeUtil::Compatible(shape1, shape2));
+}
+
+TEST(ShapeUtilTest, CompatibleNotIdenticalShapes) {
+ Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2});
+ auto layout_1 = shape_1.mutable_layout();
+ layout_1->clear_minor_to_major();
+ layout_1->add_minor_to_major(0);
+ layout_1->add_minor_to_major(1);
+
+ Shape shape_2 = ShapeUtil::MakeShape(F32, {3, 2});
+ auto layout_2 = shape_2.mutable_layout();
+ layout_2->clear_minor_to_major();
+ layout_2->add_minor_to_major(1);
+ layout_2->add_minor_to_major(0);
+
+ EXPECT_FALSE(ShapeUtil::Equal(shape_1, shape_2));
+ EXPECT_TRUE(ShapeUtil::Compatible(shape_1, shape_2));
+}
+
+TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) {
+ Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2});
+ Shape shape_2 = ShapeUtil::MakeShape(PRED, {3, 2});
+ EXPECT_FALSE(ShapeUtil::Compatible(shape_1, shape_2));
+}
+
+TEST(ShapeUtilTest, CompatibleTuples) {
+ Shape tuple1 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})});
+ Shape tuple2 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})});
+ EXPECT_TRUE(ShapeUtil::Compatible(tuple1, tuple2));
+}
+
+TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) {
+ Shape tuple1 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
+ Shape tuple2 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})});
+ EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2));
+}
+
+TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) {
+ Shape tuple1 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
+ Shape tuple2 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(S32, {3, 2})});
+ EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2));
+}
+
+TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentDimensions) {
+ Shape tuple1 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
+ Shape tuple2 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {4, 2})});
+ EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2));
+}
+
+TEST(ShapeUtilTest, EmptyLayoutEqualsMissingLayout) {
+ // A shape with a missing layout should be equal to a shape with an empty
+ // layout.
+ Shape scalar1 = ShapeUtil::MakeShape(F32, {});
+ Shape scalar2 = ShapeUtil::MakeShape(F32, {});
+
+ EXPECT_TRUE(ShapeUtil::Equal(scalar1, scalar2));
+
+ scalar1.clear_layout(); // Remove layout field.
+ scalar2.mutable_layout(); // Create empty layout field.
+
+ EXPECT_TRUE(ShapeUtil::Equal(scalar1, scalar2));
+}
+
+TEST(ShapeUtilTest, ScalarUnpopulatedLayoutEqualsScalarLayout) {
+ Shape scalar_unpopulated = ShapeUtil::MakeShape(F32, {});
+ scalar_unpopulated.clear_layout();
+ ASSERT_FALSE(scalar_unpopulated.has_layout())
+ << ShapeUtil::HumanStringWithLayout(scalar_unpopulated);
+
+ const Shape scalar_populated = ShapeUtil::MakeShapeWithLayout(F32, {}, {});
+ ASSERT_TRUE(scalar_populated.has_layout())
+ << ShapeUtil::HumanStringWithLayout(scalar_populated);
+
+ EXPECT_TRUE(ShapeUtil::Equal(scalar_unpopulated, scalar_populated));
+}
+
+TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) {
+ EXPECT_EQ(4, ShapeUtil::ByteSizeOfPrimitiveType(F32));
+ EXPECT_EQ(4, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F32, {})));
+ EXPECT_EQ(800, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F32, {10, 20})));
+
+ EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(F64));
+ EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {})));
+ EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {10, 20})));
+}
+
+TEST(ShapeUtilTest, ByteSizeOfWithPadding) {
+ EXPECT_EQ(4, ShapeUtil::ByteSizeOfPrimitiveType(F32));
+ Shape shape = ShapeUtil::MakeShape(F32, {10, 20});
+ EXPECT_EQ(800, ShapeUtil::ByteSizeOf(shape));
+
+ shape.mutable_layout()->add_padded_dimensions(15);
+ shape.mutable_layout()->add_padded_dimensions(21);
+ EXPECT_EQ(15 * 21 * 4, ShapeUtil::ByteSizeOf(shape));
+}
+
+TEST(ShapeUtilTest, NestedTuple) {
+ EXPECT_FALSE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape({})));
+ EXPECT_FALSE(ShapeUtil::IsNestedTuple(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {})})));
+ EXPECT_TRUE(ShapeUtil::IsNestedTuple(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({})})));
+ EXPECT_FALSE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})})));
+ EXPECT_TRUE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeTupleShape({})})));
+ EXPECT_TRUE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeTupleShape({}), ShapeUtil::MakeShape(S32, {})})));
+ EXPECT_TRUE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeTupleShape({}), ShapeUtil::MakeTupleShape({})})));
+}
+
+TEST(ShapeUtilTest, ElementsIn) {
+ EXPECT_EQ(1, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {})));
+ EXPECT_EQ(0, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {0})));
+ EXPECT_EQ(1, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {1})));
+ EXPECT_EQ(1, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {1, 1})));
+ EXPECT_EQ(2, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {2})));
+ EXPECT_EQ(2, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {2, 1})));
+ EXPECT_EQ(15, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {3, 5})));
+ EXPECT_EQ(0, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {3, 0, 5})));
+ EXPECT_EQ(0, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {0, 3, 0})));
+ EXPECT_EQ(15, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {1, 3, 5})));
+ EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17})));
+}
+
+TEST(ShapeUtilTest, HasZeroElements) {
+ EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {})));
+ EXPECT_EQ(true, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0})));
+ EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1})));
+ EXPECT_EQ(false,
+ ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 1})));
+ EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2})));
+ EXPECT_EQ(false,
+ ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2, 1})));
+ EXPECT_EQ(false,
+ ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 5})));
+ EXPECT_EQ(true,
+ ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 0, 5})));
+ EXPECT_EQ(true,
+ ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0, 3, 0})));
+ EXPECT_EQ(false,
+ ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 3, 5})));
+ EXPECT_EQ(false,
+ ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {13, 17})));
+}
+
+TEST(ShapeUtilTest, SameDimensions) {
+ EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {}),
+ ShapeUtil::MakeShape(S32, {})));
+ EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {}),
+ ShapeUtil::MakeShape(F32, {})));
+ EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}),
+ ShapeUtil::MakeShape(S32, {1})));
+ EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {0}),
+ ShapeUtil::MakeShape(S32, {0})));
+ EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {2}),
+ ShapeUtil::MakeShape(S32, {2})));
+ EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}),
+ ShapeUtil::MakeShape(F32, {2})));
+ EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {0, 0}),
+ ShapeUtil::MakeShape(F32, {0})));
+ EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}),
+ ShapeUtil::MakeShape(F32, {1, 1})));
+ EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {}),
+ ShapeUtil::MakeShape(F32, {1})));
+ EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}),
+ ShapeUtil::MakeShape(F32, {1, 1})));
+ EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}),
+ ShapeUtil::MakeShape(F32, {1, 0})));
+ EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1, 1}),
+ ShapeUtil::MakeShape(F32, {1, 2})));
+}
+
+TEST(ShapeUtilTest, GetSubshape) {
+ // Test array shape.
+ Shape array_shape = ShapeUtil::MakeShape(F32, {42, 42, 123});
+ EXPECT_TRUE(
+ ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(array_shape, {})));
+ EXPECT_TRUE(ShapeUtil::Equal(
+ array_shape, *ShapeUtil::GetMutableSubshape(&array_shape, {})));
+
+ // Test tuple shape.
+ Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({array_shape, array_shape, array_shape});
+ EXPECT_TRUE(
+ ShapeUtil::Equal(tuple_shape, ShapeUtil::GetSubshape(tuple_shape, {})));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(tuple_shape, {0})));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(tuple_shape, {1})));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(tuple_shape, {2})));
+
+ // Test nested tuple shape.
+ Shape nested_tuple_shape = ShapeUtil::MakeTupleShape(
+ {array_shape, ShapeUtil::MakeTupleShape({array_shape, array_shape}),
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeTupleShape({array_shape, array_shape}),
+ array_shape})});
+ EXPECT_TRUE(ShapeUtil::Equal(nested_tuple_shape,
+ ShapeUtil::GetSubshape(nested_tuple_shape, {})));
+ EXPECT_TRUE(ShapeUtil::Equal(
+ array_shape, ShapeUtil::GetSubshape(nested_tuple_shape, {0})));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(ShapeUtil::MakeTupleShape({array_shape, array_shape}),
+ ShapeUtil::GetSubshape(nested_tuple_shape, {1})));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(ShapeUtil::MakeTupleShape({array_shape, array_shape}),
+ ShapeUtil::GetSubshape(nested_tuple_shape, {2, 0})));
+}
+
+TEST(ShapeUtilTest, HumanString) {
+ Shape opaque = ShapeUtil::MakeOpaqueShape();
+ Shape scalar = ShapeUtil::MakeShape(F32, {});
+ Shape matrix = ShapeUtil::MakeShape(U32, {1, 2});
+ Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
+ Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2});
+ Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix});
+
+ EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque));
+ EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar));
+ EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix));
+ EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2));
+ EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])",
+ ShapeUtil::HumanString(tuple));
+ EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
+ ShapeUtil::HumanString(nested_tuple));
+
+ EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque));
+ EXPECT_EQ("f32[]", ShapeUtil::HumanStringWithLayout(scalar));
+ EXPECT_EQ("u32[1,2] {1,0}", ShapeUtil::HumanStringWithLayout(matrix));
+ EXPECT_EQ("s32[3,4] {0,1}", ShapeUtil::HumanStringWithLayout(matrix2));
+ EXPECT_EQ("(opaque[], f32[], u32[1,2] {1,0}, s32[3,4] {0,1})",
+ ShapeUtil::HumanStringWithLayout(tuple));
+ EXPECT_EQ(
+ "((opaque[], f32[], u32[1,2] {1,0}, s32[3,4] {0,1}), u32[1,2] {1,0})",
+ ShapeUtil::HumanStringWithLayout(nested_tuple));
+
+ ProgramShape prog = ShapeUtil::MakeProgramShape(
+ {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple);
+ EXPECT_EQ(
+ "((unknown): opaque[], "
+ "(unknown): f32[], "
+ "(unknown): u32[1,2], "
+ "(unknown): s32[3,4], "
+ "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), "
+ "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> "
+ "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
+ ShapeUtil::HumanString(prog));
+
+ prog.add_parameter_names("arg0");
+ prog.add_parameter_names("scalar");
+ prog.add_parameter_names("matrix");
+ prog.add_parameter_names("matrix2");
+ prog.add_parameter_names("tuple");
+ prog.add_parameter_names("nested_tuple");
+ EXPECT_EQ(
+ "(arg0: opaque[], "
+ "scalar: f32[], "
+ "matrix: u32[1,2], "
+ "matrix2: s32[3,4], "
+ "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), "
+ "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> "
+ "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
+ ShapeUtil::HumanString(prog));
+}
+
+TEST(ShapeUtilTest, ForEachSubshapeArray) {
+ const Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
+ int calls = 0;
+ EXPECT_IS_OK(ShapeUtil::ForEachSubshape(
+ shape, [&calls, &shape](const Shape& subshape, const ShapeIndex& index) {
+ EXPECT_EQ(&shape, &subshape);
+ EXPECT_TRUE(index.empty());
+ ++calls;
+ return tensorflow::Status::OK();
+ }));
+ EXPECT_EQ(1, calls);
+}
+
+TEST(ShapeUtilTest, ForEachSubshapeNestedTuple) {
+ const Shape shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {42}),
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {101}),
+ ShapeUtil::MakeShape(PRED, {33})})});
+ int calls = 0;
+ EXPECT_IS_OK(ShapeUtil::ForEachSubshape(
+ shape, [&calls, &shape](const Shape& subshape, const ShapeIndex& index) {
+ EXPECT_TRUE(
+ ShapeUtil::Equal(subshape, ShapeUtil::GetSubshape(shape, index)));
+ if (calls == 0) {
+ // Visitation should go from outside in.
+ EXPECT_TRUE(index.empty());
+ } else if (calls == 4) {
+ // Last visitation should be to the array with 33 elements.
+ EXPECT_EQ(33, ShapeUtil::ElementsIn(subshape));
+ }
+ ++calls;
+ return tensorflow::Status::OK();
+ }));
+ EXPECT_EQ(5, calls);
+}
+
+TEST(ShapeUtilTest, ForEachMutableSubshapeNestedTuple) {
+ Shape shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {42}),
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {101}),
+ ShapeUtil::MakeShape(PRED, {33})})});
+ int calls = 0;
+ EXPECT_IS_OK(ShapeUtil::ForEachMutableSubshape(
+ &shape, [&calls, &shape](const Shape* subshape, const ShapeIndex& index) {
+ // Pointer values should be equal
+ EXPECT_EQ(subshape, ShapeUtil::GetMutableSubshape(&shape, index));
+ if (calls == 0) {
+ // Visitation should go from outside in.
+ EXPECT_TRUE(index.empty());
+ } else if (calls == 4) {
+ // Last visitation should be to the array with 33 elements.
+ EXPECT_EQ(33, ShapeUtil::ElementsIn(*subshape));
+ }
+ ++calls;
+ return tensorflow::Status::OK();
+ }));
+ EXPECT_EQ(5, calls);
+}
+
+TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) {
+ Shape shape0 = ShapeUtil::MakeShape(S32, {9, 1, 4});
+ Shape shape1 = ShapeUtil::MakeShape(S32, {1, 9, 4, 1});
+ Shape shape2 = ShapeUtil::MakeShape(S32, {3, 1, 12});
+ EXPECT_TRUE(std::get<0>(
+ ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape1)));
+ EXPECT_FALSE(std::get<0>(
+ ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2)));
+}
+
+TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) {
+ // All output dimensions should be unmodified. One of the input dimensions is
+ // modified because the input rank is larger by one.
+ EXPECT_EQ(3,
+ ShapeUtil::DimensionsUnmodifiedByReshape(
+ ShapeUtil::MakeShape(S32, {1, 1, 1, 1}),
+ ShapeUtil::MakeShape(S32, {1, 1, 1}))
+ .size());
+}
+
+TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1_to_1x1x1x1) {
+ // All input dimensions should be unmodified. One of the output dimensions is
+ // modified because the output rank is larger by one.
+ EXPECT_EQ(3,
+ ShapeUtil::DimensionsUnmodifiedByReshape(
+ ShapeUtil::MakeShape(S32, {1, 1, 1}),
+ ShapeUtil::MakeShape(S32, {1, 1, 1, 1}))
+ .size());
+}
+
+TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_4x1x3x5x6x7_to_2x6x1x5x1x42) {
+ // The only matching dimension is the one with size 5.
+ // 4, 1, 3, 5, 6, 7
+ // |
+ // 2, 6, 1, 5, 1, 42
+ EXPECT_TRUE(
+ ContainersEqual(ShapeUtil::DimensionsUnmodifiedByReshape(
+ ShapeUtil::MakeShape(S32, {4, 1, 3, 5, 6, 7}),
+ ShapeUtil::MakeShape(S32, {2, 6, 1, 5, 1, 42})),
+ std::vector<std::pair<int64, int64>>({{3, 3}})));
+}
+
+TEST(ShapeUtilTest, ReshapeIsBitcast_3x4_6x2) {
+ for (bool input_is_row_major : {true, false}) {
+ for (bool output_is_row_major : {true, false}) {
+ Layout input_layout = input_is_row_major ? LayoutUtil::MakeLayout({1, 0})
+ : LayoutUtil::MakeLayout({0, 1});
+ Layout output_layout = output_is_row_major
+ ? LayoutUtil::MakeLayout({1, 0})
+ : LayoutUtil::MakeLayout({0, 1});
+ // Suppose the input is logically (i.e. ignoring its layout)
+ // 0 1 2 3
+ // 4 5 6 7
+ // 8 9 10 11
+ //
+ // The reshape transforms the input to logically
+ // 0 1
+ // 2 3
+ // 4 5
+ // 6 7
+ // 8 9
+ // 10 11
+ //
+ // The input and the output have the same underlying data only if they
+ // are both row-major.
+ EXPECT_EQ(
+ ShapeUtil::ReshapeIsBitcast(
+ ShapeUtil::MakeShapeWithLayout(
+ F32, {3, 4}, AsInt64Slice(input_layout.minor_to_major())),
+ ShapeUtil::MakeShapeWithLayout(
+ F32, {6, 2}, AsInt64Slice(output_layout.minor_to_major()))),
+ input_is_row_major && output_is_row_major);
+ }
+ }
+}
+
+TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) {
+ EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(
+ ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {1, 0, 2}),
+ ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1})));
+}
+
+TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) {
+ EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast(
+ ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}),
+ ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1})));
+}
+
+} // namespace
+} // namespace xla