aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-30 12:19:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-30 12:25:21 -0800
commitdad0c9e85439d3a63ac3168ea5a52f5365e52028 (patch)
treef3680f56a9d44ba24e2da515eb2aeeacf2aa914d
parent873a54cae6f102afe52c97470da58741288b5dac (diff)
Preserves order when calling Session.run on an OrderedDict.
Change: 140635682
-rw-r--r--tensorflow/python/client/session.py7
-rw-r--r--tensorflow/python/client/session_test.py12
2 files changed, 16 insertions, 3 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 71c931037e..591cc5afbc 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -364,6 +364,7 @@ class _DictFetchMapper(_FetchMapper):
Args:
fetches: Dict of fetches.
"""
+ self._fetch_type = type(fetches)
self._keys = fetches.keys()
self._mappers = [_FetchMapper.for_fetch(fetch)
for fetch in fetches.values()]
@@ -373,7 +374,7 @@ class _DictFetchMapper(_FetchMapper):
return self._unique_fetches
def build_results(self, values):
- results = {}
+ results = self._fetch_type()
for k, m, vi in zip(self._keys, self._mappers, self._value_indices):
results[k] = m.build_results([values[j] for j in vi])
return results
@@ -661,8 +662,8 @@ class BaseSession(SessionInterface):
`feed_dict` for the corresponding input values.
The `fetches` argument may be a single graph element, or an arbitrarily
- nested list, tuple, namedtuple, or dict containing graph elements at its
- leaves. A graph element can be one of the following types:
+ nested list, tuple, namedtuple, dict, or OrderedDict containing graph
+ elements at its leaves. A graph element can be one of the following types:
* An [`Operation`](../../api_docs/python/framework.md#Operation).
The corresponding fetched value will be `None`.
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index a20376b91d..0c602a9014 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -254,6 +254,18 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(None, res['b'])
self.assertEqual(44.0, res['c'])
+ def testFetchOrderedDict(self):
+ with session.Session() as sess:
+ a = constant_op.constant(42.0)
+ b = control_flow_ops.no_op() # An op, not a tensor.
+ c = constant_op.constant(44.0)
+ res = sess.run(collections.OrderedDict([(3, a), (2, b), (1, c)]))
+ self.assertTrue(isinstance(res, collections.OrderedDict))
+ self.assertEqual([3, 2, 1], list(res.keys()))
+ self.assertEqual(42.0, res[3])
+ self.assertEqual(None, res[2])
+ self.assertEqual(44.0, res[1])
+
def testFetchNestingEmptyOneLevel(self):
with session.Session() as sess:
a_val = 11.0