diff options
author | 2018-09-07 09:21:07 -0700 | |
---|---|---|
committer | 2018-09-07 09:25:09 -0700 | |
commit | abfbafac6409cee3bd1a959096b07abcac875a9d (patch) | |
tree | ce0a688eb304e65b543071638da483fde0b2b053 | |
parent | 81110ff2beb38a2cbfbefb69a9b640bf67a8558a (diff) |
mkl_pooling_ops_common.cc: convert asserts to DCHECKs.
DCHECK is more idiomatic in the Tensorflow code base. Also, some of the
"not-reached" asserts were actually inverted, asserting an always-true rather
than an always-false expression.
PiperOrigin-RevId: 211986533
-rw-r--r-- | tensorflow/core/kernels/mkl_pooling_ops_common.cc | 28 |
1 files changed, 14 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc index ec6d241e17..5398e6113f 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc @@ -34,11 +34,11 @@ using mkldnn::prop_kind; template <typename T> void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) { - if (fwdParams.alg_kind != pooling_max && fwdParams.alg_kind != pooling_avg && - fwdParams.alg_kind != pooling_avg_include_padding && - fwdParams.alg_kind != pooling_avg_exclude_padding) { - assert("Pooling algorithm kind is not supported\n"); - } + DCHECK(fwdParams.alg_kind == pooling_max || + fwdParams.alg_kind == pooling_avg || + fwdParams.alg_kind == pooling_avg_include_padding || + fwdParams.alg_kind == pooling_avg_exclude_padding) + << "Pooling algorithm kind is not supported"; context_.alg_kind = fwdParams.alg_kind; // create memory desc @@ -102,7 +102,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data, static_cast<void*>(const_cast<T*>(src_data))); context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); if (context_.alg_kind == pooling_max) { // max pooling must have ws - assert(ws_data != nullptr); + DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(ws_data); } context_.fwd_stream->submit(context_.fwd_primitives); @@ -111,7 +111,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data, context_.src_mem->set_data_handle(DummyData); context_.dst_mem->set_data_handle(DummyData); if (context_.alg_kind == pooling_max) { // max pooling must have ws - assert(ws_data != nullptr); + DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(DummyData); } } @@ -120,11 +120,11 @@ template class MklPoolingFwdPrimitive<float>; template <typename T> void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) { - if (bwdParams.alg_kind != pooling_max && bwdParams.alg_kind != pooling_avg && - bwdParams.alg_kind != pooling_avg_include_padding && - bwdParams.alg_kind != pooling_avg_exclude_padding) { - assert("Pooling algorithm kind is not supported\n"); - } + DCHECK(bwdParams.alg_kind == pooling_max || + bwdParams.alg_kind == pooling_avg || + bwdParams.alg_kind == pooling_avg_include_padding || + bwdParams.alg_kind == pooling_avg_exclude_padding) + << "Pooling algorithm kind is not supported"; context_.alg_kind = bwdParams.alg_kind; // check whether it is 2d or 3d @@ -190,7 +190,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data, static_cast<void*>(const_cast<T*>(diff_dst_data))); context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data)); if (context_.alg_kind == pooling_max) { - assert(ws_data != nullptr); + DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(const_cast<void*>(ws_data)); } @@ -199,7 +199,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data, context_.diff_dst_mem->set_data_handle(DummyData); context_.diff_src_mem->set_data_handle(DummyData); if (context_.alg_kind == pooling_max) { - assert(ws_data != nullptr); + DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(DummyData); } } |