diff options
author | 2016-05-31 01:49:17 -0800 | |
---|---|---|
committer | 2016-05-31 03:03:01 -0700 | |
commit | 8100421994fd3e89b31d3605cf9468cca0217ac5 (patch) | |
tree | cc0fb77d0949f2c201beff6a222ac09020c0e691 | |
parent | 7c10d9a66fbe923e475eff6fc818feffd41db574 (diff) |
Add support for fetching dictionaries in the Python Session class.
For example:
with tf.Graph().as_default(), tf.Session() as sess:
x = tf.constant(7)
res = sess.run({'x': x, 'y': [x, tf.square(x)]}
# res['x'] = 7
# res['y'] = [7, 49]
Change: 123623373
-rw-r--r-- | tensorflow/contrib/learn/python/learn/graph_actions.py | 29 | ||||
-rw-r--r-- | tensorflow/python/client/session.py | 99 | ||||
-rw-r--r-- | tensorflow/python/client/session_test.py | 25 |
3 files changed, 105 insertions, 48 deletions
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py index f9d37a3c10..b3fefd34a4 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions.py @@ -104,27 +104,6 @@ def _restore_from_checkpoint(session, graph, checkpoint_path, saver=None): logging.info('No variables found in graph, not creating Saver() object.') -def _run_dict(session, run_dict, feed_dict=None): - """Convenience function to run a session on each item in a dict of tensors. - - Args: - session: The session to evaluate. - run_dict: A dict of tensors to be run in the session. - feed_dict: Feed dict to be used in running the session. - - Returns: - A dict containing the result of evaluating the tensors. - Raises: - ValueError: if `run_dict` is missing or empty. - """ - if run_dict is None: - raise ValueError('Invalid run_dict %s.', run_dict) - keys = run_dict.keys() - tensors = [run_dict[key] for key in keys] - values = session.run(tensors, feed_dict=feed_dict) - return dict(zip(keys, values)) - - def _run_with_monitors(session, step, tensors, feed_dict, monitors): """Runs session for given tensors with monitor callbacks.""" for monitor in monitors: @@ -506,13 +485,13 @@ def evaluate(graph, if update_op is not None: session.run(update_op, feed_dict=feed_dict) else: - eval_results = _run_dict(session, eval_dict, feed_dict=feed_dict) + eval_results = session.run(eval_dict, feed_dict=feed_dict) eval_step = step # TODO(wicke): We should assert that the global step hasn't changed. if step % log_every_steps == 0: if eval_step is None or step != eval_step: - eval_results = _run_dict(session, eval_dict, feed_dict=feed_dict) + eval_results = session.run(eval_dict, feed_dict=feed_dict) eval_step = step duration = time.time() - start_time logging.info('Results after %d steps (%.3f sec/batch): %s.', @@ -521,7 +500,7 @@ def evaluate(graph, for k, v in eval_results.items())) finally: if eval_results is None or step != eval_step: - eval_results = _run_dict(session, eval_dict, feed_dict=feed_dict) + eval_results = session.run(eval_dict, feed_dict=feed_dict) eval_step = step # Stop queue runners. coord.request_stop() @@ -620,7 +599,7 @@ def run_feeds(output_dict, feed_dicts, restore_checkpoint_path=None): coord = Coordinator() try: queue_runner.start_queue_runners(session, coord=coord) - return [_run_dict(session, output_dict, f) for f in feed_dicts] + return [session.run(output_dict, f) for f in feed_dicts] finally: coord.request_stop() diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 32f0e3f6ce..943fb70b13 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -69,6 +69,42 @@ def _get_feeds_for_indexed_slices(feed, feed_val): [feed.values, feed.indices, feed.dense_shape], feed_val)) +def _flatten1(seq): + """Flattens one level of nested sequences.""" + ret = [] + for el in seq: + if isinstance(el, (list, tuple)): + ret.extend(el) + else: + ret.append(el) + return ret + + +def _unflatten_fetches(fetches, flat_values): + """Creates a dictionary mapping fetched keys to values. + + Args: + fetches: A heterogeneous list of either graph elements or lists/tuples + of graph elements. + flat_values: A flat list of fetched values. + + Returns: + A dictionary with the same keys as `fetches`, mapping to the fetched value + (or list of values) in `flat_values`. + """ + used = 0 + ret = {} + for key, fetch in fetches.items(): + if isinstance(fetch, (list, tuple)): + start, used = used, used + len(fetch) + ret[key] = flat_values[start : used] + else: + ret[key] = flat_values[used] + used += 1 + assert used == len(flat_values) + return ret + + class BaseSession(SessionInterface): """A class for interacting with a TensorFlow computation. @@ -256,24 +292,25 @@ class BaseSession(SessionInterface): and evaluate every `Tensor` in `fetches`, substituting the values in `feed_dict` for the corresponding input values. - The `fetches` argument may be a list of graph elements or a single - graph element, and these determine the return value of this + The `fetches` argument may be a single graph element, a list of + graph elements, or a dictionary whose values are the above. The type of + `fetches` determines the return value of this method. A graph element can be one of the following types: - * If the *i*th element of `fetches` is an - [`Operation`](../../api_docs/python/framework.md#Operation), the *i*th - return value will be `None`. - * If the *i*th element of `fetches` is a - [`Tensor`](../../api_docs/python/framework.md#Tensor), the *i*th return - value will be a numpy ndarray containing the value of that tensor. - * If the *i*th element of `fetches` is a + * If an element of `fetches` is an + [`Operation`](../../api_docs/python/framework.md#Operation), the + corresponding fetched value will be `None`. + * If an element of `fetches` is a + [`Tensor`](../../api_docs/python/framework.md#Tensor), the corresponding + fetched value will be a numpy ndarray containing the value of that tensor. + * If an element of `fetches` is a [`SparseTensor`](../../api_docs/python/sparse_ops.md#SparseTensor), - the *i*th return value will be a + the corresponding fetched value will be a [`SparseTensorValue`](../../api_docs/python/sparse_ops.md#SparseTensorValue) containing the value of that sparse tensor. - * If the *i*th element of `fetches` is produced by a `get_tensor_handle` op, - the *i*th return value will be a numpy ndarray containing the handle of - that tensor. + * If an element of `fetches` is produced by a `get_tensor_handle` op, + the corresponding fetched value will be a numpy ndarray containing the + handle of that tensor. The optional `feed_dict` argument allows the caller to override the value of tensors in the graph. Each key in `feed_dict` can be @@ -303,8 +340,9 @@ class BaseSession(SessionInterface): collected into this argument and passed back. Args: - fetches: A single graph element, or a list of graph elements - (described above). + fetches: A single graph element, a list of graph elements, + or a dictionary whose values are graph elements or lists of graph + elements (described above). feed_dict: A dictionary that maps graph elements to values (described above). options: A [`RunOptions`] protocol buffer @@ -312,7 +350,8 @@ class BaseSession(SessionInterface): Returns: Either a single value if `fetches` is a single graph element, or - a list of values if `fetches` is a list (described above). + a list of values if `fetches` is a list, or a dictionary with the + same keys as `fetches` if that is a dictionary (described above). Raises: RuntimeError: If this `Session` is in an invalid state (e.g. has been @@ -369,14 +408,17 @@ class BaseSession(SessionInterface): Args: handle: A handle for a sequence of partial runs. - fetches: A single graph element, or a list of graph elements - (described above). + fetches: A single graph element, a list of graph elements, + or a dictionary whose values are graph elements or lists of graph + elements (see documentation for `run`). feed_dict: A dictionary that maps graph elements to values (described above). Returns: Either a single value if `fetches` is a single graph element, or - a list of values if `fetches` is a list (described above). + a list of values if `fetches` is a list, or a dictionary with the + same keys as `fetches` if that is a dictionary + (see documentation for `run`). Raises: tf.errors.OpError: Or one of its subclasses on error. @@ -517,6 +559,20 @@ class BaseSession(SessionInterface): raise RuntimeError('The Session graph is empty. Add operations to the ' 'graph before calling run().') + # Flatten/unflatten fetched values. + if isinstance(fetches, (list, tuple)): + # fetches is already a list or tuple; nothing to do. + unflatten = lambda fetched: fetched + elif isinstance(fetches, dict): + # fetches is a dictionary; flatten the values and map fetched + # values back into to a dictionary. + orig_fetches, fetches = fetches, _flatten1(fetches.values()) + unflatten = lambda fetched: _unflatten_fetches(orig_fetches, fetched) + else: + # fetches is a singleton. + fetches = [fetches] + unflatten = lambda fetched: fetched[0] + # Validate and process fetches. processed_fetches = self._process_fetches(fetches) unique_fetches = processed_fetches[0] @@ -595,10 +651,7 @@ class BaseSession(SessionInterface): else: ret.append(None) - if isinstance(fetches, (list, tuple)): - return ret - else: - return ret[0] + return unflatten(ret) # Captures the name of a node in an error status. _NODEDEF_NAME_RE = re.compile(r'\[\[Node: ([^ ]*?) =') diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index c3cce65a9b..e97acb68be 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -182,6 +182,12 @@ class SessionTest(test_util.TensorFlowTestCase): a_val, b_val = s.run([a, b]) # Test multiple fetches. self.assertAllEqual([[1.0, 1.0]], a_val) self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], b_val) + results_with_dict = s.run({'a': [a], 'b': b, 'z': [a, b]}) + self.assertAllEqual([[1.0, 1.0]], results_with_dict['a'][0]) + self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], + results_with_dict['b']) + self.assertAllEqual(results_with_dict['a'][0], results_with_dict['z'][0]) + self.assertAllEqual(results_with_dict['b'], results_with_dict['z'][1]) def testFetchScalar(self): with session.Session() as s: @@ -199,6 +205,10 @@ class SessionTest(test_util.TensorFlowTestCase): xy, = s.run([tf_xy]) self.assertEqual(scalar, type(xy)) self.assertEqual(x + y, xy) + # Dict fetch + xy = s.run({'xy': tf_xy})['xy'] + self.assertEqual(scalar, type(xy)) + self.assertEqual(x + y, xy) def testFetchOperationObject(self): with session.Session() as s: @@ -243,6 +253,21 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertAllEqual(sp_out.indices, indices) self.assertAllEqual(sp_out.values, values) self.assertAllEqual(sp_out.shape, shape) + # Dict fetch (single value), use as tuple + indices_out, values_out, shape_out = s.run({'sp': sp})['sp'] + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(values_out, values) + self.assertAllEqual(shape_out, shape) + # Dict fetch (list value), use as tuple + (indices_out, values_out, shape_out), = s.run({'sp': [sp]})['sp'] + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(values_out, values) + self.assertAllEqual(shape_out, shape) + # Dict fetch, use as SparseTensorValue + sp_out = s.run({'sp': sp})['sp'] + self.assertAllEqual(sp_out.indices, indices) + self.assertAllEqual(sp_out.values, values) + self.assertAllEqual(sp_out.shape, shape) def testFeedSparseTensor(self): with session.Session() as s: |