diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-08-31 16:28:11 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-31 16:41:20 -0700 |
commit | 97f4d4016fc4a1eac5313bc9753aaeae118734a4 (patch) | |
tree | 98ff08921e18f4f85280aeb609fe2ef07d224006 /tensorflow/compiler/xla/service/cpu | |
parent | b1341d049a74b7054450a9641b1c186f23df501f (diff) |
CHECK that the thread locality of the call matches thread locality of the callee
PiperOrigin-RevId: 211162384
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu')
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/ir_emitter.cc | 9 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/ir_emitter.h | 3 |
2 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 8eaca57680..43f2a034ff 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -100,6 +100,11 @@ IrEmitter::IrEmitter( b_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() .xla_cpu_enable_fast_math())); + Status s = GatherComputationsByAllocationType( + &hlo_module, &thread_local_computations_, &global_computations_); + absl::c_sort(thread_local_computations_); + absl::c_sort(global_computations_); + TF_CHECK_OK(s) << "Should have failed buffer assignment."; } StatusOr<llvm::Function*> IrEmitter::EmitComputation( @@ -2832,6 +2837,8 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { llvm::Value* IrEmitter::EmitThreadLocalCall( const HloComputation& callee, absl::Span<llvm::Value* const> parameters, absl::string_view name) { + CHECK(absl::c_binary_search(thread_local_computations_, &callee)); + const Shape& return_shape = callee.root_instruction()->shape(); // Lifting this restriction to allow "small" arrays should be easy. Allowing @@ -2869,6 +2876,8 @@ llvm::Value* IrEmitter::EmitThreadLocalCall( void IrEmitter::EmitGlobalCall(const HloComputation& callee, absl::string_view name) { + CHECK(absl::c_binary_search(global_computations_, &callee)); + Call(FindOrDie(emitted_functions_, &callee), GetArrayFunctionCallArguments( /*parameter_addresses=*/{}, &b_, name, diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 9cb8162327..6efd7fd001 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -571,6 +571,9 @@ class IrEmitter : public DfsHloVisitorWithDefault, tensorflow::gtl::FlatMap<BufferAllocation::Index, llvm::Constant*> constant_buffer_to_global_; + std::vector<const HloComputation*> thread_local_computations_; + std::vector<const HloComputation*> global_computations_; + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); }; |