[mlir][Transforms] Dialect conversion: add `originalType` param to materializations (#112128)

This commit adds an optional `originalType` parameter to target
materialization functions. Without this parameter, target
materializations are underspecified.

Note: `originalType` is only needed for target materializations.
Source/argument materializations do not have it.

Consider the following example: Let's assume that a conversion pattern
"P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then
a different conversion pattern "P2" matches an op that has "v1" as an
operand. Let's furthermore assume that "P2" determines that the
legalized type of "t1" is "t3", which may be different from "t2". In
this example, the target materialization callback will be invoked with:
outputType = "t3", inputs = "v2", originalType = "t1". Note that the
original type "t1" cannot be recovered from just "t3" and "v2"; that's
why the `originalType` parameter is added.

This change is in preparation of merging the 1:1 and 1:N dialect
conversion drivers. As part of that change, argument materializations
will be removed (as they are no longer needed; they were just a
workaround because of missing 1:N support in the dialect conversion).
The new `originalType` parameter is needed when lowering MemRef to LLVM.
During that lowering, MemRef function block arguments are replaced with
the elements that make up a MemRef descriptor. The type converter is set
up in such a way that the legalized type of a MemRef type is an
`!llvm.struct` that represents the MemRef descriptor. When the bare
pointer calling convention is enabled, the function block arguments
consist of just an LLVM pointer. In such a case, a target
materialization will be invoked to construct a MemRef descriptor (output
type = `!llvm.struct<...>`) from just the bare pointer (inputs =
`!llvm.ptr`). The original MemRef type is required to construct the
MemRef descriptor, as static sizes/strides/offset cannot be inferred
from just the bare pointer.
GitOrigin-RevId: 0d906a425444e0205be8d19e585abe7caa808ba0
diff --git a/include/mlir/Transforms/DialectConversion.h b/include/mlir/Transforms/DialectConversion.h
index 65e279e..45ad6f8 100644
--- a/include/mlir/Transforms/DialectConversion.h
+++ b/include/mlir/Transforms/DialectConversion.h
@@ -138,7 +138,8 @@
   };
 
   /// Register a conversion function. A conversion function must be convertible
-  /// to any of the following forms(where `T` is a class derived from `Type`:
+  /// to any of the following forms (where `T` is a class derived from `Type`):
+  ///
   ///   * std::optional<Type>(T)
   ///     - This form represents a 1-1 type conversion. It should return nullptr
   ///       or `std::nullopt` to signify failure. If `std::nullopt` is returned,
@@ -151,15 +152,7 @@
   ///       existing value are expected to be removed during conversion. If
   ///       `std::nullopt` is returned, the converter is allowed to try another
   ///       conversion function to perform the conversion.
-  ///   * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &,
-  ///                                  ArrayRef<Type>)
-  ///     - This form represents a 1-N type conversion supporting recursive
-  ///       types. The first two arguments and the return value are the same as
-  ///       for the regular 1-N form. The third argument is contains is the
-  ///       "call stack" of the recursive conversion: it contains the list of
-  ///       types currently being converted, with the current type being the
-  ///       last one. If it is present more than once in the list, the
-  ///       conversion concerns a recursive type.
+  ///
   /// Note: When attempting to convert a type, e.g. via 'convertType', the
   ///       mostly recently added conversions will be invoked first.
   template <typename FnT, typename T = typename llvm::function_traits<
@@ -178,6 +171,9 @@
   /// it failed but other materialization can be attempted, and `nullptr` on
   /// unrecoverable failure. Materialization functions must be provided when a
   /// type conversion may persist after the conversion has finished.
+  ///
+  /// Note: Target materializations may optionally accept an additional Type
+  /// parameter, which is the original type of the SSA value.
 
   /// This method registers a materialization that will be called when
   /// converting (potentially multiple) block arguments that were the result of
@@ -203,11 +199,22 @@
 
   /// This method registers a materialization that will be called when
   /// converting an illegal (source) value to a legal (target) type.
+  ///
+  /// Note: For target materializations, users can optionally take the original
+  /// type. This type may be different from the type of the input. For example,
+  /// let's assume that a conversion pattern "P1" replaced an SSA value "v1"
+  /// (type "t1") with "v2" (type "t2"). Then a different conversion pattern
+  /// "P2" matches an op that has "v1" as an operand. Let's furthermore assume
+  /// that "P2" determines that the legalized type of "t1" is "t3", which may
+  /// be different from "t2". In this example, the target materialization
+  /// will be invoked with: outputType = "t3", inputs = "v2",
+  // originalType = "t1". Note  that the original type "t1" cannot be recovered
+  /// from just "t3" and "v2"; that's why the originalType parameter exists.
   template <typename FnT, typename T = typename llvm::function_traits<
                               std::decay_t<FnT>>::template arg_t<1>>
   void addTargetMaterialization(FnT &&callback) {
     targetMaterializations.emplace_back(
-        wrapMaterialization<T>(std::forward<FnT>(callback)));
+        wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
   }
 
   /// Register a conversion function for attributes within types. Type
@@ -303,21 +310,12 @@
   /// `add*Materialization` for more information on the context for these
   /// methods.
   Value materializeArgumentConversion(OpBuilder &builder, Location loc,
-                                      Type resultType,
-                                      ValueRange inputs) const {
-    return materializeConversion(argumentMaterializations, builder, loc,
-                                 resultType, inputs);
-  }
+                                      Type resultType, ValueRange inputs) const;
   Value materializeSourceConversion(OpBuilder &builder, Location loc,
-                                    Type resultType, ValueRange inputs) const {
-    return materializeConversion(sourceMaterializations, builder, loc,
-                                 resultType, inputs);
-  }
+                                    Type resultType, ValueRange inputs) const;
   Value materializeTargetConversion(OpBuilder &builder, Location loc,
-                                    Type resultType, ValueRange inputs) const {
-    return materializeConversion(targetMaterializations, builder, loc,
-                                 resultType, inputs);
-  }
+                                    Type resultType, ValueRange inputs,
+                                    Type originalType = {}) const;
 
   /// Convert an attribute present `attr` from within the type `type` using
   /// the registered conversion functions. If no applicable conversion has been
@@ -333,21 +331,23 @@
   using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
       Type, SmallVectorImpl<Type> &)>;
 
-  /// The signature of the callback used to materialize a conversion.
+  /// The signature of the callback used to materialize a source/argument
+  /// conversion.
+  ///
+  /// Arguments: builder, result type, inputs, location
   using MaterializationCallbackFn = std::function<std::optional<Value>(
       OpBuilder &, Type, ValueRange, Location)>;
 
+  /// The signature of the callback used to materialize a target conversion.
+  ///
+  /// Arguments: builder, result type, inputs, location, original type
+  using TargetMaterializationCallbackFn = std::function<std::optional<Value>(
+      OpBuilder &, Type, ValueRange, Location, Type)>;
+
   /// The signature of the callback used to convert a type attribute.
   using TypeAttributeConversionCallbackFn =
       std::function<AttributeConversionResult(Type, Attribute)>;
 
-  /// Attempt to materialize a conversion using one of the provided
-  /// materialization functions.
-  Value
-  materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
-                        OpBuilder &builder, Location loc, Type resultType,
-                        ValueRange inputs) const;
-
   /// Generate a wrapper for the given callback. This allows for accepting
   /// different callback forms, that all compose into a single version.
   /// With callback of form: `std::optional<Type>(T)`
@@ -388,9 +388,10 @@
     cachedMultiConversions.clear();
   }
 
-  /// Generate a wrapper for the given materialization callback. The callback
-  /// may take any subclass of `Type` and the wrapper will check for the target
-  /// type to be of the expected class before calling the callback.
+  /// Generate a wrapper for the given argument/source materialization
+  /// callback. The callback may take any subclass of `Type` and the
+  /// wrapper will check for the target type to be of the expected class
+  /// before calling the callback.
   template <typename T, typename FnT>
   MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
     return [callback = std::forward<FnT>(callback)](
@@ -402,6 +403,41 @@
     };
   }
 
+  /// Generate a wrapper for the given target materialization callback.
+  /// The callback may take any subclass of `Type` and the wrapper will check
+  /// for the target type to be of the expected class before calling the
+  /// callback.
+  ///
+  /// With callback of form:
+  /// `Value(OpBuilder &, T, ValueRange, Location, Type)`
+  template <typename T, typename FnT>
+  std::enable_if_t<
+      std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
+      TargetMaterializationCallbackFn>
+  wrapTargetMaterialization(FnT &&callback) const {
+    return [callback = std::forward<FnT>(callback)](
+               OpBuilder &builder, Type resultType, ValueRange inputs,
+               Location loc, Type originalType) -> std::optional<Value> {
+      if (T derivedType = dyn_cast<T>(resultType))
+        return callback(builder, derivedType, inputs, loc, originalType);
+      return std::nullopt;
+    };
+  }
+  /// With callback of form:
+  /// `Value(OpBuilder &, T, ValueRange, Location)`
+  template <typename T, typename FnT>
+  std::enable_if_t<
+      std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
+      TargetMaterializationCallbackFn>
+  wrapTargetMaterialization(FnT &&callback) const {
+    return wrapTargetMaterialization<T>(
+        [callback = std::forward<FnT>(callback)](
+            OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
+            Type originalType) -> std::optional<Value> {
+          return callback(builder, resultType, inputs, loc);
+        });
+  }
+
   /// Generate a wrapper for the given memory space conversion callback. The
   /// callback may take any subclass of `Attribute` and the wrapper will check
   /// for the target attribute to be of the expected class before calling the
@@ -434,7 +470,7 @@
   /// The list of registered materialization functions.
   SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
   SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
-  SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
+  SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
 
   /// The list of registered type attribute conversion functions.
   SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
diff --git a/lib/Transforms/Utils/DialectConversion.cpp b/lib/Transforms/Utils/DialectConversion.cpp
index 97dd3ab..1baddd8 100644
--- a/lib/Transforms/Utils/DialectConversion.cpp
+++ b/lib/Transforms/Utils/DialectConversion.cpp
@@ -683,10 +683,10 @@
 /// conversion.
 class UnresolvedMaterializationRewrite : public OperationRewrite {
 public:
-  UnresolvedMaterializationRewrite(
-      ConversionPatternRewriterImpl &rewriterImpl,
-      UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
-      MaterializationKind kind = MaterializationKind::Target);
+  UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                                   UnrealizedConversionCastOp op,
+                                   const TypeConverter *converter,
+                                   MaterializationKind kind, Type originalType);
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -708,11 +708,18 @@
     return converterAndKind.getInt();
   }
 
+  /// Return the original type of the SSA value.
+  Type getOriginalType() const { return originalType; }
+
 private:
   /// The corresponding type converter to use when resolving this
   /// materialization, and the kind of this materialization.
   llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
       converterAndKind;
+
+  /// The original type of the SSA value. Only used for target
+  /// materializations.
+  Type originalType;
 };
 } // namespace
 
@@ -808,6 +815,7 @@
   Value buildUnresolvedMaterialization(MaterializationKind kind,
                                        OpBuilder::InsertPoint ip, Location loc,
                                        ValueRange inputs, Type outputType,
+                                       Type originalType,
                                        const TypeConverter *converter);
 
   //===--------------------------------------------------------------------===//
@@ -1034,9 +1042,12 @@
 
 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
     ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
-    const TypeConverter *converter, MaterializationKind kind)
+    const TypeConverter *converter, MaterializationKind kind, Type originalType)
     : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
-      converterAndKind(converter, kind) {
+      converterAndKind(converter, kind), originalType(originalType) {
+  assert(!originalType ||
+         kind == MaterializationKind::Target &&
+             "original type is valid only for target materializations");
   rewriterImpl.unresolvedMaterializations[op] = this;
 }
 
@@ -1139,7 +1150,7 @@
       Value castValue = buildUnresolvedMaterialization(
           MaterializationKind::Target, computeInsertPoint(newOperand),
           operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
-          currentTypeConverter);
+          /*originalType=*/origType, currentTypeConverter);
       mapping.map(newOperand, castValue);
       newOperand = castValue;
     }
@@ -1255,7 +1266,7 @@
           MaterializationKind::Source,
           OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
           /*inputs=*/ValueRange(),
-          /*outputType=*/origArgType, converter);
+          /*outputType=*/origArgType, /*originalType=*/Type(), converter);
       mapping.map(origArg, repl);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
       continue;
@@ -1280,7 +1291,8 @@
     Value argMat = buildUnresolvedMaterialization(
         MaterializationKind::Argument,
         OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-        /*inputs=*/replArgs, origArgType, converter);
+        /*inputs=*/replArgs, /*outputType=*/origArgType,
+        /*originalType=*/Type(), converter);
     mapping.map(origArg, argMat);
 
     Type legalOutputType;
@@ -1299,7 +1311,8 @@
     if (legalOutputType && legalOutputType != origArgType) {
       Value targetMat = buildUnresolvedMaterialization(
           MaterializationKind::Target, computeInsertPoint(argMat),
-          origArg.getLoc(), argMat, legalOutputType, converter);
+          origArg.getLoc(), /*inputs=*/argMat, /*outputType=*/legalOutputType,
+          /*originalType=*/origArgType, converter);
       mapping.map(argMat, targetMat);
     }
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1322,7 +1335,12 @@
 /// of input operands.
 Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-    ValueRange inputs, Type outputType, const TypeConverter *converter) {
+    ValueRange inputs, Type outputType, Type originalType,
+    const TypeConverter *converter) {
+  assert(!originalType ||
+         kind == MaterializationKind::Target &&
+             "original type is valid only for target materializations");
+
   // Avoid materializing an unnecessary cast.
   if (inputs.size() == 1 && inputs.front().getType() == outputType)
     return inputs.front();
@@ -1333,7 +1351,8 @@
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
-  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
+  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
+                                                  originalType);
   return convertOp.getResult(0);
 }
 
@@ -1381,7 +1400,8 @@
       newValue = buildUnresolvedMaterialization(
           MaterializationKind::Source, computeInsertPoint(result),
           result.getLoc(), /*inputs=*/ValueRange(),
-          /*outputType=*/result.getType(), currentTypeConverter);
+          /*outputType=*/result.getType(), /*originalType=*/Type(),
+          currentTypeConverter);
     }
 
     // Remap, and check for any result type changes.
@@ -2408,7 +2428,8 @@
       [[fallthrough]];
     case MaterializationKind::Target:
       newMaterialization = converter->materializeTargetConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
+          rewriter, op->getLoc(), outputType, inputOperands,
+          rewrite->getOriginalType());
       break;
     case MaterializationKind::Source:
       newMaterialization = converter->materializeSourceConversion(
@@ -2565,7 +2586,7 @@
           MaterializationKind::Source, computeInsertPoint(newValue),
           originalValue.getLoc(),
           /*inputs=*/newValue, /*outputType=*/originalValue.getType(),
-          converter);
+          /*originalType=*/Type(), converter);
       rewriterImpl.mapping.map(originalValue, castValue);
       inverseMapping[castValue].push_back(originalValue);
       llvm::erase(inverseMapping[newValue], originalValue);
@@ -2787,15 +2808,39 @@
   return success();
 }
 
-Value TypeConverter::materializeConversion(
-    ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
-    Location loc, Type resultType, ValueRange inputs) const {
-  for (const MaterializationCallbackFn &fn : llvm::reverse(materializations))
+Value TypeConverter::materializeArgumentConversion(OpBuilder &builder,
+                                                   Location loc,
+                                                   Type resultType,
+                                                   ValueRange inputs) const {
+  for (const MaterializationCallbackFn &fn :
+       llvm::reverse(argumentMaterializations))
     if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
       return *result;
   return nullptr;
 }
 
+Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
+                                                 Location loc, Type resultType,
+                                                 ValueRange inputs) const {
+  for (const MaterializationCallbackFn &fn :
+       llvm::reverse(sourceMaterializations))
+    if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
+      return *result;
+  return nullptr;
+}
+
+Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
+                                                 Location loc, Type resultType,
+                                                 ValueRange inputs,
+                                                 Type originalType) const {
+  for (const TargetMaterializationCallbackFn &fn :
+       llvm::reverse(targetMaterializations))
+    if (std::optional<Value> result =
+            fn(builder, resultType, inputs, loc, originalType))
+      return *result;
+  return nullptr;
+}
+
 std::optional<TypeConverter::SignatureConversion>
 TypeConverter::convertBlockSignature(Block *block) const {
   SignatureConversion conversion(block->getNumArguments());