aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-10-13 07:00:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-13 07:05:06 -0700
commit1c241e5ba7fa7068f9cf8f925638b170db57c438 (patch)
treed1bf5bb507023162d4c79a86c6d72d6d4f36cc09 /tensorflow
parenta3b2d6f395ef3f66c9ccd8578e94243e49f76576 (diff)
[XLA] Add ShiftLeft, ShiftRightArithmetic, and ShiftRightLogical operators.
PiperOrigin-RevId: 172091595
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc18
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h10
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h15
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc87
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h3
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc3
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc9
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc6
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc66
-rw-r--r--tensorflow/compiler/xla/xla_data.proto4
15 files changed, 264 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index 206af290c6..dcbdb3525e 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -972,6 +972,24 @@ ComputationDataHandle ComputationBuilder::Not(
return UnaryOp(UNOP_NOT, operand);
}
+ComputationDataHandle ComputationBuilder::ShiftLeft(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(BINOP_SHIFT_LEFT, lhs, rhs, broadcast_dimensions);
+}
+
+ComputationDataHandle ComputationBuilder::ShiftRightArithmetic(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(BINOP_SHIFT_RIGHT_ARITHMETIC, lhs, rhs, broadcast_dimensions);
+}
+
+ComputationDataHandle ComputationBuilder::ShiftRightLogical(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(BINOP_SHIFT_RIGHT_LOGICAL, lhs, rhs, broadcast_dimensions);
+}
+
ComputationDataHandle ComputationBuilder::Abs(
const ComputationDataHandle& operand) {
return UnaryOp(UNOP_ABS, operand);
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index 94b03502f9..cdd9c8847f 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -472,6 +472,16 @@ class ComputationBuilder {
ComputationDataHandle Not(const ComputationDataHandle& operand);
+ ComputationDataHandle ShiftLeft(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ ComputationDataHandle ShiftRightArithmetic(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+ ComputationDataHandle ShiftRightLogical(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
// Reduces an array among the provided dimensions, given "computation" as a
// reduction operator.
ComputationDataHandle Reduce(
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 8c864f3d07..5b1dbf439c 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -167,6 +167,21 @@ class DfsHloVisitor {
HloInstruction* rhs) {
return HandleElementwiseBinary(or_);
}
+ virtual Status HandleShiftLeft(HloInstruction* shift_left,
+ HloInstruction* lhs, HloInstruction* rhs) {
+ return HandleElementwiseBinary(shift_left);
+ }
+ virtual Status HandleShiftRightArithmetic(
+ HloInstruction* shift_right_arithmetic, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return HandleElementwiseBinary(shift_right_arithmetic);
+ }
+ virtual Status HandleShiftRightLogical(HloInstruction* shift_right_logical,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return HandleElementwiseBinary(shift_right_logical);
+ }
+
virtual Status HandleReducePrecision(HloInstruction* reduce_precision) {
return HandleElementwiseUnary(reduce_precision);
}
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index fb4d233d04..44f709bede 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -568,6 +568,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
return ir_builder_->CreateAnd(lhs_value, rhs_value);
case HloOpcode::kOr:
return ir_builder_->CreateOr(lhs_value, rhs_value);
+ case HloOpcode::kShiftLeft:
+ return ir_builder_->CreateShl(lhs_value, rhs_value);
+ case HloOpcode::kShiftRightArithmetic:
+ return ir_builder_->CreateAShr(lhs_value, rhs_value);
+ case HloOpcode::kShiftRightLogical:
+ return ir_builder_->CreateLShr(lhs_value, rhs_value);
default:
return Unimplemented("binary integer op '%s'",
HloOpcodeString(op->opcode()).c_str());
@@ -830,6 +836,9 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
case HloOpcode::kSubtract:
case HloOpcode::kAnd:
case HloOpcode::kOr:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
const HloInstruction* lhs = hlo->operand(0);
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 20dba60f4e..5fd891835d 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -387,6 +387,93 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
};
+ template <typename NativeT,
+ typename std::enable_if<
+ std::is_integral<NativeT>::value &&
+ !std::is_same<NativeT, bool>::value>::type* = nullptr>
+ Status HandleShiftLeft(HloInstruction* shl, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[shl],
+ ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) {
+ return lhs_elem << rhs_elem;
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<!std::is_integral<NativeT>::value ||
+ std::is_same<NativeT, bool>::value>::type* =
+ nullptr>
+ Status HandleShiftLeft(HloInstruction* shl, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return InvalidArgument("Unsupported type for ShiftLeft");
+ }
+
+ Status HandleShiftLeft(HloInstruction* shl, HloInstruction* lhs,
+ HloInstruction* rhs) override {
+ return HandleShiftLeft<ReturnT>(shl, lhs, rhs);
+ }
+ template <typename NativeT,
+ typename std::enable_if<
+ std::is_integral<NativeT>::value &&
+ !std::is_same<NativeT, bool>::value>::type* = nullptr>
+ Status HandleShiftRightArithmetic(HloInstruction* shr, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ typedef typename std::make_signed<NativeT>::type SignedT;
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[shr],
+ ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) {
+ return static_cast<NativeT>(static_cast<SignedT>(lhs_elem) >>
+ rhs_elem);
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<!std::is_integral<NativeT>::value ||
+ std::is_same<NativeT, bool>::value>::type* =
+ nullptr>
+ Status HandleShiftRightArithmetic(HloInstruction* shr, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return InvalidArgument("Unsupported type for ShiftRightArithmetic");
+ }
+
+ Status HandleShiftRightArithmetic(HloInstruction* shra, HloInstruction* lhs,
+ HloInstruction* rhs) override {
+ return HandleShiftRightArithmetic<ReturnT>(shra, lhs, rhs);
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<
+ std::is_integral<NativeT>::value &&
+ !std::is_same<NativeT, bool>::value>::type* = nullptr>
+ Status HandleShiftRightLogical(HloInstruction* shr, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ typedef typename std::make_unsigned<NativeT>::type UnsignedT;
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[shr],
+ ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) {
+ return static_cast<NativeT>(static_cast<UnsignedT>(lhs_elem) >>
+ rhs_elem);
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<!std::is_integral<NativeT>::value ||
+ std::is_same<NativeT, bool>::value>::type* =
+ nullptr>
+ Status HandleShiftRightLogical(HloInstruction* shr, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return InvalidArgument("Unsupported type for ShiftRightLogical");
+ }
+
+ Status HandleShiftRightLogical(HloInstruction* shrl, HloInstruction* lhs,
+ HloInstruction* rhs) override {
+ return HandleShiftRightLogical<ReturnT>(shrl, lhs, rhs);
+ }
+
Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
HloInstruction* arg, HloInstruction* max) override {
std::function<ReturnT(ReturnT, ReturnT, ReturnT)> clamp_op =
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 20fc85c0e9..24e390529e 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -789,6 +789,9 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSelect:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSlice:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index b18280552d..72f4d0715d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -163,6 +163,9 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case (HloOpcode::kSubtract):
case (HloOpcode::kAnd):
case (HloOpcode::kOr):
+ case (HloOpcode::kShiftLeft):
+ case (HloOpcode::kShiftRightArithmetic):
+ case (HloOpcode::kShiftRightLogical):
break;
default:
LOG(FATAL) << "Invalid binary instruction opcode "
@@ -905,6 +908,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kRemainder:
case HloOpcode::kAnd:
case HloOpcode::kOr:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
CHECK_EQ(new_operands.size(), 2);
return CreateBinary(shape, opcode_, new_operands[0], new_operands[1]);
// Ternary ops.
@@ -1293,6 +1299,9 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSelect:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSubtract:
@@ -1984,6 +1993,13 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
return visitor->HandleAnd(this, operands_[0], operands_[1]);
case HloOpcode::kOr:
return visitor->HandleOr(this, operands_[0], operands_[1]);
+ case HloOpcode::kShiftLeft:
+ return visitor->HandleShiftLeft(this, operands_[0], operands_[1]);
+ case HloOpcode::kShiftRightArithmetic:
+ return visitor->HandleShiftRightArithmetic(this, operands_[0],
+ operands_[1]);
+ case HloOpcode::kShiftRightLogical:
+ return visitor->HandleShiftRightLogical(this, operands_[0], operands_[1]);
case HloOpcode::kConcatenate:
return visitor->HandleConcatenate(this, operands_);
case HloOpcode::kConvert:
@@ -2344,6 +2360,9 @@ bool HloInstruction::IsElementwiseBinary() const {
case HloOpcode::kSubtract:
case HloOpcode::kAnd:
case HloOpcode::kOr:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
return true;
default:
return false;
@@ -2393,6 +2412,9 @@ bool HloInstruction::IsElementwise() const {
case HloOpcode::kSubtract:
case HloOpcode::kAnd:
case HloOpcode::kOr:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
return true;
// Ternary elementwise operations.
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index ab5e5463fa..d1ae5f776d 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -104,6 +104,9 @@ HLO_MATCHER(Rng);
HLO_MATCHER(Select);
HLO_MATCHER(SelectAndScatter);
HLO_MATCHER(Send);
+HLO_MATCHER(ShiftLeft);
+HLO_MATCHER(ShiftRightLogical);
+HLO_MATCHER(ShiftRightArithmetic);
HLO_MATCHER(Sign);
HLO_MATCHER(Slice);
HLO_MATCHER(Sort);
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc
index d3d78f4a99..e98012ec0c 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode.cc
@@ -147,6 +147,12 @@ string HloOpcodeString(HloOpcode opcode) {
return "select";
case HloOpcode::kSend:
return "send";
+ case HloOpcode::kShiftLeft:
+ return "shift-left";
+ case HloOpcode::kShiftRightArithmetic:
+ return "shift-right-arithmetic";
+ case HloOpcode::kShiftRightLogical:
+ return "shift-right-logical";
case HloOpcode::kSign:
return "sign";
case HloOpcode::kSin:
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 9c26f360fb..057d4f6ea7 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -88,6 +88,9 @@ enum class HloOpcode {
kSelect,
kSelectAndScatter,
kSend,
+ kShiftLeft,
+ kShiftRightArithmetic,
+ kShiftRightLogical,
kSign,
kSin,
kSlice,
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index e08e4e4d69..7e46d79ba4 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -69,6 +69,9 @@ namespace xla {
case HloOpcode::kReverse:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kSelect:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSlice:
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index a091a067c1..f3c8e3aff3 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -117,6 +117,12 @@ BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) {
return BINOP_OR;
case HloOpcode::kAnd:
return BINOP_AND;
+ case HloOpcode::kShiftLeft:
+ return BINOP_SHIFT_LEFT;
+ case HloOpcode::kShiftRightArithmetic:
+ return BINOP_SHIFT_RIGHT_ARITHMETIC;
+ case HloOpcode::kShiftRightLogical:
+ return BINOP_SHIFT_RIGHT_LOGICAL;
default:
LOG(FATAL) << "unhandled opcode " << opcode;
}
@@ -748,6 +754,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
case BINOP_DIV:
case BINOP_REM:
case BINOP_MUL:
+ case BINOP_SHIFT_LEFT:
+ case BINOP_SHIFT_RIGHT_ARITHMETIC:
+ case BINOP_SHIFT_RIGHT_LOGICAL:
return InferElementwiseBinaryOpShape(operation, lhs, rhs,
broadcast_dimensions);
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index 317817d022..b3506b72bf 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -115,6 +115,12 @@ HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) {
return HloOpcode::kOr;
case BINOP_AND:
return HloOpcode::kAnd;
+ case BINOP_SHIFT_LEFT:
+ return HloOpcode::kShiftLeft;
+ case BINOP_SHIFT_RIGHT_ARITHMETIC:
+ return HloOpcode::kShiftRightArithmetic;
+ case BINOP_SHIFT_RIGHT_LOGICAL:
+ return HloOpcode::kShiftRightLogical;
default:
LOG(FATAL) << "unhandled operation " << binop;
}
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index eb931dcff3..a62b13e04f 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -739,6 +739,72 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) {
ComputeAndCompareR1<uint32>(&builder, {}, {});
}
+XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a =
+ builder.ConstantR1<int32>({static_cast<int32>(0x12345678),
+ static_cast<int32>(0xF0001000), 1, 3, 77});
+ auto b = builder.ConstantR1<int32>({4, 8, 2, 7, 15});
+ auto out = builder.ShiftLeft(a, b);
+
+ ComputeAndCompareR1<int32>(
+ &builder,
+ {static_cast<int32>(0x23456780), 0x00100000, 0x4, 0x180, 2523136}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a =
+ builder.ConstantR1<int32>({static_cast<int32>(0x92345678),
+ static_cast<int32>(0x10001000), 1, 3, 77});
+ auto b = builder.ConstantR1<int32>({4, 8, 2, 7, 2});
+ auto out = builder.ShiftRightArithmetic(a, b);
+
+ ComputeAndCompareR1<int32>(&builder,
+ {static_cast<int32>(0xF9234567),
+ static_cast<int32>(0x00100010), 0, 0, 19},
+ {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a =
+ builder.ConstantR1<int32>({static_cast<int32>(0x92345678),
+ static_cast<int32>(0x10001000), 1, 3, 77});
+ auto b = builder.ConstantR1<int32>({4, 8, 2, 7, 5});
+ auto out = builder.ShiftRightLogical(a, b);
+
+ ComputeAndCompareR1<int32>(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<uint32>({0x12345678, 0xF0001000, 1, 3, 77});
+ auto b = builder.ConstantR1<uint32>({4, 8, 2, 7, 15});
+ auto out = builder.ShiftLeft(a, b);
+
+ ComputeAndCompareR1<uint32>(
+ &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<uint32>({0x92345678, 0x10001000, 1, 3, 77});
+ auto b = builder.ConstantR1<uint32>({4, 8, 2, 7, 2});
+ auto out = builder.ShiftRightArithmetic(a, b);
+
+ ComputeAndCompareR1<uint32>(&builder, {0xF9234567, 0x00100010, 0, 0, 19}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<uint32>({0x92345678, 0x10001000, 1, 3, 77});
+ auto b = builder.ConstantR1<uint32>({4, 8, 2, 7, 5});
+ auto out = builder.ShiftRightLogical(a, b);
+
+ ComputeAndCompareR1<uint32>(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {});
+}
+
XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
SetFastMathDisabled(true);
ComputationBuilder builder(client_, TestName());
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 876b073b3f..0d7e583bed 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -713,6 +713,10 @@ enum BinaryOperation {
// Logical operators
BINOP_AND = 18;
BINOP_OR = 19;
+
+ BINOP_SHIFT_LEFT = 20;
+ BINOP_SHIFT_RIGHT_ARITHMETIC = 21;
+ BINOP_SHIFT_RIGHT_LOGICAL = 22;
}
message BinaryOpRequest {