blob: 9f6632164b04e1c9d0e4df82269e1d4a7467bb95 [file] [log] [blame]
//===- ArmSVEDialect.cpp - MLIR ArmSVE dialect implementation -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the ArmSVE dialect and its operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace arm_sve;
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.cpp.inc"
static Type getI1SameShape(Type type);
static void buildScalableCmpIOp(OpBuilder &build, OperationState &result,
arith::CmpIPredicate predicate, Value lhs,
Value rhs);
static void buildScalableCmpFOp(OpBuilder &build, OperationState &result,
arith::CmpFPredicate predicate, Value lhs,
Value rhs);
#define GET_OP_CLASSES
#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
void ArmSVEDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc"
>();
addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc"
>();
}
//===----------------------------------------------------------------------===//
// ScalableVectorType
//===----------------------------------------------------------------------===//
Type ArmSVEDialect::parseType(DialectAsmParser &parser) const {
llvm::SMLoc typeLoc = parser.getCurrentLocation();
{
Type genType;
auto parseResult = generatedTypeParser(parser, "vector", genType);
if (parseResult.hasValue())
return genType;
}
parser.emitError(typeLoc, "unknown type in ArmSVE dialect");
return Type();
}
void ArmSVEDialect::printType(Type type, DialectAsmPrinter &os) const {
if (failed(generatedTypePrinter(type, os)))
llvm_unreachable("unexpected 'arm_sve' type kind");
}
//===----------------------------------------------------------------------===//
// ScalableVector versions of general helpers for comparison ops
//===----------------------------------------------------------------------===//
// Return the scalable vector of the same shape and containing i1.
static Type getI1SameShape(Type type) {
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto sVectorType = type.dyn_cast<ScalableVectorType>())
return ScalableVectorType::get(type.getContext(), sVectorType.getShape(),
i1Type);
return nullptr;
}
//===----------------------------------------------------------------------===//
// CmpFOp
//===----------------------------------------------------------------------===//
static void buildScalableCmpFOp(OpBuilder &build, OperationState &result,
arith::CmpFPredicate predicate, Value lhs,
Value rhs) {
result.addOperands({lhs, rhs});
result.types.push_back(getI1SameShape(lhs.getType()));
result.addAttribute(ScalableCmpFOp::getPredicateAttrName(),
build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
}
static void buildScalableCmpIOp(OpBuilder &build, OperationState &result,
arith::CmpIPredicate predicate, Value lhs,
Value rhs) {
result.addOperands({lhs, rhs});
result.types.push_back(getI1SameShape(lhs.getType()));
result.addAttribute(ScalableCmpIOp::getPredicateAttrName(),
build.getI64IntegerAttr(static_cast<int64_t>(predicate)));
}