aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/fused_conv
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/fused_conv')
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc3
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h12
2 files changed, 9 insertions, 6 deletions
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index 4554a3d89a..0ccb4583ab 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -443,6 +443,8 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
: dnn::DataLayout::kBatchDepthYX;
constexpr auto filter_layout = is_int8x4 ? dnn::FilterLayout::kOutputInputYX4
: dnn::FilterLayout::kOutputInputYX;
+ constexpr auto compute_data_format =
+ is_int8x4 ? FORMAT_NCHW_VECT_C : FORMAT_NCHW;
dnn::BatchDescriptor conv_input_desc;
conv_input_desc.set_count(batch_size)
@@ -529,6 +531,7 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
batch_size,
conv_input_depth,
{{conv_input_rows, conv_input_cols}},
+ compute_data_format,
output_depth,
{{filter_rows, filter_cols}},
// TODO(yangzihao): Add support for arbitrary dilations for fused conv.
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h
index ba52697679..b9c131a2e9 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h
@@ -29,13 +29,13 @@ namespace tensorflow {
class FusedConvParameters : public ConvParameters {
public:
FusedConvParameters(int64 batch, int64 in_depths, const SpatialArray& in,
- int64 out_depths, const SpatialArray& filter,
- const SpatialArray& dilation, const SpatialArray& stride,
- const SpatialArray& padding, DataType dtype,
- int device_id, bool has_side_input,
+ TensorFormat data_format, int64 out_depths,
+ const SpatialArray& filter, const SpatialArray& dilation,
+ const SpatialArray& stride, const SpatialArray& padding,
+ DataType dtype, int device_id, bool has_side_input,
ActivationMode activation_mode)
- : ConvParameters(batch, in_depths, in, out_depths, filter, dilation,
- stride, padding, dtype, device_id),
+ : ConvParameters(batch, in_depths, in, data_format, out_depths, filter,
+ dilation, stride, padding, dtype, device_id),
activation_mode_(activation_mode),
has_side_input_(has_side_input) {
hash_code_ = Hash64Combine(hash_code_, has_side_input);