diff options
author | 2017-12-15 17:32:50 -0800 | |
---|---|---|
committer | 2017-12-15 17:39:26 -0800 | |
commit | 9648f8040a559f6cf9bbe0501ba96f2b2c2864b1 (patch) | |
tree | 57dc6e959e0a534622eaf392ee43b7691378b10e /tensorflow/core/kernels/mkl_conv_grad_input_ops.cc | |
parent | 5b5445b9a7aa2664a90c4fc946ecf268c971425b (diff) |
Automated g4 rollback of changelist 179258973
PiperOrigin-RevId: 179260538
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_input_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_grad_input_ops.cc | 244 |
1 files changed, 143 insertions, 101 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index df51df9638..4a47d0463e 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -49,6 +49,9 @@ limitations under the License. using mkldnn::stream; using mkldnn::prop_kind; + +using mkldnn::convolution_forward; +using mkldnn::convolution_direct; using mkldnn::convolution_backward_data; #endif @@ -359,117 +362,143 @@ class MklConv2DCustomBackpropInputOp : public OpKernel { #else template <typename Device, class T> -class MklConv2DCustomBackpropInputOp : - public MklConv2DBackpropCommonOp<Device, T> { +class MklConv2DCustomBackpropInputOp : public OpKernel { public: - explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context) - : MklConv2DBackpropCommonOp<Device, T>(context) { } ~MklConv2DCustomBackpropInputOp() {} + explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context) + : OpKernel(context) { + string data_format_str; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + int stride_n = GetTensorDim(strides_, data_format_, 'N'); + int stride_c = GetTensorDim(strides_, data_format_, 'C'); + OP_REQUIRES( + context, (stride_n == 1 && stride_c == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); - private: - void ValidateMklShapes(const MklDnnShape& input_mkl_shape, - const MklDnnShape& filter_mkl_shape, - const MklDnnShape& obp_mkl_shape) { - // Tensor that feeds to 'Input' slot of BackpropInput is always just a shape - // of the Tensor and never an actual tensor. So it will never be in MKL - // layout. - CHECK(!input_mkl_shape.IsMklTensor()) - << "Conv2DBackpropInput: input should not be in MKL Layout"; - } - - size_t GetInputTensorIndexWithSizes() { return 0; /* input index */ } - - TensorShape MakeInputTfShape(OpKernelContext* context, - const Tensor& input_tensor) { - TensorShape input_tf_shape; - CHECK_EQ(TensorShapeUtils::IsVector(input_tensor.shape()), true); - CHECK_EQ(TensorShapeUtils::MakeShape(input_tensor.vec<int32>(), - &input_tf_shape).ok(), true); - return input_tf_shape; + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); } - TensorShape MakeFilterTfShape(OpKernelContext* context, - const Tensor& filter_tensor) { - size_t filter_idx = 1; - return GetTfShape(context, filter_idx); - } + void Compute(OpKernelContext* context) override { + try { + auto cpu_engine = engine(engine::cpu, 0); - const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims, - const memory::dims& fwd_filter_dims) { - // Output Shape of Conv2DBackpropInput is same as shape of Conv2D 'input'. - return fwd_input_dims; - } + MklDnnData<T> filter(&cpu_engine); + MklDnnData<T> outbackprop(&cpu_engine); + MklDnnData<T> output(&cpu_engine); - memory::format GetOutputFormat(const memory::format data_format) { - // Output layout is Tensorflow's layout in data format order. - return data_format; - } + // Input tensors + const Tensor& input_tensor = MklGetInput(context, 0); + const Tensor& filter_tensor = MklGetInput(context, 1); + const Tensor& obp_tensor = MklGetInput(context, 2); // Outbackprop - void CreatePrimitive(OpKernelContext* context, - const engine& cpu_engine, - const convolution_forward::primitive_desc& conv_fwd_pd, - MklDnnData<T>* input, MklDnnData<T>* filter, - MklDnnData<T>* outbackprop, MklDnnData<T>* output, - Tensor** output_tensor, - const memory::dims& strides, - const memory::dims& padding_l, - const memory::dims& padding_r, - padding_kind padding, - const memory::dims& bwd_output_dims, - memory::format bwd_output_format) { - CHECK_NOTNULL(context); - CHECK_NOTNULL(input); - CHECK_NOTNULL(filter); - CHECK_NOTNULL(outbackprop); - CHECK_NOTNULL(output); - CHECK_NOTNULL(output_tensor); - - // Create convolution backward data primitive. - auto bwd_desc = convolution_backward_data::desc(convolution_direct, - output->GetOpMemDesc(), filter->GetOpMemDesc(), - outbackprop->GetOpMemDesc(), strides, padding_l, - padding_r, padding); - - auto bwd_pd = convolution_backward_data::primitive_desc(bwd_desc, - cpu_engine, - conv_fwd_pd); - - - // Allocate output tensor in TensorFlow and MKL layout. - AllocateOutputTensor(context, bwd_pd, bwd_output_dims, - bwd_output_format, output_tensor); - CHECK_NOTNULL(*output_tensor); - // Set buffer handle using allocated output tensor. - output->SetUsrMemDataHandle(*output_tensor); - - PrepareAndExecutePrimitive(bwd_pd, filter, outbackprop, output); + // Generate input shape. + TensorShape input_shape; + OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()), + errors::InvalidArgument( + "Conv2DBackpropInput: input_sizes input must be 1-dim, not ", + input_tensor.dims())); + OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( + input_tensor.vec<int32>(), &input_shape)); + TensorShape filter_shape = filter_tensor.shape(); + TensorShape obp_shape = obp_tensor.shape(); + + // By default, all dims are in MKL order. Only dims in TF order + // are those with prefix tf_order. + memory::dims obp_dims, fwd_input_dims, fwd_filter_dims; + memory::dims padding_l, padding_r, strides, fwd_output_dims; + memory::dims fwd_output_dims_tf_order; + + // Get forward convolution parameters. + MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_); + conv_utl.GetConvFwdSizesInMklOrder(input_shape, filter_shape, + &fwd_input_dims, &fwd_filter_dims, + &strides, + &fwd_output_dims_tf_order, + &fwd_output_dims, + &padding_l, &padding_r); + if (!context->status().ok()) return; + + // Create Convolution forward descriptor since Convolution backward + // API needs it. For that, we first need to create input, filter + // and output memory descriptors. + auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_); + auto fwd_src_md = memory::desc(fwd_input_dims, MklDnnType<T>(), + mkl_data_format); + auto fwd_filter_md = memory::desc(fwd_filter_dims, MklDnnType<T>(), + memory::format::hwio); + auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(), + mkl_data_format); + auto fwd_desc = convolution_forward::desc(prop_kind::forward, + convolution_direct, fwd_src_md, fwd_filter_md, fwd_out_md, + strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_)); + auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine); + + // Allocate output tensor and shape + // TODO(nhasabni): Update this when support for MKL layout is added. + // Shape of output of Conv2DBackpropInput is same as 'input' of Conv2D. + TensorShape tf_output_shape(input_shape); + MklShape mkl_output_mkl_shape; + mkl_output_mkl_shape.SetMklTensor(false); + Tensor* output_tensor = nullptr; + AllocateOutputSetMklShape(context, 0, &output_tensor, tf_output_shape, + mkl_output_mkl_shape); + + // Create memory for user data. + // Describe how the inputs and outputs of Convolution look like. Also + // specify buffers containing actual input and output data. + // Although input shape required is in MKL-DNN order, the layout is + // Tensorflow's layout (NHWC or NCHW depending on data format). + // Although filter shape (filter_dims) required is in MKL-DNN order, + // the layout is Tensorflow's layout (HWIO). + // Shape of Conv2DBackpropInput's filter is same as that of Conv2D filter. + filter.SetUsrMem(fwd_filter_dims, memory::format::hwio, &filter_tensor); + // Outbackprop shape is NHWC or NCHW depending on data format. Since + // GetInputSizeInMklOrder function returns size in that order we just use + // use that function directly. + conv_utl.GetInputSizeInMklOrder(obp_shape, &obp_dims); + if (!context->status().ok()) return; + outbackprop.SetUsrMem(obp_dims, mkl_data_format, &obp_tensor); + // Although output shape required is in MKL-DNN order, + // layout is Tensorflow's layout (NHWC or NCHW depending on data format). + // Shape of output of Conv2DBackpropInput is same as shape of 'input' + // of Conv2D. + memory::dims bwd_output_dims = fwd_input_dims; + output.SetUsrMem(bwd_output_dims, mkl_data_format, output_tensor); + + // Create memory descriptors for convolution data w/ no specified format. + filter.SetOpMemDesc(fwd_filter_dims, memory::format::any); + outbackprop.SetOpMemDesc(obp_dims, memory::format::any); + output.SetOpMemDesc(bwd_output_dims, memory::format::any); + + // Create convolution backward data primitive. + auto bwd_desc = convolution_backward_data::desc(convolution_direct, + output.GetOpMemDesc(), filter.GetOpMemDesc(), + outbackprop.GetOpMemDesc(), strides, padding_l, + padding_r, TFPaddingToMklDnnPadding(padding_)); + + auto bwd_pd = convolution_backward_data::primitive_desc(bwd_desc, + cpu_engine, + fwd_pd); + + PrepareAndExecutePrimitive(bwd_pd, &filter, &outbackprop, &output); + } catch (mkldnn::error &e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + + ", in file " + string(__FILE__) + ":" + + std::to_string(__LINE__); + OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:", + error_msg)); + } } - // Allocate output tensor. - void AllocateOutputTensor(OpKernelContext* context, - const convolution_backward_data::primitive_desc& conv_pd, - const memory::dims& output_dims_mkl_order, - memory::format output_tf_format, Tensor** output_tensor) { - CHECK_NOTNULL(output_tensor); - - // Output primitive descriptor for backward data is diff_src. - auto dst_pd = conv_pd.diff_src_primitive_desc(); - - // Allocate shape of Mkl tensor. - MklDnnShape output_mkl_shape; - output_mkl_shape.SetMklTensor(true); - output_mkl_shape.SetMklLayout(&dst_pd); - output_mkl_shape.SetElemType(MklDnnType<T>()); - output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), - output_dims_mkl_order, output_tf_format); - - // Allocate shape of TF tensor. - TensorShape output_tf_shape; - output_tf_shape.AddDim(dst_pd.get_size() / sizeof(T)); - - AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape, - output_mkl_shape); - } + private: + std::vector<int32> strides_; + Padding padding_; + TensorFormat data_format_; // Prepare and execute net - checks for input and output reorders. void PrepareAndExecutePrimitive( @@ -482,9 +511,22 @@ class MklConv2DCustomBackpropInputOp : filter->CheckReorderToOpMem(conv_pd.weights_primitive_desc(), &net); obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net); + // Memory for output of convolution. Since we may need reorder on the + // output side, we will prepare reorder primitive in case output + // reorder to user memory is required. + bool output_reorder_required = output->PrepareReorderToUserMemIfReq( + conv_pd.diff_src_primitive_desc()); + net.push_back(convolution_backward_data(conv_pd, obp->GetOpMem(), filter->GetOpMem(), output->GetOpMem())); + // Insert reorder primitive in the net for output reorder if reorder is + // required. + if (output_reorder_required) { + output->InsertReorderToUserMem(&net); + } + + // Handle output reorder stream(stream::kind::eager).submit(net).wait(); } }; |