diff options
author | 2017-07-26 16:51:45 -0700 | |
---|---|---|
committer | 2017-07-26 16:55:49 -0700 | |
commit | 423c1eea0e47eb71d3bf3ec7e99e7a4a63c3e433 (patch) | |
tree | d8863742bb909a57be96416e27f4a3f243cd788e /tensorflow | |
parent | 6028c071b592330fabf107df08d880ec443d6844 (diff) |
BREAKING CHANGE: Fix semantic error in how maybe_batch* handles sparse tensors.
PiperOrigin-RevId: 163276613
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/training/input.py | 9 | ||||
-rw-r--r-- | tensorflow/python/training/input_test.py | 46 |
2 files changed, 54 insertions, 1 deletions
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index 396ec11a02..94c5df619f 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -492,8 +492,15 @@ def _store_sparse_tensors(tensor_list, enqueue_many, keep_input, lambda: -1 * array_ops.ones(array_ops.shape(t)[0:1], dtypes.int64)) out_tensor.set_shape([None]) # necessary when t.ndims is unknown return out_tensor + def _sparse_values_to_keep(t, keep_input): + """Convert a per-row `keep_input` vector to a per-value one.""" + # Get the rows of every value in the sparse Tensor. + row_values = array_ops.reshape( + t.indices, [array_ops.shape(t.indices)[0], -1])[:, 0] + # The value should be kept iff the row should be kept. + return array_ops.gather(keep_input, row_values) if keep_input.shape.ndims == 1: - t = sparse_ops.sparse_retain(t, keep_input) + t = sparse_ops.sparse_retain(t, _sparse_values_to_keep(t, keep_input)) store_f = lambda t, name, _: _store_many_sparse(t, shared_name=name) elif enqueue_many: store_f = _maybe_store_many_sparse diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py index d32768f5da..3a25bfe343 100644 --- a/tensorflow/python/training/input_test.py +++ b/tensorflow/python/training/input_test.py @@ -903,6 +903,29 @@ class BatchTest(test_lib.TestCase): [sparse], keep_input=[True, False], batch_size=2, enqueue_many=True) self.assertIs(None, batched.dense_shape.get_shape().num_elements()) + def testMaybeBatchCorrectValues(self): + sparse_t = sparse_tensor.SparseTensor( + indices=[[0, 1], [0, 2], [1, 0], [1, 3]], + dense_shape=[2, 4], + values=[5, 4, 7, 2]) + keep = constant_op.constant([True, False]) + batched = inp.maybe_batch( + [sparse_t], keep_input=keep, batch_size=1, enqueue_many=True) + + with self.test_session(): + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(coord=coord) + + batched_np = batched.eval() + + coord.request_stop() + for thread in threads: + thread.join() + + self.assertAllEqual([[0, 1], [0, 2]], batched_np.indices) + self.assertAllEqual([5, 4], batched_np.values) + self.assertAllEqual([1, 4], batched_np.dense_shape) + class BatchJoinTest(test_lib.TestCase): @@ -1457,6 +1480,29 @@ class BatchJoinTest(test_lib.TestCase): [[sparse]], keep_input=[True, False], batch_size=2, enqueue_many=True) self.assertIs(None, batched.dense_shape.get_shape().num_elements()) + def testMaybeBatchCorrectValues(self): + sparse = sparse_tensor.SparseTensor( + indices=[[0, 1], [0, 2], [1, 0], [1, 3]], + dense_shape=[2, 4], + values=[5, 4, 7, 2]) + keep = constant_op.constant([True, False]) + batched = inp.maybe_batch_join( + [[sparse]], keep_input=keep, batch_size=1, enqueue_many=True) + + with self.test_session(): + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(coord=coord) + + batched_np = batched.eval() + + coord.request_stop() + for thread in threads: + thread.join() + + self.assertAllEqual([[0, 1], [0, 2]], batched_np.indices) + self.assertAllEqual([5, 4], batched_np.values) + self.assertAllEqual([1, 4], batched_np.dense_shape) + class ShuffleBatchTest(test_lib.TestCase): |