aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/lib/arithmetic.cc
blob: 8c314fa61bbd67774c91e7e34e93730dbe77eb8d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/client/lib/arithmetic.h"

#include <string>

#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"

namespace xla {
namespace {

using XlaOpGenerator = XlaOp (*)(XlaBuilder*, const XlaOp&, const XlaOp&);

XlaComputation CreateScalarComputation(const string& name, PrimitiveType type,
                                       XlaBuilder* builder,
                                       XlaOpGenerator generator) {
  std::unique_ptr<XlaBuilder> b;
  if (type == PRED) {
    b = builder->CreateSubBuilder(name);
  } else {
    b = builder->CreateSubBuilder(
        tensorflow::strings::StrCat(name, "_", PrimitiveType_Name(type)));
  }

  const Shape scalar = ShapeUtil::MakeShape(type, {});
  auto lhs = Parameter(b.get(), 0, scalar, "lhs");
  auto rhs = Parameter(b.get(), 1, scalar, "rhs");
  generator(b.get(), lhs, rhs);
  return b->BuildAndNoteError();
}

}  // namespace

XlaComputation CreateScalarAddComputation(PrimitiveType type,
                                          XlaBuilder* builder) {
  return CreateScalarComputation(
      "add", type, builder,
      [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
        return Add(lhs, rhs);
      });
}

XlaComputation CreateScalarMultiplyComputation(PrimitiveType type,
                                               XlaBuilder* builder) {
  return CreateScalarComputation(
      "mul", type, builder,
      [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
        return Mul(lhs, rhs);
      });
}

XlaComputation CreateScalarGeComputation(PrimitiveType type,
                                         XlaBuilder* builder) {
  return CreateScalarComputation("ge", type, builder,
                                 [](XlaBuilder* b, const XlaOp& lhs,
                                    const XlaOp& rhs) { return Ge(lhs, rhs); });
}

XlaComputation CreateScalarMaxComputation(PrimitiveType type,
                                          XlaBuilder* builder) {
  return CreateScalarComputation(
      "max", type, builder,
      [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
        return Max(lhs, rhs);
      });
}

XlaComputation CreateScalarMinComputation(PrimitiveType type,
                                          XlaBuilder* builder) {
  return CreateScalarComputation(
      "min", type, builder,
      [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
        return Min(lhs, rhs);
      });
}

XlaComputation CreateScalarAndComputation(XlaBuilder* builder) {
  return CreateScalarComputation(
      "and", PRED, builder,
      [](XlaBuilder* b, const XlaOp& lhs, const XlaOp& rhs) {
        return And(lhs, rhs);
      });
}

XlaComputation CreateScalarOrComputation(XlaBuilder* builder) {
  return CreateScalarComputation("or", PRED, builder,
                                 [](XlaBuilder* b, const XlaOp& lhs,
                                    const XlaOp& rhs) { return Or(lhs, rhs); });
}

XlaOp Any(XlaOp predicates) {
  XlaBuilder* builder = predicates.builder();
  return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
    auto f = ConstantR0<bool>(builder, false);
    XlaComputation logical_or = CreateScalarOrComputation(builder);
    TF_ASSIGN_OR_RETURN(const Shape& predicates_shape,
                        builder->GetShape(predicates));
    std::vector<int64> all_dimensions(ShapeUtil::Rank(predicates_shape));
    std::iota(all_dimensions.begin(), all_dimensions.end(), 0);
    return Reduce(predicates, f, logical_or, all_dimensions);
  });
}

namespace {
XlaOp FloatLiteral(XlaBuilder* b, PrimitiveType data_type, float value) {
  return ConvertElementType(ConstantR0(b, value), data_type);
}

// Polynomials for computing erf/erfc.  Originally from cephes.
// Note we use float for compatibility across devices, at the cost of some
// precision for 64 bit computations.
//
// Coefficients are in descending order.
std::array<float, 9> kErfcPCoefficient = {
    2.46196981473530512524E-10, 5.64189564831068821977E-1,
    7.46321056442269912687E0,   4.86371970985681366614E1,
    1.96520832956077098242E2,   5.26445194995477358631E2,
    9.34528527171957607540E2,   1.02755188689515710272E3,
    5.57535335369399327526E2};
std::array<float, 9> kErfcQCoefficient = {
    1.00000000000000000000E0, 1.32281951154744992508E1,
    8.67072140885989742329E1, 3.54937778887819891062E2,
    9.75708501743205489753E2, 1.82390916687909736289E3,
    2.24633760818710981792E3, 1.65666309194161350182E3,
    5.57535340817727675546E2};
std::array<float, 6> kErfcRCoefficient = {
    5.64189583547755073984E-1, 1.27536670759978104416E0,
    5.01905042251180477414E0,  6.16021097993053585195E0,
    7.40974269950448939160E0,  2.97886665372100240670E0};
std::array<float, 7> kErfcSCoefficient = {
    1.00000000000000000000E0, 2.26052863220117276590E0,
    9.39603524938001434673E0, 1.20489539808096656605E1,
    1.70814450747565897222E1, 9.60896809063285878198E0,
    3.36907645100081516050E0};
std::array<float, 5> kErfTCoefficient = {
    9.60497373987051638749E0, 9.00260197203842689217E1,
    2.23200534594684319226E3, 7.00332514112805075473E3,
    5.55923013010394962768E4};
std::array<float, 6> kErfUCoefficient = {
    1.00000000000000000000E0, 3.35617141647503099647E1,
    5.21357949780152679795E2, 4.59432382970980127987E3,
    2.26290000613890934246E4, 4.92673942608635921086E4};
}  // namespace

// Evaluate the polynomial given coefficients and `x`.
// N.B. Coefficients should be supplied in decreasing order.
XlaOp EvaluatePolynomial(XlaOp x,
                         tensorflow::gtl::ArraySlice<float> coefficients,
                         PrimitiveType data_type) {
  XlaBuilder* b = x.builder();
  XlaOp poly = FloatLiteral(b, data_type, 0.0);
  for (float c : coefficients) {
    poly = Add(Mul(poly, x), FloatLiteral(b, data_type, c));
  }
  return poly;
}

// Compute an approximation of the error function complement (1 - erf(x)).
XlaOp Erfc(XlaOp x, PrimitiveType data_type) {
  XlaBuilder* b = x.builder();
  XlaOp zero = FloatLiteral(b, data_type, 0.0);
  XlaOp two = FloatLiteral(b, data_type, 2.0);
  XlaOp eight = FloatLiteral(b, data_type, 8.0);

  XlaOp abs_x = Abs(x);
  XlaOp z = Exp(Mul(Neg(x), x));

  XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient, data_type);
  XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient, data_type);
  XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient, data_type);
  XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient, data_type);

  XlaOp y = Select(Lt(abs_x, eight), Div(Mul(z, pp), pq), Div(Mul(z, pr), ps));

  return Select(Lt(x, zero), Sub(two, y), y);
}

// Compute a polynomial approximation of the error function.
XlaOp Erf(XlaOp x, PrimitiveType data_type) {
  XlaOp z = Mul(x, x);
  XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient, data_type);
  XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient, data_type);
  return Div(Mul(x, pt), pu);
}

// Approximation for the inverse error function from
//   Giles, M., "Approximating the erfinv function".
// The approximation has the form:
//   w = -log((1 - x) * (1 + x))
//   if ( w < 5 ) {
//     w = w - 2.5
//     p = sum_{i=1}^n lq[i]*w^i
//   } else {
//     w = sqrt(w) - 3
//     p = sum_{i=1}^n gq[i]*w^i
//   }
//   return p*x
XlaOp ErfInv(XlaOp x) {
  XlaBuilder* b = x.builder();
  return b->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
    TF_ASSIGN_OR_RETURN(Shape shape, b->GetShape(x));
    constexpr int kDegree = 9;
    constexpr std::array<float, 9> w_less_than_5_constants = {
        2.81022636e-08f,  3.43273939e-07f, -3.5233877e-06f,
        -4.39150654e-06f, 0.00021858087f,  -0.00125372503f,
        -0.00417768164f,  0.246640727f,    1.50140941f};
    constexpr std::array<float, 9> w_greater_than_5_constants = {
        -0.000200214257f, 0.000100950558f, 0.00134934322f,
        -0.00367342844f,  0.00573950773f,  -0.0076224613f,
        0.00943887047f,   1.00167406f,     2.83297682f};

    auto one = ConstantR0<float>(b, 1.0);
    auto w = Neg(Log(Mul(Sub(one, x), Add(one, x))));

    auto lt = Lt(w, ConstantR0<float>(b, 5.0));
    auto coefficient = [&](int i) {
      return Select(
          lt,
          Broadcast(ConstantR0<float>(b, w_less_than_5_constants[i]),
                    AsInt64Slice(shape.dimensions())),
          Broadcast(ConstantR0<float>(b, w_greater_than_5_constants[i]),
                    AsInt64Slice(shape.dimensions())));
    };
    w = Select(lt, Sub(w, ConstantR0<float>(b, 2.5f)),
               Sub(SqrtF32(w), ConstantR0<float>(b, 3.0f)));
    auto p = coefficient(0);
    for (int i = 1; i < kDegree; ++i) {
      p = Add(coefficient(i), Mul(p, w));
    }
    return Mul(p, x);
  });
}

}  // namespace xla