aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_concat_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_concat_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_concat_op.cc4
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc
index 82771792d7..d109bb6bcf 100644
--- a/tensorflow/core/kernels/mkl_concat_op.cc
+++ b/tensorflow/core/kernels/mkl_concat_op.cc
@@ -598,7 +598,6 @@ class MklConcatOp : public OpKernel {
concat_dim_tensor.shape().DebugString()));
int32 concat_dim = internal::SubtleMustCopy(
concat_dim_tensor.scalar<int32>()());
- if (concat_dim < 0) concat_dim = N + concat_dim;
// check that ranks of all tensors match
// and that their shapes match except for concat_dim.
@@ -609,6 +608,9 @@ class MklConcatOp : public OpKernel {
input_shapes[0].GetTfShape() :
input_tensors[0].shape();
size_t expected_dims = expected_shape.dims();
+
+ if (concat_dim < 0) concat_dim = expected_dims + concat_dim;
+
for (auto& s : input_shapes) {
if (s == expected_shape) {++i; continue;}