diff options
author | 2016-05-23 15:51:44 -0800 | |
---|---|---|
committer | 2016-05-23 17:03:46 -0700 | |
commit | bf51a7c7dc1c5bf04f31362208986d69d0c456df (patch) | |
tree | 813125e71752370e521f0c790cb83c6464d41246 /tensorflow/core/kernels/conv_ops_gpu.h | |
parent | d4ef9aa02c3c8297a053176918beaf34c13b73a6 (diff) |
Change the ConvParameters comparison function to use std::tuple to avoid
including the paddings.
Change: 123053215
Diffstat (limited to 'tensorflow/core/kernels/conv_ops_gpu.h')
-rw-r--r-- | tensorflow/core/kernels/conv_ops_gpu.h | 15 |
1 files changed, 13 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h index 419ba4dfc6..14e7d033eb 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.h +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -18,6 +18,7 @@ limitations under the License. #if GOOGLE_CUDA +#include <tuple> #include "tensorflow/core/platform/stream_executor.h" namespace tensorflow { @@ -95,8 +96,18 @@ struct ConvParameters { int64 padding_cols; int device_id; + typedef std::tuple<int64, int64, int64, int64, int64, int64, int64, int64, + int64, int64, int64, int> + DataType; + + DataType get_data_as_tuple() const { + return std::make_tuple(batch, in_depths, in_rows, in_cols, out_depths, + filter_rows, filter_cols, stride_rows, stride_cols, + padding_rows, padding_cols, device_id); + } + bool operator==(const ConvParameters& other) const { - return memcmp(this, &other, sizeof(ConvParameters)) == 0; + return this->get_data_as_tuple() == other.get_data_as_tuple(); } bool operator!=(const ConvParameters& other) const { @@ -104,7 +115,7 @@ struct ConvParameters { } bool operator<(const ConvParameters& other) const { - return memcmp(this, &other, sizeof(ConvParameters)) < 0; + return this->get_data_as_tuple() < other.get_data_as_tuple(); } }; |