aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/nn_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/nn_ops.cc')
-rw-r--r--tensorflow/core/ops/nn_ops.cc8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 536fc7c0c1..3f72b41569 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -1818,7 +1818,11 @@ REGISTER_OP("_MklMaxPool")
.Input("input: T")
.Input("mkl_input: uint8")
.Output("output: T")
+#ifndef INTEL_MKL_DNN
.Output("workspace: T")
+#else
+ .Output("workspace: uint8")
+#endif
.Output("mkl_output: uint8")
.Output("mkl_workspace: uint8")
.SetShapeFn(shape_inference::MaxPoolShape)
@@ -1840,7 +1844,11 @@ REGISTER_OP("_MklMaxPoolGrad")
.Input("orig_input: T")
.Input("orig_output: T")
.Input("grad: T")
+#ifndef INTEL_MKL_DNN
.Input("workspace: T")
+#else
+ .Input("workspace: uint8")
+#endif
.Input("mkl_orig_input: uint8")
.Input("mkl_orig_output: uint8")
.Input("mkl_grad: uint8")