aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-31 01:49:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-31 03:03:01 -0700
commit8100421994fd3e89b31d3605cf9468cca0217ac5 (patch)
treecc0fb77d0949f2c201beff6a222ac09020c0e691
parent7c10d9a66fbe923e475eff6fc818feffd41db574 (diff)
Add support for fetching dictionaries in the Python Session class.
For example: with tf.Graph().as_default(), tf.Session() as sess: x = tf.constant(7) res = sess.run({'x': x, 'y': [x, tf.square(x)]} # res['x'] = 7 # res['y'] = [7, 49] Change: 123623373
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions.py29
-rw-r--r--tensorflow/python/client/session.py99
-rw-r--r--tensorflow/python/client/session_test.py25
3 files changed, 105 insertions, 48 deletions
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py
index f9d37a3c10..b3fefd34a4 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions.py
@@ -104,27 +104,6 @@ def _restore_from_checkpoint(session, graph, checkpoint_path, saver=None):
logging.info('No variables found in graph, not creating Saver() object.')
-def _run_dict(session, run_dict, feed_dict=None):
- """Convenience function to run a session on each item in a dict of tensors.
-
- Args:
- session: The session to evaluate.
- run_dict: A dict of tensors to be run in the session.
- feed_dict: Feed dict to be used in running the session.
-
- Returns:
- A dict containing the result of evaluating the tensors.
- Raises:
- ValueError: if `run_dict` is missing or empty.
- """
- if run_dict is None:
- raise ValueError('Invalid run_dict %s.', run_dict)
- keys = run_dict.keys()
- tensors = [run_dict[key] for key in keys]
- values = session.run(tensors, feed_dict=feed_dict)
- return dict(zip(keys, values))
-
-
def _run_with_monitors(session, step, tensors, feed_dict, monitors):
"""Runs session for given tensors with monitor callbacks."""
for monitor in monitors:
@@ -506,13 +485,13 @@ def evaluate(graph,
if update_op is not None:
session.run(update_op, feed_dict=feed_dict)
else:
- eval_results = _run_dict(session, eval_dict, feed_dict=feed_dict)
+ eval_results = session.run(eval_dict, feed_dict=feed_dict)
eval_step = step
# TODO(wicke): We should assert that the global step hasn't changed.
if step % log_every_steps == 0:
if eval_step is None or step != eval_step:
- eval_results = _run_dict(session, eval_dict, feed_dict=feed_dict)
+ eval_results = session.run(eval_dict, feed_dict=feed_dict)
eval_step = step
duration = time.time() - start_time
logging.info('Results after %d steps (%.3f sec/batch): %s.',
@@ -521,7 +500,7 @@ def evaluate(graph,
for k, v in eval_results.items()))
finally:
if eval_results is None or step != eval_step:
- eval_results = _run_dict(session, eval_dict, feed_dict=feed_dict)
+ eval_results = session.run(eval_dict, feed_dict=feed_dict)
eval_step = step
# Stop queue runners.
coord.request_stop()
@@ -620,7 +599,7 @@ def run_feeds(output_dict, feed_dicts, restore_checkpoint_path=None):
coord = Coordinator()
try:
queue_runner.start_queue_runners(session, coord=coord)
- return [_run_dict(session, output_dict, f) for f in feed_dicts]
+ return [session.run(output_dict, f) for f in feed_dicts]
finally:
coord.request_stop()
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 32f0e3f6ce..943fb70b13 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -69,6 +69,42 @@ def _get_feeds_for_indexed_slices(feed, feed_val):
[feed.values, feed.indices, feed.dense_shape], feed_val))
+def _flatten1(seq):
+ """Flattens one level of nested sequences."""
+ ret = []
+ for el in seq:
+ if isinstance(el, (list, tuple)):
+ ret.extend(el)
+ else:
+ ret.append(el)
+ return ret
+
+
+def _unflatten_fetches(fetches, flat_values):
+ """Creates a dictionary mapping fetched keys to values.
+
+ Args:
+ fetches: A heterogeneous list of either graph elements or lists/tuples
+ of graph elements.
+ flat_values: A flat list of fetched values.
+
+ Returns:
+ A dictionary with the same keys as `fetches`, mapping to the fetched value
+ (or list of values) in `flat_values`.
+ """
+ used = 0
+ ret = {}
+ for key, fetch in fetches.items():
+ if isinstance(fetch, (list, tuple)):
+ start, used = used, used + len(fetch)
+ ret[key] = flat_values[start : used]
+ else:
+ ret[key] = flat_values[used]
+ used += 1
+ assert used == len(flat_values)
+ return ret
+
+
class BaseSession(SessionInterface):
"""A class for interacting with a TensorFlow computation.
@@ -256,24 +292,25 @@ class BaseSession(SessionInterface):
and evaluate every `Tensor` in `fetches`, substituting the values in
`feed_dict` for the corresponding input values.
- The `fetches` argument may be a list of graph elements or a single
- graph element, and these determine the return value of this
+ The `fetches` argument may be a single graph element, a list of
+ graph elements, or a dictionary whose values are the above. The type of
+ `fetches` determines the return value of this
method. A graph element can be one of the following types:
- * If the *i*th element of `fetches` is an
- [`Operation`](../../api_docs/python/framework.md#Operation), the *i*th
- return value will be `None`.
- * If the *i*th element of `fetches` is a
- [`Tensor`](../../api_docs/python/framework.md#Tensor), the *i*th return
- value will be a numpy ndarray containing the value of that tensor.
- * If the *i*th element of `fetches` is a
+ * If an element of `fetches` is an
+ [`Operation`](../../api_docs/python/framework.md#Operation), the
+ corresponding fetched value will be `None`.
+ * If an element of `fetches` is a
+ [`Tensor`](../../api_docs/python/framework.md#Tensor), the corresponding
+ fetched value will be a numpy ndarray containing the value of that tensor.
+ * If an element of `fetches` is a
[`SparseTensor`](../../api_docs/python/sparse_ops.md#SparseTensor),
- the *i*th return value will be a
+ the corresponding fetched value will be a
[`SparseTensorValue`](../../api_docs/python/sparse_ops.md#SparseTensorValue)
containing the value of that sparse tensor.
- * If the *i*th element of `fetches` is produced by a `get_tensor_handle` op,
- the *i*th return value will be a numpy ndarray containing the handle of
- that tensor.
+ * If an element of `fetches` is produced by a `get_tensor_handle` op,
+ the corresponding fetched value will be a numpy ndarray containing the
+ handle of that tensor.
The optional `feed_dict` argument allows the caller to override
the value of tensors in the graph. Each key in `feed_dict` can be
@@ -303,8 +340,9 @@ class BaseSession(SessionInterface):
collected into this argument and passed back.
Args:
- fetches: A single graph element, or a list of graph elements
- (described above).
+ fetches: A single graph element, a list of graph elements,
+ or a dictionary whose values are graph elements or lists of graph
+ elements (described above).
feed_dict: A dictionary that maps graph elements to values
(described above).
options: A [`RunOptions`] protocol buffer
@@ -312,7 +350,8 @@ class BaseSession(SessionInterface):
Returns:
Either a single value if `fetches` is a single graph element, or
- a list of values if `fetches` is a list (described above).
+ a list of values if `fetches` is a list, or a dictionary with the
+ same keys as `fetches` if that is a dictionary (described above).
Raises:
RuntimeError: If this `Session` is in an invalid state (e.g. has been
@@ -369,14 +408,17 @@ class BaseSession(SessionInterface):
Args:
handle: A handle for a sequence of partial runs.
- fetches: A single graph element, or a list of graph elements
- (described above).
+ fetches: A single graph element, a list of graph elements,
+ or a dictionary whose values are graph elements or lists of graph
+ elements (see documentation for `run`).
feed_dict: A dictionary that maps graph elements to values
(described above).
Returns:
Either a single value if `fetches` is a single graph element, or
- a list of values if `fetches` is a list (described above).
+ a list of values if `fetches` is a list, or a dictionary with the
+ same keys as `fetches` if that is a dictionary
+ (see documentation for `run`).
Raises:
tf.errors.OpError: Or one of its subclasses on error.
@@ -517,6 +559,20 @@ class BaseSession(SessionInterface):
raise RuntimeError('The Session graph is empty. Add operations to the '
'graph before calling run().')
+ # Flatten/unflatten fetched values.
+ if isinstance(fetches, (list, tuple)):
+ # fetches is already a list or tuple; nothing to do.
+ unflatten = lambda fetched: fetched
+ elif isinstance(fetches, dict):
+ # fetches is a dictionary; flatten the values and map fetched
+ # values back into to a dictionary.
+ orig_fetches, fetches = fetches, _flatten1(fetches.values())
+ unflatten = lambda fetched: _unflatten_fetches(orig_fetches, fetched)
+ else:
+ # fetches is a singleton.
+ fetches = [fetches]
+ unflatten = lambda fetched: fetched[0]
+
# Validate and process fetches.
processed_fetches = self._process_fetches(fetches)
unique_fetches = processed_fetches[0]
@@ -595,10 +651,7 @@ class BaseSession(SessionInterface):
else:
ret.append(None)
- if isinstance(fetches, (list, tuple)):
- return ret
- else:
- return ret[0]
+ return unflatten(ret)
# Captures the name of a node in an error status.
_NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =')
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index c3cce65a9b..e97acb68be 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -182,6 +182,12 @@ class SessionTest(test_util.TensorFlowTestCase):
a_val, b_val = s.run([a, b]) # Test multiple fetches.
self.assertAllEqual([[1.0, 1.0]], a_val)
self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], b_val)
+ results_with_dict = s.run({'a': [a], 'b': b, 'z': [a, b]})
+ self.assertAllEqual([[1.0, 1.0]], results_with_dict['a'][0])
+ self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]],
+ results_with_dict['b'])
+ self.assertAllEqual(results_with_dict['a'][0], results_with_dict['z'][0])
+ self.assertAllEqual(results_with_dict['b'], results_with_dict['z'][1])
def testFetchScalar(self):
with session.Session() as s:
@@ -199,6 +205,10 @@ class SessionTest(test_util.TensorFlowTestCase):
xy, = s.run([tf_xy])
self.assertEqual(scalar, type(xy))
self.assertEqual(x + y, xy)
+ # Dict fetch
+ xy = s.run({'xy': tf_xy})['xy']
+ self.assertEqual(scalar, type(xy))
+ self.assertEqual(x + y, xy)
def testFetchOperationObject(self):
with session.Session() as s:
@@ -243,6 +253,21 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertAllEqual(sp_out.indices, indices)
self.assertAllEqual(sp_out.values, values)
self.assertAllEqual(sp_out.shape, shape)
+ # Dict fetch (single value), use as tuple
+ indices_out, values_out, shape_out = s.run({'sp': sp})['sp']
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # Dict fetch (list value), use as tuple
+ (indices_out, values_out, shape_out), = s.run({'sp': [sp]})['sp']
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # Dict fetch, use as SparseTensorValue
+ sp_out = s.run({'sp': sp})['sp']
+ self.assertAllEqual(sp_out.indices, indices)
+ self.assertAllEqual(sp_out.values, values)
+ self.assertAllEqual(sp_out.shape, shape)
def testFeedSparseTensor(self):
with session.Session() as s: