aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_conv_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_ops.cc')
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc42
1 files changed, 19 insertions, 23 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index b818819b02..76b9f1798d 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -36,9 +36,9 @@ limitations under the License.
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/util/mkl_util.h"
#include "third_party/mkl/include/mkl_dnn.h"
#include "third_party/mkl/include/mkl_dnn_types.h"
-#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
@@ -98,19 +98,18 @@ class MklConv2DOp : public OpKernel {
filter.shape().DebugString()));
for (int i = 0; i < 3; i++) {
- OP_REQUIRES(
- context,
- FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
- errors::InvalidArgument("filter too large"));
+ OP_REQUIRES(context, FastBoundsCheck(filter.dim_size(i),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("filter too large"));
}
const int64 input_depth =
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'C')
: GetTensorDim(input, data_format_, 'C');
- OP_REQUIRES(context, input_depth == filter.dim_size(2),
- errors::InvalidArgument(
- "input and filter must have the same depth: ", input_depth,
- " vs ", filter.dim_size(2)));
+ OP_REQUIRES(
+ context, input_depth == filter.dim_size(2),
+ errors::InvalidArgument("input and filter must have the same depth: ",
+ input_depth, " vs ", filter.dim_size(2)));
// The last dimension for filter is out_depth.
const int out_depth = static_cast<int>(filter.dim_size(3));
@@ -119,10 +118,9 @@ class MklConv2DOp : public OpKernel {
const int64 input_rows_raw =
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'H')
: GetTensorDim(input, data_format_, 'H');
- OP_REQUIRES(
- context,
- FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("Input rows too large"));
+ OP_REQUIRES(context, FastBoundsCheck(input_rows_raw,
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input rows too large"));
const int input_rows = static_cast<int>(input_rows_raw);
const int filter_rows = static_cast<int>(filter.dim_size(0));
@@ -131,10 +129,9 @@ class MklConv2DOp : public OpKernel {
const int64 input_cols_raw =
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'W')
: GetTensorDim(input, data_format_, 'W');
- OP_REQUIRES(
- context,
- FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("Input cols too large"));
+ OP_REQUIRES(context, FastBoundsCheck(input_cols_raw,
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input cols too large"));
const int input_cols = static_cast<int>(input_cols_raw);
const int filter_cols = static_cast<int>(filter.dim_size(1));
@@ -142,10 +139,9 @@ class MklConv2DOp : public OpKernel {
const int64 input_batch_raw =
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'N')
: GetTensorDim(input, data_format_, 'N');
- OP_REQUIRES(
- context,
- FastBoundsCheck(input_batch_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("batch is too large"));
+ OP_REQUIRES(context, FastBoundsCheck(input_batch_raw,
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("batch is too large"));
const int batch = static_cast<int>(input_batch_raw);
// For now we take the stride from the second and third dimensions only (we
@@ -438,12 +434,12 @@ class MklConv2DOp : public OpKernel {
};
#define REGISTER_MKL_CPU(T) \
- REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
MklConv2DOp<CPUDevice, T, false>); \
- REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \