diff options
author | 2017-01-11 15:32:45 -0800 | |
---|---|---|
committer | 2017-01-11 16:01:34 -0800 | |
commit | 99e1b19ceba32b8354dddc2841b81864c9ba96bb (patch) | |
tree | bc94a344d7a67629f521aa0922ca7e9da9a9c869 | |
parent | ca5ba46fb51f25eef39c4b4bfdf0abb75704dba6 (diff) |
Clarify ResetDevice operates on all devices associated with backend.
Change: 144258290
-rw-r--r-- | tensorflow/compiler/xla/service/backend.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/backend.h | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/generic_transfer_manager.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/generic_transfer_manager.h | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/service.cc | 11 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/service.h | 10 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/transfer_manager.h | 7 |
7 files changed, 28 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 6e76c98c9f..7452a7b696 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -234,4 +234,8 @@ StatusOr<bool> Backend::devices_equivalent(int device_ordinal_a, executor_b->GetDeviceDescription().name()); } +Status Backend::ResetDevices() { + return transfer_manager_->ResetDevices(stream_executors_); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index 17c53d299e..db482c09ae 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -149,6 +149,9 @@ class Backend { // used for scheduling work. For other platforms, returns NULL. const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const; + // Resets the devices associated with this backend. + Status ResetDevices(); + private: struct EigenThreadPoolWrapper; Backend(int64 replica_count, perftools::gputools::Platform* platform, diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 086306696d..1a6a144bd6 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -160,7 +160,9 @@ Status GenericTransferManager::TransferLiteralToInfeed( return Unimplemented("Infeed is not supported on GPU (b/30467474)"); } -Status GenericTransferManager::ResetDevice(se::StreamExecutor* executor) { +Status GenericTransferManager::ResetDevices( + tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> + executors) { return Unimplemented( "Device reset is not yet supported on CPU and GPU (b/30481585)"); } diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index cfa02bf22f..06819d65c7 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -55,7 +55,9 @@ class GenericTransferManager : public TransferManager { Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; - Status ResetDevice(perftools::gputools::StreamExecutor* executor) override; + Status ResetDevices( + tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> + executors) override; StatusOr<std::vector<perftools::gputools::DeviceMemoryBase>> ShallowCopyTupleFromDevice( diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 847aea7888..0b3900b3b2 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -1019,16 +1019,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) { - int first_device_ordinal = arg->has_device_handle() - ? arg->device_handle().handle() - : execute_backend_->default_device_ordinal(); - TF_ASSIGN_OR_RETURN(auto executors, - execute_backend_->Replicas(first_device_ordinal)); - for (se::StreamExecutor* executor : executors) { - TF_RETURN_IF_ERROR( - execute_backend_->transfer_manager()->ResetDevice(executor)); - } - return tensorflow::Status::OK(); + return execute_backend_->ResetDevices(); } tensorflow::Status Service::TransferToClientInProcess( diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 1141e99fe3..e8ad61c9d0 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -162,7 +162,15 @@ class Service : public ServiceInterface { const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) override; - // Resets the device, clearing all existing state on the device. + // Resets devices, clearing all existing state on all the devices associated + // with this service (including memory allocated on the devices). + // + // ResetDevice may only be called where no previous Execution state on the + // device is used by the next Execution. + // + // ResetDevice should be called before an Execution that expect the device to + // be in the reset state. For example, if the prior Execution modifies device + // state (e.g., architectural state) that the next Execution depends on. tensorflow::Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 90dc921b7d..7ffce45213 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -63,8 +64,10 @@ class TransferManager { perftools::gputools::StreamExecutor* executor, const Literal& literal) = 0; - // Resets the device that the given executor runs on. - virtual Status ResetDevice(perftools::gputools::StreamExecutor* executor) = 0; + // Resets the devices associated with this transfer manager. + virtual Status ResetDevices( + tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> + executor) = 0; // Shallow copy a tuple from the device and create a DeviceMemoryBase object // for each element in the tuple. A DeviceMemoryBase object refers to the |