aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Jacques Pienaar <jpienaar@google.com>2017-02-27 15:05:10 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-27 15:42:08 -0800
commitefc8f98d45df835bac2373e19f1da57e3a1ea2d0 (patch)
tree28dabbd42fb3955e6569c0ae1d6954b14433d338 /tensorflow
parent28554fbc756454cd2a1f6f6bda2b2cc86c68bcff (diff)
[XLA] Add basic outfeed support.
Change: 148699787
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/client/client.cc32
-rw-r--r--tensorflow/compiler/xla/client/client.h9
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc2
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h6
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc6
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h10
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc9
-rw-r--r--tensorflow/compiler/xla/service/service.cc49
-rw-r--r--tensorflow/compiler/xla/service/service.h6
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h6
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc16
-rw-r--r--tensorflow/compiler/xla/service_interface.h4
-rw-r--r--tensorflow/compiler/xla/xla.proto15
-rw-r--r--tensorflow/compiler/xla/xla_data.proto3
16 files changed, 160 insertions, 27 deletions
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index 341c02f1a1..c4430dab65 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -132,6 +132,38 @@ Status Client::TransferToInfeed(const Literal& literal, int64 replica_id,
return Status::OK();
}
+StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
+ const Shape* shape_with_layout, int64 replica_id,
+ const DeviceHandle* device_handle) {
+ TransferFromOutfeedRequest request;
+ if (device_handle) {
+ *request.mutable_device_handle() = *device_handle;
+ }
+ request.set_replica_id(replica_id);
+ if (shape_with_layout != nullptr) {
+ *request.mutable_shape_with_layout() = *shape_with_layout;
+ }
+ TransferFromOutfeedResponse response;
+
+ VLOG(1) << "making transfer from outfeed request";
+ VLOG(3) << "TransferFromOutfeedRequest: {" << request.DebugString() << "}";
+ Status s = stub_->TransferFromOutfeed(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+ VLOG(3) << "TransferFromOutfeedResponse: {" << response.DebugString() << "}";
+
+ if (!response.has_literal()) {
+ return FailedPrecondition(
+ "server provided response without a literal in "
+ "TransferToClient request");
+ }
+
+ return WrapUnique(response.release_literal());
+}
+
Status Client::ResetDevice() {
ResetDeviceRequest request;
ResetDeviceResponse response;
diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h
index f261de9d0d..ea166acc91 100644
--- a/tensorflow/compiler/xla/client/client.h
+++ b/tensorflow/compiler/xla/client/client.h
@@ -119,6 +119,15 @@ class Client {
Status TransferToInfeed(const Literal& literal, int64 replica_id = 0,
const DeviceHandle* device_handle = nullptr);
+ // Transfers from the Outfeed of the device.
+ //
+ // device_handle and replica_id together specify a particular device; a device
+ // assigned for the given replica_id among the replicas that the given device
+ // handle belongs to.
+ StatusOr<std::unique_ptr<Literal>> TransferFromOutfeed(
+ const Shape* shape_with_layout, int64 replica_id = 0,
+ const DeviceHandle* device_handle = nullptr);
+
// Resets the device, clearing all existing state on the device.
Status ResetDevice();
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index c4c91b7ea8..ae7695ade5 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -786,6 +786,7 @@ ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape,
}
void ComputationBuilder::Outfeed(const ComputationDataHandle& operand,
+ const Shape& shape,
const string& outfeed_config) {
if (!first_error_.ok() || !PrepareComputation().ok()) {
return;
@@ -794,6 +795,7 @@ void ComputationBuilder::Outfeed(const ComputationDataHandle& operand,
OutfeedRequest request;
request.set_outfeed_config(outfeed_config);
*request.mutable_operand() = operand;
+ *request.mutable_shape() = shape;
OpRequest op_request;
*op_request.mutable_outfeed_request() = request;
*op_request.mutable_computation() = computation_.handle();
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index b1a68e3687..a49e5a8843 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -352,13 +352,13 @@ class ComputationBuilder {
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers);
- // Enqueues an infeed instruction onto the computation, which reads data of
- // the given shape from the infeed buffer of the device.
+ // Enqueues an infeed instruction onto the computation, which writes data of
+ // the given shape to the infeed buffer of the device.
ComputationDataHandle Infeed(const Shape& shape, const string& config = "");
// Enqueues an outfeed instruction onto the computation. This instruction
// generates outgoing data transfers for the given data.
- void Outfeed(const ComputationDataHandle& operand,
+ void Outfeed(const ComputationDataHandle& operand, const Shape& shape,
const string& outfeed_config);
// Enqueues a call instruction onto the computation.
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index 8f39ba8b1d..aa512f242a 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -162,6 +162,12 @@ Status GenericTransferManager::TransferLiteralToInfeed(
return Unimplemented("Infeed is not supported on GPU (b/30467474)");
}
+Status GenericTransferManager::TransferLiteralFromOutfeed(
+ perftools::gputools::StreamExecutor* executor, const Shape& literal_shape,
+ Literal* literal) {
+ return Unimplemented("Outfeed is not supported on CPU/GPU (b/30467474)");
+}
+
Status GenericTransferManager::ResetDevices(
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
executors) {
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h
index 06819d65c7..2fbdb94f06 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h
@@ -55,6 +55,10 @@ class GenericTransferManager : public TransferManager {
Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor,
const Literal& literal) override;
+ Status TransferLiteralFromOutfeed(
+ perftools::gputools::StreamExecutor* executor, const Shape& literal_shape,
+ Literal* literal) override;
+
Status ResetDevices(
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
executors) override;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index b5438865cb..9951f59911 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -236,11 +236,13 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape,
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
- HloInstruction* operand, tensorflow::StringPiece outfeed_config) {
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::StringPiece outfeed_config) {
std::unique_ptr<HloInstruction> instruction =
WrapUnique(new HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil()));
instruction->AppendOperand(operand);
instruction->outfeed_config_ = outfeed_config.ToString();
+ instruction->outfeed_shape_ = shape;
return instruction;
}
@@ -1852,6 +1854,12 @@ Status HloInstruction::AcceptOrdered(
return visitor->FinishVisit(this);
}
+const Shape& HloInstruction::outfeed_shape() const {
+ DCHECK_EQ(opcode_, HloOpcode::kOutfeed);
+ TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
+ return outfeed_shape_;
+}
+
const Shape& HloInstruction::shape() const {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
return shape_;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index bafb402e9d..fbde4681af 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -137,7 +137,8 @@ class HloInstruction {
// Creates an outfeed instruction, which outputs data.
static std::unique_ptr<HloInstruction> CreateOutfeed(
- HloInstruction* operand, tensorflow::StringPiece outfeed_config);
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::StringPiece outfeed_config);
// Creates a send instruction with the given channel id, which sends the
// operand data to a unique receive instruction in another computation that
@@ -428,6 +429,10 @@ class HloInstruction {
// Precondition: opcode() == HloOpcode::kOutfeed
const string& outfeed_config() const;
+ // Returns the shape for the Outfeed instruction.
+ // Precondition: opcode() == HloOpcode::kOutfeed
+ const Shape& outfeed_shape() const;
+
// Gets/sets the while_condition or while_body HloComputation for While. The
// setters should only be called by HloModule or HloComputation methods.
//
@@ -727,6 +732,9 @@ class HloInstruction {
// Returns how this instruction uses elements of its `i`th operand.
UseKind OperandElementUse(int64 i) const;
+ // Shape of outfeed request.
+ Shape outfeed_shape_;
+
// Result shape of this instruction.
Shape shape_;
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index a350acc4da..6270960b34 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -298,8 +298,8 @@ string LayoutConstraints::ToString() const {
for (int64 i = 0; i < instruction->operand_count(); ++i) {
if (OperandLayout(instruction, i) != nullptr) {
tensorflow::strings::StrAppend(
- &output, " operand (", i, "): ",
- OperandLayout(instruction, i)->ToString(), "\n");
+ &output, " operand (", i,
+ "): ", OperandLayout(instruction, i)->ToString(), "\n");
}
}
for (const LogicalBuffer* buffer :
@@ -338,6 +338,11 @@ Status LayoutAssignment::AddMandatoryConstraints(
// TODO(b/31425034): Change infeeds to be more like parameters, with
// shapes in the ComputationLayout.
shape_with_layout = &instruction->shape();
+ } else if (instruction->opcode() == HloOpcode::kOutfeed) {
+ // Constrain the input to the Outfeed instruction to be the expected
+ // layout of the Outfeed.
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
+ instruction->outfeed_shape(), instruction.get(), 0));
} else if (instruction->opcode() == HloOpcode::kParameter) {
// Parameter layouts must match the respective layout in
// ComputationLayout.
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 5b6a4b1e15..249e31cf26 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -587,9 +587,8 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>>
repeated_arguments(backend->Replicas().size(), arguments);
- TF_ASSIGN_OR_RETURN(
- auto results,
- executable->ExecuteOnStreams(run_options, repeated_arguments));
+ TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams(
+ run_options, repeated_arguments));
TF_RET_CHECK(!results.empty());
result = results[0];
}
@@ -927,9 +926,8 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
se::StreamExecutor* stream_executor;
if (arg->has_device_handle()) {
- TF_ASSIGN_OR_RETURN(
- stream_executor,
- execute_backend_->stream_executor(arg->device_handle().handle()));
+ TF_ASSIGN_OR_RETURN(stream_executor, execute_backend_->stream_executor(
+ arg->device_handle().handle()));
} else {
stream_executor = execute_backend_->default_stream_executor();
}
@@ -948,9 +946,8 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
execute_backend_.get(), stream_executor->device_ordinal(), allocation,
shape, StrCat("TransferToServer literal of size ", allocation_size));
- TF_ASSIGN_OR_RETURN(
- auto replicas,
- execute_backend_->Replicas(stream_executor->device_ordinal()));
+ TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas(
+ stream_executor->device_ordinal()));
for (se::StreamExecutor* executor : replicas) {
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralToDevice(
@@ -973,9 +970,8 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
se::StreamExecutor* executor;
if (arg->has_device_handle()) {
- TF_ASSIGN_OR_RETURN(
- auto replicas,
- execute_backend_->Replicas(arg->device_handle().handle()));
+ TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas(
+ arg->device_handle().handle()));
executor = replicas[arg->replica_id()];
} else {
executor = execute_backend_->Replicas()[arg->replica_id()];
@@ -985,6 +981,30 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
executor, arg->literal());
}
+tensorflow::Status Service::TransferFromOutfeed(
+ const TransferFromOutfeedRequest* arg,
+ TransferFromOutfeedResponse* result) {
+ const int64 replica_count = execute_backend_->Replicas().size();
+ if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
+ return FailedPrecondition(
+ "The replica_id=%lld on TransferFromOutfeedRequest not in range [0, "
+ "%lld)",
+ arg->replica_id(), replica_count);
+ }
+
+ se::StreamExecutor* executor;
+ if (arg->has_device_handle()) {
+ TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas(
+ arg->device_handle().handle()));
+ executor = replicas[arg->replica_id()];
+ } else {
+ executor = execute_backend_->Replicas()[arg->replica_id()];
+ }
+
+ return execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
+ executor, arg->shape_with_layout(), result->mutable_literal());
+}
+
tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg,
ResetDeviceResponse* result) {
return execute_backend_->ResetDevices();
@@ -1151,9 +1171,8 @@ tensorflow::Status Service::GetComputationShape(
VersionedComputationHandle versioned_handle =
computation->GetVersionedHandle();
- TF_ASSIGN_OR_RETURN(
- auto program_shape,
- computation->ComputeProgramShape(versioned_handle.version));
+ TF_ASSIGN_OR_RETURN(auto program_shape, computation->ComputeProgramShape(
+ versioned_handle.version));
*result->mutable_program_shape() = *program_shape;
return tensorflow::Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index ba609fe881..ce07489fe0 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -162,6 +162,12 @@ class Service : public ServiceInterface {
const TransferToInfeedRequest* arg,
TransferToInfeedResponse* result) override;
+ // Transfers data from the Outfeed othe device to the literal provided by the
+ // client.
+ tensorflow::Status TransferFromOutfeed(
+ const TransferFromOutfeedRequest* arg,
+ TransferFromOutfeedResponse* result) override;
+
// Resets devices, clearing all existing state on all the devices associated
// with this service (including memory allocated on the devices).
//
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index 7ffce45213..83e893a14a 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -64,6 +64,12 @@ class TransferManager {
perftools::gputools::StreamExecutor* executor,
const Literal& literal) = 0;
+ // Transfers the given literal from the Outfeed interface of the device,
+ // using the given executor.
+ virtual Status TransferLiteralFromOutfeed(
+ perftools::gputools::StreamExecutor* executor, const Shape& literal_shape,
+ Literal* literal) = 0;
+
// Resets the devices associated with this transfer manager.
virtual Status ResetDevices(
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index 7fde1945a5..79e44cb67a 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -891,6 +891,14 @@ Status UserComputation::AddOutfeedInstruction(
const OutfeedRequest& outfeed_request) {
tensorflow::mutex_lock lock(mutex_);
+ const Shape& shape = outfeed_request.shape();
+ if (ShapeUtil::IsNestedTuple(shape)) {
+ return InvalidArgument("Outfeed does not support nested tuple shapes");
+ }
+ if (!LayoutUtil::HasLayout(shape)) {
+ return InvalidArgument("Given shape to Outfeed must have a layout");
+ }
+
// Verify that operand is valid.
TF_RETURN_IF_ERROR(LookUpRequest(outfeed_request.operand()).status());
@@ -900,7 +908,7 @@ Status UserComputation::AddOutfeedInstruction(
OperationRequest& request =
(*session_computation_.mutable_requests())[handle.handle()];
*request.mutable_output_handle() = handle;
- *request.mutable_output_shape() = ShapeUtil::MakeNil();
+ *request.mutable_output_shape() = shape;
*request.mutable_request()->mutable_outfeed_request() = outfeed_request;
VLOG(1) << "AddOutfeedInstruction (" << GetVersionedHandleInternal()
@@ -1991,9 +1999,9 @@ HloInstruction* ComputationLowerer::Visit(
const OutfeedRequest& outfeed_request =
request.request().outfeed_request();
HloInstruction* operand = Visit(outfeed_request.operand(), visited);
- hlo_instruction =
- hlo_builder_.AddInstruction(HloInstruction::CreateOutfeed(
- operand, outfeed_request.outfeed_config()));
+ hlo_instruction = hlo_builder_.AddInstruction(
+ HloInstruction::CreateOutfeed(outfeed_request.shape(), operand,
+ outfeed_request.outfeed_config()));
break;
}
diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h
index fc107480f7..2159386152 100644
--- a/tensorflow/compiler/xla/service_interface.h
+++ b/tensorflow/compiler/xla/service_interface.h
@@ -41,6 +41,10 @@ class ServiceInterface {
virtual tensorflow::Status TransferToInfeed(
const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) = 0;
+ virtual tensorflow::Status TransferFromOutfeed(
+ const TransferFromOutfeedRequest* arg,
+ TransferFromOutfeedResponse* result) = 0;
+
virtual tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
ResetDeviceResponse* result) = 0;
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index b9d82c557b..57f557c458 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -76,7 +76,7 @@ message TransferToClientRequest {
GlobalDataHandle data = 1;
// This optional field directs the service to return the literal in this
- // layout. A shape is used to hold the layout to accomodate tuples.
+ // layout. A shape is used to hold the layout to accommodate tuples.
Shape shape_with_layout = 2;
}
@@ -119,6 +119,19 @@ message TransferToInfeedRequest {
message TransferToInfeedResponse {
}
+message TransferFromOutfeedRequest {
+ // This optional field directs the service to return the literal in this
+ // layout. A shape is used to hold the layout to accommodate tuples.
+ Shape shape_with_layout = 1;
+
+ int64 replica_id = 2;
+ DeviceHandle device_handle = 3;
+}
+
+message TransferFromOutfeedResponse {
+ Literal literal = 1;
+}
+
message ResetDeviceRequest {
DeviceHandle device_handle = 1;
}
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 99a9ba3ee0..7786c46f4c 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -386,6 +386,9 @@ message InfeedRequest {
}
message OutfeedRequest {
+ // The shape of the data returned by reading the device's outfeed buffer.
+ Shape shape = 1;
+
// Operand to the Outfeed. Supports tuple.
ComputationDataHandle operand = 2;