aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-03 09:58:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-03 10:02:23 -0700
commit96c415ad77c20e1cf2da5e61f85e24fd6c36eb28 (patch)
tree8f32617e699ca2e99d0b2ca071e6cb9eb34fe12e
parentd8935f6414e36c6e1da95dbd13c876b7208c019b (diff)
[XLA] Use maps with a deterministic iteration order for HloInstruction*.
Convert a bunch of std::maps with HloInstruction* and const HloInstruction* keys to use a comparator that is based on the unique_id of the instruction rather than the pointer value. PiperOrigin-RevId: 174474868
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc14
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h19
5 files changed, 31 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index d9b1738c3c..af2bd6d5d7 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -610,8 +610,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
&hlo_to_profile_idx, jit->target_machine(),
jit->external_constant_pool());
- std::unique_ptr<std::map<HloInstruction*, string>> function_names(
- new std::map<HloInstruction*, string>());
+ std::unique_ptr<HloInstructionMap<string>> function_names(
+ new HloInstructionMap<string>());
for (auto embedded_computation :
computation->MakeEmbeddedComputationsList()) {
if (embedded_computation->IsFusionComputation()) {
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
index 8c443b1409..aff61296ce 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
@@ -58,7 +58,7 @@ ParallelCpuExecutable::ParallelCpuExecutable(
std::unique_ptr<SimpleOrcJIT> jit,
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<const HloModule> hlo_module,
- std::unique_ptr<const std::map<HloInstruction*, string>> function_names,
+ std::unique_ptr<const HloInstructionMap<string>> function_names,
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx,
std::unordered_map<const HloInstruction*, std::unique_ptr<unsigned char[]>>
aligned_constants)
@@ -102,10 +102,10 @@ namespace {
// in 'pending' on 'thread_pool' (storing resulting data in 'results').
class Executor {
public:
- Executor(const std::map<HloInstruction*, ComputeFunctionType>& functions,
+ Executor(const HloInstructionMap<ComputeFunctionType>& functions,
const ServiceExecutableRunOptions* run_options,
std::list<HloInstruction*>* pending,
- std::map<HloInstruction*, const void*>* results, void** temps_array,
+ HloInstructionMap<const void*>* results, void** temps_array,
uint64* profile_counters_array, const BufferAssignment* assignment)
: functions_(functions),
run_options_(run_options),
@@ -142,10 +142,10 @@ class Executor {
const void** GetOperandBuffers(HloInstruction* instruction);
// Arguments passed into Executor.
- const std::map<HloInstruction*, ComputeFunctionType>& functions_;
+ const HloInstructionMap<ComputeFunctionType>& functions_;
const ServiceExecutableRunOptions* run_options_;
std::list<HloInstruction*>* pending_;
- std::map<HloInstruction*, const void*>* results_;
+ HloInstructionMap<const void*>* results_;
void** temps_array_;
uint64* profile_counters_array_;
tensorflow::thread::ThreadPool* thread_pool_;
@@ -400,7 +400,7 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions(
}
// Resolve functions for all the HLO instructions ahead of time.
- std::map<HloInstruction*, ComputeFunctionType> functions;
+ HloInstructionMap<ComputeFunctionType> functions;
for (auto& entry : *function_names_) {
tensorflow::mutex_lock lock(jit_mutex_);
HloInstruction* instruction = entry.first;
@@ -412,7 +412,7 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions(
}
// Map containing pointers to result buffers for each instruction.
- std::map<HloInstruction*, const void*> results;
+ HloInstructionMap<const void*> results;
uint64 start_micros = tensorflow::Env::Default()->NowMicros();
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
index a75552b7d1..db16aaf48b 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
@@ -51,7 +51,7 @@ class ParallelCpuExecutable : public Executable {
std::unique_ptr<SimpleOrcJIT> jit,
std::unique_ptr<const BufferAssignment> assignment,
std::unique_ptr<const HloModule> hlo_module,
- std::unique_ptr<const std::map<HloInstruction*, string>> function_names,
+ std::unique_ptr<const HloInstructionMap<string>> function_names,
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx,
std::unordered_map<const HloInstruction*,
std::unique_ptr<unsigned char[]>>
@@ -141,8 +141,7 @@ class ParallelCpuExecutable : public Executable {
string ir_module_string_;
// Map containing the JITted function names for each HLO instruction.
- const std::unique_ptr<const std::map<HloInstruction*, string>>
- function_names_;
+ const std::unique_ptr<const HloInstructionMap<string>> function_names_;
// Maps HLOs to their index into the profile counter array.
const std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx_;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index d82462112e..2c7e735a1c 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1233,7 +1233,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
new_instruction->AppendOperand(new_operand);
}
// Clone all the fused instructions for the new fusion instruction.
- std::map<HloInstruction*, HloInstruction*> old_to_new;
+ HloInstructionMap<HloInstruction*> old_to_new;
std::list<std::unique_ptr<HloInstruction>> new_fused_instructions;
// Create the list of fused parameters by mapping through the cloned,
// fused instructions.
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 524cfe3f26..411f926a87 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -1231,6 +1231,25 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
+// Map classes that guarantee a deterministic iteration order when the key is
+// an HloInstruction* or a const HloInstruction*.
+// To make the iteration order over the map deterministic, the comparator
+// should not be using the pointer values, but rather an intrinsic property of
+// the hlo.
+struct HloPtrComparator {
+ bool operator()(const HloInstruction* const& lhs,
+ const HloInstruction* const& rhs) const {
+ return lhs->unique_id() < rhs->unique_id();
+ }
+};
+
+template <typename ValueT>
+using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>;
+
+template <typename ValueT>
+using ConstHloInstructionMap =
+ std::map<const HloInstruction*, ValueT, HloPtrComparator>;
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_