blob: b966003f24fdaed33fa78ccf8992a287425895a6 [file] [log] [blame]
//===- Builders.cpp - MLIR Declarative Builder Classes --------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/EDSC/Builders.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/AffineExpr.h"
#include "llvm/ADT/Optional.h"
using namespace mlir;
using namespace mlir::edsc;
mlir::edsc::ScopedContext::ScopedContext(OpBuilder &builder, Location location)
: builder(builder), location(location),
enclosingScopedContext(ScopedContext::getCurrentScopedContext()),
nestedBuilder(nullptr) {
getCurrentScopedContext() = this;
}
/// Sets the insertion point of the builder to 'newInsertPt' for the duration
/// of the scope. The existing insertion point of the builder is restored on
/// destruction.
mlir::edsc::ScopedContext::ScopedContext(OpBuilder &builder,
OpBuilder::InsertPoint newInsertPt,
Location location)
: builder(builder), prevBuilderInsertPoint(builder.saveInsertionPoint()),
location(location),
enclosingScopedContext(ScopedContext::getCurrentScopedContext()),
nestedBuilder(nullptr) {
getCurrentScopedContext() = this;
builder.restoreInsertionPoint(newInsertPt);
}
mlir::edsc::ScopedContext::~ScopedContext() {
assert(!nestedBuilder &&
"Active NestedBuilder must have been exited at this point!");
if (prevBuilderInsertPoint)
builder.restoreInsertionPoint(*prevBuilderInsertPoint);
getCurrentScopedContext() = enclosingScopedContext;
}
ScopedContext *&mlir::edsc::ScopedContext::getCurrentScopedContext() {
thread_local ScopedContext *context = nullptr;
return context;
}
OpBuilder &mlir::edsc::ScopedContext::getBuilder() {
assert(ScopedContext::getCurrentScopedContext() &&
"Unexpected Null ScopedContext");
return ScopedContext::getCurrentScopedContext()->builder;
}
Location mlir::edsc::ScopedContext::getLocation() {
assert(ScopedContext::getCurrentScopedContext() &&
"Unexpected Null ScopedContext");
return ScopedContext::getCurrentScopedContext()->location;
}
MLIRContext *mlir::edsc::ScopedContext::getContext() {
return getBuilder().getContext();
}
mlir::edsc::ValueHandle::ValueHandle(index_t cst) {
auto &b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
v = b.create<ConstantIndexOp>(loc, cst.v).getResult();
t = v.getType();
}
ValueHandle &mlir::edsc::ValueHandle::operator=(const ValueHandle &other) {
assert(t == other.t && "Wrong type capture");
assert(!v && "ValueHandle has already been captured, use a new name!");
v = other.v;
return *this;
}
ValueHandle
mlir::edsc::ValueHandle::createComposedAffineApply(AffineMap map,
ArrayRef<Value> operands) {
Operation *op =
makeComposedAffineApply(ScopedContext::getBuilder(),
ScopedContext::getLocation(), map, operands)
.getOperation();
assert(op->getNumResults() == 1 && "Not a single result AffineApply");
return ValueHandle(op->getResult(0));
}
ValueHandle ValueHandle::create(StringRef name, ArrayRef<ValueHandle> operands,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes) {
Operation *op =
OperationHandle::create(name, operands, resultTypes, attributes);
if (op->getNumResults() == 1) {
return ValueHandle(op->getResult(0));
}
if (auto f = dyn_cast<AffineForOp>(op)) {
return ValueHandle(f.getInductionVar());
}
llvm_unreachable("unsupported operation, use an OperationHandle instead");
}
OperationHandle OperationHandle::create(StringRef name,
ArrayRef<ValueHandle> operands,
ArrayRef<Type> resultTypes,
ArrayRef<NamedAttribute> attributes) {
OperationState state(ScopedContext::getLocation(), name);
SmallVector<Value, 4> ops(operands.begin(), operands.end());
state.addOperands(ops);
state.addTypes(resultTypes);
for (const auto &attr : attributes) {
state.addAttribute(attr.first, attr.second);
}
return OperationHandle(ScopedContext::getBuilder().createOperation(state));
}
BlockHandle mlir::edsc::BlockHandle::create(ArrayRef<Type> argTypes) {
auto &currentB = ScopedContext::getBuilder();
auto *ib = currentB.getInsertionBlock();
auto ip = currentB.getInsertionPoint();
BlockHandle res;
res.block = ScopedContext::getBuilder().createBlock(ib->getParent());
// createBlock sets the insertion point inside the block.
// We do not want this behavior when using declarative builders with nesting.
currentB.setInsertionPoint(ib, ip);
for (auto t : argTypes) {
res.block->addArgument(t);
}
return res;
}
static Optional<ValueHandle> emitStaticFor(ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs,
int64_t step) {
if (lbs.size() != 1 || ubs.size() != 1)
return Optional<ValueHandle>();
auto *lbDef = lbs.front().getValue().getDefiningOp();
auto *ubDef = ubs.front().getValue().getDefiningOp();
if (!lbDef || !ubDef)
return Optional<ValueHandle>();
auto lbConst = dyn_cast<ConstantIndexOp>(lbDef);
auto ubConst = dyn_cast<ConstantIndexOp>(ubDef);
if (!lbConst || !ubConst)
return Optional<ValueHandle>();
return ValueHandle::create<AffineForOp>(lbConst.getValue(),
ubConst.getValue(), step);
}
mlir::edsc::LoopBuilder mlir::edsc::LoopBuilder::makeAffine(
ValueHandle *iv, ArrayRef<ValueHandle> lbHandles,
ArrayRef<ValueHandle> ubHandles, int64_t step) {
mlir::edsc::LoopBuilder result;
if (auto staticFor = emitStaticFor(lbHandles, ubHandles, step)) {
*iv = staticFor.getValue();
} else {
SmallVector<Value, 4> lbs(lbHandles.begin(), lbHandles.end());
SmallVector<Value, 4> ubs(ubHandles.begin(), ubHandles.end());
*iv = ValueHandle::create<AffineForOp>(
lbs, ScopedContext::getBuilder().getMultiDimIdentityMap(lbs.size()),
ubs, ScopedContext::getBuilder().getMultiDimIdentityMap(ubs.size()),
step);
}
auto *body = getForInductionVarOwner(iv->getValue()).getBody();
result.enter(body, /*prev=*/1);
return result;
}
mlir::edsc::LoopBuilder
mlir::edsc::LoopBuilder::makeLoop(ValueHandle *iv, ValueHandle lbHandle,
ValueHandle ubHandle,
ValueHandle stepHandle) {
mlir::edsc::LoopBuilder result;
auto forOp =
OperationHandle::createOp<loop::ForOp>(lbHandle, ubHandle, stepHandle);
*iv = ValueHandle(forOp.getInductionVar());
auto *body = loop::getForInductionVarOwner(iv->getValue()).getBody();
result.enter(body, /*prev=*/1);
return result;
}
void mlir::edsc::LoopBuilder::operator()(function_ref<void(void)> fun) {
// Call to `exit` must be explicit and asymmetric (cannot happen in the
// destructor) because of ordering wrt comma operator.
/// The particular use case concerns nested blocks:
///
/// ```c++
/// For (&i, lb, ub, 1)({
/// /--- destructor for this `For` is not always called before ...
/// V
/// For (&j1, lb, ub, 1)({
/// some_op_1,
/// }),
/// /--- ... this scope is entered, resulting in improperly nested IR.
/// V
/// For (&j2, lb, ub, 1)({
/// some_op_2,
/// }),
/// });
/// ```
if (fun)
fun();
exit();
}
mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder(
ValueHandle *iv, ArrayRef<ValueHandle> lbs, ArrayRef<ValueHandle> ubs,
int64_t step) {
loops.emplace_back(LoopBuilder::makeAffine(iv, lbs, ubs, step));
}
mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder(
ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps) {
assert(ivs.size() == lbs.size() && "Mismatch in number of arguments");
assert(ivs.size() == ubs.size() && "Mismatch in number of arguments");
assert(ivs.size() == steps.size() && "Mismatch in number of arguments");
for (auto it : llvm::zip(ivs, lbs, ubs, steps))
loops.emplace_back(LoopBuilder::makeAffine(
std::get<0>(it), std::get<1>(it), std::get<2>(it), std::get<3>(it)));
}
void mlir::edsc::AffineLoopNestBuilder::operator()(
function_ref<void(void)> fun) {
if (fun)
fun();
// Iterate on the calling operator() on all the loops in the nest.
// The iteration order is from innermost to outermost because enter/exit needs
// to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit()
// occurs on calling operator()). The asymmetry is required for properly
// nesting imperfectly nested regions (see LoopBuilder::operator()).
for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit)
(*lit)();
}
mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef<ValueHandle *> ivs,
ArrayRef<ValueHandle> lbs,
ArrayRef<ValueHandle> ubs,
ArrayRef<ValueHandle> steps) {
assert(ivs.size() == lbs.size() && "expected size of ivs and lbs to match");
assert(ivs.size() == ubs.size() && "expected size of ivs and ubs to match");
assert(ivs.size() == steps.size() &&
"expected size of ivs and steps to match");
loops.reserve(ivs.size());
for (auto it : llvm::zip(ivs, lbs, ubs, steps)) {
loops.emplace_back(LoopBuilder::makeLoop(std::get<0>(it), std::get<1>(it),
std::get<2>(it), std::get<3>(it)));
}
assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size");
}
void LoopNestBuilder::LoopNestBuilder::operator()(
std::function<void(void)> fun) {
if (fun)
fun();
for (auto &lit : reverse(loops))
lit({});
}
mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle bh, Append) {
assert(bh && "Expected already captured BlockHandle");
enter(bh.getBlock());
}
mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh,
ArrayRef<ValueHandle *> args) {
assert(!*bh && "BlockHandle already captures a block, use "
"the explicit BockBuilder(bh, Append())({}) syntax instead.");
SmallVector<Type, 8> types;
for (auto *a : args) {
assert(!a->hasValue() &&
"Expected delayed ValueHandle that has not yet captured.");
types.push_back(a->getType());
}
*bh = BlockHandle::create(types);
for (auto it : llvm::zip(args, bh->getBlock()->getArguments())) {
*(std::get<0>(it)) = ValueHandle(std::get<1>(it));
}
enter(bh->getBlock());
}
/// Only serves as an ordering point between entering nested block and creating
/// stmts.
void mlir::edsc::BlockBuilder::operator()(function_ref<void(void)> fun) {
// Call to `exit` must be explicit and asymmetric (cannot happen in the
// destructor) because of ordering wrt comma operator.
if (fun)
fun();
exit();
}
template <typename Op>
static ValueHandle createBinaryHandle(ValueHandle lhs, ValueHandle rhs) {
return ValueHandle::create<Op>(lhs.getValue(), rhs.getValue());
}
static std::pair<AffineExpr, Value>
categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims,
unsigned &numSymbols) {
AffineExpr d;
Value resultVal = nullptr;
if (auto constant = dyn_cast_or_null<ConstantIndexOp>(val.getDefiningOp())) {
d = getAffineConstantExpr(constant.getValue(), context);
} else if (isValidSymbol(val) && !isValidDim(val)) {
d = getAffineSymbolExpr(numSymbols++, context);
resultVal = val;
} else {
d = getAffineDimExpr(numDims++, context);
resultVal = val;
}
return std::make_pair(d, resultVal);
}
static ValueHandle createBinaryIndexHandle(
ValueHandle lhs, ValueHandle rhs,
function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
MLIRContext *context = ScopedContext::getContext();
unsigned numDims = 0, numSymbols = 0;
AffineExpr d0, d1;
Value v0, v1;
std::tie(d0, v0) =
categorizeValueByAffineType(context, lhs.getValue(), numDims, numSymbols);
std::tie(d1, v1) =
categorizeValueByAffineType(context, rhs.getValue(), numDims, numSymbols);
SmallVector<Value, 2> operands;
if (v0) {
operands.push_back(v0);
}
if (v1) {
operands.push_back(v1);
}
auto map = AffineMap::get(numDims, numSymbols, {affCombiner(d0, d1)});
// TODO: createOrFold when available.
return ValueHandle::createComposedAffineApply(map, operands);
}
template <typename IOp, typename FOp>
static ValueHandle createBinaryHandle(
ValueHandle lhs, ValueHandle rhs,
function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) {
auto thisType = lhs.getValue().getType();
auto thatType = rhs.getValue().getType();
assert(thisType == thatType && "cannot mix types in operators");
(void)thisType;
(void)thatType;
if (thisType.isIndex()) {
return createBinaryIndexHandle(lhs, rhs, affCombiner);
} else if (thisType.isa<IntegerType>()) {
return createBinaryHandle<IOp>(lhs, rhs);
} else if (thisType.isa<FloatType>()) {
return createBinaryHandle<FOp>(lhs, rhs);
} else if (thisType.isa<VectorType>() || thisType.isa<TensorType>()) {
auto aggregateType = thisType.cast<ShapedType>();
if (aggregateType.getElementType().isa<IntegerType>())
return createBinaryHandle<IOp>(lhs, rhs);
else if (aggregateType.getElementType().isa<FloatType>())
return createBinaryHandle<FOp>(lhs, rhs);
}
llvm_unreachable("failed to create a ValueHandle");
}
ValueHandle mlir::edsc::op::operator+(ValueHandle lhs, ValueHandle rhs) {
return createBinaryHandle<AddIOp, AddFOp>(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 + d1; });
}
ValueHandle mlir::edsc::op::operator-(ValueHandle lhs, ValueHandle rhs) {
return createBinaryHandle<SubIOp, SubFOp>(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 - d1; });
}
ValueHandle mlir::edsc::op::operator*(ValueHandle lhs, ValueHandle rhs) {
return createBinaryHandle<MulIOp, MulFOp>(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 * d1; });
}
ValueHandle mlir::edsc::op::operator/(ValueHandle lhs, ValueHandle rhs) {
return createBinaryHandle<SignedDivIOp, DivFOp>(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) -> AffineExpr {
llvm_unreachable("only exprs of non-index type support operator/");
});
}
ValueHandle mlir::edsc::op::operator%(ValueHandle lhs, ValueHandle rhs) {
return createBinaryHandle<SignedRemIOp, RemFOp>(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 % d1; });
}
ValueHandle mlir::edsc::op::floorDiv(ValueHandle lhs, ValueHandle rhs) {
return createBinaryIndexHandle(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.floorDiv(d1); });
}
ValueHandle mlir::edsc::op::ceilDiv(ValueHandle lhs, ValueHandle rhs) {
return createBinaryIndexHandle(
lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.ceilDiv(d1); });
}
ValueHandle mlir::edsc::op::operator!(ValueHandle value) {
assert(value.getType().isInteger(1) && "expected boolean expression");
return ValueHandle::create<ConstantIntOp>(1, 1) - value;
}
ValueHandle mlir::edsc::op::operator&&(ValueHandle lhs, ValueHandle rhs) {
assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS");
assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS");
return lhs * rhs;
}
ValueHandle mlir::edsc::op::operator||(ValueHandle lhs, ValueHandle rhs) {
return !(!lhs && !rhs);
}
static ValueHandle createIComparisonExpr(CmpIPredicate predicate,
ValueHandle lhs, ValueHandle rhs) {
auto lhsType = lhs.getType();
auto rhsType = rhs.getType();
(void)lhsType;
(void)rhsType;
assert(lhsType == rhsType && "cannot mix types in operators");
assert((lhsType.isa<IndexType>() || lhsType.isa<IntegerType>()) &&
"only integer comparisons are supported");
auto op = ScopedContext::getBuilder().create<CmpIOp>(
ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue());
return ValueHandle(op.getResult());
}
static ValueHandle createFComparisonExpr(CmpFPredicate predicate,
ValueHandle lhs, ValueHandle rhs) {
auto lhsType = lhs.getType();
auto rhsType = rhs.getType();
(void)lhsType;
(void)rhsType;
assert(lhsType == rhsType && "cannot mix types in operators");
assert(lhsType.isa<FloatType>() && "only float comparisons are supported");
auto op = ScopedContext::getBuilder().create<CmpFOp>(
ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue());
return ValueHandle(op.getResult());
}
// All floating point comparison are ordered through EDSL
ValueHandle mlir::edsc::op::operator==(ValueHandle lhs, ValueHandle rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::eq, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator!=(ValueHandle lhs, ValueHandle rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::ne, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs)
:
// TODO(ntv,zinenko): signed by default, how about unsigned?
createIComparisonExpr(CmpIPredicate::slt, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator<=(ValueHandle lhs, ValueHandle rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::sle, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator>(ValueHandle lhs, ValueHandle rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs);
}
ValueHandle mlir::edsc::op::operator>=(ValueHandle lhs, ValueHandle rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::sge, lhs, rhs);
}