aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-20 15:01:43 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-20 16:18:07 -0700
commita114a85e1254d53e73eae6b6412599b3c1b632d8 (patch)
treee700247f4713db8f211299727dbcaa9cd83c07b1
parent11230eb3b7def33b7cc0a67544876dfb6669d342 (diff)
Support half floats in BatchMatMul.
Clean up tests and extend coverage to all supported types. Change: 133766358
-rw-r--r--tensorflow/core/kernels/batch_matmul_op.cc24
-rw-r--r--tensorflow/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/batch_matmul_op_test.py206
3 files changed, 105 insertions, 126 deletions
diff --git a/tensorflow/core/kernels/batch_matmul_op.cc b/tensorflow/core/kernels/batch_matmul_op.cc
index b15891772d..5966ac1b86 100644
--- a/tensorflow/core/kernels/batch_matmul_op.cc
+++ b/tensorflow/core/kernels/batch_matmul_op.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/type_traits.h"
@@ -298,16 +299,21 @@ class BatchMatMul : public OpKernel {
Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
BatchMatMul<GPUDevice, TYPE>)
-REGISTER_CPU(float);
-REGISTER_CPU(double);
-REGISTER_CPU(int32);
-REGISTER_CPU(complex64);
-REGISTER_CPU(complex128);
+TF_CALL_float(REGISTER_CPU);
+TF_CALL_double(REGISTER_CPU);
+TF_CALL_half(REGISTER_CPU);
+TF_CALL_int32(REGISTER_CPU);
+TF_CALL_complex64(REGISTER_CPU);
+TF_CALL_complex128(REGISTER_CPU);
-#ifdef GOOGLE_CUDA
-REGISTER_GPU(float);
-REGISTER_GPU(complex64);
-REGISTER_GPU(complex128);
+#if GOOGLE_CUDA
+TF_CALL_float(REGISTER_GPU);
+TF_CALL_double(REGISTER_GPU);
+TF_CALL_complex64(REGISTER_GPU);
+TF_CALL_complex128(REGISTER_GPU);
+#if CUDA_VERSION >= 7050
+TF_CALL_half(REGISTER_GPU);
+#endif
#endif // GOOGLE_CUDA
#undef REGISTER_CPU
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 3608915933..bd75be440e 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -445,6 +445,7 @@ cuda_py_test(
additional_deps = [
"//tensorflow:tensorflow_py",
],
+ shard_count = 20,
)
cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/batch_matmul_op_test.py b/tensorflow/python/kernel_tests/batch_matmul_op_test.py
index 0b9338887a..418c761f50 100644
--- a/tensorflow/python/kernel_tests/batch_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/batch_matmul_op_test.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Tests for tensorflow.ops.tf.BatchMatMul."""
from __future__ import absolute_import
from __future__ import division
@@ -26,8 +25,6 @@ class BatchMatmulOpTest(tf.test.TestCase):
# Uses numpy to compute batch_matmul(x, y, adj_x, adj_y).
def _npBatchMatmul(self, x, y, adj_x, adj_y):
- assert x.ndim >= 3
- assert y.ndim >= 3
# output's shape depends on adj[0] and adj[1]
d0 = x.shape[-2] if not adj_x else x.shape[-1]
d2 = y.shape[-1] if not adj_y else y.shape[-2]
@@ -48,7 +45,7 @@ class BatchMatmulOpTest(tf.test.TestCase):
return z
# Test _npBatchMatMul works.
- def testSimpleNpVersion(self):
+ def testNpVersion(self):
x = np.array([0., 1., 2., 3.]).reshape([1, 2, 2])
y = np.array([1., 2., 3., 4.]).reshape([1, 2, 2])
z0 = self._npBatchMatmul(x, y, False, False)
@@ -62,149 +59,124 @@ class BatchMatmulOpTest(tf.test.TestCase):
self.assertTrue(np.array_equal(z0, z1))
z0 = self._npBatchMatmul(x, y, False, True)
- z1 = np.array([(2.-2.j), (-2.+2.j), (-2.+2.j), (2.-2.j)]).reshape([1, 2, 2])
+ z1 = np.array([(2. - 2.j), (-2. + 2.j), (-2. + 2.j), (2. - 2.j)]).reshape(
+ [1, 2, 2])
self.assertTrue(np.array_equal(z0, z1))
z0 = self._npBatchMatmul(x, y, True, False)
- z1 = np.array([(2.+2.j), (-2.+2.j), (2.-2.j), (2.+2.j)]).reshape([1, 2, 2])
+ z1 = np.array([(2. + 2.j), (-2. + 2.j), (2. - 2.j), (2. + 2.j)]).reshape(
+ [1, 2, 2])
self.assertTrue(np.array_equal(z0, z1))
# Compares _tfpBatchMatmul(x, y, alpha, adj) and _npBatchMatMul(x, y, alpha,
# adj)
- def _compare(self, x, y, adj_x, adj_y):
- with self.test_session(use_gpu=True):
+ def _compare(self, x_in, y_in, adj_x, adj_y):
+ x_t_shape = x_in.shape[:-2] + (x_in.shape[-1], x_in.shape[-2])
+ y_t_shape = y_in.shape[:-2] + (y_in.shape[-1], y_in.shape[-2])
+ x = x_in if not adj_x else x_in.reshape(x_t_shape)
+ y = y_in if not adj_y else y_in.reshape(y_t_shape)
+ is_floating = x.dtype != np.int32
+ tol = 100 * np.finfo(x.dtype).eps if is_floating else 0
+
+ with self.test_session(use_gpu=is_floating):
z0 = tf.batch_matmul(x, y, adj_x=adj_x, adj_y=adj_y)
z0_val = z0.eval()
- z1 = self._npBatchMatmul(x, y, adj_x, adj_y)
- self.assertShapeEqual(z1, z0)
- if z0_val.size != 0:
- err = (np.abs(z0_val - z1) / np.maximum(1, np.abs(z0_val))).max()
- tf.logging.info("error = %f", err)
- self.assertTrue(err < 1e-4)
-
- # Returns a random float np of "shape".
- def _randFloat(self, shape):
- vals = np.random.normal(0, 1, np.prod(shape)).reshape(shape)
- return np.array(vals, dtype=np.float32)
-
- def testSimpleFloat(self):
- self._compare(self._randFloat([7, 2, 3]), self._randFloat([7, 3, 5]),
- False, False)
- self._compare(self._randFloat([7, 2, 3]), self._randFloat([7, 5, 3]),
- False, True)
- self._compare(self._randFloat([7, 3, 2]), self._randFloat([7, 3, 5]),
- True, False)
- self._compare(self._randFloat([7, 3, 2]), self._randFloat([7, 5, 3]),
- True, True)
-
- def testLargeFloat(self):
- self._compare(self._randFloat([10, 64, 75]),
- self._randFloat([10, 75, 30]), False, False)
- self._compare(self._randFloat([10, 75, 64]),
- self._randFloat([10, 75, 30]), True, False)
- self._compare(self._randFloat([10, 64, 75]),
- self._randFloat([10, 30, 75]), False, True)
- self._compare(self._randFloat([10, 75, 64]),
- self._randFloat([10, 30, 75]), True, True)
-
- def testHighNDims(self):
- self._compare(self._randFloat([5, 7, 2, 3]),
- self._randFloat([5, 7, 3, 5]), False, False)
- self._compare(self._randFloat([5, 7, 3, 2]),
- self._randFloat([5, 7, 3, 5]), True, False)
- self._compare(self._randFloat([5, 7, 2, 3]),
- self._randFloat([5, 7, 5, 3]), False, True)
- self._compare(self._randFloat([5, 7, 3, 2]),
- self._randFloat([5, 7, 5, 3]), True, True)
-
- # Returns a random complex numpy array of "shape".
- def _randComplex(self, shape):
- real = np.random.normal(0, 1, np.prod(shape))
- imag = np.random.normal(0, 1, np.prod(shape))
- vals = [np.complex(v[0], v[1]) for v in zip(real, imag)]
- return np.array(vals, dtype=np.complex64).reshape(shape)
-
- def testSimpleComplex(self):
- self._compare(self._randComplex([7, 2, 3]),
- self._randComplex([7, 3, 5]), False, False)
- self._compare(self._randComplex([7, 2, 3]),
- self._randComplex([7, 5, 3]), False, True)
- self._compare(self._randComplex([7, 3, 2]),
- self._randComplex([7, 3, 5]), True, False)
- self._compare(self._randComplex([7, 3, 2]),
- self._randComplex([7, 5, 3]), True, True)
-
- def testLargeComplex(self):
- self._compare(self._randComplex([10, 64, 75]),
- self._randComplex([10, 75, 30]), False, False)
- self._compare(self._randComplex([10, 64, 75]),
- self._randComplex([10, 30, 75]), False, True)
- self._compare(self._randComplex([10, 75, 64]),
- self._randComplex([10, 75, 30]), True, False)
- self._compare(self._randComplex([10, 75, 64]),
- self._randComplex([10, 30, 75]), True, True)
-
- def testEmpty(self):
- self._compare(np.zeros([0, 3, 2]).astype(np.float32),
- np.zeros([0, 2, 4]).astype(np.float32), False, False)
- self._compare(np.zeros([3, 2, 0]).astype(np.float32),
- np.zeros([3, 0, 5]).astype(np.float32), False, False)
- self._compare(np.zeros([3, 0, 2]).astype(np.float32),
- np.zeros([3, 2, 5]).astype(np.float32), False, False)
- self._compare(np.zeros([3, 3, 2]).astype(np.float32),
- np.zeros([3, 2, 0]).astype(np.float32), False, False)
+ z1 = self._npBatchMatmul(x, y, adj_x, adj_y)
+ self.assertAllClose(z0_val, z1, rtol=tol, atol=tol)
+
+ def _rand(self, shape, dtype):
+ vals = np.array(np.random.normal(-10, 10, np.prod(shape)), dtype=dtype)
+ if dtype in (np.complex64, np.complex128):
+ imag = np.array(np.random.normal(-10, 10, np.prod(shape)), dtype=dtype)
+ vals += 1j * imag
+ return vals.reshape(shape)
+
+ def _testNonEmpty(self, dtype, adj_x, adj_y):
+ self._compare(
+ self._rand([7, 2, 3], dtype), self._rand([7, 3, 5], dtype), adj_x,
+ adj_y)
+ self._compare(
+ self._rand([10, 64, 75], dtype), self._rand([10, 75, 30], dtype), adj_x,
+ adj_y)
+ self._compare(
+ self._rand([5, 7, 2, 3], dtype), self._rand([5, 7, 3, 5], dtype), adj_x,
+ adj_y)
+
+ def _testEmpty(self, dtype, adj_x, adj_y):
+ self._compare(
+ np.zeros([0, 3, 2]).astype(dtype), np.zeros([0, 2, 4]).astype(dtype),
+ adj_x, adj_y)
+ self._compare(
+ np.zeros([3, 0, 2]).astype(dtype), np.zeros([3, 2, 5]).astype(dtype),
+ adj_x, adj_y)
+ self._compare(
+ np.zeros([3, 3, 2]).astype(dtype), np.zeros([3, 2, 0]).astype(dtype),
+ adj_x, adj_y)
+
+
+def _GetBatchMatmulOpTest(dtype, adj_x, adj_y):
+
+ def Test(self):
+ self._testNonEmpty(dtype, adj_x, adj_y)
+ self._testEmpty(dtype, adj_x, adj_y)
+
+ return Test
class BatchMatmulGradientTest(tf.test.TestCase):
# loss = sum(batch_matmul(x, y)). Verify dl/dx and dl/dy via the
# gradient checker.
- def _checkGrad(self, x, y, adj_x, adj_y):
- assert 3 == x.ndim
- assert 3 == y.ndim
+ def _checkGrad(self, x_in, y_in, adj_x, adj_y):
+ x_t_shape = x_in.shape[:-2] + (x_in.shape[-1], x_in.shape[-2])
+ y_t_shape = y_in.shape[:-2] + (y_in.shape[-1], y_in.shape[-2])
+ x = x_in if not adj_x else x_in.reshape(x_t_shape)
+ y = y_in if not adj_y else y_in.reshape(y_t_shape)
+ epsilon = np.finfo(x.dtype).eps
+ delta = epsilon**(1.0 / 3.0)
with self.test_session(use_gpu=True):
- inx = tf.convert_to_tensor(x)
- iny = tf.convert_to_tensor(y)
+ inx = tf.constant(x)
+ iny = tf.constant(y)
z = tf.batch_matmul(inx, iny, adj_x, adj_y)
loss = tf.reduce_sum(z)
- epsilon = 1e-2
((x_jacob_t, x_jacob_n),
(y_jacob_t, y_jacob_n)) = tf.test.compute_gradient(
- [inx, iny],
- [x.shape, y.shape],
- loss,
- [1],
+ [inx, iny], [x.shape, y.shape],
+ loss, [1],
x_init_value=[x, y],
- delta=epsilon)
-
- tf.logging.info("x_jacob_t = %s", x_jacob_t.reshape(x.shape))
- tf.logging.info("x_jacob_n = %s", x_jacob_n.reshape(x.shape))
- self.assertAllClose(x_jacob_t, x_jacob_n, rtol=1e-2, atol=epsilon)
- tf.logging.info("y_jacob_t = %s", y_jacob_t.reshape(y.shape))
- tf.logging.info("y_jacob_n = %s", y_jacob_n.reshape(y.shape))
- self.assertAllClose(y_jacob_t, y_jacob_n, rtol=1e-2, atol=epsilon)
+ delta=delta)
+ tol = 20 * delta
+ self.assertAllClose(x_jacob_t, x_jacob_n, rtol=tol, atol=tol)
+ self.assertAllClose(y_jacob_t, y_jacob_n, rtol=tol, atol=tol)
# Tests a batched matmul of x, and y: x is a 3D tensor of shape [b,
# n, k] y is a 3D tensor of shape [b, k, m] the batched matmul
# computes z of shape [b, n, m], where z[i, :, :] = x[i, :, :]
# matmul y[i, :, :]
- def _compare(self, b, n, k, m):
- x = np.random.normal(0, 1, b * n * k).astype(np.float32).reshape([b, n, k])
- y = np.random.normal(0, 1, b * k * m).astype(np.float32).reshape([b, k, m])
- self._checkGrad(x, y, False, False)
- self._checkGrad(x.reshape([b, k, n]), y, True, False)
- self._checkGrad(x, y.reshape([b, m, k]), False, True)
- self._checkGrad(x.reshape([b, k, n]), y.reshape([b, m, k]), True, True)
+ def _compare(self, b, n, k, m, dtype, adj_x, adj_y):
+ x = np.random.normal(0, 1, b * n * k).astype(dtype).reshape([b, n, k])
+ y = np.random.normal(0, 1, b * k * m).astype(dtype).reshape([b, k, m])
+ self._checkGrad(x, y, adj_x, adj_y)
+
- def testSmall(self):
- self._compare(1, 2, 3, 5)
+def _GetBatchMatmulGradientTest(dtype, adj_x, adj_y):
- def testMedium(self):
- self._compare(3, 4, 7, 10)
+ def Test(self):
+ self._compare(1, 2, 3, 5, dtype, adj_x, adj_y)
+ self._compare(3, 4, 7, 10, dtype, adj_x, adj_y)
- # Can't do testLarge using very large inputs because gradient
- # checker will take way too long time.
+ return Test
if __name__ == "__main__":
+ for dtype_ in [np.float16, np.float32, np.float64, np.complex64,
+ np.complex128, np.int32]:
+ for adj_x_ in False, True:
+ for adj_y_ in False, True:
+ name = "%s_%s_%s" % (dtype_.__name__, adj_x_, adj_y_)
+ setattr(BatchMatmulOpTest, "testBatchMatmulOp_" + name,
+ _GetBatchMatmulOpTest(dtype_, adj_x_, adj_y_))
+ if dtype_ is not np.int32:
+ setattr(BatchMatmulGradientTest, "testBatchMatmulGradient_" + name,
+ _GetBatchMatmulGradientTest(dtype_, adj_x_, adj_y_))
tf.test.main()