diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_conv_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_conv_ops.cc | 42 |
1 files changed, 19 insertions, 23 deletions
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index b818819b02..76b9f1798d 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -36,9 +36,9 @@ limitations under the License. #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/core/util/mkl_util.h" #include "third_party/mkl/include/mkl_dnn.h" #include "third_party/mkl/include/mkl_dnn_types.h" -#include "tensorflow/core/util/mkl_util.h" namespace tensorflow { @@ -98,19 +98,18 @@ class MklConv2DOp : public OpKernel { filter.shape().DebugString())); for (int i = 0; i < 3; i++) { - OP_REQUIRES( - context, - FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()), - errors::InvalidArgument("filter too large")); + OP_REQUIRES(context, FastBoundsCheck(filter.dim_size(i), + std::numeric_limits<int>::max()), + errors::InvalidArgument("filter too large")); } const int64 input_depth = input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'C') : GetTensorDim(input, data_format_, 'C'); - OP_REQUIRES(context, input_depth == filter.dim_size(2), - errors::InvalidArgument( - "input and filter must have the same depth: ", input_depth, - " vs ", filter.dim_size(2))); + OP_REQUIRES( + context, input_depth == filter.dim_size(2), + errors::InvalidArgument("input and filter must have the same depth: ", + input_depth, " vs ", filter.dim_size(2))); // The last dimension for filter is out_depth. const int out_depth = static_cast<int>(filter.dim_size(3)); @@ -119,10 +118,9 @@ class MklConv2DOp : public OpKernel { const int64 input_rows_raw = input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'H') : GetTensorDim(input, data_format_, 'H'); - OP_REQUIRES( - context, - FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()), - errors::InvalidArgument("Input rows too large")); + OP_REQUIRES(context, FastBoundsCheck(input_rows_raw, + std::numeric_limits<int>::max()), + errors::InvalidArgument("Input rows too large")); const int input_rows = static_cast<int>(input_rows_raw); const int filter_rows = static_cast<int>(filter.dim_size(0)); @@ -131,10 +129,9 @@ class MklConv2DOp : public OpKernel { const int64 input_cols_raw = input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'W') : GetTensorDim(input, data_format_, 'W'); - OP_REQUIRES( - context, - FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()), - errors::InvalidArgument("Input cols too large")); + OP_REQUIRES(context, FastBoundsCheck(input_cols_raw, + std::numeric_limits<int>::max()), + errors::InvalidArgument("Input cols too large")); const int input_cols = static_cast<int>(input_cols_raw); const int filter_cols = static_cast<int>(filter.dim_size(1)); @@ -142,10 +139,9 @@ class MklConv2DOp : public OpKernel { const int64 input_batch_raw = input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'N') : GetTensorDim(input, data_format_, 'N'); - OP_REQUIRES( - context, - FastBoundsCheck(input_batch_raw, std::numeric_limits<int>::max()), - errors::InvalidArgument("batch is too large")); + OP_REQUIRES(context, FastBoundsCheck(input_batch_raw, + std::numeric_limits<int>::max()), + errors::InvalidArgument("batch is too large")); const int batch = static_cast<int>(input_batch_raw); // For now we take the stride from the second and third dimensions only (we @@ -438,12 +434,12 @@ class MklConv2DOp : public OpKernel { }; #define REGISTER_MKL_CPU(T) \ - REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \ + REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \ .Device(DEVICE_CPU) \ .TypeConstraint<T>("T") \ .Label(mkl_op_registry::kMklOpLabel), \ MklConv2DOp<CPUDevice, T, false>); \ - REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \ + REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \ .Device(DEVICE_CPU) \ .TypeConstraint<T>("T") \ .Label(mkl_op_registry::kMklOpLabel), \ |