aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/python/local_computation_builder.cc
blob: b5ba4e2d429e465649fc1b7acaf19fcb75f6d1ef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/python/local_computation_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/thread_annotations.h"

namespace xla {
namespace swig {

// TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of
// device handles instead of needing to set the number of replicas at XLA
// service initialization time.
tensorflow::mutex g_local_client_mutex(tensorflow::LINKER_INITIALIZED);
int g_replica_count GUARDED_BY(g_local_client_mutex) = 1;
LocalClient* g_local_client GUARDED_BY(g_local_client_mutex) = nullptr;

Status InitializeReplicaCount(int replica_count) {
  if (replica_count < 1) {
    return InvalidArgument("Replica count must be >= 1; got %d.",
                           replica_count);
  }
  tensorflow::mutex_lock lock(g_local_client_mutex);
  if (g_local_client != nullptr) {
    return FailedPrecondition(
        "Attempted to set the replica count to %d, but a local XLA service was "
        "previously created with a replica count of %d.",
        replica_count, g_replica_count);
  }
  g_replica_count = replica_count;
  return Status::OK();
}

int GetReplicaCount() {
  tensorflow::mutex_lock lock(g_local_client_mutex);
  return g_replica_count;
}

LocalClient* GetOrCreateLocalClient() {
  tensorflow::mutex_lock lock(g_local_client_mutex);
  if (g_local_client != nullptr) {
    return g_local_client;
  }
  LocalClientOptions options;
  options.set_number_of_replicas(g_replica_count);
  g_local_client = ClientLibrary::GetOrCreateLocalClient(options).ValueOrDie();
  CHECK(g_local_client != nullptr);
  return g_local_client;
}

Status TransferToInfeedLocal(const Literal& literal) {
  VLOG(1) << "Infeeding literal without replica number; shape: "
          << literal.shape();
  LocalClient* client = GetOrCreateLocalClient();
  return client->TransferToInfeedLocal(literal, /*device_ordinal=*/0);
}

Status TransferToInfeedLocalReplica(const Literal& literal,
                                    int replica_number) {
  VLOG(1) << "Infeeding shape " << literal.shape()
          << " to replica number: " << replica_number;
  LocalClient* client = GetOrCreateLocalClient();
  TF_ASSIGN_OR_RETURN(int device_ordinal,
                      client->ReplicaNumberToDeviceOrdinal(replica_number));
  return client->TransferToInfeedLocal(literal, device_ordinal);
}

StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocalReplica(
    const Shape& shape, int replica_number) {
  VLOG(1) << "Outfeeding literal from replica number: " << replica_number
          << " shape: " << shape;
  LocalClient* client = GetOrCreateLocalClient();
  TF_ASSIGN_OR_RETURN(int device_ordinal,
                      client->ReplicaNumberToDeviceOrdinal(replica_number));
  return client->TransferFromOutfeedLocal(shape, device_ordinal);
}

LocalShapedBuffer::LocalShapedBuffer(ScopedShapedBuffer shaped_buffer)
    : shaped_buffer_(std::move(shaped_buffer)) {}

const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const {
  return &shaped_buffer_;
}

ShapedBuffer LocalShapedBuffer::Release() { return shaped_buffer_.release(); }

LocalShapedBufferTuple::LocalShapedBufferTuple(
    std::vector<LocalShapedBuffer*> elements)
    : elements_(std::move(elements)) {
  for (auto* element : elements_) {
    DCHECK(element != nullptr);
  }
}

LocalShapedBufferTuple::~LocalShapedBufferTuple() {
  for (LocalShapedBuffer* element : elements_) {
    if (element != nullptr) {
      delete element;
    }
  }
}

StatusOr<LocalShapedBuffer*> LocalShapedBufferTuple::Release(int i) {
  LocalShapedBuffer* element = elements_[i];
  if (element == nullptr) {
    return InvalidArgument("Attempted to release already-released element %d.",
                           i);
  }
  elements_[i] = nullptr;
  return element;
}

int LocalShapedBufferTuple::size() const { return elements_.size(); }

static StatusOr<ScopedShapedBuffer> ToBuffer(LocalClient* client,
                                             int device_ordinal,
                                             const Literal& arg) {
  return client->LiteralToShapedBuffer(arg, device_ordinal,
                                       client->backend().memory_allocator());
}

/* static */
StatusOr<LocalShapedBuffer*> LocalShapedBuffer::FromLiteral(
    const Literal& argument,
    const tensorflow::gtl::optional<Shape>& shape_with_layout) {
  LocalClient* client = GetOrCreateLocalClient();
  StatusOr<ScopedShapedBuffer> buf = [&] {
    if (shape_with_layout) {
      std::unique_ptr<Literal> relaid =
          argument.Relayout(shape_with_layout.value());
      return ToBuffer(client, /*device_ordinal=*/0, *relaid);
    }
    return ToBuffer(client, /*device_ordinal=*/0, argument);
  }();
  TF_RETURN_IF_ERROR(buf.status());
  return new LocalShapedBuffer(std::move(buf).ValueOrDie());
}

StatusOr<std::unique_ptr<Literal>> LocalShapedBuffer::ToLiteral() const {
  LocalClient* client = GetOrCreateLocalClient();
  return client->ShapedBufferToLiteral(*shaped_buffer());
}

CompiledLocalComputation::CompiledLocalComputation(
    std::unique_ptr<LocalExecutable> executable)
    : executable_(std::move(executable)) {}

StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
    const std::vector<Literal>& arguments,
    const std::vector<tensorflow::gtl::optional<Shape>>& shapes_with_layout) {
  LocalClient* client = GetOrCreateLocalClient();

  VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas.";

  // Each replica populates a StatusOr result, but only replica zero actually
  // retrieves its literal value.
  std::vector<StatusOr<std::unique_ptr<Literal>>> results(GetReplicaCount());
  {
    tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun",
                                        GetReplicaCount());

    for (int replica = 0; replica < GetReplicaCount(); ++replica) {
      pool.Schedule(
          [this, client, replica, &arguments, &shapes_with_layout, &results] {
            StatusOr<int> device_ordinal_status =
                client->ReplicaNumberToDeviceOrdinal(replica);
            if (!device_ordinal_status.ok()) {
              results[replica] = device_ordinal_status.status();
              return;
            }
            const int device_ordinal = device_ordinal_status.ValueOrDie();
            VLOG(3) << "Replica " << replica
                    << " mapped to device ordinal for execution: "
                    << device_ordinal;

            // Transfer arguments in
            std::vector<ScopedShapedBuffer> scoped_buffers;
            scoped_buffers.reserve(arguments.size());
            for (int i = 0; i < arguments.size(); ++i) {
              const Literal& argument = arguments[i];
              const tensorflow::gtl::optional<Shape>& shape_with_layout =
                  shapes_with_layout[i];

              StatusOr<ScopedShapedBuffer> pushed;
              if (shape_with_layout) {
                std::unique_ptr<Literal> relaid =
                    argument.Relayout(shape_with_layout.value());
                pushed = ToBuffer(client, device_ordinal, *relaid);
              } else {
                pushed = ToBuffer(client, device_ordinal, argument);
              }
              if (!pushed.ok()) {
                results[replica] = pushed.status();
                return;
              }

              scoped_buffers.push_back(std::move(pushed).ValueOrDie());
            }

            // Execute
            std::vector<const ShapedBuffer*> argument_buffers;
            argument_buffers.reserve(scoped_buffers.size());
            for (auto& buffer : scoped_buffers) {
              argument_buffers.push_back(&buffer);
            }

            DeviceAssignment device_assignment =
                client->backend()
                    .computation_placer()
                    ->AssignDevices(GetReplicaCount(), /*computation_count=*/1)
                    .ConsumeValueOrDie();

            ExecutableRunOptions options;
            options.set_device_ordinal(device_ordinal);
            options.set_allocator(client->backend().memory_allocator());
            options.set_intra_op_thread_pool(
                client->backend().eigen_intra_op_thread_pool_device());
            options.set_device_assignment(&device_assignment);
            StatusOr<ScopedShapedBuffer> result_buffer_status =
                executable_->Run(argument_buffers, options);
            if (!result_buffer_status.ok()) {
              results[replica] = result_buffer_status.status();
              return;
            }

            // Transfer result out
            results[replica] = client->ShapedBufferToLiteral(
                std::move(result_buffer_status).ValueOrDie());
          });
    }
  }

  for (int replica = 0; replica < GetReplicaCount(); ++replica) {
    const auto& statusor = results[replica];
    if (!statusor.ok()) {
      return InternalError(
          "Failed running replica %d (other replicas may have failed as well): "
          "%s.",
          replica, statusor.status().ToString().c_str());
    }
  }

  return std::move(results[0]);
}

LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers(
    tensorflow::gtl::ArraySlice<LocalShapedBuffer*> argument_handles) {
  LocalClient* client = GetOrCreateLocalClient();

  std::vector<const ShapedBuffer*> argument_buffers;
  argument_buffers.reserve(argument_handles.size());
  for (auto& handle : argument_handles) {
    argument_buffers.push_back(handle->shaped_buffer());
  }

  // Execute
  ExecutableRunOptions options;
  options.set_allocator(client->backend().memory_allocator());
  options.set_intra_op_thread_pool(
      client->backend().eigen_intra_op_thread_pool_device());
  ScopedShapedBuffer result_buffer =
      executable_->Run(argument_buffers, options).ConsumeValueOrDie();

  return new LocalShapedBuffer(std::move(result_buffer));
}

LocalComputation::LocalComputation(XlaComputation computation)
    : computation_(std::move(computation)) {}

StatusOr<CompiledLocalComputation*> LocalComputation::Compile(
    const std::vector<Shape>& argument_shapes,
    const ExecutableBuildOptions* build_options) {
  std::vector<const Shape*> argument_shape_pointers;
  argument_shape_pointers.reserve(argument_shapes.size());
  for (auto& argument_shape : argument_shapes) {
    argument_shape_pointers.push_back(&argument_shape);
  }

  LocalClient* client = GetOrCreateLocalClient();
  ExecutableBuildOptions options;
  if (build_options != nullptr) {
    options = *build_options;
  }
  TF_ASSIGN_OR_RETURN(
      auto local_executable,
      client->Compile(computation_, argument_shape_pointers, options));
  return new CompiledLocalComputation(std::move(local_executable));
}

const XlaComputation& LocalComputation::computation() const {
  return computation_;
}

string LocalComputation::GetSerializedProto() const {
  string result;
  if (!computation_.proto().SerializeToString(&result)) {
    LOG(ERROR) << "Failed to serialize the HloModuleProto.";
    return "";
  }
  return result;
}

StatusOr<Shape> LocalComputation::GetReturnValueShape() const {
  TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
                      computation_.GetProgramShape());
  return std::move(*program_shape.mutable_result());
}

LocalOp::LocalOp(const XlaOp& op) : op_(op) {}

const XlaOp& LocalOp::op() const { return op_; }

LocalComputationBuilder::LocalComputationBuilder(const string& computation_name)
    : builder_(computation_name) {}

void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) {
  builder_.SetOpMetadata(metadata);
}

void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); }

StatusOr<LocalComputation*> LocalComputationBuilder::Build() {
  TF_ASSIGN_OR_RETURN(XlaComputation computation, builder_.Build());
  return new LocalComputation(std::move(computation));
}

LocalOp LocalComputationBuilder::Parameter(int64 parameter_number,
                                           const Shape& shape,
                                           const string& name) {
  return xla::Parameter(&builder_, parameter_number, shape, name);
}

StatusOr<Shape> LocalComputationBuilder::GetShape(const LocalOp& operand) {
  return builder_.GetShape(operand.op());
}

StatusOr<Shape> LocalComputationBuilder::GetReturnValueShape() {
  TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape());
  return program_shape.result();
}

LocalOp LocalComputationBuilder::Infeed(const Shape& shape) {
  return xla::Infeed(&builder_, shape);
}

void LocalComputationBuilder::Outfeed(const LocalOp& operand,
                                      const Shape& shape,
                                      const string& outfeed_config) {
  xla::Outfeed(operand.op(), shape, outfeed_config);
}

LocalOp LocalComputationBuilder::ConstantLiteral(const Literal& literal) {
  return xla::ConstantLiteral(&builder_, literal);
}

LocalOp LocalComputationBuilder::Broadcast(
    const LocalOp& operand,
    tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
  return xla::Broadcast(operand.op(), broadcast_sizes);
}

LocalOp LocalComputationBuilder::Pad(const LocalOp& operand,
                                     const LocalOp& padding_value,
                                     const PaddingConfig& padding_config) {
  return xla::Pad(operand.op(), padding_value.op(), padding_config);
}

LocalOp LocalComputationBuilder::Reshape(
    const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
    tensorflow::gtl::ArraySlice<int64> new_sizes) {
  return xla::Reshape(operand.op(), dimensions, new_sizes);
}

LocalOp LocalComputationBuilder::Collapse(
    const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
  return xla::Collapse(operand.op(), dimensions);
}

LocalOp LocalComputationBuilder::CrossReplicaSum(const LocalOp& operand) {
  return xla::CrossReplicaSum(operand.op());
}

LocalOp LocalComputationBuilder::Slice(
    const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> start_indices,
    tensorflow::gtl::ArraySlice<int64> limit_indices,
    tensorflow::gtl::ArraySlice<int64> strides) {
  return xla::Slice(operand.op(), start_indices, limit_indices, strides);
}

LocalOp LocalComputationBuilder::SliceInDim(const LocalOp& operand,
                                            int64 start_index,
                                            int64 limit_index, int64 stride,
                                            int64 dimno) {
  return xla::SliceInDim(operand.op(), start_index, limit_index, stride, dimno);
}

LocalOp LocalComputationBuilder::DynamicSlice(
    const LocalOp& operand, const LocalOp& start_indices,
    tensorflow::gtl::ArraySlice<int64> slice_sizes) {
  return xla::DynamicSlice(operand.op(), start_indices.op(), slice_sizes);
}

LocalOp LocalComputationBuilder::DynamicUpdateSlice(
    const LocalOp& operand, const LocalOp& update,
    const LocalOp& start_indices) {
  return xla::DynamicUpdateSlice(operand.op(), update.op(), start_indices.op());
}

LocalOp LocalComputationBuilder::ConcatInDim(
    tensorflow::gtl::ArraySlice<LocalOp> operands, int64 dimension) {
  std::vector<XlaOp> xla_ops;
  xla_ops.reserve(operands.size());
  for (const auto& op : operands) {
    xla_ops.push_back(op.op());
  }
  return xla::ConcatInDim(&builder_, xla_ops, dimension);
}

LocalOp LocalComputationBuilder::SelectAndScatterWithGeneralPadding(
    const LocalOp& operand, const LocalComputation& select,
    tensorflow::gtl::ArraySlice<int64> window_dimensions,
    tensorflow::gtl::ArraySlice<int64> window_strides,
    tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
    const LocalOp& source, const LocalOp& init_value,
    const LocalComputation& scatter) {
  return xla::SelectAndScatterWithGeneralPadding(
      operand.op(), select.computation(), window_dimensions, window_strides,
      padding, source.op(), init_value.op(), scatter.computation());
}

LocalOp LocalComputationBuilder::Tuple(
    tensorflow::gtl::ArraySlice<LocalOp> elements) {
  std::vector<XlaOp> xla_ops;
  xla_ops.reserve(elements.size());
  for (const auto& op : elements) {
    xla_ops.push_back(op.op());
  }

  return xla::Tuple(&builder_, xla_ops);
}

LocalOp LocalComputationBuilder::GetTupleElement(const LocalOp& tuple_data,
                                                 int64 index) {
  return xla::GetTupleElement(tuple_data.op(), index);
}

LocalOp LocalComputationBuilder::Dot(const LocalOp& lhs, const LocalOp& rhs) {
  return xla::Dot(lhs.op(), rhs.op());
}

LocalOp LocalComputationBuilder::DotGeneral(
    const LocalOp& lhs, const LocalOp& rhs,
    const DotDimensionNumbers& dimension_numbers) {
  return xla::DotGeneral(lhs.op(), rhs.op(), dimension_numbers);
}

LocalOp LocalComputationBuilder::ConvGeneralDilated(
    const LocalOp& lhs, const LocalOp& rhs,
    tensorflow::gtl::ArraySlice<int64> window_strides,
    tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
    tensorflow::gtl::ArraySlice<int64> lhs_dilation,
    tensorflow::gtl::ArraySlice<int64> rhs_dilation,
    const ConvolutionDimensionNumbers& dimension_numbers) {
  return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding,
                                 lhs_dilation, rhs_dilation, dimension_numbers);
}

LocalOp LocalComputationBuilder::ConvertElementType(
    const LocalOp& operand, PrimitiveType new_element_type) {
  return xla::ConvertElementType(operand.op(), new_element_type);
}

LocalOp LocalComputationBuilder::Call(
    const LocalComputation& local_computation,
    tensorflow::gtl::ArraySlice<LocalOp> operands) {
  std::vector<XlaOp> xla_ops;
  xla_ops.reserve(operands.size());
  for (const auto& op : operands) {
    xla_ops.push_back(op.op());
  }
  return xla::Call(&builder_, local_computation.computation(), xla_ops);
}

LocalOp LocalComputationBuilder::Transpose(
    const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> permutation) {
  return xla::Transpose(operand.op(), permutation);
}

LocalOp LocalComputationBuilder::Rev(
    const LocalOp& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
  return xla::Rev(operand.op(), dimensions);
}

LocalOp LocalComputationBuilder::Map(
    tensorflow::gtl::ArraySlice<LocalOp> operands,
    const LocalComputation& local_computation,
    tensorflow::gtl::ArraySlice<int64> dimensions) {
  std::vector<XlaOp> xla_ops;
  xla_ops.reserve(operands.size());
  for (const auto& op : operands) {
    xla_ops.push_back(op.op());
  }

  return xla::Map(&builder_, xla_ops, local_computation.computation(),
                  dimensions);
}

LocalOp LocalComputationBuilder::Reduce(
    const LocalOp& operand, const LocalOp& init_value,
    const LocalComputation& local_computation,
    tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
  return xla::Reduce(operand.op(), init_value.op(),
                     local_computation.computation(), dimensions_to_reduce);
}

LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding(
    const LocalOp& operand, const LocalOp& init_value,
    const LocalComputation& local_computation,
    tensorflow::gtl::ArraySlice<int64> window_dimensions,
    tensorflow::gtl::ArraySlice<int64> window_strides,
    tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
  return xla::ReduceWindowWithGeneralPadding(
      operand.op(), init_value.op(), local_computation.computation(),
      window_dimensions, window_strides, padding);
}

LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu,
                                           const LocalOp& sigma,
                                           const Shape& shape) {
  return xla::RngNormal(mu.op(), sigma.op(), shape);
}

LocalOp LocalComputationBuilder::RngUniform(const LocalOp& a, const LocalOp& b,
                                            const Shape& shape) {
  return xla::RngUniform(a.op(), b.op(), shape);
}

LocalOp LocalComputationBuilder::While(const LocalComputation& condition,
                                       const LocalComputation& body,
                                       const LocalOp& init) {
  return xla::While(condition.computation(), body.computation(), init.op());
}

LocalOp LocalComputationBuilder::Conditional(
    const LocalOp& predicate, const LocalOp& true_operand,
    const LocalComputation& true_computation, const LocalOp& false_operand,
    const LocalComputation& false_computation) {
  return xla::Conditional(predicate.op(), true_operand.op(),
                          true_computation.computation(), false_operand.op(),
                          false_computation.computation());
}

StatusOr<bool> LocalComputationBuilder::IsConstant(const LocalOp& operand) {
  return builder_.IsConstant(operand.op());
}

StatusOr<LocalComputation*> LocalComputationBuilder::BuildConstantSubGraph(
    const LocalOp& operand) {
  TF_ASSIGN_OR_RETURN(XlaComputation computation,
                      builder_.BuildConstantSubGraph(operand.op()));
  return new LocalComputation(std::move(computation));
}

#define _FORWARD(method_name, return_sig, args_sig, args)    \
  return_sig LocalComputationBuilder::method_name args_sig { \
    return xla::method_name args;                            \
  }

#define _FORWARD_UNOP(method_name) \
  _FORWARD(method_name, LocalOp, (const LocalOp& operand), (operand.op()))

#define _FORWARD_BINOP(method_name)                                   \
  _FORWARD(method_name, LocalOp,                                      \
           (const LocalOp& lhs, const LocalOp& rhs,                   \
            tensorflow::gtl::ArraySlice<int64> broadcast_dimensions), \
           (lhs.op(), rhs.op(), broadcast_dimensions))

#define _FORWARD_TRIOP(method_name)                                      \
  _FORWARD(method_name, LocalOp,                                         \
           (const LocalOp& lhs, const LocalOp& rhs, const LocalOp& ehs), \
           (lhs.op(), rhs.op(), ehs.op()))

_FORWARD_TRIOP(Select)
_FORWARD_TRIOP(Clamp)
_FORWARD_BINOP(Eq)
_FORWARD_BINOP(Ne)
_FORWARD_BINOP(Ge)
_FORWARD_BINOP(Gt)
_FORWARD_BINOP(Lt)
_FORWARD_BINOP(Le)
_FORWARD_BINOP(Add)
_FORWARD_BINOP(Sub)
_FORWARD_BINOP(Mul)
_FORWARD_BINOP(Div)
_FORWARD_BINOP(Rem)
_FORWARD_BINOP(Max)
_FORWARD_BINOP(Min)
_FORWARD_BINOP(And)
_FORWARD_BINOP(Or)
_FORWARD_BINOP(Xor)
_FORWARD_UNOP(Not)
_FORWARD_UNOP(Abs)
_FORWARD_UNOP(Exp)
_FORWARD_UNOP(Expm1)
_FORWARD_UNOP(Floor)
_FORWARD_UNOP(Ceil)
_FORWARD_UNOP(Round)
_FORWARD_UNOP(Log)
_FORWARD_UNOP(Log1p)
_FORWARD_UNOP(Sign)
_FORWARD_UNOP(Cos)
_FORWARD_UNOP(Sin)
_FORWARD_UNOP(Tanh)
_FORWARD_UNOP(SqrtF32)
_FORWARD_UNOP(SquareF32)
_FORWARD_BINOP(Pow)
_FORWARD_UNOP(IsFinite)
_FORWARD_UNOP(ReciprocalF32)
_FORWARD_UNOP(Neg)
_FORWARD_UNOP(Sort)

#undef _FORWARD
#undef _FORWARD_UNOP
#undef _FORWARD_BINOP
#undef _FORWARD_TRIOP

void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) {
  delete local_shaped_buffer;
}

void DeleteCompiledLocalComputation(CompiledLocalComputation* computation) {
  delete computation;
}

void DeleteLocalComputation(LocalComputation* computation) {
  delete computation;
}

StatusOr<LocalShapedBufferTuple*> DestructureLocalShapedBufferTuple(
    LocalShapedBuffer* local_shaped_buffer) {
  if (!ShapeUtil::IsTuple(
          local_shaped_buffer->shaped_buffer()->on_device_shape())) {
    return InvalidArgument(
        "Attemped to destructure a LocalShapedBuffer that did not have a tuple "
        "shape; shape: %s",
        ShapeUtil::HumanString(
            local_shaped_buffer->shaped_buffer()->on_device_shape())
            .c_str());
  }

  DeviceMemoryAllocator* allocator =
      local_shaped_buffer->shaped_buffer()->memory_allocator();
  ShapedBuffer tuple_buffer = local_shaped_buffer->Release();

  // Extract some metadata we use to construct scoped buffers.
  const se::Platform* platform = tuple_buffer.platform();
  int device_ordinal = tuple_buffer.device_ordinal();

  ShapeTree<se::DeviceMemoryBase>& shape_tree = tuple_buffer.buffers();
  const Shape& tuple_shape = tuple_buffer.on_device_shape();
  std::vector<LocalShapedBuffer*> results;
  for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) {
    // Create a shaped buffer for this destructured tuple element.
    const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i});
    VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape;
    ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal);

    ShapeUtil::ForEachSubshape(
        subshape, [&](const Shape& s, const ShapeIndex& index) {
          ShapeIndex original(index);
          original.push_front(i);
          se::DeviceMemoryBase* device_memory =
              shape_tree.mutable_element(original);
          shaped_buffer.set_buffer(*device_memory, index);
          *device_memory = se::DeviceMemoryBase();
        });

    VLOG(3) << "Completed tuple element: " << i;
    results.push_back(new LocalShapedBuffer(
        ScopedShapedBuffer(std::move(shaped_buffer), allocator)));
  }
  // Deallocate the root buffer.
  se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer();
  TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer));
  return new LocalShapedBufferTuple(std::move(results));
}

}  // namespace swig
}  // namespace xla