diff options
author | Dandelion Man? <dandelion@google.com> | 2017-12-15 17:12:41 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-15 17:16:29 -0800 |
commit | d55f532867a3670d66460c5ee3b774519542adc1 (patch) | |
tree | 7de4d85bcd61e93401459276b4d371ab0be23c1f /tensorflow/core/kernels/mkl_conv_ops.h | |
parent | 32d5048ae96116202f2aa0fa739ef37514ee8a54 (diff) |
Merge changes from github.
PiperOrigin-RevId: 179258973
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_ops.h')
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_ops.h | 269 |
1 files changed, 257 insertions, 12 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h index f0cb37f8a4..47a9b4bfc7 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.h +++ b/tensorflow/core/kernels/mkl_conv_ops.h @@ -41,6 +41,12 @@ limitations under the License. #ifdef INTEL_MKL_DNN #include "mkldnn.hpp" + +using mkldnn::stream; +using mkldnn::prop_kind; + +using mkldnn::convolution_forward; +using mkldnn::convolution_direct; #endif namespace tensorflow { @@ -108,7 +114,13 @@ class MklDnnConvUtil { #undef CHECK_BOUNDS // MKL-DNN always requires input in NCHW format. - *input_dims = {input_batch, input_depth, input_rows, input_cols}; + std::vector<int> mkldnn_sizes(4, -1); + mkldnn_sizes[MklDnnDims::Dim_N] = input_batch; + mkldnn_sizes[MklDnnDims::Dim_C] = input_depth; + mkldnn_sizes[MklDnnDims::Dim_H] = input_rows; + mkldnn_sizes[MklDnnDims::Dim_W] = input_cols; + + *input_dims = mkldnn_sizes; } // Calculate Convolution filter size in MKL-DNN order. MKL-DNN @@ -156,7 +168,13 @@ class MklDnnConvUtil { // MKL-DNN always needs filter in OIHW format. // OIHW = (out_depth, in_depth, rows, cols) - *filter_dims = {out_depth, in_depth, filter_rows, filter_cols}; + std::vector<int> mkldnn_sizes(4, -1); + mkldnn_sizes[MklDnnDims::Dim_O] = out_depth; + mkldnn_sizes[MklDnnDims::Dim_I] = in_depth; + mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows; + mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols; + + *filter_dims = mkldnn_sizes; } // Calculate Convolution filter size in MKL-DNN order. MKL-DNN @@ -167,9 +185,9 @@ class MklDnnConvUtil { GetFilterSizeInMklOrder(size_t src_index, size_t filter_index, memory::dims *filter_dims) { CHECK_NOTNULL(filter_dims); - const Tensor& input = MklGetInput(context_, src_index); - const Tensor& filter = MklGetInput(context_, filter_index); - GetFilterSizeInMklOrder(input.shape(), filter.shape(), filter_dims); + GetFilterSizeInMklOrder(GetTfShape(context_, src_index), + GetTfShape(context_, filter_index), + filter_dims); } // Calculate Bias size for 2D Convolution. Function does not return @@ -238,8 +256,12 @@ class MklDnnConvUtil { *output_dims_tf_order = TFShapeToMklDnnDims(out_shape); // MKL-DNN always needs output in NCHW format. - *output_dims_mkl_order = {out_batch, out_depth, static_cast<int>(out_rows), - static_cast<int>(out_cols)}; + std::vector<int> mkldnn_sizes(4, -1); + mkldnn_sizes[MklDnnDims::Dim_N] = out_batch; + mkldnn_sizes[MklDnnDims::Dim_C] = out_depth; + mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows); + mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols); + *output_dims_mkl_order = mkldnn_sizes; // Now handle padding. MKL-DNN uses asymetric padding. *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)}; @@ -261,14 +283,14 @@ class MklDnnConvUtil { CHECK_NOTNULL(pad_l); CHECK_NOTNULL(pad_r); - const Tensor& input = MklGetInput(context_, src_index); - const Tensor& filter = MklGetInput(context_, filter_index); + auto input_tf_shape = GetTfShape(context_, src_index); + auto filter_tf_shape = GetTfShape(context_, filter_index); - OP_REQUIRES(context_, input.dims() == 4, + OP_REQUIRES(context_, input_tf_shape.dims() == 4, errors::InvalidArgument("input must be 4-dimensional", - input.shape().DebugString())); + input_tf_shape.DebugString())); - GetOutputAndPadSizeInMklOrder(input.shape(), filter.shape(), + GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape, strides, output_dims_tf_order, output_dims_mkl_order, pad_l, pad_r); } @@ -309,8 +331,231 @@ class MklDnnConvUtil { } }; +///////////////////////////////////////////////////////////////////// +/// Common class that implements Conv2DBackpropFilter and Input +///////////////////////////////////////////////////////////////////// + +template <typename Device, class T> +class MklConv2DBackpropCommonOp : public OpKernel { + public: + ~MklConv2DBackpropCommonOp() {} + explicit MklConv2DBackpropCommonOp(OpKernelConstruction* context) + : OpKernel(context) { + string data_format_str; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + int stride_n = GetTensorDim(strides_, data_format_, 'N'); + int stride_c = GetTensorDim(strides_, data_format_, 'C'); + OP_REQUIRES( + context, (stride_n == 1 && stride_c == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + try { + auto cpu_engine = engine(engine::cpu, 0); + + // Prepare common tensors for Conv2DBackpropInput and + // Conv2DBackpropFilter. + MklDnnData<T> input(&cpu_engine); + MklDnnData<T> filter(&cpu_engine); + MklDnnData<T> outbackprop(&cpu_engine); + MklDnnData<T> output(&cpu_engine); + + // Input tensors + const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2; + const Tensor& input_tensor = MklGetInput(context, kInputIdx); + const Tensor& filter_tensor = MklGetInput(context, kFilterIdx); + const Tensor& outbprop_tensor = MklGetInput(context, kOutbpropIdx); + + MklDnnShape input_mkl_shape, filter_mkl_shape, outbprop_mkl_shape; + GetMklShape(context, kInputIdx, &input_mkl_shape); + GetMklShape(context, kFilterIdx, &filter_mkl_shape); + GetMklShape(context, kOutbpropIdx, &outbprop_mkl_shape); + // Allow operator-specific sanity checking of shapes. + ValidateMklShapes(input_mkl_shape, filter_mkl_shape, outbprop_mkl_shape); + + // Allow operator-specific generation of shapes. + // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a + // tensor containing shape of filter. So filter.shape() is not + // a correct way to get filter shape. These operator-specific calls + // allow this class to handle this case. + TensorShape input_tf_shape = MakeInputTfShape(context, input_tensor); + TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor); + TensorShape outbprop_tf_shape = GetTfShape(context, kOutbpropIdx); + + // By default, all dims are in MKL order. Only dims in TF order + // are those with prefix tf_order. + memory::dims outbprop_dims, fwd_input_dims, fwd_filter_dims; + memory::dims padding_l, padding_r, strides, fwd_output_dims; + memory::dims fwd_output_dims_tf_order; + + // Get forward convolution parameters. + MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_); + conv_utl.GetConvFwdSizesInMklOrder(input_tf_shape, filter_tf_shape, + &fwd_input_dims, &fwd_filter_dims, + &strides, + &fwd_output_dims_tf_order, + &fwd_output_dims, + &padding_l, &padding_r); + if (!context->status().ok()) return; + + // Create Convolution forward descriptor since Convolution backward + // API needs it. For that, we first need to create input, filter + // and output memory descriptors. + auto tf_fmt = TFDataFormatToMklDnnDataFormat(data_format_); + // If input is in MKL layout, then simply grab input layout; otherwise, + // construct input TF layout. For TF layout, although input shape + // required is in MKL-DNN order, the layout is Tensorflow's layout + // (NHWC or NCHW depending on data format). + auto fwd_input_md = input_mkl_shape.IsMklTensor() ? + input_mkl_shape.GetMklLayout() : + memory::desc(fwd_input_dims, MklDnnType<T>(), tf_fmt); + // If filter is in MKL layout, then simply grab filter layout; otherwise + // construct filter in TF layout. For TF layout, filter is in HWIO format. + auto fwd_filter_md = filter_mkl_shape.IsMklTensor() ? + filter_mkl_shape.GetMklLayout() : + memory::desc(fwd_filter_dims, MklDnnType<T>(), + memory::format::hwio); + // Tensorflow Output of Conv2D is in data_format order. + auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(), tf_fmt); + auto fwd_desc = convolution_forward::desc(prop_kind::forward, + convolution_direct, fwd_input_md, fwd_filter_md, fwd_out_md, + strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_)); + auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine); + + // Create memory for user data. Describe how the inputs and outputs of + // Convolution look like. Also specify buffers containing actual input + // and output data. + + // Since this is a common class for both Conv2DBackpropFilter and + // Conv2DBackpropInput, we skip SetUsrMem call for input tensor (for + // Conv2DBackpropInput) and for filter tensor (for + // conv2DBackpropFilter) depending on which tensor is int32 type. + size_t input_with_sizes = GetInputTensorIndexWithSizes(); + if (input_with_sizes != kInputIdx) { + // Shape of Conv2DBackpropFilter's input is same as Conv2D input. + input.SetUsrMem(fwd_input_md, &input_tensor); + } else if (input_with_sizes != kFilterIdx) { + // Shape of Conv2DBackpropInput's filter is same as Conv2D filter. + filter.SetUsrMem(fwd_filter_md, &filter_tensor); + } + + conv_utl.GetInputSizeInMklOrder(outbprop_tf_shape, &outbprop_dims); + if (!context->status().ok()) return; + if (outbprop_mkl_shape.IsMklTensor()) { + // If outbackprop is in Mkl layout, then simply grab it. + auto outbprop_md = outbprop_mkl_shape.GetMklLayout(); + outbackprop.SetUsrMem(outbprop_md, &outbprop_tensor); + } else { + // If outbackprop is in TensorFlow layout, then we need to create memory + // descriptor for it. Outbackprop shape is data format order. + outbackprop.SetUsrMem(outbprop_dims, tf_fmt, &outbprop_tensor); + } + + // Operator specific call to get output shape and data_format. + auto bwd_output_dims = GetOutputDims(fwd_input_dims, fwd_filter_dims); + auto bwd_output_format = GetOutputFormat(tf_fmt); + output.SetUsrMem(bwd_output_dims, bwd_output_format); + + // Create memory descriptors for convolution data w/ no specified format. + input.SetOpMemDesc(fwd_input_dims, memory::format::any); + filter.SetOpMemDesc(fwd_filter_dims, memory::format::any); + outbackprop.SetOpMemDesc(outbprop_dims, memory::format::any); + output.SetOpMemDesc(bwd_output_dims, memory::format::any); + + // Operator-specific call to create and execute primitive. + Tensor* output_tensor = nullptr; + CreatePrimitive(context, cpu_engine, fwd_pd, &input, &filter, + &outbackprop, &output, &output_tensor, + strides, padding_l, padding_r, + TFPaddingToMklDnnPadding(padding_), + bwd_output_dims, bwd_output_format); + } catch (mkldnn::error &e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + + ", in file " + string(__FILE__) + ":" + + std::to_string(__LINE__); + OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:", + error_msg)); + } + } + + /// Pure virtual function to allow operator to check for validity of input + /// shapes. Function asserts that input shapes are valid. + virtual void ValidateMklShapes(const MklDnnShape& input_mkl_shape, + const MklDnnShape& filter_mkl_shape, + const MklDnnShape& outbprop_mkl_shape) = 0; + + /// Operator-specific function that returns index of input that is + /// representing input sizes. For Conv2DBackpropFilter it returns 1 since + /// filter for this operator is filter shape. For Conv2DBackpropInput it + /// returns 0 (for input). + virtual size_t GetInputTensorIndexWithSizes() = 0; + + /// Get TensorFlow shape of input tensor. + virtual TensorShape MakeInputTfShape(OpKernelContext* context, + const Tensor& input_tensor) = 0; + + /// Get TensorFlow shape of filter tensor. + virtual TensorShape MakeFilterTfShape(OpKernelContext* context, + const Tensor& filter_tensor) = 0; + + /// Get shape of output in MKL-DNN order. Computes shape of output from + /// input shape (fwd_input_dims) and filter shape (fwd_filter_dims). + virtual + const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims, + const memory::dims& fwd_filter_dims) = 0; + + /// Get data_format of output in MKL-DNN order. If output data format is + /// same as input data format, then it simply returns value of data_format + /// parameter as it is. + virtual memory::format GetOutputFormat(const memory::format data_format) = 0; + + /// Create and execute the primitive storing output in the output_tensor. + virtual void CreatePrimitive(OpKernelContext* context, + const engine& cpu_engine, + const convolution_forward::primitive_desc& conv_fwd_pd, + MklDnnData<T>* input, MklDnnData<T>* filter, MklDnnData<T>* outbackprop, + MklDnnData<T>* output, Tensor** output_tensor, const memory::dims& strides, + const memory::dims& padding_l, const memory::dims& padding_r, + padding_kind padding, const memory::dims& bwd_output_dims, + memory::format bwd_output_format) = 0; + + // Get the data_format {NCHW, NHWC} + TensorFormat GetTFDataFormat () { return data_format_; } + + private: + std::vector<int32> strides_; + Padding padding_; + TensorFormat data_format_; +}; #endif // INTEL_MKL_DNN +///////////////////////////////////////////////////////////////////// +/// Dummy Mkl op that is just used for operators that are intermediate +/// output of node fusion in the graph +///////////////////////////////////////////////////////////////////// + +template <typename Device, typename T> +class MklDummyOp : public OpKernel { + public: + ~MklDummyOp() {} + + explicit MklDummyOp(OpKernelConstruction* context) : + OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + TF_CHECK_OK(errors::Unimplemented("This is a dummy op." + "It should not have been invoked.")); + } +}; + } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_ |