aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/gpu_executable.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc55
1 files changed, 54 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 9767836cd6..52c8595ffb 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -181,6 +181,51 @@ Status GpuExecutable::ExecuteThunks(
return Status::OK();
}
+StatusOr<const GpuExecutable::BufferAllocToDeviceMemoryMap*>
+GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) {
+ tensorflow::mutex_lock lock(module_handle_mutex_);
+ auto it = module_globals_.find(executor);
+ if (it != module_globals_.end()) {
+ return &it->second;
+ }
+
+ se::MultiModuleLoaderSpec module_spec;
+ module_spec.AddCudaCubinInMemory(cubin());
+ module_spec.AddCudaPtxInMemory(ptx().c_str());
+
+ tensorflow::gtl::FlatMap<int64, se::DeviceMemoryBase> globals;
+ se::ModuleHandle module_handle;
+ executor->LoadModule(module_spec, &module_handle);
+
+ for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
+ ++i) {
+ const BufferAllocation& allocation = assignment_->GetAllocation(i);
+ if (allocation.is_constant()) {
+ TF_ASSIGN_OR_RETURN(
+ se::DeviceMemoryBase global,
+ executor->GetUntypedSymbol(
+ ConstantBufferAllocationToGlobalName(allocation), module_handle));
+ VLOG(3) << "Resolved global "
+ << ConstantBufferAllocationToGlobalName(allocation) << " to "
+ << global.opaque();
+ InsertOrDie(&globals, i, global);
+
+ const Literal& literal = LiteralForConstantAllocation(allocation);
+ CHECK(ShapeUtil::IsArray(literal.shape()));
+ if (!ShouldEmitLiteralInLlvmIr(literal)) {
+ VLOG(3) << "H2D memcpy for constant with shape "
+ << ShapeUtil::HumanString(literal.shape());
+ TF_RETURN_IF_ERROR(executor->SynchronousMemcpyH2D(
+ literal.untyped_data(), allocation.size(), &global));
+ }
+ }
+ }
+
+ module_handles_.emplace(executor,
+ se::ScopedModuleHandle(executor, module_handle));
+ return &module_globals_.emplace(executor, std::move(globals)).first->second;
+}
+
StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
@@ -192,6 +237,10 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
}
BufferAllocations::Builder buffer_allocations_builder;
+ se::StreamExecutor* executor = run_options->stream()->parent();
+
+ TF_ASSIGN_OR_RETURN(auto* const globals, ResolveConstantGlobals(executor));
+
for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
++i) {
const BufferAllocation& allocation = assignment_->GetAllocation(i);
@@ -213,8 +262,12 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
buffer_allocations_builder.RegisterBuffer(i, buffer);
}
+
+ if (allocation.is_constant()) {
+ buffer_allocations_builder.RegisterBuffer(i, FindOrDie(*globals, i));
+ }
}
- se::StreamExecutor* executor = run_options->stream()->parent();
+
TF_ASSIGN_OR_RETURN(
auto buffer_allocations,
buffer_allocations_builder.Build(