aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-10-13 07:00:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-13 07:05:06 -0700
commit1c241e5ba7fa7068f9cf8f925638b170db57c438 (patch)
treed1bf5bb507023162d4c79a86c6d72d6d4f36cc09 /tensorflow/compiler/xla/service
parenta3b2d6f395ef3f66c9ccd8578e94243e49f76576 (diff)
[XLA] Add ShiftLeft, ShiftRightArithmetic, and ShiftRightLogical operators.
PiperOrigin-RevId: 172091595
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h15
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc87
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h3
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc3
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc9
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc6
11 files changed, 166 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 8c864f3d07..5b1dbf439c 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -167,6 +167,21 @@ class DfsHloVisitor {
HloInstruction* rhs) {
return HandleElementwiseBinary(or_);
}
+ virtual Status HandleShiftLeft(HloInstruction* shift_left,
+ HloInstruction* lhs, HloInstruction* rhs) {
+ return HandleElementwiseBinary(shift_left);
+ }
+ virtual Status HandleShiftRightArithmetic(
+ HloInstruction* shift_right_arithmetic, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return HandleElementwiseBinary(shift_right_arithmetic);
+ }
+ virtual Status HandleShiftRightLogical(HloInstruction* shift_right_logical,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return HandleElementwiseBinary(shift_right_logical);
+ }
+
virtual Status HandleReducePrecision(HloInstruction* reduce_precision) {
return HandleElementwiseUnary(reduce_precision);
}
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index fb4d233d04..44f709bede 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -568,6 +568,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
return ir_builder_->CreateAnd(lhs_value, rhs_value);
case HloOpcode::kOr:
return ir_builder_->CreateOr(lhs_value, rhs_value);
+ case HloOpcode::kShiftLeft:
+ return ir_builder_->CreateShl(lhs_value, rhs_value);
+ case HloOpcode::kShiftRightArithmetic:
+ return ir_builder_->CreateAShr(lhs_value, rhs_value);
+ case HloOpcode::kShiftRightLogical:
+ return ir_builder_->CreateLShr(lhs_value, rhs_value);
default:
return Unimplemented("binary integer op '%s'",
HloOpcodeString(op->opcode()).c_str());
@@ -830,6 +836,9 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
case HloOpcode::kSubtract:
case HloOpcode::kAnd:
case HloOpcode::kOr:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
const HloInstruction* lhs = hlo->operand(0);
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 20dba60f4e..5fd891835d 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -387,6 +387,93 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
};
+ template <typename NativeT,
+ typename std::enable_if<
+ std::is_integral<NativeT>::value &&
+ !std::is_same<NativeT, bool>::value>::type* = nullptr>
+ Status HandleShiftLeft(HloInstruction* shl, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[shl],
+ ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) {
+ return lhs_elem << rhs_elem;
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<!std::is_integral<NativeT>::value ||
+ std::is_same<NativeT, bool>::value>::type* =
+ nullptr>
+ Status HandleShiftLeft(HloInstruction* shl, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return InvalidArgument("Unsupported type for ShiftLeft");
+ }
+
+ Status HandleShiftLeft(HloInstruction* shl, HloInstruction* lhs,
+ HloInstruction* rhs) override {
+ return HandleShiftLeft<ReturnT>(shl, lhs, rhs);
+ }
+ template <typename NativeT,
+ typename std::enable_if<
+ std::is_integral<NativeT>::value &&
+ !std::is_same<NativeT, bool>::value>::type* = nullptr>
+ Status HandleShiftRightArithmetic(HloInstruction* shr, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ typedef typename std::make_signed<NativeT>::type SignedT;
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[shr],
+ ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) {
+ return static_cast<NativeT>(static_cast<SignedT>(lhs_elem) >>
+ rhs_elem);
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<!std::is_integral<NativeT>::value ||
+ std::is_same<NativeT, bool>::value>::type* =
+ nullptr>
+ Status HandleShiftRightArithmetic(HloInstruction* shr, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return InvalidArgument("Unsupported type for ShiftRightArithmetic");
+ }
+
+ Status HandleShiftRightArithmetic(HloInstruction* shra, HloInstruction* lhs,
+ HloInstruction* rhs) override {
+ return HandleShiftRightArithmetic<ReturnT>(shra, lhs, rhs);
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<
+ std::is_integral<NativeT>::value &&
+ !std::is_same<NativeT, bool>::value>::type* = nullptr>
+ Status HandleShiftRightLogical(HloInstruction* shr, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ typedef typename std::make_unsigned<NativeT>::type UnsignedT;
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[shr],
+ ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) {
+ return static_cast<NativeT>(static_cast<UnsignedT>(lhs_elem) >>
+ rhs_elem);
+ }));
+ return Status::OK();
+ }
+
+ template <typename NativeT,
+ typename std::enable_if<!std::is_integral<NativeT>::value ||
+ std::is_same<NativeT, bool>::value>::type* =
+ nullptr>
+ Status HandleShiftRightLogical(HloInstruction* shr, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return InvalidArgument("Unsupported type for ShiftRightLogical");
+ }
+
+ Status HandleShiftRightLogical(HloInstruction* shrl, HloInstruction* lhs,
+ HloInstruction* rhs) override {
+ return HandleShiftRightLogical<ReturnT>(shrl, lhs, rhs);
+ }
+
Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
HloInstruction* arg, HloInstruction* max) override {
std::function<ReturnT(ReturnT, ReturnT, ReturnT)> clamp_op =
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 20fc85c0e9..24e390529e 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -789,6 +789,9 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSelect:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSlice:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index b18280552d..72f4d0715d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -163,6 +163,9 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case (HloOpcode::kSubtract):
case (HloOpcode::kAnd):
case (HloOpcode::kOr):
+ case (HloOpcode::kShiftLeft):
+ case (HloOpcode::kShiftRightArithmetic):
+ case (HloOpcode::kShiftRightLogical):
break;
default:
LOG(FATAL) << "Invalid binary instruction opcode "
@@ -905,6 +908,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kRemainder:
case HloOpcode::kAnd:
case HloOpcode::kOr:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
CHECK_EQ(new_operands.size(), 2);
return CreateBinary(shape, opcode_, new_operands[0], new_operands[1]);
// Ternary ops.
@@ -1293,6 +1299,9 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kSelect:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSubtract:
@@ -1984,6 +1993,13 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
return visitor->HandleAnd(this, operands_[0], operands_[1]);
case HloOpcode::kOr:
return visitor->HandleOr(this, operands_[0], operands_[1]);
+ case HloOpcode::kShiftLeft:
+ return visitor->HandleShiftLeft(this, operands_[0], operands_[1]);
+ case HloOpcode::kShiftRightArithmetic:
+ return visitor->HandleShiftRightArithmetic(this, operands_[0],
+ operands_[1]);
+ case HloOpcode::kShiftRightLogical:
+ return visitor->HandleShiftRightLogical(this, operands_[0], operands_[1]);
case HloOpcode::kConcatenate:
return visitor->HandleConcatenate(this, operands_);
case HloOpcode::kConvert:
@@ -2344,6 +2360,9 @@ bool HloInstruction::IsElementwiseBinary() const {
case HloOpcode::kSubtract:
case HloOpcode::kAnd:
case HloOpcode::kOr:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
return true;
default:
return false;
@@ -2393,6 +2412,9 @@ bool HloInstruction::IsElementwise() const {
case HloOpcode::kSubtract:
case HloOpcode::kAnd:
case HloOpcode::kOr:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
return true;
// Ternary elementwise operations.
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index ab5e5463fa..d1ae5f776d 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -104,6 +104,9 @@ HLO_MATCHER(Rng);
HLO_MATCHER(Select);
HLO_MATCHER(SelectAndScatter);
HLO_MATCHER(Send);
+HLO_MATCHER(ShiftLeft);
+HLO_MATCHER(ShiftRightLogical);
+HLO_MATCHER(ShiftRightArithmetic);
HLO_MATCHER(Sign);
HLO_MATCHER(Slice);
HLO_MATCHER(Sort);
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc
index d3d78f4a99..e98012ec0c 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode.cc
@@ -147,6 +147,12 @@ string HloOpcodeString(HloOpcode opcode) {
return "select";
case HloOpcode::kSend:
return "send";
+ case HloOpcode::kShiftLeft:
+ return "shift-left";
+ case HloOpcode::kShiftRightArithmetic:
+ return "shift-right-arithmetic";
+ case HloOpcode::kShiftRightLogical:
+ return "shift-right-logical";
case HloOpcode::kSign:
return "sign";
case HloOpcode::kSin:
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 9c26f360fb..057d4f6ea7 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -88,6 +88,9 @@ enum class HloOpcode {
kSelect,
kSelectAndScatter,
kSend,
+ kShiftLeft,
+ kShiftRightArithmetic,
+ kShiftRightLogical,
kSign,
kSin,
kSlice,
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index e08e4e4d69..7e46d79ba4 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -69,6 +69,9 @@ namespace xla {
case HloOpcode::kReverse:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kSelect:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSlice:
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index a091a067c1..f3c8e3aff3 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -117,6 +117,12 @@ BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) {
return BINOP_OR;
case HloOpcode::kAnd:
return BINOP_AND;
+ case HloOpcode::kShiftLeft:
+ return BINOP_SHIFT_LEFT;
+ case HloOpcode::kShiftRightArithmetic:
+ return BINOP_SHIFT_RIGHT_ARITHMETIC;
+ case HloOpcode::kShiftRightLogical:
+ return BINOP_SHIFT_RIGHT_LOGICAL;
default:
LOG(FATAL) << "unhandled opcode " << opcode;
}
@@ -748,6 +754,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
case BINOP_DIV:
case BINOP_REM:
case BINOP_MUL:
+ case BINOP_SHIFT_LEFT:
+ case BINOP_SHIFT_RIGHT_ARITHMETIC:
+ case BINOP_SHIFT_RIGHT_LOGICAL:
return InferElementwiseBinaryOpShape(operation, lhs, rhs,
broadcast_dimensions);
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index 317817d022..b3506b72bf 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -115,6 +115,12 @@ HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) {
return HloOpcode::kOr;
case BINOP_AND:
return HloOpcode::kAnd;
+ case BINOP_SHIFT_LEFT:
+ return HloOpcode::kShiftLeft;
+ case BINOP_SHIFT_RIGHT_ARITHMETIC:
+ return HloOpcode::kShiftRightArithmetic;
+ case BINOP_SHIFT_RIGHT_LOGICAL:
+ return HloOpcode::kShiftRightLogical;
default:
LOG(FATAL) << "unhandled operation " << binop;
}