diff options
author | 2016-06-01 18:42:22 -0800 | |
---|---|---|
committer | 2016-06-01 19:48:11 -0700 | |
commit | 1236e3b7cbe8931a1c6d8e1417b30b639e9ad1e5 (patch) | |
tree | 369fbbc909bc1f4d77bb108b2244d111204bb4fe | |
parent | 71a8205a34971ece61461d494d14e29f7d44ff07 (diff) |
Add job_name to VariableDeviceChooser
Change: 123823213
-rw-r--r-- | tensorflow/contrib/framework/python/ops/variables.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index 5f7b27c5ee..34e6d5411d 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -522,6 +522,7 @@ class VariableDeviceChooser(object): def __init__(self, num_tasks=0, + job_name='ps', device_type='CPU', device_index=0): """Initialize VariableDeviceChooser. @@ -536,22 +537,23 @@ class VariableDeviceChooser(object): Args: num_tasks: number of tasks. + job_name: String, a name for the parameter server job. device_type: Optional device type string (e.g. "CPU" or "GPU") device_index: int. Optional device index. If left unspecified, device represents 'any' device_index. """ - self._job_name = 'ps' if num_tasks > 0 else None + self._job_name = job_name self._device_type = device_type self._device_index = device_index self._num_tasks = num_tasks self._next_task_id = 0 def __call__(self, op): - device_spec = tf_device.DeviceSpec(job=self._job_name, - device_type=self._device_type, + device_spec = tf_device.DeviceSpec(device_type=self._device_type, device_index=self._device_index) if self._num_tasks > 0: task_id = self._next_task_id self._next_task_id = (self._next_task_id + 1) % self._num_tasks + device_spec.job = self._job_name device_spec.task = task_id return device_spec.to_string() |