aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-01-21 14:39:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-21 14:47:25 -0800
commit4fe280c59a71e85b73e9947063147743adf2ff2b (patch)
tree3a7d4e6f32e474a2219acbb74d0467e02de67e8a /tensorflow/compiler/xla
parent3117a107285256c2f3071e87a105a1cd8a92e823 (diff)
Added optional string argument to infeed HLO op.
Change: 145188452
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc4
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h12
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc6
-rw-r--r--tensorflow/compiler/xla/xla_data.proto3
6 files changed, 26 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index 2b8b0b6ae5..a33b316833 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -765,13 +765,15 @@ ComputationDataHandle ComputationBuilder::ConvGeneralDilated(
return ParseOpResponse(s, &response);
}
-ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape) {
+ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape,
+ const string& config) {
if (!first_error_.ok() || !PrepareComputation().ok()) {
return ComputationDataHandle();
}
InfeedRequest request;
*request.mutable_shape() = shape;
+ *request.mutable_config() = config;
OpRequest op_request;
*op_request.mutable_computation() = computation_.handle();
*op_request.mutable_infeed_request() = request;
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index a74257eae3..e3389b0882 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -345,7 +345,7 @@ class ComputationBuilder {
// Enqueues an infeed instruction onto the computation, which reads data of
// the given shape from the infeed buffer of the device.
- ComputationDataHandle Infeed(const Shape& shape);
+ ComputationDataHandle Infeed(const Shape& shape, const string& config = "");
// Enqueues a call instruction onto the computation.
ComputationDataHandle Call(
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 2c808e4d09..46af52017e 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -222,8 +222,10 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape,
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
- const Shape& shape) {
- return WrapUnique(new HloInstruction(HloOpcode::kInfeed, shape));
+ const Shape& shape, const string& config) {
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kInfeed, shape));
+ instruction->set_infeed_config(config);
+ return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index e9c653c108..07b3fb386d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -132,7 +132,8 @@ class HloInstruction {
// Creates an infeed instruction, which reads data of the given shape from the
// Infeed interface of the device.
- static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& shape);
+ static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& shape,
+ const string& config);
// Creates a send instruction with the given channel id, which sends the
// operand data to a unique receive instruction in another computation that
@@ -456,6 +457,12 @@ class HloInstruction {
// Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv
int64 channel_id() const { return channel_id_; }
+ // Returns the infeed configuration string. The infeed configuration includes
+ // any metadata needed for the backend compiler (e.g., infeed buffer address)
+ // and is target-dependent.
+ string infeed_config() const { return infeed_config_; }
+ void set_infeed_config(const string& config) { infeed_config_ = config; }
+
// Returns a tag to be used in tracing.
//
// Precondition: opcode() == HloOpcode::kTrace
@@ -799,6 +806,9 @@ class HloInstruction {
// Only present for kSend or kRecv.
int64 channel_id_ = -1;
+ // The string representation of the infeed configuration.
+ string infeed_config_;
+
// String identifier for instruction.
string name_;
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index 19e26f7fb9..b0dd555414 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -1802,8 +1802,10 @@ HloInstruction* ComputationLowerer::Visit(
}
case OpRequest::kInfeedRequest: {
- hlo_instruction = hlo_builder_.AddInstruction(
- HloInstruction::CreateInfeed(request.output_shape()));
+ const InfeedRequest& infeed_request = request.request().infeed_request();
+ hlo_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateInfeed(
+ request.output_shape(), infeed_request.config()));
break;
}
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 4a19d86e77..e142186319 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -380,6 +380,9 @@ message ConvolveRequest {
message InfeedRequest {
// The shape of the data returned by reading the device's infeed buffer.
Shape shape = 2;
+
+ // Additional infeed configuration for the backend.
+ string config = 3;
}
message CallRequest {