| //===- Builders.cpp - MLIR Declarative Builder Classes --------------------===// |
| // |
| // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/EDSC/Builders.h" |
| #include "mlir/Dialect/StandardOps/Ops.h" |
| #include "mlir/IR/AffineExpr.h" |
| |
| #include "llvm/ADT/Optional.h" |
| |
| using namespace mlir; |
| using namespace mlir::edsc; |
| |
| mlir::edsc::ScopedContext::ScopedContext(OpBuilder &builder, Location location) |
| : builder(builder), location(location), |
| enclosingScopedContext(ScopedContext::getCurrentScopedContext()), |
| nestedBuilder(nullptr) { |
| getCurrentScopedContext() = this; |
| } |
| |
| /// Sets the insertion point of the builder to 'newInsertPt' for the duration |
| /// of the scope. The existing insertion point of the builder is restored on |
| /// destruction. |
| mlir::edsc::ScopedContext::ScopedContext(OpBuilder &builder, |
| OpBuilder::InsertPoint newInsertPt, |
| Location location) |
| : builder(builder), prevBuilderInsertPoint(builder.saveInsertionPoint()), |
| location(location), |
| enclosingScopedContext(ScopedContext::getCurrentScopedContext()), |
| nestedBuilder(nullptr) { |
| getCurrentScopedContext() = this; |
| builder.restoreInsertionPoint(newInsertPt); |
| } |
| |
| mlir::edsc::ScopedContext::~ScopedContext() { |
| assert(!nestedBuilder && |
| "Active NestedBuilder must have been exited at this point!"); |
| if (prevBuilderInsertPoint) |
| builder.restoreInsertionPoint(*prevBuilderInsertPoint); |
| getCurrentScopedContext() = enclosingScopedContext; |
| } |
| |
| ScopedContext *&mlir::edsc::ScopedContext::getCurrentScopedContext() { |
| thread_local ScopedContext *context = nullptr; |
| return context; |
| } |
| |
| OpBuilder &mlir::edsc::ScopedContext::getBuilder() { |
| assert(ScopedContext::getCurrentScopedContext() && |
| "Unexpected Null ScopedContext"); |
| return ScopedContext::getCurrentScopedContext()->builder; |
| } |
| |
| Location mlir::edsc::ScopedContext::getLocation() { |
| assert(ScopedContext::getCurrentScopedContext() && |
| "Unexpected Null ScopedContext"); |
| return ScopedContext::getCurrentScopedContext()->location; |
| } |
| |
| MLIRContext *mlir::edsc::ScopedContext::getContext() { |
| return getBuilder().getContext(); |
| } |
| |
| mlir::edsc::ValueHandle::ValueHandle(index_t cst) { |
| auto &b = ScopedContext::getBuilder(); |
| auto loc = ScopedContext::getLocation(); |
| v = b.create<ConstantIndexOp>(loc, cst.v).getResult(); |
| t = v.getType(); |
| } |
| |
| ValueHandle &mlir::edsc::ValueHandle::operator=(const ValueHandle &other) { |
| assert(t == other.t && "Wrong type capture"); |
| assert(!v && "ValueHandle has already been captured, use a new name!"); |
| v = other.v; |
| return *this; |
| } |
| |
| ValueHandle |
| mlir::edsc::ValueHandle::createComposedAffineApply(AffineMap map, |
| ArrayRef<Value> operands) { |
| Operation *op = |
| makeComposedAffineApply(ScopedContext::getBuilder(), |
| ScopedContext::getLocation(), map, operands) |
| .getOperation(); |
| assert(op->getNumResults() == 1 && "Not a single result AffineApply"); |
| return ValueHandle(op->getResult(0)); |
| } |
| |
| ValueHandle ValueHandle::create(StringRef name, ArrayRef<ValueHandle> operands, |
| ArrayRef<Type> resultTypes, |
| ArrayRef<NamedAttribute> attributes) { |
| Operation *op = |
| OperationHandle::create(name, operands, resultTypes, attributes); |
| if (op->getNumResults() == 1) { |
| return ValueHandle(op->getResult(0)); |
| } |
| if (auto f = dyn_cast<AffineForOp>(op)) { |
| return ValueHandle(f.getInductionVar()); |
| } |
| llvm_unreachable("unsupported operation, use an OperationHandle instead"); |
| } |
| |
| OperationHandle OperationHandle::create(StringRef name, |
| ArrayRef<ValueHandle> operands, |
| ArrayRef<Type> resultTypes, |
| ArrayRef<NamedAttribute> attributes) { |
| OperationState state(ScopedContext::getLocation(), name); |
| SmallVector<Value, 4> ops(operands.begin(), operands.end()); |
| state.addOperands(ops); |
| state.addTypes(resultTypes); |
| for (const auto &attr : attributes) { |
| state.addAttribute(attr.first, attr.second); |
| } |
| return OperationHandle(ScopedContext::getBuilder().createOperation(state)); |
| } |
| |
| BlockHandle mlir::edsc::BlockHandle::create(ArrayRef<Type> argTypes) { |
| auto ¤tB = ScopedContext::getBuilder(); |
| auto *ib = currentB.getInsertionBlock(); |
| auto ip = currentB.getInsertionPoint(); |
| BlockHandle res; |
| res.block = ScopedContext::getBuilder().createBlock(ib->getParent()); |
| // createBlock sets the insertion point inside the block. |
| // We do not want this behavior when using declarative builders with nesting. |
| currentB.setInsertionPoint(ib, ip); |
| for (auto t : argTypes) { |
| res.block->addArgument(t); |
| } |
| return res; |
| } |
| |
| static Optional<ValueHandle> emitStaticFor(ArrayRef<ValueHandle> lbs, |
| ArrayRef<ValueHandle> ubs, |
| int64_t step) { |
| if (lbs.size() != 1 || ubs.size() != 1) |
| return Optional<ValueHandle>(); |
| |
| auto *lbDef = lbs.front().getValue().getDefiningOp(); |
| auto *ubDef = ubs.front().getValue().getDefiningOp(); |
| if (!lbDef || !ubDef) |
| return Optional<ValueHandle>(); |
| |
| auto lbConst = dyn_cast<ConstantIndexOp>(lbDef); |
| auto ubConst = dyn_cast<ConstantIndexOp>(ubDef); |
| if (!lbConst || !ubConst) |
| return Optional<ValueHandle>(); |
| |
| return ValueHandle::create<AffineForOp>(lbConst.getValue(), |
| ubConst.getValue(), step); |
| } |
| |
| mlir::edsc::LoopBuilder mlir::edsc::LoopBuilder::makeAffine( |
| ValueHandle *iv, ArrayRef<ValueHandle> lbHandles, |
| ArrayRef<ValueHandle> ubHandles, int64_t step) { |
| mlir::edsc::LoopBuilder result; |
| if (auto staticFor = emitStaticFor(lbHandles, ubHandles, step)) { |
| *iv = staticFor.getValue(); |
| } else { |
| SmallVector<Value, 4> lbs(lbHandles.begin(), lbHandles.end()); |
| SmallVector<Value, 4> ubs(ubHandles.begin(), ubHandles.end()); |
| *iv = ValueHandle::create<AffineForOp>( |
| lbs, ScopedContext::getBuilder().getMultiDimIdentityMap(lbs.size()), |
| ubs, ScopedContext::getBuilder().getMultiDimIdentityMap(ubs.size()), |
| step); |
| } |
| auto *body = getForInductionVarOwner(iv->getValue()).getBody(); |
| result.enter(body, /*prev=*/1); |
| return result; |
| } |
| |
| mlir::edsc::LoopBuilder |
| mlir::edsc::LoopBuilder::makeLoop(ValueHandle *iv, ValueHandle lbHandle, |
| ValueHandle ubHandle, |
| ValueHandle stepHandle) { |
| mlir::edsc::LoopBuilder result; |
| auto forOp = |
| OperationHandle::createOp<loop::ForOp>(lbHandle, ubHandle, stepHandle); |
| *iv = ValueHandle(forOp.getInductionVar()); |
| auto *body = loop::getForInductionVarOwner(iv->getValue()).getBody(); |
| result.enter(body, /*prev=*/1); |
| return result; |
| } |
| |
| void mlir::edsc::LoopBuilder::operator()(function_ref<void(void)> fun) { |
| // Call to `exit` must be explicit and asymmetric (cannot happen in the |
| // destructor) because of ordering wrt comma operator. |
| /// The particular use case concerns nested blocks: |
| /// |
| /// ```c++ |
| /// For (&i, lb, ub, 1)({ |
| /// /--- destructor for this `For` is not always called before ... |
| /// V |
| /// For (&j1, lb, ub, 1)({ |
| /// some_op_1, |
| /// }), |
| /// /--- ... this scope is entered, resulting in improperly nested IR. |
| /// V |
| /// For (&j2, lb, ub, 1)({ |
| /// some_op_2, |
| /// }), |
| /// }); |
| /// ``` |
| if (fun) |
| fun(); |
| exit(); |
| } |
| |
| mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder( |
| ValueHandle *iv, ArrayRef<ValueHandle> lbs, ArrayRef<ValueHandle> ubs, |
| int64_t step) { |
| loops.emplace_back(LoopBuilder::makeAffine(iv, lbs, ubs, step)); |
| } |
| |
| mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder( |
| ArrayRef<ValueHandle *> ivs, ArrayRef<ValueHandle> lbs, |
| ArrayRef<ValueHandle> ubs, ArrayRef<int64_t> steps) { |
| assert(ivs.size() == lbs.size() && "Mismatch in number of arguments"); |
| assert(ivs.size() == ubs.size() && "Mismatch in number of arguments"); |
| assert(ivs.size() == steps.size() && "Mismatch in number of arguments"); |
| for (auto it : llvm::zip(ivs, lbs, ubs, steps)) |
| loops.emplace_back(LoopBuilder::makeAffine( |
| std::get<0>(it), std::get<1>(it), std::get<2>(it), std::get<3>(it))); |
| } |
| |
| void mlir::edsc::AffineLoopNestBuilder::operator()( |
| function_ref<void(void)> fun) { |
| if (fun) |
| fun(); |
| // Iterate on the calling operator() on all the loops in the nest. |
| // The iteration order is from innermost to outermost because enter/exit needs |
| // to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit() |
| // occurs on calling operator()). The asymmetry is required for properly |
| // nesting imperfectly nested regions (see LoopBuilder::operator()). |
| for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) |
| (*lit)(); |
| } |
| |
| mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef<ValueHandle *> ivs, |
| ArrayRef<ValueHandle> lbs, |
| ArrayRef<ValueHandle> ubs, |
| ArrayRef<ValueHandle> steps) { |
| assert(ivs.size() == lbs.size() && "expected size of ivs and lbs to match"); |
| assert(ivs.size() == ubs.size() && "expected size of ivs and ubs to match"); |
| assert(ivs.size() == steps.size() && |
| "expected size of ivs and steps to match"); |
| loops.reserve(ivs.size()); |
| for (auto it : llvm::zip(ivs, lbs, ubs, steps)) { |
| loops.emplace_back(LoopBuilder::makeLoop(std::get<0>(it), std::get<1>(it), |
| std::get<2>(it), std::get<3>(it))); |
| } |
| assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size"); |
| } |
| |
| void LoopNestBuilder::LoopNestBuilder::operator()( |
| std::function<void(void)> fun) { |
| if (fun) |
| fun(); |
| for (auto &lit : reverse(loops)) |
| lit({}); |
| } |
| |
| mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle bh, Append) { |
| assert(bh && "Expected already captured BlockHandle"); |
| enter(bh.getBlock()); |
| } |
| |
| mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh, |
| ArrayRef<ValueHandle *> args) { |
| assert(!*bh && "BlockHandle already captures a block, use " |
| "the explicit BockBuilder(bh, Append())({}) syntax instead."); |
| SmallVector<Type, 8> types; |
| for (auto *a : args) { |
| assert(!a->hasValue() && |
| "Expected delayed ValueHandle that has not yet captured."); |
| types.push_back(a->getType()); |
| } |
| *bh = BlockHandle::create(types); |
| for (auto it : llvm::zip(args, bh->getBlock()->getArguments())) { |
| *(std::get<0>(it)) = ValueHandle(std::get<1>(it)); |
| } |
| enter(bh->getBlock()); |
| } |
| |
| /// Only serves as an ordering point between entering nested block and creating |
| /// stmts. |
| void mlir::edsc::BlockBuilder::operator()(function_ref<void(void)> fun) { |
| // Call to `exit` must be explicit and asymmetric (cannot happen in the |
| // destructor) because of ordering wrt comma operator. |
| if (fun) |
| fun(); |
| exit(); |
| } |
| |
| template <typename Op> |
| static ValueHandle createBinaryHandle(ValueHandle lhs, ValueHandle rhs) { |
| return ValueHandle::create<Op>(lhs.getValue(), rhs.getValue()); |
| } |
| |
| static std::pair<AffineExpr, Value> |
| categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims, |
| unsigned &numSymbols) { |
| AffineExpr d; |
| Value resultVal = nullptr; |
| if (auto constant = dyn_cast_or_null<ConstantIndexOp>(val.getDefiningOp())) { |
| d = getAffineConstantExpr(constant.getValue(), context); |
| } else if (isValidSymbol(val) && !isValidDim(val)) { |
| d = getAffineSymbolExpr(numSymbols++, context); |
| resultVal = val; |
| } else { |
| d = getAffineDimExpr(numDims++, context); |
| resultVal = val; |
| } |
| return std::make_pair(d, resultVal); |
| } |
| |
| static ValueHandle createBinaryIndexHandle( |
| ValueHandle lhs, ValueHandle rhs, |
| function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) { |
| MLIRContext *context = ScopedContext::getContext(); |
| unsigned numDims = 0, numSymbols = 0; |
| AffineExpr d0, d1; |
| Value v0, v1; |
| std::tie(d0, v0) = |
| categorizeValueByAffineType(context, lhs.getValue(), numDims, numSymbols); |
| std::tie(d1, v1) = |
| categorizeValueByAffineType(context, rhs.getValue(), numDims, numSymbols); |
| SmallVector<Value, 2> operands; |
| if (v0) { |
| operands.push_back(v0); |
| } |
| if (v1) { |
| operands.push_back(v1); |
| } |
| auto map = AffineMap::get(numDims, numSymbols, {affCombiner(d0, d1)}); |
| // TODO: createOrFold when available. |
| return ValueHandle::createComposedAffineApply(map, operands); |
| } |
| |
| template <typename IOp, typename FOp> |
| static ValueHandle createBinaryHandle( |
| ValueHandle lhs, ValueHandle rhs, |
| function_ref<AffineExpr(AffineExpr, AffineExpr)> affCombiner) { |
| auto thisType = lhs.getValue().getType(); |
| auto thatType = rhs.getValue().getType(); |
| assert(thisType == thatType && "cannot mix types in operators"); |
| (void)thisType; |
| (void)thatType; |
| if (thisType.isIndex()) { |
| return createBinaryIndexHandle(lhs, rhs, affCombiner); |
| } else if (thisType.isa<IntegerType>()) { |
| return createBinaryHandle<IOp>(lhs, rhs); |
| } else if (thisType.isa<FloatType>()) { |
| return createBinaryHandle<FOp>(lhs, rhs); |
| } else if (thisType.isa<VectorType>() || thisType.isa<TensorType>()) { |
| auto aggregateType = thisType.cast<ShapedType>(); |
| if (aggregateType.getElementType().isa<IntegerType>()) |
| return createBinaryHandle<IOp>(lhs, rhs); |
| else if (aggregateType.getElementType().isa<FloatType>()) |
| return createBinaryHandle<FOp>(lhs, rhs); |
| } |
| llvm_unreachable("failed to create a ValueHandle"); |
| } |
| |
| ValueHandle mlir::edsc::op::operator+(ValueHandle lhs, ValueHandle rhs) { |
| return createBinaryHandle<AddIOp, AddFOp>( |
| lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 + d1; }); |
| } |
| |
| ValueHandle mlir::edsc::op::operator-(ValueHandle lhs, ValueHandle rhs) { |
| return createBinaryHandle<SubIOp, SubFOp>( |
| lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 - d1; }); |
| } |
| |
| ValueHandle mlir::edsc::op::operator*(ValueHandle lhs, ValueHandle rhs) { |
| return createBinaryHandle<MulIOp, MulFOp>( |
| lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 * d1; }); |
| } |
| |
| ValueHandle mlir::edsc::op::operator/(ValueHandle lhs, ValueHandle rhs) { |
| return createBinaryHandle<SignedDivIOp, DivFOp>( |
| lhs, rhs, [](AffineExpr d0, AffineExpr d1) -> AffineExpr { |
| llvm_unreachable("only exprs of non-index type support operator/"); |
| }); |
| } |
| |
| ValueHandle mlir::edsc::op::operator%(ValueHandle lhs, ValueHandle rhs) { |
| return createBinaryHandle<SignedRemIOp, RemFOp>( |
| lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 % d1; }); |
| } |
| |
| ValueHandle mlir::edsc::op::floorDiv(ValueHandle lhs, ValueHandle rhs) { |
| return createBinaryIndexHandle( |
| lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.floorDiv(d1); }); |
| } |
| |
| ValueHandle mlir::edsc::op::ceilDiv(ValueHandle lhs, ValueHandle rhs) { |
| return createBinaryIndexHandle( |
| lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.ceilDiv(d1); }); |
| } |
| |
| ValueHandle mlir::edsc::op::operator!(ValueHandle value) { |
| assert(value.getType().isInteger(1) && "expected boolean expression"); |
| return ValueHandle::create<ConstantIntOp>(1, 1) - value; |
| } |
| |
| ValueHandle mlir::edsc::op::operator&&(ValueHandle lhs, ValueHandle rhs) { |
| assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS"); |
| assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS"); |
| return lhs * rhs; |
| } |
| |
| ValueHandle mlir::edsc::op::operator||(ValueHandle lhs, ValueHandle rhs) { |
| return !(!lhs && !rhs); |
| } |
| |
| static ValueHandle createIComparisonExpr(CmpIPredicate predicate, |
| ValueHandle lhs, ValueHandle rhs) { |
| auto lhsType = lhs.getType(); |
| auto rhsType = rhs.getType(); |
| (void)lhsType; |
| (void)rhsType; |
| assert(lhsType == rhsType && "cannot mix types in operators"); |
| assert((lhsType.isa<IndexType>() || lhsType.isa<IntegerType>()) && |
| "only integer comparisons are supported"); |
| |
| auto op = ScopedContext::getBuilder().create<CmpIOp>( |
| ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue()); |
| return ValueHandle(op.getResult()); |
| } |
| |
| static ValueHandle createFComparisonExpr(CmpFPredicate predicate, |
| ValueHandle lhs, ValueHandle rhs) { |
| auto lhsType = lhs.getType(); |
| auto rhsType = rhs.getType(); |
| (void)lhsType; |
| (void)rhsType; |
| assert(lhsType == rhsType && "cannot mix types in operators"); |
| assert(lhsType.isa<FloatType>() && "only float comparisons are supported"); |
| |
| auto op = ScopedContext::getBuilder().create<CmpFOp>( |
| ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue()); |
| return ValueHandle(op.getResult()); |
| } |
| |
| // All floating point comparison are ordered through EDSL |
| ValueHandle mlir::edsc::op::operator==(ValueHandle lhs, ValueHandle rhs) { |
| auto type = lhs.getType(); |
| return type.isa<FloatType>() |
| ? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs) |
| : createIComparisonExpr(CmpIPredicate::eq, lhs, rhs); |
| } |
| ValueHandle mlir::edsc::op::operator!=(ValueHandle lhs, ValueHandle rhs) { |
| auto type = lhs.getType(); |
| return type.isa<FloatType>() |
| ? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs) |
| : createIComparisonExpr(CmpIPredicate::ne, lhs, rhs); |
| } |
| ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) { |
| auto type = lhs.getType(); |
| return type.isa<FloatType>() |
| ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs) |
| : |
| // TODO(ntv,zinenko): signed by default, how about unsigned? |
| createIComparisonExpr(CmpIPredicate::slt, lhs, rhs); |
| } |
| ValueHandle mlir::edsc::op::operator<=(ValueHandle lhs, ValueHandle rhs) { |
| auto type = lhs.getType(); |
| return type.isa<FloatType>() |
| ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs) |
| : createIComparisonExpr(CmpIPredicate::sle, lhs, rhs); |
| } |
| ValueHandle mlir::edsc::op::operator>(ValueHandle lhs, ValueHandle rhs) { |
| auto type = lhs.getType(); |
| return type.isa<FloatType>() |
| ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs) |
| : createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs); |
| } |
| ValueHandle mlir::edsc::op::operator>=(ValueHandle lhs, ValueHandle rhs) { |
| auto type = lhs.getType(); |
| return type.isa<FloatType>() |
| ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs) |
| : createIComparisonExpr(CmpIPredicate::sge, lhs, rhs); |
| } |