aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-07-31 17:21:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-31 17:26:19 -0700
commit182a00ee781017932443bacb475af7acc4a56d5a (patch)
treeec24c5cc7e424642bf6772b4d46f7a08ac31dc88
parent64f191cdc0121bbcb322c3b11b160d638c2f4af9 (diff)
Automated rollback of commit fba2d773f45f10882aa475ac75cbf9884995d626
PiperOrigin-RevId: 206855848
-rw-r--r--tensorflow/compiler/aot/runtime.cc4
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc22
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h29
-rw-r--r--tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc12
-rw-r--r--tensorflow/compiler/xla/service/cpu/compiler_functor.cc21
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc23
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc104
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.h27
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc7
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc410
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h97
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.cc44
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc5
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc6
-rw-r--r--tensorflow/compiler/xla/tests/local_client_aot_test.cc7
-rw-r--r--tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc2
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc29
17 files changed, 389 insertions, 460 deletions
diff --git a/tensorflow/compiler/aot/runtime.cc b/tensorflow/compiler/aot/runtime.cc
index 475eebaa35..5e74079fc1 100644
--- a/tensorflow/compiler/aot/runtime.cc
+++ b/tensorflow/compiler/aot/runtime.cc
@@ -85,9 +85,7 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
}
uintptr_t pos = reinterpret_cast<uintptr_t>(contiguous);
for (size_t i = 0; i < n; ++i) {
- if (sizes[i] < 0) {
- // bufs[i] is either a constant, an entry parameter or a thread local
- // allocation.
+ if (sizes[i] == -1) {
bufs[i] = nullptr;
} else {
bufs[i] = reinterpret_cast<void*>(pos);
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
index ed5aa08c6f..672e19bd93 100644
--- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
@@ -26,8 +26,6 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
result_index_(static_data.result_index),
args_(new void*[static_data.num_args]),
temps_(new void*[static_data.num_temps]),
- arg_index_to_temp_index_(new int32[static_data.num_args]),
- num_args_(static_data.num_args),
arg_names_(static_data.arg_names),
result_names_(static_data.result_names),
program_shape_(static_data.program_shape),
@@ -42,13 +40,6 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
static_data.temp_sizes, static_data.num_temps, temps_,
/*annotate_initialized=*/true);
- for (int i = 0; i < static_data.num_temps; i++) {
- if (static_data.temp_sizes[i] < -1) {
- int32 param_number = -(static_data.temp_sizes[i] + 2);
- arg_index_to_temp_index_[param_number] = i;
- }
- }
-
// If Hlo profiling is enabled the generated code expects an appropriately
// sized buffer to be passed in as the last argument. If Hlo profiling is
// disabled the last function argument is still present in the function
@@ -59,24 +50,11 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
}
}
-bool XlaCompiledCpuFunction::Run() {
- // Propagate pointers to the argument buffers into the temps array. Code
- // generated by XLA discovers the incoming argument pointers from the temps
- // array.
- for (int32 i = 0; i < num_args_; i++) {
- temps_[arg_index_to_temp_index_[i]] = args_[i];
- }
- raw_function_(temps_[result_index_], &run_options_, nullptr, temps_,
- profile_counters_);
- return true;
-}
-
XlaCompiledCpuFunction::~XlaCompiledCpuFunction() {
tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_);
tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_);
delete[] args_;
delete[] temps_;
- delete[] arg_index_to_temp_index_;
delete[] profile_counters_;
}
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
index 27cfb354bf..48a8c083ca 100644
--- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
@@ -60,19 +60,9 @@ class XlaCompiledCpuFunction {
// The raw function to call.
RawFunction raw_function;
- // Cardinality and size of arg buffers.
+ // Cardinality and sizes of arg and temp buffers.
const intptr_t* arg_sizes = nullptr;
size_t num_args = 0;
-
- // Cardinality and size of temp buffers.
- //
- // If temp_sizes[i] >= 0 then the i'th temp is a regular temporary buffer.
- //
- // If temp_sizes[i] == -1 then the i'th temp is a constant buffer. The
- // corresponding entry in the temp buffer array needs to be set to null.
- //
- // If temp_sizes[i] < -1 then the i'th temp is the entry parameter
- // -(temp_sizes[i] + 2).
const intptr_t* temp_sizes = nullptr;
size_t num_temps = 0;
@@ -123,7 +113,11 @@ class XlaCompiledCpuFunction {
// Runs the computation, with inputs read from arg buffers, and outputs
// written to result buffers. Returns true on success and false on failure.
- bool Run();
+ bool Run() {
+ raw_function_(temps_[result_index_], &run_options_,
+ const_cast<const void**>(args_), temps_, profile_counters_);
+ return true;
+ }
// Returns the error message from the previous failed Run call.
//
@@ -230,17 +224,6 @@ class XlaCompiledCpuFunction {
void** args_ = nullptr;
void** temps_ = nullptr;
- // Argument i needs to be placed in temps_[arg_index_to_temp_index_[i]] for
- // XLA generated code to be able to find it.
- //
- // For now we need to keep around the args_ array because there is code that
- // depends on args() returning a void**. However, in the future we may remove
- // args_ in favor of using temps_ as the sole storage for the arguments.
- int32* arg_index_to_temp_index_;
-
- // The number of incoming arguments.
- int32 num_args_;
-
// Backing memory for individual arg and temp buffers.
void* alloc_args_ = nullptr;
void* alloc_temps_ = nullptr;
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
index 114a9241bd..00ccfb1c78 100644
--- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
@@ -58,15 +58,11 @@ xla::StatusOr<std::vector<intptr_t>> ComputeTempSizes(
std::vector<intptr_t> temp_sizes;
temp_sizes.reserve(allocations.size());
for (const xla::BufferAllocation& allocation : allocations) {
- if (allocation.is_constant() || allocation.is_thread_local()) {
- // Constants are lowered to globals. Thread locals are lowered to
- // allocas.
+ // Callers don't allocate temporary buffers for parameters. Nor for
+ // thread-local buffers, which are lowered to alloca.
+ if (allocation.is_entry_computation_parameter() ||
+ allocation.is_thread_local()) {
temp_sizes.push_back(-1);
- } else if (allocation.is_entry_computation_parameter()) {
- // Entry computation parameters need some preprocessing in
- // XlaCompiledCpuFunction::Run. See the comment on
- // XlaCompiledCpuFunction::StaticData::temp_sizes.
- temp_sizes.push_back(-allocation.parameter_number() - 2);
} else {
temp_sizes.push_back(allocation.size());
}
diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
index 128eea4828..6a7eb85e3b 100644
--- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
+++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
@@ -156,26 +156,9 @@ std::unique_ptr<llvm::MemoryBuffer> CompilerFunctor::operator()(
target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream);
codegen_passes.run(module);
- std::unique_ptr<llvm::MemoryBuffer> memory_buffer(
+ // Construct ObjectFile from machine code buffer.
+ return std::unique_ptr<llvm::MemoryBuffer>(
new llvm::SmallVectorMemoryBuffer(std::move(stream_buffer)));
-
- if (VLOG_IS_ON(2)) {
- llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>> obj_file =
- llvm::object::ObjectFile::createObjectFile(*memory_buffer);
- if (obj_file) {
- StatusOr<DisassemblerResult> disasm_result =
- disassembler_->DisassembleObjectFile(*obj_file.get());
- if (disasm_result.ok()) {
- XLA_VLOG_LINES(2, disasm_result.ValueOrDie().text);
- } else {
- LOG(WARNING) << "Could not disassemble object file!";
- }
- } else {
- LOG(WARNING) << "Could convert memory buffer to object file!";
- }
- }
-
- return memory_buffer;
}
static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 8cbe9a1b0d..b49ea89896 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -840,29 +840,18 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
BufferSizes buffer_sizes;
for (const BufferAllocation& allocation : assignment->Allocations()) {
- // Callers don't need to allocate anything for thread-local temporary
- // buffers. They are lowered to allocas.
- if (allocation.is_thread_local()) {
+ // Callers don't need to allocate temporary buffers for parameters.
+ if (allocation.is_entry_computation_parameter() ||
+ allocation.is_constant()) {
buffer_sizes.push_back(-1);
continue;
}
-
- // Callers don't need to allocate anything for constant buffers. They are
- // lowered to globals.
- if (allocation.is_constant()) {
+ // Callers don't need to allocate anything for thread-local temporary
+ // buffers. They are lowered to allocas.
+ if (allocation.is_thread_local()) {
buffer_sizes.push_back(-1);
continue;
}
-
- // Callers don't need to allocate anything for entry computation buffers,
- // but they do need to stash the pointer to the entry computation buffer
- // in the temp buffer table. See the comment on
- // XlaCompiledCpuFunction::StaticData::temp_sizes.
- if (allocation.is_entry_computation_parameter()) {
- buffer_sizes.push_back(-allocation.parameter_number() - 2);
- continue;
- }
-
buffer_sizes.push_back(allocation.size());
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index 946f5124b8..81e17a5cd4 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -69,19 +69,12 @@ CpuExecutable::CpuExecutable(
// guarded by the mutex.
compute_function_ =
reinterpret_cast<ComputeFunctionType>(cantFail(sym.getAddress()));
- VLOG(1) << "compute_function_ at address "
- << reinterpret_cast<void*>(compute_function_);
}
-StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
- std::vector<OwningDeviceMemory>>>
-CpuExecutable::CreateTempArray(
+Status CpuExecutable::AllocateBuffers(
DeviceMemoryAllocator* memory_allocator, int device_ordinal,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
- std::vector<se::DeviceMemoryBase> unowning_buffers(
- assignment_->Allocations().size());
- std::vector<OwningDeviceMemory> owning_buffers(
- assignment_->Allocations().size());
+ std::vector<OwningDeviceMemory>* buffers) {
+ CHECK_EQ(buffers->size(), assignment_->Allocations().size());
VLOG(3) << "Allocating " << assignment_->Allocations().size()
<< " allocations for module " << module().name();
for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
@@ -91,8 +84,6 @@ CpuExecutable::CreateTempArray(
VLOG(3) << allocation.ToString();
if (allocation.is_entry_computation_parameter()) {
- unowning_buffers[i] = arguments[allocation.parameter_number()]->buffer(
- allocation.param_shape_index());
VLOG(3) << "allocation #" << i << " is a parameter";
continue;
}
@@ -108,34 +99,34 @@ CpuExecutable::CreateTempArray(
}
int64 buffer_size = allocation.size();
- if (!owning_buffers[i].is_null()) {
+ if (!(*buffers)[i].is_null()) {
VLOG(3) << "buffer #" << i
<< " is in the preallocated result ShapedBuffer";
} else {
- TF_ASSIGN_OR_RETURN(owning_buffers[i], memory_allocator->Allocate(
- device_ordinal, buffer_size));
- unowning_buffers[i] = owning_buffers[i].AsDeviceMemoryBase();
+ TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate(
+ device_ordinal, buffer_size));
VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes ["
- << owning_buffers[i].opaque() << "]";
+ << (*buffers)[i].opaque() << "]";
}
// Since the output buffer and all the temporary buffers were written into
// by the JITed code, msan has no way of knowing their memory was
// initialized. Mark them initialized so that msan doesn't flag loads from
// these buffers.
- TF_ANNOTATE_MEMORY_IS_INITIALIZED(owning_buffers[i].opaque(), buffer_size);
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size);
}
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
assignment_->GetUniqueTopLevelOutputSlice());
VLOG(3) << "result index: " << result_slice.index();
- return {{std::move(unowning_buffers), std::move(owning_buffers)}};
+ return Status::OK();
}
Status CpuExecutable::ExecuteComputeFunction(
const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile) {
// The calling convention for JITed functions is:
@@ -145,11 +136,17 @@ Status CpuExecutable::ExecuteComputeFunction(
//
// result: Points at the result.
// run_options: the ExecutableRunOptions object.
- // args_array: null
- // temps_array: An array of pointers, containing pointers to temporary buffers
- // required by the executable adn pointers to entry computation
- // parameters.
+ // args_array: An array of pointers, each of which points to a parameter.
+ // The size of this array is determined by the function's arity
+ // (ProgramShape).
+ // temps_array: An array of pointers, each of which points to a temporary
+ // buffer the computation needs. The size of this array is
+ // determined by buffer analysis.
//
+ std::vector<const void*> args_array;
+ for (const ShapedBuffer* argument : arguments) {
+ args_array.push_back(argument->root_buffer().opaque());
+ }
uint64 start_micros = tensorflow::Env::Default()->NowMicros();
@@ -172,14 +169,16 @@ Status CpuExecutable::ExecuteComputeFunction(
if (VLOG_IS_ON(3)) {
VLOG(3) << "Executing compute function:";
VLOG(3) << tensorflow::strings::Printf(
- " func(void* result, void* params[null], void* temps[%zu], "
+ " func(void* result, void* params[%zu], void* temps[%zu], "
"uint64 profile_counters[%zu])",
- buffer_pointers.size(), profile_counters_size);
+ args_array.size(), buffer_pointers.size(), profile_counters_size);
VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer);
auto ptr_printer = [](string* out, const void* p) {
tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p));
};
- VLOG(3) << " params = nullptr";
+ VLOG(3) << tensorflow::strings::Printf(
+ " params = [%s]",
+ tensorflow::str_util::Join(args_array, ", ", ptr_printer).c_str());
VLOG(3) << tensorflow::strings::Printf(
" temps = [%s]",
tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str());
@@ -187,8 +186,8 @@ Status CpuExecutable::ExecuteComputeFunction(
profile_counters);
}
- compute_function_(result_buffer, run_options, nullptr, buffer_pointers.data(),
- profile_counters);
+ compute_function_(result_buffer, run_options, args_array.data(),
+ buffer_pointers.data(), profile_counters);
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
@@ -255,18 +254,21 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
se::Stream* stream = run_options->stream();
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
+ std::vector<OwningDeviceMemory> buffers(assignment_->Allocations().size());
- std::vector<OwningDeviceMemory> owning_buffers;
- std::vector<se::DeviceMemoryBase> unowning_buffers;
- TF_ASSIGN_OR_RETURN(
- std::tie(unowning_buffers, owning_buffers),
- CreateTempArray(memory_allocator, stream->parent()->device_ordinal(),
- arguments));
+ TF_RETURN_IF_ERROR(AllocateBuffers(
+ memory_allocator, stream->parent()->device_ordinal(), &buffers));
- TF_RETURN_IF_ERROR(ExecuteComputeFunction(
- &run_options->run_options(), unowning_buffers, hlo_execution_profile));
+ std::vector<se::DeviceMemoryBase> unowning_buffers;
+ unowning_buffers.reserve(buffers.size());
+ for (auto& buffer : buffers) {
+ unowning_buffers.push_back(buffer.AsDeviceMemoryBase());
+ }
+ TF_RETURN_IF_ERROR(ExecuteComputeFunction(&run_options->run_options(),
+ arguments, unowning_buffers,
+ hlo_execution_profile));
- return CreateResultShapedBuffer(run_options, &owning_buffers);
+ return CreateResultShapedBuffer(run_options, &buffers);
}
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
@@ -282,15 +284,17 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
run_options->stream()->implementation());
se::Stream* stream = run_options->stream();
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
- std::vector<OwningDeviceMemory> owning_buffers;
- std::vector<se::DeviceMemoryBase> unowning_buffers;
- TF_ASSIGN_OR_RETURN(
- std::tie(unowning_buffers, owning_buffers),
- CreateTempArray(memory_allocator, stream->parent()->device_ordinal(),
- arguments));
+ std::vector<OwningDeviceMemory> buffers(assignment_->Allocations().size());
+ TF_RETURN_IF_ERROR(AllocateBuffers(
+ memory_allocator, stream->parent()->device_ordinal(), &buffers));
+ std::vector<se::DeviceMemoryBase> unowning_buffers;
+ unowning_buffers.reserve(buffers.size());
+ for (auto& buffer : buffers) {
+ unowning_buffers.push_back(buffer.AsDeviceMemoryBase());
+ }
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
- CreateResultShapedBuffer(run_options, &owning_buffers));
+ CreateResultShapedBuffer(run_options, &buffers));
// At this point, `unowning_buffers` contains unowning pointers to all of our
// buffers, and `buffers` contains owning pointers to the non-live-out
@@ -308,6 +312,7 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
struct AsyncRunTask {
CpuExecutable* executable;
ServiceExecutableRunOptions run_options;
+ std::vector<const ShapedBuffer*> arguments;
std::vector<se::DeviceMemoryBase> unowning_buffers;
std::shared_ptr<std::vector<OwningDeviceMemory>> buffers;
@@ -315,14 +320,15 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
// Failing a CHECK here is not great, but I don't see an obvious way to
// return a failed Status asynchronously.
TF_CHECK_OK(executable->ExecuteComputeFunction(
- &run_options.run_options(), unowning_buffers,
+ &run_options.run_options(), arguments, unowning_buffers,
/*hlo_execution_profile=*/nullptr));
}
};
- host_stream->EnqueueTask(
- AsyncRunTask{this, *run_options, std::move(unowning_buffers),
- std::make_shared<std::vector<OwningDeviceMemory>>(
- std::move(owning_buffers))});
+ host_stream->EnqueueTask(AsyncRunTask{
+ this, *run_options,
+ std::vector<const ShapedBuffer*>(arguments.begin(), arguments.end()),
+ unowning_buffers,
+ std::make_shared<std::vector<OwningDeviceMemory>>(std::move(buffers))});
return std::move(result);
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 8af8a5dfec..8dd47bfb86 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -85,29 +85,20 @@ class CpuExecutable : public Executable {
const BufferAssignment& buffer_assignment() const { return *assignment_; }
private:
- // Creates an array suitable for passing as the "temps" argument to the JIT
- // compiled function pointer.
- //
- // Returns (unowning_buffers, owning_buffers) where:
- //
- // - unowning_buffers.data() can be passed as the temps argument as-is and
- // includes pointers to the scratch storage required by the computation,
- // the live-out buffer into which the result will be written and entry
- // computation parameters.
- //
- // - owning_buffers contains owning pointers to the buffers that were
- // allocated by this routine. This routine allocates buffers for temporary
- // storage and the live-out buffer into which the computation writes it
- // result.
- StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
- std::vector<OwningDeviceMemory>>>
- CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
+ // Allocate buffers required for execution and assign them to the elements of
+ // "buffers". "buffers" should be sized to the number of buffers in buffer
+ // assignment. Each vector element corresponds to a particular Index. If
+ // a vector element already contains a non-null DeviceMemoryBase, then no
+ // buffer is assigned for this element.
+ Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator,
+ int device_ordinal,
+ std::vector<OwningDeviceMemory>* buffers);
// Calls the generated function performing the computation with the given
// arguments using the supplied buffers.
Status ExecuteComputeFunction(
const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile);
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index c13d36776f..cf955a8add 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -19,8 +19,6 @@ limitations under the License.
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
-#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
-#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/types.h"
@@ -119,8 +117,9 @@ llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator(
ElementwiseSourceIndex(index, *hlo, i)));
operands.push_back(operand_value);
}
- return ir_emitter_->EmitElementalMap(*Cast<HloMapInstruction>(hlo),
- operands, llvm_ir::IrName(hlo));
+ return ir_emitter_->EmitScalarCall(hlo->shape().element_type(),
+ hlo->to_apply(), operands,
+ llvm_ir::IrName(hlo));
};
}
return ElementalIrEmitter::MakeElementGenerator(hlo, operand_to_generator);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 60f9cd1121..a6d8551841 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -116,19 +116,6 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
computation->root_instruction()->outer_dimension_partitions().size();
}
- if (computation->root_instruction()->opcode() != HloOpcode::kOutfeed) {
- TF_ASSIGN_OR_RETURN(
- computation_root_allocation_,
- assignment_.GetUniqueTopLevelSlice(computation->root_instruction()));
- }
-
- for (const HloInstruction* param : computation->parameter_instructions()) {
- TF_ASSIGN_OR_RETURN(BufferAllocation::Slice param_slice,
- assignment_.GetUniqueTopLevelSlice(param));
- computation_parameter_allocations_[param_slice.allocation()->index()] =
- param->parameter_number();
- }
-
InitializeIrFunction(function_name);
// The rdtscp instruction is x86 specific. We will fallback to LLVM's generic
// readcyclecounter if it is unavailable.
@@ -145,8 +132,6 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
// Delete 'compute_function', finalizing 'ir_function' and restoring caller
// IR insert point.
compute_function_.reset();
- computation_root_allocation_ = BufferAllocation::Slice();
- computation_parameter_allocations_.clear();
return ir_function;
}
@@ -499,11 +484,23 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) {
return Status::OK();
}
-llvm::Value* IrEmitter::EmitElementalMap(
- const HloMapInstruction& map_instr,
- tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
- tensorflow::StringPiece name) {
- return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name);
+StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForMap(
+ HloMapInstruction* map, const llvm_ir::IrArray::Index& index) {
+ llvm::Function* mapped_ir_function =
+ FindOrDie(emitted_functions_, map->to_apply());
+ std::vector<llvm::Value*> parameter_addresses;
+ for (const HloInstruction* operand : map->operands()) {
+ const llvm_ir::IrArray& array = GetIrArrayFor(operand);
+ parameter_addresses.push_back(array.EmitArrayElementAddress(index, &b_));
+ }
+ return EmitElementFunctionCall(mapped_ir_function, map->shape(),
+ parameter_addresses, "map_function");
+}
+
+Status IrEmitter::HandleMap(HloInstruction* map) {
+ return EmitTargetElementLoop(map, [&](const llvm_ir::IrArray::Index& index) {
+ return EmitTargetElementLoopBodyForMap(Cast<HloMapInstruction>(map), index);
+ });
}
StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
@@ -511,6 +508,9 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
const llvm_ir::IrArray::Index& index) {
const HloInstruction* operand = reduce_window->operand(0);
const Window& window = reduce_window->window();
+ HloComputation* function = reduce_window->to_apply();
+ // The called computation should have been emitted previously.
+ llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
// We fold inputs into the accumulator and initialize it to
// the initial value on the reduce_window.
@@ -563,10 +563,11 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
// We are not in the padding, so carry out the computation.
llvm_ir::IrArray input_array(GetIrArrayFor(operand));
- llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_);
- llvm::Value* result = EmitThreadLocalCall(
- *reduce_window->to_apply(),
- {b_.CreateLoad(accumulator_address), input_value}, "reducer_function");
+ llvm::Value* input_value_address =
+ input_array.EmitArrayElementAddress(input_index, &b_);
+ llvm::Value* result = EmitElementFunctionCall(
+ reducer_function, reduce_window->shape(),
+ {accumulator_address, input_value_address}, "reducer_function");
b_.CreateStore(result, accumulator_address);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_);
@@ -622,6 +623,12 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
"Dilation for SelectAndScatter is not implemented on CPU. ");
}
+ // The select and scatter computations should have been emitted previously.
+ llvm::Function* select_function =
+ FindOrDie(emitted_functions_, select_and_scatter->select());
+ llvm::Function* scatter_function =
+ FindOrDie(emitted_functions_, select_and_scatter->scatter());
+
// Pseudo code for select-and-scatter:
//
// initialized_flag is initially off for every window, and is turned on after
@@ -726,12 +733,11 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
// If the initialized_flag is true, call the `select` function to potentially
// update the selected value and index with the currently visiting operand.
SetToFirstInsertPoint(if_initialized.true_block, &b_);
+ const Shape output_shape = ShapeUtil::MakeShape(PRED, {});
llvm::Value* operand_address =
operand_array.EmitArrayElementAddress(operand_index, &b_);
- llvm::Value* operand_element = b_.CreateLoad(operand_address);
- llvm::Value* result = EmitThreadLocalCall(
- *select_and_scatter->select(),
- {b_.CreateLoad(selected_value_address), operand_element},
+ llvm::Value* result = EmitElementFunctionCall(
+ select_function, output_shape, {selected_value_address, operand_address},
"select_function");
// If the 'select' function returns false, update the selected value and the
@@ -758,14 +764,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
selected_index.push_back(b_.CreateLoad(selected_index_address_slot));
}
llvm_ir::IrArray source_array(GetIrArrayFor(source));
- llvm::Value* source_value =
- source_array.EmitReadArrayElement(source_index, &b_);
+ llvm::Value* source_value_address =
+ source_array.EmitArrayElementAddress(source_index, &b_);
llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter));
- llvm::Value* output_value =
- output_array.EmitReadArrayElement(selected_index, &b_);
- llvm::Value* scatter_value =
- EmitThreadLocalCall(*select_and_scatter->scatter(),
- {output_value, source_value}, "scatter_function");
+ llvm::Value* output_value_address =
+ output_array.EmitArrayElementAddress(selected_index, &b_);
+ llvm::Value* scatter_value = EmitElementFunctionCall(
+ scatter_function, source->shape(),
+ {output_value_address, source_value_address}, "scatter_function");
output_array.EmitWriteArrayElement(selected_index, scatter_value, &b_);
SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(), &b_);
@@ -1242,7 +1248,46 @@ static llvm_ir::IrArray::Index FillReducedDimensionIndex(
Status IrEmitter::HandleParameter(HloInstruction* parameter) {
VLOG(2) << "HandleParameter: " << parameter->ToString();
- return EmitTargetAddressForOp(parameter);
+ auto param_number = parameter->parameter_number();
+ auto param_shape = parameter->shape();
+
+ // We have to access the parameter at offset param_number in the params
+ // array. The code generated here is equivalent to this C code:
+ //
+ // i8* param_address_untyped = params[param_number];
+ // Param* param_address_typed = (Param*)param_address_untyped;
+ //
+ // Where Param is the actual element type of the underlying buffer (for
+ // example, float for an XLA F32 element type).
+ llvm::Value* params = compute_function_->parameters_arg();
+ llvm::Value* param_address_offset =
+ llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
+ llvm::LoadInst* param_address_untyped = b_.CreateLoad(param_address_offset);
+ param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped")));
+ if (is_top_level_computation_ &&
+ hlo_module_config_.debug_options()
+ .xla_llvm_enable_invariant_load_metadata()) {
+ // In the entry computation the parameter slots in the %params argument are
+ // invariant through program execution. In computations that are called
+ // from the entry computation (via kWhile, kCall and kConditional) the
+ // parameter slots are *not* invariant since they're written to by their
+ // callers.
+ param_address_untyped->setMetadata(
+ llvm::LLVMContext::MD_invariant_load,
+ llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{}));
+ }
+
+ llvm::Value* param_address_typed = b_.CreateBitCast(
+ param_address_untyped, IrShapeType(param_shape)->getPointerTo());
+ emitted_value_[parameter] = param_address_typed;
+
+ if (!ShapeUtil::IsOpaque(param_shape)) {
+ AttachAlignmentMetadataForLoad(param_address_untyped, param_shape);
+ AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape);
+ }
+
+ VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*param_address_typed);
+ return Status::OK();
}
// Returns true if the relative order of the unreduced dimensions stays the same
@@ -1706,6 +1751,9 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
const HloInstruction* arg = reduce->mutable_operand(0);
const HloInstruction* init_value = reduce->mutable_operand(1);
gtl::ArraySlice<int64> dimensions(reduce->dimensions());
+ HloComputation* function = reduce->to_apply();
+ // The called computation should have been emitted previously.
+ llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
// Initialize an accumulator with init_value.
PrimitiveType accumulator_type = reduce->shape().element_type();
@@ -1745,9 +1793,10 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce(
CHECK(index.end() == it);
// Apply the reduction function to the loaded value.
- llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_);
- llvm::Value* result = EmitThreadLocalCall(
- *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element},
+ llvm::Value* input_address =
+ arg_array.EmitArrayElementAddress(input_index, &b_);
+ llvm::Value* result = EmitElementFunctionCall(
+ reducer_function, reduce->shape(), {accumulator_addr, input_address},
"reduce_function");
b_.CreateStore(result, accumulator_addr);
@@ -2085,13 +2134,18 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
HloComputation* computation = call->to_apply();
llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation);
+ std::vector<llvm::Value*> parameter_addresses;
+ for (const HloInstruction* operand : call->operands()) {
+ parameter_addresses.push_back(GetEmittedValueFor(operand));
+ }
+
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call));
if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
// ParallelTaskAssignment assigned partitions, emit call to
// ParallelForkJoin.
std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments(
- {}, &b_, computation->name(),
+ parameter_addresses, &b_, computation->name(),
/*return_value_buffer=*/emitted_value_[call],
/*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
/*temp_buffers_arg=*/GetTempBuffersArgument(),
@@ -2102,7 +2156,8 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
call_args, root->shape(), root->outer_dimension_partitions(), &b_,
call_ir_function, computation->name()));
} else {
- EmitGlobalCall(*computation, computation->name());
+ EmitArrayFunctionCallInto(call_ir_function, parameter_addresses,
+ emitted_value_[call], computation->name());
}
return Status::OK();
@@ -2183,6 +2238,12 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
const HloInstruction* init = xla_while->operand(0);
emitted_value_[xla_while] = GetEmittedValueFor(init);
+ // The called computation should have been emitted previously.
+ llvm::Function* condition_ir_function =
+ FindOrDie(emitted_functions_, condition);
+ llvm::Function* body_ir_function =
+ FindOrDie(emitted_functions_, xla_while->while_body());
+
// Generating:
// while (Condition(while_result)) {
// // CopyInsertion pass inserts copies which enable 'while_result' to
@@ -2199,10 +2260,12 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
// Calls the condition function to determine whether to proceed with the
// body. It must return a bool, so use the scalar call form.
- EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond"));
+ llvm::Value* while_result = GetEmittedValueFor(xla_while);
+ llvm::Value* while_condition = EmitElementFunctionCall(
+ condition_ir_function, condition->root_instruction()->shape(),
+ {while_result}, IrName(xla_while, "cond"));
llvm::Value* while_predicate = b_.CreateICmpNE(
- b_.CreateLoad(
- GetBufferForGlobalCallReturnValue(*xla_while->while_condition())),
+ while_condition,
llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0));
// Branches to the body or to the while exit depending on the condition.
@@ -2217,8 +2280,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
b_.SetInsertPoint(body_bb);
// Calls the body function.
- EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body"));
-
+ EmitArrayFunctionCallInto(body_ir_function, {while_result}, while_result,
+ IrName(xla_while, "body"));
// Finishes with a branch back to the header.
b_.CreateBr(header_bb);
@@ -2386,6 +2449,8 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) {
Status IrEmitter::HandleConditional(HloInstruction* conditional) {
auto pred = conditional->operand(0);
+ auto true_arg = conditional->operand(1);
+ auto false_arg = conditional->operand(2);
TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) &&
pred->shape().element_type() == PRED)
<< "Predicate on a Conditional must be bool; got: "
@@ -2407,7 +2472,13 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
<< " and "
<< ShapeUtil::HumanString(false_computation->root_instruction()->shape());
+ llvm::Function* true_function =
+ FindOrDie(emitted_functions_, true_computation);
+ llvm::Function* false_function =
+ FindOrDie(emitted_functions_, false_computation);
+
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional));
+ llvm::Value* conditional_result = GetEmittedValueFor(conditional);
// Generating:
// if (pred)
@@ -2424,12 +2495,12 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) {
llvm_ir::EmitIfThenElse(pred_cond, "conditional", &b_);
SetToFirstInsertPoint(if_data.true_block, &b_);
- EmitGlobalCall(*conditional->true_computation(),
- IrName(conditional, "_true"));
+ EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)},
+ conditional_result, IrName(conditional, "_true"));
SetToFirstInsertPoint(if_data.false_block, &b_);
- EmitGlobalCall(*conditional->false_computation(),
- IrName(conditional, "_false"));
+ EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)},
+ conditional_result, IrName(conditional, "_false"));
SetToFirstInsertPoint(if_data.after_block, &b_);
return Status::OK();
@@ -2630,76 +2701,44 @@ llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
return compute_function_->exec_run_options_arg();
}
-llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer(
+llvm::Value* IrEmitter::EmitTempBufferPointer(
const BufferAllocation::Slice& slice, const Shape& target_shape) {
- const BufferAllocation& allocation = *slice.allocation();
- llvm::Value* tempbuf_address = [&]() -> llvm::Value* {
- if (slice == computation_root_allocation_) {
- llvm::Argument* retval = compute_function_->result_arg();
- llvm::AttrBuilder attr_builder;
- attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
- attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
- retval->addAttrs(attr_builder);
- return retval;
- }
-
- auto param_it =
- computation_parameter_allocations_.find(slice.allocation()->index());
- if (param_it != computation_parameter_allocations_.end()) {
- int64 param_number = param_it->second;
- // We have to access the parameter at offset param_number in the params
- // array. The code generated here is equivalent to this C code:
- //
- // i8* param_address_untyped = params[param_number];
- // Param* param_address_typed = (Param*)param_address_untyped;
- //
- // Where Param is the actual element type of the underlying buffer (for
- // example, float for an XLA F32 element type).
- llvm::Value* params = compute_function_->parameters_arg();
- llvm::Value* param_address_offset =
- llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_);
- llvm::LoadInst* param_address_untyped =
- b_.CreateLoad(param_address_offset);
-
- if (!ShapeUtil::IsOpaque(target_shape)) {
- AttachAlignmentMetadataForLoad(param_address_untyped, target_shape);
- AttachDereferenceableMetadataForLoad(param_address_untyped,
- target_shape);
- }
- return param_address_untyped;
- }
-
+ llvm::Type* element_type = IrShapeType(target_shape);
+ // The alignment and number of bytes within the temporary buffer is determined
+ // by the maximal shape as determined by buffer assignment.
+ const BufferAllocation& allocation = assignment_.GetAllocation(slice.index());
+ if (allocation.is_thread_local()) {
// Thread-local allocations should only be assigned a single buffer.
const auto& assigned_buffers = allocation.assigned_buffers();
CHECK_EQ(1, assigned_buffers.size());
const Shape& shape = assigned_buffers.begin()->first->shape();
- std::pair<llvm::Function*, BufferAllocation::Slice> key = {
- compute_function_->function(), slice};
- auto buf_it = thread_local_buffers_.find(key);
- if (buf_it == thread_local_buffers_.end()) {
- llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm::AllocaInst*& tempbuf_address =
+ thread_local_buffers_[{b_.GetInsertBlock()->getParent(), slice}];
+ if (tempbuf_address == nullptr) {
+ tempbuf_address = llvm_ir::EmitAllocaAtFunctionEntry(
IrShapeType(shape),
tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_,
MinimumAlignmentForShape(target_shape));
- auto it_inserted_pair = thread_local_buffers_.insert({key, buffer});
- CHECK(it_inserted_pair.second);
- buf_it = it_inserted_pair.first;
}
- return buf_it->second;
- }();
- return b_.CreateBitCast(tempbuf_address,
- IrShapeType(target_shape)->getPointerTo());
-}
+ return b_.CreateBitCast(tempbuf_address, element_type->getPointerTo());
+ }
+
+ if (allocation.is_constant()) {
+ return FindOrDie(constant_buffer_to_global_, allocation.index());
+ }
-llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
- const BufferAllocation::Slice& slice, const Shape& target_shape) {
- const BufferAllocation& allocation = *slice.allocation();
llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP(
GetTempBuffersArgument(), slice.index(), &b_);
llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr);
- if (hlo_module_config_.debug_options()
+ if (is_top_level_computation_ &&
+ hlo_module_config_.debug_options()
.xla_llvm_enable_invariant_load_metadata()) {
+ // In the entry computation the parameter slots in the %params argument are
+ // invariant through program execution. In computations that are called
+ // from the entry computation (via kWhile, kCall and kConditional) the
+ // parameter slots are *not* invariant since they're written to by their
+ // callers.
tempbuf_address_base->setMetadata(
llvm::LLVMContext::MD_invariant_load,
llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{}));
@@ -2714,25 +2753,85 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer(
b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset()));
}
return b_.CreateBitCast(tempbuf_address_untyped,
- IrShapeType(target_shape)->getPointerTo());
+ element_type->getPointerTo());
}
-llvm::Value* IrEmitter::EmitTempBufferPointer(
- const BufferAllocation::Slice& slice, const Shape& target_shape) {
- if (slice.allocation()->is_thread_local()) {
- return EmitThreadLocalTempBufferPointer(slice, target_shape);
- } else if (slice.allocation()->is_constant()) {
- return FindOrDie(constant_buffer_to_global_, slice.allocation()->index());
- } else {
- return EmitGlobalTempBufferPointer(slice, target_shape);
- }
+// Emits a function call returning a single array element. Allocates space
+// for a single element_type value, and loads it after call.
+llvm::Value* IrEmitter::EmitElementFunctionCall(
+ llvm::Function* function, const Shape& return_shape,
+ gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ tensorflow::StringPiece name) {
+ llvm::Value* return_value_buffer = EmitArrayFunctionCall(
+ function, return_shape, 1, parameter_addresses, name);
+ return b_.CreateLoad(
+ return_value_buffer,
+ AsStringRef(tensorflow::strings::StrCat(name, "_return_value")));
+}
+
+// Emits a core function call based on the following pseudo-code.
+//
+// char** parameter_addresses_buffer =
+// allocate buffer with a pointer for each parameter to the function
+// for each parameter index, i.e. for i = 0, ..., #parameters:
+// parameter_addresses_buffer[i] = parameter_addresses[i]
+// call function(return_value_buffer,
+// parameter_addresses_buffer,
+// temps)
+// return return_value_buffer -- address of the return value.
+void IrEmitter::EmitArrayFunctionCallInto(
+ llvm::Function* function, gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ llvm::Value* return_value_buffer, tensorflow::StringPiece name) {
+ b_.CreateCall(function,
+ GetArrayFunctionCallArguments(
+ parameter_addresses, &b_, name,
+ /*return_value_buffer=*/return_value_buffer,
+ /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
+ /*temp_buffers_arg=*/GetTempBuffersArgument(),
+ /*profile_counters_arg=*/GetProfileCountersArgument()));
+}
+
+llvm::Value* IrEmitter::EmitArrayFunctionCall(
+ llvm::Function* function, const Shape& return_shape, int64 element_count,
+ gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ tensorflow::StringPiece name) {
+ llvm::Value* elements =
+ llvm::ConstantInt::get(b_.getInt64Ty(), element_count);
+ PrimitiveType return_type = return_shape.element_type();
+ llvm::Value* return_value_buffer =
+ llvm_ir::EmitAllocaAtFunctionEntryWithCount(
+ llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements,
+ tensorflow::strings::StrCat(name, "_return_value_address"), &b_,
+ MinimumAlignmentForPrimitiveType(return_type));
+ EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer,
+ name);
+ return return_value_buffer;
}
Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
+ llvm::Value* addr;
const Shape& target_shape = op->shape();
- TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
- assignment_.GetUniqueTopLevelSlice(op));
- llvm::Value* addr = EmitTempBufferPointer(slice, target_shape);
+ if (op == op->parent()->root_instruction()) {
+ // For the root node, we write directly to the output buffer of the
+ // function.
+ llvm::Argument* retval = compute_function_->result_arg();
+ if ((ShapeUtil::IsArray(target_shape) &&
+ !ShapeUtil::IsZeroElementArray(target_shape)) ||
+ (ShapeUtil::IsTuple(target_shape) &&
+ !ShapeUtil::IsEmptyTuple(target_shape))) {
+ llvm::AttrBuilder attr_builder;
+ attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
+ attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
+ retval->addAttrs(attr_builder);
+ }
+ addr = b_.CreateBitCast(retval, IrShapeType(target_shape)->getPointerTo());
+ } else {
+ // For other nodes, we need the temporary buffer allocated for this node to
+ // write the result into.
+ TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
+ assignment_.GetUniqueTopLevelSlice(op));
+ addr = EmitTempBufferPointer(slice, target_shape);
+ }
addr->setName(AsStringRef(IrName(op)));
emitted_value_[op] = addr;
return Status::OK();
@@ -2837,69 +2936,20 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
}
-llvm::Value* IrEmitter::EmitThreadLocalCall(
- const HloComputation& callee,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
- tensorflow::StringPiece name) {
- const Shape& return_shape = callee.root_instruction()->shape();
-
- // Lifting this restriction to allow "small" arrays should be easy. Allowing
- // larger arrays is difficult because we allocate the buffer for this return
- // value on the stack.
- CHECK(ShapeUtil::IsScalar(return_shape));
-
- PrimitiveType return_type = return_shape.element_type();
-
- std::vector<llvm::Value*> parameter_addrs;
- for (llvm::Value* parameter : parameters) {
- CHECK(!parameter->getType()->isPointerTy());
- llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry(
- parameter->getType(), "arg_addr", &b_);
- b_.CreateStore(parameter, parameter_addr);
- parameter_addrs.push_back(parameter_addr);
+StatusOr<llvm::Value*> IrEmitter::EmitScalarCall(
+ PrimitiveType return_type, HloComputation* computation,
+ const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name) {
+ llvm::Function* llvm_function = FindOrDie(emitted_functions_, computation);
+ std::vector<llvm::Value*> argument_addrs;
+ for (auto argument : arguments) {
+ llvm::Value* argument_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+ argument->getType(), "arg_addr", &b_);
+ b_.CreateStore(argument, argument_addr);
+ argument_addrs.push_back(argument_addr);
}
-
- llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
- llvm_ir::PrimitiveTypeToIrType(return_type, module_),
- tensorflow::strings::StrCat(name, "_retval_addr"), &b_,
- MinimumAlignmentForPrimitiveType(return_type));
-
- b_.CreateCall(
- FindOrDie(emitted_functions_, &callee),
- GetArrayFunctionCallArguments(
- parameter_addrs, &b_, name,
- /*return_value_buffer=*/return_value_buffer,
- /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/
- llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()),
- /*profile_counters_arg=*/GetProfileCountersArgument()));
-
- return b_.CreateLoad(return_value_buffer);
+ return EmitElementFunctionCall(llvm_function,
+ ShapeUtil::MakeShape(return_type, {}),
+ argument_addrs, name);
}
-
-void IrEmitter::EmitGlobalCall(const HloComputation& callee,
- tensorflow::StringPiece name) {
- b_.CreateCall(FindOrDie(emitted_functions_, &callee),
- GetArrayFunctionCallArguments(
- /*parameter_addresses=*/{}, &b_, name,
- /*return_value_buffer=*/
- llvm::Constant::getNullValue(b_.getInt8PtrTy()),
- /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(),
- /*temp_buffers_arg=*/GetTempBuffersArgument(),
- /*profile_counters_arg=*/GetProfileCountersArgument()));
-}
-
-llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue(
- const HloComputation& callee) {
- const HloInstruction* root_inst = callee.root_instruction();
- if (root_inst->opcode() == HloOpcode::kOutfeed) {
- return llvm::Constant::getNullValue(b_.getInt8PtrTy());
- }
-
- const BufferAllocation::Slice root_buffer =
- assignment_.GetUniqueTopLevelSlice(root_inst).ValueOrDie();
- return EmitTempBufferPointer(root_buffer, root_inst->shape());
-}
-
} // namespace cpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 372017441f..03bbb2afb5 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -100,15 +100,14 @@ class IrEmitter : public DfsHloVisitorWithDefault {
llvm::IRBuilder<>* b() { return &b_; }
+ // Emits a call to `computation` with scalar arguments `arguments`.
+ StatusOr<llvm::Value*> EmitScalarCall(
+ PrimitiveType return_type, HloComputation* computation,
+ const std::vector<llvm::Value*>& arguments, tensorflow::StringPiece name);
+
// Emit an LLVM global variable for every constant buffer allocation.
Status EmitConstantGlobals();
- // Emit code to map one element according to `map_instr`.
- llvm::Value* EmitElementalMap(
- const HloMapInstruction& map_instr,
- tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands,
- tensorflow::StringPiece name);
-
protected:
//
// The following methods implement the DfsHloVisitor interface.
@@ -144,6 +143,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleRecvDone(HloInstruction* recv_done) override;
Status HandlePad(HloInstruction* pad) override;
Status HandleTuple(HloInstruction* tuple) override;
+ Status HandleMap(HloInstruction* map) override;
Status HandleFusion(HloInstruction* fusion) override;
Status HandleCall(HloInstruction* call) override;
Status HandleCustomCall(HloInstruction* custom_call) override;
@@ -218,18 +218,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// computation function being emitted by this emitter.
llvm::Value* GetTempBuffersArgument();
- // Helper for EmitTempBufferPointer.
- llvm::Value* EmitGlobalTempBufferPointer(const BufferAllocation::Slice& slice,
- const Shape& target_shape);
-
- // Helper for EmitTempBufferPointer.
- llvm::Value* EmitThreadLocalTempBufferPointer(
- const BufferAllocation::Slice& slice, const Shape& target_shape);
-
- // Emits code that computes the address of the given buffer allocation slice.
- //
- // TODO(sanjoy): This should be renamed to reflect that it no longer provides
- // access to just temporaries.
+ // Emits code that computes the address of the given temporary buffer to the
+ // function. target_shape is the shape of this temporary buffer.
+ // The returned Value's type is a pointer to element_type.
llvm::Value* EmitTempBufferPointer(const BufferAllocation::Slice& slice,
const Shape& target_shape);
@@ -241,27 +232,44 @@ class IrEmitter : public DfsHloVisitorWithDefault {
tensorflow::StringPiece
function_name_suffix); // Used for LLVM IR register names.
- // Emits a call to a thread local function (e.g. to the computation nested
- // within a reduce or a map). Thread local callees (by definition) only write
- // to and read from thread local allocations.
- //
- // `parameters` holds the *scalar values* that need to be passed to the
- // callee. The return value is the scalar returned by the callee.
- llvm::Value* EmitThreadLocalCall(
- const HloComputation& callee,
- tensorflow::gtl::ArraySlice<llvm::Value*> parameters,
+ // Methods that emit a function call.
+ // Parameters:
+ // function - The LLVM function to call.
+ // return_shape - The return shape of the HLO computation that was used to
+ // make the function. Not the same as the return type of the function
+ // in LLVM, since we use output parameters for the return type.
+ // element_count - number of elements to return (array form only).
+ // parameter_addresses - pointers to be passed to the function as
+ // parameters.
+ // name - used for LLVM IR register names.
+
+ // Emits a function call, returning a scalar, often an element of a larger
+ // array. Returns a Value for the scalar element returned by the function.
+ llvm::Value* EmitElementFunctionCall(
+ llvm::Function* function, const Shape& return_shape,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
tensorflow::StringPiece name);
- // Emits a call to a "global" function (e.g. to the computation nested within
- // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to
- // the parameters and return values for these computations so there is no need
- // to explicitly pass parameters or return results.
- void EmitGlobalCall(const HloComputation& callee,
- tensorflow::StringPiece name);
-
- // Returns the buffer to which a global call to `callee` would have written
- // its result.
- llvm::Value* GetBufferForGlobalCallReturnValue(const HloComputation& callee);
+ // Array function call emitter. Stores the function's result into a supplied
+ // buffer.
+ // Parameters:
+ // function - The LLVM function to call.
+ // parameter_addresses - pointers to be passed to the function as
+ // parameters.
+ // return_value - pointer to a buffer where the call result is stored.
+
+ void EmitArrayFunctionCallInto(
+ llvm::Function* function,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ llvm::Value* return_value_buffer, tensorflow::StringPiece name);
+
+ // Array function call emitter. Returns a Value for the function's return
+ // value buffer address. The return value buffer is alloca'ed by this
+ // function.
+ llvm::Value* EmitArrayFunctionCall(
+ llvm::Function* function, const Shape& return_shape, int64 element_count,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ tensorflow::StringPiece name);
// Verifies that the element types of all of the given operand instructions
// match and are of one of the given supported types.
@@ -400,10 +408,11 @@ class IrEmitter : public DfsHloVisitorWithDefault {
NameUniquer name_uniquer_;
// Map containing all previously emitted computations.
- std::map<const HloComputation*, llvm::Function*> emitted_functions_;
+ std::map<HloComputation*, llvm::Function*> emitted_functions_;
// Map containing all previously emitted thread-local temporary buffers.
- std::map<std::pair<llvm::Function*, BufferAllocation::Slice>, llvm::Value*>
+ std::map<std::pair<llvm::Function*, BufferAllocation::Slice>,
+ llvm::AllocaInst*>
thread_local_buffers_;
// The following fields track the IR emission state. According to LLVM memory
@@ -413,16 +422,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
std::unique_ptr<IrFunction> compute_function_;
llvm::IRBuilder<> b_;
- // The buffer allocation slice for the root of the computation being compiled.
- // Only relevant for thread local computations.
- BufferAllocation::Slice computation_root_allocation_;
-
- // Maps the buffer allocation slices for the parameters to the computation
- // being compiled to their parameter numbers. Only relevant for thread local
- // computations.
- tensorflow::gtl::FlatMap<BufferAllocation::Index, int64>
- computation_parameter_allocations_;
-
// Maps HLO instructions to their index into the profile counter array.
const std::unordered_map<const HloInstruction*, int64>
instruction_to_profile_idx_;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc
index 2db4d000f5..6aff838462 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc
@@ -80,16 +80,9 @@ void IrFunction::Initialize(const string& function_name,
// void function(i8* retval, i8* run_options, i8** params, i8** temps,
// i64* dynamic_loop_bounds, i64* prof_counters)
//
- // For thread local functions:
- // retval: points to the returned value.
- // params: address of an array with pointers to parameters.
- // temps: is null
- //
- // For global functions:
- // retval: is null
- // params: is null
- // temps: address of an array with pointers to temporary buffers and entry
- // computation parameters.
+ // retval: points to the returned value.
+ // params: address of an array with pointers to parameters.
+ // temps: address of an array with pointers to temporary buffers.
//
// Therefore, the generated function's signature (FunctionType) is statically
// determined - parameter unpacking is done in code generated into the
@@ -203,25 +196,18 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments(
llvm::IRBuilder<>* b, tensorflow::StringPiece name,
llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg,
llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) {
- llvm::Value* parameter_addresses_buffer;
-
- if (parameter_addresses.empty()) {
- parameter_addresses_buffer =
- llvm::Constant::getNullValue(b->getInt8PtrTy()->getPointerTo());
- } else {
- parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
- b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()),
- tensorflow::strings::StrCat(name, "_parameter_addresses"), b);
-
- for (size_t i = 0; i < parameter_addresses.size(); ++i) {
- llvm::Value* parameter_as_i8ptr =
- b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(),
- AsStringRef(tensorflow::strings::StrCat(
- name, "_parameter_", i, "_address_as_i8ptr")));
- llvm::Value* slot_in_param_addresses =
- b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)});
- b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
- }
+ llvm::Value* parameter_addresses_buffer =
+ llvm_ir::EmitAllocaAtFunctionEntryWithCount(
+ b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()),
+ tensorflow::strings::StrCat(name, "_parameter_addresses"), b);
+ for (size_t i = 0; i < parameter_addresses.size(); ++i) {
+ llvm::Value* parameter_as_i8ptr =
+ b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(),
+ AsStringRef(tensorflow::strings::StrCat(
+ name, "_parameter_", i, "_address_as_i8ptr")));
+ llvm::Value* slot_in_param_addresses =
+ b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)});
+ b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
}
const auto to_int8_ptr = [=](llvm::Value* ptr) {
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
index 52b6c8eb80..d03da46575 100644
--- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
+++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc
@@ -65,7 +65,6 @@ void __xla_cpu_runtime_ParallelForkJoin(
VLOG(2) << "ParallelForkJoin ENTRY"
<< " num_partitions: " << num_partitions
<< " num_partitioned_dims: " << num_partitioned_dims;
- CHECK_EQ(params, nullptr);
CHECK_GT(num_partitions, 1);
CHECK_GT(num_partitioned_dims, 0);
const xla::ExecutableRunOptions* run_options =
@@ -80,9 +79,9 @@ void __xla_cpu_runtime_ParallelForkJoin(
for (int32 i = 1; i < num_partitions; ++i) {
const int64 offset = i * stride;
run_options->intra_op_thread_pool()->enqueueNoNotification(
- [i, function, result_ptr, run_options_ptr, temps, prof_counters,
+ [i, function, result_ptr, run_options_ptr, params, temps, prof_counters,
partitions, offset, &bc]() {
- function(result_ptr, run_options_ptr, nullptr, temps,
+ function(result_ptr, run_options_ptr, params, temps,
&partitions[offset], prof_counters);
bc.DecrementCount();
VLOG(3) << "ParallelForkJoin partition " << i << " done.";
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
index fe5ec1cc66..941d940684 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc
@@ -56,12 +56,12 @@ ENTRY while3 {
)";
CompileAndVerifyIr(hlo_string, R"(
-; CHECK-LABEL: @body(i8* %retval
+; CHECK-LABEL: @body(i8* align 4 dereferenceable(4) %retval
; CHECK: %[[add_result:.*]] = fadd fast float %[[fadd_lhs:.*]], %[[fadd_rhs:.*]]
; CHECK: store float %[[add_result]], float* %[[store_dest:.*]], !alias.scope ![[alias_scope_md_for_store:[0-9]+]]
;
-; CHECK-LABEL: @condition(i8* %retval, i8* noalias %run_options, i8** noalias %params
-; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %temps, i64 0
+; CHECK-LABEL: @condition(i8* align 1 dereferenceable(1) %fusion, i8* noalias %run_options, i8** noalias %params
+; CHECK: %[[cond_state_buf_ptr:.*]] = getelementptr inbounds i8*, i8** %params, i64 0
; CHECK: %[[cond_state_buf_untyped:.*]] = load i8*, i8** %[[cond_state_buf_ptr]]
; CHECK: %[[cond_state_buf_typed:.*]] = bitcast i8* %[[cond_state_buf_untyped]] to float*
; CHECK: load float, float* %[[cond_state_buf_typed]], !alias.scope ![[alias_scope_md_for_store]], !noalias ![[noalias_md_for_load:.*]]
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc
index 115448c908..47cab79604 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc
@@ -42,12 +42,13 @@ extern "C" void SumStructElements(float* out, void** parameters) {
TEST_F(LocalClientAotTest, Constant) {
xla::ExecutableRunOptions run_options;
OpaqueData opaque_data{100, 20, 3};
+ void* parameters[] = {&opaque_data};
float out = 0;
- void* temporary_buffers[] = {&opaque_data, &out};
- SumAndDouble(&out, &run_options, nullptr, temporary_buffers);
+ void* temporary_buffers[] = {nullptr, &out};
+ SumAndDouble(&out, &run_options, parameters, temporary_buffers);
EXPECT_EQ(out, 246.0f);
opaque_data = {1, 2, 3};
- SumAndDouble(&out, &run_options, nullptr, temporary_buffers);
+ SumAndDouble(&out, &run_options, parameters, temporary_buffers);
EXPECT_EQ(out, 12.0f);
}
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
index e310966d8b..74494e60e8 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
@@ -93,7 +93,7 @@ int main(int argc, char** argv) {
// local_client_aot_test.cc to be able to easily invoke the function.
CHECK_EQ(result->result_buffer_index(), 1);
CHECK_EQ(result->buffer_sizes().size(), 3);
- CHECK_EQ(result->buffer_sizes()[0], -2); // param buffer
+ CHECK_EQ(result->buffer_sizes()[0], -1); // param buffer
CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // result buffer
CHECK_EQ(result->buffer_sizes()[2], -1); // const buffer
if (triple.isOSBinFormatELF()) {
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 1bdf1867b9..c81c27891c 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -1236,35 +1236,6 @@ TEST_F(WhileTest, WhileWithLoopInvariantOperation) {
{param_value.get()}, ErrorSpec(4e-5));
}
-TEST_F(WhileTest, DISABLED_ON_INTERPRETER(WhileInfeedCondition)) {
- auto while_shape = ShapeUtil::MakeShape(S32, {});
-
- XlaComputation condition;
- {
- XlaBuilder builder("condition");
- Parameter(&builder, 0, while_shape, "state");
- Infeed(&builder, ShapeUtil::MakeShape(PRED, {}));
- TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
- }
-
- XlaComputation body;
- {
- XlaBuilder builder("body");
- auto indvar = Parameter(&builder, 0, while_shape, "state");
- Add(indvar, ConstantR0<int32>(&builder, 1));
- TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
- }
-
- XlaBuilder builder(TestName());
- While(condition, body, ConstantR0<int32>(&builder, 0));
-
- TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
- TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(true)));
- TF_ASSERT_OK(client_->TransferToInfeed(*LiteralUtil::CreateR0<bool>(false)));
-
- ComputeAndCompareR0<int32>(&builder, 2, {});
-}
-
void BM_WhileLoop(int num_iters) {
// Benchmark a simple kernel to measure while loop overheads.
tensorflow::testing::StopTiming();