diff options
author | 2018-01-17 16:49:54 -0800 | |
---|---|---|
committer | 2018-01-17 16:53:45 -0800 | |
commit | 0304b6630f3fcda6602d24e91ab3dd31e46f495f (patch) | |
tree | 3cb76bcf54f526aeab0e5bd04eacdd81a68178e3 | |
parent | 713ff801c842a279042a3e90cd1b3eaf313ec348 (diff) |
Use explicit broadcasts in ResizeBilinear Kernel creation to make fusion easier.
PiperOrigin-RevId: 182291519
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index bedabc78e8..a4fe7f0e93 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -126,8 +126,10 @@ xla::ComputationDataHandle MakeBilinearResizeKernel( XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); auto diag = builder->ConvertElementType( - builder->Eq(builder->Reshape(channels_iota, {1, 1, 1, channels}), - channels_iota, /*broadcast_dimensions=*/{2}), + builder->Eq( + builder->Broadcast(channels_iota, {2 * kernel_size[0] - 1, + 2 * kernel_size[1] - 1, channels}), + channels_iota, /*broadcast_dimensions=*/{2}), xla::PrimitiveType::F32); return builder->Mul( builder->Mul(diag, |