aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client/session.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/client/session.py')
-rw-r--r--tensorflow/python/client/session.py19
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index c629e7e34a..918e7e4da6 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -53,6 +53,17 @@ class SessionInterface(object):
raise NotImplementedError('Run')
+def _get_indexed_slices_value_from_fetches(fetched_vals):
+ return ops.IndexedSlicesValue(fetched_vals[0], fetched_vals[1],
+ fetched_vals[2]
+ if len(fetched_vals) == 3 else None)
+
+
+def _get_feeds_for_indexed_slices(feed, feed_val):
+ return list(zip([feed.values, feed.indices] if feed.dense_shape is None else
+ [feed.values, feed.indices, feed.dense_shape], feed_val))
+
+
class BaseSession(SessionInterface):
"""A class for interacting with a TensorFlow computation.
@@ -221,6 +232,14 @@ class BaseSession(SessionInterface):
lambda fetched_vals: ops.SparseTensorValue(*fetched_vals)),
lambda feed, feed_val: list(zip(
[feed.indices, feed.values, feed.shape], feed_val))),
+ # IndexedSlices are fetched as IndexedSlicesValues. They can be fed
+ # IndexedSlicesValues or normal tuples.
+ (ops.IndexedSlices,
+ lambda fetch: (
+ [fetch.values, fetch.indices] if fetch.dense_shape is None
+ else [fetch.values, fetch.indices, fetch.dense_shape],
+ _get_indexed_slices_value_from_fetches),
+ _get_feeds_for_indexed_slices),
# The default catches all types and performs no expansions.
(object,
lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]),