diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/categorical_op.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/categorical_op.cc | 18 |
1 files changed, 8 insertions, 10 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index c137d026bd..1784e712b5 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -74,16 +74,14 @@ class CategoricalOp : public XlaOpKernel { // See: // https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ // TODO(b/68769470): Switch to using a cumulative sum approach. - auto softmax_entries = - xla::Sub(logits, xla::Log(xla::Neg(xla::Log(uniforms))), - /*broadcast_dimensions=*/{0, 2}); - - TensorShape softmax_shape(uniform_shape_array); - xla::XlaOp argmax; - OP_REQUIRES_OK( - ctx, - XlaHelpers::ArgMax(builder, ctx, softmax_entries, softmax_shape, - input_type(0), output_type(0), /*axis=*/2, &argmax)); + auto softmax_entries = xla::Sub(logits, xla::Log(-xla::Log(uniforms)), + /*broadcast_dimensions=*/{0, 2}); + + xla::PrimitiveType xla_output_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(output_type(0), &xla_output_type)); + xla::XlaOp argmax = + XlaHelpers::ArgMax(softmax_entries, xla_output_type, /*axis=*/2); ctx->SetOutput(0, argmax); } |