aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/cpu_executable.cc')
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc51
1 files changed, 25 insertions, 26 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index e6ef9d6314..e956f478b8 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -55,11 +55,12 @@ CpuExecutable::CpuExecutable(
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<const HloModule> hlo_module,
const string& entry_function_name,
- std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx)
- : Executable(std::move(hlo_module)),
+ std::unique_ptr<HloProfilePrinter> hlo_profile_printer,
+ std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
+ : Executable(std::move(hlo_module), std::move(hlo_profile_printer),
+ std::move(hlo_profile_index_map)),
jit_(std::move(jit)),
- assignment_(std::move(assignment)),
- hlo_to_profile_idx_(std::move(hlo_to_profile_idx)) {
+ assignment_(std::move(assignment)) {
// Resolve symbols in the constructor rather than at execution time to avoid
// races because FindSymbol is not thread safe.
llvm::JITSymbol sym = jit_->FindSymbol(entry_function_name);
@@ -183,9 +184,16 @@ Status CpuExecutable::ExecuteComputeFunction(
uint64 start_micros = tensorflow::Env::Default()->NowMicros();
// Allocate profiling counters for each hlo instruction that we would like to
- // profile. Allocate an additional profile counter for the entire
- // computation.
- std::vector<uint64> profile_counters(hlo_to_profile_idx_.size() + 1);
+ // profile. Even when not Hlo profiling, we allocate a counter for the entire
+ // computation, which we use to update ExecutionProfile below.
+ std::vector<int64>* profile_counters = nullptr;
+ std::vector<int64> profile_counter_for_entry_computation;
+ if (hlo_execution_profile) {
+ profile_counters = hlo_execution_profile->mutable_profile_counters();
+ } else {
+ profile_counters = &profile_counter_for_entry_computation;
+ profile_counter_for_entry_computation.push_back(0);
+ }
// Call the computation function following the calling convention.
std::vector<void*> buffer_pointers;
@@ -200,7 +208,7 @@ Status CpuExecutable::ExecuteComputeFunction(
VLOG(3) << tensorflow::strings::Printf(
" func(void* result, void* params[%zu], void* temps[%zu], "
"uint64 profile_counters[%zu])",
- args_array.size(), buffer_pointers.size(), profile_counters.size());
+ args_array.size(), buffer_pointers.size(), profile_counters->size());
VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer);
auto ptr_printer = [](string* out, const void* p) {
tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p));
@@ -212,11 +220,11 @@ Status CpuExecutable::ExecuteComputeFunction(
" temps = [%s]",
tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str());
VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p",
- profile_counters.data());
+ profile_counters->data());
}
compute_function_(result_buffer, run_options, args_array.data(),
- buffer_pointers.data(), profile_counters.data());
+ buffer_pointers.data(), profile_counters->data());
uint64 end_micros = tensorflow::Env::Default()->NowMicros();
@@ -225,20 +233,15 @@ Status CpuExecutable::ExecuteComputeFunction(
const double nanoseconds = (end_micros - start_micros) * 1000.0;
execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0));
- // The last profile counter is used for the computation as a whole.
- execution_profile_.set_compute_cycle_count(profile_counters.back());
- }
-
- if (hlo_execution_profile != nullptr) {
- hlo_execution_profile->set_total_cycles_executed(
- *module().entry_computation(), profile_counters.back());
-
- for (auto hlo_prof_idx : hlo_to_profile_idx_) {
- const HloInstruction* hlo = hlo_prof_idx.first;
- uint64 cycles_taken = profile_counters[hlo_prof_idx.second];
- hlo_execution_profile->SetCyclesTakenBy(hlo, cycles_taken);
+ if (hlo_execution_profile) {
+ execution_profile_.set_compute_cycle_count(
+ hlo_execution_profile->total_cycles_executed(
+ *module().entry_computation()));
+ } else {
+ execution_profile_.set_compute_cycle_count(profile_counters->back());
}
}
+
return Status::OK();
}
@@ -428,9 +431,5 @@ const PointsToSet& CpuExecutable::GetRootPointsToSet() const {
module().entry_computation()->root_instruction());
}
-std::unique_ptr<HloCostAnalysis> CpuExecutable::CreateCostAnalysis() const {
- return MakeUnique<HloCostAnalysis>(ShapeSizeBytes);
-}
-
} // namespace cpu
} // namespace xla