diff options
6 files changed, 26 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 2b8b0b6ae5..a33b316833 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -765,13 +765,15 @@ ComputationDataHandle ComputationBuilder::ConvGeneralDilated( return ParseOpResponse(s, &response); } -ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape) { +ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape, + const string& config) { if (!first_error_.ok() || !PrepareComputation().ok()) { return ComputationDataHandle(); } InfeedRequest request; *request.mutable_shape() = shape; + *request.mutable_config() = config; OpRequest op_request; *op_request.mutable_computation() = computation_.handle(); *op_request.mutable_infeed_request() = request; diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index a74257eae3..e3389b0882 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -345,7 +345,7 @@ class ComputationBuilder { // Enqueues an infeed instruction onto the computation, which reads data of // the given shape from the infeed buffer of the device. - ComputationDataHandle Infeed(const Shape& shape); + ComputationDataHandle Infeed(const Shape& shape, const string& config = ""); // Enqueues a call instruction onto the computation. ComputationDataHandle Call( diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 2c808e4d09..46af52017e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -222,8 +222,10 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape, } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed( - const Shape& shape) { - return WrapUnique(new HloInstruction(HloOpcode::kInfeed, shape)); + const Shape& shape, const string& config) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kInfeed, shape)); + instruction->set_infeed_config(config); + return instruction; } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend( diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index e9c653c108..07b3fb386d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -132,7 +132,8 @@ class HloInstruction { // Creates an infeed instruction, which reads data of the given shape from the // Infeed interface of the device. - static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& shape); + static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& shape, + const string& config); // Creates a send instruction with the given channel id, which sends the // operand data to a unique receive instruction in another computation that @@ -456,6 +457,12 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv int64 channel_id() const { return channel_id_; } + // Returns the infeed configuration string. The infeed configuration includes + // any metadata needed for the backend compiler (e.g., infeed buffer address) + // and is target-dependent. + string infeed_config() const { return infeed_config_; } + void set_infeed_config(const string& config) { infeed_config_ = config; } + // Returns a tag to be used in tracing. // // Precondition: opcode() == HloOpcode::kTrace @@ -799,6 +806,9 @@ class HloInstruction { // Only present for kSend or kRecv. int64 channel_id_ = -1; + // The string representation of the infeed configuration. + string infeed_config_; + // String identifier for instruction. string name_; diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 19e26f7fb9..b0dd555414 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -1802,8 +1802,10 @@ HloInstruction* ComputationLowerer::Visit( } case OpRequest::kInfeedRequest: { - hlo_instruction = hlo_builder_.AddInstruction( - HloInstruction::CreateInfeed(request.output_shape())); + const InfeedRequest& infeed_request = request.request().infeed_request(); + hlo_instruction = + hlo_builder_.AddInstruction(HloInstruction::CreateInfeed( + request.output_shape(), infeed_request.config())); break; } diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 4a19d86e77..e142186319 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -380,6 +380,9 @@ message ConvolveRequest { message InfeedRequest { // The shape of the data returned by reading the device's infeed buffer. Shape shape = 2; + + // Additional infeed configuration for the backend. + string config = 3; } message CallRequest { |