diff options
author | 2018-01-09 13:56:07 -0800 | |
---|---|---|
committer | 2018-01-09 13:59:37 -0800 | |
commit | 6c95dc837a676c5da0f3183382ea54278e896a65 (patch) | |
tree | 604f5fbf529811673267fc8e1deaed4362bce195 | |
parent | 3e852d462aaba446f62f76007405c0794a6087b9 (diff) |
[TF:XLA] Use broadcasts instead of larger constants for image resizing.
PiperOrigin-RevId: 181369272
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index c0b8f9c179..bedabc78e8 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -113,23 +113,27 @@ xla::ComputationDataHandle MakeBilinearResizeKernel( auto make_1d_kernel = [](int64 n) { std::vector<float> kernel(n * 2 - 1); for (int64 i = 0; i < n; ++i) { - float v = i + 1; + float v = (i + 1.0f) / n; kernel[i] = v; kernel[n * 2 - 2 - i] = v; } return kernel; }; - // Form a block diagonal kernel where each channel interacts only with itself. - xla::Array4D<float> diag(1, 1, channels, channels, 0.0f); - for (int i = 0; i < channels; ++i) { - diag(0, 0, i, i) = 1.0f / (kernel_size[0] * kernel_size[1]); - } + xla::ComputationDataHandle channels_iota; + // DT_INT32 Iota will always return status::OK(). + TF_CHECK_OK( + XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); + + auto diag = builder->ConvertElementType( + builder->Eq(builder->Reshape(channels_iota, {1, 1, 1, channels}), + channels_iota, /*broadcast_dimensions=*/{2}), + xla::PrimitiveType::F32); return builder->Mul( - builder->ConstantR1<float>(make_1d_kernel(kernel_size[0])), - builder->Mul(builder->ConstantR1<float>(make_1d_kernel(kernel_size[1])), - builder->ConstantR4FromArray4D(diag), + builder->Mul(diag, + builder->ConstantR1<float>(make_1d_kernel(kernel_size[1])), /*broadcast_dimensions=*/{1}), + builder->ConstantR1<float>(make_1d_kernel(kernel_size[0])), /*broadcast_dimensions=*/{0}); } |