diff options
Diffstat (limited to 'tensorflow/python/framework/device.py')
-rw-r--r-- | tensorflow/python/framework/device.py | 220 |
1 files changed, 220 insertions, 0 deletions
diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py new file mode 100644 index 0000000000..676e5f779a --- /dev/null +++ b/tensorflow/python/framework/device.py @@ -0,0 +1,220 @@ +"""Class to represent a device.""" +import copy + + +class Device(object): + """Represents a Device.""" + + def __init__(self, job=None, replica=None, task=None, device_type=None, + device_index=None): + """Create a new device object. + + Args: + job: string. Optional device job name. + replica: int. Optional replica index. + task: int. Optional task index. + device_type: Optional device type string (e.g. "CPU" or "GPU") + device_index: int. Optional device index. If left + unspecified, device represents 'any' device_index. + """ + self.job = job + self.replica = replica + self.task = task + if device_type == "cpu" or device_type == "gpu": + # For backwards compatibility only, we support lowercase variants of + # cpu and gpu but turn them into uppercase here. + self.device_type = device_type.upper() + else: + self.device_type = device_type + self.device_index = device_index + + def _clear(self): + self._job = None + self._replica = None + self._task = None + self.device_type = None + self.device_index = None + + @property + def job(self): + return self._job + + @job.setter + def job(self, job): + if job is not None: + self._job = str(job) + else: + self._job = None + + @property + def replica(self): + return self._replica + + @replica.setter + def replica(self, replica): + if replica is not None: + self._replica = int(replica) + else: + self._replica = None + + @property + def task(self): + return self._task + + @task.setter + def task(self, task): + if task is not None: + self._task = int(task) + else: + self._task = None + + def parse_from_string(self, spec): + """Parse a Device name into its components. + + Args: + spec: a string of the form + /job:<name>/replica:<id>/task:<id>/device:CPU:<id> + or + /job:<name>/replica:<id>/task:<id>/device:GPU:<id> + as cpu and gpu are mutually exclusive. + All entries are optional. + + Returns: + The Device, for convenience. + + Raises: + ValueError: if the spec was not valid. + """ + self._clear() + splits = [x.split(":") for x in spec.split("/")] + for y in splits: + ly = len(y) + if y: + # NOTE(mdevin): we use the property getters here. + if ly == 2 and y[0] == "job": + self.job = y[1] + elif ly == 2 and y[0] == "replica": + self.replica = y[1] + elif ly == 2 and y[0] == "task": + self.task = y[1] + elif ((ly == 1 or ly == 2) and + ((y[0].upper() == "GPU") or (y[0].upper() == "CPU"))): + if self.device_type is not None: + raise ValueError("Cannot specify multiple device types: %s" % spec) + self.device_type = y[0].upper() + if ly == 2 and y[1] != "*": + self.device_index = int(y[1]) + elif ly == 3 and y[0] == "device": + if self.device_type is not None: + raise ValueError("Cannot specify multiple device types: %s" % spec) + self.device_type = y[1] + if y[2] != "*": + self.device_index = int(y[2]) + elif ly and y[0] != "": # pylint: disable=g-explicit-bool-comparison + raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec)) + + return self + + def merge_from(self, dev): + """Merge the properties of "dev" into this Device. + + Args: + dev: a Device. + """ + if dev.job is not None: + self.job = dev.job + if dev.replica is not None: + self.replica = dev.replica + if dev.task is not None: + self.task = dev.task + if dev.device_type is not None: + self.device_type = dev.device_type + if dev.device_index is not None: + self.device_index = dev.device_index + + def to_string(self): + """Return a Device specification string. + + Returns: + a string of the form /job:<name>/replica:<id>/task:<id>/device:cpu:<id> + or /job:<name>/replica:<id>/task:<id>/device:cpu:<id>. + """ + dev = "" + if self.job is not None: + dev += "/job:" + self.job + if self.replica is not None: + dev += "/replica:" + str(self.replica) + if self.task is not None: + dev += "/task:" + str(self.task) + if self.device_type is not None: + device_index_string = "*" + if self.device_index is not None: + device_index_string = str(self.device_index) + dev += "/device:%s:%s" % (self.device_type, device_index_string) + return dev + + +def from_string(spec): + """Construct a Device from a string. + + Args: + spec: a string of the form + /job:<name>/replica:<id>/task:<id>/device:CPU:<id> + or + /job:<name>/replica:<id>/task:<id>/device:GPU:<id> + as cpu and gpu are mutually exclusive. + All entries are optional. + + Returns: + A Device. + """ + return Device().parse_from_string(spec) + + +def check_valid(spec): + """Check that a device spec is valid. + + Args: + spec: a string. + + Raises: + An exception if the spec is invalid. + """ + # Construct a device. It will assert a failure if spec is invalid. + from_string(spec) + + +def merge_device(spec): + """Returns a device function that merges devices specifications. + + This can be used to merge partial specifications of devices. The + innermost setting for a device field takes precedence. For example: + + with tf.Device(MergeDevice("/device:GPU:0")) + # Nodes created here have device "/device:GPU:0" + with tf.Device(MergeDevice("/job:worker")): + # Nodes created here have device "/job:worker/device:GPU:0" + with tf.Device(MergeDevice("/device:CPU:0")): + # Nodes created here have device "/job:worker/device:CPU:0" + with tf.Device(MergeDevice("/job:ps")): + # Nodes created here have device "/job:ps/device:CPU:0" + + Args: + spec: A device or a device spec string (partially) describing the + device that should be used for all nodes created in the scope of + the returned device function's with block. + + Returns: + A device function with the above-described behavior. + + Raises: + ValueError: if the spec was not valid. + """ + if not isinstance(spec, Device): + spec = from_string(spec or "") + def _device_function(node_def): + current_device = from_string(node_def.device or "") + copy_spec = copy.copy(spec) + copy_spec.merge_from(current_device) # current_device takes precedence. + return copy_spec + return _device_function |