aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-04-06 15:02:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-06 16:11:07 -0700
commit838c47b173c733f770e2ac799f1bc1faa0a3128e (patch)
tree68bae67c37343de3f4014b42f1548feba2114032
parentf3554041edfbadcde875abd3c762f23ff418f15d (diff)
Add SparseTensor.eval.
Change: 119212533
-rw-r--r--tensorflow/python/framework/ops.py31
-rw-r--r--tensorflow/python/framework/ops_test.py14
2 files changed, 43 insertions, 2 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index f3b4455257..aee976c71e 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -38,8 +38,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import registry
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import versions
-from tensorflow.python.util import compat
from tensorflow.python.platform import logging
+from tensorflow.python.util import compat
def _convert_stack(stack):
@@ -955,6 +955,32 @@ class SparseTensor(object):
return "SparseTensor(indices=%s, values=%s, shape=%s)" % (
self._indices, self._values, self._shape)
+ def eval(self, feed_dict=None, session=None):
+ """Evaluates this sparse tensor in a `Session`.
+
+ Calling this method will execute all preceding operations that
+ produce the inputs needed for the operation that produces this
+ tensor.
+
+ *N.B.* Before invoking `SparseTensor.eval()`, its graph must have been
+ launched in a session, and either a default session must be
+ available, or `session` must be specified explicitly.
+
+ Args:
+ feed_dict: A dictionary that maps `Tensor` objects to feed values.
+ See [`Session.run()`](../../api_docs/python/client.md#Session.run) for a
+ description of the valid feed values.
+ session: (Optional.) The `Session` to be used to evaluate this sparse
+ tensor. If none, the default session will be used.
+
+ Returns:
+ A `SparseTensorValue` object.
+
+ """
+ indices, values, shape = _eval_using_default_session(
+ [self.indices, self.values, self.shape], feed_dict, self.graph, session)
+ return SparseTensorValue(indices, values, shape)
+
SparseTensorValue = collections.namedtuple("SparseTensorValue",
["indices", "values", "shape"])
@@ -2025,6 +2051,9 @@ class Graph(object):
grad_function_name: If not None, this specifies the name of a function
that shall be used as the gradient function of
the function being added.
+
+ Raises:
+ ValueError: if another function is defined with the same name.
"""
previous_def = self._functions.get(function_def.signature.name, None)
if previous_def:
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 65e92872cf..0df619c16e 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -58,10 +58,22 @@ class TensorTest(test_util.TensorFlowTestCase):
class SparseTensorTest(test_util.TensorFlowTestCase):
def testPythonConstruction(self):
- sp = ops.SparseTensor([[1, 2], [2, 0], [3, 4]], ["a", "b", "c"], [4, 5])
+ indices = [[1, 2], [2, 0], [3, 4]]
+ values = ["a", "b", "c"]
+ shape = [4, 5]
+ sp = ops.SparseTensor(indices, values, shape)
self.assertEqual(sp.indices.dtype, dtypes.int64)
self.assertEqual(sp.values.dtype, dtypes.string)
self.assertEqual(sp.shape.dtype, dtypes.int64)
+ with self.test_session() as sess:
+ value = sp.eval()
+ self.assertAllEqual(indices, value.indices)
+ self.assertAllEqual(values, value.values)
+ self.assertAllEqual(shape, value.shape)
+ sess_run_value = sess.run(sp)
+ self.assertAllEqual(sess_run_value.indices, value.indices)
+ self.assertAllEqual(sess_run_value.values, value.values)
+ self.assertAllEqual(sess_run_value.shape, value.shape)
class IndexedSlicesTest(test_util.TensorFlowTestCase):