diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/gather_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/gather_op_test.py | 103 |
1 files changed, 61 insertions, 42 deletions
diff --git a/tensorflow/python/kernel_tests/gather_op_test.py b/tensorflow/python/kernel_tests/gather_op_test.py index dac8d58b35..1f161e59cd 100644 --- a/tensorflow/python/kernel_tests/gather_op_test.py +++ b/tensorflow/python/kernel_tests/gather_op_test.py @@ -31,61 +31,80 @@ from tensorflow.python.platform import test class GatherTest(test.TestCase): use_gpu = False + def _buildParams(self, data, dtype): + data = data.astype(dtype.as_numpy_dtype) + # For complex types, add an index-dependent imaginary component so we can + # tell we got the right value. + if dtype.is_complex: + return data + 10j * data + return data + def testScalar1D(self): with self.test_session(use_gpu=self.use_gpu): - params = constant_op.constant([0, 1, 2, 3, 7, 5]) - indices = constant_op.constant(4) - gather_t = array_ops.gather(params, indices) - gather_val = gather_t.eval() - self.assertAllEqual(7, gather_val) - self.assertEqual([], gather_t.get_shape()) + data = np.array([0, 1, 2, 3, 7, 5]) + for dtype in (dtypes.float32, dtypes.complex64, dtypes.complex128): + params_np = self._buildParams(data, dtype) + params = constant_op.constant(params_np) + indices = constant_op.constant(4) + gather_t = array_ops.gather(params, indices) + gather_val = gather_t.eval() + self.assertAllEqual(params_np[4], gather_val) + self.assertEqual([], gather_t.get_shape()) def testScalar2D(self): with self.test_session(use_gpu=self.use_gpu): - params = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], - [9, 10, 11], [12, 13, 14]]) - indices = constant_op.constant(2) - gather_t = array_ops.gather(params, indices) - gather_val = gather_t.eval() - self.assertAllEqual([6, 7, 8], gather_val) - self.assertEqual([3], gather_t.get_shape()) + data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], + [9, 10, 11], [12, 13, 14]]) + for dtype in (dtypes.float32, dtypes.complex64, dtypes.complex128): + params_np = self._buildParams(data, dtype) + params = constant_op.constant(params_np) + indices = constant_op.constant(2) + gather_t = array_ops.gather(params, indices) + gather_val = gather_t.eval() + self.assertAllEqual(params_np[2], gather_val) + self.assertEqual([3], gather_t.get_shape()) def testSimpleTwoD32(self): with self.test_session(use_gpu=self.use_gpu): - params = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], - [9, 10, 11], [12, 13, 14]]) - indices = constant_op.constant([0, 4, 0, 2]) - gather_t = array_ops.gather(params, indices) - gather_val = gather_t.eval() - self.assertAllEqual([[0, 1, 2], [12, 13, 14], [0, 1, 2], [6, 7, 8]], - gather_val) - self.assertEqual([4, 3], gather_t.get_shape()) + data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], + [9, 10, 11], [12, 13, 14]]) + for dtype in (dtypes.float32, dtypes.complex64, dtypes.complex128): + params_np = self._buildParams(data, dtype) + params = constant_op.constant(params_np) + indices = constant_op.constant([0, 4, 0, 2]) + gather_t = array_ops.gather(params, indices) + gather_val = gather_t.eval() + self.assertAllEqual(params_np[[0, 4, 0, 2]], gather_val) + self.assertEqual([4, 3], gather_t.get_shape()) def testHigherRank(self): np.random.seed(1) # We check that scalar and empty shapes work as well for shape in (7, 0), (4, 3, 2): for indices_shape in (), (0,), (3, 0), (3, 5): - params = np.random.randn(*shape) - indices = np.random.randint(shape[0], size=indices_shape) - with self.test_session(use_gpu=self.use_gpu): - tf_params = constant_op.constant(params) - tf_indices = constant_op.constant(indices) - gather = array_ops.gather(tf_params, tf_indices) - self.assertAllEqual(params[indices], gather.eval()) - self.assertEqual(indices.shape + params.shape[1:], gather.get_shape()) - # Test gradients - gather_grad = np.random.randn(*gather.get_shape().as_list()) - params_grad, indices_grad = gradients_impl.gradients( - gather, [tf_params, tf_indices], gather_grad) - self.assertEqual(indices_grad, None) - self.assertEqual(type(params_grad), ops.IndexedSlices) - params_grad = ops.convert_to_tensor(params_grad) - correct_params_grad = np.zeros(shape) - for i, g in zip(indices.flat, - gather_grad.reshape((indices.size,) + shape[1:])): - correct_params_grad[i] += g - self.assertAllClose(correct_params_grad, params_grad.eval()) + for dtype in (dtypes.float32, dtypes.complex64, dtypes.complex128): + params = self._buildParams(np.random.randn(*shape), dtype) + indices = np.random.randint(shape[0], size=indices_shape) + with self.test_session(use_gpu=self.use_gpu): + tf_params = constant_op.constant(params) + tf_indices = constant_op.constant(indices) + gather = array_ops.gather(tf_params, tf_indices) + self.assertAllEqual(params[indices], gather.eval()) + self.assertEqual(indices.shape + params.shape[1:], + gather.get_shape()) + # Test gradients + gather_grad = np.random.randn(*gather.get_shape().as_list()).astype( + dtype.as_numpy_dtype) + params_grad, indices_grad = gradients_impl.gradients( + gather, [tf_params, tf_indices], gather_grad) + self.assertEqual(indices_grad, None) + self.assertEqual(type(params_grad), ops.IndexedSlices) + params_grad = ops.convert_to_tensor(params_grad) + correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype) + for i, g in zip(indices.flat, + gather_grad.reshape((indices.size,) + shape[1:])): + correct_params_grad[i] += g + self.assertAllClose(correct_params_grad, params_grad.eval()) def testUnknownIndices(self): params = constant_op.constant([[0, 1, 2]]) @@ -103,7 +122,7 @@ class GatherTest(test.TestCase): def testEmptySlices(self): with self.test_session(use_gpu=self.use_gpu): - for dtype in np.float32, np.float64: + for dtype in np.float32, np.float64, np.complex64, np.complex128: for itype in np.int32, np.int64: params = np.zeros((7, 0), dtype=dtype) indices = np.array([3, 4], dtype=itype) |