aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
diff options
context:
space:
mode:
authorGravatar Dandelion Man? <dandelion@google.com>2017-12-15 18:15:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-15 18:19:09 -0800
commit90e42f3ac8c43474633136af4242dca04b6a1e09 (patch)
tree64dbb44252c89c847bee86db07cea5aa94072e7c /tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
parent713d45278491d792c525344de6038a61ebcb2136 (diff)
Automated g4 rollback of changelist 179260538
PiperOrigin-RevId: 179263865
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc')
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc317
1 files changed, 180 insertions, 137 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index f291281108..793fa24d99 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -47,11 +47,8 @@ limitations under the License.
using mkldnn::stream;
using mkldnn::prop_kind;
-
-using mkldnn::convolution_forward;
using mkldnn::convolution_backward_weights;
-using mkldnn::convolution_direct;
-
+using mkldnn::memory;
#endif
namespace tensorflow {
@@ -426,183 +423,229 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
TensorFormat data_format_;
};
+#define REGISTER_MKL_FILTER_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConv2DCustomBackpropFilterOp<CPUDevice, T>);
+TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
+#undef REGISTER_MKL_FILTER_KERNELS
+
#else
-template <typename Device, class T>
-class MklConv2DCustomBackpropFilterOp : public OpKernel {
+template <typename Device, class T, bool biasEnabled>
+class MklConv2DCustomBackpropFilterOp :
+ public MklConv2DBackpropCommonOp<Device, T> {
public:
explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context)
- : OpKernel(context) {
- string data_format;
- OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
- OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ : MklConv2DBackpropCommonOp<Device, T>(context) { }
+ ~MklConv2DCustomBackpropFilterOp() {}
- OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
- int stride_n = GetTensorDim(strides_, data_format_, 'N');
- int stride_c = GetTensorDim(strides_, data_format_, 'C');
- OP_REQUIRES(
- context, (stride_n == 1 && stride_c == 1),
- errors::InvalidArgument("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
- OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ private:
+ void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
+ const MklDnnShape& filter_mkl_shape,
+ const MklDnnShape& obp_mkl_shape) {
+ CHECK(!filter_mkl_shape.IsMklTensor())
+ << "Conv2DBackpropFilter: filter should not be in MKL Layout";
}
- void Compute(OpKernelContext* context) override {
- try {
- auto cpu_engine = engine(engine::cpu, 0);
+ size_t GetInputTensorIndexWithSizes() { return 1; /* filter index */ }
- MklDnnData<T> input(&cpu_engine);
- MklDnnData<T> outbackprop(&cpu_engine);
- MklDnnData<T> output(&cpu_engine);
+ TensorShape MakeInputTfShape(OpKernelContext* context,
+ const Tensor& input_tensor) {
+ size_t input_idx = 0;
+ return GetTfShape(context, input_idx);
+ }
- // Input tensors
- const Tensor& input_tensor = MklGetInput(context, 0);
- const Tensor& filter_tensor = MklGetInput(context, 1);
- const Tensor& obp_tensor = MklGetInput(context, 2); // Outbackprop
+ TensorShape MakeFilterTfShape(OpKernelContext* context,
+ const Tensor& filter_tensor) {
+ TensorShape filter_tf_shape;
+ CHECK_EQ(TensorShapeUtils::IsVector(filter_tensor.shape()), true);
+ CHECK_EQ(TensorShapeUtils::MakeShape(
+ filter_tensor.vec<int32>(), &filter_tf_shape).ok(), true);
+ return filter_tf_shape;
+ }
- // Generate input shapes.
- TensorShape filter_shape;
- OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_tensor.shape()),
- errors::InvalidArgument(
- "Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ",
- filter_tensor.dims()));
- OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
- filter_tensor.vec<int32>(), &filter_shape));
- TensorShape input_shape = input_tensor.shape();
- TensorShape obp_shape = obp_tensor.shape();
-
- // By default, all dims are in MKL order. Only dims in TF order
- // are those with prefix tf_order.
- memory::dims obp_dims, fwd_input_dims, fwd_filter_dims;
- memory::dims padding_l, padding_r, strides, fwd_output_dims;
- memory::dims fwd_output_dims_tf_order;
-
- // Get forward convolution parameters.
- MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
- conv_utl.GetConvFwdSizesInMklOrder(input_shape, filter_shape,
- &fwd_input_dims, &fwd_filter_dims,
- &strides,
- &fwd_output_dims_tf_order,
- &fwd_output_dims,
- &padding_l, &padding_r);
- if (!context->status().ok()) return;
-
- // Create Convolution forward descriptor since Convolution backward
- // API needs it. For that, we first need to create input, filter
- // and output memory descriptors.
- auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_);
- auto fwd_src_md = memory::desc(fwd_input_dims, MklDnnType<T>(),
- mkl_data_format);
- auto fwd_filter_md = memory::desc(fwd_filter_dims, MklDnnType<T>(),
- memory::format::hwio);
- auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(),
- mkl_data_format);
- auto fwd_desc = convolution_forward::desc(prop_kind::forward,
- convolution_direct, fwd_src_md, fwd_filter_md, fwd_out_md,
- strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_));
- auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine);
-
- // Allocate output tensor and shape
- // TODO(nhasabni): Update this when support for MKL layout is added.
- // Shape of output of Conv2DBackpropInput is same as 'input' of Conv2D.
- TensorShape tf_output_shape(filter_shape);
- MklShape mkl_output_mkl_shape;
- mkl_output_mkl_shape.SetMklTensor(false);
- Tensor* output_tensor = nullptr;
- AllocateOutputSetMklShape(context, 0, &output_tensor, tf_output_shape,
- mkl_output_mkl_shape);
-
- // Create memory for user data.
- // Describe how the inputs and outputs of Convolution look like. Also
- // specify buffers containing actual input and output data.
- // Although input shape required is in MKL-DNN order, the layout is
- // Tensorflow's layout (NHWC or NCHW depending on data format).
- input.SetUsrMem(fwd_input_dims, mkl_data_format, &input_tensor);
- // Outbackprop shape is NHWC or NCHW depending on data format. Since
- // GetInputSizeInMklOrder function returns size in that order we just use
- // use that function directly.
- conv_utl.GetInputSizeInMklOrder(obp_shape, &obp_dims);
- if (!context->status().ok()) return;
- outbackprop.SetUsrMem(obp_dims, mkl_data_format, &obp_tensor);
- // Although output shape required is in MKL-DNN order,
- // layout is Tensorflow's filter layout (HWIO)
- // Shape of output of Conv2DBackpropInput is same as shape of filter.
- memory::dims bwd_output_dims = fwd_filter_dims;
- output.SetUsrMem(bwd_output_dims, memory::format::hwio, output_tensor);
-
- // Create memory descriptors for convolution data w/ no specified format.
- input.SetOpMemDesc(fwd_input_dims, memory::format::any);
- outbackprop.SetOpMemDesc(obp_dims, memory::format::any);
- output.SetOpMemDesc(bwd_output_dims, memory::format::any);
-
- // Create convolution backward weights primitive.
- auto bwd_desc = convolution_backward_weights::desc(convolution_direct,
- input.GetOpMemDesc(), output.GetOpMemDesc(),
- outbackprop.GetOpMemDesc(), strides, padding_l,
- padding_r, TFPaddingToMklDnnPadding(padding_));
-
- auto bwd_pd = convolution_backward_weights::primitive_desc(bwd_desc,
- cpu_engine,
- fwd_pd);
-
- PrepareAndExecutePrimitive(bwd_pd, &input, &outbackprop, &output);
- } catch (mkldnn::error &e) {
- string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + string(e.message) +
- ", in file " + string(__FILE__) + ":" +
- std::to_string(__LINE__);
- OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:",
- error_msg));
+ const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims,
+ const memory::dims& fwd_filter_dims) {
+ // Shape of output of Conv2DBackpropFilter is same as shape of filter.
+ return fwd_filter_dims;
+ }
+
+ memory::format GetOutputFormat(const memory::format data_format) {
+ // Output layout is Tensorflow's filter layout (HWIO).
+ return memory::format::hwio;
+ }
+
+ void CreatePrimitive(OpKernelContext* context,
+ const engine& cpu_engine,
+ const convolution_forward::primitive_desc& conv_fwd_pd,
+ MklDnnData<T>* input, MklDnnData<T>* filter,
+ MklDnnData<T>* outbackprop, MklDnnData<T>* output,
+ Tensor** output_tensor,
+ const memory::dims& strides,
+ const memory::dims& padding_l,
+ const memory::dims& padding_r,
+ padding_kind padding,
+ const memory::dims& bwd_output_dims,
+ memory::format bwd_output_format) {
+ CHECK_NOTNULL(context);
+ CHECK_NOTNULL(input);
+ CHECK_NOTNULL(filter);
+ CHECK_NOTNULL(outbackprop);
+ CHECK_NOTNULL(output);
+ CHECK_NOTNULL(output_tensor);
+
+ MklDnnData<T>* bias_grad = nullptr;
+ int depth = 0;
+ if (biasEnabled) {
+ // Data structure for bias_grad
+ bias_grad = new MklDnnData<T> (&cpu_engine);
+ TensorShape obp_tf_shape = GetTfShape(context, 2);
+ depth = (MklConv2DBackpropCommonOp<Device, T>::GetTFDataFormat()
+ == FORMAT_NCHW) ?
+ obp_tf_shape.dim_size(1) : obp_tf_shape.dim_size(3);
+ memory::dims bias_grad_dims = {depth};
+ bias_grad->SetOpMemDesc(bias_grad_dims, memory::format::x);
+ }
+
+ // Create convolution backward weights primitive.
+ auto bwd_desc = (biasEnabled && (bias_grad != nullptr))?
+ convolution_backward_weights::desc(convolution_direct,
+ input->GetOpMemDesc(), output->GetOpMemDesc(),
+ bias_grad->GetOpMemDesc(),
+ outbackprop->GetOpMemDesc(), strides, padding_l,
+ padding_r, padding) :
+ convolution_backward_weights::desc(convolution_direct,
+ input->GetOpMemDesc(), output->GetOpMemDesc(),
+ outbackprop->GetOpMemDesc(), strides, padding_l,
+ padding_r, padding);
+
+ auto bwd_pd = convolution_backward_weights::primitive_desc(bwd_desc,
+ cpu_engine,
+ conv_fwd_pd);
+
+ // Allocate output tensor.
+ AllocateOutputTensor(context, bwd_pd, bwd_output_dims,
+ bwd_output_format, output_tensor);
+
+ CHECK_NOTNULL(*output_tensor);
+ // Set buffer handle using allocated output tensor.
+ output->SetUsrMemDataHandle(*output_tensor);
+
+ if (biasEnabled && (bias_grad != nullptr)) {
+ // Allocate bias_grad tensor
+ TensorShape bias_grad_shape({depth});
+ Tensor* bias_grad_tensor = nullptr;
+ AllocateBiasGradTensor(context, bias_grad_shape, &bias_grad_tensor);
+ memory::dims bias_grad_dims = {depth};
+ // Since Bias is 1D, we use format::x from MKLDNN to represent it.
+ auto bias_grad_md = memory::desc({bias_grad_dims}, MklDnnType<T>(),
+ memory::format::x);
+ bias_grad->SetUsrMem(bias_grad_md, bias_grad_tensor);
+ bias_grad->SetUsrMemDataHandle(bias_grad_tensor);
+ }
+
+ if (biasEnabled && (bias_grad != nullptr)) {
+ PrepareAndExecutePrimitive(bwd_pd, input, outbackprop, output, bias_grad);
+ } else {
+ PrepareAndExecutePrimitive(bwd_pd, input, outbackprop, output);
}
}
- private:
- std::vector<int32> strides_;
- Padding padding_;
- TensorFormat data_format_;
+ // Allocate output tensor.
+ void AllocateOutputTensor(OpKernelContext* context,
+ const convolution_backward_weights::primitive_desc& conv_pd,
+ const memory::dims& output_dims_mkl_order,
+ memory::format output_tf_format, Tensor** output_tensor) {
+ CHECK_NOTNULL(output_tensor);
+
+ // For BackpropFilter, we convert the output tensor back in Tensorflow
+ // layout. Because typically, BackpropFilter is the last operator in the
+ // graph that emit filter gradient that is provided to ApplyGradient
+ // method to update the filter. But it may be possible to eliminate this
+ // by forwarding filter in MKL layout if we support ApplyGradient method
+ // for MKL layout propagation.
+ MklDnnShape output_mkl_shape;
+ output_mkl_shape.SetMklTensor(false);
+ // output_dims_mkl_order is in OIHW format.
+ // Allocate shape of TF tensor in HWIO format.
+ TensorShape output_tf_shape({output_dims_mkl_order[MklDnnDims::Dim_H],
+ output_dims_mkl_order[MklDnnDims::Dim_W],
+ output_dims_mkl_order[MklDnnDims::Dim_I],
+ output_dims_mkl_order[MklDnnDims::Dim_O]});
+ AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
+ output_mkl_shape);
+ }
+
+ // Allocate tensor for bias grad
+ void AllocateBiasGradTensor(OpKernelContext* context,
+ const TensorShape& bias_grad_shape,
+ Tensor** bias_grad_tensor) {
+ CHECK_NOTNULL(bias_grad_tensor);
+
+ MklDnnShape bias_grad_mkl_shape;
+ bias_grad_mkl_shape.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, 1, bias_grad_tensor, bias_grad_shape,
+ bias_grad_mkl_shape);
+ }
// Prepare and execute net - checks for input and output reorders.
void PrepareAndExecutePrimitive(
const convolution_backward_weights::primitive_desc& conv_pd,
MklDnnData<T>* input, MklDnnData<T>* obp,
- MklDnnData<T>* output) {
+ MklDnnData<T>* output, MklDnnData<T>* bias_grad = nullptr) {
// Create reorders between user layout and MKL layout if it is needed and
// add it to the net before convolution.
std::vector<primitive> net;
input->CheckReorderToOpMem(conv_pd.src_primitive_desc(), &net);
obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net);
- // Memory for output of convolution. Since we may need reorder on the
- // output side, we will prepare reorder primitive in case output
- // reorder to user memory is required.
+ // For BackpropFilter, we convert the output tensor back in Tensorflow
+ // layout.
bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
conv_pd.diff_weights_primitive_desc());
- net.push_back(convolution_backward_weights(conv_pd, input->GetOpMem(),
- obp->GetOpMem(), output->GetOpMem()));
+ if (biasEnabled && (bias_grad != nullptr)) {
+ net.push_back(convolution_backward_weights(conv_pd, input->GetOpMem(),
+ obp->GetOpMem(), output->GetOpMem(),
+ bias_grad->GetOpMem()));
+ } else {
+ net.push_back(convolution_backward_weights(conv_pd, input->GetOpMem(),
+ obp->GetOpMem(), output->GetOpMem()));
+ }
- // Insert reorder primitive in the net for output reorder if reorder is
- // required.
if (output_reorder_required) {
output->InsertReorderToUserMem(&net);
}
- // Handle output reorder
stream(stream::kind::eager).submit(net).wait();
}
};
-#endif
#define REGISTER_MKL_FILTER_KERNELS(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DCustomBackpropFilterOp<CPUDevice, T>);
+ MklConv2DCustomBackpropFilterOp<CPUDevice, T, false>);\
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilterWithBias") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConv2DCustomBackpropFilterOp<CPUDevice, T, true>); \
+ REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DBackpropFilterWithBias") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklDummyOp<CPUDevice, T>);
TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
#undef REGISTER_MKL_FILTER_KERNELS
+
+#endif // INTEL_MKL_DNN
+
} // namespace tensorflow
#endif // INTEL_MKL