aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc')
-rw-r--r--tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc12
1 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
index 00ccfb1c78..114a9241bd 100644
--- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
@@ -58,11 +58,15 @@ xla::StatusOr<std::vector<intptr_t>> ComputeTempSizes(
std::vector<intptr_t> temp_sizes;
temp_sizes.reserve(allocations.size());
for (const xla::BufferAllocation& allocation : allocations) {
- // Callers don't allocate temporary buffers for parameters. Nor for
- // thread-local buffers, which are lowered to alloca.
- if (allocation.is_entry_computation_parameter() ||
- allocation.is_thread_local()) {
+ if (allocation.is_constant() || allocation.is_thread_local()) {
+ // Constants are lowered to globals. Thread locals are lowered to
+ // allocas.
temp_sizes.push_back(-1);
+ } else if (allocation.is_entry_computation_parameter()) {
+ // Entry computation parameters need some preprocessing in
+ // XlaCompiledCpuFunction::Run. See the comment on
+ // XlaCompiledCpuFunction::StaticData::temp_sizes.
+ temp_sizes.push_back(-allocation.parameter_number() - 2);
} else {
temp_sizes.push_back(allocation.size());
}