aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_ops.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_ops.h')
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.h140
1 files changed, 74 insertions, 66 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h
index e29af19ca9..f0cb37f8a4 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.h
+++ b/tensorflow/core/kernels/mkl_conv_ops.h
@@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
#define TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
-#include <limits>
#include <vector>
+#include <limits>
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -26,8 +26,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/bounds_check.h"
-#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/numbers.h"
@@ -49,15 +49,15 @@ namespace tensorflow {
class MklDnnConvUtil {
protected:
- OpKernelContext *context_; // We don't own this.
+ OpKernelContext* context_; // We don't own this.
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
public:
- MklDnnConvUtil(OpKernelContext *context, const std::vector<int32> &strides,
- Padding pad, TensorFormat fm)
- : context_(context), strides_(strides), padding_(pad), data_format_(fm) {}
+ MklDnnConvUtil(OpKernelContext* context, const std::vector<int32>& strides,
+ Padding pad, TensorFormat fm) : context_(context),
+ strides_(strides), padding_(pad), data_format_(fm) {}
virtual ~MklDnnConvUtil() { context_ = nullptr; }
@@ -75,14 +75,14 @@ class MklDnnConvUtil {
// requires input in NCHW format. Function does not return anything.
// But errors arising from sanity checks are returned in context's
// status.
- virtual inline void GetInputSizeInMklOrder(const TensorShape &input_shape,
- memory::dims *input_dims) {
-#define CHECK_BOUNDS(val, err_msg) \
- do { \
- OP_REQUIRES(context_, \
- FastBoundsCheck(val, std::numeric_limits<int>::max()), \
- errors::InvalidArgument(err_msg)); \
- } while (0)
+ virtual inline void
+ GetInputSizeInMklOrder(const TensorShape& input_shape,
+ memory::dims *input_dims) {
+ #define CHECK_BOUNDS(val, err_msg) do { \
+ OP_REQUIRES(context_, FastBoundsCheck(val, \
+ std::numeric_limits<int>::max()), \
+ errors::InvalidArgument(err_msg)); \
+ }while(0)
CHECK_NOTNULL(input_dims);
@@ -105,7 +105,7 @@ class MklDnnConvUtil {
CHECK_BOUNDS(input_batch_raw, "Input batch too large");
int input_batch = static_cast<int>(input_batch_raw);
-#undef CHECK_BOUNDS
+ #undef CHECK_BOUNDS
// MKL-DNN always requires input in NCHW format.
*input_dims = {input_batch, input_depth, input_rows, input_cols};
@@ -125,9 +125,10 @@ class MklDnnConvUtil {
// forward gets actual tensor as input).
//
// TODO(nhasabni): Add similar function for input and filter in MklShape.
- virtual inline void GetFilterSizeInMklOrder(const TensorShape &input_shape,
- const TensorShape &filter_shape,
- memory::dims *filter_dims) {
+ virtual inline void
+ GetFilterSizeInMklOrder(const TensorShape& input_shape,
+ const TensorShape& filter_shape,
+ memory::dims *filter_dims) {
CHECK_NOTNULL(filter_dims);
OP_REQUIRES(context_, filter_shape.dims() == 4,
@@ -135,18 +136,17 @@ class MklDnnConvUtil {
filter_shape.DebugString()));
for (int i = 0; i < 3; i++) {
- OP_REQUIRES(context_,
- FastBoundsCheck(filter_shape.dim_size(i),
- std::numeric_limits<int>::max()),
- errors::InvalidArgument("filter too large"));
+ OP_REQUIRES(context_, FastBoundsCheck(filter_shape.dim_size(i),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("filter too large"));
}
int input_depth = GetTensorDim(input_shape, data_format_, 'C');
- OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2),
- errors::InvalidArgument(
- "input and filter must have the same depth: ", input_depth,
- " vs ", filter_shape.dim_size(2)));
+ OP_REQUIRES(
+ context_, input_depth == filter_shape.dim_size(2),
+ errors::InvalidArgument("input and filter must have the same depth: ",
+ input_depth, " vs ", filter_shape.dim_size(2)));
// TF filter is always in (rows, cols, in_depth, out_depth) order.
int filter_rows = static_cast<int>(filter_shape.dim_size(0));
@@ -163,25 +163,25 @@ class MklDnnConvUtil {
// requires filter in OIHW format. Function does not return anything.
// But errors arising from sanity checks are returned in context's
// status.
- virtual inline void GetFilterSizeInMklOrder(size_t src_index,
- size_t filter_index,
- memory::dims *filter_dims) {
+ virtual inline void
+ GetFilterSizeInMklOrder(size_t src_index, size_t filter_index,
+ memory::dims *filter_dims) {
CHECK_NOTNULL(filter_dims);
- const Tensor &input = MklGetInput(context_, src_index);
- const Tensor &filter = MklGetInput(context_, filter_index);
+ const Tensor& input = MklGetInput(context_, src_index);
+ const Tensor& filter = MklGetInput(context_, filter_index);
GetFilterSizeInMklOrder(input.shape(), filter.shape(), filter_dims);
}
// Calculate Bias size for 2D Convolution. Function does not return
// anything, but sets error in context status.
- virtual inline void GetBiasSizeInMklOrder(size_t bias_index,
- memory::dims *bias_dims) {
- const Tensor &bias = MklGetInput(context_, bias_index);
+ virtual inline void
+ GetBiasSizeInMklOrder(size_t bias_index, memory::dims *bias_dims) {
+ const Tensor& bias = MklGetInput(context_, bias_index);
OP_REQUIRES(context_, bias.dims() == 1,
errors::InvalidArgument("bias must be 1-dimensional: ",
bias.shape().DebugString()));
- *bias_dims = {static_cast<int>(bias.dim_size(0))};
+ *bias_dims = { static_cast<int>(bias.dim_size(0)) };
}
// Function to calculate output and padding size for 2D convolution.
@@ -193,11 +193,13 @@ class MklDnnConvUtil {
// status is returned via context status.
//
// TODO(nhasabni): Add similar function for input and filter in MklShape.
- virtual inline void GetOutputAndPadSizeInMklOrder(
- const TensorShape &input_shape, const TensorShape &filter_shape,
- const memory::dims &strides, memory::dims *output_dims_tf_order,
- memory::dims *output_dims_mkl_order, memory::dims *pad_l,
- memory::dims *pad_r) {
+ virtual inline void
+ GetOutputAndPadSizeInMklOrder(const TensorShape& input_shape,
+ const TensorShape& filter_shape,
+ const memory::dims& strides,
+ memory::dims *output_dims_tf_order,
+ memory::dims *output_dims_mkl_order,
+ memory::dims *pad_l, memory::dims *pad_r) {
CHECK_NOTNULL(output_dims_tf_order);
CHECK_NOTNULL(output_dims_mkl_order);
CHECK_NOTNULL(pad_l);
@@ -223,21 +225,21 @@ class MklDnnConvUtil {
int64 out_rows = 0, out_cols = 0;
int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right;
- OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
- input_rows, filter_rows, stride_rows, padding_,
- &out_rows, &pad_top, &pad_bottom));
- OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
- input_cols, filter_cols, stride_cols, padding_,
- &out_cols, &pad_left, &pad_right));
+ OP_REQUIRES_OK(context_,
+ GetWindowedOutputSizeVerbose(input_rows, filter_rows, stride_rows,
+ padding_, &out_rows, &pad_top, &pad_bottom));
+ OP_REQUIRES_OK(context_,
+ GetWindowedOutputSizeVerbose(input_cols, filter_cols, stride_cols,
+ padding_, &out_cols, &pad_left, &pad_right));
// Tensorflow output is in data_format order. (NHWC or NCHW)
- TensorShape out_shape =
- ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, out_depth);
+ TensorShape out_shape = ShapeFromFormat(data_format_, out_batch,
+ out_rows, out_cols, out_depth);
*output_dims_tf_order = TFShapeToMklDnnDims(out_shape);
// MKL-DNN always needs output in NCHW format.
*output_dims_mkl_order = {out_batch, out_depth, static_cast<int>(out_rows),
- static_cast<int>(out_cols)};
+ static_cast<int>(out_cols)};
// Now handle padding. MKL-DNN uses asymetric padding.
*pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
@@ -248,25 +250,27 @@ class MklDnnConvUtil {
// See comment on GetConvOutputAndPadSizeInMklOrder for parameters.
//
// Function does not return anything, but sets error in context status.
- inline void GetOutputAndPadSizeInMklOrder(
- size_t src_index, size_t filter_index, const memory::dims &strides,
- memory::dims *output_dims_tf_order, memory::dims *output_dims_mkl_order,
- memory::dims *pad_l, memory::dims *pad_r) {
+ inline void
+ GetOutputAndPadSizeInMklOrder(size_t src_index, size_t filter_index,
+ const memory::dims& strides,
+ memory::dims *output_dims_tf_order,
+ memory::dims *output_dims_mkl_order,
+ memory::dims *pad_l, memory::dims *pad_r) {
CHECK_NOTNULL(output_dims_tf_order);
CHECK_NOTNULL(output_dims_mkl_order);
CHECK_NOTNULL(pad_l);
CHECK_NOTNULL(pad_r);
- const Tensor &input = MklGetInput(context_, src_index);
- const Tensor &filter = MklGetInput(context_, filter_index);
+ const Tensor& input = MklGetInput(context_, src_index);
+ const Tensor& filter = MklGetInput(context_, filter_index);
OP_REQUIRES(context_, input.dims() == 4,
errors::InvalidArgument("input must be 4-dimensional",
- input.shape().DebugString()));
+ input.shape().DebugString()));
- GetOutputAndPadSizeInMklOrder(input.shape(), filter.shape(), strides,
- output_dims_tf_order, output_dims_mkl_order,
- pad_l, pad_r);
+ GetOutputAndPadSizeInMklOrder(input.shape(), filter.shape(),
+ strides, output_dims_tf_order,
+ output_dims_mkl_order, pad_l, pad_r);
}
// Wrapper function to calculate input, filter, and output sizes of
@@ -275,12 +279,15 @@ class MklDnnConvUtil {
// also calculates strides and paddings for 2D Convolution.
//
// Function does not return anything, but sets error in context status.
- inline void GetConvFwdSizesInMklOrder(
- const TensorShape &input_shape, const TensorShape &filter_shape,
- memory::dims *input_dims, memory::dims *filter_dims,
- memory::dims *strides, memory::dims *output_dims_tf_order,
- memory::dims *output_dims_mkl_order, memory::dims *pad_l,
- memory::dims *pad_r) {
+ inline void GetConvFwdSizesInMklOrder(const TensorShape& input_shape,
+ const TensorShape& filter_shape,
+ memory::dims *input_dims,
+ memory::dims *filter_dims,
+ memory::dims *strides,
+ memory::dims *output_dims_tf_order,
+ memory::dims *output_dims_mkl_order,
+ memory::dims *pad_l,
+ memory::dims *pad_r) {
CHECK_NOTNULL(input_dims);
CHECK_NOTNULL(filter_dims);
CHECK_NOTNULL(strides);
@@ -295,7 +302,8 @@ class MklDnnConvUtil {
if (!context_->status().ok()) return;
GetStridesInMklOrder(strides);
GetOutputAndPadSizeInMklOrder(input_shape, filter_shape, *strides,
- output_dims_tf_order, output_dims_mkl_order,
+ output_dims_tf_order,
+ output_dims_mkl_order,
pad_l, pad_r);
if (!context_->status().ok()) return;
}