aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-16 17:15:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-16 17:19:28 -0700
commit9d2a432ce74eab4c439fe8c60389e4da9d6c92b2 (patch)
treed094564097a0e399ea5c32ae30d7863128c018e7 /tensorflow/compiler/xla
parent8e7c2c180c4b75768c707d66c9080ee06ee72773 (diff)
Add plumbing for a ReducePrecision operation.
This CL is the first part of a series that adds a ReducePrecision operation for experimenting with the effects of reduced-precision storage of intermediate values. ReducePrecision is a Unary operation parameterized on floating-point exponent and mantissa bit sizes, and rounds the input data as if it were converted to a floating-point value with the given bit sizes and then converted back to "normal" F32 data. Using arbitrary parameterized values to describe the lower-precision value type, rather than hardcoding this as a reduction to IEEE f16, allows us to do more flexible experiments -- e.g., "Is this training error due to the reduced mantissa precision, or due to the reduced exponent range?" or "Is this a smooth degradation with reduced precision or is there a sudden drop at some value?" -- which may suggest software mitigations for the effects. This version of the CL adds the kReducePrecision instruction opcode, and the overall plumbing to support the operation. To allow testing, it includes an exceptionally simple implementation of the actual operation that returns "unimplemented" except for the exponent and mantissa bit sizes where it is a complete no-op. PiperOrigin-RevId: 159295615
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc22
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h5
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h5
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc26
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc25
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h27
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h1
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc1
-rw-r--r--tensorflow/compiler/xla/service/service.cc5
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc24
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h6
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc47
-rw-r--r--tensorflow/compiler/xla/service/user_computation.h4
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc23
-rw-r--r--tensorflow/compiler/xla/xla_data.proto9
20 files changed, 246 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index cefa4af23c..49b0e164b0 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -1514,6 +1514,28 @@ ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding(
return ParseOpResponse(s, &response);
}
+ComputationDataHandle ComputationBuilder::ReducePrecision(
+ const ComputationDataHandle& operand, const int exponent_bits,
+ const int mantissa_bits) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ ReducePrecisionRequest request;
+ *request.mutable_operand() = operand;
+ request.set_exponent_bits(exponent_bits);
+ request.set_mantissa_bits(mantissa_bits);
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_reduce_precision_request() = request;
+ AddOpMetadata(&op_request);
+ OpResponse response;
+
+ VLOG(2) << "making reduce-precision request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
void ComputationBuilder::Send(const ComputationDataHandle& operand,
const ChannelHandle& handle) {
if (!first_error_.ok() || !PrepareComputation().ok()) {
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index 13b44a71a5..6a87784f6a 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -598,6 +598,11 @@ class ComputationBuilder {
const Computation& body,
const ComputationDataHandle& init);
+ // Enqueues a ReducePrecision node onto the computation.
+ ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand,
+ const int exponent_bits,
+ const int mantissa_bits);
+
// Enqueues a Send node onto the computation, to send the given operand to
// a Recv instruction that shares the same channel handle.
void Send(const ComputationDataHandle& operand, const ChannelHandle& handle);
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 40ff037e73..1f58562ac2 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -171,6 +171,11 @@ class DfsHloVisitor {
HloInstruction* lhs, HloInstruction* rhs) {
return HandleElementwiseBinary(logical_or, HloOpcode::kLogicalOr, lhs, rhs);
}
+ virtual Status HandleReducePrecision(HloInstruction* reduce_precision,
+ HloInstruction* operand) {
+ return HandleElementwiseUnary(reduce_precision, HloOpcode::kReducePrecision,
+ operand);
+ }
virtual Status HandleInfeed(HloInstruction* infeed) = 0;
virtual Status HandleOutfeed(HloInstruction* outfeed) = 0;
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index c99ebceb45..bea1329d40 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -385,6 +385,24 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(
return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value));
}
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
+ const HloInstruction* hlo, llvm::Value* x) const {
+ if (hlo->operand(0)->shape().element_type() != F32) {
+ return Unimplemented("reduce-precision only implemented for F32");
+ }
+ // As a preliminary implementation, we only implement this for the case
+ // where it is a no-op -- that is, where the exponent and mantissa bit
+ // counts are equal to the (IEEE f32) bit counts for the input values.
+ if (hlo->exponent_bits() != 8) {
+ return Unimplemented("reduce-precision requires 8 exponent bits");
+ }
+ if (hlo->mantissa_bits() != 23) {
+ return Unimplemented("reduce-precision requires 23 mantissa bits");
+ }
+
+ return x;
+}
+
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
bool is_signed) const {
@@ -742,6 +760,14 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
ElementwiseSourceIndex(index, *hlo, 2)));
return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
};
+ case HloOpcode::kReducePrecision:
+ return [this, hlo, &operand_to_generator](
+ const IrArray::Index& index) -> StatusOr<llvm::Value*> {
+ TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
+ operand_to_generator.at(hlo->operand(0))(
+ ElementwiseSourceIndex(index, *hlo, 0)));
+ return EmitReducePrecision(hlo, operand_value);
+ };
case HloOpcode::kConcatenate:
return [this, hlo, &operand_to_generator](
const IrArray::Index target_index) -> StatusOr<llvm::Value*> {
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
index 2576d3823e..bb9117ca61 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
@@ -84,6 +84,9 @@ class ElementalIrEmitter {
virtual StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type,
llvm::Value* value) const;
+ virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo,
+ llvm::Value* x) const;
+
// A helper method for MakeElementGenerator. Given an elementwise op `hlo` and
// the target array index, computes the source array index of its
// `operand_no`-th operand.
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 7f88474a27..cbabf00913 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -100,6 +100,11 @@ Status HloCostAnalysis::HandleClamp(HloInstruction* clamp,
return HandleElementwiseOp(clamp);
}
+Status HloCostAnalysis::HandleReducePrecision(HloInstruction* hlo,
+ HloInstruction* operand) {
+ return HandleElementwiseOp(hlo);
+}
+
Status HloCostAnalysis::HandleParameter(HloInstruction* parameter) {
current_bytes_accessed_ = 0;
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 30f553a81f..f14baf6da2 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -58,6 +58,8 @@ class HloCostAnalysis : public DfsHloVisitor {
HloInstruction* lhs, HloInstruction* rhs) override;
Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
HloInstruction* arg, HloInstruction* max) override;
+ Status HandleReducePrecision(HloInstruction* hlo,
+ HloInstruction* operand) override;
Status HandleConcatenate(
HloInstruction* concatenate,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 48c33d62c5..9fe4d85f8b 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -318,6 +318,11 @@ string InstructionSequenceGraph(
StrAppend(&name, "<br/>",
"custom_call_target=", instruction->custom_call_target());
break;
+ case HloOpcode::kReducePrecision:
+ // Make ReducePrecision ops a bit more visible, since typically they
+ // will be inserted as modifications to an existing graph.
+ color = kDarkRed;
+ break;
}
// Create instruction node with appropriate label, shape, and color.
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 6bb9e9a9e6..33bb29e16d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -228,6 +228,19 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
}
/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateReducePrecision(const Shape& shape,
+ HloInstruction* operand,
+ const int exponent_bits,
+ const int mantissa_bits) {
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kReducePrecision, shape));
+ instruction->AppendOperand(operand);
+ instruction->exponent_bits_ = exponent_bits;
+ instruction->mantissa_bits_ = mantissa_bits;
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateCrossReplicaSum(const Shape& shape,
HloInstruction* operand) {
auto instruction =
@@ -796,6 +809,10 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kConvert:
CHECK_EQ(new_operands.size(), 1);
return CreateConvert(shape, new_operands[0]);
+ case HloOpcode::kReducePrecision:
+ CHECK_EQ(new_operands.size(), 1);
+ return CreateReducePrecision(shape, new_operands[0], exponent_bits_,
+ mantissa_bits_);
case HloOpcode::kConvolution:
CHECK_EQ(new_operands.size(), 2);
return CreateConvolve(shape, new_operands[0], new_operands[1], *window_,
@@ -1171,6 +1188,11 @@ bool HloInstruction::Identical(
case HloOpcode::kConvert:
return shape().element_type() == other.shape().element_type();
+ // A reduce-precision operation is determined by the bit sizes.
+ case HloOpcode::kReducePrecision:
+ return exponent_bits() == other.exponent_bits() &&
+ mantissa_bits() == other.mantissa_bits();
+
// Convolution has a window and dimensions.
case HloOpcode::kConvolution:
return protobuf_util::ProtobufEquals(window(), other.window()) &&
@@ -1855,6 +1877,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) {
return visitor->HandleTranspose(this);
case HloOpcode::kReverse:
return visitor->HandleReverse(this, operands_[0]);
+ case HloOpcode::kReducePrecision:
+ return visitor->HandleReducePrecision(this, operands_[0]);
case HloOpcode::kSlice:
return visitor->HandleSlice(this, operands_[0]);
case HloOpcode::kDynamicSlice:
@@ -2106,6 +2130,7 @@ bool HloInstruction::IsElementwise() const {
case HloOpcode::kLog:
case HloOpcode::kLogicalNot:
case HloOpcode::kNegate:
+ case HloOpcode::kReducePrecision:
case HloOpcode::kSign:
case HloOpcode::kTanh:
return true;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index cb19c84814..fb80ca86af 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -131,6 +131,13 @@ class HloInstruction {
const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers);
+ // Creates a reduce-precision op, where operand is the data to reduce in
+ // precision, and exponent_bits and mantissa_bits describe the precision to
+ // reduce it to.
+ static std::unique_ptr<HloInstruction> CreateReducePrecision(
+ const Shape& shape, HloInstruction* operand, const int exponent_bits,
+ const int mantissa_bits);
+
// Creates a cross replica sum op.
static std::unique_ptr<HloInstruction> CreateCrossReplicaSum(
const Shape& shape, HloInstruction* operand);
@@ -668,6 +675,22 @@ class HloInstruction {
return dynamic_slice_sizes_;
}
+ // Returns the number of exponent bits for a reduce-precision node.
+ //
+ // Precondition: opcode() == HloOpcode::kReducePrecision
+ int32 exponent_bits() const {
+ CHECK_EQ(HloOpcode::kReducePrecision, opcode_);
+ return exponent_bits_;
+ }
+
+ // Returns the number of mantissa bits for a reduce-precision node.
+ //
+ // Precondition: opcode() == HloOpcode::kReducePrecision
+ int32 mantissa_bits() const {
+ CHECK_EQ(HloOpcode::kReducePrecision, opcode_);
+ return mantissa_bits_;
+ }
+
// Returns data on the window in a windowed operation such as
// convolution.
const Window& window() const {
@@ -864,6 +887,10 @@ class HloInstruction {
std::vector<int64> slice_starts_;
std::vector<int64> slice_limits_;
+ // The bit sizes for a reduce-precision operation.
+ int32 exponent_bits_;
+ int32 mantissa_bits_;
+
// Describes the [start, start + size) range size for a dynamic slice
// ('start' is specified dynamically in the second operand of the operation).
std::vector<int64> dynamic_slice_sizes_;
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc
index 342c43dc5a..4d68d0d088 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode.cc
@@ -116,6 +116,8 @@ string HloOpcodeString(HloOpcode opcode) {
return "recv";
case HloOpcode::kReduce:
return "reduce";
+ case HloOpcode::kReducePrecision:
+ return "reduce-precision";
case HloOpcode::kReduceWindow:
return "reduce-window";
case HloOpcode::kRemainder:
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 8e0fa7b4f1..d1263219c0 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -76,6 +76,7 @@ enum class HloOpcode {
kPower,
kRecv,
kReduce,
+ kReducePrecision,
kReduceWindow,
kRemainder,
kReshape,
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 9bace7edaa..52da222ab9 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -65,6 +65,7 @@ namespace xla {
case HloOpcode::kNegate:
case HloOpcode::kOutfeed:
case HloOpcode::kPad:
+ case HloOpcode::kReducePrecision:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
case HloOpcode::kSelect:
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 5812d3e487..2c3a0a1a25 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -1307,6 +1307,11 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
computation->AddReduceInstruction(arg->reduce_request(), *to_apply);
break;
}
+ case OpRequest::kReducePrecisionRequest: {
+ handle_status = computation->AddReducePrecisionInstruction(
+ arg->reduce_precision_request());
+ break;
+ }
case OpRequest::kReduceWindowRequest: {
TF_ASSIGN_OR_RETURN(UserComputation * to_apply,
computation_tracker_.Resolve(
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 2508f4c13d..88fbfdf33c 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -298,6 +298,30 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
}
+/* static */ StatusOr<Shape> ShapeInference::InferReducePrecisionShape(
+ const Shape& operand_shape, const int exponent_bits,
+ const int mantissa_bits) {
+ if (!ShapeUtil::ElementIsFloating(operand_shape)) {
+ return InvalidArgument(
+ "expected element type in shape to be floating point for "
+ "ReducePrecision operation; got %s",
+ PrimitiveType_Name(operand_shape.element_type()).c_str());
+ }
+ if (exponent_bits < 1) {
+ // One exponent bit is necessary to distinguish 0 from infinity. Having
+ // no exponent bits doesn't produce a sensible number, so we require at
+ // least one.
+ return InvalidArgument("expected exponent_bits >= 1; got %d",
+ exponent_bits);
+ }
+ if (mantissa_bits < 0) {
+ // A number with no mantissa bits is still meaningful, however.
+ return InvalidArgument("expected non-negative mantissa_bits; got %d",
+ mantissa_bits);
+ }
+ return operand_shape;
+}
+
/* static */ StatusOr<Shape> ShapeInference::InferPadShape(
const Shape& operand_shape, const Shape& padding_value_shape,
const PaddingConfig& padding_config) {
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 8bd3585133..55c60e149d 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -171,6 +171,12 @@ class ShapeInference {
static StatusOr<Shape> InferConvertShape(const Shape& operand_shape,
PrimitiveType new_element_type);
+ // Helper that validates the input data type for a reduce-precision operation,
+ // and returns the result shape.
+ static StatusOr<Shape> InferReducePrecisionShape(const Shape& operand_shape,
+ const int exponent_bits,
+ const int mantissa_bits);
+
// Helper that infers the shape produced by a pad operation based on the
// padding configuration.
static StatusOr<Shape> InferPadShape(const Shape& operand_shape,
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index 3e7942075c..8014c64953 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -881,6 +881,34 @@ StatusOr<ComputationDataHandle> UserComputation::AddConvertInstruction(
return handle;
}
+StatusOr<ComputationDataHandle> UserComputation::AddReducePrecisionInstruction(
+ const ReducePrecisionRequest& reduce_precision_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookUpRequest(reduce_precision_request.operand()));
+
+ TF_ASSIGN_OR_RETURN(
+ Shape new_shape,
+ ShapeInference::InferReducePrecisionShape(
+ operand->output_shape(), reduce_precision_request.exponent_bits(),
+ reduce_precision_request.mantissa_bits()));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = new_shape;
+ *request.mutable_request()->mutable_reduce_precision_request() =
+ reduce_precision_request;
+
+ VLOG(1) << "AddReducePrecisionInstruction (" << GetVersionedHandleInternal()
+ << "), data handle " << handle.handle() << ": "
+ << reduce_precision_request.ShortDebugString();
+ return handle;
+}
+
StatusOr<ComputationDataHandle> UserComputation::AddConvolveInstruction(
const ConvolveRequest& convolve_request) {
tensorflow::mutex_lock lock(mutex_);
@@ -2180,6 +2208,13 @@ static void ForEachOperand(
break;
}
+ case OpRequest::kReducePrecisionRequest: {
+ const ReducePrecisionRequest& reduce_precision_request =
+ request.request().reduce_precision_request();
+ apply(reduce_precision_request.operand());
+ break;
+ }
+
case OpRequest::kTraceRequest: {
const TraceRequest& trace_request = request.request().trace_request();
apply(trace_request.operand());
@@ -2767,6 +2802,18 @@ void ComputationLowerer::Visit(
break;
}
+ case OpRequest::kReducePrecisionRequest: {
+ const ReducePrecisionRequest& reduce_precision_request =
+ request.request().reduce_precision_request();
+ HloInstruction* operand =
+ lookup_instruction(reduce_precision_request.operand());
+ auto exponent_bits = reduce_precision_request.exponent_bits();
+ auto mantissa_bits = reduce_precision_request.mantissa_bits();
+ hlo_instruction = add_instruction(HloInstruction::CreateReducePrecision(
+ request.output_shape(), operand, exponent_bits, mantissa_bits));
+ break;
+ }
+
case OpRequest::kTraceRequest: {
const TraceRequest& trace_request = request.request().trace_request();
HloInstruction* operand = lookup_instruction(trace_request.operand());
diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h
index a8bedf20b5..9bb7bf491a 100644
--- a/tensorflow/compiler/xla/service/user_computation.h
+++ b/tensorflow/compiler/xla/service/user_computation.h
@@ -116,6 +116,10 @@ class UserComputation {
const MapRequest& map_request,
const UserComputation& to_apply_computation);
+ // Enqueues a reduce-precision instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddReducePrecisionInstruction(
+ const ReducePrecisionRequest& reduce_precision_request);
+
// Enqueues a convolution instruction onto this user computation.
StatusOr<ComputationDataHandle> AddConvolveInstruction(
const ConvolveRequest& convolve_request);
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index b6088f8d08..f43d4b6f57 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -1864,6 +1864,29 @@ INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount,
ArrayElementwiseOpTestParamCount,
::testing::Values(127, 128, 129, 17 * 4096));
+XLA_TEST_F(ArrayElementwiseOpTest, ReducePrecisionNoOpF32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({-2.5f, 25.5f});
+ auto reduce_precision = builder.ReducePrecision(a, 8, 23);
+
+ ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, ReducePrecisionNoOpParamF32) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> a_values = {-2.5f, 25.5f};
+
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({a_values});
+ std::unique_ptr<GlobalData> a_data =
+ client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ auto a_param = builder.Parameter(0, a_literal->shape(), "a_param");
+
+ auto reduce_precision = builder.ReducePrecision(a_param, 8, 23);
+
+ ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f}, {a_data.get()});
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index a95ac968dd..95c1f0995b 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -743,6 +743,12 @@ message VariadicOpRequest {
repeated ComputationDataHandle operands = 3;
}
+message ReducePrecisionRequest {
+ ComputationDataHandle operand = 1;
+ int32 exponent_bits = 2;
+ int32 mantissa_bits = 3;
+}
+
message SendRequest {
ComputationDataHandle operand = 1;
ChannelHandle channel_handle = 2;
@@ -774,6 +780,7 @@ message OpRequest {
MapRequest map_request = 15;
PadRequest pad_request = 16;
ParameterRequest parameter_request = 17;
+ ReducePrecisionRequest reduce_precision_request = 36;
ReduceRequest reduce_request = 18;
ReduceWindowRequest reduce_window_request = 19;
ReshapeRequest reshape_request = 20;
@@ -791,7 +798,7 @@ message OpRequest {
RecvRequest recv_request = 31;
OutfeedRequest outfeed_request = 32;
BatchNormTrainingRequest batch_norm_training_request = 35;
- // Next: 36
+ // Next: 37
}
}