aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/aot/runtime.cc6
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc22
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h29
-rw-r--r--tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc12
-rw-r--r--tensorflow/compiler/xla/service/cpu/compiler_functor.cc21
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc23
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc104
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.h27
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc7
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc410
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h97
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.cc44
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc5
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc6
-rw-r--r--tensorflow/compiler/xla/tests/local_client_aot_test.cc7
-rw-r--r--tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc2
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc29
17 files changed, 461 insertions, 390 deletions
diff --git a/tensorflow/compiler/aot/runtime.cc b/tensorflow/compiler/aot/runtime.cc
index 5e74079fc1..7606420ded 100644
--- a/tensorflow/compiler/aot/runtime.cc
+++ b/tensorflow/compiler/aot/runtime.cc
@@ -64,7 +64,7 @@ size_t align_to(size_t n, size_t align) {
size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n) {
size_t total = 0;
for (size_t i = 0; i < n; ++i) {
- if (sizes[i] != -1) {
+ if (sizes[i] > 0) {
total += align_to(sizes[i], kAlign);
}
}
@@ -85,7 +85,9 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
}
uintptr_t pos = reinterpret_cast<uintptr_t>(contiguous);
for (size_t i = 0; i < n; ++i) {
- if (sizes[i] == -1) {
+ if (sizes[i] < 0) {
+ // bufs[i] is either a constant, an entry parameter or a thread local
+ // allocation.
bufs[i] = nullptr;
} else {
bufs[i] = reinterpret_cast<void*>(pos);
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
index 672e19bd93..ed5aa08c6f 100644
--- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
@@ -26,6 +26,8 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
result_index_(static_data.result_index),
args_(new void*[static_data.num_args]),
temps_(new void*[static_data.num_temps]),
+ arg_index_to_temp_index_(new int32[static_data.num_args]),
+ num_args_(static_data.num_args),
arg_names_(static_data.arg_names),
result_names_(static_data.result_names),
program_shape_(static_data.program_shape),
@@ -40,6 +42,13 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
static_data.temp_sizes, static_data.num_temps, temps_,
/*annotate_initialized=*/true);
+ for (int i = 0; i < static_data.num_temps; i++) {
+ if (static_data.temp_sizes[i] < -1) {
+ int32 param_number = -(static_data.temp_sizes[i] + 2);
+ arg_index_to_temp_index_[param_number] = i;
+ }
+ }
+
// If Hlo profiling is enabled the generated code expects an appropriately
// sized buffer to be passed in as the last argument. If Hlo profiling is
// disabled the last function argument is still present in the function
@@ -50,11 +59,24 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
}
}
+bool XlaCompiledCpuFunction::Run() {
+ // Propagate pointers to the argument buffers into the temps array. Code
+ // generated by XLA discovers the incoming argument pointers from the temps
+ // array.
+ for (int32 i = 0; i < num_args_; i++) {
+ temps_[arg_index_to_temp_index_[i]] = args_[i];
+ }
+ raw_function_(temps_[result_index_], &run_options_, nullptr, temps_,
+ profile_counters_);
+ return true;
+}
+
XlaCompiledCpuFunction::~XlaCompiledCpuFunction() {
tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_);
tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_);
delete[] args_;
delete[] temps_;
+ delete[] arg_index_to_temp_index_;
delete[] profile_counters_;
}
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
index 48a8c083ca..27cfb354bf 100644
--- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
@@ -60,9 +60,19 @@ class XlaCompiledCpuFunction {
// The raw function to call.
RawFunction raw_function;
- // Cardinality and sizes of arg and temp buffers.
+ // Cardinality and size of arg buffers.
const intptr_t* arg_sizes = nullptr;
size_t num_args = 0;
+
+ // Cardinality and size of temp buffers.
+ //
+ // If temp_sizes[i] >= 0 then the i'th temp is a regular temporary buffer.
+ //
+ // If temp_sizes[i] == -1 then the i'th temp is a constant buffer. The
+ // corresponding entry in the temp buffer array needs to be set to null.
+ //
+ // If temp_sizes[i] < -1 then the i'th temp is the entry parameter
+ // -(temp_sizes[i] + 2).
const intptr_t* temp_sizes = nullptr;
size_t num_temps = 0;
@@ -113,11 +123,7 @@ class XlaCompiledCpuFunction {
// Runs the computation, with inputs read from arg buffers, and outputs
// written to result buffers. Returns true on success and false on failure.
- bool Run() {
- raw_function_(temps_[result_index_], &run_options_,
- const_cast<const void**>(args_), temps_, profile_counters_);
- return true;
- }
+ bool Run();
// Returns the error message from the previous failed Run call.
//
@@ -224,6 +230,17 @@ class XlaCompiledCpuFunction {
void** args_ = nullptr;
void** temps_ = nullptr;
+ // Argument i needs to be placed in temps_[arg_index_to_temp_index_[i]] for
+ // XLA generated code to be able to find it.
+ //
+ // For now we need to keep around the args_ array because there is code that
+ // depends on args() returning a void**. However, in the future we may remove
+ // args_ in favor of using temps_ as the sole storage for the arguments.
+ int32* arg_index_to_temp_index_;
+
+ // The number of incoming arguments.
+ int32 num_args_;
+
// Backing memory for individual arg and temp buffers.
void* alloc_args_ = nullptr;
void* alloc_temps_ = nullptr;
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
index 00ccfb1c78..114a9241bd 100644
--- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
@@ -58,11 +58,15 @@ xla::StatusOr<std::vector<intptr_t>> ComputeTempSizes(
std::vector<intptr_t> temp_sizes;
temp_sizes.reserve(allocations.size());
for (const xla::BufferAllocation& allocation : allocations) {
- // Callers don't allocate temporary buffers for parameters. Nor for
- // thread-local buffers, which are lowered to alloca.
- if (allocation.is_entry_computation_parameter() ||
- allocation.is_thread_local()) {
+ if (allocation.is_constant() || allocation.is_thread_local()) {
+ // Constants are lowered to globals. Thread locals are lowered to
+ // allocas.
temp_sizes.push_back(-1);
+ } else if (allocation.is_entry_computation_parameter()) {
+ // Entry computation parameters need some preprocessing in
+ // XlaCompiledCpuFunction::Run. See the comment on
+ // XlaCompiledCpuFunction::StaticData::temp_sizes.
+ temp_sizes.push_back(-allocation.parameter_number() - 2);
} else {
temp_sizes.push_back(allocation.size());
}
diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
index 6a7eb85e3b..128eea4828 100644
--- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
+++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
@@ -156,9 +156,26 @@ std::unique_ptr<llvm::MemoryBuffer> CompilerFunctor::operator()(
target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream);
codegen_passes.run(module);
- // Construct ObjectFile from machine code buffer.
- return std::unique_ptr<llvm::MemoryBuffer>(
+ std::unique_ptr<llvm::MemoryBuffer> memory_buffer(
new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer)));
+
+ if (VLOG_IS_ON(2)) {
+ llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> obj_file =
+ llvm::object::ObjectFile::createObjectFile(*memory_buffer);
+ if (obj_file) {
+ StatusOr<DisassemblerResult> disasm_result =
+ disassembler_->DisassembleObjectFile(*obj_file.get());
+ if (disasm_result.ok()) {
+ XLA_VLOG_LINES(2, disasm_result.ValueOrDie().text);
+ } else {
+ LOG(WARNING) << "Could not disassemble object file!";
+ }
+ } else {
+ LOG(WARNING) << "Could convert memory buffer to object file!";
+ }
+ }
+
+ return memory_buffer;
}
static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index b49ea89896..8cbe9a1b0d 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -840,18 +840,29 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
BufferSizes buffer_sizes;
for (const BufferAllocation& allocation : assignment->Allocations()) {
- // Callers don't need to allocate temporary buffers for parameters.
- if (allocation.is_entry_computation_parameter() ||
- allocation.is_constant()) {
- buffer_sizes.push_back(-1);
- continue;
- }
// Callers don't need to allocate anything for thread-local temporary
// buffers. They are lowered to allocas.
if (allocation.is_thread_local()) {
buffer_sizes.push_back(-1);
continue;
}
+
+ // Callers don't need to allocate anything for constant buffers. They are
+ // lowered to globals.
+ if (allocation.is_constant()) {
+ buffer_sizes.push_back(-1);
+ continue;
+ }
+
+ // Callers don't need to allocate anything for entry computation buffers,
+ // but they do need to stash the pointer to the entry computation buffer
+ // in the temp buffer table. See the comment on
+ // XlaCompiledCpuFunction::StaticData::temp_sizes.
+ if (allocation.is_entry_computation_parameter()) {
+ buffer_sizes.push_back(-allocation.parameter_number() - 2);
+ continue;
+ }
+
buffer_sizes.push_back(allocation.size());
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index 81e17a5cd4..946f5124b8 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -69,12 +69,19 @@ CpuExecutable::CpuExecutable(
// guarded by the mutex.
compute_function_ =
reinterpret_cast<ComputeFunctionType>(cantFail(sym.getAddress()));
+ VLOG(1) << "compute_function_ at address "
+ << reinterpret_cast<void*>(compute_function_);
}
-Status CpuExecutable::AllocateBuffers(
+StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
+ std::vector<OwningDeviceMemory>>>
+CpuExecutable::CreateTempArray(
DeviceMemoryAllocator* memory_allocator, int device_ordinal,
- std::vector<OwningDeviceMemory>* buffers) {
- CHECK_EQ(buffers->size(), assignment_->Allocations().size());
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ std::vector<se::DeviceMemoryBase> unowning_buffers(
+ assignment_->Allocations().size());
+ std::vector<OwningDeviceMemory> owning_buffers(
+ assignment_->Allocations().size());
VLOG(3) << "Allocating " << assignment_->Allocations().size()
<< " allocations for module " << module().name();
for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
@@ -84,6 +91,8 @@ Status CpuExecutable::AllocateBuffers(
VLOG(3) << allocation.ToString();
if (allocation.is_entry_computation_parameter()) {
+ unowning_buffers[i] = arguments[allocation.parameter_number()]->buffer(
+ allocation.param_shape_index());
VLOG(3) << "allocation #" << i << " is a parameter";
continue;
}
@@ -99,34 +108,34 @@ Status CpuExecutable::AllocateBuffers(
}
int64 buffer_size = allocation.size();
- if (!(*buffers)[i].is_null()) {
+ if (!owning_buffers[i].is_null()) {
VLOG(3) << "buffer #" << i
<< " is in the preallocated result ShapedBuffer";
} else {
- TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate(
- device_ordinal, buffer_size));
+ TF_ASSIGN_OR_RETURN(owning_buffers[i], memory_allocator->Allocate(
+ device_ordinal, buffer_size));
+ unowning_buffers[i] = owning_buffers[i].AsDeviceMemoryBase();
VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes ["
- << (*buffers)[i].opaque() << "]";
+ << owning_buffers[i].opaque() << "]";
}
// Since the output buffer and all the temporary buffers were written into
// by the JITed code, msan has no way of knowing their memory was
// initialized. Mark them initialized so that msan doesn't flag loads from
// these buffers.
- TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size);
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED(owning_buffers[i].opaque(), buffer_size);
}
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
assignment_->GetUniqueTopLevelOutputSlice());
VLOG(3) << "result index: " << result_slice.index();
- return Status::OK();
+ return {{std::move(unowning_buffers), std::move(owning_buffers)}};
}
Status CpuExecutable::ExecuteComputeFunction(
const ExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile) {
// The calling convention for JITed functions is:
@@ -136,17 +145,11 @@ Status CpuExecutable::ExecuteComputeFunction(
//
// result: Points at the result.
// run_options: the ExecutableRunOptions object.
- // args_array: An array of pointers, each of which points to a parameter.
- // The size of this array is determined by the function's arity
- // (ProgramShape).
- // temps_array: An array of pointers, each of which points to a temporary
- // buffer the computation needs. The size of this array is
- // determined by buffer analysis.
+ // args_array: null
+ // temps_array: An array of pointers, containing pointers to temporary buffers
+ // required by the executable adn pointers to entry computation
+ // parameters.
//
- std::vector<const void*> args_array;
- for (const ShapedBuffer* argument : arguments) {
- args_array.push_back(argument->root_buffer().opaque());
- }
uint64 start_micros = tensorflow::Env::Default()->NowMicros();
@@ -169,16 +172,14 @@ Status CpuExecutable::ExecuteComputeFunction(
if (VLOG_IS_ON(3)) {
VLOG(3) << "Executing compute function:";
VLOG(3) << tensorflow::strings::Printf(
- " func(void* result, void* params[%zu], void* temps[%zu], "
+ " func(void* result, void* params[null], void* temps[%zu], "
"uint64 profile_counters[%zu])",
- args_array.size(), buffer_pointers.size(), profile_counters_size);
+ buffer_pointers.size(), profile_counters_size);
VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer);
auto ptr_printer = [](string* out, const void* p) {
tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p));
};
- VLOG(3) << tensorflow::strings::Printf(
- " params = [%s]",
- tensorflow::str_util::Join(args_array, ", ", ptr_printer).c_str());
+ VLOG(3) << " params = nullptr";
VLOG(3) << tensorflow::strings::Printf(
" temps = [%s]",
tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str());
@@ -186,8 +187,8 @@ Status CpuExecutable::ExecuteComputeFunction(
profile_counters);
}
- compute_function_(result_buffer, run_options, args_array.data(),
- buffer_pointers.data(), profile_counters);
+ compute_function_(result_buffer, run_options, nullptr, buffer_pointers.data(),
+ profile_counters);
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
@@ -254,21 +255,18 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
se::Stream* stream = run_options->stream();
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
- std::vector<OwningDeviceMemory> buffers(assignment_->Allocations().size());
-
- TF_RETURN_IF_ERROR(AllocateBuffers(
- memory_allocator, stream->parent()->device_ordinal(), &buffers));
+ std::vector<OwningDeviceMemory> owning_buffers;
std::vector<se::DeviceMemoryBase> unowning_buffers;
- unowning_buffers.reserve(buffers.size());
- for (auto& buffer : buffers) {
- unowning_buffers.push_back(buffer.AsDeviceMemoryBase());
- }
- TF_RETURN_IF_ERROR(ExecuteComputeFunction(&run_options->run_options(),
- arguments, unowning_buffers,
- hlo_execution_profile));
+ TF_ASSIGN_OR_RETURN(
+ std::tie(unowning_buffers, owning_buffers),
+ CreateTempArray(memory_allocator, stream->parent()->device_ordinal(),
+ arguments));
+
+ TF_RETURN_IF_ERROR(ExecuteComputeFunction(
+ &run_options->run_options(), unowning_buffers, hlo_execution_profile));
- return CreateResultShapedBuffer(run_options, &buffers);
+ return CreateResultShapedBuffer(run_options, &owning_buffers);
}
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
@@ -284,17 +282,15 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
run_options->stream()->implementation());
se::Stream* stream = run_options->stream();
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
- std::vector<OwningDeviceMemory> buffers(assignment_->Allocations().size());
- TF_RETURN_IF_ERROR(AllocateBuffers(
- memory_allocator, stream->parent()->device_ordinal(), &buffers));
-
+ std::vector<OwningDeviceMemory> owning_buffers;
std::vector<se::DeviceMemoryBase> unowning_buffers;
- unowning_buffers.reserve(buffers.size());
- for (auto& buffer : buffers) {
- unowning_buffers.push_back(buffer.AsDeviceMemoryBase());
- }
+ TF_ASSIGN_OR_RETURN(
+ std::tie(unowning_buffers, owning_buffers),
+ CreateTempArray(memory_allocator, stream->parent()->device_ordinal(),
+ arguments));
+
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
- CreateResultShapedBuffer(run_options, &buffers));
+ CreateResultShapedBuffer(run_options, &owning_buffers));
// At this point, `unowning_buffers` contains unowning pointers to all of our
// buffers, and `buffers` contains owning pointers to the non-live-out
@@ -312,7 +308,6 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
struct AsyncRunTask {
CpuExecutable* executable;
ServiceExecutableRunOptions run_options;
- std::vector<const ShapedBuffer*> arguments;
std::vector<se::DeviceMemoryBase> unowning_buffers;
std::shared_ptr<std::vector<OwningDeviceMemory>> buffers;
@@ -320,15 +315,14 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
// Failing a CHECK here is not great, but I don't see an obvious way to
// return a failed Status asynchronously.
TF_CHECK_OK(executable->ExecuteComputeFunction(
- &run_options.run_options(), arguments, unowning_buffers,
+ &run_options.run_options(), unowning_buffers,
/*hlo_execution_profile=*/nullptr));
}
};
- host_stream->EnqueueTask(AsyncRunTask{
- this, *run_options,
- std::vector<const ShapedBuffer*>(arguments.begin(), arguments.end()),
- unowning_buffers,
- std::make_shared<std::vector<OwningDeviceMemory>>(std::move(buffers))});
+ host_stream->EnqueueTask(
+ AsyncRunTask{this, *run_options, std::move(unowning_buffers),
+ std::make_shared<std::vector<OwningDeviceMemory>>(
+ std::move(owning_buffers))});
return std::move(result);
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 8dd47bfb86..8af8a5dfec 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -85,20 +85,29 @@ class CpuExecutable : public Executable {
const BufferAssignment& buffer_assignment() const { return *assignment_; }
private:
- // Allocate buffers required for execution and assign them to the elements of
- // "buffers". "buffers" should be sized to the number of buffers in buffer
- // assignment. Each vector element corresponds to a particular Index. If
- // a vector element already contains a non-null DeviceMemoryBase, then no
- // buffer is assigned for this element.
- Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator,
- int device_ordinal,
- std::vector<OwningDeviceMemory>* buffers);
+ // Creates an array suitable for passing as the "temps" argument to the JIT
+ // compiled function pointer.
+ //
+ // Returns (unowning_buffers, owning_buffers) where:
+ //
+ // - unowning_buffers.data() can be passed as the temps argument as-is and
+ // includes pointers to the scratch storage required by the computation,
+ // the live-out buffer into which the result will be written and entry
+ // computation parameters.
+ //
+ // - owning_buffers contains owning pointers to the buffers that were
+ // allocated by this routine. This routine allocates buffers for temporary
+ // storage and the live-out buffer into which the computation writes it
+ // result.
+ StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
+ std::vector<OwningDeviceMemory>>>
+ CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
// Calls the generated function performing the computation with the given
// arguments using the supplied buffers.
Status ExecuteComputeFunction(
const ExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile);
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index cf955a8add..c13d36776f 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -19,6 +19,8 @@ limitations under the License.
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/types.h"
@@ -117,9 +119,8 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
ElementwiseSourceIndex(index, *hlo, i)));
operands.push_back(operand_value);
}
- return ir_emitter_->EmitScalarCall(hlo->shape().element_type(),
- hlo->to_apply(), operands,
- llvm_ir::IrName(hlo));
+ return ir_emitter_->EmitElementalMap(*Cast<HloMapInstruction>(hlo),
+ operands, llvm_ir::IrName(hlo));
};
}
return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index a6d8551841..60f9cd1121 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -116,6 +116,19 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
computation->root_instruction()->outer_dimension_partitions().size();
}
+ if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) {
+ TF_ASSIGN_OR_RETURN(
+ computation_root_allocation_,
+ assignment_.GetUniqueTopLevelSlice(computation->root_instruction()));
+ }
+
+ for (const HloInstruction* param : computation->parameter_instructions()) {
+ TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice,
+ assignment_.GetUniqueTopLevelSlice(param));
+ computation_parameter_allocations_[param_slice.allocation()->index()] =
+ param->parameter_number();
+ }
+
InitializeIrFunction(function_name);
// The rdtscp instruction is x86 specific. We will fallback to LLVM's generic
// readcyclecounter if it is unavailable.
@@ -132,6 +145,8 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
// Delete 'compute_function', finalizing 'ir_function' and restoring caller
// IR insert point.
compute_function_.reset();
+ computation_root_allocation_ = BufferAllocation::Slice();
+ computation_parameter_allocations_.clear();
return ir_function;
}
@@ -484,23 +499,11 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
return Status::OK();
}
-StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForMap(
- HloMapInstruction* map, const llvm_ir::IrArray::Index& index) {
- llvm::Function* mapped_ir_function =
- FindOrDie(emitted_functions_, map->to_apply());
- std::vector<llvm::Value*> parameter_addresses;
- for (const HloInstruction* operand : map->operands()) {
- const llvm_ir::IrArray& array = GetIrArrayFor(operand);
- parameter_addresses.push_back(array.EmitArrayElementAddress(index, &b_));
- }
- return EmitElementFunctionCall(mapped_ir_function, map->shape(),
- parameter_addresses, "map_function");
-}
-
-Status IrEmitter::HandleMap(HloInstruction* map) {
- return EmitTargetElementLoop(map, [&](const llvm_ir::IrArray::Index& index) {
- return EmitTargetElementLoopBodyForMap(Cast<HloMapInstruction>(map), index);
- });
+llvm::Value* IrEmitter::EmitElementalMap(
+ const HloMapInstruction& map_instr,
+ tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
+ tensorflow::StringPiece name) {
+ return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name);
}
StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
@@ -508,9 +511,6 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
const llvm_ir::IrArray::Index& index) {
const HloInstruction* operand = reduce_window->operand(0);
const Window& window = reduce_window->window();
- HloComputation* function = reduce_window->to_apply();
- // The called computation should have been emitted previously.
- llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
// We fold inputs into the accumulator and initialize it to
// the initial value on the reduce_window.
@@ -563,11 +563,10 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
// We are not in the padding, so carry out the computation.
llvm_ir::IrArray input_array(GetIrArrayFor(operand));
- llvm::Value* input_value_address =
- input_array.EmitArrayElementAddress(input_index, &b_);
- llvm::Value* result = EmitElementFunctionCall(
- reducer_function, reduce_window->shape(),
- {accumulator_address, input_value_address}, "reducer_function");
+ llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_);
+ llvm::Value* result = EmitThreadLocalCall(
+ *reduce_window->to_apply(),
+ {b_.CreateLoad(accumulator_address), input_value}, "reducer_function");
b_.CreateStore(result, accumulator_address);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
@@ -623,12 +622,6 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
"Dilation for SelectAndScatter is not implemented on CPU. ");
}
- // The select and scatter computations should have been emitted previously.
- llvm::Function* select_function =
- FindOrDie(emitted_functions_, select_and_scatter->select());
- llvm::Function* scatter_function =
- FindOrDie(emitted_functions_, select_and_scatter->scatter());
-
// Pseudo code for select-and-scatter:
//
// initialized_flag is initially off for every window, and is turned on after
@@ -733,11 +726,12 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
// If the initialized_flag is true, call the `select` function to potentially
// update the selected value and index with the currently visiting operand.
SetToFirstInsertPoint(if_initialized.true_block, &b_);
- const Shape output_shape = ShapeUtil::MakeShape(PRED, {});
llvm::Value* operand_address =
operand_array.EmitArrayElementAddress(operand_index, &b_);
- llvm::Value* result = EmitElementFunctionCall(
- select_function, output_shape, {selected_value_address, operand_address},
+ llvm::Value* operand_element = b_.CreateLoad(operand_address);
+ llvm::Value* result = EmitThreadLocalCall(
+ *select_and_scatter->select(),
+ {b_.CreateLoad(selected_value_address), operand_element},
"select_function");
// If the 'select' function returns false, update the selected value and the
@@ -764,14 +758,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
selected_index.push_back(b_.CreateLoad(selected_index_address_slot));
}
llvm_ir::IrArray source_array(GetIrArrayFor(source));
- llvm::Value* source_value_address =
- source_array.EmitArrayElementAddress(source_index, &b_);
+ llvm::Value* source_value =
+ source_array.EmitReadArrayElement(source_index, &b_);
llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter));
- llvm::Value* output_value_address =
- output_array.EmitArrayElementAddress(selected_index, &b_);
- llvm::Value* scatter_value = EmitElementFunctionCall(
- scatter_function, source->shape(),
- {output_value_address, source_value_address}, "scatter_function");
+ llvm::Value* output_value =
+ output_array.EmitReadArrayElement(selected_index, &b_);
+ llvm::Value* scatter_value =
+ EmitThreadLocalCall(*select_and_scatter->scatter(),
+ {output_value, source_value}, "scatter_function");
output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_);
SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_);
@@ -1248,46 +1242,7 @@ static llvm_ir::IrArray::Index FillReducedDimensionIndex(
Status IrEmitter::HandleParameter(HloInstruction* parameter) {
VLOG(2) << "HandleParameter: " << parameter->ToString();
- auto param_number = parameter->parameter_number();
- auto param_shape = parameter->shape();
-
- // We have to access the parameter at offset param_number in the params
- // array. The code generated here is equivalent to this C code:
- //
- // i8* param_address_untyped = params[param_number];
- // Param* param_address_typed = (Param*)param_address_untyped;
- //
- // Where Param is the actual element type of the underlying buffer (for
- // example, float for an XLA F32 element type).
- llvm::Value* params = compute_function_->parameters_arg();
- llvm::Value* param_address_offset =
- llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
- llvm::LoadInst* param_address_untyped = b_.CreateLoad(param_address_offset);
- param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped")));
- if (is_top_level_computation_ &&
- hlo_module_config_.debug_options()
- .xla_llvm_enable_invariant_load_metadata()) {
- // In the entry computation the parameter slots in the %params argument are
- // invariant through program execution. In computations that are called
- // from the entry computation (via kWhile, kCall and kConditional) the
- // parameter slots are *not* invariant since they're written to by their
- // callers.
- param_address_untyped->setMetadata(
- llvm::LLVMContext::MD_invariant_load,
- llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{}));
- }
-
- llvm::Value* param_address_typed = b_.CreateBitCast(
- param_address_untyped, IrShapeType(param_shape)->getPointerTo());
- emitted_value_[parameter] = param_address_typed;
-
- if (!ShapeUtil::IsOpaque(param_shape)) {
- AttachAlignmentMetadataForLoad(param_address_untyped, param_shape);
- AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape);
- }
-
- VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*param_address_typed);
- return Status::OK();
+ return EmitTargetAddressForOp(parameter);
}
// Returns true if the relative order of the unreduced dimensions stays the same
@@ -1751,9 +1706,6 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
const HloInstruction* arg = reduce->mutable_operand(0);
const HloInstruction* init_value = reduce->mutable_operand(1);
gtl::ArraySlice<int64> dimensions(reduce->dimensions());
- HloComputation* function = reduce->to_apply();
- // The called computation should have been emitted previously.
- llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
// Initialize an accumulator with init_value.
PrimitiveType accumulator_type = reduce->shape().element_type();
@@ -1793,10 +1745,9 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
CHECK(index.end() == it);
// Apply the reduction function to the loaded value.
- llvm::Value* input_address =
- arg_array.EmitArrayElementAddress(input_index, &b_);
- llvm::Value* result = EmitElementFunctionCall(
- reducer_function, reduce->shape(), {accumulator_addr, input_address},
+ llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_);
+ llvm::Value* result = EmitThreadLocalCall(
+ *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element},
"reduce_function");
b_.CreateStore(result, accumulator_addr);
@@ -2134,18 +2085,13 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
HloComputation* computation = call->to_apply();
llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation);
- std::vector<llvm::Value*> parameter_addresses;
- for (const HloInstruction* operand : call->operands()) {
- parameter_addresses.push_back(GetEmittedValueFor(operand));
- }
-
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call));
if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
// ParallelTaskAssignment assigned partitions, emit call to
// ParallelForkJoin.
std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments(
- parameter_addresses, &b_, computation->name(),
+ {}, &b_, computation->name(),
/*return_value_buffer=*/emitted_value_[call],
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
/*temp_buffers_arg=*/GetTempBuffersArgument(),
@@ -2156,8 +2102,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
call_args, root->shape(), root->outer_dimension_partitions(), &b_,
call_ir_function, computation->name()));
} else {
- EmitArrayFunctionCallInto(call_ir_function, parameter_addresses,
- emitted_value_[call], computation->name());
+ EmitGlobalCall(*computation, computation->name());
}
return Status::OK();
@@ -2238,12 +2183,6 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
const HloInstruction* init = xla_while->operand(0);
emitted_value_[xla_while] = GetEmittedValueFor(init);
- // The called computation should have been emitted previously.
- llvm::Function* condition_ir_function =
- FindOrDie(emitted_functions_, condition);
- llvm::Function* body_ir_function =
- FindOrDie(emitted_functions_, xla_while->while_body());
-
// Generating:
// while (Condition(while_result)) {
// // CopyInsertion pass inserts copies which enable 'while_result' to
@@ -2260,12 +2199,10 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
// Calls the condition function to determine whether to proceed with the
// body. It must return a bool, so use the scalar call form.
- llvm::Value* while_result = GetEmittedValueFor(xla_while);
- llvm::Value* while_condition = EmitElementFunctionCall(
- condition_ir_function, condition->root_instruction()->shape(),
- {while_result}, IrName(xla_while, "cond"));
+ EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond"));
llvm::Value* while_predicate = b_.CreateICmpNE(
- while_condition,
+ b_.CreateLoad(
+ GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
// Branches to the body or to the while exit depending on the condition.
@@ -2280,8 +2217,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
b_.SetInsertPoint(body_bb);
// Calls the body function.
- EmitArrayFunctionCallInto(body_ir_function, {while_result}, while_result,
- IrName(xla_while, "body"));
+ EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body"));
+
// Finishes with a branch back to the header.
b_.CreateBr(header_bb);
@@ -2449,8 +2386,6 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
Status IrEmitter::HandleConditional(HloInstruction* conditional) {
auto pred = conditional->operand(0);
- auto true_arg = conditional->operand(1);
- auto false_arg = conditional->operand(2);
TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) &&
pred->shape().element_type() == PRED)
<< "Predicate on a Conditional must be bool; got: "
@@ -2472,13 +2407,7 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
<< " and "
<< ShapeUtil::HumanString(false_computation->root_instruction()->shape());
- llvm::Function* true_function =
- FindOrDie(emitted_functions_, true_computation);
- llvm::Function* false_function =
- FindOrDie(emitted_functions_, false_computation);
-
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional));
- llvm::Value* conditional_result = GetEmittedValueFor(conditional);
// Generating:
// if (pred)
@@ -2495,12 +2424,12 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_);
SetToFirstInsertPoint(if_data.true_block, &b_);
- EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)},
- conditional_result, IrName(conditional, "_true"));
+ EmitGlobalCall(*conditional->true_computation(),
+ IrName(conditional, "_true"));
SetToFirstInsertPoint(if_data.false_block, &b_);
- EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)},
- conditional_result, IrName(conditional, "_false"));
+ EmitGlobalCall(*conditional->false_computation(),
+ IrName(conditional, "_false"));
SetToFirstInsertPoint(if_data.after_block, &b_);
return Status::OK();
@@ -2701,44 +2630,76 @@ llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
return compute_function_->exec_run_options_arg();
}
-llvm::Value* IrEmitter::EmitTempBufferPointer(
+llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape) {
- llvm::Type* element_type = IrShapeType(target_shape);
- // The alignment and number of bytes within the temporary buffer is determined
- // by the maximal shape as determined by buffer assignment.
- const BufferAllocation& allocation = assignment_.GetAllocation(slice.index());
- if (allocation.is_thread_local()) {
+ const BufferAllocation& allocation = *slice.allocation();
+ llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
+ if (slice == computation_root_allocation_) {
+ llvm::Argument* retval = compute_function_->result_arg();
+ llvm::AttrBuilder attr_builder;
+ attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
+ attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
+ retval->addAttrs(attr_builder);
+ return retval;
+ }
+
+ auto param_it =
+ computation_parameter_allocations_.find(slice.allocation()->index());
+ if (param_it != computation_parameter_allocations_.end()) {
+ int64 param_number = param_it->second;
+ // We have to access the parameter at offset param_number in the params
+ // array. The code generated here is equivalent to this C code:
+ //
+ // i8* param_address_untyped = params[param_number];
+ // Param* param_address_typed = (Param*)param_address_untyped;
+ //
+ // Where Param is the actual element type of the underlying buffer (for
+ // example, float for an XLA F32 element type).
+ llvm::Value* params = compute_function_->parameters_arg();
+ llvm::Value* param_address_offset =
+ llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
+ llvm::LoadInst* param_address_untyped =
+ b_.CreateLoad(param_address_offset);
+
+ if (!ShapeUtil::IsOpaque(target_shape)) {
+ AttachAlignmentMetadataForLoad(param_address_untyped, target_shape);
+ AttachDereferenceableMetadataForLoad(param_address_untyped,
+ target_shape);
+ }
+ return param_address_untyped;
+ }
+
// Thread-local allocations should only be assigned a single buffer.
const auto& assigned_buffers = allocation.assigned_buffers();
CHECK_EQ(1, assigned_buffers.size());
const Shape& shape = assigned_buffers.begin()->first->shape();
- llvm::AllocaInst*& tempbuf_address =
- thread_local_buffers_[{b_.GetInsertBlock()->getParent(), slice}];
- if (tempbuf_address == nullptr) {
- tempbuf_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ std::pair<llvm::Function*, BufferAllocation::Slice> key = {
+ compute_function_->function(), slice};
+ auto buf_it = thread_local_buffers_.find(key);
+ if (buf_it == thread_local_buffers_.end()) {
+ llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry(
IrShapeType(shape),
tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_,
MinimumAlignmentForShape(target_shape));
+ auto it_inserted_pair = thread_local_buffers_.insert({key, buffer});
+ CHECK(it_inserted_pair.second);
+ buf_it = it_inserted_pair.first;
}
- return b_.CreateBitCast(tempbuf_address, element_type->getPointerTo());
- }
-
- if (allocation.is_constant()) {
- return FindOrDie(constant_buffer_to_global_, allocation.index());
- }
+ return buf_it->second;
+ }();
+ return b_.CreateBitCast(tempbuf_address,
+ IrShapeType(target_shape)->getPointerTo());
+}
+llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
+ const BufferAllocation::Slice& slice, const Shape& target_shape) {
+ const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
GetTempBuffersArgument(), slice.index(), &b_);
llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr);
- if (is_top_level_computation_ &&
- hlo_module_config_.debug_options()
+ if (hlo_module_config_.debug_options()
.xla_llvm_enable_invariant_load_metadata()) {
- // In the entry computation the parameter slots in the %params argument are
- // invariant through program execution. In computations that are called
- // from the entry computation (via kWhile, kCall and kConditional) the
- // parameter slots are *not* invariant since they're written to by their
- // callers.
tempbuf_address_base->setMetadata(
llvm::LLVMContext::MD_invariant_load,
llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
@@ -2753,85 +2714,25 @@ llvm::Value* IrEmitter::EmitTempBufferPointer(
b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
}
return b_.CreateBitCast(tempbuf_address_untyped,
- element_type->getPointerTo());
-}
-
-// Emits a function call returning a single array element. Allocates space
-// for a single element_type value, and loads it after call.
-llvm::Value* IrEmitter::EmitElementFunctionCall(
- llvm::Function* function, const Shape& return_shape,
- gtl::ArraySlice<llvm::Value*> parameter_addresses,
- tensorflow::StringPiece name) {
- llvm::Value* return_value_buffer = EmitArrayFunctionCall(
- function, return_shape, 1, parameter_addresses, name);
- return b_.CreateLoad(
- return_value_buffer,
- AsStringRef(tensorflow::strings::StrCat(name, "_return_value")));
+ IrShapeType(target_shape)->getPointerTo());
}
-// Emits a core function call based on the following pseudo-code.
-//
-// char** parameter_addresses_buffer =
-// allocate buffer with a pointer for each parameter to the function
-// for each parameter index, i.e. for i = 0, ..., #parameters:
-// parameter_addresses_buffer[i] = parameter_addresses[i]
-// call function(return_value_buffer,
-// parameter_addresses_buffer,
-// temps)
-// return return_value_buffer -- address of the return value.
-void IrEmitter::EmitArrayFunctionCallInto(
- llvm::Function* function, gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::Value* return_value_buffer, tensorflow::StringPiece name) {
- b_.CreateCall(function,
- GetArrayFunctionCallArguments(
- parameter_addresses, &b_, name,
- /*return_value_buffer=*/return_value_buffer,
- /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/GetTempBuffersArgument(),
- /*profile_counters_arg=*/GetProfileCountersArgument()));
-}
-
-llvm::Value* IrEmitter::EmitArrayFunctionCall(
- llvm::Function* function, const Shape& return_shape, int64 element_count,
- gtl::ArraySlice<llvm::Value*> parameter_addresses,
- tensorflow::StringPiece name) {
- llvm::Value* elements =
- llvm::ConstantInt::get(b_.getInt64Ty(), element_count);
- PrimitiveType return_type = return_shape.element_type();
- llvm::Value* return_value_buffer =
- llvm_ir::EmitAllocaAtFunctionEntryWithCount(
- llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements,
- tensorflow::strings::StrCat(name, "_return_value_address"), &b_,
- MinimumAlignmentForPrimitiveType(return_type));
- EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer,
- name);
- return return_value_buffer;
+llvm::Value* IrEmitter::EmitTempBufferPointer(
+ const BufferAllocation::Slice& slice, const Shape& target_shape) {
+ if (slice.allocation()->is_thread_local()) {
+ return EmitThreadLocalTempBufferPointer(slice, target_shape);
+ } else if (slice.allocation()->is_constant()) {
+ return FindOrDie(constant_buffer_to_global_, slice.allocation()->index());
+ } else {
+ return EmitGlobalTempBufferPointer(slice, target_shape);
+ }
}
Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
- llvm::Value* addr;
const Shape& target_shape = op->shape();
- if (op == op->parent()->root_instruction()) {
- // For the root node, we write directly to the output buffer of the
- // function.
- llvm::Argument* retval = compute_function_->result_arg();
- if ((ShapeUtil::IsArray(target_shape) &&
- !ShapeUtil::IsZeroElementArray(target_shape)) ||
- (ShapeUtil::IsTuple(target_shape) &&
- !ShapeUtil::IsEmptyTuple(target_shape))) {
- llvm::AttrBuilder attr_builder;
- attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
- attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
- retval->addAttrs(attr_builder);
- }
- addr = b_.CreateBitCast(retval, IrShapeType(target_shape)->getPointerTo());
- } else {
- // For other nodes, we need the temporary buffer allocated for this node to
- // write the result into.
- TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
- assignment_.GetUniqueTopLevelSlice(op));
- addr = EmitTempBufferPointer(slice, target_shape);
- }
+ TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
+ assignment_.GetUniqueTopLevelSlice(op));
+ llvm::Value* addr = EmitTempBufferPointer(slice, target_shape);
addr->setName(AsStringRef(IrName(op)));
emitted_value_[op] = addr;
return Status::OK();
@@ -2936,20 +2837,69 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
}
-StatusOr<llvm::Value*> IrEmitter::EmitScalarCall(
- PrimitiveType return_type, HloComputation* computation,
- const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name) {
- llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation);
- std::vector<llvm::Value*> argument_addrs;
- for (auto argument : arguments) {
- llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry(
- argument->getType(), "arg_addr", &b_);
- b_.CreateStore(argument, argument_addr);
- argument_addrs.push_back(argument_addr);
+llvm::Value* IrEmitter::EmitThreadLocalCall(
+ const HloComputation& callee,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
+ tensorflow::StringPiece name) {
+ const Shape& return_shape = callee.root_instruction()->shape();
+
+ // Lifting this restriction to allow "small" arrays should be easy. Allowing
+ // larger arrays is difficult because we allocate the buffer for this return
+ // value on the stack.
+ CHECK(ShapeUtil::IsScalar(return_shape));
+
+ PrimitiveType return_type = return_shape.element_type();
+
+ std::vector<llvm::Value*> parameter_addrs;
+ for (llvm::Value* parameter : parameters) {
+ CHECK(!parameter->getType()->isPointerTy());
+ llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+ parameter->getType(), "arg_addr", &b_);
+ b_.CreateStore(parameter, parameter_addr);
+ parameter_addrs.push_back(parameter_addr);
}
- return EmitElementFunctionCall(llvm_function,
- ShapeUtil::MakeShape(return_type, {}),
- argument_addrs, name);
+
+ llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(return_type, module_),
+ tensorflow::strings::StrCat(name, "_retval_addr"), &b_,
+ MinimumAlignmentForPrimitiveType(return_type));
+
+ b_.CreateCall(
+ FindOrDie(emitted_functions_, &callee),
+ GetArrayFunctionCallArguments(
+ parameter_addrs, &b_, name,
+ /*return_value_buffer=*/return_value_buffer,
+ /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
+ /*temp_buffers_arg=*/
+ llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
+ /*profile_counters_arg=*/GetProfileCountersArgument()));
+
+ return b_.CreateLoad(return_value_buffer);
}
+
+void IrEmitter::EmitGlobalCall(const HloComputation& callee,
+ tensorflow::StringPiece name) {
+ b_.CreateCall(FindOrDie(emitted_functions_, &callee),
+ GetArrayFunctionCallArguments(
+ /*parameter_addresses=*/{}, &b_, name,
+ /*return_value_buffer=*/
+ llvm::Constant::getNullValue(b_.getInt8PtrTy()),
+ /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
+ /*temp_buffers_arg=*/GetTempBuffersArgument(),
+ /*profile_counters_arg=*/GetProfileCountersArgument()));
+}
+
+llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
+ const HloComputation& callee) {
+ const HloInstruction* root_inst = callee.root_instruction();
+ if (root_inst->opcode() == HloOpcode::kOutfeed) {
+ return llvm::Constant::getNullValue(b_.getInt8PtrTy());
+ }
+
+ const BufferAllocation::Slice root_buffer =
+ assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie();
+ return EmitTempBufferPointer(root_buffer, root_inst->shape());
+}
+
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 03bbb2afb5..372017441f 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -100,14 +100,15 @@ class IrEmitter : public DfsHloVisitorWithDefault {
llvm::IRBuilder<>* b() { return &b_; }
- // Emits a call to `computation` with scalar arguments `arguments`.
- StatusOr<llvm::Value*> EmitScalarCall(
- PrimitiveType return_type, HloComputation* computation,
- const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name);
-
// Emit an LLVM global variable for every constant buffer allocation.
Status EmitConstantGlobals();
+ // Emit code to map one element according to `map_instr`.
+ llvm::Value* EmitElementalMap(
+ const HloMapInstruction& map_instr,
+ tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
+ tensorflow::StringPiece name);
+
protected:
//
// The following methods implement the DfsHloVisitor interface.
@@ -143,7 +144,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandlePad(HloInstruction* pad) override;
Status HandleTuple(HloInstruction* tuple) override;
- Status HandleMap(HloInstruction* map) override;
Status HandleFusion(HloInstruction* fusion) override;
Status HandleCall(HloInstruction* call) override;
Status HandleCustomCall(HloInstruction* custom_call) override;
@@ -218,9 +218,18 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// computation function being emitted by this emitter.
llvm::Value* GetTempBuffersArgument();
- // Emits code that computes the address of the given temporary buffer to the
- // function. target_shape is the shape of this temporary buffer.
- // The returned Value's type is a pointer to element_type.
+ // Helper for EmitTempBufferPointer.
+ llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice,
+ const Shape& target_shape);
+
+ // Helper for EmitTempBufferPointer.
+ llvm::Value* EmitThreadLocalTempBufferPointer(
+ const BufferAllocation::Slice& slice, const Shape& target_shape);
+
+ // Emits code that computes the address of the given buffer allocation slice.
+ //
+ // TODO(sanjoy): This should be renamed to reflect that it no longer provides
+ // access to just temporaries.
llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice,
const Shape& target_shape);
@@ -232,44 +241,27 @@ class IrEmitter : public DfsHloVisitorWithDefault {
tensorflow::StringPiece
function_name_suffix); // Used for LLVM IR register names.
- // Methods that emit a function call.
- // Parameters:
- // function - The LLVM function to call.
- // return_shape - The return shape of the HLO computation that was used to
- // make the function. Not the same as the return type of the function
- // in LLVM, since we use output parameters for the return type.
- // element_count - number of elements to return (array form only).
- // parameter_addresses - pointers to be passed to the function as
- // parameters.
- // name - used for LLVM IR register names.
-
- // Emits a function call, returning a scalar, often an element of a larger
- // array. Returns a Value for the scalar element returned by the function.
- llvm::Value* EmitElementFunctionCall(
- llvm::Function* function, const Shape& return_shape,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ // Emits a call to a thread local function (e.g. to the computation nested
+ // within a reduce or a map). Thread local callees (by definition) only write
+ // to and read from thread local allocations.
+ //
+ // `parameters` holds the *scalar values* that need to be passed to the
+ // callee. The return value is the scalar returned by the callee.
+ llvm::Value* EmitThreadLocalCall(
+ const HloComputation& callee,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
tensorflow::StringPiece name);
- // Array function call emitter. Stores the function's result into a supplied
- // buffer.
- // Parameters:
- // function - The LLVM function to call.
- // parameter_addresses - pointers to be passed to the function as
- // parameters.
- // return_value - pointer to a buffer where the call result is stored.
-
- void EmitArrayFunctionCallInto(
- llvm::Function* function,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- llvm::Value* return_value_buffer, tensorflow::StringPiece name);
-
- // Array function call emitter. Returns a Value for the function's return
- // value buffer address. The return value buffer is alloca'ed by this
- // function.
- llvm::Value* EmitArrayFunctionCall(
- llvm::Function* function, const Shape& return_shape, int64 element_count,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
- tensorflow::StringPiece name);
+ // Emits a call to a "global" function (e.g. to the computation nested within
+ // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to
+ // the parameters and return values for these computations so there is no need
+ // to explicitly pass parameters or return results.
+ void EmitGlobalCall(const HloComputation& callee,
+ tensorflow::StringPiece name);
+
+ // Returns the buffer to which a global call to `callee` would have written
+ // its result.
+ llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee);
// Verifies that the element types of all of the given operand instructions
// match and are of one of the given supported types.
@@ -408,11 +400,10 @@ class IrEmitter : public DfsHloVisitorWithDefault {
NameUniquer name_uniquer_;
// Map containing all previously emitted computations.
- std::map<HloComputation*, llvm::Function*> emitted_functions_;
+ std::map<const HloComputation*, llvm::Function*> emitted_functions_;
// Map containing all previously emitted thread-local temporary buffers.
- std::map<std::pair<llvm::Function*, BufferAllocation::Slice>,
- llvm::AllocaInst*>
+ std::map<std::pair<llvm::Function*, BufferAllocation::Slice>, llvm::Value*>
thread_local_buffers_;
// The following fields track the IR emission state. According to LLVM memory
@@ -422,6 +413,16 @@ class IrEmitter : public DfsHloVisitorWithDefault {
std::unique_ptr<IrFunction> compute_function_;
llvm::IRBuilder<> b_;
+ // The buffer allocation slice for the root of the computation being compiled.
+ // Only relevant for thread local computations.
+ BufferAllocation::Slice computation_root_allocation_;
+
+ // Maps the buffer allocation slices for the parameters to the computation
+ // being compiled to their parameter numbers. Only relevant for thread local
+ // computations.
+ tensorflow::gtl::FlatMap<BufferAllocation::Index, int64>
+ computation_parameter_allocations_;
+
// Maps HLO instructions to their index into the profile counter array.
const std::unordered_map<const HloInstruction*, int64>
instruction_to_profile_idx_;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc
index 6aff838462..2db4d000f5 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc
@@ -80,9 +80,16 @@ void IrFunction::Initialize(const string& function_name,
// void function(i8* retval, i8* run_options, i8** params, i8** temps,
// i64* dynamic_loop_bounds, i64* prof_counters)
//
- // retval: points to the returned value.
- // params: address of an array with pointers to parameters.
- // temps: address of an array with pointers to temporary buffers.
+ // For thread local functions:
+ // retval: points to the returned value.
+ // params: address of an array with pointers to parameters.
+ // temps: is null
+ //
+ // For global functions:
+ // retval: is null
+ // params: is null
+ // temps: address of an array with pointers to temporary buffers and entry
+ // computation parameters.
//
// Therefore, the generated function's signature (FunctionType) is statically
// determined - parameter unpacking is done in code generated into the
@@ -196,18 +203,25 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments(
llvm::IRBuilder<>* b, tensorflow::StringPiece name,
llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) {
- llvm::Value* parameter_addresses_buffer =
- llvm_ir::EmitAllocaAtFunctionEntryWithCount(
- b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()),
- tensorflow::strings::StrCat(name, "_parameter_addresses"), b);
- for (size_t i = 0; i < parameter_addresses.size(); ++i) {
- llvm::Value* parameter_as_i8ptr =
- b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(),
- AsStringRef(tensorflow::strings::StrCat(
- name, "_parameter_", i, "_address_as_i8ptr")));
- llvm::Value* slot_in_param_addresses =
- b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)});
- b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
+ llvm::Value* parameter_addresses_buffer;
+
+ if (parameter_addresses.empty()) {
+ parameter_addresses_buffer =
+ llvm::Constant::getNullValue(b->getInt8PtrTy()->getPointerTo());
+ } else {
+ parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
+ b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()),
+ tensorflow::strings::StrCat(name, "_parameter_addresses"), b);
+
+ for (size_t i = 0; i < parameter_addresses.size(); ++i) {
+ llvm::Value* parameter_as_i8ptr =
+ b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(),
+ AsStringRef(tensorflow::strings::StrCat(
+ name, "_parameter_", i, "_address_as_i8ptr")));
+ llvm::Value* slot_in_param_addresses =
+ b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)});
+ b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
+ }
}
const auto to_int8_ptr = [=](llvm::Value* ptr) {
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
index 7b1f62541d..a5f34908d7 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
@@ -66,6 +66,7 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin(
VLOG(2) << "ParallelForkJoin ENTRY"
<< " num_partitions: " << num_partitions
<< " num_partitioned_dims: " << num_partitioned_dims;
+ CHECK_EQ(params, nullptr);
CHECK_GT(num_partitions, 1);
CHECK_GT(num_partitioned_dims, 0);
const xla::ExecutableRunOptions* run_options =
@@ -80,9 +81,9 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ParallelForkJoin(
for (int32 i = 1; i < num_partitions; ++i) {
const int64 offset = i * stride;
run_options->intra_op_thread_pool()->enqueueNoNotification(
- [i, function, result_ptr, run_options_ptr, params, temps, prof_counters,
+ [i, function, result_ptr, run_options_ptr, temps, prof_counters,
partitions, offset, &bc]() {
- function(result_ptr, run_options_ptr, params, temps,
+ function(result_ptr, run_options_ptr, nullptr, temps,
&partitions[offset], prof_counters);
bc.DecrementCount();
VLOG(3) << "ParallelForkJoin partition " << i << " done.";
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
index 941d940684..fe5ec1cc66 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
@@ -56,12 +56,12 @@ ENTRY while3 {
)";
CompileAndVerifyIr(hlo_string, R"(
-; CHECK-LABEL: @body(i8* align 4 dereferenceable(4) %retval
+; CHECK-LABEL: @body(i8* %retval
; CHECK: %[[add_result:.*]] = fadd fast float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]]
; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:[0-9]+]]
;
-; CHECK-LABEL: @condition(i8* align 1 dereferenceable(1) %fusion, i8* noalias %run_options, i8** noalias %params
-; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %params, i64 0
+; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params
+; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %temps, i64 0
; CHECK: %[[cond_state_buf_untyped:.*]] = load i8*, i8** %[[cond_state_buf_ptr]]
; CHECK: %[[cond_state_buf_typed:.*]] = bitcast i8* %[[cond_state_buf_untyped]] to float*
; CHECK: load float, float* %[[cond_state_buf_typed]], !alias.scope ![[alias_scope_md_for_store]], !noalias ![[noalias_md_for_load:.*]]
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc
index 47cab79604..115448c908 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc
@@ -42,13 +42,12 @@ extern "C" void SumStructElements(float* out, void** parameters) {
TEST_F(LocalClientAotTest, Constant) {
xla::ExecutableRunOptions run_options;
OpaqueData opaque_data{100, 20, 3};
- void* parameters[] = {&opaque_data};
float out = 0;
- void* temporary_buffers[] = {nullptr, &out};
- SumAndDouble(&out, &run_options, parameters, temporary_buffers);
+ void* temporary_buffers[] = {&opaque_data, &out};
+ SumAndDouble(&out, &run_options, nullptr, temporary_buffers);
EXPECT_EQ(out, 246.0f);
opaque_data = {1, 2, 3};
- SumAndDouble(&out, &run_options, parameters, temporary_buffers);
+ SumAndDouble(&out, &run_options, nullptr, temporary_buffers);
EXPECT_EQ(out, 12.0f);
}
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
index 74494e60e8..e310966d8b 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
@@ -93,7 +93,7 @@ int main(int argc, char** argv) {
// local_client_aot_test.cc to be able to easily invoke the function.
CHECK_EQ(result->result_buffer_index(), 1);
CHECK_EQ(result->buffer_sizes().size(), 3);
- CHECK_EQ(result->buffer_sizes()[0], -1); // param buffer
+ CHECK_EQ(result->buffer_sizes()[0], -2); // param buffer
CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // result buffer
CHECK_EQ(result->buffer_sizes()[2], -1); // const buffer
if (triple.isOSBinFormatELF()) {
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index c81c27891c..1bdf1867b9 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -1236,6 +1236,35 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
{param_value.get()}, ErrorSpec(4e-5));
}
+TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) {
+ auto while_shape = ShapeUtil::MakeShape(S32, {});
+
+ XlaComputation condition;
+ {
+ XlaBuilder builder("condition");
+ Parameter(&builder, 0, while_shape, "state");
+ Infeed(&builder, ShapeUtil::MakeShape(PRED, {}));
+ TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
+ }
+
+ XlaComputation body;
+ {
+ XlaBuilder builder("body");
+ auto indvar = Parameter(&builder, 0, while_shape, "state");
+ Add(indvar, ConstantR0<int32>(&builder, 1));
+ TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
+ }
+
+ XlaBuilder builder(TestName());
+ While(condition, body, ConstantR0<int32>(&builder, 0));
+
+ TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
+ TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
+ TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(false)));
+
+ ComputeAndCompareR0<int32>(&builder, 2, {});
+}
+
void BM_WhileLoop(int num_iters) {
// Benchmark a simple kernel to measure while loop overheads.
tensorflow::testing::StopTiming();