aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/create_xla_launch_op.cc
blob: 6f1ff85f24a4c1fd3e6d54fcff9f8868aee6f750 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/create_xla_launch_op.h"

#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/ptr_util.h"

namespace tensorflow {
namespace {

// Utility which searches for values in a sorted list by scanning over it once.
// No matter how many times ScanForValue is called, the list is scanned at most
// once. However, if a call to ScanForValue skips over a value, that value is
// not revisited in future calls to ScanForValue, so callers must take
// care to order their calls.
//
// Useful for merging multiple sorted lists in O(n) time.
class SinglePassSearch {
 public:
  // Creates a SinglePassSearch object that can be used to search in `values`.
  // Does not take ownership of `values`. `values` must outlive this.
  // `values` must be sorted.
  explicit SinglePassSearch(const std::vector<int>* values)
      : current_index_(0), values_(values) {}

  // Scans forward in the vector looking for "value", updating the internal
  // position in to the vector.
  // Returns true iff the vector contains the given value at or after current
  // position.
  // Not thread-safe.
  bool ScanForValue(int value) {
    while (current_index_ < values_->size() &&
           (*values_)[current_index_] <= value) {
      if ((*values_)[current_index_] == value) {
        current_index_++;
        return true;
      }
      current_index_++;
    }
    return false;
  }

 private:
  int current_index_;
  const std::vector<int>* values_;
};

Status CompilationRequested(const FunctionLibraryRuntime& flr,
                            const NodeDef& node_def) {
  const FunctionDef* function_def =
      flr.GetFunctionLibraryDefinition()->Find(node_def.name());
  if (function_def == nullptr) {
    // The node def is not calling a function. Individual ops can be
    // run directly using on-demand mode, no need to create XlaLaunch
    // kernel for them.
    // TODO(b/110359382): Make custom kernel creation return a bool instead of
    // status.
    // We don't set error messages here to avoid unnecessary string copy.
    // Similarly below.
    return Status(error::INVALID_ARGUMENT, "");
  }

  // If kXlaCompileAttr is set on the node_def, use its value.
  const auto& it = node_def.attr().find(kXlaCompileAttr);
  if (it != node_def.attr().end()) {
    return it->second.b() ? Status::OK() : Status(error::INVALID_ARGUMENT, "");
  }

  // kXlaCompileAttr is not set on node_def, check if it is set on
  // FunctionDef.
  bool xla_compile = false;
  Status status = flr.GetFunctionLibraryDefinition()->GetAttr(
      node_def, kXlaCompileAttr, &xla_compile);
  if (!status.ok() || !xla_compile) {
    if (VLOG_IS_ON(3)) {
      if (!status.ok()) {
        VLOG(3) << "No " << kXlaCompileAttr << " attr defined for "
                << node_def.op() << ". status=" << status.ToString();
      } else {
        VLOG(3) << node_def.op() << " is explicitly marked not to be compiled";
      }
    }
    return Status(error::INVALID_ARGUMENT, "");
  }
  return Status::OK();
}

// Given a FunctionLibraryRuntime and a NodeDef calling a function in the
// runtime, returns this function's body in `fbody` as well as the indices
// of its constant and resource arguments.
// `fbody` is owned by `flr`.
// `constant_arg_indices` and `resource_arg_indices` should be empty vector.
// They are sorted in ascending order on this function's return.
Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr,
                                       const NodeDef& node_def,
                                       const FunctionBody** fbody,
                                       std::vector<int>* constant_arg_indices,
                                       std::vector<int>* resource_arg_indices) {
  FunctionLibraryRuntime::Handle handle;
  // If node_def is not instantiable, e.g., the function does not exist,
  // simply bail out.
  TF_RETURN_IF_ERROR(
      flr->Instantiate(node_def.op(), AttrSlice(&node_def.attr()), &handle));
  *fbody = flr->GetFunctionBody(handle);
  CHECK(*fbody);  // Can't be nullptr since we just instantiated it.
  const DataTypeVector& arg_types = (*fbody)->arg_types;
  std::vector<bool> const_args(arg_types.size());
  // If we can't analyze the const args. Bail out.
  TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
      *((*fbody)->graph), &const_args, /*compile_time_const_nodes=*/nullptr));

  for (int i = 0; i < const_args.size(); ++i) {
    if (const_args[i]) {
      constant_arg_indices->push_back(i);
    }
  }

  // There can be hundreds of resource variables. Reserve the space for them.
  // We don't reserve for constants above as they are usually few.
  resource_arg_indices->reserve(arg_types.size());
  for (int i = 0; i < arg_types.size(); ++i) {
    if (arg_types[i] == DT_RESOURCE) {
      resource_arg_indices->push_back(i);
    }
  }

  return Status::OK();
}

}  // namespace

Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& node_def,
                         std::unique_ptr<OpKernel>* kernel) {
  TF_RETURN_IF_ERROR(CompilationRequested(*flr, node_def));

  VLOG(3) << "Creating XlaLaunchOp for " << node_def.DebugString();

  // Make sure that kernels have been registered on the JIT device.
  XlaOpRegistry::RegisterCompilationKernels();
  if (!IsCompilable(flr, node_def)) {
    // node_def is calling a function that XLA can't compile.
    return errors::InvalidArgument("Not compilable: ",
                                   node_def.ShortDebugString());
  }

  // Get function body, constant args, and resource args.
  const FunctionBody* fbody = nullptr;
  std::vector<int> constant_arg_indices;
  std::vector<int> resource_arg_indices;
  TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
      flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));

  // Set input and output memory types.
  MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
  // These indices are used only for optimization purposes. They allow us
  // to loop over constant_arg_indices and resource_arg_indices only once
  // while iterating over all the function arguments checking if it is a
  // resource or a constant.
  // The reason we optimized this code is because functions can have a lot of
  // captured arguments. For example, the backward pass of ResNet50 takes in all
  // 214 variables and a similar number of activations.
  SinglePassSearch constants_search(&constant_arg_indices);
  SinglePassSearch resources_search(&resource_arg_indices);
  for (int i = 0; i < fbody->arg_types.size(); ++i) {
    if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
      // Compile-time constants and resource handles are expected to be in
      // host memory.
      input_memory_types[i] = HOST_MEMORY;
    }
  }
  // One might wonder, about the case where a compile-time constant argument
  // (which must be in host memory) is also used as an input into an op,
  // e.g. Add, that expects its inputs in device memory. Here is how it
  // works now.
  // First, what do we mean by "op expects an input in XYZ memory"?
  // There are two types of "ops" here: the tf2xla kernel and the HLO
  // computation it builds. The tf2xla kernel needs to retrieve the actual
  // numeric value of the compile-time constant tensors, so it really expects
  // them to be on in host memory. However, for other inputs, it refers to them
  // using xla::ComputationDataHandle, which is just a symbolic handle that
  // xla::ComputationBuilder assigns. How does this handle gets assigned for
  // constant arguments? Even constant arguments get an _Arg node in the graph
  // instatiated for Function compilation. The tf2xla kernel for constant _Arg
  // nodes takes the constant value, converts it to XlaLiteral, and feeds it
  // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
  // constant XlaLiteral is included in the HLO graph, and subsequently, in
  // the actual executable, which is copied to the device before being
  // executed. Thus, when this executable runs, the constant is available in
  // device memory.

  // XlaLaunch kernel keeps all outputs (including constants, which it copies),
  // in device memory except for resources.
  MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
  for (int i = 0; i < fbody->ret_types.size(); ++i) {
    if (fbody->ret_types[i] == DT_RESOURCE) {
      output_memory_types[i] = HOST_MEMORY;
    }
  }

  // Create the kernel.
  NameAttrList function;
  function.set_name(node_def.op());
  *(function.mutable_attr()) = node_def.attr();

  Device* dev = flr->device();
  Status s;
  OpKernelConstruction construction(
      DeviceType(dev->device_type()), dev,
      dev->GetAllocator(AllocatorAttributes()), &node_def,
      &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
      fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);

  *kernel = absl::make_unique<XlaLocalLaunchBase>(
      &construction, constant_arg_indices, resource_arg_indices, function);
  return s;
}

namespace {

bool RegisterLaunchOpCreator() {
  RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp);
  return true;
}

static bool register_me = RegisterLaunchOpCreator();

}  // end namespace
}  // namespace tensorflow