diff options
-rw-r--r-- | tensorflow/compiler/xla/client/client.cc | 32 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/client.h | 9 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/computation_builder.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/computation_builder.h | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/generic_transfer_manager.cc | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/generic_transfer_manager.h | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 10 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 10 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/layout_assignment.cc | 9 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/service.cc | 49 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/service.h | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/transfer_manager.h | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/user_computation.cc | 16 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service_interface.h | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/xla.proto | 15 | ||||
-rw-r--r-- | tensorflow/compiler/xla/xla_data.proto | 3 |
16 files changed, 160 insertions, 27 deletions
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 341c02f1a1..c4430dab65 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -132,6 +132,38 @@ Status Client::TransferToInfeed(const Literal& literal, int64 replica_id, return Status::OK(); } +StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed( + const Shape* shape_with_layout, int64 replica_id, + const DeviceHandle* device_handle) { + TransferFromOutfeedRequest request; + if (device_handle) { + *request.mutable_device_handle() = *device_handle; + } + request.set_replica_id(replica_id); + if (shape_with_layout != nullptr) { + *request.mutable_shape_with_layout() = *shape_with_layout; + } + TransferFromOutfeedResponse response; + + VLOG(1) << "making transfer from outfeed request"; + VLOG(3) << "TransferFromOutfeedRequest: {" << request.DebugString() << "}"; + Status s = stub_->TransferFromOutfeed(&request, &response); + VLOG(1) << "done with request"; + + if (!s.ok()) { + return s; + } + VLOG(3) << "TransferFromOutfeedResponse: {" << response.DebugString() << "}"; + + if (!response.has_literal()) { + return FailedPrecondition( + "server provided response without a literal in " + "TransferToClient request"); + } + + return WrapUnique(response.release_literal()); +} + Status Client::ResetDevice() { ResetDeviceRequest request; ResetDeviceResponse response; diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index f261de9d0d..ea166acc91 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -119,6 +119,15 @@ class Client { Status TransferToInfeed(const Literal& literal, int64 replica_id = 0, const DeviceHandle* device_handle = nullptr); + // Transfers from the Outfeed of the device. + // + // device_handle and replica_id together specify a particular device; a device + // assigned for the given replica_id among the replicas that the given device + // handle belongs to. + StatusOr<std::unique_ptr<Literal>> TransferFromOutfeed( + const Shape* shape_with_layout, int64 replica_id = 0, + const DeviceHandle* device_handle = nullptr); + // Resets the device, clearing all existing state on the device. Status ResetDevice(); diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index c4c91b7ea8..ae7695ade5 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -786,6 +786,7 @@ ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape, } void ComputationBuilder::Outfeed(const ComputationDataHandle& operand, + const Shape& shape, const string& outfeed_config) { if (!first_error_.ok() || !PrepareComputation().ok()) { return; @@ -794,6 +795,7 @@ void ComputationBuilder::Outfeed(const ComputationDataHandle& operand, OutfeedRequest request; request.set_outfeed_config(outfeed_config); *request.mutable_operand() = operand; + *request.mutable_shape() = shape; OpRequest op_request; *op_request.mutable_outfeed_request() = request; *op_request.mutable_computation() = computation_.handle(); diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index b1a68e3687..a49e5a8843 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -352,13 +352,13 @@ class ComputationBuilder { tensorflow::gtl::ArraySlice<int64> rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers); - // Enqueues an infeed instruction onto the computation, which reads data of - // the given shape from the infeed buffer of the device. + // Enqueues an infeed instruction onto the computation, which writes data of + // the given shape to the infeed buffer of the device. ComputationDataHandle Infeed(const Shape& shape, const string& config = ""); // Enqueues an outfeed instruction onto the computation. This instruction // generates outgoing data transfers for the given data. - void Outfeed(const ComputationDataHandle& operand, + void Outfeed(const ComputationDataHandle& operand, const Shape& shape, const string& outfeed_config); // Enqueues a call instruction onto the computation. diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 8f39ba8b1d..aa512f242a 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -162,6 +162,12 @@ Status GenericTransferManager::TransferLiteralToInfeed( return Unimplemented("Infeed is not supported on GPU (b/30467474)"); } +Status GenericTransferManager::TransferLiteralFromOutfeed( + perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, + Literal* literal) { + return Unimplemented("Outfeed is not supported on CPU/GPU (b/30467474)"); +} + Status GenericTransferManager::ResetDevices( tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> executors) { diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 06819d65c7..2fbdb94f06 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -55,6 +55,10 @@ class GenericTransferManager : public TransferManager { Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; + Status TransferLiteralFromOutfeed( + perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, + Literal* literal) override; + Status ResetDevices( tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> executors) override; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index b5438865cb..9951f59911 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -236,11 +236,13 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape, } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed( - HloInstruction* operand, tensorflow::StringPiece outfeed_config) { + const Shape& shape, HloInstruction* operand, + tensorflow::StringPiece outfeed_config) { std::unique_ptr<HloInstruction> instruction = WrapUnique(new HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil())); instruction->AppendOperand(operand); instruction->outfeed_config_ = outfeed_config.ToString(); + instruction->outfeed_shape_ = shape; return instruction; } @@ -1852,6 +1854,12 @@ Status HloInstruction::AcceptOrdered( return visitor->FinishVisit(this); } +const Shape& HloInstruction::outfeed_shape() const { + DCHECK_EQ(opcode_, HloOpcode::kOutfeed); + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); + return outfeed_shape_; +} + const Shape& HloInstruction::shape() const { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); return shape_; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index bafb402e9d..fbde4681af 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -137,7 +137,8 @@ class HloInstruction { // Creates an outfeed instruction, which outputs data. static std::unique_ptr<HloInstruction> CreateOutfeed( - HloInstruction* operand, tensorflow::StringPiece outfeed_config); + const Shape& shape, HloInstruction* operand, + tensorflow::StringPiece outfeed_config); // Creates a send instruction with the given channel id, which sends the // operand data to a unique receive instruction in another computation that @@ -428,6 +429,10 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kOutfeed const string& outfeed_config() const; + // Returns the shape for the Outfeed instruction. + // Precondition: opcode() == HloOpcode::kOutfeed + const Shape& outfeed_shape() const; + // Gets/sets the while_condition or while_body HloComputation for While. The // setters should only be called by HloModule or HloComputation methods. // @@ -727,6 +732,9 @@ class HloInstruction { // Returns how this instruction uses elements of its `i`th operand. UseKind OperandElementUse(int64 i) const; + // Shape of outfeed request. + Shape outfeed_shape_; + // Result shape of this instruction. Shape shape_; diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index a350acc4da..6270960b34 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -298,8 +298,8 @@ string LayoutConstraints::ToString() const { for (int64 i = 0; i < instruction->operand_count(); ++i) { if (OperandLayout(instruction, i) != nullptr) { tensorflow::strings::StrAppend( - &output, " operand (", i, "): ", - OperandLayout(instruction, i)->ToString(), "\n"); + &output, " operand (", i, + "): ", OperandLayout(instruction, i)->ToString(), "\n"); } } for (const LogicalBuffer* buffer : @@ -338,6 +338,11 @@ Status LayoutAssignment::AddMandatoryConstraints( // TODO(b/31425034): Change infeeds to be more like parameters, with // shapes in the ComputationLayout. shape_with_layout = &instruction->shape(); + } else if (instruction->opcode() == HloOpcode::kOutfeed) { + // Constrain the input to the Outfeed instruction to be the expected + // layout of the Outfeed. + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + instruction->outfeed_shape(), instruction.get(), 0)); } else if (instruction->opcode() == HloOpcode::kParameter) { // Parameter layouts must match the respective layout in // ComputationLayout. diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 5b6a4b1e15..249e31cf26 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -587,9 +587,8 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult( tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>> repeated_arguments(backend->Replicas().size(), arguments); - TF_ASSIGN_OR_RETURN( - auto results, - executable->ExecuteOnStreams(run_options, repeated_arguments)); + TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( + run_options, repeated_arguments)); TF_RET_CHECK(!results.empty()); result = results[0]; } @@ -927,9 +926,8 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, se::StreamExecutor* stream_executor; if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN( - stream_executor, - execute_backend_->stream_executor(arg->device_handle().handle())); + TF_ASSIGN_OR_RETURN(stream_executor, execute_backend_->stream_executor( + arg->device_handle().handle())); } else { stream_executor = execute_backend_->default_stream_executor(); } @@ -948,9 +946,8 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, execute_backend_.get(), stream_executor->device_ordinal(), allocation, shape, StrCat("TransferToServer literal of size ", allocation_size)); - TF_ASSIGN_OR_RETURN( - auto replicas, - execute_backend_->Replicas(stream_executor->device_ordinal())); + TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( + stream_executor->device_ordinal())); for (se::StreamExecutor* executor : replicas) { TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( @@ -973,9 +970,8 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, se::StreamExecutor* executor; if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN( - auto replicas, - execute_backend_->Replicas(arg->device_handle().handle())); + TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( + arg->device_handle().handle())); executor = replicas[arg->replica_id()]; } else { executor = execute_backend_->Replicas()[arg->replica_id()]; @@ -985,6 +981,30 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, executor, arg->literal()); } +tensorflow::Status Service::TransferFromOutfeed( + const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) { + const int64 replica_count = execute_backend_->Replicas().size(); + if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { + return FailedPrecondition( + "The replica_id=%lld on TransferFromOutfeedRequest not in range [0, " + "%lld)", + arg->replica_id(), replica_count); + } + + se::StreamExecutor* executor; + if (arg->has_device_handle()) { + TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( + arg->device_handle().handle())); + executor = replicas[arg->replica_id()]; + } else { + executor = execute_backend_->Replicas()[arg->replica_id()]; + } + + return execute_backend_->transfer_manager()->TransferLiteralFromOutfeed( + executor, arg->shape_with_layout(), result->mutable_literal()); +} + tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) { return execute_backend_->ResetDevices(); @@ -1151,9 +1171,8 @@ tensorflow::Status Service::GetComputationShape( VersionedComputationHandle versioned_handle = computation->GetVersionedHandle(); - TF_ASSIGN_OR_RETURN( - auto program_shape, - computation->ComputeProgramShape(versioned_handle.version)); + TF_ASSIGN_OR_RETURN(auto program_shape, computation->ComputeProgramShape( + versioned_handle.version)); *result->mutable_program_shape() = *program_shape; return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index ba609fe881..ce07489fe0 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -162,6 +162,12 @@ class Service : public ServiceInterface { const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) override; + // Transfers data from the Outfeed othe device to the literal provided by the + // client. + tensorflow::Status TransferFromOutfeed( + const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) override; + // Resets devices, clearing all existing state on all the devices associated // with this service (including memory allocated on the devices). // diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 7ffce45213..83e893a14a 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -64,6 +64,12 @@ class TransferManager { perftools::gputools::StreamExecutor* executor, const Literal& literal) = 0; + // Transfers the given literal from the Outfeed interface of the device, + // using the given executor. + virtual Status TransferLiteralFromOutfeed( + perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, + Literal* literal) = 0; + // Resets the devices associated with this transfer manager. virtual Status ResetDevices( tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 7fde1945a5..79e44cb67a 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -891,6 +891,14 @@ Status UserComputation::AddOutfeedInstruction( const OutfeedRequest& outfeed_request) { tensorflow::mutex_lock lock(mutex_); + const Shape& shape = outfeed_request.shape(); + if (ShapeUtil::IsNestedTuple(shape)) { + return InvalidArgument("Outfeed does not support nested tuple shapes"); + } + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument("Given shape to Outfeed must have a layout"); + } + // Verify that operand is valid. TF_RETURN_IF_ERROR(LookUpRequest(outfeed_request.operand()).status()); @@ -900,7 +908,7 @@ Status UserComputation::AddOutfeedInstruction( OperationRequest& request = (*session_computation_.mutable_requests())[handle.handle()]; *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = ShapeUtil::MakeNil(); + *request.mutable_output_shape() = shape; *request.mutable_request()->mutable_outfeed_request() = outfeed_request; VLOG(1) << "AddOutfeedInstruction (" << GetVersionedHandleInternal() @@ -1991,9 +1999,9 @@ HloInstruction* ComputationLowerer::Visit( const OutfeedRequest& outfeed_request = request.request().outfeed_request(); HloInstruction* operand = Visit(outfeed_request.operand(), visited); - hlo_instruction = - hlo_builder_.AddInstruction(HloInstruction::CreateOutfeed( - operand, outfeed_request.outfeed_config())); + hlo_instruction = hlo_builder_.AddInstruction( + HloInstruction::CreateOutfeed(outfeed_request.shape(), operand, + outfeed_request.outfeed_config())); break; } diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index fc107480f7..2159386152 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -41,6 +41,10 @@ class ServiceInterface { virtual tensorflow::Status TransferToInfeed( const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) = 0; + virtual tensorflow::Status TransferFromOutfeed( + const TransferFromOutfeedRequest* arg, + TransferFromOutfeedResponse* result) = 0; + virtual tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) = 0; diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index b9d82c557b..57f557c458 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -76,7 +76,7 @@ message TransferToClientRequest { GlobalDataHandle data = 1; // This optional field directs the service to return the literal in this - // layout. A shape is used to hold the layout to accomodate tuples. + // layout. A shape is used to hold the layout to accommodate tuples. Shape shape_with_layout = 2; } @@ -119,6 +119,19 @@ message TransferToInfeedRequest { message TransferToInfeedResponse { } +message TransferFromOutfeedRequest { + // This optional field directs the service to return the literal in this + // layout. A shape is used to hold the layout to accommodate tuples. + Shape shape_with_layout = 1; + + int64 replica_id = 2; + DeviceHandle device_handle = 3; +} + +message TransferFromOutfeedResponse { + Literal literal = 1; +} + message ResetDeviceRequest { DeviceHandle device_handle = 1; } diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 99a9ba3ee0..7786c46f4c 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -386,6 +386,9 @@ message InfeedRequest { } message OutfeedRequest { + // The shape of the data returned by reading the device's outfeed buffer. + Shape shape = 1; + // Operand to the Outfeed. Supports tuple. ComputationDataHandle operand = 2; |