aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-21 03:54:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 03:59:07 -0700
commit200b89761a4665e3de6d0efc4e3e10ab287ad81b (patch)
tree01f5e4d091c227c7de8c813e55f5ed3e5f63bf85 /tensorflow/python/client
parentd1e9a1ed54cae9b0b10ab89c06d6d7f9b53af3a1 (diff)
Added fetch support for attrs classes.
Given a class @attr.s() class SampleAttr(object): field_1 = attr.ib() field_2 = attr.ib() we will be able to run obj = SampleAttr(tensor_1, tensor_2) session.run(obj) # equivalent with session.run([obj.field_1, obj.field_2]) Please note, this does not need nest flatten support (which is only relevant to the feed_dict argument). Also, the information in __attrs_attrs__ is provided for extensions (as per the docs: http://www.attrs.org/en/stable/extending.html#extending-metadata) like this and is not an "implementation detail". PiperOrigin-RevId: 213963978
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r--tensorflow/python/client/session.py46
-rw-r--r--tensorflow/python/client/session_test.py82
2 files changed, 126 insertions, 2 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index ae0ad27f15..c963cfd334 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -178,16 +178,30 @@ def register_session_run_conversion_functions(
feed_function_for_partial_run: A callable for specifying tensor values to
feed when setting up a partial run, which takes a `tensor_type` type
object as input, and returns a list of Tensors.
+
+ Raises:
+ ValueError: If `tensor_type` has already been registered.
"""
for conversion_function in _REGISTERED_EXPANSIONS:
if issubclass(conversion_function[0], tensor_type):
- raise ValueError('%s has already been registered so ignore it.',
+ raise ValueError('%s has already been registered so ignore it.' %
tensor_type)
- return
+
_REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function,
feed_function_for_partial_run))
+def _is_attrs_instance(obj):
+ """Returns True if the given obj is an instance of attrs-decorated class."""
+ return getattr(obj.__class__, '__attrs_attrs__', None) is not None
+
+
+def _get_attrs_values(obj):
+ """Returns the list of values from an attrs instance."""
+ attrs = getattr(obj.__class__, '__attrs_attrs__')
+ return [getattr(obj, a.name) for a in attrs]
+
+
class _FetchMapper(object):
"""Definition of the interface provided by fetch mappers.
@@ -247,6 +261,8 @@ class _FetchMapper(object):
return _ListFetchMapper(fetch)
elif isinstance(fetch, collections.Mapping):
return _DictFetchMapper(fetch)
+ elif _is_attrs_instance(fetch):
+ return _AttrsFetchMapper(fetch)
else:
# Look for a handler in the registered expansions.
for tensor_type, fetch_fn, _, _ in _REGISTERED_EXPANSIONS:
@@ -398,6 +414,32 @@ class _DictFetchMapper(_FetchMapper):
return results
+class _AttrsFetchMapper(_FetchMapper):
+ """Fetch mapper for attrs decorated classes."""
+
+ def __init__(self, fetches):
+ """Creates a _AttrsFetchMapper.
+
+ Args:
+ fetches: An instance of an attrs decorated class.
+ """
+ values = _get_attrs_values(fetches)
+ self._fetch_type = type(fetches)
+ self._mappers = [
+ _FetchMapper.for_fetch(fetch) for fetch in values
+ ]
+ self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
+
+ def unique_fetches(self):
+ return self._unique_fetches
+
+ def build_results(self, values):
+ results = []
+ for m, vi in zip(self._mappers, self._value_indices):
+ results.append(m.build_results([values[j] for j in vi]))
+ return self._fetch_type(*results)
+
+
class _FetchHandler(object):
"""Handler for structured fetches.
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 4afc6399d5..f576435136 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -61,6 +61,12 @@ from tensorflow.python.platform import googletest
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
+try:
+ import attr # pylint:disable=g-import-not-at-top
+except ImportError:
+ attr = None
+
+
# NOTE(mrry): Dummy shape registration for ops used in the tests, since they
# don't have C++ op registrations on which to attach C++ shape fns.
ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
@@ -300,6 +306,82 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(None, res[2])
self.assertEqual(44.0, res[1])
+ def testFetchAttrs(self):
+ if attr is None:
+ self.skipTest('attr module is unavailable.')
+
+ @attr.s
+ class SampleAttr(object):
+ field1 = attr.ib()
+ field2 = attr.ib()
+
+ val1 = np.array([1.2, 3.4, 5.6])
+ val2 = np.array([[1, 2], [4, 3]])
+ val3 = np.array([10, 20, 30])
+
+ t1 = constant_op.constant(val1)
+ t2 = constant_op.constant(val2)
+
+ sample = SampleAttr(t1, t2)
+ with session.Session() as sess:
+ result = sess.run(sample)
+ self.assertIsInstance(result, SampleAttr)
+ self.assertAllEqual(val1, result.field1)
+ self.assertAllEqual(val2, result.field2)
+
+ result = sess.run(sample, feed_dict={sample.field1: val3})
+ self.assertIsInstance(result, SampleAttr)
+ self.assertAllEqual(val3, result.field1)
+ self.assertAllEqual(val2, result.field2)
+
+ def testFetchNestedAttrs(self):
+ if attr is None:
+ self.skipTest('attr module is unavailable.')
+
+ @attr.s
+ class SampleAttr(object):
+ field0 = attr.ib()
+ field1 = attr.ib()
+
+ v1 = 10
+ v2 = 20
+ v3 = np.float32(1.2)
+ v4 = np.float32(3.4)
+ v5 = np.float64(100.001)
+ v6 = np.float64(-23.451)
+ arr1 = np.array([1.2, 6.7, 3.4])
+ arr2 = np.array([7, 11, 3])
+ sample = SampleAttr(
+ SampleAttr(
+ SampleAttr(constant_op.constant(v1), constant_op.constant(v2)),
+ SampleAttr(constant_op.constant(arr1), constant_op.constant(arr2))),
+ {'A': SampleAttr(constant_op.constant(v3), constant_op.constant(v4)),
+ 'B': [SampleAttr(constant_op.constant(v5), constant_op.constant(v6))]})
+
+ with session.Session() as sess:
+ result = sess.run(sample)
+ self.assertIsInstance(result, SampleAttr)
+ self.assertIsInstance(result.field0, SampleAttr)
+ self.assertIsInstance(result.field0.field0, SampleAttr)
+ self.assertIsInstance(result.field0.field1, SampleAttr)
+ self.assertIsInstance(result.field0.field1.field0, np.ndarray)
+ self.assertAllEqual(arr1, result.field0.field1.field0)
+ self.assertIsInstance(result.field0.field1.field1, np.ndarray)
+ self.assertAllEqual(arr2, result.field0.field1.field1)
+ self.assertIsInstance(result.field1, dict)
+ self.assertIn('A', result.field1)
+ self.assertIn('B', result.field1)
+ self.assertIsInstance(result.field1['A'], SampleAttr)
+ self.assertAllEqual(
+ [v3, v4],
+ [result.field1['A'].field0, result.field1['A'].field1])
+ self.assertIsInstance(result.field1['B'], list)
+ self.assertEqual(1, len(result.field1['B']))
+ self.assertIsInstance(result.field1['B'][0], SampleAttr)
+ self.assertAllEqual(
+ [v5, v6],
+ [result.field1['B'][0].field0, result.field1['B'][0].field1])
+
def testFetchNestingEmptyOneLevel(self):
with session.Session() as sess:
a_val = 11.0