aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-26 16:51:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-26 16:55:49 -0700
commit423c1eea0e47eb71d3bf3ec7e99e7a4a63c3e433 (patch)
treed8863742bb909a57be96416e27f4a3f243cd788e /tensorflow
parent6028c071b592330fabf107df08d880ec443d6844 (diff)
BREAKING CHANGE: Fix semantic error in how maybe_batch* handles sparse tensors.
PiperOrigin-RevId: 163276613
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/training/input.py9
-rw-r--r--tensorflow/python/training/input_test.py46
2 files changed, 54 insertions, 1 deletions
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 396ec11a02..94c5df619f 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -492,8 +492,15 @@ def _store_sparse_tensors(tensor_list, enqueue_many, keep_input,
lambda: -1 * array_ops.ones(array_ops.shape(t)[0:1], dtypes.int64))
out_tensor.set_shape([None]) # necessary when t.ndims is unknown
return out_tensor
+ def _sparse_values_to_keep(t, keep_input):
+ """Convert a per-row `keep_input` vector to a per-value one."""
+ # Get the rows of every value in the sparse Tensor.
+ row_values = array_ops.reshape(
+ t.indices, [array_ops.shape(t.indices)[0], -1])[:, 0]
+ # The value should be kept iff the row should be kept.
+ return array_ops.gather(keep_input, row_values)
if keep_input.shape.ndims == 1:
- t = sparse_ops.sparse_retain(t, keep_input)
+ t = sparse_ops.sparse_retain(t, _sparse_values_to_keep(t, keep_input))
store_f = lambda t, name, _: _store_many_sparse(t, shared_name=name)
elif enqueue_many:
store_f = _maybe_store_many_sparse
diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py
index d32768f5da..3a25bfe343 100644
--- a/tensorflow/python/training/input_test.py
+++ b/tensorflow/python/training/input_test.py
@@ -903,6 +903,29 @@ class BatchTest(test_lib.TestCase):
[sparse], keep_input=[True, False], batch_size=2, enqueue_many=True)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
+ def testMaybeBatchCorrectValues(self):
+ sparse_t = sparse_tensor.SparseTensor(
+ indices=[[0, 1], [0, 2], [1, 0], [1, 3]],
+ dense_shape=[2, 4],
+ values=[5, 4, 7, 2])
+ keep = constant_op.constant([True, False])
+ batched = inp.maybe_batch(
+ [sparse_t], keep_input=keep, batch_size=1, enqueue_many=True)
+
+ with self.test_session():
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(coord=coord)
+
+ batched_np = batched.eval()
+
+ coord.request_stop()
+ for thread in threads:
+ thread.join()
+
+ self.assertAllEqual([[0, 1], [0, 2]], batched_np.indices)
+ self.assertAllEqual([5, 4], batched_np.values)
+ self.assertAllEqual([1, 4], batched_np.dense_shape)
+
class BatchJoinTest(test_lib.TestCase):
@@ -1457,6 +1480,29 @@ class BatchJoinTest(test_lib.TestCase):
[[sparse]], keep_input=[True, False], batch_size=2, enqueue_many=True)
self.assertIs(None, batched.dense_shape.get_shape().num_elements())
+ def testMaybeBatchCorrectValues(self):
+ sparse = sparse_tensor.SparseTensor(
+ indices=[[0, 1], [0, 2], [1, 0], [1, 3]],
+ dense_shape=[2, 4],
+ values=[5, 4, 7, 2])
+ keep = constant_op.constant([True, False])
+ batched = inp.maybe_batch_join(
+ [[sparse]], keep_input=keep, batch_size=1, enqueue_many=True)
+
+ with self.test_session():
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(coord=coord)
+
+ batched_np = batched.eval()
+
+ coord.request_stop()
+ for thread in threads:
+ thread.join()
+
+ self.assertAllEqual([[0, 1], [0, 2]], batched_np.indices)
+ self.assertAllEqual([5, 4], batched_np.values)
+ self.assertAllEqual([1, 4], batched_np.dense_shape)
+
class ShuffleBatchTest(test_lib.TestCase):