aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-03 14:07:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-03 14:12:38 -0700
commitbdd84aa59d3bdedc42647711e401229f489c7d25 (patch)
tree695399ae3fed6bc65177f38493dee35f5f74e116 /tensorflow/compiler/xla/service/shape_inference.cc
parenta6471888cc9dfe9c18d121149bc0516a3f423fbb (diff)
[TF:XLA] Split select HLO into array- and tuple-select.
Array select and tuple-select already are handled separately in all backends and HLO passes: Array select is an elementwise operation. The shapes of the to operands have the same dimensions. Tuple select does not define its own output, but instead forwards the true- or false- operand based on a scalar predicate operand. This CL reflects this by adding a new kTupleSelect HLO. The XLA builder interface stays the same and dispatches based on the operand shapes. No change in the operation semantics. This CL just splits the existing select operation into two opcodes and preserves the existing semantics. HLO cost analysis is fixed to handle the two ops appropriately. PiperOrigin-RevId: 203180342
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc37
1 files changed, 27 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 81f071ecc5..70edf7883f 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -929,6 +929,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InferClampShape(lhs, rhs, ehs);
case HloOpcode::kSelect:
return InferSelectShape(lhs, rhs, ehs);
+ case HloOpcode::kTupleSelect:
+ return InferTupleSelectShape(lhs, rhs, ehs);
default:
return InvalidArgument("Unknown operation %s.",
HloOpcodeString(opcode).c_str());
@@ -2267,15 +2269,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
// broadcast from all operands, not just the predicate.
/* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
const Shape& pred, const Shape& on_true, const Shape& on_false) {
- bool compatible;
- if (ShapeUtil::IsTuple(on_true)) {
- // Select only defines the top-level buffer, so if it's a tuple, the two
- // input must match exactly.
- compatible = ShapeUtil::Compatible(on_true, on_false);
- } else {
- compatible = ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false);
- }
- if (!compatible) {
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) {
return InvalidArgument(
"Operands to select must be the same shape; got %s and %s.",
ShapeUtil::HumanString(on_true).c_str(),
@@ -2287,7 +2281,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
ShapeUtil::HumanString(pred).c_str());
}
if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) ||
- ShapeUtil::Rank(pred) == 0) {
+ ShapeUtil::IsScalar(pred)) {
// By this stage we know that pred's element type is PRED. Therefore, this
// check restricts pred to be a PRED scalar, or a PRED array with the same
// dimensions as on_true and on_false.
@@ -2301,6 +2295,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
}
+/* static */ StatusOr<Shape> ShapeInference::InferTupleSelectShape(
+ const Shape& pred, const Shape& on_true, const Shape& on_false) {
+ // Select only defines the top-level buffer, so if it's a tuple, the two
+ // input must match exactly.
+ if (!ShapeUtil::Compatible(on_true, on_false)) {
+ return InvalidArgument(
+ "Operands to tuple-select must be the same shape; got %s and %s.",
+ ShapeUtil::HumanString(on_true).c_str(),
+ ShapeUtil::HumanString(on_false).c_str());
+ }
+ if (pred.element_type() != PRED) {
+ return InvalidArgument(
+ "TupleSelect's pred operand must have PRED element type; got %s.",
+ ShapeUtil::HumanString(pred).c_str());
+ }
+ if (!ShapeUtil::IsScalar(pred)) {
+ return InvalidArgument(
+ "TupleSelect operation with non-scalar predicate: %s.",
+ ShapeUtil::HumanString(pred).c_str());
+ }
+ return on_true;
+}
+
/* static */ StatusOr<Shape> ShapeInference::InferCallShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
const ProgramShape& to_apply) {