aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-07 09:21:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 09:25:09 -0700
commitabfbafac6409cee3bd1a959096b07abcac875a9d (patch)
treece0a688eb304e65b543071638da483fde0b2b053
parent81110ff2beb38a2cbfbefb69a9b640bf67a8558a (diff)
mkl_pooling_ops_common.cc: convert asserts to DCHECKs.
DCHECK is more idiomatic in the Tensorflow code base. Also, some of the "not-reached" asserts were actually inverted, asserting an always-true rather than an always-false expression. PiperOrigin-RevId: 211986533
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.cc28
1 files changed, 14 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
index ec6d241e17..5398e6113f 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
@@ -34,11 +34,11 @@ using mkldnn::prop_kind;
template <typename T>
void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
- if (fwdParams.alg_kind != pooling_max && fwdParams.alg_kind != pooling_avg &&
- fwdParams.alg_kind != pooling_avg_include_padding &&
- fwdParams.alg_kind != pooling_avg_exclude_padding) {
- assert("Pooling algorithm kind is not supported\n");
- }
+ DCHECK(fwdParams.alg_kind == pooling_max ||
+ fwdParams.alg_kind == pooling_avg ||
+ fwdParams.alg_kind == pooling_avg_include_padding ||
+ fwdParams.alg_kind == pooling_avg_exclude_padding)
+ << "Pooling algorithm kind is not supported";
context_.alg_kind = fwdParams.alg_kind;
// create memory desc
@@ -102,7 +102,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
if (context_.alg_kind == pooling_max) { // max pooling must have ws
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(ws_data);
}
context_.fwd_stream->submit(context_.fwd_primitives);
@@ -111,7 +111,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
context_.src_mem->set_data_handle(DummyData);
context_.dst_mem->set_data_handle(DummyData);
if (context_.alg_kind == pooling_max) { // max pooling must have ws
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(DummyData);
}
}
@@ -120,11 +120,11 @@ template class MklPoolingFwdPrimitive<float>;
template <typename T>
void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
- if (bwdParams.alg_kind != pooling_max && bwdParams.alg_kind != pooling_avg &&
- bwdParams.alg_kind != pooling_avg_include_padding &&
- bwdParams.alg_kind != pooling_avg_exclude_padding) {
- assert("Pooling algorithm kind is not supported\n");
- }
+ DCHECK(bwdParams.alg_kind == pooling_max ||
+ bwdParams.alg_kind == pooling_avg ||
+ bwdParams.alg_kind == pooling_avg_include_padding ||
+ bwdParams.alg_kind == pooling_avg_exclude_padding)
+ << "Pooling algorithm kind is not supported";
context_.alg_kind = bwdParams.alg_kind;
// check whether it is 2d or 3d
@@ -190,7 +190,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
static_cast<void*>(const_cast<T*>(diff_dst_data)));
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
if (context_.alg_kind == pooling_max) {
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(const_cast<void*>(ws_data));
}
@@ -199,7 +199,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
context_.diff_dst_mem->set_data_handle(DummyData);
context_.diff_src_mem->set_data_handle(DummyData);
if (context_.alg_kind == pooling_max) {
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(DummyData);
}
}