aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-06-13 14:19:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-13 14:24:30 -0700
commite1296c15a32cac020160a1c89002dc561333c66b (patch)
tree62ef334470d3484d166ea583eddea10fc24d1718 /tensorflow/compiler/xla/service/shape_inference_test.cc
parentbf920de58a3ccb2cfe6642be9c487c3fcb13ccae (diff)
Fix assumptions that a Shape must be a tuple or an array.
A TOKEN primitive type was added with cl/199215963 and XLA also has an OPAQUE primitive type. However, in many places in XLA we assume either a tuple or array. This CL fixes many of those instances, but some may remain. Identified instances were discovered by searching for IsTuple or IsArray so the set of fixes is not exhaustive. Also opportunistically addressed a couple potential points of confusion in the ShapeUtil interface: (1) Rename ShapeUtil::HasZeroElements to ShapeUtil::IsZeroElementArray. The point of confusion here is that tuples can also have zero elements and HasZeroElements would check fail on tuple shapes. Method no longer check fails if the given shape is not an array. (2) ShapeUtil::IsNil now returns true only for empty tuples. Previously it also returned true for zero-element array types which was confusing because ShapeUtil::MakeNil creates an empty tuple. PiperOrigin-RevId: 200452672
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 6d017dffe2..bafe14d6f4 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -1311,7 +1311,7 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) {
ASSERT_FALSE(inferred_status_error4.ok());
ASSERT_THAT(
inferred_status_error4.status().error_message(),
- HasSubstr("Expected non-tuple argument for operand of concatenation"));
+ HasSubstr("Expected array argument for operand of concatenation"));
const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32});
auto inferred_status_error5 = ShapeInference::InferConcatOpShape(
@@ -1387,7 +1387,7 @@ TEST_F(ShapeInferenceTest, ReverseInvalidDimension) {
ShapeInference::InferReverseShape(tuple_shape, {0});
ASSERT_FALSE(inferred_status_error3.ok());
ASSERT_THAT(inferred_status_error3.status().error_message(),
- HasSubstr("Expected non-tuple argument"));
+ HasSubstr("Expected array argument"));
}
TEST_F(ShapeInferenceTest, Call) {
@@ -1686,7 +1686,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
/*window_bounds=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Expected non-tuple argument for input"))
+ HasSubstr("Expected array argument for input"))
<< statusor.status();
}
@@ -1700,7 +1700,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
/*window_bounds=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Expected non-tuple argument for gather indices"))
+ HasSubstr("Expected array argument for gather indices"))
<< statusor.status();
}