aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/python/local_computation_builder.h
blob: d5c4c5804060abdf968c85643910d397e7c646e6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_
#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_

#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"

namespace xla {

namespace swig {

// Initializes the number of replicas that XLA will be initialized with (when
// first obtaining a handle to the local XLA service). If this is called after
// the handle to the local XLA service has been established, then an error is
// returned.
Status InitializeReplicaCount(int replica_count);

// Returns the replica count that is currently set, regardless of whether the
// local XLA service has been instantiated yet or not.
int GetReplicaCount();

// Wraps the local client's infeed-transfer function.
//
// The default device ordinal (0) is used.
Status TransferToInfeedLocal(const Literal& literal);

// Transfers the given literal to the infeed of the given replica.
//
// The replica number is resolved to an appropriate device ordinal.
Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number);

// Transfers a literal of the given shape from the outfeed of the given replica.
//
// The replica number is resolved to an appropriate device ordinal.
StatusOr<std::unique_ptr<Literal> > TransferFromOutfeedLocalReplica(
    const Shape& shape, int replica_number);

// Wraps a ScopedShapedBuffer produced by copying a literal "to
// device," i.e. copying a literal to a scoped buffer via the local
// client.
class LocalShapedBuffer {
 public:
  static LocalShapedBuffer* FromLiteral(const Literal& argument);
  LocalShapedBuffer(std::unique_ptr<ScopedShapedBuffer> shaped_buffer);
  const std::unique_ptr<ScopedShapedBuffer>& shaped_buffer() const;
  std::unique_ptr<Literal> ToLiteral() const;

 private:
  std::unique_ptr<ScopedShapedBuffer> shaped_buffer_;
};

// Wraps a LocalExecutable produced by compiling a
// LocalComputation. The Execute method forwards to that of the
// underlying LocalExecutable, and additionally handles tranferring
// arguments and return values in and back out of the client library's
// local client. This class is intended to be made available to Python
// via SWIG.
class CompiledLocalComputation {
 public:
  CompiledLocalComputation(std::unique_ptr<LocalExecutable> executable);
  StatusOr<std::unique_ptr<Literal> > Execute(
      const std::vector<Literal>& arguments);
  LocalShapedBuffer* ExecuteWithShapedBuffers(
      tensorflow::gtl::ArraySlice<LocalShapedBuffer*> argument_handles);

 private:
  std::unique_ptr<LocalExecutable> executable_;
};

// Wraps a Computation produced by a LocalComputationBuilder. The
// Compile method compiles the computation to a (local) executable via
// the client library's local client. This class is intended to be
// made available to Python via SWIG.
class LocalComputation {
 public:
  LocalComputation(Computation computation);
  StatusOr<CompiledLocalComputation*> Compile(
      const std::vector<Shape>& argument_shapes,
      const ExecutableBuildOptions* build_options);
  const Computation& computation() const;

 private:
  Computation computation_;
};

// Wraps the ComputationBuilder API in order to:
// - Support consumption by SWIG in order to be made available to
//   Python.
// - Set up the underlying builder to use the client library's
//   LocalClient.
// - Wrap Computations in LocalComputations for Python access.
// - Correspondingly unwrap incoming LocalComputations.
class LocalComputationBuilder {
 public:
  LocalComputationBuilder(const string& computation_name);

  void SetOpMetadata(const OpMetadata& metadata);
  void ClearOpMetadata();

  // Returns an owned LocalComputation to the caller on success.
  StatusOr<LocalComputation*> Build();

  ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape,
                                  const string& name);

  std::unique_ptr<Shape> GetShape(const ComputationDataHandle& operand);

  ComputationDataHandle Infeed(const Shape& shape);

  void Outfeed(const ComputationDataHandle& operand, const Shape& shape,
               const string& outfeed_config);

  ComputationDataHandle ConstantLiteral(const Literal& literal);

  ComputationDataHandle Broadcast(
      const ComputationDataHandle& operand,
      tensorflow::gtl::ArraySlice<int64> broadcast_sizes);

  ComputationDataHandle Pad(const ComputationDataHandle& operand,
                            const ComputationDataHandle& padding_value,
                            const PaddingConfig& padding_config);

  ComputationDataHandle Reshape(const ComputationDataHandle& operand,
                                tensorflow::gtl::ArraySlice<int64> dimensions,
                                tensorflow::gtl::ArraySlice<int64> new_sizes);

  ComputationDataHandle Collapse(const ComputationDataHandle& operand,
                                 tensorflow::gtl::ArraySlice<int64> dimensions);

  ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand);

  ComputationDataHandle Slice(const ComputationDataHandle& operand,
                              tensorflow::gtl::ArraySlice<int64> start_indices,
                              tensorflow::gtl::ArraySlice<int64> limit_indices,
                              tensorflow::gtl::ArraySlice<int64> strides);

  ComputationDataHandle DynamicSlice(
      const ComputationDataHandle& operand,
      const ComputationDataHandle& start_indices,
      tensorflow::gtl::ArraySlice<int64> slice_sizes);

  ComputationDataHandle DynamicUpdateSlice(
      const ComputationDataHandle& operand, const ComputationDataHandle& update,
      const ComputationDataHandle& start_indices);

  ComputationDataHandle ConcatInDim(
      tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
      int64 dimension);

  ComputationDataHandle SelectAndScatterWithGeneralPadding(
      const ComputationDataHandle& operand, const LocalComputation& select,
      tensorflow::gtl::ArraySlice<int64> window_dimensions,
      tensorflow::gtl::ArraySlice<int64> window_strides,
      tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding,
      const ComputationDataHandle& source,
      const ComputationDataHandle& init_value, const LocalComputation& scatter);

  ComputationDataHandle Tuple(
      tensorflow::gtl::ArraySlice<ComputationDataHandle> elements);

  ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data,
                                        int64 index);

  ComputationDataHandle Dot(const ComputationDataHandle& lhs,
                            const ComputationDataHandle& rhs);

  ComputationDataHandle ConvGeneralDilated(
      const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
      tensorflow::gtl::ArraySlice<int64> window_strides,
      tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding,
      tensorflow::gtl::ArraySlice<int64> lhs_dilation,
      tensorflow::gtl::ArraySlice<int64> rhs_dilation,
      const ConvolutionDimensionNumbers& dimension_numbers);

  ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand,
                                           PrimitiveType new_element_type);

  ComputationDataHandle Call(
      const LocalComputation& local_computation,
      tensorflow::gtl::ArraySlice<ComputationDataHandle> operands);

  ComputationDataHandle Transpose(
      const ComputationDataHandle& operand,
      tensorflow::gtl::ArraySlice<int64> permutation);

  ComputationDataHandle Rev(const ComputationDataHandle& operand,
                            tensorflow::gtl::ArraySlice<int64> dimensions);

  ComputationDataHandle Map(
      tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
      const LocalComputation& local_computation,
      tensorflow::gtl::ArraySlice<int64> dimensions,
      tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands);

  ComputationDataHandle Reduce(
      const ComputationDataHandle& operand,
      const ComputationDataHandle& init_value,
      const LocalComputation& local_computation,
      tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);

  ComputationDataHandle ReduceWindowWithGeneralPadding(
      const ComputationDataHandle& operand,
      const ComputationDataHandle& init_value,
      const LocalComputation& local_computation,
      tensorflow::gtl::ArraySlice<int64> window_dimensions,
      tensorflow::gtl::ArraySlice<int64> window_strides,
      tensorflow::gtl::ArraySlice<std::pair<int64, int64> > padding);

  ComputationDataHandle RngNormal(const ComputationDataHandle& mu,
                                  const ComputationDataHandle& sigma,
                                  const Shape& shape);

  ComputationDataHandle RngUniform(const ComputationDataHandle& a,
                                   const ComputationDataHandle& b,
                                   const Shape& shape);

  ComputationDataHandle While(const LocalComputation& condition,
                              const LocalComputation& body,
                              const ComputationDataHandle& init);

#define _FORWARD(method_name, return_sig, args_sig) \
  return_sig method_name args_sig;

#define _FORWARD_UNOP(method_name)             \
  _FORWARD(method_name, ComputationDataHandle, \
           (const ComputationDataHandle& operand))

#define _FORWARD_BINOP(method_name)                                        \
  _FORWARD(                                                                \
      method_name, ComputationDataHandle,                                  \
      (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions))

#define _FORWARD_TRIOP(method_name)                                        \
  _FORWARD(                                                                \
      method_name, ComputationDataHandle,                                  \
      (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
       const ComputationDataHandle& ehs))

  _FORWARD_TRIOP(Select)
  _FORWARD_TRIOP(Clamp)
  _FORWARD_BINOP(Eq)
  _FORWARD_BINOP(Ne)
  _FORWARD_BINOP(Ge)
  _FORWARD_BINOP(Gt)
  _FORWARD_BINOP(Lt)
  _FORWARD_BINOP(Le)
  _FORWARD_BINOP(Add)
  _FORWARD_BINOP(Sub)
  _FORWARD_BINOP(Mul)
  _FORWARD_BINOP(Div)
  _FORWARD_BINOP(Rem)
  _FORWARD_BINOP(Max)
  _FORWARD_BINOP(Min)
  _FORWARD_BINOP(And)
  _FORWARD_BINOP(Or)
  _FORWARD_UNOP(Not)
  _FORWARD_UNOP(Abs)
  _FORWARD_UNOP(Exp)
  _FORWARD_UNOP(Floor)
  _FORWARD_UNOP(Ceil)
  _FORWARD_UNOP(Round)
  _FORWARD_UNOP(Log)
  _FORWARD_UNOP(Sign)
  _FORWARD_UNOP(Cos)
  _FORWARD_UNOP(Sin)
  _FORWARD_UNOP(Tanh)
  _FORWARD_UNOP(SqrtF32)
  _FORWARD_UNOP(SquareF32)
  _FORWARD_BINOP(Pow)
  _FORWARD_UNOP(IsFinite)
  _FORWARD_UNOP(ReciprocalF32)
  _FORWARD_UNOP(Neg)
  _FORWARD_UNOP(Sort)

#undef _FORWARD
#undef _FORWARD_UNOP
#undef _FORWARD_BINOP
#undef _FORWARD_TRIOP

 private:
  ComputationBuilder builder_;
};

// Functions for freeing resources from the Python side.
void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer);
void DeleteCompiledLocalComputation(CompiledLocalComputation* computation);
void DeleteLocalComputation(LocalComputation* computation);

}  // namespace swig

}  // namespace xla

#endif  // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_