aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Blake Hechtman <blakehechtman@google.com>2018-01-17 16:49:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-17 16:53:45 -0800
commit0304b6630f3fcda6602d24e91ab3dd31e46f495f (patch)
tree3cb76bcf54f526aeab0e5bd04eacdd81a68178e3
parent713ff801c842a279042a3e90cd1b3eaf313ec348 (diff)
Use explicit broadcasts in ResizeBilinear Kernel creation to make fusion easier.
PiperOrigin-RevId: 182291519
-rw-r--r--tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc6
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
index bedabc78e8..a4fe7f0e93 100644
--- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc
@@ -126,8 +126,10 @@ xla::ComputationDataHandle MakeBilinearResizeKernel(
XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota));
auto diag = builder->ConvertElementType(
- builder->Eq(builder->Reshape(channels_iota, {1, 1, 1, channels}),
- channels_iota, /*broadcast_dimensions=*/{2}),
+ builder->Eq(
+ builder->Broadcast(channels_iota, {2 * kernel_size[0] - 1,
+ 2 * kernel_size[1] - 1, channels}),
+ channels_iota, /*broadcast_dimensions=*/{2}),
xla::PrimitiveType::F32);
return builder->Mul(
builder->Mul(diag,