aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/device.py
blob: 676e5f779ac196647c004246d126d82668322efb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""Class to represent a device."""
import copy


class Device(object):
  """Represents a Device."""

  def __init__(self, job=None, replica=None, task=None, device_type=None,
               device_index=None):
    """Create a new device object.

    Args:
      job: string.  Optional device job name.
      replica: int.  Optional replica index.
      task: int.  Optional task index.
      device_type: Optional device type string (e.g. "CPU" or "GPU")
      device_index: int.  Optional device index.  If left
        unspecified, device represents 'any' device_index.
    """
    self.job = job
    self.replica = replica
    self.task = task
    if device_type == "cpu" or device_type == "gpu":
      # For backwards compatibility only, we support lowercase variants of
      # cpu and gpu but turn them into uppercase here.
      self.device_type = device_type.upper()
    else:
      self.device_type = device_type
    self.device_index = device_index

  def _clear(self):
    self._job = None
    self._replica = None
    self._task = None
    self.device_type = None
    self.device_index = None

  @property
  def job(self):
    return self._job

  @job.setter
  def job(self, job):
    if job is not None:
      self._job = str(job)
    else:
      self._job = None

  @property
  def replica(self):
    return self._replica

  @replica.setter
  def replica(self, replica):
    if replica is not None:
      self._replica = int(replica)
    else:
      self._replica = None

  @property
  def task(self):
    return self._task

  @task.setter
  def task(self, task):
    if task is not None:
      self._task = int(task)
    else:
      self._task = None

  def parse_from_string(self, spec):
    """Parse a Device name into its components.

    Args:
      spec: a string of the form
       /job:<name>/replica:<id>/task:<id>/device:CPU:<id>
      or
       /job:<name>/replica:<id>/task:<id>/device:GPU:<id>
      as cpu and gpu are mutually exclusive.
      All entries are optional.

    Returns:
      The Device, for convenience.

    Raises:
      ValueError: if the spec was not valid.
    """
    self._clear()
    splits = [x.split(":") for x in spec.split("/")]
    for y in splits:
      ly = len(y)
      if y:
        # NOTE(mdevin): we use the property getters here.
        if ly == 2 and y[0] == "job":
          self.job = y[1]
        elif ly == 2 and y[0] == "replica":
          self.replica = y[1]
        elif ly == 2 and y[0] == "task":
          self.task = y[1]
        elif ((ly == 1 or ly == 2) and
              ((y[0].upper() == "GPU") or (y[0].upper() == "CPU"))):
          if self.device_type is not None:
            raise ValueError("Cannot specify multiple device types: %s" % spec)
          self.device_type = y[0].upper()
          if ly == 2 and y[1] != "*":
            self.device_index = int(y[1])
        elif ly == 3 and y[0] == "device":
          if self.device_type is not None:
            raise ValueError("Cannot specify multiple device types: %s" % spec)
          self.device_type = y[1]
          if y[2] != "*":
            self.device_index = int(y[2])
        elif ly and y[0] != "":  # pylint: disable=g-explicit-bool-comparison
          raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec))

    return self

  def merge_from(self, dev):
    """Merge the properties of "dev" into this Device.

    Args:
      dev: a Device.
    """
    if dev.job is not None:
      self.job = dev.job
    if dev.replica is not None:
      self.replica = dev.replica
    if dev.task is not None:
      self.task = dev.task
    if dev.device_type is not None:
      self.device_type = dev.device_type
    if dev.device_index is not None:
      self.device_index = dev.device_index

  def to_string(self):
    """Return a Device specification string.

    Returns:
      a string of the form /job:<name>/replica:<id>/task:<id>/device:cpu:<id>
      or /job:<name>/replica:<id>/task:<id>/device:cpu:<id>.
    """
    dev = ""
    if self.job is not None:
      dev += "/job:" + self.job
    if self.replica is not None:
      dev += "/replica:" + str(self.replica)
    if self.task is not None:
      dev += "/task:" + str(self.task)
    if self.device_type is not None:
      device_index_string = "*"
      if self.device_index is not None:
        device_index_string = str(self.device_index)
      dev += "/device:%s:%s" % (self.device_type, device_index_string)
    return dev


def from_string(spec):
  """Construct a Device from a string.

  Args:
    spec: a string of the form
     /job:<name>/replica:<id>/task:<id>/device:CPU:<id>
    or
     /job:<name>/replica:<id>/task:<id>/device:GPU:<id>
    as cpu and gpu are mutually exclusive.
    All entries are optional.

  Returns:
    A Device.
  """
  return Device().parse_from_string(spec)


def check_valid(spec):
  """Check that a device spec is valid.

  Args:
    spec: a string.

  Raises:
    An exception if the spec is invalid.
  """
  # Construct a device.  It will assert a failure if spec is invalid.
  from_string(spec)


def merge_device(spec):
  """Returns a device function that merges devices specifications.

  This can be used to merge partial specifications of devices. The
  innermost setting for a device field takes precedence. For example:

    with tf.Device(MergeDevice("/device:GPU:0"))
      # Nodes created here have device "/device:GPU:0"
      with tf.Device(MergeDevice("/job:worker")):
        # Nodes created here have device "/job:worker/device:GPU:0"
        with tf.Device(MergeDevice("/device:CPU:0")):
          # Nodes created here have device "/job:worker/device:CPU:0"
          with tf.Device(MergeDevice("/job:ps")):
            # Nodes created here have device "/job:ps/device:CPU:0"

  Args:
    spec: A device or a device spec string (partially) describing the
      device that should be used for all nodes created in the scope of
      the returned device function's with block.

  Returns:
    A device function with the above-described behavior.

  Raises:
    ValueError: if the spec was not valid.
  """
  if not isinstance(spec, Device):
    spec = from_string(spec or "")
  def _device_function(node_def):
    current_device = from_string(node_def.device or "")
    copy_spec = copy.copy(spec)
    copy_spec.merge_from(current_device)  # current_device takes precedence.
    return copy_spec
  return _device_function