aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2016-02-23 07:52:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-23 09:57:59 -0800
commit3972b66fa5e3bd89e50fc01b43c16b81b52e9319 (patch)
tree46d0590b200086741c67dc9782e67c929a21032b
parent93866d1407f31011d307bf3868ad72516419fea0 (diff)
Add license header and future imports to sparse_tensor_dense_matmul_grad_test.
This should fix Python 3 compatibility for this test. Change: 115339521
-rw-r--r--tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py79
-rw-r--r--tensorflow/python/ops/sparse_grad.py37
-rw-r--r--tensorflow/python/ops/standard_ops.py1
3 files changed, 116 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py
new file mode 100644
index 0000000000..19aa2dcc9b
--- /dev/null
+++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py
@@ -0,0 +1,79 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for the gradient of `tf.sparse_tensor_dense_matmul()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+
+class SparseTensorDenseMatMulGradientTest(tf.test.TestCase):
+
+ def _sparsify(self, x):
+ x[x < 0.5] = 0
+
+ non_zero = np.where(x)
+ x_indices = np.vstack(non_zero).astype(np.int64).T
+ x_values = x[non_zero]
+ x_shape = x.shape
+
+ return tf.SparseTensor(indices=x_indices, values=x_values, shape=x_shape)
+
+ def _randomTensor(self, size, np_dtype, adjoint=False, sparse=False):
+ n, m = size
+ x = np.random.randn(n, m).astype(np_dtype)
+
+ if adjoint:
+ x = x.transpose()
+
+ if sparse:
+ return self._sparsify(x)
+ else:
+ return tf.constant(x, dtype=np_dtype)
+
+ def _testGradients(self, adjoint_a, adjoint_b, name, np_dtype, use_gpu=False):
+ n, k, m = np.random.randint(1, 10, size=3)
+ sp_t = self._randomTensor([n, k], np_dtype, adjoint=adjoint_a, sparse=True)
+ dense_t = self._randomTensor([k, m], np_dtype, adjoint=adjoint_b)
+
+ matmul = tf.sparse_tensor_dense_matmul(
+ sp_t, dense_t, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name=name)
+
+ with self.test_session(use_gpu=use_gpu):
+ dense_t_shape = [m, k] if adjoint_b else [k, m]
+ err = tf.test.compute_gradient_error(dense_t, dense_t_shape, matmul,
+ [n, m])
+ print("%s gradient err = %s" % (name, err))
+ self.assertLess(err, 1e-3)
+
+ def _testGradientsType(self, np_dtype, use_gpu=False):
+ for adjoint_a in [True, False]:
+ for adjoint_b in [True, False]:
+ name = "sparse_tensor_dense_matmul_%s_%s_%s" % (adjoint_a, adjoint_b,
+ np_dtype.__name__)
+ self._testGradients(adjoint_a, adjoint_b, name, np_dtype, use_gpu)
+
+ def testGradients(self):
+ np.random.seed(5) # Fix seed to avoid flakiness
+ for use_gpu in [True, False]:
+ self._testGradientsType(np.float32, use_gpu)
+ self._testGradientsType(np.float64, use_gpu)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py
index ccf54d6b56..cecdde3fdd 100644
--- a/tensorflow/python/ops/sparse_grad.py
+++ b/tensorflow/python/ops/sparse_grad.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
@@ -31,4 +32,38 @@ ops.NoGradient("SparseConcat")
ops.NoGradient("SparseReorder")
-ops.NoGradient("SparseTensorDenseMatMul")
+@ops.RegisterGradient("SparseTensorDenseMatMul")
+def _SparseTensorDenseMatMulGrad(op, grad):
+ """Gradients for the dense tensor in the SparseTensorDenseMatMul op.
+
+ Gradients are only provided for the dense tensor.
+
+ If either input is complex, no gradient is provided.
+
+ Args:
+ op: the SparseTensorDenseMatMul op
+ grad: the incoming gradient
+
+ Returns:
+ Gradient for each of the 4 input tensors:
+ (sparse_indices, sparse_values, sparse_shape, dense_tensor)
+ The sparse tensor gradients are always None.
+ """
+ sp_t = ops.SparseTensor(*op.inputs[:3])
+ adj_a = op.get_attr("adjoint_a")
+ adj_b = op.get_attr("adjoint_b")
+
+ a_type = sp_t.values.dtype
+ b_type = op.inputs[3].dtype
+ assert a_type == b_type
+ is_complex = a_type == ops.dtypes.complex64
+ if is_complex:
+ raise NotImplementedError("SparseTensorDenseMatMul op does not support "
+ "complex gradients.")
+
+ b_grad = sparse_ops.sparse_tensor_dense_matmul(sp_t, grad,
+ adjoint_a=not adj_a)
+ if adj_b:
+ b_grad = array_ops.transpose(b_grad)
+
+ return (None, None, None, b_grad)
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index 9abd59441f..57ff78cdde 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -24,6 +24,7 @@ from __future__ import print_function
from tensorflow.python.ops import array_grad
from tensorflow.python.ops import data_flow_grad
from tensorflow.python.ops import math_grad
+from tensorflow.python.ops import sparse_grad
from tensorflow.python.ops import state_grad
from tensorflow.python.ops import tensor_array_grad