diff options
author | Eugene Brevdo <ebrevdo@gmail.com> | 2016-03-10 15:16:40 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-03-11 11:39:30 -0800 |
commit | 64dd5b58d52d37697d5beb68e2177b966108e0a7 (patch) | |
tree | 2e531a1a9ab3494625a76d0b02977e4d0564f752 | |
parent | 025c0d21a6689f081082b1a51f8812e56b07af77 (diff) |
Add SparseTensor support to tf.batch and friends.
Change: 116914274
-rw-r--r-- | tensorflow/python/framework/ops.py | 19 | ||||
-rw-r--r-- | tensorflow/python/training/input.py | 78 | ||||
-rw-r--r-- | tensorflow/python/training/input_test.py | 128 |
3 files changed, 192 insertions, 33 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index aab591aa62..e15299e519 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -576,25 +576,25 @@ def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None, as_ref=False): """Converts the given object to a `Tensor` or an `IndexedSlices`. - If `value` is an `IndexedSlices` it is returned + If `value` is an `IndexedSlices` or `SparseTensor` it is returned unmodified. Otherwise, it is converted to a `Tensor` using `convert_to_tensor()`. Args: - value: An `IndexedSlices` or an object that can be consumed by - `convert_to_tensor()`. + value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed + by `convert_to_tensor()`. dtype: (Optional.) The required `DType` of the returned `Tensor` or `IndexedSlices`. name: (Optional.) A name to use if a new `Tensor` is created. as_ref: True if the caller wants the results as ref tensors. Returns: - An `Tensor` or an `IndexedSlices` based on `value`. + An `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`. Raises: ValueError: If `dtype` does not match the element type of `value`. """ - if isinstance(value, IndexedSlices): + if isinstance(value, (IndexedSlices, SparseTensor)): if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype): raise ValueError( "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" @@ -608,9 +608,12 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None, as_ref=False): """Converts `values` to a list of `Tensor` or `IndexedSlices` objects. + Any `IndexedSlices` or `SparseTensor` objects in `values` are returned + unmodified. + Args: - values: A list of `None`, `IndexedSlices`, or objects that can be consumed - by `convert_to_tensor()`. + values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that + can be consumed by `convert_to_tensor()`. dtype: (Optional.) The required `DType` of the returned `Tensor` `IndexedSlices`. name: (Optional.) A name prefix to used when a new `Tensor` is @@ -619,7 +622,7 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None, as_ref: True if the caller wants the results as ref tensors. Returns: - A list of `Tensor` and/or `IndexedSlices` objects. + A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects. Raises: TypeError: If no conversion function is registered for an element in diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index 533053120c..55ae8adba8 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -23,6 +23,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import dtypes @@ -35,6 +37,7 @@ from tensorflow.python.ops import io_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variables from tensorflow.python.training import queue_runner @@ -228,6 +231,54 @@ def _flatten(tensor_list_list): return [tensor for tensor_list in tensor_list_list for tensor in tensor_list] +def _serialize_sparse_tensors(tensor_list, enqueue_many): + """Serialize SparseTensors for feeding into batch, etc.""" + is_sparse_list = [isinstance(t, ops.SparseTensor) for t in tensor_list] + sparse_dtypes_list = [ + t.dtype if isinstance(t, ops.SparseTensor) else None + for t in tensor_list] + + def _maybe_serialize(t, is_sparse): + if not is_sparse: + return t + return (sparse_ops.serialize_many_sparse(t) if enqueue_many + else sparse_ops.serialize_sparse(t)) + serialized_list = [ + _maybe_serialize(t, is_sparse) + for (t, is_sparse) in zip(tensor_list, is_sparse_list)] + return serialized_list, is_sparse_list, sparse_dtypes_list + + +def _serialize_sparse_tensors_join(tensor_list_list, enqueue_many): + """Serialize SparseTensors for feeding into batch_join, etc.""" + (s0, is_sparse_list, sparse_dtypes_list) = _serialize_sparse_tensors( + tensor_list_list[0], enqueue_many) + serialized_list_list = [s0] + for tensor_list in tensor_list_list[1:]: + (s, is_sparse_candidate, sparse_dtypes_candidate) = ( + _serialize_sparse_tensors(tensor_list, enqueue_many)) + if is_sparse_candidate != is_sparse_list: + raise ValueError("Inconsistent SparseTensors list: %s vs. %s" + % (tensor_list_list[0], tensor_list)) + if sparse_dtypes_candidate != sparse_dtypes_list: + raise ValueError("Inconsistent SparseTensor dtypes in list: %s vs. %s" + % (tensor_list_list[0], tensor_list)) + serialized_list_list.append(s) + return (serialized_list_list, is_sparse_list, sparse_dtypes_list) + + +def _deserialize_sparse_tensors(serialized_list, is_sparse_list, sparse_dtypes): + """Deserialize SparseTensors after dequeue in batch, batch_join, etc.""" + received_sequence = isinstance(serialized_list, collections.Sequence) + if not received_sequence: + serialized_list = (serialized_list,) + tensors = [sparse_ops.deserialize_many_sparse(s, sparse_dtype) if is_sparse + else s + for (s, is_sparse, sparse_dtype) + in zip(serialized_list, is_sparse_list, sparse_dtypes)] + return tensors if received_sequence else tensors[0] + + def _validate(tensor_list): tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensor_list) if not tensor_list: @@ -343,6 +394,8 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32, """ with ops.op_scope(tensor_list, name, "batch") as name: tensor_list = _validate(tensor_list) + tensor_list, is_sparse, sparse_dtypes = _serialize_sparse_tensors( + tensor_list, enqueue_many) types = _dtypes([tensor_list]) shapes = _shapes([tensor_list], shapes, enqueue_many) # TODO(josh11b,mrry): Switch to BatchQueue once it is written. @@ -352,7 +405,10 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32, logging_ops.scalar_summary( "queue/%s/fraction_of_%d_full" % (queue.name, capacity), math_ops.cast(queue.size(), dtypes.float32) * (1. / capacity)) - return queue.dequeue_many(batch_size, name=name) + + dequeued = queue.dequeue_many(batch_size, name=name) + dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes) + return dequeued # TODO(josh11b): Add a thread_multiplier or num_threads (that has to be @@ -422,6 +478,8 @@ def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False, """ with ops.op_scope(_flatten(tensor_list_list), name, "batch_join") as name: tensor_list_list = _validate_join(tensor_list_list) + tensor_list_list, is_sparse, sparse_dtypes = ( + _serialize_sparse_tensors_join(tensor_list_list, enqueue_many)) types = _dtypes(tensor_list_list) shapes = _shapes(tensor_list_list, shapes, enqueue_many) # TODO(josh11b,mrry): Switch to BatchQueue once it is written. @@ -431,7 +489,10 @@ def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False, logging_ops.scalar_summary( "queue/%s/fraction_of_%d_full" % (queue.name, capacity), math_ops.cast(queue.size(), dtypes.float32) * (1. / capacity)) - return queue.dequeue_many(batch_size, name=name) + + dequeued = queue.dequeue_many(batch_size, name=name) + dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes) + return dequeued def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue, @@ -506,6 +567,8 @@ def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue, """ with ops.op_scope(tensor_list, name, "shuffle_batch") as name: tensor_list = _validate(tensor_list) + tensor_list, is_sparse, sparse_dtypes = _serialize_sparse_tensors( + tensor_list, enqueue_many) types = _dtypes([tensor_list]) shapes = _shapes([tensor_list], shapes, enqueue_many) queue = data_flow_ops.RandomShuffleQueue( @@ -522,7 +585,9 @@ def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue, (name, min_after_dequeue, capacity - min_after_dequeue)) logging_ops.scalar_summary(summary_name, full) - return queue.dequeue_many(batch_size, name=name) + dequeued = queue.dequeue_many(batch_size, name=name) + dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes) + return dequeued def shuffle_batch_join(tensor_list_list, batch_size, capacity, @@ -587,6 +652,8 @@ def shuffle_batch_join(tensor_list_list, batch_size, capacity, with ops.op_scope( _flatten(tensor_list_list), name, "shuffle_batch_join") as name: tensor_list_list = _validate_join(tensor_list_list) + tensor_list_list, is_sparse, sparse_dtypes = ( + _serialize_sparse_tensors_join(tensor_list_list, enqueue_many)) types = _dtypes(tensor_list_list) shapes = _shapes(tensor_list_list, shapes, enqueue_many) queue = data_flow_ops.RandomShuffleQueue( @@ -602,4 +669,7 @@ def shuffle_batch_join(tensor_list_list, batch_size, capacity, "queue/%sfraction_over_%d_of_%d_full" % (name, min_after_dequeue, capacity - min_after_dequeue)) logging_ops.scalar_summary(summary_name, full) - return queue.dequeue_many(batch_size, name=name) + + dequeued = queue.dequeue_many(batch_size, name=name) + dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes) + return dequeued diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py index 7057341fa7..14c31442dd 100644 --- a/tensorflow/python/training/input_test.py +++ b/tensorflow/python/training/input_test.py @@ -318,7 +318,12 @@ class BatchTest(tf.test.TestCase): zero64 = tf.constant(0, dtype=tf.int64) examples = tf.Variable(zero64) counter = examples.count_up_to(num_batches * batch_size) - batched = tf.train.batch([counter, "string"], batch_size=batch_size) + sparse_counter = tf.SparseTensor( + indices=tf.reshape(tf.pack([zero64, zero64 + 1]), [2, 1]), + values=tf.cast(tf.pack([counter, -counter]), tf.float32), + shape=[2]) + batched = tf.train.batch( + [counter, sparse_counter, "string"], batch_size=batch_size) tf.initialize_all_variables().run() threads = tf.train.start_queue_runners() @@ -326,7 +331,16 @@ class BatchTest(tf.test.TestCase): results = sess.run(batched) self.assertAllEqual(results[0], np.arange(i * batch_size, (i + 1) * batch_size)) - self.assertAllEqual(results[1], [b"string"] * batch_size) + self.assertAllEqual( + results[1].indices, + np.vstack((np.arange(2 * batch_size) // 2, # 0, 0, 1, 1, ... + [0, 1] * batch_size)).T) + # [x, -x, x+1, -(x+1), ...] + expected = np.arange(2 * i * batch_size, 2 * (i + 1) * batch_size) // 2 + expected *= ([1, -1] * batch_size) # mult by [1, -1, 1, -1, ...] + self.assertAllEqual(results[1].values, expected) + self.assertAllEqual(results[1].shape, [batch_size, 2]) + self.assertAllEqual(results[2], [b"string"] * batch_size) # Reached the limit. with self.assertRaises(tf.errors.OutOfRangeError): @@ -341,7 +355,12 @@ class BatchTest(tf.test.TestCase): zero64 = tf.constant(0, dtype=tf.int64) examples = tf.Variable(zero64) counter = examples.count_up_to(num_batches * batch_size) - pre_batched = tf.train.batch([counter, "string"], batch_size=2) + sparse_counter = tf.SparseTensor( + indices=tf.reshape(zero64, [1, 1]), + values=tf.pack([tf.cast(counter, tf.float32)]), + shape=[1]) + pre_batched = tf.train.batch( + [counter, sparse_counter, "string"], batch_size=2) batched = tf.train.batch(pre_batched, enqueue_many=True, batch_size=batch_size) tf.initialize_all_variables().run() @@ -351,7 +370,13 @@ class BatchTest(tf.test.TestCase): results = sess.run(batched) self.assertAllEqual(results[0], np.arange(i * batch_size, (i + 1) * batch_size)) - self.assertAllEqual(results[1], [b"string"] * batch_size) + self.assertAllEqual( + results[1].indices, + np.vstack((np.arange(batch_size), np.zeros(batch_size))).T) + self.assertAllEqual( + results[1].values, np.arange(i * batch_size, (i + 1) * batch_size)) + self.assertAllEqual(results[1].shape, [batch_size, 1]) + self.assertAllEqual(results[2], [b"string"] * batch_size) # Reached the limit. with self.assertRaises(tf.errors.OutOfRangeError): @@ -364,10 +389,16 @@ class BatchTest(tf.test.TestCase): batch_size = 10 num_batches = 3 zero64 = tf.constant(0, dtype=tf.int64) + examples = tf.Variable(zero64) counter = examples.count_up_to(num_batches * batch_size) - batched = tf.train.batch([counter, "string"], batch_size=batch_size, - num_threads=4) + sparse_counter = tf.SparseTensor( + indices=tf.reshape(zero64, [1, 1]), + values=tf.pack([tf.cast(counter, tf.float32)]), + shape=[1]) + batched = tf.train.batch( + [counter, sparse_counter, "string"], + batch_size=batch_size, num_threads=4) tf.initialize_all_variables().run() threads = tf.train.start_queue_runners() @@ -376,8 +407,13 @@ class BatchTest(tf.test.TestCase): results = sess.run(batched) tf.logging.info("Batch %d: %s", i, results[0]) self.assertEqual(len(results[0]), batch_size) + self.assertAllEqual(results[0], results[1].values) + self.assertAllEqual( + results[1].indices, + np.vstack((np.arange(batch_size), np.zeros(batch_size))).T) + self.assertAllEqual(results[1].shape, [batch_size, 1]) all_counts.extend(results[0]) - self.assertAllEqual(results[1], [b"string"] * batch_size) + self.assertAllEqual(results[2], [b"string"] * batch_size) self.assertItemsEqual(all_counts, range(num_batches * batch_size)) # Reached the limit. @@ -411,16 +447,26 @@ class BatchJoinTest(tf.test.TestCase): zero64 = tf.constant(0, dtype=tf.int64) examples = tf.Variable(zero64) counter = examples.count_up_to(num_a) + sparse_counter = tf.SparseTensor( + indices=tf.reshape(zero64, [1, 1]), + values=tf.pack([tf.cast(counter, tf.float32)]), + shape=[1]) # The second generates (99, "b") 90 times and then stops. num_b = 90 ninety_nine = tf.train.limit_epochs( tf.constant(99, dtype=tf.int64), num_b) + sparse_ninety_nine = tf.SparseTensor( + indices=tf.reshape(zero64, [1, 1]), + values=tf.pack([tf.cast(ninety_nine, tf.float32)]), + shape=[1]) # These get joined together and grouped into batches of 5. batch_size = 5 - batched = tf.train.batch_join([[counter, "a"], [ninety_nine, "b"]], - batch_size=batch_size) + batched = tf.train.batch_join( + [[counter, sparse_counter, "a"], + [ninety_nine, sparse_ninety_nine, "b"]], + batch_size=batch_size) tf.initialize_all_variables().run() threads = tf.train.start_queue_runners() @@ -433,9 +479,14 @@ class BatchJoinTest(tf.test.TestCase): results = sess.run(batched) tf.logging.info("Batch %d: %s", i, results[0]) self.assertEqual(len(results[0]), batch_size) - self.assertEqual(len(results[1]), batch_size) - which_a = [i for i, s in enumerate(results[1]) if s == b"a"] - which_b = [i for i, s in enumerate(results[1]) if s == b"b"] + self.assertEqual(len(results[2]), batch_size) + self.assertAllEqual(results[0], results[1].values) + self.assertAllEqual( + results[1].indices, + np.vstack((np.arange(batch_size), np.zeros(batch_size))).T) + self.assertAllEqual(results[1].shape, [batch_size, 1]) + which_a = [i for i, s in enumerate(results[2]) if s == b"a"] + which_b = [i for i, s in enumerate(results[2]) if s == b"b"] self.assertEqual(len(which_a) + len(which_b), batch_size) if len(which_a) > 0 and len(which_b) > 0: saw_both += 1 all_a.extend([results[0][i] for i in which_a]) @@ -481,8 +532,13 @@ class ShuffleBatchTest(tf.test.TestCase): zero64 = tf.constant(0, dtype=tf.int64) examples = tf.Variable(zero64) counter = examples.count_up_to(num_batches * batch_size) + sparse_counter = tf.SparseTensor( + indices=tf.reshape(zero64, [1, 1]), + values=tf.pack([tf.cast(counter, tf.float32)]), + shape=[1]) batched = tf.train.shuffle_batch( - [counter, "string"], batch_size=batch_size, capacity=32, + [counter, sparse_counter, "string"], + batch_size=batch_size, capacity=32, min_after_dequeue=16, seed=141421) tf.initialize_all_variables().run() threads = tf.train.start_queue_runners() @@ -492,7 +548,12 @@ class ShuffleBatchTest(tf.test.TestCase): results = sess.run(batched) self.assertEqual(len(results[0]), batch_size) all_counts.extend(results[0]) - self.assertAllEqual(results[1], [b"string"] * batch_size) + self.assertAllEqual( + results[1].indices, + np.vstack((np.arange(batch_size), np.zeros(batch_size))).T) + self.assertAllEqual(results[0], results[1].values) + self.assertAllEqual(results[1].shape, [batch_size, 1]) + self.assertAllEqual(results[2], [b"string"] * batch_size) # Results scrambled, but include all the expected numbers. deltas = [all_counts[i + 1] - all_counts[i] for i in range(len(all_counts) - 1)] @@ -512,8 +573,13 @@ class ShuffleBatchTest(tf.test.TestCase): zero64 = tf.constant(0, dtype=tf.int64) examples = tf.Variable(zero64) counter = examples.count_up_to(num_batches * batch_size) + sparse_counter = tf.SparseTensor( + indices=tf.reshape(zero64, [1, 1]), + values=tf.pack([tf.cast(counter, tf.float32)]), + shape=[1]) batched = tf.train.shuffle_batch( - [counter, "string"], batch_size=batch_size, capacity=32, + [counter, sparse_counter, "string"], + batch_size=batch_size, capacity=32, min_after_dequeue=16, seed=173205, num_threads=4) tf.initialize_all_variables().run() threads = tf.train.start_queue_runners() @@ -524,7 +590,12 @@ class ShuffleBatchTest(tf.test.TestCase): tf.logging.info("Batch %d: %s", i, results[0]) self.assertEqual(len(results[0]), batch_size) all_counts.extend(results[0]) - self.assertAllEqual(results[1], [b"string"] * batch_size) + self.assertAllEqual( + results[1].indices, + np.vstack((np.arange(batch_size), np.zeros(batch_size))).T) + self.assertAllEqual(results[0], results[1].values) + self.assertAllEqual(results[1].shape, [batch_size, 1]) + self.assertAllEqual(results[2], [b"string"] * batch_size) # Results scrambled, but include all the expected numbers. deltas = [all_counts[i + 1] - all_counts[i] for i in range(len(all_counts) - 1)] @@ -564,17 +635,27 @@ class ShuffleBatchJoinTest(tf.test.TestCase): zero64 = tf.constant(0, dtype=tf.int64) examples = tf.Variable(zero64) counter = examples.count_up_to(num_a) + sparse_counter = tf.SparseTensor( + indices=tf.reshape(zero64, [1, 1]), + values=tf.pack([tf.cast(counter, tf.float32)]), + shape=[1]) # The second generates (99, "b") 35 times and then stops. num_b = 35 ninety_nine = tf.train.limit_epochs( tf.constant(99, dtype=tf.int64), num_b) + sparse_ninety_nine = tf.SparseTensor( + indices=tf.reshape(zero64, [1, 1]), + values=tf.pack([tf.cast(ninety_nine, tf.float32)]), + shape=[1]) # These get joined together and grouped into batches of 5. batch_size = 5 batched = tf.train.shuffle_batch_join( - [[counter, "a"], [ninety_nine, "b"]], batch_size=batch_size, - capacity=32, min_after_dequeue=16, seed=223607) + [[counter, sparse_counter, "a"], + [ninety_nine, sparse_ninety_nine, "b"]], + batch_size=batch_size, capacity=32, + min_after_dequeue=16, seed=223607) tf.initialize_all_variables().run() threads = tf.train.start_queue_runners() @@ -588,9 +669,14 @@ class ShuffleBatchJoinTest(tf.test.TestCase): results = sess.run(batched) tf.logging.info("Batch %d: %s", i, results[0]) self.assertEqual(len(results[0]), batch_size) - self.assertEqual(len(results[1]), batch_size) - which_a = [i for i, s in enumerate(results[1]) if s == b"a"] - which_b = [i for i, s in enumerate(results[1]) if s == b"b"] + self.assertEqual(len(results[2]), batch_size) + self.assertAllEqual(results[0], results[1].values) + self.assertAllEqual( + results[1].indices, + np.vstack((np.arange(batch_size), np.zeros(batch_size))).T) + self.assertAllEqual(results[1].shape, [batch_size, 1]) + which_a = [i for i, s in enumerate(results[2]) if s == b"a"] + which_b = [i for i, s in enumerate(results[2]) if s == b"b"] self.assertEqual(len(which_a) + len(which_b), batch_size) if len(which_a) > 0 and len(which_b) > 0: saw_both += 1 all_a.extend([results[0][i] for i in which_a]) |