diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_ops.h')
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_ops.h | 140 |
1 files changed, 74 insertions, 66 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h index e29af19ca9..f0cb37f8a4 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.h +++ b/tensorflow/core/kernels/mkl_conv_ops.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_ #define TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_ -#include <limits> #include <vector> +#include <limits> #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -26,8 +26,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/kernels/bounds_check.h" -#include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/strings/numbers.h" @@ -49,15 +49,15 @@ namespace tensorflow { class MklDnnConvUtil { protected: - OpKernelContext *context_; // We don't own this. + OpKernelContext* context_; // We don't own this. std::vector<int32> strides_; Padding padding_; TensorFormat data_format_; public: - MklDnnConvUtil(OpKernelContext *context, const std::vector<int32> &strides, - Padding pad, TensorFormat fm) - : context_(context), strides_(strides), padding_(pad), data_format_(fm) {} + MklDnnConvUtil(OpKernelContext* context, const std::vector<int32>& strides, + Padding pad, TensorFormat fm) : context_(context), + strides_(strides), padding_(pad), data_format_(fm) {} virtual ~MklDnnConvUtil() { context_ = nullptr; } @@ -75,14 +75,14 @@ class MklDnnConvUtil { // requires input in NCHW format. Function does not return anything. // But errors arising from sanity checks are returned in context's // status. - virtual inline void GetInputSizeInMklOrder(const TensorShape &input_shape, - memory::dims *input_dims) { -#define CHECK_BOUNDS(val, err_msg) \ - do { \ - OP_REQUIRES(context_, \ - FastBoundsCheck(val, std::numeric_limits<int>::max()), \ - errors::InvalidArgument(err_msg)); \ - } while (0) + virtual inline void + GetInputSizeInMklOrder(const TensorShape& input_shape, + memory::dims *input_dims) { + #define CHECK_BOUNDS(val, err_msg) do { \ + OP_REQUIRES(context_, FastBoundsCheck(val, \ + std::numeric_limits<int>::max()), \ + errors::InvalidArgument(err_msg)); \ + }while(0) CHECK_NOTNULL(input_dims); @@ -105,7 +105,7 @@ class MklDnnConvUtil { CHECK_BOUNDS(input_batch_raw, "Input batch too large"); int input_batch = static_cast<int>(input_batch_raw); -#undef CHECK_BOUNDS + #undef CHECK_BOUNDS // MKL-DNN always requires input in NCHW format. *input_dims = {input_batch, input_depth, input_rows, input_cols}; @@ -125,9 +125,10 @@ class MklDnnConvUtil { // forward gets actual tensor as input). // // TODO(nhasabni): Add similar function for input and filter in MklShape. - virtual inline void GetFilterSizeInMklOrder(const TensorShape &input_shape, - const TensorShape &filter_shape, - memory::dims *filter_dims) { + virtual inline void + GetFilterSizeInMklOrder(const TensorShape& input_shape, + const TensorShape& filter_shape, + memory::dims *filter_dims) { CHECK_NOTNULL(filter_dims); OP_REQUIRES(context_, filter_shape.dims() == 4, @@ -135,18 +136,17 @@ class MklDnnConvUtil { filter_shape.DebugString())); for (int i = 0; i < 3; i++) { - OP_REQUIRES(context_, - FastBoundsCheck(filter_shape.dim_size(i), - std::numeric_limits<int>::max()), - errors::InvalidArgument("filter too large")); + OP_REQUIRES(context_, FastBoundsCheck(filter_shape.dim_size(i), + std::numeric_limits<int>::max()), + errors::InvalidArgument("filter too large")); } int input_depth = GetTensorDim(input_shape, data_format_, 'C'); - OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2), - errors::InvalidArgument( - "input and filter must have the same depth: ", input_depth, - " vs ", filter_shape.dim_size(2))); + OP_REQUIRES( + context_, input_depth == filter_shape.dim_size(2), + errors::InvalidArgument("input and filter must have the same depth: ", + input_depth, " vs ", filter_shape.dim_size(2))); // TF filter is always in (rows, cols, in_depth, out_depth) order. int filter_rows = static_cast<int>(filter_shape.dim_size(0)); @@ -163,25 +163,25 @@ class MklDnnConvUtil { // requires filter in OIHW format. Function does not return anything. // But errors arising from sanity checks are returned in context's // status. - virtual inline void GetFilterSizeInMklOrder(size_t src_index, - size_t filter_index, - memory::dims *filter_dims) { + virtual inline void + GetFilterSizeInMklOrder(size_t src_index, size_t filter_index, + memory::dims *filter_dims) { CHECK_NOTNULL(filter_dims); - const Tensor &input = MklGetInput(context_, src_index); - const Tensor &filter = MklGetInput(context_, filter_index); + const Tensor& input = MklGetInput(context_, src_index); + const Tensor& filter = MklGetInput(context_, filter_index); GetFilterSizeInMklOrder(input.shape(), filter.shape(), filter_dims); } // Calculate Bias size for 2D Convolution. Function does not return // anything, but sets error in context status. - virtual inline void GetBiasSizeInMklOrder(size_t bias_index, - memory::dims *bias_dims) { - const Tensor &bias = MklGetInput(context_, bias_index); + virtual inline void + GetBiasSizeInMklOrder(size_t bias_index, memory::dims *bias_dims) { + const Tensor& bias = MklGetInput(context_, bias_index); OP_REQUIRES(context_, bias.dims() == 1, errors::InvalidArgument("bias must be 1-dimensional: ", bias.shape().DebugString())); - *bias_dims = {static_cast<int>(bias.dim_size(0))}; + *bias_dims = { static_cast<int>(bias.dim_size(0)) }; } // Function to calculate output and padding size for 2D convolution. @@ -193,11 +193,13 @@ class MklDnnConvUtil { // status is returned via context status. // // TODO(nhasabni): Add similar function for input and filter in MklShape. - virtual inline void GetOutputAndPadSizeInMklOrder( - const TensorShape &input_shape, const TensorShape &filter_shape, - const memory::dims &strides, memory::dims *output_dims_tf_order, - memory::dims *output_dims_mkl_order, memory::dims *pad_l, - memory::dims *pad_r) { + virtual inline void + GetOutputAndPadSizeInMklOrder(const TensorShape& input_shape, + const TensorShape& filter_shape, + const memory::dims& strides, + memory::dims *output_dims_tf_order, + memory::dims *output_dims_mkl_order, + memory::dims *pad_l, memory::dims *pad_r) { CHECK_NOTNULL(output_dims_tf_order); CHECK_NOTNULL(output_dims_mkl_order); CHECK_NOTNULL(pad_l); @@ -223,21 +225,21 @@ class MklDnnConvUtil { int64 out_rows = 0, out_cols = 0; int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right; - OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( - input_rows, filter_rows, stride_rows, padding_, - &out_rows, &pad_top, &pad_bottom)); - OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose( - input_cols, filter_cols, stride_cols, padding_, - &out_cols, &pad_left, &pad_right)); + OP_REQUIRES_OK(context_, + GetWindowedOutputSizeVerbose(input_rows, filter_rows, stride_rows, + padding_, &out_rows, &pad_top, &pad_bottom)); + OP_REQUIRES_OK(context_, + GetWindowedOutputSizeVerbose(input_cols, filter_cols, stride_cols, + padding_, &out_cols, &pad_left, &pad_right)); // Tensorflow output is in data_format order. (NHWC or NCHW) - TensorShape out_shape = - ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, out_depth); + TensorShape out_shape = ShapeFromFormat(data_format_, out_batch, + out_rows, out_cols, out_depth); *output_dims_tf_order = TFShapeToMklDnnDims(out_shape); // MKL-DNN always needs output in NCHW format. *output_dims_mkl_order = {out_batch, out_depth, static_cast<int>(out_rows), - static_cast<int>(out_cols)}; + static_cast<int>(out_cols)}; // Now handle padding. MKL-DNN uses asymetric padding. *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)}; @@ -248,25 +250,27 @@ class MklDnnConvUtil { // See comment on GetConvOutputAndPadSizeInMklOrder for parameters. // // Function does not return anything, but sets error in context status. - inline void GetOutputAndPadSizeInMklOrder( - size_t src_index, size_t filter_index, const memory::dims &strides, - memory::dims *output_dims_tf_order, memory::dims *output_dims_mkl_order, - memory::dims *pad_l, memory::dims *pad_r) { + inline void + GetOutputAndPadSizeInMklOrder(size_t src_index, size_t filter_index, + const memory::dims& strides, + memory::dims *output_dims_tf_order, + memory::dims *output_dims_mkl_order, + memory::dims *pad_l, memory::dims *pad_r) { CHECK_NOTNULL(output_dims_tf_order); CHECK_NOTNULL(output_dims_mkl_order); CHECK_NOTNULL(pad_l); CHECK_NOTNULL(pad_r); - const Tensor &input = MklGetInput(context_, src_index); - const Tensor &filter = MklGetInput(context_, filter_index); + const Tensor& input = MklGetInput(context_, src_index); + const Tensor& filter = MklGetInput(context_, filter_index); OP_REQUIRES(context_, input.dims() == 4, errors::InvalidArgument("input must be 4-dimensional", - input.shape().DebugString())); + input.shape().DebugString())); - GetOutputAndPadSizeInMklOrder(input.shape(), filter.shape(), strides, - output_dims_tf_order, output_dims_mkl_order, - pad_l, pad_r); + GetOutputAndPadSizeInMklOrder(input.shape(), filter.shape(), + strides, output_dims_tf_order, + output_dims_mkl_order, pad_l, pad_r); } // Wrapper function to calculate input, filter, and output sizes of @@ -275,12 +279,15 @@ class MklDnnConvUtil { // also calculates strides and paddings for 2D Convolution. // // Function does not return anything, but sets error in context status. - inline void GetConvFwdSizesInMklOrder( - const TensorShape &input_shape, const TensorShape &filter_shape, - memory::dims *input_dims, memory::dims *filter_dims, - memory::dims *strides, memory::dims *output_dims_tf_order, - memory::dims *output_dims_mkl_order, memory::dims *pad_l, - memory::dims *pad_r) { + inline void GetConvFwdSizesInMklOrder(const TensorShape& input_shape, + const TensorShape& filter_shape, + memory::dims *input_dims, + memory::dims *filter_dims, + memory::dims *strides, + memory::dims *output_dims_tf_order, + memory::dims *output_dims_mkl_order, + memory::dims *pad_l, + memory::dims *pad_r) { CHECK_NOTNULL(input_dims); CHECK_NOTNULL(filter_dims); CHECK_NOTNULL(strides); @@ -295,7 +302,8 @@ class MklDnnConvUtil { if (!context_->status().ok()) return; GetStridesInMklOrder(strides); GetOutputAndPadSizeInMklOrder(input_shape, filter_shape, *strides, - output_dims_tf_order, output_dims_mkl_order, + output_dims_tf_order, + output_dims_mkl_order, pad_l, pad_r); if (!context_->status().ok()) return; } |