aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-06-01 18:42:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-01 19:48:11 -0700
commit1236e3b7cbe8931a1c6d8e1417b30b639e9ad1e5 (patch)
tree369fbbc909bc1f4d77bb108b2244d111204bb4fe
parent71a8205a34971ece61461d494d14e29f7d44ff07 (diff)
Add job_name to VariableDeviceChooser
Change: 123823213
-rw-r--r--tensorflow/contrib/framework/python/ops/variables.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py
index 5f7b27c5ee..34e6d5411d 100644
--- a/tensorflow/contrib/framework/python/ops/variables.py
+++ b/tensorflow/contrib/framework/python/ops/variables.py
@@ -522,6 +522,7 @@ class VariableDeviceChooser(object):
def __init__(self,
num_tasks=0,
+ job_name='ps',
device_type='CPU',
device_index=0):
"""Initialize VariableDeviceChooser.
@@ -536,22 +537,23 @@ class VariableDeviceChooser(object):
Args:
num_tasks: number of tasks.
+ job_name: String, a name for the parameter server job.
device_type: Optional device type string (e.g. "CPU" or "GPU")
device_index: int. Optional device index. If left
unspecified, device represents 'any' device_index.
"""
- self._job_name = 'ps' if num_tasks > 0 else None
+ self._job_name = job_name
self._device_type = device_type
self._device_index = device_index
self._num_tasks = num_tasks
self._next_task_id = 0
def __call__(self, op):
- device_spec = tf_device.DeviceSpec(job=self._job_name,
- device_type=self._device_type,
+ device_spec = tf_device.DeviceSpec(device_type=self._device_type,
device_index=self._device_index)
if self._num_tasks > 0:
task_id = self._next_task_id
self._next_task_id = (self._next_task_id + 1) % self._num_tasks
+ device_spec.job = self._job_name
device_spec.task = task_id
return device_spec.to_string()