diff options
author | 2017-11-15 19:58:36 -0800 | |
---|---|---|
committer | 2017-11-15 20:02:55 -0800 | |
commit | 4916c64836d5f51d6b8878f429bc1622c465fcdf (patch) | |
tree | adb5b4770853321f441c4bfe29ba3aa8a42fdada /tensorflow/compiler | |
parent | fa15669fefdbe7e9a26ac2dd00bc7ce469ca60e1 (diff) |
[XLA] Adding kConditional opcode that represents a conditional HLO instruction.
PiperOrigin-RevId: 175919301
Diffstat (limited to 'tensorflow/compiler')
5 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 3d963a4b1e..d71a4b42c7 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -970,6 +970,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kOutfeed: case HloOpcode::kCrossReplicaSum: return kBrown; + case HloOpcode::kConditional: case HloOpcode::kCustomCall: case HloOpcode::kWhile: case HloOpcode::kCall: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 045abdac8b..f7b5b265d9 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1210,6 +1210,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( new_operands[2], new_operands[3], new_operands[4], epsilon(), feature_index()); break; + case HloOpcode::kConditional: case HloOpcode::kRecv: case HloOpcode::kRecvDone: case HloOpcode::kSend: @@ -1603,6 +1604,7 @@ bool HloInstruction::IdenticalSlowPath( return dimensions() == other.dimensions(); // These opcodes are not yet supported. + case HloOpcode::kConditional: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kSort: @@ -2355,6 +2357,7 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) { return visitor->HandleSendDone(this); // These opcodes are not handled here. + case HloOpcode::kConditional: case HloOpcode::kTrace: break; } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index e0d02e0665..7b07027441 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -58,6 +58,7 @@ namespace xla { V(kClamp, "clamp") \ V(kComplex, "complex") \ V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ + V(kConditional, "conditional") \ V(kConstant, "constant") \ V(kConvert, "convert") \ V(kConvolution, "convolution") \ diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index dea47b1fd7..de4804996f 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -92,6 +92,7 @@ namespace xla { case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index a65e5a856f..0159d03b11 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -798,6 +798,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, shape, operands[0], config ? *config : "")); break; } + case HloOpcode::kConditional: case HloOpcode::kCustomCall: case HloOpcode::kReducePrecision: case HloOpcode::kRng: |