From 8f0d0bdca81f9926bf6cf51eb7bf72e04fe43509 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 25 May 2017 15:19:11 -0700 Subject: Simplify `dense_to_sparse_tensor` and `indicators_to_sparse_ids`, and fix them to work with inputs of undefined rank. Add test for `indicators_to_sparse_ids` `dtype` arg. Small update to `unstack` pydoc. PiperOrigin-RevId: 157160634 --- tensorflow/contrib/layers/python/ops/sparse_ops.py | 125 +++++++++++---------- .../contrib/layers/python/ops/sparse_ops_test.py | 47 ++++++-- tensorflow/python/ops/array_ops.py | 3 +- 3 files changed, 107 insertions(+), 68 deletions(-) diff --git a/tensorflow/contrib/layers/python/ops/sparse_ops.py b/tensorflow/contrib/layers/python/ops/sparse_ops.py index 114f312d27..7e79630c5e 100644 --- a/tensorflow/contrib/layers/python/ops/sparse_ops.py +++ b/tensorflow/contrib/layers/python/ops/sparse_ops.py @@ -38,21 +38,25 @@ def _multiplier_helper(shape): return multipliers -def _ignore_value(dtype): - if dtype == dtypes.string: - # Exception due to TF strings are converted to numpy objects by default. - return "" - # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is - # constructing a new numpy object of the given type, which yields the - # default value for that type. - return dtype.as_numpy_dtype() +def _ignore_value_tensor(dtype, ignore_value=None): + """Create `Tensor` from provided `ignore_value` and `dtype`.""" + if ignore_value is None: + if dtype == dtypes.string: + # Exception due to TF strings are converted to numpy objects by default. + ignore_value = "" + else: + # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is + # constructing a new numpy object of the given type, which yields the + # default value for that type. + ignore_value = dtype.as_numpy_dtype() + return math_ops.cast(ignore_value, dtype, name="ignore_value") def dense_to_sparse_tensor(dense_tensor, ignore_value=None): """Converts dense `Tensor` to `SparseTensor`, dropping `ignore_value` cells. Args: - dense_tensor: A `Tensor`. This must have a statically defined rank. + dense_tensor: A `Tensor`. ignore_value: Entries in `dense_tensor` equal to this value will be absent from the return `SparseTensor`. If `None`, default value of `dense_tensor` dtype will be used (e.g. '' for `str`, 0 for `int`). @@ -64,33 +68,17 @@ def dense_to_sparse_tensor(dense_tensor, ignore_value=None): ValueError: when `dense_tensor`'s rank is `None`. """ with ops.name_scope("DenseToSparseTensor"): - dense_t = ops.convert_to_tensor(dense_tensor) - if dense_t.get_shape().ndims is None: - # TODO(b/32318825): Implement dense_to_sparse_tensor for undefined rank. - raise ValueError("dense_tensor.get_shape() should be defined, got None.") - if ignore_value is None: - ignore_value = _ignore_value(dense_t.dtype) - dense_shape = math_ops.cast(array_ops.shape(dense_t), dtypes.int64) + dense_tensor = ops.convert_to_tensor(dense_tensor) + ignore_value = _ignore_value_tensor(dense_tensor.dtype, ignore_value) indices = array_ops.where( - math_ops.not_equal(dense_t, math_ops.cast(ignore_value, dense_t.dtype))) - index_dims = len(dense_t.get_shape()) - # Flattens the tensor and indices for use with gather. - flat_tensor = array_ops.reshape(dense_t, [-1]) - flat_indices = indices[:, index_dims - 1] - # Computes the correct flattened indices for 2d (or higher) tensors. - if index_dims > 1: - higher_dims = indices[:, :index_dims - 1] - shape_multipliers = array_ops.stack( - _multiplier_helper(array_ops.unstack(dense_shape)[1:])) - offsets = math_ops.reduce_sum( - math_ops.multiply(higher_dims, shape_multipliers), - reduction_indices=[1]) - flat_indices = math_ops.add(flat_indices, offsets) - values = array_ops.gather(flat_tensor, flat_indices) - return sparse_tensor.SparseTensor(indices, values, dense_shape) - - -# TODO(ptucker): Support integer dtype arg, and cast values back to that. + math_ops.not_equal(dense_tensor, ignore_value), name="indices") + return sparse_tensor.SparseTensor( + indices=indices, + values=array_ops.gather_nd(dense_tensor, indices, name="values"), + dense_shape=array_ops.shape( + dense_tensor, out_type=dtypes.int64, name="dense_shape")) + + def indicators_to_sparse_ids(indicators, ignore_value=None, dtype=dtypes.int64): """Convert a dense indicator tensor to sparse IDs. @@ -98,31 +86,54 @@ def indicators_to_sparse_ids(indicators, ignore_value=None, dtype=dtypes.int64): In the following example, we have an input of shape (2, 2, num_classes), where num_classes=4. + ```python indicators = [ - [[0, 0, 1, 0], [0, 0, 0, 0]], - [[1, 0, 1, 1], [0, 0, 1, 0]], + [ + [0, 0, 1, 0], + [0, 0, 0, 0] + ], [ + [1, 0, 1, 1], + [0, 0, 1, 0] + ] ] - indicator_to_sparse_ids(indicators) => [ - [[2], []], - [[0, 2, 3], [2]], + sparse_ids = indicator_to_sparse_ids(indicators) + ``` + + `sparse_ids` in "jagged" format: + [ + [ + [2], + [] + ], [ + [0, 2, 3], + [2] + ] ] + `sparse_ids` in `SparseTensor` format: + ```python + { + indices: [[0, 0, 1], [1, 0, 0], [1, 0, 1], [1, 0, 2], [1, 1, 0]], + values: [2, 0, 2, 3, 2], + dense_shape: [2, 2, 3] + } + ``` + Args: - indicators: Dense `Tensor` of shape `(d0, ..., dn, num_classes)`. This must - have a statically defined rank. `ignore_value` values are ignored. For - other values (typically, ones), the index along the last dimension is - returned. + indicators: Dense `Tensor` of shape `(d0, ..., dn, num_classes)`. + `ignore_value` values are ignored. For other values (typically, ones), the + index along the last dimension is returned. ignore_value: Entries in `indicators` equal to this value will be absent from the returned `SparseTensor`. If `None`, default value of `indicators` dtype will be used (e.g. '' for `str`, 0 for `int`). dtype: Type of result, must be integer type. Returns: - `tf.int64` `SparseTensor` of shape `(d0, ..., dn, max_num_labels)`, + `SparseTensor` of type `dtype` and shape `(d0, ..., dn, max_num_labels)`, where `max_num_labels` is the maximum number of non-zero values in any row (in the example above, row (1, 1) has 3 non-zero values, so the result shape is (2, 2, 3)). The values of this `SparseTensor` are in the range - `[0, num_classes)` and correspond to the index of non-empty values along + `[0, num_classes)` and correspond to the index of non-ignore values along the last dimension of `indicators`. Raises: @@ -135,10 +146,9 @@ def indicators_to_sparse_ids(indicators, ignore_value=None, dtype=dtypes.int64): # Convert indicators to binary ones and zeros. We use int64 since # SparseTensor requires int64 indices. indicators = ops.convert_to_tensor(indicators, name="indicators") - if ignore_value is None: - ignore_value = _ignore_value(indicators.dtype) missing_indicators = math_ops.equal( - indicators, ignore_value, name="missing") + indicators, _ignore_value_tensor(indicators.dtype, ignore_value), + name="missing") zeros_like_indicators = array_ops.zeros_like( indicators, dtype=dtypes.int64, name="zeros") binary_indicators = array_ops.where( @@ -149,7 +159,7 @@ def indicators_to_sparse_ids(indicators, ignore_value=None, dtype=dtypes.int64): # Use cumsum along the last dimension to generate per-row indexes. # Note that these are 1-based (since 0 indicates missing values), so they're # off-by-1 from the actual indices. We'll subtract 1 below. Since they're - # off-by-one, the max value is the size of last dimension (i.e., + # off-by-one, the max value is the size of the last dimension (i.e., # last_index + 1). row_index_indicators = array_ops.where( missing_indicators, zeros_like_indicators, @@ -161,16 +171,17 @@ def indicators_to_sparse_ids(indicators, ignore_value=None, dtype=dtypes.int64): # Convert to a SparseTensor. The values of this SparseTensor are the last # indices of our result, and the last indices of this SparseTensor (i.e., # the class IDs indicated by `indicators`) are the values of our result, so - # we use unstack/stack to swap them. + # we use tensor slicing and concat to swap them. sparse_row_index_indicators = dense_to_sparse_tensor( row_index_indicators, ignore_value=0) - index_columns = array_ops.unstack( - sparse_row_index_indicators.indices, axis=1) return sparse_tensor.SparseTensor( - indices=array_ops.stack( - index_columns[0:-1] + [sparse_row_index_indicators.values - 1], - axis=1, name="indices"), - values=math_ops.cast(index_columns[-1], dtype=dtype, name="values"), + indices=array_ops.concat(( + sparse_row_index_indicators.indices[:, :-1], + array_ops.reshape(sparse_row_index_indicators.values - 1, (-1, 1)) + ), axis=1, name="indices"), + values=math_ops.cast( + sparse_row_index_indicators.indices[:, -1], dtype=dtype, + name="values"), dense_shape=array_ops.concat( (sparse_row_index_indicators.dense_shape[0:-1], result_last_dim), axis=0, name="dense_shape")) diff --git a/tensorflow/contrib/layers/python/ops/sparse_ops_test.py b/tensorflow/contrib/layers/python/ops/sparse_ops_test.py index 9a9582dcad..d50750001e 100644 --- a/tensorflow/contrib/layers/python/ops/sparse_ops_test.py +++ b/tensorflow/contrib/layers/python/ops/sparse_ops_test.py @@ -116,7 +116,7 @@ class DenseToSparseTensorTest(test.TestCase): self.assertAllEqual([1, 2, 3, 4, 5, 7, 8, 9], result.values) self.assertAllEqual([2, 2, 4], result.dense_shape) - def test_dense_to_sparse_tensor_1d_no_shape(self): + def test_dense_to_sparse_tensor_unknown_1d_shape(self): with self.test_session() as sess: tensor = array_ops.placeholder(shape=[None], dtype=dtypes.int32) st = sparse_ops.dense_to_sparse_tensor(tensor) @@ -125,7 +125,7 @@ class DenseToSparseTensorTest(test.TestCase): self.assertAllEqual([100, 3], result.values) self.assertAllEqual([4], result.dense_shape) - def test_dense_to_sparse_tensor_3d_no_shape(self): + def test_dense_to_sparse_tensor_unknown_3d_shape(self): with self.test_session() as sess: tensor = array_ops.placeholder( shape=[None, None, None], dtype=dtypes.int32) @@ -140,11 +140,15 @@ class DenseToSparseTensorTest(test.TestCase): self.assertAllEqual([1, 2, 3, 4, 5, 7, 8, 9], result.values) self.assertAllEqual([2, 2, 4], result.dense_shape) - def test_convert_to_sparse_undef_shape(self): - with self.test_session(): - with self.assertRaises(ValueError): - tensor = array_ops.placeholder(dtype=dtypes.int32) - sparse_ops.dense_to_sparse_tensor(tensor) + def test_dense_to_sparse_unknown_rank(self): + ph = array_ops.placeholder(dtype=dtypes.int32) + with self.test_session() as sess: + st = sparse_ops.dense_to_sparse_tensor(ph) + result = sess.run(st, feed_dict={ph: [[1, 2, 0, 0], [3, 4, 5, 0]]}) + self.assertAllEqual([[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]], + result.indices) + self.assertAllEqual([1, 2, 3, 4, 5], result.values) + self.assertAllEqual([2, 4], result.dense_shape) class SparseRowEnvelopeTest(test.TestCase): @@ -244,6 +248,20 @@ class IndicatorToSparseIdsTest(test.TestCase): ), dense_shape=(4, 2, 3), ), sparse_ids.eval()) + def test_int16_to_sparse_ids_2d(self): + indicators = ( + (0, 0, 1, 0), + (1, 0, 0, 1), + ) + sparse_ids = sparse_ops.indicators_to_sparse_ids( + indicators, dtype=dtypes.int16) + with self.test_session(): + _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=np.array((2, 0, 3), dtype=np.int16), + dense_shape=(2, 2), + ), sparse_ids.eval()) + def test_indicators_to_sparse_ids_ignore_value(self): indicators = ( ((-1, -1, 10, -1), (-1, -1, -1, -1)), @@ -285,7 +303,7 @@ class IndicatorToSparseIdsTest(test.TestCase): dense_shape=(2, 2, 2), ), sparse_ids.eval()) - def test_indicators_to_sparse_ids_unknown_dims(self): + def test_indicators_to_sparse_ids_unknown_3d_shape(self): indicators_values = ( ((0, 0, 1, 0), (0, 0, 0, 0)), ((1, 0, 0, 1), (0, 0, 1, 0)), @@ -301,9 +319,18 @@ class IndicatorToSparseIdsTest(test.TestCase): ), sparse_ids.eval(feed_dict={indicators: indicators_values})) def test_indicators_to_sparse_ids_unknown_rank(self): + indicators_values = ( + ((0, 0, 1, 0), (0, 0, 0, 0)), + ((1, 0, 0, 1), (0, 0, 1, 0)), + ) indicators = array_ops.placeholder(dtype=dtypes.int32) - with self.assertRaisesRegexp(ValueError, r'shape.*should be defined'): - sparse_ops.indicators_to_sparse_ids(indicators) + sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators) + with self.test_session(): + _assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)), + values=(2, 0, 3, 2), + dense_shape=(2, 2, 2), + ), sparse_ids.eval(feed_dict={indicators: indicators_values})) if __name__ == '__main__': diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index f0dba04e44..62b29ce306 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -783,6 +783,7 @@ def parallel_stack(values, name="parallel_stack"): return gen_array_ops._parallel_concat( [expand_dims(value, 0) for value in values], shape=output_shape) + def stack(values, axis=0, name="stack"): """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor. @@ -944,7 +945,7 @@ def unstack(value, num=None, axis=0, name="unstack"): `value[:, i, :, :]` and each tensor in `output` will have shape `(A, C, D)`. Etc. - This is the opposite of pack. The numpy equivalent is + This is the opposite of stack. The numpy equivalent is tf.unstack(x, n) = list(x) -- cgit v1.2.3