diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h | 29 |
1 files changed, 23 insertions, 6 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 48a8c083ca..27cfb354bf 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -60,9 +60,19 @@ class XlaCompiledCpuFunction { // The raw function to call. RawFunction raw_function; - // Cardinality and sizes of arg and temp buffers. + // Cardinality and size of arg buffers. const intptr_t* arg_sizes = nullptr; size_t num_args = 0; + + // Cardinality and size of temp buffers. + // + // If temp_sizes[i] >= 0 then the i'th temp is a regular temporary buffer. + // + // If temp_sizes[i] == -1 then the i'th temp is a constant buffer. The + // corresponding entry in the temp buffer array needs to be set to null. + // + // If temp_sizes[i] < -1 then the i'th temp is the entry parameter + // -(temp_sizes[i] + 2). const intptr_t* temp_sizes = nullptr; size_t num_temps = 0; @@ -113,11 +123,7 @@ class XlaCompiledCpuFunction { // Runs the computation, with inputs read from arg buffers, and outputs // written to result buffers. Returns true on success and false on failure. - bool Run() { - raw_function_(temps_[result_index_], &run_options_, - const_cast<const void**>(args_), temps_, profile_counters_); - return true; - } + bool Run(); // Returns the error message from the previous failed Run call. // @@ -224,6 +230,17 @@ class XlaCompiledCpuFunction { void** args_ = nullptr; void** temps_ = nullptr; + // Argument i needs to be placed in temps_[arg_index_to_temp_index_[i]] for + // XLA generated code to be able to find it. + // + // For now we need to keep around the args_ array because there is code that + // depends on args() returning a void**. However, in the future we may remove + // args_ in favor of using temps_ as the sole storage for the arguments. + int32* arg_index_to_temp_index_; + + // The number of incoming arguments. + int32 num_args_; + // Backing memory for individual arg and temp buffers. void* alloc_args_ = nullptr; void* alloc_temps_ = nullptr; |