blob: 7dabc781cd59526fd0c7248bb9486ee12914ba05 [file] [log] [blame]
//===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file declares various common operation folders. These folders
// are intended to be used by dialects to support common folding behavior
// without requiring each dialect to provide its own implementation.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_COMMONFOLDERS_H
#define MLIR_DIALECT_COMMONFOLDERS_H
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include <optional>
namespace mlir {
namespace ub {
class PoisonAttr;
}
/// Performs constant folding `calculate` with element-wise behavior on the two
/// attributes in `operands` and returns the result if possible.
/// Uses `resultType` for the type of the returned attribute.
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
/// which will be directly propagated to result.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
class CalculationT = function_ref<
std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
Type resultType,
CalculationT &&calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
static_assert(
std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
"PoisonAttr is undefined, either add a dependency on UB dialect or pass "
"void as template argument to opt-out from poison semantics.");
if constexpr (!std::is_void_v<PoisonAttr>) {
if (isa_and_nonnull<PoisonAttr>(operands[0]))
return operands[0];
if (isa_and_nonnull<PoisonAttr>(operands[1]))
return operands[1];
}
if (!resultType || !operands[0] || !operands[1])
return {};
if (isa<AttrElementT>(operands[0]) && isa<AttrElementT>(operands[1])) {
auto lhs = cast<AttrElementT>(operands[0]);
auto rhs = cast<AttrElementT>(operands[1]);
if (lhs.getType() != rhs.getType())
return {};
auto calRes = calculate(lhs.getValue(), rhs.getValue());
if (!calRes)
return {};
return AttrElementT::get(resultType, *calRes);
}
if (isa<SplatElementsAttr>(operands[0]) &&
isa<SplatElementsAttr>(operands[1])) {
// Both operands are splats so we can avoid expanding the values out and
// just fold based on the splat value.
auto lhs = cast<SplatElementsAttr>(operands[0]);
auto rhs = cast<SplatElementsAttr>(operands[1]);
if (lhs.getType() != rhs.getType())
return {};
auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
rhs.getSplatValue<ElementValueT>());
if (!elementResult)
return {};
return DenseElementsAttr::get(cast<ShapedType>(resultType), *elementResult);
}
if (isa<ElementsAttr>(operands[0]) && isa<ElementsAttr>(operands[1])) {
// Operands are ElementsAttr-derived; perform an element-wise fold by
// expanding the values.
auto lhs = cast<ElementsAttr>(operands[0]);
auto rhs = cast<ElementsAttr>(operands[1]);
if (lhs.getType() != rhs.getType())
return {};
auto maybeLhsIt = lhs.try_value_begin<ElementValueT>();
auto maybeRhsIt = rhs.try_value_begin<ElementValueT>();
if (!maybeLhsIt || !maybeRhsIt)
return {};
auto lhsIt = *maybeLhsIt;
auto rhsIt = *maybeRhsIt;
SmallVector<ElementValueT, 4> elementResults;
elementResults.reserve(lhs.getNumElements());
for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt) {
auto elementResult = calculate(*lhsIt, *rhsIt);
if (!elementResult)
return {};
elementResults.push_back(*elementResult);
}
return DenseElementsAttr::get(cast<ShapedType>(resultType), elementResults);
}
return {};
}
/// Performs constant folding `calculate` with element-wise behavior on the two
/// attributes in `operands` and returns the result if possible.
/// Uses the operand element type for the element type of the returned
/// attribute.
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
/// which will be directly propagated to result.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
class CalculationT = function_ref<
std::optional<ElementValueT>(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOpConditional(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
static_assert(
std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
"PoisonAttr is undefined, either add a dependency on UB dialect or pass "
"void as template argument to opt-out from poison semantics.");
if constexpr (!std::is_void_v<PoisonAttr>) {
if (isa_and_nonnull<PoisonAttr>(operands[0]))
return operands[0];
if (isa_and_nonnull<PoisonAttr>(operands[1]))
return operands[1];
}
auto getResultType = [](Attribute attr) -> Type {
if (auto typed = dyn_cast_or_null<TypedAttr>(attr))
return typed.getType();
return {};
};
Type lhsType = getResultType(operands[0]);
Type rhsType = getResultType(operands[1]);
if (!lhsType || !rhsType)
return {};
if (lhsType != rhsType)
return {};
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr,
CalculationT>(
operands, lhsType, std::forward<CalculationT>(calculate));
}
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class PoisonAttr = void,
class CalculationT =
function_ref<ElementValueT(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands, Type resultType,
CalculationT &&calculate) {
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
operands, resultType,
[&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
return calculate(a, b);
});
}
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
class CalculationT =
function_ref<ElementValueT(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
return constFoldBinaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
operands,
[&](ElementValueT a, ElementValueT b) -> std::optional<ElementValueT> {
return calculate(a, b);
});
}
/// Performs constant folding `calculate` with element-wise behavior on the one
/// attributes in `operands` and returns the result if possible.
/// Optional PoisonAttr template argument allows to specify 'poison' attribute
/// which will be directly propagated to result.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
class CalculationT =
function_ref<std::optional<ElementValueT>(ElementValueT)>>
Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
assert(operands.size() == 1 && "unary op takes one operands");
if (!operands[0])
return {};
static_assert(
std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
"PoisonAttr is undefined, either add a dependency on UB dialect or pass "
"void as template argument to opt-out from poison semantics.");
if constexpr (!std::is_void_v<PoisonAttr>) {
if (isa<PoisonAttr>(operands[0]))
return operands[0];
}
if (isa<AttrElementT>(operands[0])) {
auto op = cast<AttrElementT>(operands[0]);
auto res = calculate(op.getValue());
if (!res)
return {};
return AttrElementT::get(op.getType(), *res);
}
if (isa<SplatElementsAttr>(operands[0])) {
// Both operands are splats so we can avoid expanding the values out and
// just fold based on the splat value.
auto op = cast<SplatElementsAttr>(operands[0]);
auto elementResult = calculate(op.getSplatValue<ElementValueT>());
if (!elementResult)
return {};
return DenseElementsAttr::get(op.getType(), *elementResult);
} else if (isa<ElementsAttr>(operands[0])) {
// Operands are ElementsAttr-derived; perform an element-wise fold by
// expanding the values.
auto op = cast<ElementsAttr>(operands[0]);
auto maybeOpIt = op.try_value_begin<ElementValueT>();
if (!maybeOpIt)
return {};
auto opIt = *maybeOpIt;
SmallVector<ElementValueT> elementResults;
elementResults.reserve(op.getNumElements());
for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
auto elementResult = calculate(*opIt);
if (!elementResult)
return {};
elementResults.push_back(*elementResult);
}
return DenseElementsAttr::get(op.getShapedType(), elementResults);
}
return {};
}
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
class CalculationT = function_ref<ElementValueT(ElementValueT)>>
Attribute constFoldUnaryOp(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
return constFoldUnaryOpConditional<AttrElementT, ElementValueT, PoisonAttr>(
operands, [&](ElementValueT a) -> std::optional<ElementValueT> {
return calculate(a);
});
}
template <
class AttrElementT, class TargetAttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class TargetElementValueT = typename TargetAttrElementT::ValueType,
class PoisonAttr = ub::PoisonAttr,
class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
CalculationT &&calculate) {
assert(operands.size() == 1 && "Cast op takes one operand");
if (!operands[0])
return {};
static_assert(
std::is_void_v<PoisonAttr> || !llvm::is_incomplete_v<PoisonAttr>,
"PoisonAttr is undefined, either add a dependency on UB dialect or pass "
"void as template argument to opt-out from poison semantics.");
if constexpr (!std::is_void_v<PoisonAttr>) {
if (isa<PoisonAttr>(operands[0]))
return operands[0];
}
if (isa<AttrElementT>(operands[0])) {
auto op = cast<AttrElementT>(operands[0]);
bool castStatus = true;
auto res = calculate(op.getValue(), castStatus);
if (!castStatus)
return {};
return TargetAttrElementT::get(resType, res);
}
if (isa<SplatElementsAttr>(operands[0])) {
// The operand is a splat so we can avoid expanding the values out and
// just fold based on the splat value.
auto op = cast<SplatElementsAttr>(operands[0]);
bool castStatus = true;
auto elementResult =
calculate(op.getSplatValue<ElementValueT>(), castStatus);
if (!castStatus)
return {};
return DenseElementsAttr::get(cast<ShapedType>(resType), elementResult);
}
if (auto op = dyn_cast<ElementsAttr>(operands[0])) {
// Operand is ElementsAttr-derived; perform an element-wise fold by
// expanding the value.
bool castStatus = true;
auto maybeOpIt = op.try_value_begin<ElementValueT>();
if (!maybeOpIt)
return {};
auto opIt = *maybeOpIt;
SmallVector<TargetElementValueT> elementResults;
elementResults.reserve(op.getNumElements());
for (size_t i = 0, e = op.getNumElements(); i < e; ++i, ++opIt) {
auto elt = calculate(*opIt, castStatus);
if (!castStatus)
return {};
elementResults.push_back(elt);
}
return DenseElementsAttr::get(cast<ShapedType>(resType), elementResults);
}
return {};
}
} // namespace mlir
#endif // MLIR_DIALECT_COMMONFOLDERS_H