diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/cpu_executable.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/cpu_executable.cc | 51 |
1 files changed, 25 insertions, 26 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index e6ef9d6314..e956f478b8 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -55,11 +55,12 @@ CpuExecutable::CpuExecutable( std::unique_ptr<const BufferAssignment> assignment, std::unique_ptr<const HloModule> hlo_module, const string& entry_function_name, - std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx) - : Executable(std::move(hlo_module)), + std::unique_ptr<HloProfilePrinter> hlo_profile_printer, + std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map) + : Executable(std::move(hlo_module), std::move(hlo_profile_printer), + std::move(hlo_profile_index_map)), jit_(std::move(jit)), - assignment_(std::move(assignment)), - hlo_to_profile_idx_(std::move(hlo_to_profile_idx)) { + assignment_(std::move(assignment)) { // Resolve symbols in the constructor rather than at execution time to avoid // races because FindSymbol is not thread safe. llvm::JITSymbol sym = jit_->FindSymbol(entry_function_name); @@ -183,9 +184,16 @@ Status CpuExecutable::ExecuteComputeFunction( uint64 start_micros = tensorflow::Env::Default()->NowMicros(); // Allocate profiling counters for each hlo instruction that we would like to - // profile. Allocate an additional profile counter for the entire - // computation. - std::vector<uint64> profile_counters(hlo_to_profile_idx_.size() + 1); + // profile. Even when not Hlo profiling, we allocate a counter for the entire + // computation, which we use to update ExecutionProfile below. + std::vector<int64>* profile_counters = nullptr; + std::vector<int64> profile_counter_for_entry_computation; + if (hlo_execution_profile) { + profile_counters = hlo_execution_profile->mutable_profile_counters(); + } else { + profile_counters = &profile_counter_for_entry_computation; + profile_counter_for_entry_computation.push_back(0); + } // Call the computation function following the calling convention. std::vector<void*> buffer_pointers; @@ -200,7 +208,7 @@ Status CpuExecutable::ExecuteComputeFunction( VLOG(3) << tensorflow::strings::Printf( " func(void* result, void* params[%zu], void* temps[%zu], " "uint64 profile_counters[%zu])", - args_array.size(), buffer_pointers.size(), profile_counters.size()); + args_array.size(), buffer_pointers.size(), profile_counters->size()); VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); auto ptr_printer = [](string* out, const void* p) { tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); @@ -212,11 +220,11 @@ Status CpuExecutable::ExecuteComputeFunction( " temps = [%s]", tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p", - profile_counters.data()); + profile_counters->data()); } compute_function_(result_buffer, run_options, args_array.data(), - buffer_pointers.data(), profile_counters.data()); + buffer_pointers.data(), profile_counters->data()); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -225,20 +233,15 @@ Status CpuExecutable::ExecuteComputeFunction( const double nanoseconds = (end_micros - start_micros) * 1000.0; execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); - // The last profile counter is used for the computation as a whole. - execution_profile_.set_compute_cycle_count(profile_counters.back()); - } - - if (hlo_execution_profile != nullptr) { - hlo_execution_profile->set_total_cycles_executed( - *module().entry_computation(), profile_counters.back()); - - for (auto hlo_prof_idx : hlo_to_profile_idx_) { - const HloInstruction* hlo = hlo_prof_idx.first; - uint64 cycles_taken = profile_counters[hlo_prof_idx.second]; - hlo_execution_profile->SetCyclesTakenBy(hlo, cycles_taken); + if (hlo_execution_profile) { + execution_profile_.set_compute_cycle_count( + hlo_execution_profile->total_cycles_executed( + *module().entry_computation())); + } else { + execution_profile_.set_compute_cycle_count(profile_counters->back()); } } + return Status::OK(); } @@ -428,9 +431,5 @@ const PointsToSet& CpuExecutable::GetRootPointsToSet() const { module().entry_computation()->root_instruction()); } -std::unique_ptr<HloCostAnalysis> CpuExecutable::CreateCostAnalysis() const { - return MakeUnique<HloCostAnalysis>(ShapeSizeBytes); -} - } // namespace cpu } // namespace xla |