[NVPTX] Fix a typo that makes the output invalid PTX

It's surprisingly tricky to trigger this as it's only used by abs/neg
which expand into and/xor in the integer domain.
This commit is contained in:
Benjamin Kramer 2023-12-08 14:21:11 +01:00
parent e38c29c2b7
commit 06ebe3b237
2 changed files with 11 additions and 1 deletions

View File

@ -561,7 +561,7 @@ multiclass F2_Support_Half<string OpcStr, SDNode OpNode> {
[(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a)))]>,
Requires<[hasSM<80>, hasPTX<70>]>;
def bf16x2 : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a),
!strconcat(OpcStr, ".v2bf16 \t$dst, $a;"),
!strconcat(OpcStr, ".bf16x2 \t$dst, $a;"),
[(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a)))]>,
Requires<[hasSM<80>, hasPTX<70>]>;
def f16_ftz : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a),

View File

@ -392,6 +392,16 @@ define <2 x bfloat> @test_fabs(<2 x bfloat> %a) #0 {
ret <2 x bfloat> %r
}
; CHECK-LABEL: test_fabs_add(
; CHECK: abs.bf16x2
; CHECK: ret;
define <2 x bfloat> @test_fabs_add(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
%s = fadd <2 x bfloat> %a, %a
%r = call <2 x bfloat> @llvm.fabs.f16(<2 x bfloat> %s)
%d = fadd <2 x bfloat> %r, %b
ret <2 x bfloat> %d
}
; CHECK-LABEL: test_minnum(
; CHECK-DAG: ld.param.b32 [[AF0:%r[0-9]+]], [test_minnum_param_0];