blob: 3686568244e2fa9c8bfeeffed99166fe879d78c2 [file] [log] [blame]
//===- DecomposeCallGraphTypes.cpp - CG type decomposition ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// ValueDecomposer
//===----------------------------------------------------------------------===//
void ValueDecomposer::decomposeValue(OpBuilder &builder, Location loc,
Type type, Value value,
SmallVectorImpl<Value> &results) {
for (auto &conversion : decomposeValueConversions)
if (conversion(builder, loc, type, value, results))
return;
results.push_back(value);
}
//===----------------------------------------------------------------------===//
// DecomposeCallGraphTypesOpConversionPattern
//===----------------------------------------------------------------------===//
namespace {
/// Base OpConversionPattern class to make a ValueDecomposer available to
/// inherited patterns.
template <typename SourceOp>
class DecomposeCallGraphTypesOpConversionPattern
: public OpConversionPattern<SourceOp> {
public:
DecomposeCallGraphTypesOpConversionPattern(TypeConverter &typeConverter,
MLIRContext *context,
ValueDecomposer &decomposer,
PatternBenefit benefit = 1)
: OpConversionPattern<SourceOp>(typeConverter, context, benefit),
decomposer(decomposer) {}
protected:
ValueDecomposer &decomposer;
};
} // namespace
//===----------------------------------------------------------------------===//
// DecomposeCallGraphTypesForFuncArgs
//===----------------------------------------------------------------------===//
namespace {
/// Expand function arguments according to the provided TypeConverter and
/// ValueDecomposer.
struct DecomposeCallGraphTypesForFuncArgs
: public DecomposeCallGraphTypesOpConversionPattern<FuncOp> {
using DecomposeCallGraphTypesOpConversionPattern::
DecomposeCallGraphTypesOpConversionPattern;
LogicalResult
matchAndRewrite(FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto functionType = op.getType();
// Convert function arguments using the provided TypeConverter.
TypeConverter::SignatureConversion conversion(functionType.getNumInputs());
for (auto argType : llvm::enumerate(functionType.getInputs())) {
SmallVector<Type, 2> decomposedTypes;
if (failed(typeConverter->convertType(argType.value(), decomposedTypes)))
return failure();
if (!decomposedTypes.empty())
conversion.addInputs(argType.index(), decomposedTypes);
}
// If the SignatureConversion doesn't apply, bail out.
if (failed(rewriter.convertRegionTypes(&op.getBody(), *getTypeConverter(),
&conversion)))
return failure();
// Update the signature of the function.
SmallVector<Type, 2> newResultTypes;
if (failed(typeConverter->convertTypes(functionType.getResults(),
newResultTypes)))
return failure();
rewriter.updateRootInPlace(op, [&] {
op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
newResultTypes));
});
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// DecomposeCallGraphTypesForReturnOp
//===----------------------------------------------------------------------===//
namespace {
/// Expand return operands according to the provided TypeConverter and
/// ValueDecomposer.
struct DecomposeCallGraphTypesForReturnOp
: public DecomposeCallGraphTypesOpConversionPattern<ReturnOp> {
using DecomposeCallGraphTypesOpConversionPattern::
DecomposeCallGraphTypesOpConversionPattern;
LogicalResult
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
SmallVector<Value, 2> newOperands;
for (Value operand : adaptor.getOperands())
decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
operand, newOperands);
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// DecomposeCallGraphTypesForCallOp
//===----------------------------------------------------------------------===//
namespace {
/// Expand call op operands and results according to the provided TypeConverter
/// and ValueDecomposer.
struct DecomposeCallGraphTypesForCallOp
: public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
using DecomposeCallGraphTypesOpConversionPattern::
DecomposeCallGraphTypesOpConversionPattern;
LogicalResult
matchAndRewrite(CallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
// Create the operands list of the new `CallOp`.
SmallVector<Value, 2> newOperands;
for (Value operand : adaptor.getOperands())
decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
operand, newOperands);
// Create the new result types for the new `CallOp` and track the indices in
// the new call op's results that correspond to the old call op's results.
//
// expandedResultIndices[i] = "list of new result indices that old result i
// expanded to".
SmallVector<Type, 2> newResultTypes;
SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices;
for (Type resultType : op.getResultTypes()) {
unsigned oldSize = newResultTypes.size();
if (failed(typeConverter->convertType(resultType, newResultTypes)))
return failure();
auto &resultMapping = expandedResultIndices.emplace_back();
for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++)
resultMapping.push_back(i);
}
CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
newResultTypes, newOperands);
// Build a replacement value for each result to replace its uses. If a
// result has multiple mapping values, it needs to be materialized as a
// single value.
SmallVector<Value, 2> replacedValues;
replacedValues.reserve(op.getNumResults());
for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
auto decomposedValues = llvm::to_vector<6>(
llvm::map_range(expandedResultIndices[i],
[&](unsigned i) { return newCallOp.getResult(i); }));
if (decomposedValues.empty()) {
// No replacement is required.
replacedValues.push_back(nullptr);
} else if (decomposedValues.size() == 1) {
replacedValues.push_back(decomposedValues.front());
} else {
// Materialize a single Value to replace the original Value.
Value materialized = getTypeConverter()->materializeArgumentConversion(
rewriter, op.getLoc(), op.getType(i), decomposedValues);
replacedValues.push_back(materialized);
}
}
rewriter.replaceOp(op, replacedValues);
return success();
}
};
} // namespace
void mlir::populateDecomposeCallGraphTypesPatterns(
MLIRContext *context, TypeConverter &typeConverter,
ValueDecomposer &decomposer, RewritePatternSet &patterns) {
patterns
.add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
decomposer);
}