diff options
Diffstat (limited to 'tensorflow/core/kernels/conv_ops_gpu.h')
-rw-r--r-- | tensorflow/core/kernels/conv_ops_gpu.h | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h index d2c8020bb6..afc611f277 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.h +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -85,13 +85,15 @@ class ConvParameters { public: using SpatialArray = gtl::InlinedVector<int64, 3>; ConvParameters(int64 batch, int64 in_depths, const SpatialArray& in, - int64 out_depths, const SpatialArray& filter, - const SpatialArray& dilation, const SpatialArray& stride, - const SpatialArray& padding, DataType dtype, int device_id) + TensorFormat data_format, int64 out_depths, + const SpatialArray& filter, const SpatialArray& dilation, + const SpatialArray& stride, const SpatialArray& padding, + DataType dtype, int device_id) : batch_(batch), in_depths_(in_depths), out_depths_(out_depths), in_(in), + data_format_(data_format), filter_(filter), dilation_(dilation), stride_(stride), @@ -101,6 +103,7 @@ class ConvParameters { hash_code_ = batch; hash_code_ = Hash64Combine(hash_code_, in_depths); for (int64 val : in) hash_code_ = Hash64Combine(hash_code_, val); + hash_code_ = Hash64Combine(hash_code_, data_format); hash_code_ = Hash64Combine(hash_code_, out_depths); for (int64 val : filter) hash_code_ = Hash64Combine(hash_code_, val); for (int64 val : dilation) hash_code_ = Hash64Combine(hash_code_, val); @@ -123,6 +126,7 @@ class ConvParameters { return strings::StrCat( batch_, ", ", in_depths_, ", ", "(", str_util::Join(in_, ", "), "), ", + ::tensorflow::ToString(data_format_), ", ", out_depths_, ", ", "(", str_util::Join(filter_, ", "), "), ", "(", str_util::Join(dilation_, ", "), "), ", @@ -148,12 +152,13 @@ class ConvParameters { protected: using ParameterDataType = - std::tuple<int64, int64, SpatialArray, int64, SpatialArray, SpatialArray, - SpatialArray, SpatialArray, DataType, int>; + std::tuple<int64, int64, SpatialArray, TensorFormat, int64, SpatialArray, + SpatialArray, SpatialArray, SpatialArray, DataType, int>; ParameterDataType get_data_as_tuple() const { - return std::make_tuple(batch_, in_depths_, in_, out_depths_, filter_, - dilation_, stride_, padding_, dtype_, device_id_); + return std::make_tuple(batch_, in_depths_, in_, data_format_, out_depths_, + filter_, dilation_, stride_, padding_, dtype_, + device_id_); } uint64 hash_code_; @@ -178,6 +183,7 @@ class ConvParameters { int64 in_depths_; int64 out_depths_; SpatialArray in_; + TensorFormat data_format_; SpatialArray filter_; SpatialArray dilation_; SpatialArray stride_; |