aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/c_api.cc1
-rw-r--r--tensorflow/c/c_api.h7
-rw-r--r--tensorflow/python/client/session.py14
-rw-r--r--tensorflow/python/client/session_list_devices_test.py8
-rw-r--r--tensorflow/python/client/session_test.py21
-rw-r--r--tensorflow/python/client/tf_session.i5
6 files changed, 44 insertions, 12 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 5c218d3f25..a3003953a3 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -963,6 +963,7 @@ TF_DEVICELIST_METHOD(const char*, TF_DeviceListName, name().c_str(), nullptr);
TF_DEVICELIST_METHOD(const char*, TF_DeviceListType, device_type().c_str(),
nullptr);
TF_DEVICELIST_METHOD(int64_t, TF_DeviceListMemoryBytes, memory_limit(), -1);
+TF_DEVICELIST_METHOD(uint64_t, TF_DeviceListIncarnation, incarnation(), 0);
#undef TF_DEVICELIST_METHOD
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index 1eb75ef11f..fddc09d45e 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -1521,6 +1521,13 @@ TF_CAPI_EXPORT extern const char* TF_DeviceListType(const TF_DeviceList* list,
TF_CAPI_EXPORT extern int64_t TF_DeviceListMemoryBytes(
const TF_DeviceList* list, int index, TF_Status* status);
+// Retrieve the incarnation number of a given device.
+//
+// If index is out of bounds, an error code will be set in the status object,
+// and 0 will be returned.
+TF_CAPI_EXPORT extern uint64_t TF_DeviceListIncarnation(
+ const TF_DeviceList* list, int index, TF_Status* status);
+
// --------------------------------------------------------------------------
// Load plugins containing custom ops and kernels
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index e037925961..8ede6ab54c 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -540,10 +540,11 @@ class _DeviceAttributes(object):
(in bytes).
"""
- def __init__(self, name, device_type, memory_limit_bytes):
+ def __init__(self, name, device_type, memory_limit_bytes, incarnation):
self._name = device.canonical_name(name)
self._device_type = device_type
self._memory_limit_bytes = memory_limit_bytes
+ self._incarnation = incarnation
@property
def name(self):
@@ -557,11 +558,16 @@ class _DeviceAttributes(object):
def memory_limit_bytes(self):
return self._memory_limit_bytes
+ @property
+ def incarnation(self):
+ return self._incarnation
+
def __repr__(self):
- return '_DeviceAttributes(%s, %s, %d)' % (
+ return '_DeviceAttributes(%s, %s, %d, %d)' % (
self.name,
self.device_type,
self.memory_limit_bytes,
+ self.incarnation,
)
@@ -658,7 +664,9 @@ class BaseSession(SessionInterface):
name = tf_session.TF_DeviceListName(raw_device_list, i)
device_type = tf_session.TF_DeviceListType(raw_device_list, i)
memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i)
- device_list.append(_DeviceAttributes(name, device_type, memory))
+ incarnation = tf_session.TF_DeviceListIncarnation(raw_device_list, i)
+ device_list.append(
+ _DeviceAttributes(name, device_type, memory, incarnation))
tf_session.TF_DeleteDeviceList(raw_device_list)
return device_list
diff --git a/tensorflow/python/client/session_list_devices_test.py b/tensorflow/python/client/session_list_devices_test.py
index c5d82c213a..dd381c689f 100644
--- a/tensorflow/python/client/session_list_devices_test.py
+++ b/tensorflow/python/client/session_list_devices_test.py
@@ -37,6 +37,8 @@ class SessionListDevicesTest(test_util.TensorFlowTestCase):
devices = sess.list_devices()
self.assertTrue('/job:localhost/replica:0/task:0/device:CPU:0' in set(
[d.name for d in devices]), devices)
+ # All valid device incarnations must be non-zero.
+ self.assertTrue(all(d.incarnation != 0 for d in devices))
def testInvalidDeviceNumber(self):
opts = tf_session.TF_NewSessionOptions()
@@ -54,6 +56,8 @@ class SessionListDevicesTest(test_util.TensorFlowTestCase):
devices = sess.list_devices()
self.assertTrue('/job:local/replica:0/task:0/device:CPU:0' in set(
[d.name for d in devices]), devices)
+ # All valid device incarnations must be non-zero.
+ self.assertTrue(all(d.incarnation != 0 for d in devices))
def testListDevicesClusterSpecPropagation(self):
server1 = server_lib.Server.create_local_server()
@@ -67,11 +71,13 @@ class SessionListDevicesTest(test_util.TensorFlowTestCase):
config = config_pb2.ConfigProto(cluster_def=cluster_def)
with session.Session(server1.target, config=config) as sess:
devices = sess.list_devices()
- device_names = set([d.name for d in devices])
+ device_names = set(d.name for d in devices)
self.assertTrue(
'/job:worker/replica:0/task:0/device:CPU:0' in device_names)
self.assertTrue(
'/job:worker/replica:0/task:1/device:CPU:0' in device_names)
+ # All valid device incarnations must be non-zero.
+ self.assertTrue(all(d.incarnation != 0 for d in devices))
if __name__ == '__main__':
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index b72e029d1c..052be68385 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -35,6 +35,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import device as framework_device_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
@@ -104,18 +105,20 @@ class SessionTest(test_util.TensorFlowTestCase):
copy_val)
def testManyCPUs(self):
- # TODO(keveman): Implement ListDevices and test for the number of
- # devices returned by ListDevices.
with session.Session(
config=config_pb2.ConfigProto(device_count={
- 'CPU': 2
- })):
+ 'CPU': 2, 'GPU': 0
+ })) as sess:
inp = constant_op.constant(10.0, name='W1')
self.assertAllEqual(inp.eval(), 10.0)
+ devices = sess.list_devices()
+ self.assertEqual(2, len(devices))
+ for device in devices:
+ self.assertEqual('CPU', framework_device_lib.DeviceSpec.from_string(
+ device.name).device_type)
+
def testPerSessionThreads(self):
- # TODO(keveman): Implement ListDevices and test for the number of
- # devices returned by ListDevices.
with session.Session(
config=config_pb2.ConfigProto(use_per_session_threads=True)):
inp = constant_op.constant(10.0, name='W1')
@@ -1868,19 +1871,21 @@ class SessionTest(test_util.TensorFlowTestCase):
def testDeviceAttributes(self):
attrs = session._DeviceAttributes(
- '/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337)
+ '/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337, 1000000)
self.assertEqual(1337, attrs.memory_limit_bytes)
self.assertEqual('/job:worker/replica:0/task:3/device:CPU:2', attrs.name)
self.assertEqual('TYPE', attrs.device_type)
+ self.assertEqual(1000000, attrs.incarnation)
str_repr = '%s' % attrs
self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
def testDeviceAttributesCanonicalization(self):
attrs = session._DeviceAttributes('/job:worker/replica:0/task:3/cpu:1',
- 'TYPE', 1337)
+ 'TYPE', 1337, 1000000)
self.assertEqual(1337, attrs.memory_limit_bytes)
self.assertEqual('/job:worker/replica:0/task:3/device:CPU:1', attrs.name)
self.assertEqual('TYPE', attrs.device_type)
+ self.assertEqual(1000000, attrs.incarnation)
str_repr = '%s' % attrs
self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 985cb90436..1cdd8e0b6a 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -138,6 +138,11 @@ tensorflow::ImportNumpy();
$result = PyLong_FromLongLong($1);
}
+// Convert TF_DeviceListIncarnation uint64_t output to Python integer
+%typemap(out) uint64_t {
+ $result = PyLong_FromUnsignedLongLong($1);
+}
+
// We use TF_OperationGetControlInputs_wrapper instead of
// TF_OperationGetControlInputs
%ignore TF_OperationGetControlInputs;