aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/gpu_constants.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/gpu_constants.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_constants.cc13
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.cc b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc
index aa360c7f73..e6ddea6d25 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_constants.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc
@@ -14,12 +14,21 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
+#include "tensorflow/core/framework/allocator.h"
namespace xla {
namespace gpu {
-// http://docs.nvidia.com/cuda/cuda-c-programming-guide/#device-memory-accesses
-const int64 kCudaMallocAlignBytes = 256;
+// kEntryParameterAlignBytes is equal to EIGEN_MAX_ALIGN_BYTES, though including
+// Eigen headers here to get that symbol may not be a good idea.
+// EIGEN_MAX_ALIGN_BYTES may differ between CUDA-enabled builds vs CUDA-disabled
+// builds and we don't want the IR generated by XLA:GPU to depend on that.
+//
+// TODO(b/111767313): Consider raising EIGEN_MAX_ALIGN_BYTES if it helps.
+const int64 kEntryParameterAlignBytes = 16;
+
+const int64 kXlaAllocatedBufferAlignBytes =
+ tensorflow::Allocator::kAllocatorAlignment;
} // namespace gpu
} // namespace xla