aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-07-24 08:29:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-24 08:33:21 -0700
commit1b33df1814e35015953c7cba392ba2a7387ce875 (patch)
tree5f710a16aac0dca38a0b1c5d5a5be5ff0f4f5f69
parent226831aab92a395a26824a08caa9d43f0c3d604e (diff)
[XLA:GPU] Don't lie about buffer alignment to LLVM
PiperOrigin-RevId: 205832336
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_allocations.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_constants.cc13
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_constants.h9
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/BUILD19
-rw-r--r--tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc54
8 files changed, 99 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 06ff3d9bba..72aff197fc 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -36,6 +36,7 @@ cc_library(
hdrs = ["gpu_constants.h"],
deps = [
"//tensorflow/compiler/xla:types",
+ "//tensorflow/core:framework",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
index ab5149dcdb..b095d4cd73 100644
--- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
+++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
@@ -49,12 +49,12 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
if (registered_buffers_.count(i)) {
se::DeviceMemoryBase address = FindOrDie(registered_buffers_, i);
if (reinterpret_cast<uintptr_t>(address.opaque()) %
- kCudaMallocAlignBytes !=
+ kEntryParameterAlignBytes !=
0) {
return InternalError(
"Address of registered buffer %lld must be a multiple of %llx, but "
"was %p",
- i, kCudaMallocAlignBytes, address.opaque());
+ i, kEntryParameterAlignBytes, address.opaque());
}
buffer_allocations->SetBuffer(i, FindOrDie(registered_buffers_, i));
continue;
@@ -71,12 +71,12 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
TF_ASSIGN_OR_RETURN(
buffer, memory_allocator->Allocate(device_ordinal, buffer_size));
if (reinterpret_cast<uintptr_t>(buffer.opaque()) %
- kCudaMallocAlignBytes !=
+ kXlaAllocatedBufferAlignBytes !=
0) {
return InternalError(
"Address returned by memory_allocator->Allocate must be a "
"multiple of %llx, but was %p",
- kCudaMallocAlignBytes, buffer.opaque());
+ kXlaAllocatedBufferAlignBytes, buffer.opaque());
}
// We do manual memory management within BufferAllocations. Be sure not
// to do a TF_RETURN_IF_ERROR between this line and the
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.cc b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc
index aa360c7f73..e6ddea6d25 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_constants.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc
@@ -14,12 +14,21 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
+#include "tensorflow/core/framework/allocator.h"
namespace xla {
namespace gpu {
-// http://docs.nvidia.com/cuda/cuda-c-programming-guide/#device-memory-accesses
-const int64 kCudaMallocAlignBytes = 256;
+// kEntryParameterAlignBytes is equal to EIGEN_MAX_ALIGN_BYTES, though including
+// Eigen headers here to get that symbol may not be a good idea.
+// EIGEN_MAX_ALIGN_BYTES may differ between CUDA-enabled builds vs CUDA-disabled
+// builds and we don't want the IR generated by XLA:GPU to depend on that.
+//
+// TODO(b/111767313): Consider raising EIGEN_MAX_ALIGN_BYTES if it helps.
+const int64 kEntryParameterAlignBytes = 16;
+
+const int64 kXlaAllocatedBufferAlignBytes =
+ tensorflow::Allocator::kAllocatorAlignment;
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.h b/tensorflow/compiler/xla/service/gpu/gpu_constants.h
index eb1ca4c6c9..925e6927b6 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_constants.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.h
@@ -21,9 +21,12 @@ limitations under the License.
namespace xla {
namespace gpu {
-// Minimum alignment of cudaMalloc. We require that buffers created by our
-// DeviceMemoryAllocator, and all input/output buffers, have this alignment.
-extern const int64 kCudaMallocAlignBytes;
+// Minimum alignment for buffers passed as incoming arguments by TensorFlow.
+extern const int64 kEntryParameterAlignBytes;
+
+// Minimum alignment for buffers allocated by XLA: the temp buffers and the live
+// out (result) buffers.
+extern const int64 kXlaAllocatedBufferAlignBytes;
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index b1038a3cc9..1f31a7f36b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -232,7 +232,9 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
kernel->addDereferenceableAttr(arg_no + 1, alloc->size());
kernel->addParamAttr(
arg_no, llvm::Attribute::get(context, llvm::Attribute::Alignment,
- kCudaMallocAlignBytes));
+ alloc->is_entry_computation_parameter()
+ ? kEntryParameterAlignBytes
+ : kXlaAllocatedBufferAlignBytes));
if (alloc->IsPreallocatedTempBuffer()) {
fn_arg->setName("temp_buf");
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index ad29862d83..2eefadebcd 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -543,7 +543,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
BufferAssigner::Run(module.get(), hlo_schedule->ConsumeHloOrdering(),
BufferSizeBytesFunction(),
/*color_alignment=*/[](LogicalBuffer::Color) {
- return kCudaMallocAlignBytes;
+ return kXlaAllocatedBufferAlignBytes;
}));
// BufferAssignment::Stats::ToString() and BufferAssignment::ToString()
// include headers, so no need for us to print them ourselves.
diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD
index 926262e2ad..686c3c16c9 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD
@@ -202,3 +202,22 @@ tf_cc_test(
"//tensorflow/core:test_main",
],
)
+
+tf_cc_test(
+ name = "gpu_alignment_test",
+ testonly = True,
+ srcs = ["gpu_alignment_test.cc"],
+ tags = [
+ "requires-gpu-sm35",
+ ],
+ deps = [
+ ":gpu_codegen_test",
+ "//tensorflow/compiler/xla/service:gpu_plugin",
+ "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry",
+ "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis",
+ "//tensorflow/compiler/xla/tests:filecheck",
+ "//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc
new file mode 100644
index 0000000000..672c68e59b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc
@@ -0,0 +1,54 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
+#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
+#include "tensorflow/compiler/xla/tests/filecheck.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class GpuAlignmentTest : public GpuCodegenTest {};
+
+TEST_F(GpuAlignmentTest, Test) {
+ const char* hlo_string = R"(
+HloModule GpuAlignmentTest
+
+ENTRY main {
+ zero = f32[] constant(0)
+ tok = token[] after-all()
+ a = f32[100] parameter(0)
+ b_tup = (f32[200], token[]) infeed(tok)
+ b = f32[200] get-tuple-element(b_tup), index=0
+ a_padded = f32[150] pad(a, zero), padding=0_50
+ b_sliced = f32[150] slice(b), slice={[0:150]}
+ ROOT c = f32[150] add(a_padded, b_sliced)
+}
+)";
+
+ CompileAndVerifyIr(hlo_string, R"(
+CHECK: @fusion(i8* align 64 dereferenceable(600) %alloc0, i8* align 16 dereferenceable(400) %alloc1, i8* align 64 dereferenceable(864) %temp_buf)
+)");
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla