diff options
author | 2017-10-13 07:00:42 -0700 | |
---|---|---|
committer | 2017-10-13 07:05:06 -0700 | |
commit | 1c241e5ba7fa7068f9cf8f925638b170db57c438 (patch) | |
tree | d1bf5bb507023162d4c79a86c6d72d6d4f36cc09 /tensorflow/compiler/xla/service | |
parent | a3b2d6f395ef3f66c9ccd8578e94243e49f76576 (diff) |
[XLA] Add ShiftLeft, ShiftRightArithmetic, and ShiftRightLogical operators.
PiperOrigin-RevId: 172091595
Diffstat (limited to 'tensorflow/compiler/xla/service')
11 files changed, 166 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 8c864f3d07..5b1dbf439c 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -167,6 +167,21 @@ class DfsHloVisitor { HloInstruction* rhs) { return HandleElementwiseBinary(or_); } + virtual Status HandleShiftLeft(HloInstruction* shift_left, + HloInstruction* lhs, HloInstruction* rhs) { + return HandleElementwiseBinary(shift_left); + } + virtual Status HandleShiftRightArithmetic( + HloInstruction* shift_right_arithmetic, HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(shift_right_arithmetic); + } + virtual Status HandleShiftRightLogical(HloInstruction* shift_right_logical, + HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(shift_right_logical); + } + virtual Status HandleReducePrecision(HloInstruction* reduce_precision) { return HandleElementwiseUnary(reduce_precision); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index fb4d233d04..44f709bede 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -568,6 +568,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp( return ir_builder_->CreateAnd(lhs_value, rhs_value); case HloOpcode::kOr: return ir_builder_->CreateOr(lhs_value, rhs_value); + case HloOpcode::kShiftLeft: + return ir_builder_->CreateShl(lhs_value, rhs_value); + case HloOpcode::kShiftRightArithmetic: + return ir_builder_->CreateAShr(lhs_value, rhs_value); + case HloOpcode::kShiftRightLogical: + return ir_builder_->CreateLShr(lhs_value, rhs_value); default: return Unimplemented("binary integer op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -830,6 +836,9 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kSubtract: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr<llvm::Value*> { const HloInstruction* lhs = hlo->operand(0); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 20dba60f4e..5fd891835d 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -387,6 +387,93 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; + template <typename NativeT, + typename std::enable_if< + std::is_integral<NativeT>::value && + !std::is_same<NativeT, bool>::value>::type* = nullptr> + Status HandleShiftLeft(HloInstruction* shl, HloInstruction* lhs, + HloInstruction* rhs) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shl], + ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { + return lhs_elem << rhs_elem; + })); + return Status::OK(); + } + + template <typename NativeT, + typename std::enable_if<!std::is_integral<NativeT>::value || + std::is_same<NativeT, bool>::value>::type* = + nullptr> + Status HandleShiftLeft(HloInstruction* shl, HloInstruction* lhs, + HloInstruction* rhs) { + return InvalidArgument("Unsupported type for ShiftLeft"); + } + + Status HandleShiftLeft(HloInstruction* shl, HloInstruction* lhs, + HloInstruction* rhs) override { + return HandleShiftLeft<ReturnT>(shl, lhs, rhs); + } + template <typename NativeT, + typename std::enable_if< + std::is_integral<NativeT>::value && + !std::is_same<NativeT, bool>::value>::type* = nullptr> + Status HandleShiftRightArithmetic(HloInstruction* shr, HloInstruction* lhs, + HloInstruction* rhs) { + typedef typename std::make_signed<NativeT>::type SignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + return static_cast<NativeT>(static_cast<SignedT>(lhs_elem) >> + rhs_elem); + })); + return Status::OK(); + } + + template <typename NativeT, + typename std::enable_if<!std::is_integral<NativeT>::value || + std::is_same<NativeT, bool>::value>::type* = + nullptr> + Status HandleShiftRightArithmetic(HloInstruction* shr, HloInstruction* lhs, + HloInstruction* rhs) { + return InvalidArgument("Unsupported type for ShiftRightArithmetic"); + } + + Status HandleShiftRightArithmetic(HloInstruction* shra, HloInstruction* lhs, + HloInstruction* rhs) override { + return HandleShiftRightArithmetic<ReturnT>(shra, lhs, rhs); + } + + template <typename NativeT, + typename std::enable_if< + std::is_integral<NativeT>::value && + !std::is_same<NativeT, bool>::value>::type* = nullptr> + Status HandleShiftRightLogical(HloInstruction* shr, HloInstruction* lhs, + HloInstruction* rhs) { + typedef typename std::make_unsigned<NativeT>::type UnsignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + return static_cast<NativeT>(static_cast<UnsignedT>(lhs_elem) >> + rhs_elem); + })); + return Status::OK(); + } + + template <typename NativeT, + typename std::enable_if<!std::is_integral<NativeT>::value || + std::is_same<NativeT, bool>::value>::type* = + nullptr> + Status HandleShiftRightLogical(HloInstruction* shr, HloInstruction* lhs, + HloInstruction* rhs) { + return InvalidArgument("Unsupported type for ShiftRightLogical"); + } + + Status HandleShiftRightLogical(HloInstruction* shrl, HloInstruction* lhs, + HloInstruction* rhs) override { + return HandleShiftRightLogical<ReturnT>(shrl, lhs, rhs); + } + Status HandleClamp(HloInstruction* clamp, HloInstruction* min, HloInstruction* arg, HloInstruction* max) override { std::function<ReturnT(ReturnT, ReturnT, ReturnT)> clamp_op = diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 20fc85c0e9..24e390529e 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -789,6 +789,9 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSelect: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSlice: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index b18280552d..72f4d0715d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -163,6 +163,9 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case (HloOpcode::kSubtract): case (HloOpcode::kAnd): case (HloOpcode::kOr): + case (HloOpcode::kShiftLeft): + case (HloOpcode::kShiftRightArithmetic): + case (HloOpcode::kShiftRightLogical): break; default: LOG(FATAL) << "Invalid binary instruction opcode " @@ -905,6 +908,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kRemainder: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: CHECK_EQ(new_operands.size(), 2); return CreateBinary(shape, opcode_, new_operands[0], new_operands[1]); // Ternary ops. @@ -1293,6 +1299,9 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSelect: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSubtract: @@ -1984,6 +1993,13 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleAnd(this, operands_[0], operands_[1]); case HloOpcode::kOr: return visitor->HandleOr(this, operands_[0], operands_[1]); + case HloOpcode::kShiftLeft: + return visitor->HandleShiftLeft(this, operands_[0], operands_[1]); + case HloOpcode::kShiftRightArithmetic: + return visitor->HandleShiftRightArithmetic(this, operands_[0], + operands_[1]); + case HloOpcode::kShiftRightLogical: + return visitor->HandleShiftRightLogical(this, operands_[0], operands_[1]); case HloOpcode::kConcatenate: return visitor->HandleConcatenate(this, operands_); case HloOpcode::kConvert: @@ -2344,6 +2360,9 @@ bool HloInstruction::IsElementwiseBinary() const { case HloOpcode::kSubtract: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: return true; default: return false; @@ -2393,6 +2412,9 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kSubtract: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: return true; // Ternary elementwise operations. diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index ab5e5463fa..d1ae5f776d 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -104,6 +104,9 @@ HLO_MATCHER(Rng); HLO_MATCHER(Select); HLO_MATCHER(SelectAndScatter); HLO_MATCHER(Send); +HLO_MATCHER(ShiftLeft); +HLO_MATCHER(ShiftRightLogical); +HLO_MATCHER(ShiftRightArithmetic); HLO_MATCHER(Sign); HLO_MATCHER(Slice); HLO_MATCHER(Sort); diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index d3d78f4a99..e98012ec0c 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -147,6 +147,12 @@ string HloOpcodeString(HloOpcode opcode) { return "select"; case HloOpcode::kSend: return "send"; + case HloOpcode::kShiftLeft: + return "shift-left"; + case HloOpcode::kShiftRightArithmetic: + return "shift-right-arithmetic"; + case HloOpcode::kShiftRightLogical: + return "shift-right-logical"; case HloOpcode::kSign: return "sign"; case HloOpcode::kSin: diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 9c26f360fb..057d4f6ea7 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -88,6 +88,9 @@ enum class HloOpcode { kSelect, kSelectAndScatter, kSend, + kShiftLeft, + kShiftRightArithmetic, + kShiftRightLogical, kSign, kSin, kSlice, diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index e08e4e4d69..7e46d79ba4 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -69,6 +69,9 @@ namespace xla { case HloOpcode::kReverse: case HloOpcode::kRoundNearestAfz: case HloOpcode::kSelect: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSlice: diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index a091a067c1..f3c8e3aff3 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -117,6 +117,12 @@ BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { return BINOP_OR; case HloOpcode::kAnd: return BINOP_AND; + case HloOpcode::kShiftLeft: + return BINOP_SHIFT_LEFT; + case HloOpcode::kShiftRightArithmetic: + return BINOP_SHIFT_RIGHT_ARITHMETIC; + case HloOpcode::kShiftRightLogical: + return BINOP_SHIFT_RIGHT_LOGICAL; default: LOG(FATAL) << "unhandled opcode " << opcode; } @@ -748,6 +754,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( case BINOP_DIV: case BINOP_REM: case BINOP_MUL: + case BINOP_SHIFT_LEFT: + case BINOP_SHIFT_RIGHT_ARITHMETIC: + case BINOP_SHIFT_RIGHT_LOGICAL: return InferElementwiseBinaryOpShape(operation, lhs, rhs, broadcast_dimensions); diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 317817d022..b3506b72bf 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -115,6 +115,12 @@ HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { return HloOpcode::kOr; case BINOP_AND: return HloOpcode::kAnd; + case BINOP_SHIFT_LEFT: + return HloOpcode::kShiftLeft; + case BINOP_SHIFT_RIGHT_ARITHMETIC: + return HloOpcode::kShiftRightArithmetic; + case BINOP_SHIFT_RIGHT_LOGICAL: + return HloOpcode::kShiftRightLogical; default: LOG(FATAL) << "unhandled operation " << binop; } |