diff options
author | 2017-10-13 07:00:42 -0700 | |
---|---|---|
committer | 2017-10-13 07:05:06 -0700 | |
commit | 1c241e5ba7fa7068f9cf8f925638b170db57c438 (patch) | |
tree | d1bf5bb507023162d4c79a86c6d72d6d4f36cc09 /tensorflow | |
parent | a3b2d6f395ef3f66c9ccd8578e94243e49f76576 (diff) |
[XLA] Add ShiftLeft, ShiftRightArithmetic, and ShiftRightLogical operators.
PiperOrigin-RevId: 172091595
Diffstat (limited to 'tensorflow')
15 files changed, 264 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 206af290c6..dcbdb3525e 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -972,6 +972,24 @@ ComputationDataHandle ComputationBuilder::Not( return UnaryOp(UNOP_NOT, operand); } +ComputationDataHandle ComputationBuilder::ShiftLeft( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return BinaryOp(BINOP_SHIFT_LEFT, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::ShiftRightArithmetic( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return BinaryOp(BINOP_SHIFT_RIGHT_ARITHMETIC, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::ShiftRightLogical( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { + return BinaryOp(BINOP_SHIFT_RIGHT_LOGICAL, lhs, rhs, broadcast_dimensions); +} + ComputationDataHandle ComputationBuilder::Abs( const ComputationDataHandle& operand) { return UnaryOp(UNOP_ABS, operand); diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 94b03502f9..cdd9c8847f 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -472,6 +472,16 @@ class ComputationBuilder { ComputationDataHandle Not(const ComputationDataHandle& operand); + ComputationDataHandle ShiftLeft( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + ComputationDataHandle ShiftRightArithmetic( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + ComputationDataHandle ShiftRightLogical( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {}); + // Reduces an array among the provided dimensions, given "computation" as a // reduction operator. ComputationDataHandle Reduce( diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 8c864f3d07..5b1dbf439c 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -167,6 +167,21 @@ class DfsHloVisitor { HloInstruction* rhs) { return HandleElementwiseBinary(or_); } + virtual Status HandleShiftLeft(HloInstruction* shift_left, + HloInstruction* lhs, HloInstruction* rhs) { + return HandleElementwiseBinary(shift_left); + } + virtual Status HandleShiftRightArithmetic( + HloInstruction* shift_right_arithmetic, HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(shift_right_arithmetic); + } + virtual Status HandleShiftRightLogical(HloInstruction* shift_right_logical, + HloInstruction* lhs, + HloInstruction* rhs) { + return HandleElementwiseBinary(shift_right_logical); + } + virtual Status HandleReducePrecision(HloInstruction* reduce_precision) { return HandleElementwiseUnary(reduce_precision); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index fb4d233d04..44f709bede 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -568,6 +568,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp( return ir_builder_->CreateAnd(lhs_value, rhs_value); case HloOpcode::kOr: return ir_builder_->CreateOr(lhs_value, rhs_value); + case HloOpcode::kShiftLeft: + return ir_builder_->CreateShl(lhs_value, rhs_value); + case HloOpcode::kShiftRightArithmetic: + return ir_builder_->CreateAShr(lhs_value, rhs_value); + case HloOpcode::kShiftRightLogical: + return ir_builder_->CreateLShr(lhs_value, rhs_value); default: return Unimplemented("binary integer op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -830,6 +836,9 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kSubtract: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr<llvm::Value*> { const HloInstruction* lhs = hlo->operand(0); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 20dba60f4e..5fd891835d 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -387,6 +387,93 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; + template <typename NativeT, + typename std::enable_if< + std::is_integral<NativeT>::value && + !std::is_same<NativeT, bool>::value>::type* = nullptr> + Status HandleShiftLeft(HloInstruction* shl, HloInstruction* lhs, + HloInstruction* rhs) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shl], + ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { + return lhs_elem << rhs_elem; + })); + return Status::OK(); + } + + template <typename NativeT, + typename std::enable_if<!std::is_integral<NativeT>::value || + std::is_same<NativeT, bool>::value>::type* = + nullptr> + Status HandleShiftLeft(HloInstruction* shl, HloInstruction* lhs, + HloInstruction* rhs) { + return InvalidArgument("Unsupported type for ShiftLeft"); + } + + Status HandleShiftLeft(HloInstruction* shl, HloInstruction* lhs, + HloInstruction* rhs) override { + return HandleShiftLeft<ReturnT>(shl, lhs, rhs); + } + template <typename NativeT, + typename std::enable_if< + std::is_integral<NativeT>::value && + !std::is_same<NativeT, bool>::value>::type* = nullptr> + Status HandleShiftRightArithmetic(HloInstruction* shr, HloInstruction* lhs, + HloInstruction* rhs) { + typedef typename std::make_signed<NativeT>::type SignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + return static_cast<NativeT>(static_cast<SignedT>(lhs_elem) >> + rhs_elem); + })); + return Status::OK(); + } + + template <typename NativeT, + typename std::enable_if<!std::is_integral<NativeT>::value || + std::is_same<NativeT, bool>::value>::type* = + nullptr> + Status HandleShiftRightArithmetic(HloInstruction* shr, HloInstruction* lhs, + HloInstruction* rhs) { + return InvalidArgument("Unsupported type for ShiftRightArithmetic"); + } + + Status HandleShiftRightArithmetic(HloInstruction* shra, HloInstruction* lhs, + HloInstruction* rhs) override { + return HandleShiftRightArithmetic<ReturnT>(shra, lhs, rhs); + } + + template <typename NativeT, + typename std::enable_if< + std::is_integral<NativeT>::value && + !std::is_same<NativeT, bool>::value>::type* = nullptr> + Status HandleShiftRightLogical(HloInstruction* shr, HloInstruction* lhs, + HloInstruction* rhs) { + typedef typename std::make_unsigned<NativeT>::type UnsignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + return static_cast<NativeT>(static_cast<UnsignedT>(lhs_elem) >> + rhs_elem); + })); + return Status::OK(); + } + + template <typename NativeT, + typename std::enable_if<!std::is_integral<NativeT>::value || + std::is_same<NativeT, bool>::value>::type* = + nullptr> + Status HandleShiftRightLogical(HloInstruction* shr, HloInstruction* lhs, + HloInstruction* rhs) { + return InvalidArgument("Unsupported type for ShiftRightLogical"); + } + + Status HandleShiftRightLogical(HloInstruction* shrl, HloInstruction* lhs, + HloInstruction* rhs) override { + return HandleShiftRightLogical<ReturnT>(shrl, lhs, rhs); + } + Status HandleClamp(HloInstruction* clamp, HloInstruction* min, HloInstruction* arg, HloInstruction* max) override { std::function<ReturnT(ReturnT, ReturnT, ReturnT)> clamp_op = diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 20fc85c0e9..24e390529e 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -789,6 +789,9 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSelect: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSlice: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index b18280552d..72f4d0715d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -163,6 +163,9 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case (HloOpcode::kSubtract): case (HloOpcode::kAnd): case (HloOpcode::kOr): + case (HloOpcode::kShiftLeft): + case (HloOpcode::kShiftRightArithmetic): + case (HloOpcode::kShiftRightLogical): break; default: LOG(FATAL) << "Invalid binary instruction opcode " @@ -905,6 +908,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kRemainder: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: CHECK_EQ(new_operands.size(), 2); return CreateBinary(shape, opcode_, new_operands[0], new_operands[1]); // Ternary ops. @@ -1293,6 +1299,9 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSelect: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSubtract: @@ -1984,6 +1993,13 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleAnd(this, operands_[0], operands_[1]); case HloOpcode::kOr: return visitor->HandleOr(this, operands_[0], operands_[1]); + case HloOpcode::kShiftLeft: + return visitor->HandleShiftLeft(this, operands_[0], operands_[1]); + case HloOpcode::kShiftRightArithmetic: + return visitor->HandleShiftRightArithmetic(this, operands_[0], + operands_[1]); + case HloOpcode::kShiftRightLogical: + return visitor->HandleShiftRightLogical(this, operands_[0], operands_[1]); case HloOpcode::kConcatenate: return visitor->HandleConcatenate(this, operands_); case HloOpcode::kConvert: @@ -2344,6 +2360,9 @@ bool HloInstruction::IsElementwiseBinary() const { case HloOpcode::kSubtract: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: return true; default: return false; @@ -2393,6 +2412,9 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kSubtract: case HloOpcode::kAnd: case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: return true; // Ternary elementwise operations. diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index ab5e5463fa..d1ae5f776d 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -104,6 +104,9 @@ HLO_MATCHER(Rng); HLO_MATCHER(Select); HLO_MATCHER(SelectAndScatter); HLO_MATCHER(Send); +HLO_MATCHER(ShiftLeft); +HLO_MATCHER(ShiftRightLogical); +HLO_MATCHER(ShiftRightArithmetic); HLO_MATCHER(Sign); HLO_MATCHER(Slice); HLO_MATCHER(Sort); diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index d3d78f4a99..e98012ec0c 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -147,6 +147,12 @@ string HloOpcodeString(HloOpcode opcode) { return "select"; case HloOpcode::kSend: return "send"; + case HloOpcode::kShiftLeft: + return "shift-left"; + case HloOpcode::kShiftRightArithmetic: + return "shift-right-arithmetic"; + case HloOpcode::kShiftRightLogical: + return "shift-right-logical"; case HloOpcode::kSign: return "sign"; case HloOpcode::kSin: diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 9c26f360fb..057d4f6ea7 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -88,6 +88,9 @@ enum class HloOpcode { kSelect, kSelectAndScatter, kSend, + kShiftLeft, + kShiftRightArithmetic, + kShiftRightLogical, kSign, kSin, kSlice, diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index e08e4e4d69..7e46d79ba4 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -69,6 +69,9 @@ namespace xla { case HloOpcode::kReverse: case HloOpcode::kRoundNearestAfz: case HloOpcode::kSelect: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSlice: diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index a091a067c1..f3c8e3aff3 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -117,6 +117,12 @@ BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { return BINOP_OR; case HloOpcode::kAnd: return BINOP_AND; + case HloOpcode::kShiftLeft: + return BINOP_SHIFT_LEFT; + case HloOpcode::kShiftRightArithmetic: + return BINOP_SHIFT_RIGHT_ARITHMETIC; + case HloOpcode::kShiftRightLogical: + return BINOP_SHIFT_RIGHT_LOGICAL; default: LOG(FATAL) << "unhandled opcode " << opcode; } @@ -748,6 +754,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( case BINOP_DIV: case BINOP_REM: case BINOP_MUL: + case BINOP_SHIFT_LEFT: + case BINOP_SHIFT_RIGHT_ARITHMETIC: + case BINOP_SHIFT_RIGHT_LOGICAL: return InferElementwiseBinaryOpShape(operation, lhs, rhs, broadcast_dimensions); diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 317817d022..b3506b72bf 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -115,6 +115,12 @@ HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { return HloOpcode::kOr; case BINOP_AND: return HloOpcode::kAnd; + case BINOP_SHIFT_LEFT: + return HloOpcode::kShiftLeft; + case BINOP_SHIFT_RIGHT_ARITHMETIC: + return HloOpcode::kShiftRightArithmetic; + case BINOP_SHIFT_RIGHT_LOGICAL: + return HloOpcode::kShiftRightLogical; default: LOG(FATAL) << "unhandled operation " << binop; } diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index eb931dcff3..a62b13e04f 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -739,6 +739,72 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) { ComputeAndCompareR1<uint32>(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR1<int32>({static_cast<int32>(0x12345678), + static_cast<int32>(0xF0001000), 1, 3, 77}); + auto b = builder.ConstantR1<int32>({4, 8, 2, 7, 15}); + auto out = builder.ShiftLeft(a, b); + + ComputeAndCompareR1<int32>( + &builder, + {static_cast<int32>(0x23456780), 0x00100000, 0x4, 0x180, 2523136}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR1<int32>({static_cast<int32>(0x92345678), + static_cast<int32>(0x10001000), 1, 3, 77}); + auto b = builder.ConstantR1<int32>({4, 8, 2, 7, 2}); + auto out = builder.ShiftRightArithmetic(a, b); + + ComputeAndCompareR1<int32>(&builder, + {static_cast<int32>(0xF9234567), + static_cast<int32>(0x00100010), 0, 0, 19}, + {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR1<int32>({static_cast<int32>(0x92345678), + static_cast<int32>(0x10001000), 1, 3, 77}); + auto b = builder.ConstantR1<int32>({4, 8, 2, 7, 5}); + auto out = builder.ShiftRightLogical(a, b); + + ComputeAndCompareR1<int32>(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<uint32>({0x12345678, 0xF0001000, 1, 3, 77}); + auto b = builder.ConstantR1<uint32>({4, 8, 2, 7, 15}); + auto out = builder.ShiftLeft(a, b); + + ComputeAndCompareR1<uint32>( + &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<uint32>({0x92345678, 0x10001000, 1, 3, 77}); + auto b = builder.ConstantR1<uint32>({4, 8, 2, 7, 2}); + auto out = builder.ShiftRightArithmetic(a, b); + + ComputeAndCompareR1<uint32>(&builder, {0xF9234567, 0x00100010, 0, 0, 19}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1<uint32>({0x92345678, 0x10001000, 1, 3, 77}); + auto b = builder.ConstantR1<uint32>({4, 8, 2, 7, 5}); + auto out = builder.ShiftRightLogical(a, b); + + ComputeAndCompareR1<uint32>(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { SetFastMathDisabled(true); ComputationBuilder builder(client_, TestName()); diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 876b073b3f..0d7e583bed 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -713,6 +713,10 @@ enum BinaryOperation { // Logical operators BINOP_AND = 18; BINOP_OR = 19; + + BINOP_SHIFT_LEFT = 20; + BINOP_SHIFT_RIGHT_ARITHMETIC = 21; + BINOP_SHIFT_RIGHT_LOGICAL = 22; } message BinaryOpRequest { |