aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
blob: d1c69f08b0bc85fc47c03015054dd18a65eeedec (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// XLA-specific Ops for softmax.

#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/strings/str_util.h"

namespace tensorflow {
namespace {

class SoftmaxOp : public XlaOpKernel {
 public:
  explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
    log_ = str_util::StartsWith(type_string(), "Log");
  }

  void Compile(XlaOpKernelContext* ctx) override {
    const TensorShape logits_shape = ctx->InputShape(0);
    OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
                errors::InvalidArgument("logits must be 2-dimensional"));

    const int kBatchDim = 0;
    const int kClassDim = 1;

    const DataType type = input_type(0);
    auto logits = ctx->Input(0);

    xla::XlaBuilder* const b = ctx->builder();
    const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type);

    // Find the max in each batch, resulting in a tensor of shape [batch]
    auto logits_max = xla::Reduce(logits, XlaHelpers::MinValue(b, type),
                                  max_func, {kClassDim});
    // Subtract the max in batch b from every element in batch b. Broadcasts
    // along the batch dimension.
    auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim});
    auto exp_shifted = xla::Exp(shifted_logits);
    const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
    auto converted =
        XlaHelpers::ConvertElementType(b, exp_shifted, accumulation_type);
    auto reduce =
        xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
                    *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
    auto sum = XlaHelpers::ConvertElementType(b, reduce, type);
    auto softmax =
        log_
            // softmax = shifted_logits - log(sum(exp(shifted_logits)))
            ? xla::Sub(shifted_logits, xla::Log(sum), {kBatchDim})
            // softmax = exp(shifted_logits) / sum(exp(shifted_logits))
            : xla::Div(exp_shifted, sum, {kBatchDim});
    ctx->SetOutput(0, softmax);
  }

 private:
  bool log_;
};

REGISTER_XLA_OP(Name("Softmax"), SoftmaxOp);
REGISTER_XLA_OP(Name("LogSoftmax"), SoftmaxOp);

std::pair<xla::XlaOp, xla::XlaOp> CrossEntropyWithLogits(
    XlaOpKernelContext* ctx, DataType type, const xla::XlaOp& logits,
    const xla::XlaOp& labels) {
  const xla::XlaComputation& max_func = *ctx->GetOrCreateMax(type);

  const int kBatchDim = 0;
  const int kClassDim = 1;

  xla::XlaBuilder* b = ctx->builder();
  // Find the max in each batch, resulting in a tensor of shape [batch]
  auto logits_max =
      xla::Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim});

  // Subtract the max in batch b from every element in batch b.
  // Broadcasts along the batch dimension.
  auto shifted_logits = xla::Sub(logits, logits_max, {kBatchDim});

  // exp(logits - max_logits)
  auto exp_shifted_logits = xla::Exp(shifted_logits);

  // sum_{class} (exp(logits - max_logits))
  const DataType accumulation_type = XlaHelpers::SumAccumulationType(type);
  auto converted =
      XlaHelpers::ConvertElementType(b, exp_shifted_logits, accumulation_type);
  auto reduce =
      xla::Reduce(converted, XlaHelpers::Zero(b, accumulation_type),
                  *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
  auto sum_exp = XlaHelpers::ConvertElementType(b, reduce, type);

  // log(sum(exp(logits - max_logits)))
  auto log_sum_exp = xla::Log(sum_exp);

  // sum(-labels *
  //    ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
  // along classes
  // (The subtraction broadcasts along the batch dimension.)
  auto sub = xla::Sub(shifted_logits, log_sum_exp, {kBatchDim});
  auto mul = xla::Mul(xla::Neg(labels), sub);
  auto sum =
      xla::Reduce(XlaHelpers::ConvertElementType(b, mul, accumulation_type),
                  XlaHelpers::Zero(b, accumulation_type),
                  *ctx->GetOrCreateAdd(accumulation_type), {kClassDim});
  auto loss = XlaHelpers::ConvertElementType(b, sum, type);

  // backprop: prob - labels, where
  //   prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
  //     (where the division broadcasts along the batch dimension)
  xla::XlaOp backprop =
      xla::Sub(xla::Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels);
  return {loss, backprop};
}

class SoftmaxXentWithLogitsOp : public XlaOpKernel {
 public:
  explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* ctx)
      : XlaOpKernel(ctx) {}

  void Compile(XlaOpKernelContext* ctx) override {
    const TensorShape logits_shape = ctx->InputShape(0);
    const TensorShape labels_shape = ctx->InputShape(1);
    OP_REQUIRES(ctx, logits_shape.IsSameSize(labels_shape),
                errors::InvalidArgument(
                    "logits and labels must be same size: logits_size=",
                    logits_shape.DebugString(),
                    " labels_size=", labels_shape.DebugString()));
    OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
                errors::InvalidArgument("logits must be 2-dimensional"));
    // As we already tested that both inputs have the same shape no need to
    // check that "labels" is a matrix too.

    const DataType type = input_type(0);
    auto logits = ctx->Input(0);
    auto labels = ctx->Input(1);

    xla::XlaOp loss, backprop;
    std::tie(loss, backprop) =
        CrossEntropyWithLogits(ctx, type, logits, labels);
    ctx->SetOutput(0, loss);
    ctx->SetOutput(1, backprop);
  }
};

REGISTER_XLA_OP(Name("SoftmaxCrossEntropyWithLogits"), SoftmaxXentWithLogitsOp);

class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel {
 public:
  explicit SparseSoftmaxXentWithLogitsOp(OpKernelConstruction* ctx)
      : XlaOpKernel(ctx) {}

  void Compile(XlaOpKernelContext* ctx) override {
    const TensorShape logits_shape = ctx->InputShape(0);
    const TensorShape labels_shape = ctx->InputShape(1);
    OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
                errors::InvalidArgument("logits must be 2-D, but got shape ",
                                        logits_shape.DebugString()));
    OP_REQUIRES(ctx, TensorShapeUtils::IsVector(labels_shape),
                errors::InvalidArgument("labels must be 1-D, but got shape ",
                                        labels_shape.DebugString()));
    OP_REQUIRES(ctx, logits_shape.dim_size(0) == labels_shape.dim_size(0),
                errors::InvalidArgument(
                    "logits and labels must have the same first dimension, "
                    "got logits shape ",
                    logits_shape.DebugString(), " and labels shape ",
                    labels_shape.DebugString()));
    OP_REQUIRES(ctx, logits_shape.dim_size(1) > 0,
                errors::InvalidArgument(
                    "Must have at least one class, but got logits shape ",
                    logits_shape.DebugString()));

    int64 batch_size = logits_shape.dim_size(0);
    int64 depth = logits_shape.dim_size(1);

    DataType logits_type = input_type(0);
    DataType indices_type = input_type(1);

    xla::XlaOp indices = ctx->Input(1);

    xla::XlaBuilder* builder = ctx->builder();
    xla::XlaOp labels;
    OP_REQUIRES_OK(ctx,
                   XlaHelpers::OneHot(
                       builder, depth, /*axis=*/1, input_type(1), labels_shape,
                       indices, XlaHelpers::One(builder, logits_type),
                       XlaHelpers::Zero(builder, logits_type), &labels));

    // If any of the indices are out of range, we must populate the labels with
    // NaNs to obey the interface contract of
    // tf.nn.sparse_softmax_cross_entropy_with_logits.
    // Builds a vector of {batch_size} that is 0 if the index is in range, or
    // NaN otherwise; then add that vector to the labels to force out-of-range
    // values to NaNs.
    xla::XlaOp nan_or_zero = xla::Select(
        xla::And(xla::Le(XlaHelpers::Zero(builder, indices_type), indices),
                 xla::Lt(indices, XlaHelpers::IntegerLiteral(
                                      builder, indices_type, depth))),
        xla::Broadcast(XlaHelpers::Zero(builder, logits_type), {batch_size}),
        xla::Broadcast(XlaHelpers::FloatLiteral(builder, logits_type, NAN),
                       {batch_size}));
    labels = xla::Add(labels, nan_or_zero, {0});

    xla::XlaOp loss, backprop;
    std::tie(loss, backprop) =
        CrossEntropyWithLogits(ctx, logits_type, ctx->Input(0), labels);
    ctx->SetOutput(0, loss);
    ctx->SetOutput(1, backprop);
  }
};

REGISTER_XLA_OP(Name("SparseSoftmaxCrossEntropyWithLogits"),
                SparseSoftmaxXentWithLogitsOp);

}  // namespace
}  // namespace tensorflow