blob: f1ec2be72a33abadb8c0a690c459749d4d8570a9 [file] [log] [blame]
//===- PolynomialAttributes.cpp - Polynomial dialect attrs ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
namespace mlir {
namespace polynomial {
void PolynomialAttr::print(AsmPrinter &p) const {
p << '<';
p << getPolynomial();
p << '>';
}
/// Try to parse a monomial. If successful, populate the fields of the outparam
/// `monomial` with the results, and the `variable` outparam with the parsed
/// variable name. Sets shouldParseMore to true if the monomial is followed by
/// a '+'.
ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
llvm::StringRef &variable, bool &isConstantTerm,
bool &shouldParseMore) {
APInt parsedCoeff(apintBitWidth, 1);
auto parsedCoeffResult = parser.parseOptionalInteger(parsedCoeff);
monomial.coefficient = parsedCoeff;
isConstantTerm = false;
shouldParseMore = false;
// A + indicates it's a constant term with more to go, as in `1 + x`.
if (succeeded(parser.parseOptionalPlus())) {
// If no coefficient was parsed, and there's a +, then it's effectively
// parsing an empty string.
if (!parsedCoeffResult.has_value()) {
return failure();
}
monomial.exponent = APInt(apintBitWidth, 0);
isConstantTerm = true;
shouldParseMore = true;
return success();
}
// A monomial can be a trailing constant term, as in `x + 1`.
if (failed(parser.parseOptionalKeyword(&variable))) {
// If neither a coefficient nor a variable was found, then it's effectively
// parsing an empty string.
if (!parsedCoeffResult.has_value()) {
return failure();
}
monomial.exponent = APInt(apintBitWidth, 0);
isConstantTerm = true;
return success();
}
// Parse exponentiation symbol as `**`. We can't use caret because it's
// reserved for basic block identifiers If no star is present, it's treated
// as a polynomial with exponent 1.
if (succeeded(parser.parseOptionalStar())) {
// If there's one * there must be two.
if (failed(parser.parseStar())) {
return failure();
}
// If there's a **, then the integer exponent is required.
APInt parsedExponent(apintBitWidth, 0);
if (failed(parser.parseInteger(parsedExponent))) {
parser.emitError(parser.getCurrentLocation(),
"found invalid integer exponent");
return failure();
}
monomial.exponent = parsedExponent;
} else {
monomial.exponent = APInt(apintBitWidth, 1);
}
if (succeeded(parser.parseOptionalPlus())) {
shouldParseMore = true;
}
return success();
}
Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};
llvm::SmallVector<Monomial> monomials;
llvm::StringSet<> variables;
while (true) {
Monomial parsedMonomial;
llvm::StringRef parsedVariableRef;
bool isConstantTerm;
bool shouldParseMore;
if (failed(parseMonomial(parser, parsedMonomial, parsedVariableRef,
isConstantTerm, shouldParseMore))) {
parser.emitError(parser.getCurrentLocation(), "expected a monomial");
return {};
}
if (!isConstantTerm) {
std::string parsedVariable = parsedVariableRef.str();
variables.insert(parsedVariable);
}
monomials.push_back(parsedMonomial);
if (shouldParseMore)
continue;
if (succeeded(parser.parseOptionalGreater())) {
break;
}
parser.emitError(
parser.getCurrentLocation(),
"expected + and more monomials, or > to end polynomial attribute");
return {};
}
if (variables.size() > 1) {
std::string vars = llvm::join(variables.keys(), ", ");
parser.emitError(
parser.getCurrentLocation(),
"polynomials must have one indeterminate, but there were multiple: " +
vars);
}
auto result = Polynomial::fromMonomials(monomials);
if (failed(result)) {
parser.emitError(parser.getCurrentLocation())
<< "parsed polynomial must have unique exponents among monomials";
return {};
}
return PolynomialAttr::get(parser.getContext(), result.value());
}
void RingAttr::print(AsmPrinter &p) const {
p << "#polynomial.ring<coefficientType=" << getCoefficientType()
<< ", coefficientModulus=" << getCoefficientModulus()
<< ", polynomialModulus=" << getPolynomialModulus() << '>';
}
Attribute RingAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()))
return {};
if (failed(parser.parseKeyword("coefficientType")))
return {};
if (failed(parser.parseEqual()))
return {};
Type ty;
if (failed(parser.parseType(ty)))
return {};
if (failed(parser.parseComma()))
return {};
IntegerAttr coefficientModulusAttr = nullptr;
if (succeeded(parser.parseKeyword("coefficientModulus"))) {
if (failed(parser.parseEqual()))
return {};
IntegerType iType = mlir::dyn_cast<IntegerType>(ty);
if (!iType) {
parser.emitError(parser.getCurrentLocation(),
"coefficientType must specify an integer type");
return {};
}
APInt coefficientModulus(iType.getWidth(), 0);
auto result = parser.parseInteger(coefficientModulus);
if (failed(result)) {
parser.emitError(parser.getCurrentLocation(),
"invalid coefficient modulus");
return {};
}
coefficientModulusAttr = IntegerAttr::get(iType, coefficientModulus);
if (failed(parser.parseComma()))
return {};
}
PolynomialAttr polyAttr = nullptr;
if (succeeded(parser.parseKeyword("polynomialModulus"))) {
if (failed(parser.parseEqual()))
return {};
PolynomialAttr attr;
if (failed(parser.parseAttribute<PolynomialAttr>(attr)))
return {};
polyAttr = attr;
}
if (failed(parser.parseGreater()))
return {};
return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr,
polyAttr);
}
} // namespace polynomial
} // namespace mlir