diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc | 422 |
1 files changed, 422 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc new file mode 100644 index 0000000000..85198d89f5 --- /dev/null +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -0,0 +1,422 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/nn_ops.cc. + +#ifdef INTEL_MKL + +#include <algorithm> +#include <vector> + +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/core/util/use_cudnn.h" +#include "tensorflow/core/util/work_sharder.h" + +#include "third_party/mkl/include/mkl_dnn.h" +#include "third_party/mkl/include/mkl_dnn_types.h" +#include "tensorflow/core/util/mkl_util.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template <typename Device, class T> +class MklConv2DCustomBackpropFilterOp : public OpKernel { + public: + explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context) + : OpKernel(context) { + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + int stride_n = GetTensorDim(strides_, data_format_, 'N'); + int stride_c = GetTensorDim(strides_, data_format_, 'C'); + OP_REQUIRES( + context, (stride_n == 1 && stride_c == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + MklConv2DGradFilterOpContext mkl_context; + const Tensor& input = MklGetInput(context, 0); + GetMklShape(context, 0, &(mkl_context.input_shape)); + bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor(); + + const Tensor& filter_sizes = MklGetInput(context, 1); + + const Tensor& out_backprop = MklGetInput(context, 2); + GetMklShape(context, 2, &(mkl_context.out_backprop_shape)); + bool out_backprop_in_mkl_format = + mkl_context.out_backprop_shape.IsMklTensor(); + + TensorShape input_shape, filter_shape, out_backprop_shape; + + OP_REQUIRES( + context, TensorShapeUtils::IsVector(filter_sizes.shape()), + errors::InvalidArgument( + "Conv2DCustomBackpropFilter: filter_sizes input must be 1-dim, " + "not ", + filter_sizes.dims())); + OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( + filter_sizes.vec<int32>(), &filter_shape)); + + ConvBackpropDimensions backprop_dims; + + // Generate shape for input if input is in MKL format. + if (input_in_mkl_format) { + OP_REQUIRES(context, mkl_context.input_shape.GetDimension() == 4, + errors::InvalidArgument( + "Conv2DCustomBackpropFilter: input size must be 4-dim")); + + MklSizesToTFSizes(context, data_format_, mkl_context.input_shape, + &input_shape); + } else { + input_shape = input.shape(); + } + + // Generate shape for outback prop if input is in MKL format. + if (out_backprop_in_mkl_format) { + OP_REQUIRES( + context, mkl_context.out_backprop_shape.GetDimension() == 4, + errors::InvalidArgument( + "Conv2DCustomBackpropFilter: outbackprop size must be 4-dim")); + + MklSizesToTFSizes(context, data_format_, mkl_context.out_backprop_shape, + &out_backprop_shape); + } else { + out_backprop_shape = out_backprop.shape(); + } + + OP_REQUIRES_OK(context, + ConvBackpropComputeDimensions( + "Conv2DCustomBackpropFilter", /*num_spatial_dims=*/2, + input_shape, filter_shape, out_backprop_shape, strides_, + padding_, data_format_, &backprop_dims)); + + int64 pad_top, pad_bottom; + int64 pad_left, pad_right; + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + backprop_dims.spatial_dims[0].input_size, + backprop_dims.spatial_dims[0].filter_size, + backprop_dims.spatial_dims[0].stride, padding_, + &backprop_dims.spatial_dims[0].output_size, + &pad_top, &pad_bottom)); + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + backprop_dims.spatial_dims[1].input_size, + backprop_dims.spatial_dims[1].filter_size, + backprop_dims.spatial_dims[1].stride, padding_, + &backprop_dims.spatial_dims[1].output_size, + &pad_left, &pad_right)); + + // Create MKL primitives for convolution filter grad + mkl_context.in_dims = input_in_mkl_format + ? mkl_context.input_shape.GetDimension() + : input.dims(); + mkl_context.out_dims = out_backprop_in_mkl_format + ? mkl_context.out_backprop_shape.GetDimension() + : out_backprop.dims(); + mkl_context.in_sizes[0] = + static_cast<size_t>(backprop_dims.spatial_dims[1].input_size); + mkl_context.in_sizes[1] = + static_cast<size_t>(backprop_dims.spatial_dims[0].input_size); + mkl_context.in_sizes[2] = static_cast<size_t>(backprop_dims.in_depth); + mkl_context.in_sizes[3] = static_cast<size_t>(backprop_dims.batch_size); + mkl_context.out_sizes[0] = + static_cast<size_t>(backprop_dims.spatial_dims[1].output_size); + mkl_context.out_sizes[1] = + static_cast<size_t>(backprop_dims.spatial_dims[0].output_size); + mkl_context.out_sizes[2] = static_cast<size_t>(backprop_dims.out_depth); + mkl_context.out_sizes[3] = static_cast<size_t>(backprop_dims.batch_size); + mkl_context.input_offsets[0] = static_cast<int>(-pad_left); + mkl_context.input_offsets[1] = static_cast<int>(-pad_top); + mkl_context.conv_strides[0] = + static_cast<size_t>(backprop_dims.spatial_dims[1].stride); + mkl_context.conv_strides[1] = + static_cast<size_t>(backprop_dims.spatial_dims[0].stride); + + GetStridesFromSizes(data_format_, mkl_context.in_strides, + mkl_context.in_sizes); + GetStridesFromSizes(data_format_, mkl_context.out_strides, + mkl_context.out_sizes); + + // MKL understands dimensions in 0, 1, 2, and 3 indices denotes + // filter cols, rows, input channels, and output depth/channels. + mkl_context.filter_dims = 4; + mkl_context.filter_sizes[0] = backprop_dims.spatial_dims[1].filter_size; + mkl_context.filter_sizes[1] = backprop_dims.spatial_dims[0].filter_size; + mkl_context.filter_sizes[2] = backprop_dims.in_depth; + mkl_context.filter_sizes[3] = backprop_dims.out_depth; + + // We want filter grad to be in TF format, so + // make the strides accordingly to reflect this fact. + // Note TF filter layout : (rows, cols, in_depth, out_depth), + // while row is the innermost dimension. + mkl_context.filter_strides[0] = + backprop_dims.out_depth * backprop_dims.in_depth; + mkl_context.filter_strides[1] = backprop_dims.out_depth * + backprop_dims.in_depth * + backprop_dims.spatial_dims[1].filter_size; + mkl_context.filter_strides[2] = backprop_dims.out_depth; + mkl_context.filter_strides[3] = 1; + + mkl_context.conv_strides[0] = backprop_dims.spatial_dims[1].stride; + mkl_context.conv_strides[1] = backprop_dims.spatial_dims[0].stride; + + // Create convolution-grad-filter primitive + CHECK_EQ(dnnConvolutionCreateBackwardFilter_F32( + &mkl_context.prim_conv_bwdfilter, nullptr, + dnnAlgorithmConvolutionDirect, mkl_context.in_dims, + mkl_context.in_sizes, mkl_context.out_sizes, + mkl_context.filter_sizes, mkl_context.conv_strides, + mkl_context.input_offsets, dnnBorderZeros), + E_SUCCESS); + + // Create the layouts for entities in received context. + mkl_context.MklCreateInputLayouts(context); + + // Mkl needs the entities in its native format. + // So create temporary tensors along with buffers to + // convert the received entities. + Tensor mkl_tmp_input_buf_tensor, mkl_tmp_out_backprop_buf_tensor; + // This preparation sets (1) dnnResourceSrc (2) dnnResourceDiffDst + mkl_context.MklPrepareInputs(context, &mkl_tmp_input_buf_tensor, + &mkl_tmp_out_backprop_buf_tensor); + + // Final conv-grad-filter should be in TF layout. + Tensor* grad_filter; + mkl_context.grad_filter_shape.SetMklTensor(false); + mkl_context.grad_filter_shape.SetTfLayout(mkl_context.filter_dims, + mkl_context.filter_sizes, + mkl_context.filter_strides); + AllocateOutputSetMklshape(context, 0, &grad_filter, filter_shape, + mkl_context.grad_filter_shape); + + // Need to set member variable for TF layout + mkl_context.lt_grad_filter = mkl_context.grad_filter_shape.GetTfLayout(); + + // MKL conv-grad-filter might produce grad in its internal layout + Tensor mkl_tmp_grad_filter_buf_tensor; + // This preparation sets conversion primitive if required + // and allocates temporary tensor and its buffer without doing conversions. + // Also sets (3) dnnResourceDiffFilter accordingly + mkl_context.MklPrepareGradFilter(context, grad_filter, + &mkl_tmp_grad_filter_buf_tensor); + + // After setting all the required dnnResources, ready for execution! + CHECK_EQ( + dnnExecute_F32(mkl_context.prim_conv_bwdfilter, mkl_context.conv_res), + E_SUCCESS); + + // Convert grad-filter to TF layout + if (mkl_context.convert_bwdfilter != nullptr) { + void* mkl_buf_convert_grad_filter = + const_cast<void*>(static_cast<const void*>( + mkl_tmp_grad_filter_buf_tensor.flat<T>().data())); + void* mkl_buf_grad_filter = const_cast<void*>( + static_cast<const void*>(grad_filter->flat<T>().data())); + CHECK_EQ(dnnConversionExecute_F32(mkl_context.convert_bwdfilter, + mkl_buf_convert_grad_filter, + mkl_buf_grad_filter), + E_SUCCESS); + } + + mkl_context.MklCleanup(); + } + + private: + typedef struct { + int in_dims; + size_t in_sizes[4]; + size_t in_strides[4]; + int out_dims; + size_t out_sizes[4]; + size_t out_strides[4]; + int filter_dims; + size_t filter_sizes[4]; + size_t filter_strides[4]; + int input_offsets[2]; + size_t conv_strides[2]; + MklShape input_shape, grad_filter_shape, out_backprop_shape; + dnnPrimitive_t prim_conv_bwdfilter, convert_bwdfilter; + dnnLayout_t lt_input, lt_grad_filter, lt_out_backprop; + void* conv_res[dnnResourceNumber]; + + void MklCleanup() { + // Cleanup member layouts and primitives except "lt_grad_filter_" + // which points to MklShape's TFLayout + bool input_in_mkl_format = input_shape.IsMklTensor(); + bool out_backprop_in_mkl_format = out_backprop_shape.IsMklTensor(); + if (!input_in_mkl_format) dnnLayoutDelete_F32(lt_input); + if (!out_backprop_in_mkl_format) dnnLayoutDelete_F32(lt_out_backprop); + if (convert_bwdfilter != nullptr) dnnDelete_F32(convert_bwdfilter); + dnnDelete_F32(prim_conv_bwdfilter); + } + + // Create MKL dnnLayout_t objects for tensors coming into the layer + void MklCreateInputLayouts(OpKernelContext* context) { + bool input_in_mkl_format = input_shape.IsMklTensor(); + if (input_in_mkl_format) { + lt_input = static_cast<dnnLayout_t>(input_shape.GetCurLayout()); + } else { + CHECK_EQ(dnnLayoutCreate_F32(<_input, in_dims, in_sizes, in_strides), + E_SUCCESS); + } + + bool out_backprop_in_mkl_format = out_backprop_shape.IsMklTensor(); + if (out_backprop_in_mkl_format) { + lt_out_backprop = + static_cast<dnnLayout_t>(out_backprop_shape.GetCurLayout()); + } else { + CHECK_EQ(dnnLayoutCreate_F32(<_out_backprop, out_dims, out_sizes, + out_strides), + E_SUCCESS); + } + } + + // Compare incoming tensor layouts with MKL preferred layouts and convert + // data to the preferred layout if necessary + void MklPrepareInputs(OpKernelContext* context, + Tensor* mkl_tmp_input_buf_tensor, + Tensor* mkl_tmp_out_backprop_buf_tensor) { + bool mkl_convert_input, mkl_convert_out_backprop; + dnnPrimitive_t mkl_prim_convert_input, mkl_prim_convert_out_backprop; + dnnLayout_t mkl_lt_internal_input, mkl_lt_internal_out_backprop; + void *mkl_buf_convert_input, *mkl_buf_convert_out_backprop; + + mkl_prim_convert_input = nullptr; + mkl_prim_convert_out_backprop = nullptr; + mkl_lt_internal_input = nullptr; + mkl_lt_internal_out_backprop = nullptr; + mkl_buf_convert_input = nullptr; + mkl_buf_convert_out_backprop = nullptr; + + // Compare with internal layouts and convert if needed + const Tensor& input = MklGetInput(context, 0); + void* mkl_buf_input = + const_cast<void*>(static_cast<const void*>(input.flat<T>().data())); + CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( + &mkl_lt_internal_input, prim_conv_bwdfilter, dnnResourceSrc), + E_SUCCESS); + mkl_convert_input = + !dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input); + if (mkl_convert_input) { + CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, lt_input, + mkl_lt_internal_input), + E_SUCCESS); + AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input, + &mkl_buf_convert_input); + CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input, + mkl_buf_convert_input), + E_SUCCESS); + dnnDelete_F32(mkl_prim_convert_input); + } + dnnLayoutDelete_F32(mkl_lt_internal_input); + + conv_res[dnnResourceSrc] = + (mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input; + + const Tensor& out_backprop = MklGetInput(context, 2); + void* mkl_buf_out_backprop = const_cast<void*>( + static_cast<const void*>(out_backprop.flat<T>().data())); + CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_out_backprop, + prim_conv_bwdfilter, + dnnResourceDiffDst), + E_SUCCESS); + mkl_convert_out_backprop = + !dnnLayoutCompare_F32(mkl_lt_internal_out_backprop, lt_out_backprop); + if (mkl_convert_out_backprop) { + CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_out_backprop, + lt_out_backprop, + mkl_lt_internal_out_backprop), + E_SUCCESS); + AllocTmpBuffer(context, mkl_tmp_out_backprop_buf_tensor, + lt_out_backprop, &mkl_buf_convert_out_backprop); + CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_out_backprop, + mkl_buf_out_backprop, + mkl_buf_convert_out_backprop), + E_SUCCESS); + dnnDelete_F32(mkl_prim_convert_out_backprop); + } + dnnLayoutDelete_F32(mkl_lt_internal_out_backprop); + + conv_res[dnnResourceDiffDst] = (mkl_convert_out_backprop) + ? mkl_buf_convert_out_backprop + : mkl_buf_out_backprop; + } + + void MklPrepareGradFilter(OpKernelContext* context, Tensor* grad_filter, + Tensor* mkl_tmp_grad_filter_buf_tensor) { + bool mkl_convert_grad_filter; + dnnLayout_t mkl_lt_internal_grad_filter = nullptr; + void* mkl_buf_convert_grad_filter = nullptr; + void* mkl_buf_grad_filter = const_cast<void*>( + static_cast<const void*>(grad_filter->flat<T>().data())); + CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_grad_filter, + prim_conv_bwdfilter, + dnnResourceDiffFilter), + E_SUCCESS); + mkl_convert_grad_filter = + !dnnLayoutCompare_F32(mkl_lt_internal_grad_filter, lt_grad_filter); + if (mkl_convert_grad_filter) { + CHECK_EQ(dnnConversionCreate_F32(&convert_bwdfilter, + mkl_lt_internal_grad_filter, + lt_grad_filter), + E_SUCCESS); + AllocTmpBuffer(context, mkl_tmp_grad_filter_buf_tensor, + mkl_lt_internal_grad_filter, + &mkl_buf_convert_grad_filter); + } + dnnLayoutDelete_F32(mkl_lt_internal_grad_filter); + + conv_res[dnnResourceDiffFilter] = (mkl_convert_grad_filter) + ? mkl_buf_convert_grad_filter + : mkl_buf_grad_filter; + } + } MklConv2DGradFilterOpContext; + + std::vector<int32> strides_; + Padding padding_; + TensorFormat data_format_; +}; + +#define REGISTER_MKL_FILTER_KERNELS(T) \ + REGISTER_KERNEL_BUILDER(Name("MklConv2DBackpropFilter") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T") \ + .Label(mkl_layer_registry::kMklLayerLabel), \ + MklConv2DCustomBackpropFilterOp<CPUDevice, T>); + +TF_CALL_float(REGISTER_MKL_FILTER_KERNELS); +#undef REGISTER_MKL_FILTER_KERNELS +} // namespace tensorflow + +#endif // INTEL_MKL |