aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc19
1 files changed, 9 insertions, 10 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc
index 92b371cc4e..852d2a966e 100644
--- a/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc
@@ -172,15 +172,14 @@ class DepthwiseConv2dNativeOp : public XlaOpKernel {
} else {
// These will be used to define the bounds of each slice.
// Within the loop, the input_channel index will be modified.
- gtl::InlinedVector<int64, 4> filter_begin;
- gtl::InlinedVector<int64, 4> filter_limits;
- gtl::InlinedVector<int64, 4> input_begin;
- gtl::InlinedVector<int64, 4> input_limits;
+ gtl::InlinedVector<int64, 4> filter_begin(4, 0);
+ gtl::InlinedVector<int64, 4> filter_limits(4);
+ gtl::InlinedVector<int64, 4> input_begin(4, 0);
+ gtl::InlinedVector<int64, 4> input_limits(4);
+ gtl::InlinedVector<int64, 4> strides(4, 1);
for (int i = 0; i < 4; ++i) {
- filter_begin.push_back(0);
- filter_limits.push_back(filter_shape.dim_size(i));
- input_begin.push_back(0);
- input_limits.push_back(input_shape.dim_size(i));
+ filter_limits[i] = filter_shape.dim_size(i);
+ input_limits[i] = input_shape.dim_size(i);
}
std::vector<int64> strides_for_tla{strides_[1], strides_[2]};
@@ -209,9 +208,9 @@ class DepthwiseConv2dNativeOp : public XlaOpKernel {
input_limits[3] = i + 1;
xla::ComputationDataHandle filter_slice =
- b.Slice(filter, filter_begin, filter_limits);
+ b.Slice(filter, filter_begin, filter_limits, strides);
xla::ComputationDataHandle input_slice =
- b.Slice(input, input_begin, input_limits);
+ b.Slice(input, input_begin, input_limits, strides);
convs.push_back(b.ConvWithGeneralDimensions(
input_slice, filter_slice, strides_for_tla, xla_padding, dims));
}