diff options
Diffstat (limited to 'tensorflow/python/client/session.py')
-rw-r--r-- | tensorflow/python/client/session.py | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index c629e7e34a..918e7e4da6 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -53,6 +53,17 @@ class SessionInterface(object): raise NotImplementedError('Run') +def _get_indexed_slices_value_from_fetches(fetched_vals): + return ops.IndexedSlicesValue(fetched_vals[0], fetched_vals[1], + fetched_vals[2] + if len(fetched_vals) == 3 else None) + + +def _get_feeds_for_indexed_slices(feed, feed_val): + return list(zip([feed.values, feed.indices] if feed.dense_shape is None else + [feed.values, feed.indices, feed.dense_shape], feed_val)) + + class BaseSession(SessionInterface): """A class for interacting with a TensorFlow computation. @@ -221,6 +232,14 @@ class BaseSession(SessionInterface): lambda fetched_vals: ops.SparseTensorValue(*fetched_vals)), lambda feed, feed_val: list(zip( [feed.indices, feed.values, feed.shape], feed_val))), + # IndexedSlices are fetched as IndexedSlicesValues. They can be fed + # IndexedSlicesValues or normal tuples. + (ops.IndexedSlices, + lambda fetch: ( + [fetch.values, fetch.indices] if fetch.dense_shape is None + else [fetch.values, fetch.indices, fetch.dense_shape], + _get_indexed_slices_value_from_fetches), + _get_feeds_for_indexed_slices), # The default catches all types and performs no expansions. (object, lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]), |