aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-25 15:19:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-25 15:22:39 -0700
commit8f0d0bdca81f9926bf6cf51eb7bf72e04fe43509 (patch)
treeb289fee9e094fd098950a6e8d7b9b1ce3df18211
parent4659e562c95622c47668f31fc42b8f5f2cfb4596 (diff)
Simplify `dense_to_sparse_tensor` and `indicators_to_sparse_ids`, and fix them to work with inputs of undefined rank.
Add test for `indicators_to_sparse_ids` `dtype` arg. Small update to `unstack` pydoc. PiperOrigin-RevId: 157160634
-rw-r--r--tensorflow/contrib/layers/python/ops/sparse_ops.py125
-rw-r--r--tensorflow/contrib/layers/python/ops/sparse_ops_test.py47
-rw-r--r--tensorflow/python/ops/array_ops.py3
3 files changed, 107 insertions, 68 deletions
diff --git a/tensorflow/contrib/layers/python/ops/sparse_ops.py b/tensorflow/contrib/layers/python/ops/sparse_ops.py
index 114f312d27..7e79630c5e 100644
--- a/tensorflow/contrib/layers/python/ops/sparse_ops.py
+++ b/tensorflow/contrib/layers/python/ops/sparse_ops.py
@@ -38,21 +38,25 @@ def _multiplier_helper(shape):
return multipliers
-def _ignore_value(dtype):
- if dtype == dtypes.string:
- # Exception due to TF strings are converted to numpy objects by default.
- return ""
- # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is
- # constructing a new numpy object of the given type, which yields the
- # default value for that type.
- return dtype.as_numpy_dtype()
+def _ignore_value_tensor(dtype, ignore_value=None):
+ """Create `Tensor` from provided `ignore_value` and `dtype`."""
+ if ignore_value is None:
+ if dtype == dtypes.string:
+ # Exception due to TF strings are converted to numpy objects by default.
+ ignore_value = ""
+ else:
+ # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is
+ # constructing a new numpy object of the given type, which yields the
+ # default value for that type.
+ ignore_value = dtype.as_numpy_dtype()
+ return math_ops.cast(ignore_value, dtype, name="ignore_value")
def dense_to_sparse_tensor(dense_tensor, ignore_value=None):
"""Converts dense `Tensor` to `SparseTensor`, dropping `ignore_value` cells.
Args:
- dense_tensor: A `Tensor`. This must have a statically defined rank.
+ dense_tensor: A `Tensor`.
ignore_value: Entries in `dense_tensor` equal to this value will be
absent from the return `SparseTensor`. If `None`, default value of
`dense_tensor` dtype will be used (e.g. '' for `str`, 0 for `int`).
@@ -64,33 +68,17 @@ def dense_to_sparse_tensor(dense_tensor, ignore_value=None):
ValueError: when `dense_tensor`'s rank is `None`.
"""
with ops.name_scope("DenseToSparseTensor"):
- dense_t = ops.convert_to_tensor(dense_tensor)
- if dense_t.get_shape().ndims is None:
- # TODO(b/32318825): Implement dense_to_sparse_tensor for undefined rank.
- raise ValueError("dense_tensor.get_shape() should be defined, got None.")
- if ignore_value is None:
- ignore_value = _ignore_value(dense_t.dtype)
- dense_shape = math_ops.cast(array_ops.shape(dense_t), dtypes.int64)
+ dense_tensor = ops.convert_to_tensor(dense_tensor)
+ ignore_value = _ignore_value_tensor(dense_tensor.dtype, ignore_value)
indices = array_ops.where(
- math_ops.not_equal(dense_t, math_ops.cast(ignore_value, dense_t.dtype)))
- index_dims = len(dense_t.get_shape())
- # Flattens the tensor and indices for use with gather.
- flat_tensor = array_ops.reshape(dense_t, [-1])
- flat_indices = indices[:, index_dims - 1]
- # Computes the correct flattened indices for 2d (or higher) tensors.
- if index_dims > 1:
- higher_dims = indices[:, :index_dims - 1]
- shape_multipliers = array_ops.stack(
- _multiplier_helper(array_ops.unstack(dense_shape)[1:]))
- offsets = math_ops.reduce_sum(
- math_ops.multiply(higher_dims, shape_multipliers),
- reduction_indices=[1])
- flat_indices = math_ops.add(flat_indices, offsets)
- values = array_ops.gather(flat_tensor, flat_indices)
- return sparse_tensor.SparseTensor(indices, values, dense_shape)
-
-
-# TODO(ptucker): Support integer dtype arg, and cast values back to that.
+ math_ops.not_equal(dense_tensor, ignore_value), name="indices")
+ return sparse_tensor.SparseTensor(
+ indices=indices,
+ values=array_ops.gather_nd(dense_tensor, indices, name="values"),
+ dense_shape=array_ops.shape(
+ dense_tensor, out_type=dtypes.int64, name="dense_shape"))
+
+
def indicators_to_sparse_ids(indicators, ignore_value=None, dtype=dtypes.int64):
"""Convert a dense indicator tensor to sparse IDs.
@@ -98,31 +86,54 @@ def indicators_to_sparse_ids(indicators, ignore_value=None, dtype=dtypes.int64):
In the following example, we have an input of shape (2, 2, num_classes),
where num_classes=4.
+ ```python
indicators = [
- [[0, 0, 1, 0], [0, 0, 0, 0]],
- [[1, 0, 1, 1], [0, 0, 1, 0]],
+ [
+ [0, 0, 1, 0],
+ [0, 0, 0, 0]
+ ], [
+ [1, 0, 1, 1],
+ [0, 0, 1, 0]
+ ]
]
- indicator_to_sparse_ids(indicators) => [
- [[2], []],
- [[0, 2, 3], [2]],
+ sparse_ids = indicator_to_sparse_ids(indicators)
+ ```
+
+ `sparse_ids` in "jagged" format:
+ [
+ [
+ [2],
+ []
+ ], [
+ [0, 2, 3],
+ [2]
+ ]
]
+ `sparse_ids` in `SparseTensor` format:
+ ```python
+ {
+ indices: [[0, 0, 1], [1, 0, 0], [1, 0, 1], [1, 0, 2], [1, 1, 0]],
+ values: [2, 0, 2, 3, 2],
+ dense_shape: [2, 2, 3]
+ }
+ ```
+
Args:
- indicators: Dense `Tensor` of shape `(d0, ..., dn, num_classes)`. This must
- have a statically defined rank. `ignore_value` values are ignored. For
- other values (typically, ones), the index along the last dimension is
- returned.
+ indicators: Dense `Tensor` of shape `(d0, ..., dn, num_classes)`.
+ `ignore_value` values are ignored. For other values (typically, ones), the
+ index along the last dimension is returned.
ignore_value: Entries in `indicators` equal to this value will be
absent from the returned `SparseTensor`. If `None`, default value of
`indicators` dtype will be used (e.g. '' for `str`, 0 for `int`).
dtype: Type of result, must be integer type.
Returns:
- `tf.int64` `SparseTensor` of shape `(d0, ..., dn, max_num_labels)`,
+ `SparseTensor` of type `dtype` and shape `(d0, ..., dn, max_num_labels)`,
where `max_num_labels` is the maximum number of non-zero values in any
row (in the example above, row (1, 1) has 3 non-zero values, so the result
shape is (2, 2, 3)). The values of this `SparseTensor` are in the range
- `[0, num_classes)` and correspond to the index of non-empty values along
+ `[0, num_classes)` and correspond to the index of non-ignore values along
the last dimension of `indicators`.
Raises:
@@ -135,10 +146,9 @@ def indicators_to_sparse_ids(indicators, ignore_value=None, dtype=dtypes.int64):
# Convert indicators to binary ones and zeros. We use int64 since
# SparseTensor requires int64 indices.
indicators = ops.convert_to_tensor(indicators, name="indicators")
- if ignore_value is None:
- ignore_value = _ignore_value(indicators.dtype)
missing_indicators = math_ops.equal(
- indicators, ignore_value, name="missing")
+ indicators, _ignore_value_tensor(indicators.dtype, ignore_value),
+ name="missing")
zeros_like_indicators = array_ops.zeros_like(
indicators, dtype=dtypes.int64, name="zeros")
binary_indicators = array_ops.where(
@@ -149,7 +159,7 @@ def indicators_to_sparse_ids(indicators, ignore_value=None, dtype=dtypes.int64):
# Use cumsum along the last dimension to generate per-row indexes.
# Note that these are 1-based (since 0 indicates missing values), so they're
# off-by-1 from the actual indices. We'll subtract 1 below. Since they're
- # off-by-one, the max value is the size of last dimension (i.e.,
+ # off-by-one, the max value is the size of the last dimension (i.e.,
# last_index + 1).
row_index_indicators = array_ops.where(
missing_indicators, zeros_like_indicators,
@@ -161,16 +171,17 @@ def indicators_to_sparse_ids(indicators, ignore_value=None, dtype=dtypes.int64):
# Convert to a SparseTensor. The values of this SparseTensor are the last
# indices of our result, and the last indices of this SparseTensor (i.e.,
# the class IDs indicated by `indicators`) are the values of our result, so
- # we use unstack/stack to swap them.
+ # we use tensor slicing and concat to swap them.
sparse_row_index_indicators = dense_to_sparse_tensor(
row_index_indicators, ignore_value=0)
- index_columns = array_ops.unstack(
- sparse_row_index_indicators.indices, axis=1)
return sparse_tensor.SparseTensor(
- indices=array_ops.stack(
- index_columns[0:-1] + [sparse_row_index_indicators.values - 1],
- axis=1, name="indices"),
- values=math_ops.cast(index_columns[-1], dtype=dtype, name="values"),
+ indices=array_ops.concat((
+ sparse_row_index_indicators.indices[:, :-1],
+ array_ops.reshape(sparse_row_index_indicators.values - 1, (-1, 1))
+ ), axis=1, name="indices"),
+ values=math_ops.cast(
+ sparse_row_index_indicators.indices[:, -1], dtype=dtype,
+ name="values"),
dense_shape=array_ops.concat(
(sparse_row_index_indicators.dense_shape[0:-1], result_last_dim),
axis=0, name="dense_shape"))
diff --git a/tensorflow/contrib/layers/python/ops/sparse_ops_test.py b/tensorflow/contrib/layers/python/ops/sparse_ops_test.py
index 9a9582dcad..d50750001e 100644
--- a/tensorflow/contrib/layers/python/ops/sparse_ops_test.py
+++ b/tensorflow/contrib/layers/python/ops/sparse_ops_test.py
@@ -116,7 +116,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([1, 2, 3, 4, 5, 7, 8, 9], result.values)
self.assertAllEqual([2, 2, 4], result.dense_shape)
- def test_dense_to_sparse_tensor_1d_no_shape(self):
+ def test_dense_to_sparse_tensor_unknown_1d_shape(self):
with self.test_session() as sess:
tensor = array_ops.placeholder(shape=[None], dtype=dtypes.int32)
st = sparse_ops.dense_to_sparse_tensor(tensor)
@@ -125,7 +125,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([100, 3], result.values)
self.assertAllEqual([4], result.dense_shape)
- def test_dense_to_sparse_tensor_3d_no_shape(self):
+ def test_dense_to_sparse_tensor_unknown_3d_shape(self):
with self.test_session() as sess:
tensor = array_ops.placeholder(
shape=[None, None, None], dtype=dtypes.int32)
@@ -140,11 +140,15 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([1, 2, 3, 4, 5, 7, 8, 9], result.values)
self.assertAllEqual([2, 2, 4], result.dense_shape)
- def test_convert_to_sparse_undef_shape(self):
- with self.test_session():
- with self.assertRaises(ValueError):
- tensor = array_ops.placeholder(dtype=dtypes.int32)
- sparse_ops.dense_to_sparse_tensor(tensor)
+ def test_dense_to_sparse_unknown_rank(self):
+ ph = array_ops.placeholder(dtype=dtypes.int32)
+ with self.test_session() as sess:
+ st = sparse_ops.dense_to_sparse_tensor(ph)
+ result = sess.run(st, feed_dict={ph: [[1, 2, 0, 0], [3, 4, 5, 0]]})
+ self.assertAllEqual([[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]],
+ result.indices)
+ self.assertAllEqual([1, 2, 3, 4, 5], result.values)
+ self.assertAllEqual([2, 4], result.dense_shape)
class SparseRowEnvelopeTest(test.TestCase):
@@ -244,6 +248,20 @@ class IndicatorToSparseIdsTest(test.TestCase):
), dense_shape=(4, 2, 3),
), sparse_ids.eval())
+ def test_int16_to_sparse_ids_2d(self):
+ indicators = (
+ (0, 0, 1, 0),
+ (1, 0, 0, 1),
+ )
+ sparse_ids = sparse_ops.indicators_to_sparse_ids(
+ indicators, dtype=dtypes.int16)
+ with self.test_session():
+ _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=np.array((2, 0, 3), dtype=np.int16),
+ dense_shape=(2, 2),
+ ), sparse_ids.eval())
+
def test_indicators_to_sparse_ids_ignore_value(self):
indicators = (
((-1, -1, 10, -1), (-1, -1, -1, -1)),
@@ -285,7 +303,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
dense_shape=(2, 2, 2),
), sparse_ids.eval())
- def test_indicators_to_sparse_ids_unknown_dims(self):
+ def test_indicators_to_sparse_ids_unknown_3d_shape(self):
indicators_values = (
((0, 0, 1, 0), (0, 0, 0, 0)),
((1, 0, 0, 1), (0, 0, 1, 0)),
@@ -301,9 +319,18 @@ class IndicatorToSparseIdsTest(test.TestCase):
), sparse_ids.eval(feed_dict={indicators: indicators_values}))
def test_indicators_to_sparse_ids_unknown_rank(self):
+ indicators_values = (
+ ((0, 0, 1, 0), (0, 0, 0, 0)),
+ ((1, 0, 0, 1), (0, 0, 1, 0)),
+ )
indicators = array_ops.placeholder(dtype=dtypes.int32)
- with self.assertRaisesRegexp(ValueError, r'shape.*should be defined'):
- sparse_ops.indicators_to_sparse_ids(indicators)
+ sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
+ with self.test_session():
+ _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
+ values=(2, 0, 3, 2),
+ dense_shape=(2, 2, 2),
+ ), sparse_ids.eval(feed_dict={indicators: indicators_values}))
if __name__ == '__main__':
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index f0dba04e44..62b29ce306 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -783,6 +783,7 @@ def parallel_stack(values, name="parallel_stack"):
return gen_array_ops._parallel_concat(
[expand_dims(value, 0) for value in values], shape=output_shape)
+
def stack(values, axis=0, name="stack"):
"""Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.
@@ -944,7 +945,7 @@ def unstack(value, num=None, axis=0, name="unstack"):
`value[:, i, :, :]` and each tensor in `output` will have shape `(A, C, D)`.
Etc.
- This is the opposite of pack. The numpy equivalent is
+ This is the opposite of stack. The numpy equivalent is
tf.unstack(x, n) = list(x)