aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-23 20:18:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-25 04:54:55 -0700
commit418ae5ed77f1353c794f93a4adfbf7db02fa3191 (patch)
tree51ec953fa53da1f197296d31d019ae25ae24d63f
parent759e9f874eb0af7902a586e0efcaf53463816c23 (diff)
A couple of small device-related utilities.
PiperOrigin-RevId: 190312148
-rw-r--r--tensorflow/python/BUILD3
-rw-r--r--tensorflow/python/training/device_util.py68
2 files changed, 71 insertions, 0 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index acfdcd15f7..e6ad564ede 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2884,9 +2884,11 @@ py_library(
":client",
":control_flow_ops",
":data_flow_ops",
+ ":device",
":errors",
":framework",
":framework_for_generated_wrappers",
+ ":framework_ops",
":gradients",
":init_ops",
":io_ops",
@@ -2911,6 +2913,7 @@ py_library(
":variable_scope",
":variables",
"//tensorflow/python/eager:backprop",
+ "//tensorflow/python/eager:context",
"//third_party/py/numpy",
"@six_archive//:six",
],
diff --git a/tensorflow/python/training/device_util.py b/tensorflow/python/training/device_util.py
new file mode 100644
index 0000000000..f1137e80ab
--- /dev/null
+++ b/tensorflow/python/training/device_util.py
@@ -0,0 +1,68 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Device-related support functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import device as tf_device
+from tensorflow.python.framework import ops
+
+
+def canonicalize(d):
+ d = tf_device.DeviceSpec.from_string(d)
+ assert d.device_type is None or d.device_type == d.device_type.upper(), (
+ "Device type '%s' must be all-caps." % (d.device_type,))
+ # Fill in missing device fields using defaults.
+ result = tf_device.DeviceSpec(
+ job="localhost", replica=0, task=0, device_type="CPU", device_index=0)
+ result.merge_from(d)
+ return result.to_string()
+
+
+class _FakeNodeDef(object):
+ """A fake NodeDef for _FakeOperation."""
+
+ def __init__(self):
+ self.op = ""
+ self.name = ""
+
+
+class _FakeOperation(object):
+ """A fake Operation object to pass to device functions."""
+
+ def __init__(self):
+ self.device = ""
+ self.type = ""
+ self.name = ""
+ self.node_def = _FakeNodeDef()
+
+ def _set_device(self, device):
+ self.device = ops._device_string(device) # pylint: disable=protected-access
+
+
+def current():
+ """Return a string (not canonicalized) for the current device."""
+ # TODO(josh11b): Work out how this function interacts with ops.colocate_with.
+ ctx = context.context()
+ if ctx.executing_eagerly():
+ d = ctx.device_name
+ else:
+ op = _FakeOperation()
+ ops.get_default_graph()._apply_device_functions(op) # pylint: disable=protected-access
+ d = op.device
+ return d