aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Illia Polosukhin <ilblackdragon@gmail.com>2016-05-02 17:09:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-02 18:11:48 -0700
commitd5b0e88e0fa8bdd6d953a5cb73576e123285eca0 (patch)
treeaf416d4e3599209a3ddbd22a3691a2e91fa189f2
parentc9d658fbdae07738ed9dfb7a5a1606448743757a (diff)
Added a sparse_placeholder function to array_ops, to simplify feeding sparse data.
Usage: sp = tf.sparse_placeholder(dtype=tf.float32) ... session.run(output, feed_dict={sp: SparseTensorValue(indices=..., values=..., shape=...)}) or session.run(output, feed_dict={sp: (some_indices, some_values, some_shape)}) or for fixed shape: sp = tf.sparse_placeholder(dtype=tf.float32, shape=[2, 2]) can retrieve shape without session: tf.constant_value(sp.shape) ... session.run(output, feed_dict={sp: (some_indices, some_values)}) Change: 121332788
-rw-r--r--tensorflow/python/client/session_test.py49
-rw-r--r--tensorflow/python/ops/array_ops.py55
-rw-r--r--tensorflow/python/ops/io_ops.py5
3 files changed, 109 insertions, 0 deletions
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index c081fe0300..14f11e2eb6 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -261,6 +261,55 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertAllEqual(sp2_out.values, values)
self.assertAllEqual(sp2_out.shape, shape)
+ def testFeedSparsePlaceholder(self):
+ with session.Session() as s:
+ indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
+ values = np.array([1.0, 2.0]).astype(np.float32)
+ shape = np.array([7, 9, 2]).astype(np.int64)
+ sp = array_ops.sparse_placeholder(dtype=np.float32, name='placeholder1')
+ sp_indices = array_ops.identity(sp.indices)
+ sp_values = array_ops.identity(sp.values)
+ sp_shape = array_ops.identity(sp.shape)
+ sp2 = ops.SparseTensor(sp_indices, sp_values, sp_shape)
+ # Feed with tuple
+ indices_out, values_out, shape_out = s.run(
+ [sp_indices, sp_values, sp_shape], {sp: (indices, values, shape)})
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # Feed with SparseTensorValue
+ indices_out, values_out, shape_out = s.run(
+ [sp_indices, sp_values, sp_shape],
+ {sp: ops.SparseTensorValue(indices, values, shape)})
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # Feed with SparseTensorValue, fetch SparseTensorValue
+ sp2_out = s.run(sp2, {sp: ops.SparseTensorValue(indices, values, shape)})
+ self.assertAllEqual(sp2_out.indices, indices)
+ self.assertAllEqual(sp2_out.values, values)
+ self.assertAllEqual(sp2_out.shape, shape)
+
+ def testFeedSparePlaceholderConstantShape(self):
+ with session.Session() as s:
+ indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
+ values = np.array([1.0, 2.0]).astype(np.float32)
+ shape = np.array([7, 9, 2]).astype(np.int64)
+ sp = array_ops.sparse_placeholder(dtype=np.float32,
+ shape=shape,
+ name='placeholder1')
+ self.assertAllEqual(sp.shape.eval(session=s), shape)
+ self.assertAllEqual(tensor_util.constant_value(sp.shape), shape)
+ sp_indices = array_ops.identity(sp.indices)
+ sp_values = array_ops.identity(sp.values)
+ sp_shape = array_ops.identity(sp.shape)
+ # Feed with tuple
+ indices_out, values_out, shape_out = s.run(
+ [sp_indices, sp_values, sp_shape], {sp: (indices, values)})
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+
def testFetchIndexedSlices(self):
with session.Session() as s:
indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 83b04bcbce..4afca58d89 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -766,6 +766,61 @@ def placeholder(dtype, shape=None, name=None):
return ret
+def sparse_placeholder(dtype, shape=None, name=None):
+ """Inserts a placeholder for a sparse tensor that will be always fed.
+
+ **Important**: This sparse tensor will produce an error if evaluated.
+ Its value must be fed using the `feed_dict` optional argument to
+ `Session.run()`, `Tensor.eval()`, or `Operation.run()`.
+
+ For example:
+
+ ```python
+ x = tf.sparse_placeholder(tf.float32)
+ y = tf.sparse_reduce_sum(x)
+
+ with tf.Session() as sess:
+ print(sess.run(y)) # ERROR: will fail because x was not fed.
+
+ indices = np.array([[3, 2, 0], [4, 5, 1]], dtype=np.int64)
+ values = np.array([1.0, 2.0], dtype=np.float32)
+ shape = np.array([7, 9, 2], dtype=np.int64)
+ print(sess.run(y, feed_dict={
+ x: tf.SparseTensorValue(indices, values, shape)})) # Will succeed.
+ print(sess.run(y, feed_dict={
+ x: (indices, values, shape)})) # Will succeed.
+
+ sp = tf.SparseTensor(indices=indices, values=values, shape=shape)
+ sp_value = sp.eval(session)
+ print(sess.run(y, feed_dict={x: sp_value})) # Will succeed.
+ ```
+
+ Args:
+ dtype: The type of `values` elements in the tensor to be fed.
+ shape: The shape of the tensor to be fed (optional). If the shape is not
+ specified, you can feed a sparse tensor of any shape.
+ name: A name for prefixing the operations (optional).
+
+ Returns:
+ A `SparseTensor` that may be used as a handle for feeding a value, but not
+ evaluated directly.
+ """
+ if shape is None:
+ shape = placeholder(
+ dtypes.int64, name=(name + "/shape") if name is not None else None)
+ else:
+ shape = ops.convert_to_tensor(
+ shape, name=(name + "/shape") if name is not None else None)
+ return ops.SparseTensor(
+ values=placeholder(
+ dtype, name=(name + "/values") if name is not None else None),
+ indices=placeholder(
+ dtypes.int64,
+ name=(name + "/indices") if name is not None else None),
+ shape=shape
+ )
+
+
def pad(tensor, paddings, mode="CONSTANT", name=None): # pylint: disable=invalid-name
"""Pads a tensor.
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index 78069295dc..3317d1c045 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -23,6 +23,11 @@ data](../../how_tos/reading_data/index.md#feeding).
@@placeholder
@@placeholder_with_default
+For feeding `SparseTensor`s which are composite type,
+there is a convenience function:
+
+@@sparse_placeholder
+
## Readers
TensorFlow provides a set of Reader classes for reading data formats.