diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_concat_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_concat_op.cc | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc index 82771792d7..d109bb6bcf 100644 --- a/tensorflow/core/kernels/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl_concat_op.cc @@ -598,7 +598,6 @@ class MklConcatOp : public OpKernel { concat_dim_tensor.shape().DebugString())); int32 concat_dim = internal::SubtleMustCopy( concat_dim_tensor.scalar<int32>()()); - if (concat_dim < 0) concat_dim = N + concat_dim; // check that ranks of all tensors match // and that their shapes match except for concat_dim. @@ -609,6 +608,9 @@ class MklConcatOp : public OpKernel { input_shapes[0].GetTfShape() : input_tensors[0].shape(); size_t expected_dims = expected_shape.dims(); + + if (concat_dim < 0) concat_dim = expected_dims + concat_dim; + for (auto& s : input_shapes) { if (s == expected_shape) {++i; continue;} |