aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client/session.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/client/session.py')
-rw-r--r--tensorflow/python/client/session.py19
1 files changed, 14 insertions, 5 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index f3b788f931..180bb74d00 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import functools
import re
import threading
@@ -243,7 +244,7 @@ class _FetchMapper(object):
elif isinstance(fetch, (list, tuple)):
# NOTE(touts): This is also the code path for namedtuples.
return _ListFetchMapper(fetch)
- elif isinstance(fetch, dict):
+ elif isinstance(fetch, collections.Mapping):
return _DictFetchMapper(fetch)
else:
# Look for a handler in the registered expansions.
@@ -361,7 +362,7 @@ class _ListFetchMapper(_FetchMapper):
for m, vi in zip(self._mappers, self._value_indices):
results.append(m.build_results([values[j] for j in vi]))
# Return a value of the original type of the fetches.
- if self._fetch_type == list:
+ if issubclass(self._fetch_type, list):
return results
elif self._fetch_type == tuple:
return tuple(results)
@@ -540,10 +541,11 @@ class _DeviceAttributes(object):
(in bytes).
"""
- def __init__(self, name, device_type, memory_limit_bytes):
+ def __init__(self, name, device_type, memory_limit_bytes, incarnation):
self._name = device.canonical_name(name)
self._device_type = device_type
self._memory_limit_bytes = memory_limit_bytes
+ self._incarnation = incarnation
@property
def name(self):
@@ -557,11 +559,16 @@ class _DeviceAttributes(object):
def memory_limit_bytes(self):
return self._memory_limit_bytes
+ @property
+ def incarnation(self):
+ return self._incarnation
+
def __repr__(self):
- return '_DeviceAttributes(%s, %s, %d)' % (
+ return '_DeviceAttributes(%s, %s, %d, %d)' % (
self.name,
self.device_type,
self.memory_limit_bytes,
+ self.incarnation,
)
@@ -658,7 +665,9 @@ class BaseSession(SessionInterface):
name = tf_session.TF_DeviceListName(raw_device_list, i)
device_type = tf_session.TF_DeviceListType(raw_device_list, i)
memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i)
- device_list.append(_DeviceAttributes(name, device_type, memory))
+ incarnation = tf_session.TF_DeviceListIncarnation(raw_device_list, i)
+ device_list.append(
+ _DeviceAttributes(name, device_type, memory, incarnation))
tf_session.TF_DeleteDeviceList(raw_device_list)
return device_list