aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-08-09 14:01:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-09 14:04:59 -0700
commit8d545ce994af060c9a1dada3c061d2cb60e24519 (patch)
tree2048ea6f84fa1994d18b11b1f8ac3ffb39cdbe62 /tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
parent874437315670566611808674ec5a0741ae557314 (diff)
[XLA] Make sure backends that don't support variadic reduce reject it.
PiperOrigin-RevId: 208106767
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc9
1 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index a093ffc7c1..1e81cbde35 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -545,6 +545,11 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
switch (root->opcode()) {
case HloOpcode::kTuple:
case HloOpcode::kReduce: {
+ if (root->opcode() == HloOpcode::kReduce &&
+ ShapeUtil::IsTuple(root->shape())) {
+ // TODO(b/112040122): Support variadic reduce.
+ return Unimplemented("Variadic reduce is not supported on GPU");
+ }
VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString();
std::vector<std::unique_ptr<Thunk>> thunks;
ArraySlice<HloInstruction*> output_instructions =
@@ -1694,6 +1699,10 @@ Status IrEmitterUnnested::EmitReductionToVector(
}
Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
+ // TODO(b/112040122): Support multi-output reduce.
+ if (!ShapeUtil::IsArray(reduce->shape())) {
+ return Unimplemented("Multi-output reduce is not supported on GPU");
+ }
auto input = reduce->operand(0);
auto init_value = reduce->operand(1);
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce(reduce->dimensions());