From 838c47b173c733f770e2ac799f1bc1faa0a3128e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 6 Apr 2016 15:02:45 -0800 Subject: Add SparseTensor.eval. Change: 119212533 --- tensorflow/python/framework/ops.py | 31 ++++++++++++++++++++++++++++++- tensorflow/python/framework/ops_test.py | 14 +++++++++++++- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index f3b4455257..aee976c71e 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -38,8 +38,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import registry from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import versions -from tensorflow.python.util import compat from tensorflow.python.platform import logging +from tensorflow.python.util import compat def _convert_stack(stack): @@ -955,6 +955,32 @@ class SparseTensor(object): return "SparseTensor(indices=%s, values=%s, shape=%s)" % ( self._indices, self._values, self._shape) + def eval(self, feed_dict=None, session=None): + """Evaluates this sparse tensor in a `Session`. + + Calling this method will execute all preceding operations that + produce the inputs needed for the operation that produces this + tensor. + + *N.B.* Before invoking `SparseTensor.eval()`, its graph must have been + launched in a session, and either a default session must be + available, or `session` must be specified explicitly. + + Args: + feed_dict: A dictionary that maps `Tensor` objects to feed values. + See [`Session.run()`](../../api_docs/python/client.md#Session.run) for a + description of the valid feed values. + session: (Optional.) The `Session` to be used to evaluate this sparse + tensor. If none, the default session will be used. + + Returns: + A `SparseTensorValue` object. + + """ + indices, values, shape = _eval_using_default_session( + [self.indices, self.values, self.shape], feed_dict, self.graph, session) + return SparseTensorValue(indices, values, shape) + SparseTensorValue = collections.namedtuple("SparseTensorValue", ["indices", "values", "shape"]) @@ -2025,6 +2051,9 @@ class Graph(object): grad_function_name: If not None, this specifies the name of a function that shall be used as the gradient function of the function being added. + + Raises: + ValueError: if another function is defined with the same name. """ previous_def = self._functions.get(function_def.signature.name, None) if previous_def: diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 65e92872cf..0df619c16e 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -58,10 +58,22 @@ class TensorTest(test_util.TensorFlowTestCase): class SparseTensorTest(test_util.TensorFlowTestCase): def testPythonConstruction(self): - sp = ops.SparseTensor([[1, 2], [2, 0], [3, 4]], ["a", "b", "c"], [4, 5]) + indices = [[1, 2], [2, 0], [3, 4]] + values = ["a", "b", "c"] + shape = [4, 5] + sp = ops.SparseTensor(indices, values, shape) self.assertEqual(sp.indices.dtype, dtypes.int64) self.assertEqual(sp.values.dtype, dtypes.string) self.assertEqual(sp.shape.dtype, dtypes.int64) + with self.test_session() as sess: + value = sp.eval() + self.assertAllEqual(indices, value.indices) + self.assertAllEqual(values, value.values) + self.assertAllEqual(shape, value.shape) + sess_run_value = sess.run(sp) + self.assertAllEqual(sess_run_value.indices, value.indices) + self.assertAllEqual(sess_run_value.values, value.values) + self.assertAllEqual(sess_run_value.shape, value.shape) class IndexedSlicesTest(test_util.TensorFlowTestCase): -- cgit v1.2.3