[NVPTX] Unforce minimum alignment of 4 for byval arguments of device-side functions.

Minimum alignment of 4 for byval arguments was forced to workaround
a bug in old versions of ptxas. Details: https://reviews.llvm.org/D22428.
Recent ptxas versions (> 9.0) do not seem to have this bug, so alignment
requirement was relaxed. To force again minimum alignment of 4, use
'-force-min-byval-param-align' option.
This commit is contained in:
Pavel Kopyl 2023-04-22 02:52:04 +02:00
parent 99cfaf0d5e
commit 62439d54fe
3 changed files with 34 additions and 20 deletions

View File

@ -89,6 +89,12 @@ static cl::opt<bool> UsePrecSqrtF32(
cl::desc("NVPTX Specific: 0 use sqrt.approx, 1 use sqrt.rn."),
cl::init(true));
static cl::opt<bool> ForceMinByValParamAlign(
"nvptx-force-min-byval-param-align", cl::Hidden,
cl::desc("NVPTX Specific: force 4-byte minimal alignment for byval"
" params of device functions."),
cl::init(false));
int NVPTXTargetLowering::getDivF32Level() const {
if (UsePrecDivF32.getNumOccurrences() > 0) {
// If nvptx-prec-div32=N is used on the command-line, always honor it
@ -4502,16 +4508,17 @@ Align NVPTXTargetLowering::getFunctionByValParamAlign(
if (F)
ArgAlign = std::max(ArgAlign, getFunctionParamOptimizedAlign(F, ArgTy, DL));
// Work around a bug in ptxas. When PTX code takes address of
// Old ptx versions have a bug. When PTX code takes address of
// byval parameter with alignment < 4, ptxas generates code to
// spill argument into memory. Alas on sm_50+ ptxas generates
// SASS code that fails with misaligned access. To work around
// the problem, make sure that we align byval parameters by at
// least 4.
// TODO: this will need to be undone when we get to support multi-TU
// device-side compilation as it breaks ABI compatibility with nvcc.
// Hopefully ptxas bug is fixed by then.
ArgAlign = std::max(ArgAlign, Align(4));
// least 4. This bug seems to be fixed at least starting from
// ptxas > 9.0.
// TODO: remove this after verifying the bug is not reproduced
// on non-deprecated ptxas versions.
if (ForceMinByValParamAlign)
ArgAlign = std::max(ArgAlign, Align(4));
return ArgAlign;
}

View File

@ -13,8 +13,9 @@ target triple = "nvptx64-nvidia-cuda"
%"class.sycl::_V1::detail::half_impl::half" = type { half }
%complex_half = type { half, half }
; CHECK: .param .align 4 .b8 param2[4];
; CHECK: st.param.v2.b16 [param2+0], {%h2, %h1};
; CHECK: .param .align 2 .b8 param2[4];
; CHECK: st.param.b16 [param2+0], %h1;
; CHECK: st.param.b16 [param2+2], %h2;
; CHECK: .param .align 2 .b8 retval0[4];
; CHECK: call.uni (retval0),
; CHECK-NEXT: _Z20__spirv_GroupCMulKHRjjN5__spv12complex_halfE,
@ -29,15 +30,16 @@ entry:
;;
declare ptr @usefp(ptr %fp)
; CHECK: .func callee(
; CHECK-NEXT: .param .align 4 .b8 callee_param_0[4]
; CHECK-NEXT: .param .align 2 .b8 callee_param_0[4]
define internal void @callee(ptr byval(%"class.complex") %byval_arg) {
ret void
}
define void @boom() {
%fp = call ptr @usefp(ptr @callee)
; CHECK: .param .align 4 .b8 param0[4];
; CHECK: st.param.v2.b16 [param0+0]
; CHECK: .callprototype ()_ (.param .align 4 .b8 _[4]);
; CHECK: .param .align 2 .b8 param0[4];
; CHECK: st.param.b16 [param0+0], %h1;
; CHECK: st.param.b16 [param0+2], %h2;
; CHECK: .callprototype ()_ (.param .align 2 .b8 _[4]);
call void %fp(ptr byval(%"class.complex") null)
ret void
}

View File

@ -1,5 +1,7 @@
; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s --check-prefixes=CHECK,NOALIGN4
; RUN: llc < %s -march=nvptx -mcpu=sm_20 -nvptx-force-min-byval-param-align | FileCheck %s --check-prefixes=CHECK,ALIGN4
; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_20 | %ptxas-verify %}
; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_20 -nvptx-force-min-byval-param-align | %ptxas-verify %}
;;; Need 4-byte alignment on ptr passed byval
define ptx_device void @t1(ptr byval(float) %x) {
@ -25,20 +27,21 @@ define ptx_device void @t3(ptr byval(%struct.float2) %x) {
ret void
}
;;; Need at least 4-byte alignment in order to avoid miscompilation by
;;; ptxas for sm_50+
define ptx_device void @t4(ptr byval(i8) %x) {
; CHECK: .func t4
; CHECK: .param .align 4 .b8 t4_param_0[1]
; NOALIGN4: .param .align 1 .b8 t4_param_0[1]
; ALIGN4: .param .align 4 .b8 t4_param_0[1]
ret void
}
;;; Make sure we adjust alignment at the call site as well.
define ptx_device void @t5(ptr align 2 byval(i8) %x) {
; CHECK: .func t5
; CHECK: .param .align 4 .b8 t5_param_0[1]
; NOALIGN4: .param .align 2 .b8 t5_param_0[1]
; ALIGN4: .param .align 4 .b8 t5_param_0[1]
; CHECK: {
; CHECK: .param .align 4 .b8 param0[1];
; NOALIGN4: .param .align 1 .b8 param0[1];
; ALIGN4: .param .align 4 .b8 param0[1];
; CHECK: call.uni
call void @t4(ptr byval(i8) %x)
ret void
@ -56,11 +59,13 @@ define ptx_device void @t6() {
call void %fp(ptr byval(double) null);
%fp2 = call ptr @getfp(i32 1)
; CHECK: prototype_4 : .callprototype ()_ (.param .align 4 .b8 _[4]);
; NOALIGN4: prototype_4 : .callprototype ()_ (.param .align 2 .b8 _[4]);
; ALIGN4: prototype_4 : .callprototype ()_ (.param .align 4 .b8 _[4]);
call void %fp(ptr byval(%struct.half2) null);
%fp3 = call ptr @getfp(i32 2)
; CHECK: prototype_6 : .callprototype ()_ (.param .align 4 .b8 _[1]);
; NOALIGN4: prototype_6 : .callprototype ()_ (.param .align 1 .b8 _[1]);
; ALIGN4: prototype_6 : .callprototype ()_ (.param .align 4 .b8 _[1]);
call void %fp(ptr byval(i8) null);
ret void
}