aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/framework/ops.py19
-rw-r--r--tensorflow/python/training/input.py78
-rw-r--r--tensorflow/python/training/input_test.py128
3 files changed, 192 insertions, 33 deletions
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index aab591aa62..e15299e519 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -576,25 +576,25 @@ def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None,
as_ref=False):
"""Converts the given object to a `Tensor` or an `IndexedSlices`.
- If `value` is an `IndexedSlices` it is returned
+ If `value` is an `IndexedSlices` or `SparseTensor` it is returned
unmodified. Otherwise, it is converted to a `Tensor` using
`convert_to_tensor()`.
Args:
- value: An `IndexedSlices` or an object that can be consumed by
- `convert_to_tensor()`.
+ value: An `IndexedSlices`, `SparseTensor`, or an object that can be consumed
+ by `convert_to_tensor()`.
dtype: (Optional.) The required `DType` of the returned `Tensor` or
`IndexedSlices`.
name: (Optional.) A name to use if a new `Tensor` is created.
as_ref: True if the caller wants the results as ref tensors.
Returns:
- An `Tensor` or an `IndexedSlices` based on `value`.
+ An `Tensor`, `IndexedSlices`, or `SparseTensor` based on `value`.
Raises:
ValueError: If `dtype` does not match the element type of `value`.
"""
- if isinstance(value, IndexedSlices):
+ if isinstance(value, (IndexedSlices, SparseTensor)):
if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value.dtype):
raise ValueError(
"Tensor conversion requested dtype %s for Tensor with dtype %s: %r"
@@ -608,9 +608,12 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None,
as_ref=False):
"""Converts `values` to a list of `Tensor` or `IndexedSlices` objects.
+ Any `IndexedSlices` or `SparseTensor` objects in `values` are returned
+ unmodified.
+
Args:
- values: A list of `None`, `IndexedSlices`, or objects that can be consumed
- by `convert_to_tensor()`.
+ values: A list of `None`, `IndexedSlices`, `SparseTensor`, or objects that
+ can be consumed by `convert_to_tensor()`.
dtype: (Optional.) The required `DType` of the returned `Tensor`
`IndexedSlices`.
name: (Optional.) A name prefix to used when a new `Tensor` is
@@ -619,7 +622,7 @@ def convert_n_to_tensor_or_indexed_slices(values, dtype=None, name=None,
as_ref: True if the caller wants the results as ref tensors.
Returns:
- A list of `Tensor` and/or `IndexedSlices` objects.
+ A list of `Tensor`, `IndexedSlices`, and/or `SparseTensor` objects.
Raises:
TypeError: If no conversion function is registered for an element in
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 533053120c..55ae8adba8 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -23,6 +23,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import dtypes
@@ -35,6 +37,7 @@ from tensorflow.python.ops import io_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import queue_runner
@@ -228,6 +231,54 @@ def _flatten(tensor_list_list):
return [tensor for tensor_list in tensor_list_list for tensor in tensor_list]
+def _serialize_sparse_tensors(tensor_list, enqueue_many):
+ """Serialize SparseTensors for feeding into batch, etc."""
+ is_sparse_list = [isinstance(t, ops.SparseTensor) for t in tensor_list]
+ sparse_dtypes_list = [
+ t.dtype if isinstance(t, ops.SparseTensor) else None
+ for t in tensor_list]
+
+ def _maybe_serialize(t, is_sparse):
+ if not is_sparse:
+ return t
+ return (sparse_ops.serialize_many_sparse(t) if enqueue_many
+ else sparse_ops.serialize_sparse(t))
+ serialized_list = [
+ _maybe_serialize(t, is_sparse)
+ for (t, is_sparse) in zip(tensor_list, is_sparse_list)]
+ return serialized_list, is_sparse_list, sparse_dtypes_list
+
+
+def _serialize_sparse_tensors_join(tensor_list_list, enqueue_many):
+ """Serialize SparseTensors for feeding into batch_join, etc."""
+ (s0, is_sparse_list, sparse_dtypes_list) = _serialize_sparse_tensors(
+ tensor_list_list[0], enqueue_many)
+ serialized_list_list = [s0]
+ for tensor_list in tensor_list_list[1:]:
+ (s, is_sparse_candidate, sparse_dtypes_candidate) = (
+ _serialize_sparse_tensors(tensor_list, enqueue_many))
+ if is_sparse_candidate != is_sparse_list:
+ raise ValueError("Inconsistent SparseTensors list: %s vs. %s"
+ % (tensor_list_list[0], tensor_list))
+ if sparse_dtypes_candidate != sparse_dtypes_list:
+ raise ValueError("Inconsistent SparseTensor dtypes in list: %s vs. %s"
+ % (tensor_list_list[0], tensor_list))
+ serialized_list_list.append(s)
+ return (serialized_list_list, is_sparse_list, sparse_dtypes_list)
+
+
+def _deserialize_sparse_tensors(serialized_list, is_sparse_list, sparse_dtypes):
+ """Deserialize SparseTensors after dequeue in batch, batch_join, etc."""
+ received_sequence = isinstance(serialized_list, collections.Sequence)
+ if not received_sequence:
+ serialized_list = (serialized_list,)
+ tensors = [sparse_ops.deserialize_many_sparse(s, sparse_dtype) if is_sparse
+ else s
+ for (s, is_sparse, sparse_dtype)
+ in zip(serialized_list, is_sparse_list, sparse_dtypes)]
+ return tensors if received_sequence else tensors[0]
+
+
def _validate(tensor_list):
tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensor_list)
if not tensor_list:
@@ -343,6 +394,8 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32,
"""
with ops.op_scope(tensor_list, name, "batch") as name:
tensor_list = _validate(tensor_list)
+ tensor_list, is_sparse, sparse_dtypes = _serialize_sparse_tensors(
+ tensor_list, enqueue_many)
types = _dtypes([tensor_list])
shapes = _shapes([tensor_list], shapes, enqueue_many)
# TODO(josh11b,mrry): Switch to BatchQueue once it is written.
@@ -352,7 +405,10 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32,
logging_ops.scalar_summary(
"queue/%s/fraction_of_%d_full" % (queue.name, capacity),
math_ops.cast(queue.size(), dtypes.float32) * (1. / capacity))
- return queue.dequeue_many(batch_size, name=name)
+
+ dequeued = queue.dequeue_many(batch_size, name=name)
+ dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes)
+ return dequeued
# TODO(josh11b): Add a thread_multiplier or num_threads (that has to be
@@ -422,6 +478,8 @@ def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False,
"""
with ops.op_scope(_flatten(tensor_list_list), name, "batch_join") as name:
tensor_list_list = _validate_join(tensor_list_list)
+ tensor_list_list, is_sparse, sparse_dtypes = (
+ _serialize_sparse_tensors_join(tensor_list_list, enqueue_many))
types = _dtypes(tensor_list_list)
shapes = _shapes(tensor_list_list, shapes, enqueue_many)
# TODO(josh11b,mrry): Switch to BatchQueue once it is written.
@@ -431,7 +489,10 @@ def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False,
logging_ops.scalar_summary(
"queue/%s/fraction_of_%d_full" % (queue.name, capacity),
math_ops.cast(queue.size(), dtypes.float32) * (1. / capacity))
- return queue.dequeue_many(batch_size, name=name)
+
+ dequeued = queue.dequeue_many(batch_size, name=name)
+ dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes)
+ return dequeued
def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue,
@@ -506,6 +567,8 @@ def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue,
"""
with ops.op_scope(tensor_list, name, "shuffle_batch") as name:
tensor_list = _validate(tensor_list)
+ tensor_list, is_sparse, sparse_dtypes = _serialize_sparse_tensors(
+ tensor_list, enqueue_many)
types = _dtypes([tensor_list])
shapes = _shapes([tensor_list], shapes, enqueue_many)
queue = data_flow_ops.RandomShuffleQueue(
@@ -522,7 +585,9 @@ def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue,
(name, min_after_dequeue, capacity - min_after_dequeue))
logging_ops.scalar_summary(summary_name, full)
- return queue.dequeue_many(batch_size, name=name)
+ dequeued = queue.dequeue_many(batch_size, name=name)
+ dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes)
+ return dequeued
def shuffle_batch_join(tensor_list_list, batch_size, capacity,
@@ -587,6 +652,8 @@ def shuffle_batch_join(tensor_list_list, batch_size, capacity,
with ops.op_scope(
_flatten(tensor_list_list), name, "shuffle_batch_join") as name:
tensor_list_list = _validate_join(tensor_list_list)
+ tensor_list_list, is_sparse, sparse_dtypes = (
+ _serialize_sparse_tensors_join(tensor_list_list, enqueue_many))
types = _dtypes(tensor_list_list)
shapes = _shapes(tensor_list_list, shapes, enqueue_many)
queue = data_flow_ops.RandomShuffleQueue(
@@ -602,4 +669,7 @@ def shuffle_batch_join(tensor_list_list, batch_size, capacity,
"queue/%sfraction_over_%d_of_%d_full" %
(name, min_after_dequeue, capacity - min_after_dequeue))
logging_ops.scalar_summary(summary_name, full)
- return queue.dequeue_many(batch_size, name=name)
+
+ dequeued = queue.dequeue_many(batch_size, name=name)
+ dequeued = _deserialize_sparse_tensors(dequeued, is_sparse, sparse_dtypes)
+ return dequeued
diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py
index 7057341fa7..14c31442dd 100644
--- a/tensorflow/python/training/input_test.py
+++ b/tensorflow/python/training/input_test.py
@@ -318,7 +318,12 @@ class BatchTest(tf.test.TestCase):
zero64 = tf.constant(0, dtype=tf.int64)
examples = tf.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size)
- batched = tf.train.batch([counter, "string"], batch_size=batch_size)
+ sparse_counter = tf.SparseTensor(
+ indices=tf.reshape(tf.pack([zero64, zero64 + 1]), [2, 1]),
+ values=tf.cast(tf.pack([counter, -counter]), tf.float32),
+ shape=[2])
+ batched = tf.train.batch(
+ [counter, sparse_counter, "string"], batch_size=batch_size)
tf.initialize_all_variables().run()
threads = tf.train.start_queue_runners()
@@ -326,7 +331,16 @@ class BatchTest(tf.test.TestCase):
results = sess.run(batched)
self.assertAllEqual(results[0], np.arange(i * batch_size,
(i + 1) * batch_size))
- self.assertAllEqual(results[1], [b"string"] * batch_size)
+ self.assertAllEqual(
+ results[1].indices,
+ np.vstack((np.arange(2 * batch_size) // 2, # 0, 0, 1, 1, ...
+ [0, 1] * batch_size)).T)
+ # [x, -x, x+1, -(x+1), ...]
+ expected = np.arange(2 * i * batch_size, 2 * (i + 1) * batch_size) // 2
+ expected *= ([1, -1] * batch_size) # mult by [1, -1, 1, -1, ...]
+ self.assertAllEqual(results[1].values, expected)
+ self.assertAllEqual(results[1].shape, [batch_size, 2])
+ self.assertAllEqual(results[2], [b"string"] * batch_size)
# Reached the limit.
with self.assertRaises(tf.errors.OutOfRangeError):
@@ -341,7 +355,12 @@ class BatchTest(tf.test.TestCase):
zero64 = tf.constant(0, dtype=tf.int64)
examples = tf.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size)
- pre_batched = tf.train.batch([counter, "string"], batch_size=2)
+ sparse_counter = tf.SparseTensor(
+ indices=tf.reshape(zero64, [1, 1]),
+ values=tf.pack([tf.cast(counter, tf.float32)]),
+ shape=[1])
+ pre_batched = tf.train.batch(
+ [counter, sparse_counter, "string"], batch_size=2)
batched = tf.train.batch(pre_batched, enqueue_many=True,
batch_size=batch_size)
tf.initialize_all_variables().run()
@@ -351,7 +370,13 @@ class BatchTest(tf.test.TestCase):
results = sess.run(batched)
self.assertAllEqual(results[0], np.arange(i * batch_size,
(i + 1) * batch_size))
- self.assertAllEqual(results[1], [b"string"] * batch_size)
+ self.assertAllEqual(
+ results[1].indices,
+ np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
+ self.assertAllEqual(
+ results[1].values, np.arange(i * batch_size, (i + 1) * batch_size))
+ self.assertAllEqual(results[1].shape, [batch_size, 1])
+ self.assertAllEqual(results[2], [b"string"] * batch_size)
# Reached the limit.
with self.assertRaises(tf.errors.OutOfRangeError):
@@ -364,10 +389,16 @@ class BatchTest(tf.test.TestCase):
batch_size = 10
num_batches = 3
zero64 = tf.constant(0, dtype=tf.int64)
+
examples = tf.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size)
- batched = tf.train.batch([counter, "string"], batch_size=batch_size,
- num_threads=4)
+ sparse_counter = tf.SparseTensor(
+ indices=tf.reshape(zero64, [1, 1]),
+ values=tf.pack([tf.cast(counter, tf.float32)]),
+ shape=[1])
+ batched = tf.train.batch(
+ [counter, sparse_counter, "string"],
+ batch_size=batch_size, num_threads=4)
tf.initialize_all_variables().run()
threads = tf.train.start_queue_runners()
@@ -376,8 +407,13 @@ class BatchTest(tf.test.TestCase):
results = sess.run(batched)
tf.logging.info("Batch %d: %s", i, results[0])
self.assertEqual(len(results[0]), batch_size)
+ self.assertAllEqual(results[0], results[1].values)
+ self.assertAllEqual(
+ results[1].indices,
+ np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
+ self.assertAllEqual(results[1].shape, [batch_size, 1])
all_counts.extend(results[0])
- self.assertAllEqual(results[1], [b"string"] * batch_size)
+ self.assertAllEqual(results[2], [b"string"] * batch_size)
self.assertItemsEqual(all_counts, range(num_batches * batch_size))
# Reached the limit.
@@ -411,16 +447,26 @@ class BatchJoinTest(tf.test.TestCase):
zero64 = tf.constant(0, dtype=tf.int64)
examples = tf.Variable(zero64)
counter = examples.count_up_to(num_a)
+ sparse_counter = tf.SparseTensor(
+ indices=tf.reshape(zero64, [1, 1]),
+ values=tf.pack([tf.cast(counter, tf.float32)]),
+ shape=[1])
# The second generates (99, "b") 90 times and then stops.
num_b = 90
ninety_nine = tf.train.limit_epochs(
tf.constant(99, dtype=tf.int64), num_b)
+ sparse_ninety_nine = tf.SparseTensor(
+ indices=tf.reshape(zero64, [1, 1]),
+ values=tf.pack([tf.cast(ninety_nine, tf.float32)]),
+ shape=[1])
# These get joined together and grouped into batches of 5.
batch_size = 5
- batched = tf.train.batch_join([[counter, "a"], [ninety_nine, "b"]],
- batch_size=batch_size)
+ batched = tf.train.batch_join(
+ [[counter, sparse_counter, "a"],
+ [ninety_nine, sparse_ninety_nine, "b"]],
+ batch_size=batch_size)
tf.initialize_all_variables().run()
threads = tf.train.start_queue_runners()
@@ -433,9 +479,14 @@ class BatchJoinTest(tf.test.TestCase):
results = sess.run(batched)
tf.logging.info("Batch %d: %s", i, results[0])
self.assertEqual(len(results[0]), batch_size)
- self.assertEqual(len(results[1]), batch_size)
- which_a = [i for i, s in enumerate(results[1]) if s == b"a"]
- which_b = [i for i, s in enumerate(results[1]) if s == b"b"]
+ self.assertEqual(len(results[2]), batch_size)
+ self.assertAllEqual(results[0], results[1].values)
+ self.assertAllEqual(
+ results[1].indices,
+ np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
+ self.assertAllEqual(results[1].shape, [batch_size, 1])
+ which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
+ which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
self.assertEqual(len(which_a) + len(which_b), batch_size)
if len(which_a) > 0 and len(which_b) > 0: saw_both += 1
all_a.extend([results[0][i] for i in which_a])
@@ -481,8 +532,13 @@ class ShuffleBatchTest(tf.test.TestCase):
zero64 = tf.constant(0, dtype=tf.int64)
examples = tf.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size)
+ sparse_counter = tf.SparseTensor(
+ indices=tf.reshape(zero64, [1, 1]),
+ values=tf.pack([tf.cast(counter, tf.float32)]),
+ shape=[1])
batched = tf.train.shuffle_batch(
- [counter, "string"], batch_size=batch_size, capacity=32,
+ [counter, sparse_counter, "string"],
+ batch_size=batch_size, capacity=32,
min_after_dequeue=16, seed=141421)
tf.initialize_all_variables().run()
threads = tf.train.start_queue_runners()
@@ -492,7 +548,12 @@ class ShuffleBatchTest(tf.test.TestCase):
results = sess.run(batched)
self.assertEqual(len(results[0]), batch_size)
all_counts.extend(results[0])
- self.assertAllEqual(results[1], [b"string"] * batch_size)
+ self.assertAllEqual(
+ results[1].indices,
+ np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
+ self.assertAllEqual(results[0], results[1].values)
+ self.assertAllEqual(results[1].shape, [batch_size, 1])
+ self.assertAllEqual(results[2], [b"string"] * batch_size)
# Results scrambled, but include all the expected numbers.
deltas = [all_counts[i + 1] - all_counts[i]
for i in range(len(all_counts) - 1)]
@@ -512,8 +573,13 @@ class ShuffleBatchTest(tf.test.TestCase):
zero64 = tf.constant(0, dtype=tf.int64)
examples = tf.Variable(zero64)
counter = examples.count_up_to(num_batches * batch_size)
+ sparse_counter = tf.SparseTensor(
+ indices=tf.reshape(zero64, [1, 1]),
+ values=tf.pack([tf.cast(counter, tf.float32)]),
+ shape=[1])
batched = tf.train.shuffle_batch(
- [counter, "string"], batch_size=batch_size, capacity=32,
+ [counter, sparse_counter, "string"],
+ batch_size=batch_size, capacity=32,
min_after_dequeue=16, seed=173205, num_threads=4)
tf.initialize_all_variables().run()
threads = tf.train.start_queue_runners()
@@ -524,7 +590,12 @@ class ShuffleBatchTest(tf.test.TestCase):
tf.logging.info("Batch %d: %s", i, results[0])
self.assertEqual(len(results[0]), batch_size)
all_counts.extend(results[0])
- self.assertAllEqual(results[1], [b"string"] * batch_size)
+ self.assertAllEqual(
+ results[1].indices,
+ np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
+ self.assertAllEqual(results[0], results[1].values)
+ self.assertAllEqual(results[1].shape, [batch_size, 1])
+ self.assertAllEqual(results[2], [b"string"] * batch_size)
# Results scrambled, but include all the expected numbers.
deltas = [all_counts[i + 1] - all_counts[i]
for i in range(len(all_counts) - 1)]
@@ -564,17 +635,27 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
zero64 = tf.constant(0, dtype=tf.int64)
examples = tf.Variable(zero64)
counter = examples.count_up_to(num_a)
+ sparse_counter = tf.SparseTensor(
+ indices=tf.reshape(zero64, [1, 1]),
+ values=tf.pack([tf.cast(counter, tf.float32)]),
+ shape=[1])
# The second generates (99, "b") 35 times and then stops.
num_b = 35
ninety_nine = tf.train.limit_epochs(
tf.constant(99, dtype=tf.int64), num_b)
+ sparse_ninety_nine = tf.SparseTensor(
+ indices=tf.reshape(zero64, [1, 1]),
+ values=tf.pack([tf.cast(ninety_nine, tf.float32)]),
+ shape=[1])
# These get joined together and grouped into batches of 5.
batch_size = 5
batched = tf.train.shuffle_batch_join(
- [[counter, "a"], [ninety_nine, "b"]], batch_size=batch_size,
- capacity=32, min_after_dequeue=16, seed=223607)
+ [[counter, sparse_counter, "a"],
+ [ninety_nine, sparse_ninety_nine, "b"]],
+ batch_size=batch_size, capacity=32,
+ min_after_dequeue=16, seed=223607)
tf.initialize_all_variables().run()
threads = tf.train.start_queue_runners()
@@ -588,9 +669,14 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
results = sess.run(batched)
tf.logging.info("Batch %d: %s", i, results[0])
self.assertEqual(len(results[0]), batch_size)
- self.assertEqual(len(results[1]), batch_size)
- which_a = [i for i, s in enumerate(results[1]) if s == b"a"]
- which_b = [i for i, s in enumerate(results[1]) if s == b"b"]
+ self.assertEqual(len(results[2]), batch_size)
+ self.assertAllEqual(results[0], results[1].values)
+ self.assertAllEqual(
+ results[1].indices,
+ np.vstack((np.arange(batch_size), np.zeros(batch_size))).T)
+ self.assertAllEqual(results[1].shape, [batch_size, 1])
+ which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
+ which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
self.assertEqual(len(which_a) + len(which_b), batch_size)
if len(which_a) > 0 and len(which_b) > 0: saw_both += 1
all_a.extend([results[0][i] for i in which_a])