aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/sparse_slice_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/sparse_slice_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/sparse_slice_op_test.py22
1 files changed, 20 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/sparse_slice_op_test.py b/tensorflow/python/kernel_tests/sparse_slice_op_test.py
index da116601f8..97f30daf4a 100644
--- a/tensorflow/python/kernel_tests/sparse_slice_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_slice_op_test.py
@@ -21,13 +21,15 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import sparse_ops
+import tensorflow.python.ops.sparse_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
class SparseSliceOpTest(test.TestCase):
- def _SparseTensor_4x6(self):
+ def _SparseTensor_4x6(self, val_dtype=np.int64):
# [0 | |2 | |4 |5 ]
# [ |11| |13|14| ]
# [20| | |23| |25]
@@ -37,7 +39,7 @@ class SparseSliceOpTest(test.TestCase):
[2, 3], [2, 5], [3, 0], [3, 2], [3, 3], [3, 5]]).astype(
np.int64)
val = np.array([0, 2, 4, 5, 11, 13, 14, 20, 23, 25, 30, 32, 33, 35]).astype(
- np.int64)
+ val_dtype)
shape = np.array([4, 6]).astype(np.int64)
return sparse_tensor.SparseTensor(ind, val, shape)
@@ -244,6 +246,22 @@ class SparseSliceOpTest(test.TestCase):
self.assertAllEqual(sparse_tensor5.values.eval(), [5, 25, 35])
self.assertAllEqual(sparse_tensor5.dense_shape.eval(), [4, 1])
+ def testGradients(self):
+ sp_input = self._SparseTensor_4x6(val_dtype=np.float32)
+ start_and_size = [([0, 0], [4, 2]),
+ ([0, 2], [5, 2]),
+ ([0, 4], [5, 3])]
+
+ with self.test_session(use_gpu=False):
+ for start, size in start_and_size:
+ sp_output = sparse_ops.sparse_slice(sp_input, start, size)
+ nnz_in = len(sp_input.values.eval())
+ nnz_out = len(sp_output.values.eval())
+
+ err = gradient_checker.compute_gradient_error(
+ [sp_input.values], [(nnz_in,)], sp_output.values, (nnz_out,))
+ self.assertLess(err, 1e-3)
+
if __name__ == '__main__':
test.main()