aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_aggregate_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_aggregate_ops.cc')
-rw-r--r--tensorflow/core/kernels/mkl_aggregate_ops.cc20
1 files changed, 16 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/mkl_aggregate_ops.cc b/tensorflow/core/kernels/mkl_aggregate_ops.cc
index 28edf51546..20aa1f7ea1 100644
--- a/tensorflow/core/kernels/mkl_aggregate_ops.cc
+++ b/tensorflow/core/kernels/mkl_aggregate_ops.cc
@@ -392,16 +392,28 @@ class MklAddNOp : public OpKernel {
memory::format src1_mkl_data_format = src1_mkl_shape.GetTfDataFormat();
auto src1_tf_data_format =
MklDnnDataFormatToTFDataFormat(src1_mkl_data_format);
- auto src2_dims =
- TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(), src1_tf_data_format);
+ memory::dims src2_dims;
+ if (src2_tensor.dims() == 4) {
+ src2_dims = TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(),
+ src1_tf_data_format);
+ } else {
+ src2_dims = TFShapeToMklDnnDimsInNCDHW(src2_tensor.shape(),
+ src1_tf_data_format);
+ }
md2 = memory::desc(src2_dims, MklDnnType<T>(), src1_mkl_data_format);
} else if (input2_in_mkl_format && !input1_in_mkl_format) {
// Same comment as above.
memory::format src2_mkl_data_format = src2_mkl_shape.GetTfDataFormat();
auto src2_tf_data_format =
MklDnnDataFormatToTFDataFormat(src2_mkl_data_format);
- auto src1_dims =
- TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(), src2_tf_data_format);
+ memory::dims src1_dims;
+ if (src1_tensor.dims() == 4) {
+ src1_dims = TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(),
+ src2_tf_data_format);
+ } else {
+ src1_dims = TFShapeToMklDnnDimsInNCDHW(src1_tensor.shape(),
+ src2_tf_data_format);
+ }
md1 = memory::desc(src1_dims, MklDnnType<T>(), src2_mkl_data_format);
md2 = src2_mkl_shape.GetMklLayout();