aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-17 13:42:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-17 14:52:08 -0700
commit4d55cb5c55fcc85009171b6a4657cbd966fd85e5 (patch)
tree303d69d5be3054475262d49ceefc9c116ec5da40 /tensorflow
parent2709b2303b39b848b9c6a86cd89811820d390580 (diff)
Enable fp16 support for MatMul via cuBLAS, gated on compilation with
CUDA 7.5 or higher. (If not, the GPU tests will also not be run, so that the tests as a whole will keep passing on CUDA 7.0.) Change: 122566230
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/kernels/matmul_op.cc5
-rw-r--r--tensorflow/core/kernels/transpose_functor_gpu.cu.cc1
-rw-r--r--tensorflow/core/ops/math_grad.cc2
-rw-r--r--tensorflow/core/ops/math_ops.cc4
-rw-r--r--tensorflow/core/util/port.cc14
-rw-r--r--tensorflow/core/util/port.h4
-rw-r--r--tensorflow/python/framework/test_util.py4
-rw-r--r--tensorflow/python/kernel_tests/matmul_op_test.py58
-rw-r--r--tensorflow/python/util/port.i1
9 files changed, 88 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index d90cf550f0..08b699c8fa 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/kernels/fill_functor.h"
#if GOOGLE_CUDA
+#include "third_party/gpus/cuda/include/cuda.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
@@ -205,9 +206,13 @@ REGISTER_CPU(float);
REGISTER_CPU(double);
REGISTER_CPU(int32);
REGISTER_CPU(complex64);
+REGISTER_CPU(Eigen::half);
#if GOOGLE_CUDA
REGISTER_GPU(float);
// REGISTER_GPU(double);
+#if CUDA_VERSION >= 7050
+REGISTER_GPU(Eigen::half);
+#endif
#endif // GOOGLE_CUDA
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc
index 18d675443e..3febba0441 100644
--- a/tensorflow/core/kernels/transpose_functor_gpu.cu.cc
+++ b/tensorflow/core/kernels/transpose_functor_gpu.cu.cc
@@ -110,6 +110,7 @@ Status DoTranspose<Device>(const Device& d, const Tensor& in,
break;
case DT_BFLOAT16:
+ case DT_HALF:
case DT_INT16:
case DT_QINT16:
case DT_QUINT16:
diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc
index e44d4426dc..d290580077 100644
--- a/tensorflow/core/ops/math_grad.cc
+++ b/tensorflow/core/ops/math_grad.cc
@@ -523,7 +523,7 @@ static Status MatMulGradHelper(FunctionDef* g, const string& opname,
// Ret val defs
{"dx: T", "dy: T"},
// Attr defs
- {{"T: {float, double}"}},
+ {{"T: {half, float, double}"}},
// Nodes
{
{{"dx"},
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index ba0b5e4bbb..861ed74b1f 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -37,7 +37,7 @@ REGISTER_OP("BatchMatMul")
.Input("x: T")
.Input("y: T")
.Output("output: T")
- .Attr("T: {float, double, int32, complex64}")
+ .Attr("T: {half, float, double, int32, complex64}")
.Attr("adj_x: bool = false")
.Attr("adj_y: bool = false")
.Doc(R"doc(
@@ -577,7 +577,7 @@ REGISTER_OP("MatMul")
.Output("product: T")
.Attr("transpose_a: bool = false")
.Attr("transpose_b: bool = false")
- .Attr("T: {float, double, int32, complex64}")
+ .Attr("T: {half, float, double, int32, complex64}")
.Doc(R"doc(
Multiply the matrix "a" by the matrix "b".
diff --git a/tensorflow/core/util/port.cc b/tensorflow/core/util/port.cc
index e43cba9203..5d61ecbbae 100644
--- a/tensorflow/core/util/port.cc
+++ b/tensorflow/core/util/port.cc
@@ -15,6 +15,10 @@ limitations under the License.
#include "tensorflow/core/util/port.h"
+#if GOOGLE_CUDA
+#include "third_party/gpus/cuda/include/cuda.h"
+#endif
+
namespace tensorflow {
bool IsGoogleCudaEnabled() {
@@ -25,4 +29,14 @@ bool IsGoogleCudaEnabled() {
#endif
}
+bool CudaSupportsHalfMatMulAndConv() {
+#if GOOGLE_CUDA
+ // NOTE: We check compile-time and not runtime, since the check for
+ // whether we include the fp16 kernels or not is compile-time.
+ return CUDA_VERSION >= 7050;
+#else
+ return false;
+#endif
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/util/port.h b/tensorflow/core/util/port.h
index 009d8968da..65d628c841 100644
--- a/tensorflow/core/util/port.h
+++ b/tensorflow/core/util/port.h
@@ -21,6 +21,10 @@ namespace tensorflow {
// Returns true if GOOGLE_CUDA is defined.
bool IsGoogleCudaEnabled();
+// Returns true if GOOGLE_CUDA is defined, and the given CUDA version supports
+// half-precision matrix multiplications and convolution operations.
+bool CudaSupportsHalfMatMulAndConv();
+
} // end namespace tensorflow
#endif // TENSORFLOW_UTIL_PORT_H_
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 15c342f38e..c262a00985 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -103,6 +103,10 @@ def IsGoogleCudaEnabled():
return pywrap_tensorflow.IsGoogleCudaEnabled()
+def CudaSupportsHalfMatMulAndConv():
+ return pywrap_tensorflow.CudaSupportsHalfMatMulAndConv()
+
+
class TensorFlowTestCase(googletest.TestCase):
"""Base class for tests that need to test TensorFlow.
"""
diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py
index aa291cdbdd..3d39af8d85 100644
--- a/tensorflow/python/kernel_tests/matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/matmul_op_test.py
@@ -22,17 +22,40 @@ import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
class MatMulTest(tf.test.TestCase):
+ def assertAllCloseAccordingToType(self, a, b, rtol=1e-6, atol=1e-6):
+ """Like test_util.assertAllCloseToType, but with looser fp16 limits.
+
+ With matrix multiplication, many values are summed, compounding
+ accuracy issues. Thus, we set fp16 tolerance to 1e-2 instead of 1e-6.
+ (This primarily affects the CPU versions, which accumulate in fp16;
+ the CUDA versions currently use fp32 math internally.)
+
+ Args:
+ a: a numpy ndarray or anything can be converted to one.
+ b: a numpy ndarray or anything can be converted to one.
+ rtol: relative tolerance
+ atol: absolute tolerance
+ """
+ a = self._GetNdArray(a)
+ b = self._GetNdArray(b)
+ if a.dtype == np.float16 or b.dtype == np.float16:
+ rtol = max(rtol, 1e-2)
+ atol = max(atol, 1e-2)
+
+ self.assertAllClose(a, b, rtol=rtol, atol=atol)
+
def _testCpuMatmul(self, x, y, transpose_x=False, transpose_y=False):
x_mat = np.matrix(x).T if transpose_x else np.matrix(x)
y_mat = np.matrix(y).T if transpose_y else np.matrix(y)
np_ans = x_mat * y_mat
with self.test_session(use_gpu=False):
tf_ans = tf.matmul(x, y, transpose_x, transpose_y).eval()
- self.assertAllClose(np_ans, tf_ans)
+ self.assertAllCloseAccordingToType(np_ans, tf_ans)
self.assertAllEqual(np_ans.shape, tf_ans.shape)
def _testGpuMatmul(self, x, y, transpose_x=False, transpose_y=False):
@@ -41,7 +64,7 @@ class MatMulTest(tf.test.TestCase):
np_ans = x_mat * y_mat
with self.test_session(use_gpu=True):
tf_ans = tf.matmul(x, y, transpose_x, transpose_y).eval()
- self.assertAllClose(np_ans, tf_ans)
+ self.assertAllCloseAccordingToType(np_ans, tf_ans)
self.assertAllEqual(np_ans.shape, tf_ans.shape)
def _randMatrix(self, rows, cols, dtype):
@@ -69,6 +92,15 @@ class MatMulTest(tf.test.TestCase):
y = np.arange(1., 3.).reshape([1, 2]).astype(np.float64)
self._testCpuMatmul(x, y)
+ def testHalfBasic(self):
+ x = np.arange(1., 5.).reshape([4, 1]).astype(np.float16)
+ y = np.arange(1., 3.).reshape([1, 2]).astype(np.float16)
+ self._testCpuMatmul(x, y)
+ if test_util.CudaSupportsHalfMatMulAndConv():
+ self._testGpuMatmul(x, y)
+ else:
+ print("Built without fp16 matmul support, skipping GPU test.")
+
def testInt32Basic(self):
x = np.arange(1., 5.).reshape([4, 1]).astype(np.int32)
y = np.arange(1., 3.).reshape([1, 2]).astype(np.int32)
@@ -95,6 +127,17 @@ class MatMulTest(tf.test.TestCase):
y = self._randMatrix(k, m, np.float64)
self._testCpuMatmul(x, y)
+ def testHalfRandom(self):
+ for _ in range(10):
+ n, k, m = np.random.randint(1, 10, size=3) # Smaller range than float.
+ x = self._randMatrix(n, k, np.float16)
+ y = self._randMatrix(k, m, np.float16)
+ self._testCpuMatmul(x, y)
+ if test_util.CudaSupportsHalfMatMulAndConv():
+ self._testGpuMatmul(x, y)
+ else:
+ print("Built without fp16 matmul support, skipping GPU test.")
+
def testInt32Random(self):
for _ in range(10):
n, k, m = np.random.randint(1, 100, size=3)
@@ -127,6 +170,17 @@ class MatMulTest(tf.test.TestCase):
y = self._randMatrix(m, k, np.float64)
self._testCpuMatmul(x, y, True, True)
+ def testHalfRandomTransposeBoth(self):
+ for _ in range(10):
+ n, k, m = np.random.randint(1, 10, size=3) # Smaller range than float.
+ x = self._randMatrix(k, n, np.float16)
+ y = self._randMatrix(m, k, np.float16)
+ self._testCpuMatmul(x, y, True, True)
+ if test_util.CudaSupportsHalfMatMulAndConv():
+ self._testGpuMatmul(x, y, True, True)
+ else:
+ print("Built without fp16 matmul support, skipping GPU test.")
+
def testMatMul_OutEmpty_A(self):
n, k, m = 0, 8, 3
x = self._randMatrix(n, k, np.float32)
diff --git a/tensorflow/python/util/port.i b/tensorflow/python/util/port.i
index 568658cd7f..c3c833faef 100644
--- a/tensorflow/python/util/port.i
+++ b/tensorflow/python/util/port.i
@@ -22,5 +22,6 @@ limitations under the License.
%ignoreall
%unignore tensorflow;
%unignore tensorflow::IsGoogleCudaEnabled;
+%unignore tensorflow::CudaSupportsHalfMatMulAndConv;
%include "tensorflow/core/util/port.h"
%unignoreall