aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/lib/cholesky.cc
blob: 397f0e3a7286ac46030ae602a4c059cd8aaa1ae1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/tf2xla/lib/cholesky.h"

#include <memory>
#include <vector>

#include "tensorflow/compiler/tf2xla/lib/batch_dot.h"
#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/errors.h"

namespace tensorflow {

namespace {

// The Cholesky–Banachiewicz algorithm. See
// https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky–Banachiewicz_and_Cholesky–Crout_algorithms
// for a description.
//
// def cholesky_unblocked(a):
//   assert len(a.shape) == 2 and a.shape[-2] == a.shape[-1]
//   n = a.shape[-2]
//   l = np.zeros_like(a)
//   for j in xrange(n):
//     row = l[..., j, :j]
//     row_t = np.swapaxes(row, -1, -2)
//     l[..., j, j] = np.sqrt(a[..., j, j] - np.dot(row, row_t))
//     l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], row_t)) /
//                       l[..., j, j]
//   return l
xla::StatusOr<xla::XlaOp> CholeskyUnblocked(xla::XlaBuilder* builder,
                                            const xla::XlaOp& a) {
  TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
  const int n_dims = xla::ShapeUtil::Rank(a_shape);
  const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
  gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(a_shape.dimensions()),
                                    /*pos=*/0,
                                    /*len=*/n_dims - 2);

  xla::XlaOp l = Zeros(builder, a_shape);

  // Construct the for loop body to iterate over rows.
  auto body_fn = [&](xla::XlaOp i, gtl::ArraySlice<xla::XlaOp> loop_vars,
                     xla::XlaBuilder* body_builder)
      -> xla::StatusOr<std::vector<xla::XlaOp>> {
    xla::Shape col_shape;
    xla::Shape row_shape;
    for (int64 d : major_dims) {
      row_shape.add_dimensions(d);
      col_shape.add_dimensions(d);
    }
    row_shape.add_dimensions(1);
    row_shape.add_dimensions(n);
    row_shape.set_element_type(a_shape.element_type());
    auto mask_zeros_row = Zeros(body_builder, row_shape);

    col_shape.add_dimensions(n);
    col_shape.add_dimensions(1);
    col_shape.set_element_type(a_shape.element_type());
    auto mask_zeros_col = Zeros(body_builder, col_shape);

    std::vector<int32> mask_vector(n);
    std::iota(mask_vector.begin(), mask_vector.end(), 0);
    auto mask_range = xla::ConstantR1<int32>(body_builder, mask_vector);
    auto mask_range_row =
        xla::Broadcast(xla::Reshape(mask_range, {0}, {1, n}), major_dims);
    auto mask_range_col =
        xla::Broadcast(xla::Reshape(mask_range, {0}, {n, 1}), major_dims);
    auto body_a = loop_vars[0];
    auto body_l = loop_vars[1];

    // row = l[..., i, :i]
    // select the whole i-th row, then mask out all columns past i-1
    auto zero = xla::ConstantR0<int32>(body_builder, 0);
    TF_ASSIGN_OR_RETURN(auto l_i, DynamicSliceInMinorDims(body_builder, body_l,
                                                          {i, zero}, {1, n}));
    auto row = xla::Select(xla::Ge(mask_range_row, i), mask_zeros_row, l_i);
    // a[..., i, i]
    TF_ASSIGN_OR_RETURN(auto a_ii, DynamicSliceInMinorDims(body_builder, body_a,
                                                           {i, i}, {1, 1}));
    // np.dot(row, np.swapaxes(row, -1, -2))
    xla::XlaOp diag_dot;
    TF_ASSIGN_OR_RETURN(diag_dot, BatchDot(body_builder, row, row,
                                           /*transpose_x=*/false,
                                           /*transpose_y=*/true));
    // l[..., i, i] = np.sqrt(a[..., i, i] - np.dot(row,
    //                                              np.swapaxes(row, -1, -2)))
    auto l_ii =
        xla::Pow(xla::Sub(a_ii, diag_dot),
                 FloatLiteral(body_builder, a_shape.element_type(), 0.5));

    // a[..., i+1:, i]
    // select the whole i-th column, then mask out all rows above i+1
    TF_ASSIGN_OR_RETURN(
        auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1}));
    auto a_ip1i = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, a_0i);

    // l[..., i+1:, i] = (a[..., i+1:, i] - np.dot(l[..., i+1:, :i], r.T)) /
    //                   l[..., i, i]
    // The columns in [i, n] are zeroed out in `row`, so we just have to
    // zero out rows above i+1 after the BatchDot. np.dot(l[..., :, :i],
    // r.T)
    TF_ASSIGN_OR_RETURN(auto dot, BatchDot(body_builder, body_l, row,
                                           /*transpose_x=*/false,
                                           /*transpose_y=*/true));
    // np.dot(l[..., i+1:, :i], r.T)
    auto dot_ip1 = xla::Select(xla::Le(mask_range_col, i), mask_zeros_col, dot);

    auto col_update = xla::Div(xla::Sub(a_ip1i, dot_ip1), l_ii);
    TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims(
                                    body_builder, body_l, col_update, {i}));
    // Assign the diagonal after the rest of the column because otherwise the
    // column assign will wrap around and overwrite the diagonal assign.
    TF_ASSIGN_OR_RETURN(body_l, DynamicUpdateSliceInMinorDims(
                                    body_builder, body_l, l_ii, {i, i}));

    return std::vector<xla::XlaOp>{body_a, body_l};
  };

  TF_ASSIGN_OR_RETURN(
      auto cholesky_while,
      XlaForEachIndex(n, xla::S32, body_fn, {a, l}, "unblocked", builder));

  return cholesky_while[1];
}

}  // namespace

xla::StatusOr<xla::XlaOp> Cholesky(xla::XlaBuilder* builder, xla::XlaOp a,
                                   int64 block_size) {
  TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
  const int ndims = xla::ShapeUtil::Rank(a_shape);
  if (ndims < 2) {
    return errors::InvalidArgument(
        "Arguments to Cholesky must have rank >= 2: ", ndims);
  }

  const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
  if (n != xla::ShapeUtil::GetDimension(a_shape, -2)) {
    return errors::InvalidArgument(
        "Arguments to Cholesky must be square matrices: ",
        xla::ShapeUtil::HumanString(a_shape));
  }

  if (block_size < 1) {
    return errors::InvalidArgument(
        "block_size argument to Cholesky must be >= 1; got ", block_size);
  }

  // Blocked left-looking Cholesky factorization.
  // Algorithm 1 from
  // Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only
  // execution." Proceedings of General Purpose GPUs. ACM, 2017.
  xla::XlaOp l = Zeros(builder, a_shape);
  for (int64 i = 0; i < n; i += block_size) {
    int64 k = std::min(block_size, n - i);
    if (i > 0) {
      // TODO(phawkins): consider implementing SYRK for the diagonal part of
      // the panel.
      // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i]))
      TF_ASSIGN_OR_RETURN(auto lhs,
                          SliceInMinorDims(builder, l, {i, 0}, {n, i}));
      TF_ASSIGN_OR_RETURN(auto rhs,
                          SliceInMinorDims(builder, l, {i, 0}, {i + k, i}));
      TF_ASSIGN_OR_RETURN(auto delta,
                          BatchDot(builder, lhs, rhs, /*transpose_x=*/false,
                                   /*transpose_y=*/true, /*conjugate_x=*/false,
                                   /*conjugate_y=*/false));
      TF_ASSIGN_OR_RETURN(auto before,
                          SliceInMinorDims(builder, a, {i, i}, {n, i + k}));
      TF_ASSIGN_OR_RETURN(a, UpdateSliceInMinorDims(
                                 builder, a, xla::Sub(before, delta), {i, i}));
    }

    // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k])
    TF_ASSIGN_OR_RETURN(auto x,
                        SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
    TF_ASSIGN_OR_RETURN(auto factorized, CholeskyUnblocked(builder, x));
    TF_ASSIGN_OR_RETURN(l,
                        UpdateSliceInMinorDims(builder, l, factorized, {i, i}));

    if (i + k < n) {
      // l[i+k:, i:i+k] = trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k])
      TF_ASSIGN_OR_RETURN(auto panel,
                          SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
      TF_ASSIGN_OR_RETURN(auto update,
                          TriangularSolve(builder, factorized, panel,
                                          /*left_side=*/false,
                                          /*lower=*/true,
                                          /*transpose_a=*/true,
                                          /*conjugate_a=*/false,
                                          /*block_size=*/block_size));
      TF_ASSIGN_OR_RETURN(
          l, UpdateSliceInMinorDims(builder, l, update, {i + k, i}));
    }
  }
  return l;
}

}  // namespace tensorflow