aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/cpu_executable.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/cpu_executable.h')
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.h27
1 files changed, 18 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 8dd47bfb86..8af8a5dfec 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -85,20 +85,29 @@ class CpuExecutable : public Executable {
const BufferAssignment& buffer_assignment() const { return *assignment_; }
private:
- // Allocate buffers required for execution and assign them to the elements of
- // "buffers". "buffers" should be sized to the number of buffers in buffer
- // assignment. Each vector element corresponds to a particular Index. If
- // a vector element already contains a non-null DeviceMemoryBase, then no
- // buffer is assigned for this element.
- Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator,
- int device_ordinal,
- std::vector<OwningDeviceMemory>* buffers);
+ // Creates an array suitable for passing as the "temps" argument to the JIT
+ // compiled function pointer.
+ //
+ // Returns (unowning_buffers, owning_buffers) where:
+ //
+ // - unowning_buffers.data() can be passed as the temps argument as-is and
+ // includes pointers to the scratch storage required by the computation,
+ // the live-out buffer into which the result will be written and entry
+ // computation parameters.
+ //
+ // - owning_buffers contains owning pointers to the buffers that were
+ // allocated by this routine. This routine allocates buffers for temporary
+ // storage and the live-out buffer into which the computation writes it
+ // result.
+ StatusOr<std::pair<std::vector<se::DeviceMemoryBase>,
+ std::vector<OwningDeviceMemory>>>
+ CreateTempArray(DeviceMemoryAllocator* memory_allocator, int device_ordinal,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
// Calls the generated function performing the computation with the given
// arguments using the supplied buffers.
Status ExecuteComputeFunction(
const ExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile);