aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/client_library_test_base.h
blob: 7cfc276ec19e3b177f87a08e716cb34b7676dd6b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
#define TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_

#include <memory>
#include <string>
#include <type_traits>
#include <vector>

#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/bitmap.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"

namespace xla {

// A client library test establishes an in-process XLA client connection.
class ClientLibraryTestBase : public ::testing::Test {
 protected:
  explicit ClientLibraryTestBase(
      perftools::gputools::Platform* platform = nullptr);

  // Creates a new ClientLibraryTestBase with custom client options.
  ClientLibraryTestBase(perftools::gputools::Platform* platform,
                        const LocalClientOptions& client_options);

  // Returns the name of the test currently being run.
  string TestName() const;

  void SetFastMathDisabled(bool disabled) {
    execution_options_.mutable_debug_options()->set_xla_enable_fast_math(
        !disabled);
  }

  void SetSeed(uint64 seed) { execution_options_.set_seed(seed); }

  // Provides mutable access to the execution DebugOptions field; this lets
  // tests tweak the options that will be used to compile/run the graph.
  DebugOptions* mutable_debug_options() {
    return execution_options_.mutable_debug_options();
  }

  // TODO(b/25566808): Add helper that populates a literal from a testdata file.

  // Convenience methods for building and running a computation with the member
  // execution options. Modify execution_options_ in your test if you want to
  // customize the options.
  StatusOr<std::unique_ptr<GlobalData>> Execute(
      ComputationBuilder* builder,
      tensorflow::gtl::ArraySlice<GlobalData*> arguments);
  StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
      ComputationBuilder* builder,
      tensorflow::gtl::ArraySlice<GlobalData*> arguments,
      const Shape* shape_with_output_layout = nullptr);
  StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
      const Computation& computation,
      tensorflow::gtl::ArraySlice<GlobalData*> arguments,
      const Shape* shape_with_output_layout = nullptr);

  // Convenience OrDie variants of above methods.
  std::unique_ptr<GlobalData> ExecuteOrDie(
      ComputationBuilder* builder,
      tensorflow::gtl::ArraySlice<GlobalData*> arguments);
  std::unique_ptr<Literal> ExecuteAndTransferOrDie(
      ComputationBuilder* builder,
      tensorflow::gtl::ArraySlice<GlobalData*> arguments);

  // Run a computation and return its value as a string. If an error
  // occurs, then instead return the error as a string.
  string ExecuteToString(ComputationBuilder* builder,
                         tensorflow::gtl::ArraySlice<GlobalData*> arguments);

  // Convenience methods for building and running a computation, transferring
  // the result, and comparing it to the expected value(s). Methods are
  // templated on the native host type which maps to specific XLA types (See
  // ComputationBuilder for details). For each rank, two forms are provided: one
  // for floating point types with an ErrorSpec parameter, and one for integral
  // types without the ErrorSpec parameter.
  template <typename NativeT>
  void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected,
                           tensorflow::gtl::ArraySlice<GlobalData*> arguments);
  template <typename NativeT>
  void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected,
                           tensorflow::gtl::ArraySlice<GlobalData*> arguments,
                           ErrorSpec error);

  template <typename NativeT>
  void ComputeAndCompareR1(ComputationBuilder* builder,
                           tensorflow::gtl::ArraySlice<NativeT> expected,
                           tensorflow::gtl::ArraySlice<GlobalData*> arguments);
  template <typename NativeT>
  void ComputeAndCompareR1(ComputationBuilder* builder,
                           tensorflow::gtl::ArraySlice<NativeT> expected,
                           tensorflow::gtl::ArraySlice<GlobalData*> arguments,
                           ErrorSpec error);

  // As above, but uses a bitmap to hold the predicate vector to avoid
  // deficiencies of vector<bool>.
  void ComputeAndCompareR1(ComputationBuilder* builder,
                           const tensorflow::core::Bitmap& expected,
                           tensorflow::gtl::ArraySlice<GlobalData*> arguments);

  template <typename NativeT>
  void ComputeAndCompareR2(ComputationBuilder* builder,
                           const Array2D<NativeT>& expected,
                           tensorflow::gtl::ArraySlice<GlobalData*> arguments);
  template <typename NativeT>
  void ComputeAndCompareR2(ComputationBuilder* builder,
                           const Array2D<NativeT>& expected,
                           tensorflow::gtl::ArraySlice<GlobalData*> arguments,
                           ErrorSpec error);

  template <typename NativeT>
  void ComputeAndCompareR3(ComputationBuilder* builder,
                           const Array3D<NativeT>& expected,
                           tensorflow::gtl::ArraySlice<GlobalData*> arguments);
  template <typename NativeT>
  void ComputeAndCompareR3(ComputationBuilder* builder,
                           const Array3D<NativeT>& expected,
                           tensorflow::gtl::ArraySlice<GlobalData*> arguments,
                           ErrorSpec error);

  template <typename NativeT>
  void ComputeAndCompareR4(ComputationBuilder* builder,
                           const Array4D<NativeT>& expected,
                           tensorflow::gtl::ArraySlice<GlobalData*> arguments);
  template <typename NativeT>
  void ComputeAndCompareR4(ComputationBuilder* builder,
                           const Array4D<NativeT>& expected,
                           tensorflow::gtl::ArraySlice<GlobalData*> arguments,
                           ErrorSpec error);

  // Build and run the computation and compare the result with the given
  // literal. shape_with_layout indicates the result layout to request when
  // calling Execute.
  void ComputeAndCompareLiteral(
      ComputationBuilder* builder, const Literal& expected,
      tensorflow::gtl::ArraySlice<GlobalData*> arguments,
      const Shape* shape_with_layout = nullptr);
  void ComputeAndCompareLiteral(
      ComputationBuilder* builder, const Literal& expected,
      tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
      const Shape* shape_with_layout = nullptr);

  // ComputeAndCompare variant which returns an error status.
  tensorflow::Status ComputeAndCompareLiteralWithStatus(
      ComputationBuilder* builder, const Literal& expected,
      tensorflow::gtl::ArraySlice<GlobalData*> arguments,
      const Shape* shape_with_layout = nullptr);
  tensorflow::Status ComputeAndCompareLiteralWithStatus(
      ComputationBuilder* builder, const Literal& expected,
      tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
      const Shape* shape_with_layout = nullptr);

  // Compare the result of the computation to a strings. In XLA strings are
  // represented using rank-1 U8 shapes.
  void ComputeAndCompareR1U8(
      ComputationBuilder* builder, tensorflow::StringPiece expected,
      tensorflow::gtl::ArraySlice<GlobalData*> arguments);

  // Convenience method for running a built computation, transferring the
  // result, and comparing it to the expected tuple literal.
  void ComputeAndCompareTuple(
      ComputationBuilder* builder, const Literal& expected,
      tensorflow::gtl::ArraySlice<GlobalData*> arguments);
  void ComputeAndCompareTuple(
      ComputationBuilder* builder, const Literal& expected,
      tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec abs_error);

  // Create scalar operations for use in reductions.
  Computation CreateScalarRelu();
  Computation CreateScalarMax();
  Computation CreateScalarReluSensitivity();

  // Special case convenience functions for creating filled arrays.

  // Creates an array of pseudorandom values lying between the given minimum and
  // maximum values.
  template <typename NativeT>
  std::vector<NativeT> CreatePseudorandomR1(const int width, NativeT min_value,
                                            NativeT max_value, uint32 seed);
  template <typename NativeT>
  std::unique_ptr<Array2D<NativeT>> CreatePseudorandomR2(const int rows,
                                                         const int cols,
                                                         NativeT min_value,
                                                         NativeT max_value,
                                                         uint32 seed);

  // Creates a (rows x cols) array filled in the following form:
  //
  //  [      0              1 ...                   cols-1]
  //  [  1,000          1,001 ...          1000.0 + cols-1]
  //  [    ...            ... ...                      ...]
  //  [(rows-1)*1000.0    ... ... (rows-1)*1000.0 + cols-1]
  //
  // If provided, offset is added uniformly to every element (e.g. an offset of
  // 64 would cause 0 in the above to be 64, 1 to be 65, 1000 to be 1064, etc.)
  std::unique_ptr<Array2D<float>> CreatePatternedMatrix(const int rows,
                                                        const int cols,
                                                        float offset = 0.0);

  // Creates a (rows x cols) array as above, padded out to
  // (rows_padded x cols_padded) with zeroes.  Requires rows_padded >= rows
  // and cols_padded > cols.
  std::unique_ptr<Array2D<float>> CreatePatternedMatrixWithZeroPadding(
      const int rows, const int cols, const int rows_padded,
      const int cols_padded);

  // Create a parameter instruction that wraps a given value and then stores
  // into "data_handle" the global handle for that parameter.
  //
  // "parameter_number" is the parameter number.
  // "name" is the name of the parameter instruction.
  template <typename NativeT>
  std::unique_ptr<GlobalData> CreateR0Parameter(
      NativeT value, int64 parameter_number, const string& name,
      ComputationBuilder* builder, ComputationDataHandle* data_handle);

  // Create a parameter instruction that wraps the given values and then stores
  // into "data_handle" the global handle for that parameter.
  //
  // "parameter_number" is the parameter number.
  // "name" is the name of the parameter instruction.
  template <typename NativeT>
  std::unique_ptr<GlobalData> CreateR1Parameter(
      tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
      const string& name, ComputationBuilder* builder,
      ComputationDataHandle* data_handle);

  // Create a parameter instruction that wraps the given constant array
  // "array_2d" and then stores to "data_handle" the global handle for that
  // parameter.
  //
  // "parameter_number" is the parameter number.
  // "name" is the name of the parameter instruction.
  template <typename NativeT>
  std::unique_ptr<GlobalData> CreateR2Parameter(
      const Array2D<NativeT>& array_2d, int64 parameter_number,
      const string& name, ComputationBuilder* builder,
      ComputationDataHandle* data_handle);

  // Create a parameter instruction that wraps the given constant array
  // "array_3d" and then stores to "data_handle" the global handle for that
  // parameter.
  //
  // "parameter_number" is the parameter number.
  // "name" is the name of the parameter instruction.
  template <typename NativeT>
  std::unique_ptr<GlobalData> CreateR3Parameter(
      const Array3D<NativeT>& array_3d, int64 parameter_number,
      const string& name, ComputationBuilder* builder,
      ComputationDataHandle* data_handle);

  Client* client_;
  ExecutionOptions execution_options_;

 private:
  // Build and run the computation with all permutations of output layouts.
  tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts(
      const xla::Computation& computation, const Literal& expected,
      tensorflow::gtl::ArraySlice<GlobalData*> arguments,
      const std::function<void(const Literal& actual,
                               const string& error_message)>& verify_output);
  // Build and run the computation with all permutations of layouts of all input
  // arguments.
  tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts(
      const xla::Computation& computation, const Literal& expected,
      tensorflow::gtl::ArraySlice<GlobalData*> arguments,
      const std::function<void(const Literal& actual,
                               const string& error_message)>& verify_output,
      const Shape* output_with_layout = nullptr);
};

template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR0(
    ComputationBuilder* builder, NativeT expected,
    tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
  std::unique_ptr<Literal> expected_literal =
      Literal::CreateR0<NativeT>(expected);
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
                                                  arguments);
}

template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR0(
    ComputationBuilder* builder, NativeT expected,
    tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
  static_assert(std::is_same<NativeT, float>::value ||
                    std::is_same<NativeT, double>::value,
                "Floating point type required when specifying an ErrorSpec");
  std::unique_ptr<Literal> expected_literal =
      Literal::CreateR0<NativeT>(expected);
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
                                                  arguments, error);
}

template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR1(
    ComputationBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
    tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
  std::unique_ptr<Literal> expected_literal =
      Literal::CreateR1<NativeT>(expected);
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
                                                  arguments);
}

template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR1(
    ComputationBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
    tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
  static_assert(std::is_same<NativeT, float>::value ||
                    std::is_same<NativeT, double>::value,
                "Floating point type required when specifying an ErrorSpec");
  std::unique_ptr<Literal> expected_literal =
      Literal::CreateR1<NativeT>(expected);
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
                                                  arguments, error);
}

template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR2(
    ComputationBuilder* builder, const Array2D<NativeT>& expected,
    tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
  std::unique_ptr<Literal> expected_literal =
      Literal::CreateR2FromArray2D<NativeT>(expected);
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
                                                  arguments);
}

template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR2(
    ComputationBuilder* builder, const Array2D<NativeT>& expected,
    tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
  static_assert(std::is_same<NativeT, float>::value ||
                    std::is_same<NativeT, double>::value ||
                    std::is_same<NativeT, complex64>::value,
                "Float or complex type required when specifying an ErrorSpec");
  std::unique_ptr<Literal> expected_literal =
      Literal::CreateR2FromArray2D<NativeT>(expected);
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
                                                  arguments, error);
}

template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR3(
    ComputationBuilder* builder, const Array3D<NativeT>& expected,
    tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
  std::unique_ptr<Literal> expected_literal =
      Literal::CreateR3FromArray3D<NativeT>(expected);
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
                                                  arguments);
}

template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR3(
    ComputationBuilder* builder, const Array3D<NativeT>& expected,
    tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
  static_assert(std::is_same<NativeT, float>::value ||
                    std::is_same<NativeT, double>::value ||
                    std::is_same<NativeT, complex64>::value,
                "Float or complex type required when specifying an ErrorSpec");
  std::unique_ptr<Literal> expected_literal =
      Literal::CreateR3FromArray3D<NativeT>(expected);
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
                                                  arguments, error);
}

template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR4(
    ComputationBuilder* builder, const Array4D<NativeT>& expected,
    tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
  std::unique_ptr<Literal> expected_literal =
      Literal::CreateR4FromArray4D<NativeT>(expected);
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
                                                  arguments);
}

template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR4(
    ComputationBuilder* builder, const Array4D<NativeT>& expected,
    tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
  static_assert(std::is_same<NativeT, float>::value ||
                    std::is_same<NativeT, double>::value ||
                    std::is_same<NativeT, complex64>::value,
                "Float or complex type required when specifying an ErrorSpec");
  std::unique_ptr<Literal> expected_literal =
      Literal::CreateR4FromArray4D<NativeT>(expected);
  ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
                                                  arguments, error);
}

template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
    NativeT value, int64 parameter_number, const string& name,
    ComputationBuilder* builder, ComputationDataHandle* data_handle) {
  std::unique_ptr<Literal> literal = Literal::CreateR0(value);
  std::unique_ptr<GlobalData> data =
      client_->TransferToServer(*literal).ConsumeValueOrDie();
  *data_handle = builder->Parameter(parameter_number, literal->shape(), name);
  return data;
}

template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
    tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
    const string& name, ComputationBuilder* builder,
    ComputationDataHandle* data_handle) {
  std::unique_ptr<Literal> literal = Literal::CreateR1(values);
  std::unique_ptr<GlobalData> data =
      client_->TransferToServer(*literal).ConsumeValueOrDie();
  *data_handle = builder->Parameter(parameter_number, literal->shape(), name);
  return data;
}

template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
    const Array2D<NativeT>& array_2d, int64 parameter_number,
    const string& name, ComputationBuilder* builder,
    ComputationDataHandle* data_handle) {
  std::unique_ptr<Literal> literal = Literal::CreateR2FromArray2D(array_2d);
  std::unique_ptr<GlobalData> data =
      client_->TransferToServer(*literal).ConsumeValueOrDie();
  *data_handle = builder->Parameter(parameter_number, literal->shape(), name);
  return data;
}

template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
    const Array3D<NativeT>& array_3d, int64 parameter_number,
    const string& name, ComputationBuilder* builder,
    ComputationDataHandle* data_handle) {
  std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(array_3d);
  std::unique_ptr<GlobalData> data =
      client_->TransferToServer(*literal).ConsumeValueOrDie();
  *data_handle = builder->Parameter(parameter_number, literal->shape(), name);
  return data;
}

template <typename NativeT>
std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
    const int width, NativeT min_value, NativeT max_value, uint32 seed) {
  std::vector<NativeT> result(width);
  test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
                                                       seed);
  for (int i = 0; i < width; ++i) {
    result[i] = generator.get();
  }
  return result;
}

template <typename NativeT>
std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
    const int rows, const int cols, NativeT min_value, NativeT max_value,
    uint32 seed) {
  auto result = MakeUnique<Array2D<NativeT>>(rows, cols);
  test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
                                                       seed);
  for (int y = 0; y < rows; ++y) {
    for (int x = 0; x < cols; ++x) {
      (*result)(y, x) = generator.get();
    }
  }
  return result;
}

}  // namespace xla

#endif  // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_