diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 00ccfb1c78..114a9241bd 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -58,11 +58,15 @@ xla::StatusOr<std::vector<intptr_t>> ComputeTempSizes( std::vector<intptr_t> temp_sizes; temp_sizes.reserve(allocations.size()); for (const xla::BufferAllocation& allocation : allocations) { - // Callers don't allocate temporary buffers for parameters. Nor for - // thread-local buffers, which are lowered to alloca. - if (allocation.is_entry_computation_parameter() || - allocation.is_thread_local()) { + if (allocation.is_constant() || allocation.is_thread_local()) { + // Constants are lowered to globals. Thread locals are lowered to + // allocas. temp_sizes.push_back(-1); + } else if (allocation.is_entry_computation_parameter()) { + // Entry computation parameters need some preprocessing in + // XlaCompiledCpuFunction::Run. See the comment on + // XlaCompiledCpuFunction::StaticData::temp_sizes. + temp_sizes.push_back(-allocation.parameter_number() - 2); } else { temp_sizes.push_back(allocation.size()); } |