aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/shape_util_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/shape_util_test.cc')
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc45
1 files changed, 45 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index b6f30af381..e5dd62ae9a 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
+#include <numeric>
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
@@ -22,12 +23,23 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace {
using ::testing::ElementsAre;
+TEST(ShapeUtilTest, ShapeIndexViewTest) {
+ ShapeIndex index = {1, 2, 3, 4};
+ ShapeIndexView index_view(index, 1);
+ EXPECT_EQ(3, index_view.size());
+ EXPECT_EQ(ShapeIndexView({2, 3, 4}), index_view);
+ EXPECT_EQ(ShapeIndexView({3, 4}), index_view.ConsumeFront());
+ EXPECT_EQ(ShapeIndexView({2, 3}), index_view.ConsumeBack());
+}
+
TEST(ShapeUtilTest, GetDimensionHelperCanNegativeIndex) {
Shape matrix = ShapeUtil::MakeShape(F32, {2, 3});
EXPECT_EQ(3, ShapeUtil::GetDimension(matrix, -1));
@@ -322,6 +334,17 @@ TEST(ShapeUtilTest, IncompatibleScalarVsTuple) {
EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape2, shape1));
}
+TEST(ShapeUtilTest, OpaqueVsArray) {
+ Shape shape1 = ShapeUtil::MakeShape(F32, {5, 7});
+ Shape shape2 = ShapeUtil::MakeOpaqueShape();
+ EXPECT_FALSE(ShapeUtil::Compatible(shape1, shape2));
+ EXPECT_FALSE(ShapeUtil::Compatible(shape2, shape1));
+ EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2));
+ EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape2, shape1));
+ EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(shape1, shape2));
+ EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(shape2, shape1));
+}
+
TEST(ShapeUtilTest, CompareShapesWithPaddedDimensionsMismatch) {
Shape shape1 = ShapeUtil::MakeShape(F32, {20, 30});
shape1.mutable_layout()->add_padded_dimensions(10);
@@ -821,6 +844,28 @@ TEST(ShapeUtilTest, HasDegenerateDimensions) {
ShapeUtil::HasDegenerateDimensions(ShapeUtil::MakeShape(F32, {3, 0, 5})));
}
+TEST(ShapeUtilTest, PermuteDimensionsLayout) {
+ std::vector<int64> layout(3);
+ std::iota(layout.begin(), layout.end(), 0);
+ do {
+ Shape s = ShapeUtil::MakeShapeWithLayout(F32, {10, 100, 1000}, layout);
+ SCOPED_TRACE(tensorflow::strings::StrCat("s=", ShapeUtil::HumanString(s)));
+
+ std::vector<int64> permutation(3);
+ std::iota(permutation.begin(), permutation.end(), 0);
+ do {
+ SCOPED_TRACE(tensorflow::strings::StrCat(
+ "permutation=", tensorflow::str_util::Join(permutation, ",")));
+
+ // TransposeIsBitcast takes the inverse of the permutation that
+ // PermuteDimensions takes.
+ EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(
+ s, ShapeUtil::PermuteDimensions(permutation, s),
+ InversePermutation(permutation)));
+ } while (std::next_permutation(permutation.begin(), permutation.end()));
+ } while (std::next_permutation(layout.begin(), layout.end()));
+}
+
TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) {
EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast(
ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}),