[mlir][LLVM] Add fastmath flags support to fpext/fptrunc ops. (#192185)
Add fastmath attributes to llvm fpext/fptrunc ops,
FastmathFlagsInterface op interface support.
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 98f7e4e..d7c8cf2 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -588,6 +588,26 @@
}];
}
+class LLVM_CastOpWithFastMathFlag<string mnemonic, string instName, Type type,
+ Type resultType, list<Trait> traits = []> :
+ LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<FastmathFlagsInterface>], traits)>,
+ LLVM_Builder<"$res = builder.Create" # instName # "($arg, $_resultType);"> {
+ let arguments = (
+ ins type:$arg,
+ DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags);
+ let results = (outs resultType:$res);
+ let builders = [LLVM_OneResultOpBuilder];
+ let assemblyFormat = "$arg (`fastmath` `` $fastmathFlags^)? "
+ "attr-dict `:` type($arg) `to` type($res)";
+ string llvmInstName = instName;
+ string mlirBuilder = [{
+ auto op = $_qualCppClassName::create($_builder,
+ $_location, $_resultType, $arg);
+ moduleImport.setFastmathFlagsAttr(inst, op);
+ $res = op;
+ }];
+}
+
class LLVM_DereferenceableCastOp<string mnemonic, string instName, Type type,
Type resultType, list<Trait> traits = []> :
LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<DereferenceableOpInterface>], traits)> {
@@ -699,10 +719,10 @@
def LLVM_FPToUIOp : LLVM_CastOp<"fptoui", "FPToUI",
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
LLVM_ScalarOrVectorOf<AnySignlessInteger>>;
-def LLVM_FPExtOp : LLVM_CastOp<"fpext", "FPExt",
+def LLVM_FPExtOp : LLVM_CastOpWithFastMathFlag<"fpext", "FPExt",
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
-def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "FPTrunc",
+def LLVM_FPTruncOp : LLVM_CastOpWithFastMathFlag<"fptrunc", "FPTrunc",
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
diff --git a/mlir/test/Target/LLVMIR/Import/fastmath.ll b/mlir/test/Target/LLVMIR/Import/fastmath.ll
index 0c6a74c..a6f8f10 100644
--- a/mlir/test/Target/LLVMIR/Import/fastmath.ll
+++ b/mlir/test/Target/LLVMIR/Import/fastmath.ll
@@ -19,6 +19,17 @@
; // -----
+; CHECK-LABEL: @fastmath_cast
+define void @fastmath_cast(float %arg1) {
+ ; CHECK: llvm.fpext %{{.*}} fastmath<nnan> : f32 to f64
+ %1 = fpext nnan float %arg1 to double
+ ; CHECK: llvm.fptrunc %{{.*}} fastmath<fast> : f32 to f16
+ %2 = fptrunc fast float %arg1 to half
+ ret void
+}
+
+; // -----
+
; CHECK-LABEL: @fastmath_fcmp
define void @fastmath_fcmp(float %arg1, float %arg2) {
; CHECK: llvm.fcmp "oge" %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath<nsz>} : f32
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index ef4082d..aa4d0cd 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -2208,6 +2208,11 @@
%25 = llvm.mlir.constant(true) : i1
// CHECK: select contract i1
%26 = llvm.select %25, %arg0, %20 {fastmathFlags = #llvm.fastmath<contract>} : i1, f32
+
+// CHECK: {{.*}} = fpext nnan float {{.*}} to double
+// CHECK: {{.*}} = fptrunc fast float {{.*}} to half
+ %27 = llvm.fpext %arg0 fastmath<nnan> : f32 to f64
+ %28 = llvm.fptrunc %arg0 fastmath<fast> : f32 to f16
llvm.return
}