aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/softmax_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/softmax_op.cc33
1 files changed, 21 insertions, 12 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
index d1c69f08b0..60c6a5d349 100644
--- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
@@ -15,10 +15,13 @@ limitations under the License.
// XLA-specific Ops for softmax.
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -42,23 +45,27 @@ class SoftmaxOp : public XlaOpKernel {
const int kClassDim = 1;
const DataType type = input_type(0);
+ const xla::PrimitiveType xla_type = ctx->input_xla_type(0);
auto logits = ctx->Input(0);
xla::XlaBuilder* const b = ctx->builder();
const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type);
// Find the max in each batch, resulting in a tensor of shape [batch]
- auto logits_max = xla::Reduce(logits, XlaHelpers::MinValue(b, type),
- max_func, {kClassDim});
+ auto logits_max =
+ xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim});
// Subtract the max in batch b from every element in batch b. Broadcasts
// along the batch dimension.
auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim});
auto exp_shifted = xla::Exp(shifted_logits);
const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
+ xla::PrimitiveType xla_accumulation_type;
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(accumulation_type,
+ &xla_accumulation_type));
auto converted =
- XlaHelpers::ConvertElementType(b, exp_shifted, accumulation_type);
+ xla::ConvertElementType(exp_shifted, xla_accumulation_type);
auto reduce =
- xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
+ xla::Reduce(converted, xla::Zero(b, xla_accumulation_type),
*ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
auto sum = XlaHelpers::ConvertElementType(b, reduce, type);
auto softmax =
@@ -78,8 +85,8 @@ REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp);
REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp);
std::pair<xla::XlaOp, xla::XlaOp> CrossEntropyWithLogits(
- XlaOpKernelContext* ctx, DataType type, const xla::XlaOp& logits,
- const xla::XlaOp& labels) {
+ XlaOpKernelContext* ctx, DataType type, xla::PrimitiveType xla_type,
+ xla::XlaOp logits, xla::XlaOp labels) {
const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type);
const int kBatchDim = 0;
@@ -88,7 +95,7 @@ std::pair<xla::XlaOp, xla::XlaOp> CrossEntropyWithLogits(
xla::XlaBuilder* b = ctx->builder();
// Find the max in each batch, resulting in a tensor of shape [batch]
auto logits_max =
- xla::Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim});
+ xla::Reduce(logits, xla::MinValue(b, xla_type), max_func, {kClassDim});
// Subtract the max in batch b from every element in batch b.
// Broadcasts along the batch dimension.
@@ -148,12 +155,13 @@ class SoftmaxXentWithLogitsOp : public XlaOpKernel {
// check that "labels" is a matrix too.
const DataType type = input_type(0);
+ const xla::PrimitiveType xla_type = ctx->input_xla_type(0);
auto logits = ctx->Input(0);
auto labels = ctx->Input(1);
xla::XlaOp loss, backprop;
std::tie(loss, backprop) =
- CrossEntropyWithLogits(ctx, type, logits, labels);
+ CrossEntropyWithLogits(ctx, type, xla_type, logits, labels);
ctx->SetOutput(0, loss);
ctx->SetOutput(1, backprop);
}
@@ -189,8 +197,9 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
int64 batch_size = logits_shape.dim_size(0);
int64 depth = logits_shape.dim_size(1);
- DataType logits_type = input_type(0);
- DataType indices_type = input_type(1);
+ const DataType logits_type = input_type(0);
+ const xla::PrimitiveType xla_logits_type = ctx->input_xla_type(0);
+ const DataType indices_type = input_type(1);
xla::XlaOp indices = ctx->Input(1);
@@ -218,8 +227,8 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
labels = xla::Add(labels, nan_or_zero, {0});
xla::XlaOp loss, backprop;
- std::tie(loss, backprop) =
- CrossEntropyWithLogits(ctx, logits_type, ctx->Input(0), labels);
+ std::tie(loss, backprop) = CrossEntropyWithLogits(
+ ctx, logits_type, xla_logits_type, ctx->Input(0), labels);
ctx->SetOutput(0, loss);
ctx->SetOutput(1, backprop);
}