[mlir][Transforms] Detect mapping overwrites during block signature conversion
Add extra assertions to make sure that a value in the conversion value mapping is not overwritten during `applySignatureConversion`.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 4904d3c..94e61a2 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -176,6 +176,8 @@
template <typename OldVal, typename NewVal>
std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
map(OldVal &&oldVal, NewVal &&newVal) {
+ assert(!mapping.contains(oldVal) &&
+ "attempting to overwrite existing mapping");
LLVM_DEBUG({
ValueVector next(newVal);
while (true) {
@@ -1412,6 +1414,7 @@
for (unsigned i = 0; i != origArgCount; ++i) {
BlockArgument origArg = block->getArgument(i);
Type origArgType = origArg.getType();
+ ValueVector currentMapping = mapping.lookupOrDefault(origArg);
std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
signatureConversion.getInputMapping(i);
@@ -1421,7 +1424,7 @@
buildUnresolvedMaterialization(
MaterializationKind::Source,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
- /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
+ /*valuesToMap=*/currentMapping, /*inputs=*/ValueRange(),
/*outputType=*/origArgType, /*originalType=*/Type(), converter);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
continue;
@@ -1432,7 +1435,7 @@
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
- mapping.map(origArg, repl);
+ mapping.map(currentMapping, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
continue;
}
@@ -1441,7 +1444,7 @@
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
- mapping.map(origArg, std::move(replArgVals));
+ mapping.map(currentMapping, std::move(replArgVals));
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
}
@@ -1757,6 +1760,8 @@
<< "'(in region of '" << parentOp->getName()
<< "'(" << from.getOwner()->getParentOp() << ")\n";
});
+ llvm::errs() << "replaceUsesOfBlockArgument: " << from.getOwner() << "\n";
+
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
impl->currentTypeConverter);
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);