diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/softmax_op.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/softmax_op.cc | 33 |
1 files changed, 21 insertions, 12 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index d1c69f08b0..60c6a5d349 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -15,10 +15,13 @@ limitations under the License. // XLA-specific Ops for softmax. +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -42,23 +45,27 @@ class SoftmaxOp : public XlaOpKernel { const int kClassDim = 1; const DataType type = input_type(0); + const xla::PrimitiveType xla_type = ctx->input_xla_type(0); auto logits = ctx->Input(0); xla::XlaBuilder* const b = ctx->builder(); const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type); // Find the max in each batch, resulting in a tensor of shape [batch] - auto logits_max = xla::Reduce(logits, XlaHelpers::MinValue(b, type), - max_func, {kClassDim}); + auto logits_max = + xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim}); // Subtract the max in batch b from every element in batch b. Broadcasts // along the batch dimension. auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim}); auto exp_shifted = xla::Exp(shifted_logits); const DataType accumulation_type = XlaHelpers::SumAccumulationType(type); + xla::PrimitiveType xla_accumulation_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(accumulation_type, + &xla_accumulation_type)); auto converted = - XlaHelpers::ConvertElementType(b, exp_shifted, accumulation_type); + xla::ConvertElementType(exp_shifted, xla_accumulation_type); auto reduce = - xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type), + xla::Reduce(converted, xla::Zero(b, xla_accumulation_type), *ctx->GetOrCreateAdd(accumulation_type), {kClassDim}); auto sum = XlaHelpers::ConvertElementType(b, reduce, type); auto softmax = @@ -78,8 +85,8 @@ REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp); REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp); std::pair<xla::XlaOp, xla::XlaOp> CrossEntropyWithLogits( - XlaOpKernelContext* ctx, DataType type, const xla::XlaOp& logits, - const xla::XlaOp& labels) { + XlaOpKernelContext* ctx, DataType type, xla::PrimitiveType xla_type, + xla::XlaOp logits, xla::XlaOp labels) { const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type); const int kBatchDim = 0; @@ -88,7 +95,7 @@ std::pair<xla::XlaOp, xla::XlaOp> CrossEntropyWithLogits( xla::XlaBuilder* b = ctx->builder(); // Find the max in each batch, resulting in a tensor of shape [batch] auto logits_max = - xla::Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim}); + xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim}); // Subtract the max in batch b from every element in batch b. // Broadcasts along the batch dimension. @@ -148,12 +155,13 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel { // check that "labels" is a matrix too. const DataType type = input_type(0); + const xla::PrimitiveType xla_type = ctx->input_xla_type(0); auto logits = ctx->Input(0); auto labels = ctx->Input(1); xla::XlaOp loss, backprop; std::tie(loss, backprop) = - CrossEntropyWithLogits(ctx, type, logits, labels); + CrossEntropyWithLogits(ctx, type, xla_type, logits, labels); ctx->SetOutput(0, loss); ctx->SetOutput(1, backprop); } @@ -189,8 +197,9 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { int64 batch_size = logits_shape.dim_size(0); int64 depth = logits_shape.dim_size(1); - DataType logits_type = input_type(0); - DataType indices_type = input_type(1); + const DataType logits_type = input_type(0); + const xla::PrimitiveType xla_logits_type = ctx->input_xla_type(0); + const DataType indices_type = input_type(1); xla::XlaOp indices = ctx->Input(1); @@ -218,8 +227,8 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { labels = xla::Add(labels, nan_or_zero, {0}); xla::XlaOp loss, backprop; - std::tie(loss, backprop) = - CrossEntropyWithLogits(ctx, logits_type, ctx->Input(0), labels); + std::tie(loss, backprop) = CrossEntropyWithLogits( + ctx, logits_type, xla_logits_type, ctx->Input(0), labels); ctx->SetOutput(0, loss); ctx->SetOutput(1, backprop); } |