diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-21 03:54:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-21 03:59:07 -0700 |
commit | 200b89761a4665e3de6d0efc4e3e10ab287ad81b (patch) | |
tree | 01f5e4d091c227c7de8c813e55f5ed3e5f63bf85 /tensorflow/python/client | |
parent | d1e9a1ed54cae9b0b10ab89c06d6d7f9b53af3a1 (diff) |
Added fetch support for attrs classes.
Given a class
@attr.s()
class SampleAttr(object):
field_1 = attr.ib()
field_2 = attr.ib()
we will be able to run
obj = SampleAttr(tensor_1, tensor_2)
session.run(obj) # equivalent with session.run([obj.field_1, obj.field_2])
Please note, this does not need nest flatten support (which is only relevant to the feed_dict argument).
Also, the information in __attrs_attrs__ is provided for extensions (as per the docs: http://www.attrs.org/en/stable/extending.html#extending-metadata) like this and is not an "implementation detail".
PiperOrigin-RevId: 213963978
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r-- | tensorflow/python/client/session.py | 46 | ||||
-rw-r--r-- | tensorflow/python/client/session_test.py | 82 |
2 files changed, 126 insertions, 2 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index ae0ad27f15..c963cfd334 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -178,16 +178,30 @@ def register_session_run_conversion_functions( feed_function_for_partial_run: A callable for specifying tensor values to feed when setting up a partial run, which takes a `tensor_type` type object as input, and returns a list of Tensors. + + Raises: + ValueError: If `tensor_type` has already been registered. """ for conversion_function in _REGISTERED_EXPANSIONS: if issubclass(conversion_function[0], tensor_type): - raise ValueError('%s has already been registered so ignore it.', + raise ValueError('%s has already been registered so ignore it.' % tensor_type) - return + _REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function, feed_function_for_partial_run)) +def _is_attrs_instance(obj): + """Returns True if the given obj is an instance of attrs-decorated class.""" + return getattr(obj.__class__, '__attrs_attrs__', None) is not None + + +def _get_attrs_values(obj): + """Returns the list of values from an attrs instance.""" + attrs = getattr(obj.__class__, '__attrs_attrs__') + return [getattr(obj, a.name) for a in attrs] + + class _FetchMapper(object): """Definition of the interface provided by fetch mappers. @@ -247,6 +261,8 @@ class _FetchMapper(object): return _ListFetchMapper(fetch) elif isinstance(fetch, collections.Mapping): return _DictFetchMapper(fetch) + elif _is_attrs_instance(fetch): + return _AttrsFetchMapper(fetch) else: # Look for a handler in the registered expansions. for tensor_type, fetch_fn, _, _ in _REGISTERED_EXPANSIONS: @@ -398,6 +414,32 @@ class _DictFetchMapper(_FetchMapper): return results +class _AttrsFetchMapper(_FetchMapper): + """Fetch mapper for attrs decorated classes.""" + + def __init__(self, fetches): + """Creates a _AttrsFetchMapper. + + Args: + fetches: An instance of an attrs decorated class. + """ + values = _get_attrs_values(fetches) + self._fetch_type = type(fetches) + self._mappers = [ + _FetchMapper.for_fetch(fetch) for fetch in values + ] + self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers) + + def unique_fetches(self): + return self._unique_fetches + + def build_results(self, values): + results = [] + for m, vi in zip(self._mappers, self._value_indices): + results.append(m.build_results([values[j] for j in vi])) + return self._fetch_type(*results) + + class _FetchHandler(object): """Handler for structured fetches. diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 4afc6399d5..f576435136 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -61,6 +61,12 @@ from tensorflow.python.platform import googletest from tensorflow.python.training import server_lib from tensorflow.python.util import compat +try: + import attr # pylint:disable=g-import-not-at-top +except ImportError: + attr = None + + # NOTE(mrry): Dummy shape registration for ops used in the tests, since they # don't have C++ op registrations on which to attach C++ shape fns. ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape) @@ -300,6 +306,82 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertEqual(None, res[2]) self.assertEqual(44.0, res[1]) + def testFetchAttrs(self): + if attr is None: + self.skipTest('attr module is unavailable.') + + @attr.s + class SampleAttr(object): + field1 = attr.ib() + field2 = attr.ib() + + val1 = np.array([1.2, 3.4, 5.6]) + val2 = np.array([[1, 2], [4, 3]]) + val3 = np.array([10, 20, 30]) + + t1 = constant_op.constant(val1) + t2 = constant_op.constant(val2) + + sample = SampleAttr(t1, t2) + with session.Session() as sess: + result = sess.run(sample) + self.assertIsInstance(result, SampleAttr) + self.assertAllEqual(val1, result.field1) + self.assertAllEqual(val2, result.field2) + + result = sess.run(sample, feed_dict={sample.field1: val3}) + self.assertIsInstance(result, SampleAttr) + self.assertAllEqual(val3, result.field1) + self.assertAllEqual(val2, result.field2) + + def testFetchNestedAttrs(self): + if attr is None: + self.skipTest('attr module is unavailable.') + + @attr.s + class SampleAttr(object): + field0 = attr.ib() + field1 = attr.ib() + + v1 = 10 + v2 = 20 + v3 = np.float32(1.2) + v4 = np.float32(3.4) + v5 = np.float64(100.001) + v6 = np.float64(-23.451) + arr1 = np.array([1.2, 6.7, 3.4]) + arr2 = np.array([7, 11, 3]) + sample = SampleAttr( + SampleAttr( + SampleAttr(constant_op.constant(v1), constant_op.constant(v2)), + SampleAttr(constant_op.constant(arr1), constant_op.constant(arr2))), + {'A': SampleAttr(constant_op.constant(v3), constant_op.constant(v4)), + 'B': [SampleAttr(constant_op.constant(v5), constant_op.constant(v6))]}) + + with session.Session() as sess: + result = sess.run(sample) + self.assertIsInstance(result, SampleAttr) + self.assertIsInstance(result.field0, SampleAttr) + self.assertIsInstance(result.field0.field0, SampleAttr) + self.assertIsInstance(result.field0.field1, SampleAttr) + self.assertIsInstance(result.field0.field1.field0, np.ndarray) + self.assertAllEqual(arr1, result.field0.field1.field0) + self.assertIsInstance(result.field0.field1.field1, np.ndarray) + self.assertAllEqual(arr2, result.field0.field1.field1) + self.assertIsInstance(result.field1, dict) + self.assertIn('A', result.field1) + self.assertIn('B', result.field1) + self.assertIsInstance(result.field1['A'], SampleAttr) + self.assertAllEqual( + [v3, v4], + [result.field1['A'].field0, result.field1['A'].field1]) + self.assertIsInstance(result.field1['B'], list) + self.assertEqual(1, len(result.field1['B'])) + self.assertIsInstance(result.field1['B'][0], SampleAttr) + self.assertAllEqual( + [v5, v6], + [result.field1['B'][0].field0, result.field1['B'][0].field1]) + def testFetchNestingEmptyOneLevel(self): with session.Session() as sess: a_val = 11.0 |