diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index f752fb3ae2..16b778bca4 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -94,12 +94,14 @@ class BatchMatMulOp : public XlaOpKernel { // Slice off individual matrices and reshape to 2D tensors. auto x_slice = builder->Slice( x_flat, {i, 0, 0}, - {i + 1, x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)}); + {i + 1, x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)}, + {1, 1, 1}); x_slice = builder->Reshape( x_slice, {x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)}); auto y_slice = builder->Slice( y_flat, {i, 0, 0}, - {i + 1, y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)}); + {i + 1, y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)}, + {1, 1, 1}); y_slice = builder->Reshape( y_slice, {y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)}); |