[flang] Support arith::FastMathFlagsAttr for fir::CallOp.

The main purpose of this patch is to propagate fastmath attribute
to SimplifyIntrinsicsPass, so that the inline code can inherit
the call operation's attributes. Even though I added translation
of fastmath from fir::CallOp to LLVM::CallOp, there are no fastmath
attributes in LLVM IR. It looks like the translation drops it.
This will need additional commits.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D137602
This commit is contained in:
Slava Zakharin 2022-11-09 15:18:50 -08:00
parent a11cd0d94e
commit bc955cae35
5 changed files with 61 additions and 7 deletions

View File

@ -26,6 +26,11 @@ def fir_Dialect : Dialect {
let cppNamespace = "::fir";
let useDefaultTypePrinterParser = 0;
let useDefaultAttributePrinterParser = 0;
let dependentDialects = [
// Arith dialect provides FastMathFlagsAttr
// supported by some FIR operations.
"arith::ArithDialect"
];
}
#endif // FORTRAN_DIALECT_FIR_DIALECT

View File

@ -14,6 +14,8 @@
#ifndef FORTRAN_DIALECT_FIR_OPS
#define FORTRAN_DIALECT_FIR_OPS
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
include "flang/Optimizer/Dialect/FIRDialect.td"
include "flang/Optimizer/Dialect/FIRTypes.td"
include "flang/Optimizer/Dialect/FIRAttr.td"
@ -2266,7 +2268,8 @@ def fir_IterWhileOp : region_Op<"iterate_while",
// Procedure call operations
//===----------------------------------------------------------------------===//
def fir_CallOp : fir_Op<"call", [CallOpInterface]> {
def fir_CallOp : fir_Op<"call",
[CallOpInterface, DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
let summary = "call a procedure";
let description = [{
@ -2283,7 +2286,9 @@ def fir_CallOp : fir_Op<"call", [CallOpInterface]> {
let arguments = (ins
OptionalAttr<SymbolRefAttr>:$callee,
Variadic<AnyType>:$args
Variadic<AnyType>:$args,
DefaultValuedAttr<Arith_FastMathAttr,
"::mlir::arith::FastMathFlags::none">:$fastmath
);
let results = (outs Variadic<AnyType>);

View File

@ -19,6 +19,7 @@
#include "flang/Optimizer/Support/InternalNames.h"
#include "flang/Optimizer/Support/TypeCode.h"
#include "flang/Semantics/runtime-type-info.h"
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
@ -699,8 +700,11 @@ struct CallOpConversion : public FIROpConversion<fir::CallOp> {
llvm::SmallVector<mlir::Type> resultTys;
for (auto r : call.getResults())
resultTys.push_back(convertType(r.getType()));
// Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
attrConvert(call);
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
call, resultTys, adaptor.getOperands(), call->getAttrs());
call, resultTys, adaptor.getOperands(), attrConvert.getAttrs());
return mlir::success();
}
};

View File

@ -655,8 +655,18 @@ void fir::CallOp::print(mlir::OpAsmPrinter &p) {
else
p << getOperand(0);
p << '(' << (*this)->getOperands().drop_front(isDirect ? 0 : 1) << ')';
p.printOptionalAttrDict((*this)->getAttrs(),
{fir::CallOp::getCalleeAttrNameStr()});
// Print 'fastmath<...>' (if it has non-default value) before
// any other attributes.
mlir::arith::FastMathFlagsAttr fmfAttr = getFastmathAttr();
if (fmfAttr.getValue() != mlir::arith::FastMathFlags::none) {
p << ' ' << mlir::arith::FastMathFlagsAttr::getMnemonic();
p.printStrippedAttrOrType(fmfAttr);
}
p.printOptionalAttrDict(
(*this)->getAttrs(),
{fir::CallOp::getCalleeAttrNameStr(), getFastmathAttrName()});
auto resultTypes{getResultTypes()};
llvm::SmallVector<mlir::Type> argTypes(
llvm::drop_begin(getOperandTypes(), isDirect ? 0 : 1));
@ -678,8 +688,18 @@ mlir::ParseResult fir::CallOp::parse(mlir::OpAsmParser &parser,
return mlir::failure();
mlir::Type type;
if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren) ||
parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::Paren))
return mlir::failure();
// Parse 'fastmath<...>', if present.
mlir::arith::FastMathFlagsAttr fmfAttr;
llvm::StringRef fmfAttrName = getFastmathAttrName(result.name);
if (mlir::succeeded(parser.parseOptionalKeyword(fmfAttrName)))
if (parser.parseCustomAttributeWithFallback(fmfAttr, mlir::Type{},
fmfAttrName, attrs))
return mlir::failure();
if (parser.parseOptionalAttrDict(attrs) || parser.parseColon() ||
parser.parseType(type))
return mlir::failure();

View File

@ -0,0 +1,20 @@
// RUN: fir-opt %s | fir-opt | FileCheck %s
// CHECK-LABEL: @test_callop
func.func @test_callop(%arg0 : f32) {
// CHECK: fir.call @callee() : () -> ()
fir.call @callee() fastmath<none> : () -> ()
// CHECK: fir.call @callee() : () -> ()
fir.call @callee() {fastmath = #arith.fastmath<none>} : () -> ()
// CHECK: fir.call @callee() fastmath<ninf,contract> : () -> ()
fir.call @callee() fastmath<ninf,contract> : () -> ()
// CHECK: fir.call @callee() fastmath<nnan,afn> : () -> ()
fir.call @callee() {fastmath = #arith.fastmath<nnan,afn>} : () -> ()
// CHECK: fir.call @callee() fastmath<fast> : () -> ()
fir.call @callee() fastmath<fast> : () -> ()
// CHECK: fir.call @callee() fastmath<fast> : () -> ()
fir.call @callee() {fastmath = #arith.fastmath<fast>} : () -> ()
return
}
func.func private @callee()