aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/categorical_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/categorical_op.cc18
1 files changed, 8 insertions, 10 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
index c137d026bd..1784e712b5 100644
--- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc
@@ -74,16 +74,14 @@ class CategoricalOp : public XlaOpKernel {
// See:
// https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/
// TODO(b/68769470): Switch to using a cumulative sum approach.
- auto softmax_entries =
- xla::Sub(logits, xla::Log(xla::Neg(xla::Log(uniforms))),
- /*broadcast_dimensions=*/{0, 2});
-
- TensorShape softmax_shape(uniform_shape_array);
- xla::XlaOp argmax;
- OP_REQUIRES_OK(
- ctx,
- XlaHelpers::ArgMax(builder, ctx, softmax_entries, softmax_shape,
- input_type(0), output_type(0), /*axis=*/2, &argmax));
+ auto softmax_entries = xla::Sub(logits, xla::Log(-xla::Log(uniforms)),
+ /*broadcast_dimensions=*/{0, 2});
+
+ xla::PrimitiveType xla_output_type;
+ OP_REQUIRES_OK(ctx,
+ DataTypeToPrimitiveType(output_type(0), &xla_output_type));
+ xla::XlaOp argmax =
+ XlaHelpers::ArgMax(softmax_entries, xla_output_type, /*axis=*/2);
ctx->SetOutput(0, argmax);
}