aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/softmax_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/softmax_op.cc152
1 files changed, 152 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
new file mode 100644
index 0000000000..06ee520163
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
@@ -0,0 +1,152 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Ops for softmax.
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+namespace {
+
+class SoftmaxOp : public XlaOpKernel {
+ public:
+ explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ log_ = StringPiece(type_string()).starts_with("Log");
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape logits_shape = ctx->InputShape(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
+ errors::InvalidArgument("logits must be 2-dimensional"));
+
+ const int kBatchDim = 0;
+ const int kClassDim = 1;
+
+ const DataType type = input_type(0);
+ auto logits = ctx->Input(0);
+
+ xla::ComputationBuilder* b = ctx->builder();
+ const xla::Computation& max_func = *ctx->GetOrCreateMax(type);
+ const xla::Computation& add_func = *ctx->GetOrCreateAdd(type);
+
+ // Find the max in each batch, resulting in a tensor of shape [batch]
+ auto logits_max =
+ b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim});
+ // Subtract the max in batch b from every element in batch b. Broadcasts
+ // along the batch dimension.
+ auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim});
+ xla::ComputationDataHandle softmax;
+ if (log_) {
+ // softmax = shifted_logits - log(sum(exp(shifted_logits)))
+ auto log_sum_exp =
+ b->Log(b->Reduce(b->Exp(shifted_logits), XlaHelpers::Zero(b, type),
+ add_func, {kClassDim}));
+ softmax = b->Sub(shifted_logits, log_sum_exp, {kBatchDim});
+ } else {
+ // softmax = exp(shifted_logits) / sum(exp(shifted_logits))
+ auto exp_shifted = b->Exp(shifted_logits);
+ auto sum_exp = b->Reduce(exp_shifted, XlaHelpers::Zero(b, type), add_func,
+ {kClassDim});
+ softmax = b->Div(exp_shifted, sum_exp, {kBatchDim});
+ }
+
+ ctx->SetOutput(0, softmax);
+ }
+
+ private:
+ bool log_;
+};
+
+REGISTER_XLA_OP("Softmax", SoftmaxOp);
+REGISTER_XLA_OP("LogSoftmax", SoftmaxOp);
+
+class SoftmaxXentWithLogitsOp : public XlaOpKernel {
+ public:
+ explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape logits_shape = ctx->InputShape(0);
+ const TensorShape labels_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, logits_shape.IsSameSize(labels_shape),
+ errors::InvalidArgument(
+ "logits and labels must be same size: logits_size=",
+ logits_shape.DebugString(), " labels_size=",
+ labels_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
+ errors::InvalidArgument("logits must be 2-dimensional"));
+ // As we already tested that both inputs have the same shape no need to
+ // check that "labels" is a matrix too.
+
+ // loss is 1-D (one per example), and size is batch_size.
+
+ const int kBatchDim = 0;
+ const int kClassDim = 1;
+
+ const DataType type = input_type(0);
+ xla::ComputationBuilder* b = ctx->builder();
+ auto logits = ctx->Input(0);
+ auto labels = ctx->Input(1);
+
+ const xla::Computation& max_func = *ctx->GetOrCreateMax(type);
+ const xla::Computation& add_func = *ctx->GetOrCreateAdd(type);
+
+ // Find the max in each batch, resulting in a tensor of shape [batch]
+ auto logits_max =
+ b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim});
+
+ // Subtract the max in batch b from every element in batch b.
+ // Broadcasts along the batch dimension.
+ auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim});
+
+ // exp(logits - max_logits)
+ auto exp_shifted_logits = b->Exp(shifted_logits);
+
+ // sum_{class} (exp(logits - max_logits))
+ auto sum_exp = b->Reduce(exp_shifted_logits, XlaHelpers::Zero(b, type),
+ add_func, {kClassDim});
+
+ // log(sum(exp(logits - max_logits)))
+ auto log_sum_exp = b->Log(sum_exp);
+
+ // sum(-labels *
+ // ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
+ // along classes
+ // (The subtraction broadcasts along the batch dimension.)
+ xla::ComputationDataHandle loss =
+ b->Reduce(b->Mul(b->Neg(labels),
+ b->Sub(shifted_logits, log_sum_exp, {kBatchDim})),
+ XlaHelpers::Zero(b, type), add_func, {kClassDim});
+
+ // backprop: prob - labels, where
+ // prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
+ // (where the division broadcasts along the batch dimension)
+ xla::ComputationDataHandle backprop =
+ b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels);
+
+ ctx->SetOutput(0, loss);
+ ctx->SetOutput(1, backprop);
+ }
+};
+
+REGISTER_XLA_OP("SoftmaxCrossEntropyWithLogits", SoftmaxXentWithLogitsOp);
+
+} // namespace
+} // namespace tensorflow