aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/conv_ops.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-20 17:13:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-20 17:24:09 -0800
commitb7a389d47127b631141f492108ae8f3d124d4a05 (patch)
tree257c4a6231b3d5b177260ba60148e2570c9c3e91 /tensorflow/core/kernels/conv_ops.cc
parentede3c12a11cd6858eef4de52b7697299743d4660 (diff)
Added support for TensorFlow backward convolutions to use libxsmm; weight
updates are not yet supported. There still might be some bugs in the libxsmm code for backward convolutions, so that mode is not recommended for use yet. Since this feature is experimental, "--define tensorflow_xsmm_backward=1" needs to be used on the build command line to enable it. Change: 142614258
Diffstat (limited to 'tensorflow/core/kernels/conv_ops.cc')
-rw-r--r--tensorflow/core/kernels/conv_ops.cc14
1 files changed, 6 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index f6e3b532aa..91cd1c4b9a 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -183,6 +183,8 @@ class LaunchXsmmConvOp<CPUDevice, float> {
int filter_cols, int pad_rows, int pad_cols, int out_rows,
int out_cols, int out_depth, int stride_rows, int stride_cols,
Tensor* output, TensorFormat data_format) {
+ auto num_threads =
+ ctx->device()->tensorflow_cpu_worker_threads()->num_threads;
// See libxsmm_dnn.h for this struct definition.
libxsmm_dnn_conv_desc desc;
desc.N = batch;
@@ -198,7 +200,7 @@ class LaunchXsmmConvOp<CPUDevice, float> {
desc.pad_w_in = pad_cols; // ignored by libxsmm for now.
desc.pad_h_out = 0;
desc.pad_w_out = 0;
- desc.threads = 0; // Unknown at this point, will be set later.
+ desc.threads = num_threads;
desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
desc.buffer_format = LIBXSMM_DNN_CONV_FORMAT_NHWC;
desc.filter_format = LIBXSMM_DNN_CONV_FORMAT_RSCK;
@@ -207,17 +209,13 @@ class LaunchXsmmConvOp<CPUDevice, float> {
desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
- if (!CanUseXsmmConv2D(desc, data_format)) {
- return false;
- }
-
auto input_ptr = input.template flat<float>().data();
auto filter_ptr = filter.template flat<float>().data();
auto output_ptr = output->template flat<float>().data();
- functor::XsmmConv2D<CPUDevice, float>()(ctx, desc, input_ptr, filter_ptr,
- output_ptr);
- return true;
+ bool success = functor::XsmmFwdConv2D<CPUDevice, float>()(
+ ctx, desc, input_ptr, filter_ptr, output_ptr);
+ return success;
}
};
#endif