[Matrix] Support #pragma clang fp

From https://bugs.llvm.org/show_bug.cgi?id=49739:

Currently, `#pragma clang fp` are ignored for matrix types.

For the code below, the `contract` fast-math flag should be added to the generated call to `llvm.matrix.multiply` and `fadd`

```
typedef float fx2x2_t __attribute__((matrix_type(2, 2)));

void foo(fx2x2_t &A, fx2x2_t &C, fx2x2_t &B) {
  #pragma clang fp contract(fast)
  C = A*B + C;
}
```

Reviewed By: fhahn, mibintc

Differential Revision: https://reviews.llvm.org/D100834
This commit is contained in:
Hamza Mahfooz 2021-04-22 09:15:48 +01:00 committed by Florian Hahn
parent 439366817b
commit be2277fbf2
No known key found for this signature in database
GPG Key ID: 61D7554B5CECDC0D
2 changed files with 54 additions and 0 deletions

View File

@ -732,6 +732,7 @@ public:
BO->getLHS()->getType().getCanonicalType());
auto *RHSMatTy = dyn_cast<ConstantMatrixType>(
BO->getRHS()->getType().getCanonicalType());
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
if (LHSMatTy && RHSMatTy)
return MB.CreateMatrixMultiply(Ops.LHS, Ops.RHS, LHSMatTy->getNumRows(),
LHSMatTy->getNumColumns(),
@ -3206,6 +3207,7 @@ Value *ScalarExprEmitter::EmitDiv(const BinOpInfo &Ops) {
"first operand must be a matrix");
assert(BO->getRHS()->getType().getCanonicalType()->isArithmeticType() &&
"second operand must be an arithmetic type");
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
return MB.CreateScalarDiv(Ops.LHS, Ops.RHS,
Ops.Ty->hasUnsignedIntegerRepresentation());
}
@ -3585,6 +3587,7 @@ Value *ScalarExprEmitter::EmitAdd(const BinOpInfo &op) {
if (op.Ty->isConstantMatrixType()) {
llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(CGF, op.FPFeatures);
return MB.CreateAdd(op.LHS, op.RHS);
}
@ -3734,6 +3737,7 @@ Value *ScalarExprEmitter::EmitSub(const BinOpInfo &op) {
if (op.Ty->isConstantMatrixType()) {
llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(CGF, op.FPFeatures);
return MB.CreateSub(op.LHS, op.RHS);
}

View File

@ -0,0 +1,50 @@
// RUN: %clang -emit-llvm -S -fenable-matrix -mllvm -disable-llvm-optzns %s -o - | FileCheck %s
typedef float fx2x2_t __attribute__((matrix_type(2, 2)));
typedef int ix2x2_t __attribute__((matrix_type(2, 2)));
fx2x2_t fp_matrix_contract(fx2x2_t a, fx2x2_t b, float c, float d) {
// CHECK: call contract <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32
// CHECK: fdiv contract <4 x float>
// CHECK: fmul contract <4 x float>
#pragma clang fp contract(fast)
return (a * b / c) * d;
}
fx2x2_t fp_matrix_reassoc(fx2x2_t a, fx2x2_t b, fx2x2_t c) {
// CHECK: fadd reassoc <4 x float>
// CHECK: fsub reassoc <4 x float>
#pragma clang fp reassociate(on)
return a + b - c;
}
fx2x2_t fp_matrix_ops(fx2x2_t a, fx2x2_t b, fx2x2_t c) {
// CHECK: call reassoc contract <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32
// CHECK: fadd reassoc contract <4 x float>
#pragma clang fp contract(fast) reassociate(on)
return a * b + c;
}
fx2x2_t fp_matrix_compound_ops(fx2x2_t a, fx2x2_t b, fx2x2_t c, fx2x2_t d,
float e, float f) {
// CHECK: call reassoc contract <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32
// CHECK: fadd reassoc contract <4 x float>
// CHECK: fsub reassoc contract <4 x float>
// CHECK: fmul reassoc contract <4 x float>
// CHECK: fdiv reassoc contract <4 x float>
#pragma clang fp contract(fast) reassociate(on)
a *= b;
a += c;
a -= d;
a *= e;
a /= f;
return a;
}
ix2x2_t int_matrix_ops(ix2x2_t a, ix2x2_t b, ix2x2_t c) {
// CHECK: call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32
// CHECK: add <4 x i32>
#pragma clang fp contract(fast) reassociate(on)
return a * b + c;
}