blob: 84ca14fab48a7a47f50c0b6da2fed385eb6e7852 [file] [log] [blame]
//===- ConvertVectorToLLVM.h - Utils to convert from the vector dialect ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_
#define MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
class LLVMTypeConverter;
class ModuleOp;
template <typename T>
class OperationPass;
/// Options to control Vector to LLVM lowering.
///
/// This should kept in sync with VectorToLLVM options defined for the
/// ConvertVectorToLLVM pass in include/mlir/Conversion/Passes.td
struct LowerVectorToLLVMOptions {
LowerVectorToLLVMOptions()
: reassociateFPReductions(false), indexOptimizations(true),
armNeon(false), armSVE(false), amx(false), x86Vector(false) {}
LowerVectorToLLVMOptions &enableReassociateFPReductions(bool b = true) {
reassociateFPReductions = b;
return *this;
}
LowerVectorToLLVMOptions &enableIndexOptimizations(bool b = true) {
indexOptimizations = b;
return *this;
}
LowerVectorToLLVMOptions &enableArmNeon(bool b = true) {
armNeon = b;
return *this;
}
LowerVectorToLLVMOptions &enableArmSVE(bool b = true) {
armSVE = b;
return *this;
}
LowerVectorToLLVMOptions &enableAMX(bool b = true) {
amx = b;
return *this;
}
LowerVectorToLLVMOptions &enableX86Vector(bool b = true) {
x86Vector = b;
return *this;
}
bool reassociateFPReductions;
bool indexOptimizations;
bool armNeon;
bool armSVE;
bool amx;
bool x86Vector;
};
/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics
/// will be needed when invoking LLVM.
void populateVectorToLLVMMatrixConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns,
bool reassociateFPReductions = false);
/// Create a pass to convert vector operations to the LLVMIR dialect.
std::unique_ptr<OperationPass<ModuleOp>> createConvertVectorToLLVMPass(
const LowerVectorToLLVMOptions &options = LowerVectorToLLVMOptions());
} // namespace mlir
#endif // MLIR_CONVERSION_VECTORTOLLVM_CONVERTVECTORTOLLVM_H_