aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jacques Pienaar <jpienaar@google.com>2017-01-11 15:32:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-11 16:01:34 -0800
commit99e1b19ceba32b8354dddc2841b81864c9ba96bb (patch)
treebc94a344d7a67629f521aa0922ca7e9da9a9c869
parentca5ba46fb51f25eef39c4b4bfdf0abb75704dba6 (diff)
Clarify ResetDevice operates on all devices associated with backend.
Change: 144258290
-rw-r--r--tensorflow/compiler/xla/service/backend.cc4
-rw-r--r--tensorflow/compiler/xla/service/backend.h3
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc4
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.h4
-rw-r--r--tensorflow/compiler/xla/service/service.cc11
-rw-r--r--tensorflow/compiler/xla/service/service.h10
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h7
7 files changed, 28 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc
index 6e76c98c9f..7452a7b696 100644
--- a/tensorflow/compiler/xla/service/backend.cc
+++ b/tensorflow/compiler/xla/service/backend.cc
@@ -234,4 +234,8 @@ StatusOr<bool> Backend::devices_equivalent(int device_ordinal_a,
executor_b->GetDeviceDescription().name());
}
+Status Backend::ResetDevices() {
+ return transfer_manager_->ResetDevices(stream_executors_);
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h
index 17c53d299e..db482c09ae 100644
--- a/tensorflow/compiler/xla/service/backend.h
+++ b/tensorflow/compiler/xla/service/backend.h
@@ -149,6 +149,9 @@ class Backend {
// used for scheduling work. For other platforms, returns NULL.
const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const;
+ // Resets the devices associated with this backend.
+ Status ResetDevices();
+
private:
struct EigenThreadPoolWrapper;
Backend(int64 replica_count, perftools::gputools::Platform* platform,
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index 086306696d..1a6a144bd6 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -160,7 +160,9 @@ Status GenericTransferManager::TransferLiteralToInfeed(
return Unimplemented("Infeed is not supported on GPU (b/30467474)");
}
-Status GenericTransferManager::ResetDevice(se::StreamExecutor* executor) {
+Status GenericTransferManager::ResetDevices(
+ tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
+ executors) {
return Unimplemented(
"Device reset is not yet supported on CPU and GPU (b/30481585)");
}
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h
index cfa02bf22f..06819d65c7 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h
@@ -55,7 +55,9 @@ class GenericTransferManager : public TransferManager {
Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor,
const Literal& literal) override;
- Status ResetDevice(perftools::gputools::StreamExecutor* executor) override;
+ Status ResetDevices(
+ tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
+ executors) override;
StatusOr<std::vector<perftools::gputools::DeviceMemoryBase>>
ShallowCopyTupleFromDevice(
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 847aea7888..0b3900b3b2 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -1019,16 +1019,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg,
ResetDeviceResponse* result) {
- int first_device_ordinal = arg->has_device_handle()
- ? arg->device_handle().handle()
- : execute_backend_->default_device_ordinal();
- TF_ASSIGN_OR_RETURN(auto executors,
- execute_backend_->Replicas(first_device_ordinal));
- for (se::StreamExecutor* executor : executors) {
- TF_RETURN_IF_ERROR(
- execute_backend_->transfer_manager()->ResetDevice(executor));
- }
- return tensorflow::Status::OK();
+ return execute_backend_->ResetDevices();
}
tensorflow::Status Service::TransferToClientInProcess(
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 1141e99fe3..e8ad61c9d0 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -162,7 +162,15 @@ class Service : public ServiceInterface {
const TransferToInfeedRequest* arg,
TransferToInfeedResponse* result) override;
- // Resets the device, clearing all existing state on the device.
+ // Resets devices, clearing all existing state on all the devices associated
+ // with this service (including memory allocated on the devices).
+ //
+ // ResetDevice may only be called where no previous Execution state on the
+ // device is used by the next Execution.
+ //
+ // ResetDevice should be called before an Execution that expect the device to
+ // be in the reset state. For example, if the prior Execution modifies device
+ // state (e.g., architectural state) that the next Execution depends on.
tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
ResetDeviceResponse* result) override;
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index 90dc921b7d..7ffce45213 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -63,8 +64,10 @@ class TransferManager {
perftools::gputools::StreamExecutor* executor,
const Literal& literal) = 0;
- // Resets the device that the given executor runs on.
- virtual Status ResetDevice(perftools::gputools::StreamExecutor* executor) = 0;
+ // Resets the devices associated with this transfer manager.
+ virtual Status ResetDevices(
+ tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
+ executor) = 0;
// Shallow copy a tuple from the device and create a DeviceMemoryBase object
// for each element in the tuple. A DeviceMemoryBase object refers to the