aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/conv_ops_gpu.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/conv_ops_gpu.h')
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu.h20
1 files changed, 13 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index d2c8020bb6..afc611f277 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -85,13 +85,15 @@ class ConvParameters {
public:
using SpatialArray = gtl::InlinedVector<int64, 3>;
ConvParameters(int64 batch, int64 in_depths, const SpatialArray& in,
- int64 out_depths, const SpatialArray& filter,
- const SpatialArray& dilation, const SpatialArray& stride,
- const SpatialArray& padding, DataType dtype, int device_id)
+ TensorFormat data_format, int64 out_depths,
+ const SpatialArray& filter, const SpatialArray& dilation,
+ const SpatialArray& stride, const SpatialArray& padding,
+ DataType dtype, int device_id)
: batch_(batch),
in_depths_(in_depths),
out_depths_(out_depths),
in_(in),
+ data_format_(data_format),
filter_(filter),
dilation_(dilation),
stride_(stride),
@@ -101,6 +103,7 @@ class ConvParameters {
hash_code_ = batch;
hash_code_ = Hash64Combine(hash_code_, in_depths);
for (int64 val : in) hash_code_ = Hash64Combine(hash_code_, val);
+ hash_code_ = Hash64Combine(hash_code_, data_format);
hash_code_ = Hash64Combine(hash_code_, out_depths);
for (int64 val : filter) hash_code_ = Hash64Combine(hash_code_, val);
for (int64 val : dilation) hash_code_ = Hash64Combine(hash_code_, val);
@@ -123,6 +126,7 @@ class ConvParameters {
return strings::StrCat(
batch_, ", ", in_depths_, ", ",
"(", str_util::Join(in_, ", "), "), ",
+ ::tensorflow::ToString(data_format_), ", ",
out_depths_, ", ",
"(", str_util::Join(filter_, ", "), "), ",
"(", str_util::Join(dilation_, ", "), "), ",
@@ -148,12 +152,13 @@ class ConvParameters {
protected:
using ParameterDataType =
- std::tuple<int64, int64, SpatialArray, int64, SpatialArray, SpatialArray,
- SpatialArray, SpatialArray, DataType, int>;
+ std::tuple<int64, int64, SpatialArray, TensorFormat, int64, SpatialArray,
+ SpatialArray, SpatialArray, SpatialArray, DataType, int>;
ParameterDataType get_data_as_tuple() const {
- return std::make_tuple(batch_, in_depths_, in_, out_depths_, filter_,
- dilation_, stride_, padding_, dtype_, device_id_);
+ return std::make_tuple(batch_, in_depths_, in_, data_format_, out_depths_,
+ filter_, dilation_, stride_, padding_, dtype_,
+ device_id_);
}
uint64 hash_code_;
@@ -178,6 +183,7 @@ class ConvParameters {
int64 in_depths_;
int64 out_depths_;
SpatialArray in_;
+ TensorFormat data_format_;
SpatialArray filter_;
SpatialArray dilation_;
SpatialArray stride_;