aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Blake Hechtman <blakehechtman@google.com>2018-01-09 13:56:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-09 13:59:37 -0800
commit6c95dc837a676c5da0f3183382ea54278e896a65 (patch)
tree604f5fbf529811673267fc8e1deaed4362bce195
parent3e852d462aaba446f62f76007405c0794a6087b9 (diff)
[TF:XLA] Use broadcasts instead of larger constants for image resizing.
PiperOrigin-RevId: 181369272
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc22
1 files changed, 13 insertions, 9 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index c0b8f9c179..bedabc78e8 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -113,23 +113,27 @@ xla::ComputationDataHandle MakeBilinearResizeKernel(
auto make_1d_kernel = [](int64 n) {
std::vector<float> kernel(n * 2 - 1);
for (int64 i = 0; i < n; ++i) {
- float v = i + 1;
+ float v = (i + 1.0f) / n;
kernel[i] = v;
kernel[n * 2 - 2 - i] = v;
}
return kernel;
};
- // Form a block diagonal kernel where each channel interacts only with itself.
- xla::Array4D<float> diag(1, 1, channels, channels, 0.0f);
- for (int i = 0; i < channels; ++i) {
- diag(0, 0, i, i) = 1.0f / (kernel_size[0] * kernel_size[1]);
- }
+ xla::ComputationDataHandle channels_iota;
+ // DT_INT32 Iota will always return status::OK().
+ TF_CHECK_OK(
+ XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota));
+
+ auto diag = builder->ConvertElementType(
+ builder->Eq(builder->Reshape(channels_iota, {1, 1, 1, channels}),
+ channels_iota, /*broadcast_dimensions=*/{2}),
+ xla::PrimitiveType::F32);
return builder->Mul(
- builder->ConstantR1<float>(make_1d_kernel(kernel_size[0])),
- builder->Mul(builder->ConstantR1<float>(make_1d_kernel(kernel_size[1])),
- builder->ConstantR4FromArray4D(diag),
+ builder->Mul(diag,
+ builder->ConstantR1<float>(make_1d_kernel(kernel_size[1])),
/*broadcast_dimensions=*/{1}),
+ builder->ConstantR1<float>(make_1d_kernel(kernel_size[0])),
/*broadcast_dimensions=*/{0});
}