diff options
author | Benoit Steiner <benoit.steiner.goog@gmail.com> | 2016-05-31 09:51:31 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-05-31 11:04:18 -0700 |
commit | 07792c757457e8ecf62c8d37038e458484eab78a (patch) | |
tree | fbfd0c191977e400c9b5accb4641a951f7dfd272 | |
parent | 3ce1d20108cfc190553bac98c17a53b23457f8bd (diff) |
Added support for convolutions of 16bit floats on CPU
Change: 123659102
-rw-r--r-- | eigen.BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/cmake/external/eigen.cmake | 4 | ||||
-rw-r--r-- | tensorflow/core/kernels/conv_2d.h | 15 | ||||
-rw-r--r-- | tensorflow/core/kernels/conv_grad_ops.cc | 28 | ||||
-rw-r--r-- | tensorflow/core/kernels/conv_ops.cc | 4 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/conv_ops_test.py | 27 | ||||
-rw-r--r-- | tensorflow/workspace.bzl | 4 | ||||
-rw-r--r-- | third_party/eigen3/Eigen/Cholesky | 2 | ||||
-rw-r--r-- | third_party/eigen3/Eigen/Core | 2 | ||||
-rw-r--r-- | third_party/eigen3/Eigen/Eigenvalues | 2 | ||||
-rw-r--r-- | third_party/eigen3/Eigen/LU | 2 | ||||
-rw-r--r-- | third_party/eigen3/Eigen/QR | 2 | ||||
-rw-r--r-- | third_party/eigen3/unsupported/Eigen/CXX11/Tensor | 2 |
13 files changed, 71 insertions, 25 deletions
diff --git a/eigen.BUILD b/eigen.BUILD index 16dd4f8422..79bafe65b6 100644 --- a/eigen.BUILD +++ b/eigen.BUILD @@ -1,6 +1,6 @@ package(default_visibility = ["//visibility:public"]) -archive_dir = "eigen-eigen-f3a13643ac1f" +archive_dir = "eigen-eigen-d02e6a705c30" cc_library( name = "eigen", diff --git a/tensorflow/contrib/cmake/external/eigen.cmake b/tensorflow/contrib/cmake/external/eigen.cmake index c1929a10f3..db409760fa 100644 --- a/tensorflow/contrib/cmake/external/eigen.cmake +++ b/tensorflow/contrib/cmake/external/eigen.cmake @@ -7,7 +7,7 @@ include (ExternalProject) -set(eigen_archive_hash "f3a13643ac1f") +set(eigen_archive_hash "d02e6a705c30") set(eigen_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR} @@ -16,7 +16,7 @@ set(eigen_INCLUDE_DIRS ${tensorflow_source_dir}/third_party/eigen3 ) set(eigen_URL https://bitbucket.org/eigen/eigen/get/${eigen_archive_hash}.tar.gz) -set(eigen_HASH SHA256=a9266e60366cddb371a23d86b11a297eee86372a89ef4b38a3509012f9cc37ec) +set(eigen_HASH SHA256=532956172daa8aba87c750791ff89a5c38cdb07e2525afe17ecb4bef812d67cf) set(eigen_BUILD ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen) set(eigen_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/eigen/install) diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h index 40ee4420bb..620a2fb703 100644 --- a/tensorflow/core/kernels/conv_2d.h +++ b/tensorflow/core/kernels/conv_2d.h @@ -71,6 +71,21 @@ struct SpatialConvolution { } }; +template <typename Device> +struct SpatialConvolution<Device, Eigen::half> { + void operator()(const Device& d, + typename TTypes<Eigen::half, 4>::Tensor output, + typename TTypes<Eigen::half, 4>::ConstTensor input, + typename TTypes<Eigen::half, 4>::ConstTensor filter, + int row_stride, int col_stride, + const Eigen::PaddingType& padding) { + output.device(d) = + Eigen::SpatialConvolution(input.cast<float>(), filter.cast<float>(), + col_stride, row_stride, padding) + .cast<Eigen::half>(); + } +}; + template <typename Device, typename T> struct SpatialConvolutionBackwardInput { void operator()(const Device& d, typename TTypes<T, 4>::Tensor input_backward, diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc index 77f6f40a13..a1fcff578a 100644 --- a/tensorflow/core/kernels/conv_grad_ops.cc +++ b/tensorflow/core/kernels/conv_grad_ops.cc @@ -625,18 +625,32 @@ class Conv2DCustomBackpropInputOp : public OpKernel { REGISTER_KERNEL_BUILDER( Name("Conv2DBackpropInput").Device(DEVICE_CPU).TypeConstraint<float>("T"), Conv2DCustomBackpropInputOp<CPUDevice, float>); +REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") + .Device(DEVICE_CPU) + .TypeConstraint<Eigen::half>("T"), + Conv2DCustomBackpropInputOp<CPUDevice, Eigen::half>); REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") .Device(DEVICE_CPU) .Label("custom") .TypeConstraint<float>("T"), Conv2DCustomBackpropInputOp<CPUDevice, float>); +REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") + .Device(DEVICE_CPU) + .Label("custom") + .TypeConstraint<Eigen::half>("T"), + Conv2DCustomBackpropInputOp<CPUDevice, Eigen::half>); REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") .Device(DEVICE_CPU) .Label("eigen_tensor") .TypeConstraint<float>("T"), Conv2DFastBackpropInputOp<CPUDevice, float>); +REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput") + .Device(DEVICE_CPU) + .Label("eigen_tensor") + .TypeConstraint<Eigen::half>("T"), + Conv2DFastBackpropInputOp<CPUDevice, Eigen::half>); template <typename Device, class T> class Conv2DFastBackpropFilterOp : public OpKernel { @@ -856,18 +870,32 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { REGISTER_KERNEL_BUILDER( Name("Conv2DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<float>("T"), Conv2DCustomBackpropFilterOp<CPUDevice, float>); +REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") + .Device(DEVICE_CPU) + .TypeConstraint<Eigen::half>("T"), + Conv2DCustomBackpropFilterOp<CPUDevice, Eigen::half>); REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") .Device(DEVICE_CPU) .Label("custom") .TypeConstraint<float>("T"), Conv2DCustomBackpropFilterOp<CPUDevice, float>); +REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") + .Device(DEVICE_CPU) + .Label("custom") + .TypeConstraint<Eigen::half>("T"), + Conv2DCustomBackpropFilterOp<CPUDevice, Eigen::half>); REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") .Device(DEVICE_CPU) .Label("eigen_tensor") .TypeConstraint<float>("T"), Conv2DFastBackpropFilterOp<CPUDevice, float>); +REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") + .Device(DEVICE_CPU) + .Label("eigen_tensor") + .TypeConstraint<Eigen::half>("T"), + Conv2DFastBackpropFilterOp<CPUDevice, Eigen::half>); // GPU definitions of both ops. #if GOOGLE_CUDA diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 8b982ef432..490b4ea422 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -248,7 +248,9 @@ class Conv2DOp : public BinaryOp<T> { REGISTER_KERNEL_BUILDER( Name("Conv2D").Device(DEVICE_CPU).TypeConstraint<float>("T"), Conv2DOp<CPUDevice, float>); - +REGISTER_KERNEL_BUILDER( + Name("Conv2D").Device(DEVICE_CPU).TypeConstraint<Eigen::half>("T"), + Conv2DOp<CPUDevice, Eigen::half>); #if GOOGLE_CUDA int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb, diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index de3b8d0691..ec723ae045 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -163,13 +163,13 @@ def GetTestConfigs(): class Conv2DTest(tf.test.TestCase): - def _DtypesToTest(self): - if test_util.CudaSupportsHalfMatMulAndConv(): + def _DtypesToTest(self, use_gpu): + if use_gpu and not test_util.CudaSupportsHalfMatMulAndConv(): + return [tf.float32] + else: # It is important that float32 comes before float16 here, # as we will be using its gradients as reference for fp16 gradients. return [tf.float32, tf.float16] - else: - return [tf.float32] def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, strides, padding, data_format, dtype, use_gpu): @@ -255,10 +255,9 @@ class Conv2DTest(tf.test.TestCase): def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, strides, padding, expected): - for dtype in self._DtypesToTest(): - print(dtype) - tensors = [] - for (data_format, use_gpu) in GetTestConfigs(): + tensors = [] + for (data_format, use_gpu) in GetTestConfigs(): + for dtype in self._DtypesToTest(use_gpu): result = self._SetupValuesForDevice(tensor_in_sizes, filter_in_sizes, strides, @@ -274,7 +273,10 @@ class Conv2DTest(tf.test.TestCase): value = values[i] print("expected = ", expected) print("actual = ", value) - self.assertAllCloseAccordingToType(expected, np.ravel(value)) + tol = 1e-5 + if value.dtype == np.float16: + tol = 1e-3 + self.assertAllClose(expected, np.ravel(value), atol=tol, rtol=tol) self.assertShapeEqual(value, conv) def testConv2D1x1Filter(self): @@ -360,7 +362,7 @@ class Conv2DTest(tf.test.TestCase): # strides=[4, 4], padding="SAME", # expected=[72, 112, 392, 432]) - # Testing for backprops + # Testing for backprops def _RunAndVerifyBackpropInput(self, input_sizes, filter_sizes, output_sizes, strides, padding, expected, data_format, use_gpu): @@ -506,7 +508,7 @@ class Conv2DTest(tf.test.TestCase): # numbers from 1. x0 = [f * 1.0 for f in range(1, total_input_size + 1)] x2 = [f * 1.0 for f in range(1, total_output_size + 1)] - for dtype in self._DtypesToTest(): + for dtype in self._DtypesToTest(use_gpu=use_gpu): with self.test_session(use_gpu=use_gpu) as sess: t0 = tf.constant(x0, shape=input_sizes, dtype=dtype) t1 = tf.constant(filter_sizes, shape=[len(filter_sizes)]) @@ -635,7 +637,7 @@ class Conv2DTest(tf.test.TestCase): # a problem in the way Eigen's Conv2DGrad works for double. # So we disable the DOUBLE path. We should re-enable this # when double support returns for CPU and/or GPU. - for dtype in self._DtypesToTest(): + for dtype in self._DtypesToTest(use_gpu=use_gpu): with self.test_session(use_gpu=use_gpu): input_tensor = tf.constant(input_data, shape=input_shape, dtype=dtype, name="input") @@ -935,7 +937,6 @@ class Conv2DTest(tf.test.TestCase): strides=[1, 1, 1, 1], padding="SAME") - # This is only a very simple test. More comprehensive tests live in # //learning/dist_belief/experimental/brain_compatibility/conv_nn_test.py # where we compare the numeric results of the depthwise conv op with the diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 8f19d80b0e..c1d27fb6a6 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -6,8 +6,8 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): native.new_http_archive( name = "eigen_archive", - url = "https://bitbucket.org/eigen/eigen/get/f3a13643ac1f.tar.gz", - sha256 = "a9266e60366cddb371a23d86b11a297eee86372a89ef4b38a3509012f9cc37ec", + url = "https://bitbucket.org/eigen/eigen/get/d02e6a705c30.tar.gz", + sha256 = "532956172daa8aba87c750791ff89a5c38cdb07e2525afe17ecb4bef812d67cf", build_file = path_prefix + "eigen.BUILD", ) diff --git a/third_party/eigen3/Eigen/Cholesky b/third_party/eigen3/Eigen/Cholesky index 7b196a8904..56059bcc61 100644 --- a/third_party/eigen3/Eigen/Cholesky +++ b/third_party/eigen3/Eigen/Cholesky @@ -1 +1 @@ -#include "eigen-eigen-f3a13643ac1f/Eigen/Cholesky" +#include "eigen-eigen-d02e6a705c30/Eigen/Cholesky" diff --git a/third_party/eigen3/Eigen/Core b/third_party/eigen3/Eigen/Core index 97361e5183..c1d4a2e0f8 100644 --- a/third_party/eigen3/Eigen/Core +++ b/third_party/eigen3/Eigen/Core @@ -1 +1 @@ -#include "eigen-eigen-f3a13643ac1f/Eigen/Core" +#include "eigen-eigen-d02e6a705c30/Eigen/Core" diff --git a/third_party/eigen3/Eigen/Eigenvalues b/third_party/eigen3/Eigen/Eigenvalues index a5f98ed870..0a0731ba19 100644 --- a/third_party/eigen3/Eigen/Eigenvalues +++ b/third_party/eigen3/Eigen/Eigenvalues @@ -1 +1 @@ -#include "eigen-eigen-f3a13643ac1f/Eigen/Eigenvalues" +#include "eigen-eigen-d02e6a705c30/Eigen/Eigenvalues" diff --git a/third_party/eigen3/Eigen/LU b/third_party/eigen3/Eigen/LU index 5172aece6c..d6b39b8d23 100644 --- a/third_party/eigen3/Eigen/LU +++ b/third_party/eigen3/Eigen/LU @@ -1 +1 @@ -#include "eigen-eigen-f3a13643ac1f/Eigen/LU" +#include "eigen-eigen-d02e6a705c30/Eigen/LU" diff --git a/third_party/eigen3/Eigen/QR b/third_party/eigen3/Eigen/QR index bd59f7adf2..a5406e93bc 100644 --- a/third_party/eigen3/Eigen/QR +++ b/third_party/eigen3/Eigen/QR @@ -1 +1 @@ -#include "eigen-eigen-f3a13643ac1f/Eigen/QR" +#include "eigen-eigen-d02e6a705c30/Eigen/QR" diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor index 8d363c3845..4f730236b7 100644 --- a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor +++ b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor @@ -1 +1 @@ -#include "eigen-eigen-f3a13643ac1f/unsupported/Eigen/CXX11/Tensor" +#include "eigen-eigen-d02e6a705c30/unsupported/Eigen/CXX11/Tensor" |