[fir] Add fir.convert canonicalization patterns
Add rewrite patterns for fir.convert op canonicalization.
This patch is part of the upstreaming effort from fir-dev branch.
Reviewed By: kiranchandramohan
Differential Revision: https://reviews.llvm.org/D111537
Co-authored-by: Valentin Clement <clementval@gmail.com>
GitOrigin-RevId: c3abfe4207d35c221f1667d5b0c79a6511be5ea3
diff --git a/include/flang/Optimizer/Transforms/CMakeLists.txt b/include/flang/Optimizer/Transforms/CMakeLists.txt
index 37096bf..47fcdb9 100644
--- a/include/flang/Optimizer/Transforms/CMakeLists.txt
+++ b/include/flang/Optimizer/Transforms/CMakeLists.txt
@@ -1,5 +1,8 @@
+set(LLVM_TARGET_DEFINITIONS RewritePatterns.td)
+mlir_tablegen(RewritePatterns.inc -gen-rewriters)
+add_public_tablegen_target(RewritePatternsIncGen)
+
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name OptTransform)
add_public_tablegen_target(FIROptTransformsPassIncGen)
-
diff --git a/include/flang/Optimizer/Transforms/RewritePatterns.td b/include/flang/Optimizer/Transforms/RewritePatterns.td
new file mode 100644
index 0000000..5ececcf
--- /dev/null
+++ b/include/flang/Optimizer/Transforms/RewritePatterns.td
@@ -0,0 +1,59 @@
+//===-- RewritePatterns.td - FIR Rewrite Patterns -----------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// Defines pattern rewrites for fir optimizations
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_FIR_REWRITE_PATTERNS
+#define FORTRAN_FIR_REWRITE_PATTERNS
+
+include "mlir/IR/OpBase.td"
+include "mlir/Dialect/StandardOps/IR/Ops.td"
+include "flang/Optimizer/Dialect/FIROps.td"
+
+def IdenticalTypePred : Constraint<CPred<"$0.getType() == $1.getType()">>;
+def IntegerTypePred : Constraint<CPred<"fir::isa_integer($0.getType())">>;
+def IndexTypePred : Constraint<CPred<"$0.getType().isa<mlir::IndexType>()">>;
+
+def SmallerWidthPred
+ : Constraint<CPred<"$0.getType().getIntOrFloatBitWidth() "
+ "<= $1.getType().getIntOrFloatBitWidth()">>;
+
+def ConvertConvertOptPattern
+ : Pat<(fir_ConvertOp (fir_ConvertOp $arg)),
+ (fir_ConvertOp $arg),
+ [(IntegerTypePred $arg)]>;
+
+def RedundantConvertOptPattern
+ : Pat<(fir_ConvertOp:$res $arg),
+ (replaceWithValue $arg),
+ [(IdenticalTypePred $res, $arg)
+ ,(IntegerTypePred $arg)]>;
+
+def CombineConvertOptPattern
+ : Pat<(fir_ConvertOp:$res(fir_ConvertOp:$irm $arg)),
+ (replaceWithValue $arg),
+ [(IdenticalTypePred $res, $arg)
+ ,(IntegerTypePred $arg)
+ ,(IntegerTypePred $irm)
+ ,(SmallerWidthPred $arg, $irm)]>;
+
+def createConstantOp
+ : NativeCodeCall<"$_builder.create<mlir::ConstantOp>"
+ "($_loc, $_builder.getIndexType(), "
+ "rewriter.getIndexAttr($1.dyn_cast<IntegerAttr>().getInt()))">;
+
+def ForwardConstantConvertPattern
+ : Pat<(fir_ConvertOp:$res (ConstantOp:$cnt $attr)),
+ (createConstantOp $res, $attr),
+ [(IndexTypePred $res)
+ ,(IntegerTypePred $cnt)]>;
+
+#endif // FORTRAN_FIR_REWRITE_PATTERNS
diff --git a/lib/Optimizer/Dialect/FIROps.cpp b/lib/Optimizer/Dialect/FIROps.cpp
index 33db64c..294ade9 100644
--- a/lib/Optimizer/Dialect/FIROps.cpp
+++ b/lib/Optimizer/Dialect/FIROps.cpp
@@ -24,6 +24,9 @@
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
+namespace {
+#include "flang/Optimizer/Transforms/RewritePatterns.inc"
+} // namespace
using namespace fir;
/// Return true if a sequence type is of some incomplete size or a record type
@@ -773,7 +776,11 @@
//===----------------------------------------------------------------------===//
void fir::ConvertOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {}
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<ConvertConvertOptPattern, RedundantConvertOptPattern,
+ CombineConvertOptPattern, ForwardConstantConvertPattern>(
+ context);
+}
mlir::OpFoldResult fir::ConvertOp::fold(llvm::ArrayRef<mlir::Attribute> opnds) {
if (value().getType() == getType())
diff --git a/lib/Optimizer/Transforms/CMakeLists.txt b/lib/Optimizer/Transforms/CMakeLists.txt
index 99b022e..9d3448a 100644
--- a/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/lib/Optimizer/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@
FIRDialect
FIRSupport
FIROptTransformsPassIncGen
+ RewritePatternsIncGen
LINK_LIBS
FIRDialect
diff --git a/test/Fir/convert-fold.fir b/test/Fir/convert-fold.fir
new file mode 100644
index 0000000..79959bb
--- /dev/null
+++ b/test/Fir/convert-fold.fir
@@ -0,0 +1,37 @@
+// RUN: fir-opt --canonicalize %s | FileCheck %s
+
+// CHECK-LABEL: @ftest
+func @ftest(%x : i1) -> i1 {
+ // this pair of converts should be folded and DCEd
+ %1 = fir.convert %x : (i1) -> !fir.logical<1>
+ %2 = fir.convert %1 : (!fir.logical<1>) -> i1
+ // CHECK-NEXT: return %{{.*}} : i1
+ return %2 : i1
+}
+
+// CHECK-LABEL: @gtest
+func @gtest(%x : !fir.logical<2>) -> !fir.logical<2> {
+ // this pair of converts should be folded and DCEd
+ %1 = fir.convert %x : (!fir.logical<2>) -> i1
+ %2 = fir.convert %1 : (i1) -> !fir.logical<2>
+ // CHECK-NEXT: return %{{.*}} : !fir.logical<2>
+ return %2 : !fir.logical<2>
+}
+
+// CHECK-LABEL: @htest
+func @htest(%x : !fir.int<4>) -> !fir.int<4> {
+ // these converts are NOPs and should be folded away
+ %1 = fir.convert %x : (!fir.int<4>) -> !fir.int<4>
+ %2 = fir.convert %1 : (!fir.int<4>) -> !fir.int<4>
+ // CHECK-NEXT: return %{{.*}} : !fir.int<4>
+ return %2 : !fir.int<4>
+}
+
+// CHECK-LABEL: @ctest
+func @ctest() -> index {
+ %1 = constant 10 : i32
+ %2 = fir.convert %1 : (i32) -> index
+ // CHECK-NEXT: %{{.*}} = constant 10 : index
+ // CHECK-NEXT: return %{{.*}} : index
+ return %2 : index
+}