blob: a52b89027c5b659e781daeb4cf1bb4baab1e93c8 [file] [log] [blame]
//===- CommonFolders.h - Common Operation Folders----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file declares various common operation folders. These folders
// are intended to be used by dialects to support common folding behavior
// without requiring each dialect to provide its own implementation.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_COMMONFOLDERS_H
#define MLIR_DIALECT_COMMONFOLDERS_H
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
namespace mlir {
/// Performs constant folding `calculate` with element-wise behavior on the two
/// attributes in `operands` and returns the result if possible.
template <class AttrElementT,
class ElementValueT = typename AttrElementT::ValueType,
class CalculationT =
function_ref<ElementValueT(ElementValueT, ElementValueT)>>
Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
const CalculationT &calculate) {
assert(operands.size() == 2 && "binary op takes two operands");
if (!operands[0] || !operands[1])
return {};
if (operands[0].getType() != operands[1].getType())
return {};
if (operands[0].isa<AttrElementT>() && operands[1].isa<AttrElementT>()) {
auto lhs = operands[0].cast<AttrElementT>();
auto rhs = operands[1].cast<AttrElementT>();
return AttrElementT::get(lhs.getType(),
calculate(lhs.getValue(), rhs.getValue()));
} else if (operands[0].isa<SplatElementsAttr>() &&
operands[1].isa<SplatElementsAttr>()) {
// Both operands are splats so we can avoid expanding the values out and
// just fold based on the splat value.
auto lhs = operands[0].cast<SplatElementsAttr>();
auto rhs = operands[1].cast<SplatElementsAttr>();
auto elementResult = calculate(lhs.getSplatValue<ElementValueT>(),
rhs.getSplatValue<ElementValueT>());
return DenseElementsAttr::get(lhs.getType(), elementResult);
} else if (operands[0].isa<ElementsAttr>() &&
operands[1].isa<ElementsAttr>()) {
// Operands are ElementsAttr-derived; perform an element-wise fold by
// expanding the values.
auto lhs = operands[0].cast<ElementsAttr>();
auto rhs = operands[1].cast<ElementsAttr>();
auto lhsIt = lhs.value_begin<ElementValueT>();
auto rhsIt = rhs.value_begin<ElementValueT>();
SmallVector<ElementValueT, 4> elementResults;
elementResults.reserve(lhs.getNumElements());
for (size_t i = 0, e = lhs.getNumElements(); i < e; ++i, ++lhsIt, ++rhsIt)
elementResults.push_back(calculate(*lhsIt, *rhsIt));
return DenseElementsAttr::get(lhs.getType(), elementResults);
}
return {};
}
} // namespace mlir
#endif // MLIR_DIALECT_COMMONFOLDERS_H