diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-11-30 12:19:09 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-30 12:25:21 -0800 |
commit | dad0c9e85439d3a63ac3168ea5a52f5365e52028 (patch) | |
tree | f3680f56a9d44ba24e2da515eb2aeeacf2aa914d | |
parent | 873a54cae6f102afe52c97470da58741288b5dac (diff) |
Preserves order when calling Session.run on an OrderedDict.
Change: 140635682
-rw-r--r-- | tensorflow/python/client/session.py | 7 | ||||
-rw-r--r-- | tensorflow/python/client/session_test.py | 12 |
2 files changed, 16 insertions, 3 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 71c931037e..591cc5afbc 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -364,6 +364,7 @@ class _DictFetchMapper(_FetchMapper): Args: fetches: Dict of fetches. """ + self._fetch_type = type(fetches) self._keys = fetches.keys() self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches.values()] @@ -373,7 +374,7 @@ class _DictFetchMapper(_FetchMapper): return self._unique_fetches def build_results(self, values): - results = {} + results = self._fetch_type() for k, m, vi in zip(self._keys, self._mappers, self._value_indices): results[k] = m.build_results([values[j] for j in vi]) return results @@ -661,8 +662,8 @@ class BaseSession(SessionInterface): `feed_dict` for the corresponding input values. The `fetches` argument may be a single graph element, or an arbitrarily - nested list, tuple, namedtuple, or dict containing graph elements at its - leaves. A graph element can be one of the following types: + nested list, tuple, namedtuple, dict, or OrderedDict containing graph + elements at its leaves. A graph element can be one of the following types: * An [`Operation`](../../api_docs/python/framework.md#Operation). The corresponding fetched value will be `None`. diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index a20376b91d..0c602a9014 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -254,6 +254,18 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertEqual(None, res['b']) self.assertEqual(44.0, res['c']) + def testFetchOrderedDict(self): + with session.Session() as sess: + a = constant_op.constant(42.0) + b = control_flow_ops.no_op() # An op, not a tensor. + c = constant_op.constant(44.0) + res = sess.run(collections.OrderedDict([(3, a), (2, b), (1, c)])) + self.assertTrue(isinstance(res, collections.OrderedDict)) + self.assertEqual([3, 2, 1], list(res.keys())) + self.assertEqual(42.0, res[3]) + self.assertEqual(None, res[2]) + self.assertEqual(44.0, res[1]) + def testFetchNestingEmptyOneLevel(self): with session.Session() as sess: a_val = 11.0 |