aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/gpu/gpu_device.h
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-08-08 14:06:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-08 15:17:22 -0700
commit2d0d126749d6d0cf82fb86691362c923a1bfbfe4 (patch)
tree9823261d38fe3439326da4fdbef5ffb2c5adbfc2 /tensorflow/core/common_runtime/gpu/gpu_device.h
parent0204fbd5fec268e2b4d4d4e9185e21725a6c248d (diff)
Change DeviceFactory functions that create devices to propagate
Statuses, so that failures to initialize devices don't crash the program. Changes swig for device_lib to be a lot simpler, thanks to mrry@ and keveman@'s help. Change allocation of eigen scratch memory to go through the allocator. Re-enable test for local devices now that python3 issue is fixed. Change: 129678132
Diffstat (limited to 'tensorflow/core/common_runtime/gpu/gpu_device.h')
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.h26
1 files changed, 15 insertions, 11 deletions
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
index 4ac9c4021d..03090aa537 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -48,6 +48,9 @@ class BaseGPUDevice : public LocalDevice {
~BaseGPUDevice() override;
+ // Initialize the device and return the status of initialization.
+ Status Init(const SessionOptions& options);
+
// GPU devices require the Op Compute method to save a reference to
// any temporary tensors that are allocated until the Op execution
// completes.
@@ -97,6 +100,7 @@ class BaseGPUDevice : public LocalDevice {
mutex trace_mu_;
int gpu_id_ = -1;
const bool sync_every_op_ = false;
+ const int32 max_streams_;
std::unique_ptr<EventMgr> em_;
void ReinitializeDevice(OpKernelContext* context, PerOpGpuDevice* device,
@@ -105,19 +109,19 @@ class BaseGPUDevice : public LocalDevice {
class BaseGPUDeviceFactory : public DeviceFactory {
public:
- void CreateDevices(const SessionOptions& options, const string& name_prefix,
- std::vector<Device*>* devices) override;
+ Status CreateDevices(const SessionOptions& options, const string& name_prefix,
+ std::vector<Device*>* devices) override;
private:
- LocalDevice* CreateGPUDevice(const SessionOptions& options,
- const string& name, int gpu_id);
-
- virtual LocalDevice* CreateGPUDevice(const SessionOptions& options,
- const string& name, Bytes memory_limit,
- BusAdjacency bus_adjacency, int gpu_id,
- const string& physical_device_desc,
- Allocator* gpu_allocator,
- Allocator* cpu_allocator) = 0;
+ Status CreateGPUDevice(const SessionOptions& options, const string& name,
+ int gpu_id, BaseGPUDevice** out_device);
+
+ virtual BaseGPUDevice* CreateGPUDevice(const SessionOptions& options,
+ const string& name, Bytes memory_limit,
+ BusAdjacency bus_adjacency, int gpu_id,
+ const string& physical_device_desc,
+ Allocator* gpu_allocator,
+ Allocator* cpu_allocator) = 0;
void GetValidDeviceIds(std::vector<int>* ids);
};