diff options
-rw-r--r-- | tensorflow/contrib/metrics/python/kernel_tests/set_ops_test.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/metrics/python/ops/set_ops.py | 6 |
2 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/contrib/metrics/python/kernel_tests/set_ops_test.py b/tensorflow/contrib/metrics/python/kernel_tests/set_ops_test.py index 4ba55a99eb..0feb1103be 100644 --- a/tensorflow/contrib/metrics/python/kernel_tests/set_ops_test.py +++ b/tensorflow/contrib/metrics/python/kernel_tests/set_ops_test.py @@ -153,6 +153,7 @@ class SetOpsTest(test_util.TensorFlowTestCase): ] for op in ops: self.assertEqual(None, op.get_shape().dims) + self.assertEqual(tf.int32, op.dtype) with self.test_session() as sess: results = sess.run(ops) self.assertAllEqual(results[0], results[1]) diff --git a/tensorflow/contrib/metrics/python/ops/set_ops.py b/tensorflow/contrib/metrics/python/ops/set_ops.py index e78e60ca19..01b624593b 100644 --- a/tensorflow/contrib/metrics/python/ops/set_ops.py +++ b/tensorflow/contrib/metrics/python/ops/set_ops.py @@ -49,9 +49,9 @@ def set_size(a, validate_indices=True): in `a`. Returns: - For `a` ranked `n`, this is a `Tensor` with rank `n-1`, and the same 1st - `n-1` dimensions as `a`. Each value is the number of unique elements in - the corresponding `[0...n-1]` dimension of `a`. + `int32` `Tensor` of set sizes. For `a` ranked `n`, this is a `Tensor` with + rank `n-1`, and the same 1st `n-1` dimensions as `a`. Each value is the + number of unique elements in the corresponding `[0...n-1]` dimension of `a`. Raises: TypeError: If `a` is an invalid types. |