aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-08-31 16:28:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-31 16:41:20 -0700
commit97f4d4016fc4a1eac5313bc9753aaeae118734a4 (patch)
tree98ff08921e18f4f85280aeb609fe2ef07d224006 /tensorflow/compiler/xla/service/cpu
parentb1341d049a74b7054450a9641b1c186f23df501f (diff)
CHECK that the thread locality of the call matches thread locality of the callee
PiperOrigin-RevId: 211162384
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu')
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc9
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h3
2 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 8eaca57680..43f2a034ff 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -100,6 +100,11 @@ IrEmitter::IrEmitter(
b_.setFastMathFlags(llvm_ir::GetFastMathFlags(
/*fast_math_enabled=*/hlo_module_config_.debug_options()
.xla_cpu_enable_fast_math()));
+ Status s = GatherComputationsByAllocationType(
+ &hlo_module, &thread_local_computations_, &global_computations_);
+ absl::c_sort(thread_local_computations_);
+ absl::c_sort(global_computations_);
+ TF_CHECK_OK(s) << "Should have failed buffer assignment.";
}
StatusOr<llvm::Function*> IrEmitter::EmitComputation(
@@ -2832,6 +2837,8 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) {
llvm::Value* IrEmitter::EmitThreadLocalCall(
const HloComputation& callee, absl::Span<llvm::Value* const> parameters,
absl::string_view name) {
+ CHECK(absl::c_binary_search(thread_local_computations_, &callee));
+
const Shape& return_shape = callee.root_instruction()->shape();
// Lifting this restriction to allow "small" arrays should be easy. Allowing
@@ -2869,6 +2876,8 @@ llvm::Value* IrEmitter::EmitThreadLocalCall(
void IrEmitter::EmitGlobalCall(const HloComputation& callee,
absl::string_view name) {
+ CHECK(absl::c_binary_search(global_computations_, &callee));
+
Call(FindOrDie(emitted_functions_, &callee),
GetArrayFunctionCallArguments(
/*parameter_addresses=*/{}, &b_, name,
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 9cb8162327..6efd7fd001 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -571,6 +571,9 @@ class IrEmitter : public DfsHloVisitorWithDefault,
tensorflow::gtl::FlatMap<BufferAllocation::Index, llvm::Constant*>
constant_buffer_to_global_;
+ std::vector<const HloComputation*> thread_local_computations_;
+ std::vector<const HloComputation*> global_computations_;
+
TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter);
};