aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-08-23 12:20:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 12:31:37 -0700
commitbb22a312ebd5f22be66bdcf678710721fbe129d1 (patch)
treef428d2bf690f620fd0e552e35ff925866af9b86f /tensorflow/compiler/xla/service/shape_inference.cc
parent04e66772eb48ade841e60de47a9cbc4ac01c7e67 (diff)
[XLA] Tighten up shape inference rules
kRoundNearestAfz should be treated like kCeil and kFloor. kClz is only reasonable on integral types, not floating point or predicate types. kAbs and kSign are only reasonable on signed types and complex. PiperOrigin-RevId: 209978375
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc49
1 files changed, 42 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 50f2080aa4..84918034fa 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -234,10 +234,12 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
switch (opcode) {
case HloOpcode::kFloor:
case HloOpcode::kCeil:
+ case HloOpcode::kRoundNearestAfz:
if (!ShapeUtil::ElementIsFloating(shape)) {
return InvalidArgument(
- "Expected element type in shape to be floating for floor/ceil "
- "operation; got %s.",
+ "Expected element type in shape to be floating for %s operation; "
+ "got %s.",
+ HloOpcodeString(opcode).c_str(),
PrimitiveType_Name(shape.element_type()).c_str());
}
return shape;
@@ -251,8 +253,9 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
if (!ShapeUtil::ElementIsFloating(shape) &&
!ShapeUtil::ElementIsComplex(shape)) {
return InvalidArgument(
- "Expected element type in shape to be floating or complex for "
- "sin/cos/exp/log/tanh operation; got %s.",
+ "Expected element type in shape to be floating or complex for %s "
+ "operation; got %s.",
+ HloOpcodeString(opcode).c_str(),
PrimitiveType_Name(shape.element_type()).c_str());
}
return shape;
@@ -265,19 +268,51 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
} else {
return InvalidArgument(
"Expected element type in shape to be floating or complex for "
- "real/imag operation; got %s.",
+ "%s operation; got %s.",
+ HloOpcodeString(opcode).c_str(),
PrimitiveType_Name(shape.element_type()).c_str());
}
case HloOpcode::kAbs:
if (ShapeUtil::ElementIsComplex(shape)) {
return ShapeUtil::ChangeElementType(
shape, primitive_util::ComplexComponentType(shape.element_type()));
+ } else if (ShapeUtil::ElementIsSigned(shape)) {
+ return shape;
+ } else {
+ return InvalidArgument(
+ "Expected element type in shape to be floating or complex for "
+ "%s operation; got %s.",
+ HloOpcodeString(opcode).c_str(),
+ PrimitiveType_Name(shape.element_type()).c_str());
}
- return shape;
case HloOpcode::kClz:
+ if (!ShapeUtil::ElementIsIntegral(shape)) {
+ return InvalidArgument(
+ "Expected an integral element type in argument to Clz "
+ "operation; got %s.",
+ PrimitiveType_Name(shape.element_type()).c_str());
+ }
+ return shape;
case HloOpcode::kNegate:
- case HloOpcode::kRoundNearestAfz:
+ if (!ShapeUtil::ElementIsIntegral(shape) &&
+ !ShapeUtil::ElementIsFloating(shape) &&
+ !ShapeUtil::ElementIsComplex(shape)) {
+ return InvalidArgument(
+ "Expected element type in shape to be integral, floating or "
+ "complex for %s operation; got %s.",
+ HloOpcodeString(opcode).c_str(),
+ PrimitiveType_Name(shape.element_type()).c_str());
+ }
+ return shape;
case HloOpcode::kSign:
+ if (!ShapeUtil::ElementIsSigned(shape) &&
+ !ShapeUtil::ElementIsComplex(shape)) {
+ return InvalidArgument(
+ "Expected element type in shape to be signed or complex for "
+ "%s operation; got %s.",
+ HloOpcodeString(opcode).c_str(),
+ PrimitiveType_Name(shape.element_type()).c_str());
+ }
return shape;
case HloOpcode::kNot: