blob: 539f3e833b12d6ff06c276a702e6caecb2f7b381 [file]
//===- TosaReduceTransposes.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// ----------
// Motivation:
// ----------
// Some legalization pathways introduce redundant tosa.TRANSPOSE
// operations that result in avoidable data movement. For example,
// PyTorch -> TOSA contains a lot of unnecessary transposes due
// to conversions between NCHW and NHWC.
// We wish to remove all the ones that we can, since in general
// it is possible to remove the overwhelming majority.
// -------------------
// High-Level Overview:
// -------------------
// The pass works through the transpose operators in the program. It begins at
// some transpose operator with an associated permutations tensor. It traverses
// upwards through the dependencies of this transpose and verifies that we
// encounter only operators with the TosaElementwiseOperator trait and terminate
// in either constants, reshapes, or transposes.
// We then evaluate whether there are any additional restrictions (the
// transposes it terminates in must invert the one we began at, and the reshapes
// must be ones in which we can fold the transpose into), and then we hoist the
// transpose through the intervening operators, folding it at the constants,
// reshapes, and transposes.
// Finally, we ensure that we do not need both the transposed form (the form
// that had the transpose hoisted through it) and the untransposed form (which
// it was prior), by analyzing the usages of those dependent operators of a
// given transpose we are attempting to hoist and replace.
// If they are such that it would require both forms to be necessary, then we do
// not replace the hoisted transpose, causing the new chain to be dead.
// Otherwise, we do and the old chain (untransposed form) becomes dead. Only one
// chain will ever then be live, resulting in no duplication.
// We then perform a simple one-pass DCE, so no canonicalization is necessary.
// -----------
// Future Work:
// -----------
// (1) Evaluate tradeoffs with permitting ConstOp to be duplicated across
// hoisted
// transposes with different permutation tensors.
// (2) Expand the class of foldable upstream ReshapeOp we permit beyond
// N -> 1x1x...x1xNx1x...x1x1.
// (3) Enchance the pass to permit folding arbitrary transpose pairs, beyond
// those that form the identity.
// (4) Add support for more instructions besides TosaElementwiseOperator as
// the intervening ones (for example, the reduce_* operators).
// (5) Support hoisting transposes up to an input parameter.
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/Matchers.h"
#include "llvm/ADT/TypeSwitch.h"
#include <memory>
#include <set>
#include <stack>
namespace mlir {
namespace tosa {
#define GEN_PASS_DEF_TOSAREDUCETRANSPOSES
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
} // namespace tosa
} // namespace mlir
using namespace mlir;
using namespace mlir::tosa;
//===----------------------------------------------------------------------===//
// TOSA Reduce Transposes Pass.
//===----------------------------------------------------------------------===//
namespace {
struct TosaReduceTransposes final
: public tosa::impl::TosaReduceTransposesBase<TosaReduceTransposes> {
void runOnOperation() override;
private:
// This will collect all the data dependencies for the given Operation
// up to and including ConstOp, ReshapeOp, and TransposeOp.
bool collectFanIn(Operation *op, SetVector<Operation *> &collected);
bool convertDependentOps(SetVector<Operation *> &dependentOps,
DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter,
ArrayRef<int32_t> hoistedPerms);
// Checks if the two permutations, when applied consecutively, result
// in the identity.
bool areInvolutionTransposes(ArrayRef<int32_t> perms1,
ArrayRef<int32_t> perms2);
// This is meant to apply to operations with the TosaElementwiseOperator
// trait.
std::optional<Value>
buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
// This updates valuesMap when we encounter another TransposeOp as a
// dependency of the hoisted one. %0 = tosa.transpose %arg0 <- applies to
// this %1 = tosa.transpose %0 <- when tracking back from this
std::optional<Value>
buildMappedToValue(TransposeOp transposeOp,
const DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
// Checks if ReshapeOp can have hoisted TransposeOp folded into it. If so,
// it creates new ReshapeOp with that fold.
std::optional<Value>
buildMappedToValue(ReshapeOp reshapeOp,
const DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
// We may have something like:
// %0 = tosa.const
// %1 = tosa.transpose
// %2 = tosa.add %0, %1
// %3 = tosa.transpose %2
// that --tosa-layerwise-const-fold wouldn't handle. This use shows up
// in MobilenetV3.
std::optional<Value>
buildMappedToValue(ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
// Checks which TransposeOp we should "replace", turning their converted
// chains of ops, through which they were propagated, "live", and the old code
// "dead." Attempts to avoid doing so when doing so would result in the old
// code staying "live," resulting in duplication.
std::set<TransposeOp> getGoodReplacements(
ArrayRef<int32_t> perms,
std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
&transposeInfo);
// Helper function for dependenciesAreValid.
bool userNotContainedInValidTransposeDependencies(
Operation *user, std::set<TransposeOp> &validTransposes,
std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
&transposeInfo);
// Helper function for getGoodReplacements to check if some TransposeOp's
// dependencies are OK.
bool dependenciesAreValid(
ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
std::set<TransposeOp> &validTransposes,
std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
&transposeInfo);
// Applies perms to the DenseElementsAttr.
// If it returns std::nullopt, it also triggers pass failure, since verifier
// guarantees from TOSA are not in place (and otherwise, if used elsewhere,
// it should fail).
// This is a basic API and may benefit from refactor into the core MLIR APIs.
std::optional<DenseElementsAttr>
transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms);
};
std::optional<DenseElementsAttr>
TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
ArrayRef<int32_t> perms) {
RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType());
RankedTensorType newType =
RankedTensorType::get(applyTOSAPermutation(oldType.getShape(), perms),
oldType.getElementType());
size_t rank = oldType.getRank();
// Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
// 0. If not in place, something is very wrong.
if (rank <= 0 || oldType.getNumElements() <= 0 || perms.size() != rank) {
signalPassFailure();
return std::nullopt;
}
if (input.isSplat())
return input.reshape(newType);
// The algorithm is approximately as follows:
// input: perms, input flat array, input tensor type
// (1/2) determine the strides of input/output if
// they were strided in row-major order. (3) adjust the strides for the
// input to be in the same order of indices as the output is written.
// (4) process dimension by dimension. example: perms 2, 0, 1; input
// 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] =
// input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust
// input strides to be as input[i + 12j + 4k] so we may process
// layer-by-layer.
// Step 1/2: Strides for input. We ignore output since row-major and can just
// push_back.
SmallVector<int64_t> originalInputStrides(rank);
originalInputStrides[rank - 1] = 1;
// index with int64_t to avoid overflow
for (int64_t i = rank - 2; i >= 0; i--)
originalInputStrides[i] =
originalInputStrides[i + 1] * oldType.getDimSize(i + 1);
// Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as
// output which is done in row-major order.
SmallVector<int64_t> newInputStrides;
newInputStrides.reserve(rank);
for (int32_t v : perms)
newInputStrides.push_back(originalInputStrides[v]);
// Step 4: Write out the transposed "flat array" dimension by dimension.
auto inputArray = input.getValues<Attribute>();
SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides;
for (size_t i = 0; i < rank; i++)
boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]});
SmallVector<Attribute> resultArray;
resultArray.reserve(inputArray.size());
std::function<void(int64_t,
SmallVector<std::pair<int64_t, int64_t>>::const_iterator)>
processTransposeDim = [&](auto accumulatedIndex, auto it) {
if (it == boundsAndStrides.end()) {
resultArray.push_back(inputArray[accumulatedIndex]);
return;
}
for (int64_t i = 0; i < it->first; i++) {
int64_t j = accumulatedIndex + i * it->second;
processTransposeDim(j, it + 1);
}
};
processTransposeDim(0, boundsAndStrides.begin());
return DenseElementsAttr::get(newType, resultArray);
}
// The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
// as the sources of the data dependencies, and TosaElementWiseOperator
// after that, if the function returns true.
bool TosaReduceTransposes::collectFanIn(Operation *op,
SetVector<Operation *> &collected) {
// Can occur if defined through the parameter to a func.func.
if (!op)
return false;
if (!llvm::isa_and_present<tosa::TosaDialect>(op->getDialect()))
return false;
// Prevent extra work if already seen.
if (collected.contains(op))
return true;
// Throw it out so later don't have to deal with this.
if (op->getNumResults() != 1 ||
!llvm::isa<RankedTensorType>(op->getResult(0).getType()))
return false;
// We don't wish to traverse up a ReshapeOp, since generally we can't
// propagate a TransposeOp through it. TransposeOp, ReshapeOp, ConstOp
// will have no in-edges in the data dependency graph we construct for
// the downstream TransposeOp.
if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
!llvm::isa<tosa::ConstOp>(op)) {
if (!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
return false;
for (Value operand : op->getOperands())
// If this is a problem in future, think about alternatives to recursion.
if (!collectFanIn(operand.getDefiningOp(), collected))
return false;
}
// Insert in topological order.
collected.insert(op);
return true;
}
// Assuming that due to the verification of TransposeOp perms arrays are
// permutations of 0 - perms.size() - 1.
bool TosaReduceTransposes::areInvolutionTransposes(ArrayRef<int32_t> perms1,
ArrayRef<int32_t> perms2) {
if (perms1.size() != perms2.size())
return false;
int32_t n = perms1.size();
for (int32_t i = 0; i < n; i++)
if (perms2[perms1[i]] != i)
return false;
return true;
}
// Primary overload for those with TosaElementwiseOperator trait.
// The other ones handle the case of the operations that occur at the
// roots of the data dependency graph (ConstOp, ReshapeOp, TransposeOp).
std::optional<Value> TosaReduceTransposes::buildMappedToValue(
Operation *op, const DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
if (op->getNumResults() != 1 ||
!op->hasTrait<OpTrait::tosa::TosaElementwiseOperator>())
return std::nullopt;
auto resultType = llvm::cast<RankedTensorType>(op->getResult(0).getType());
SmallVector<Value, 3> operands;
for (Value v : op->getOperands()) {
if (valuesMap.contains(v)) {
operands.push_back(valuesMap.at(v));
} else {
return std::nullopt;
}
}
// Conceptually, we propagate the hoisted TransposeOp through
// these interveaning operations. For example,
// %0 = tosa.clamp %input : (tensor<2x3xi32>) -> tensor<2x3xi32>
// %1 = tosa.transpose %0 {perms = [1, 0]} : (tensor<2x3xi32>) ->
// tensor<3x2xi32>
// becomes:
// %0 = tosa.transpose %input {perms = [1, 0]} : (tensor<2x3xi32>) ->
// tensor<3x2xi32>
// %1 = tosa.clamp %0 : (tensor<3x2xi32>) -> tensor<3x2xi32>)
// We construct this new tosa.clamp here, but it doesn't
// turn "live" until the transpose being hoisted through this chain
// is replaced with the proper value from the new chain.
return rewriter
.create(op->getLoc(), op->getName().getIdentifier(), operands,
RankedTensorType::get(
applyTOSAPermutation(resultType.getShape(), hoistedPerms),
resultType.getElementType()),
op->getAttrs())
->getResult(0);
}
std::optional<Value> TosaReduceTransposes::buildMappedToValue(
TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
SmallVector<int32_t> perms;
if (failed(transposeOp.getConstantPerms(perms)) ||
!areInvolutionTransposes(hoistedPerms, perms))
return std::nullopt;
return transposeOp.getInput1();
}
std::optional<Value> TosaReduceTransposes::buildMappedToValue(
ReshapeOp reshapeOp, const DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
auto reshapeOutput = reshapeOp.getOutput();
auto reshapeInputType =
llvm::dyn_cast<RankedTensorType>(reshapeOp.getInput1().getType());
auto reshapeInputShape = reshapeInputType.getShape();
// want reshape N -> 1x1x...x1xNx1x...x1x1
if (!reshapeInputType || reshapeInputShape.size() != 1)
return std::nullopt;
auto reshapeOutputType =
llvm::cast<RankedTensorType>(reshapeOutput.getType());
// Instead of inserting a TransposeOp here, we check if we can fold it into
// the ReshapeOp. There is more complex cases where this is possible, and
// this check can be extended.
// Checking if reshape is N -> 1x1x...x1xNx1x...x1x1
auto shape = reshapeOutputType.getShape();
size_t ones = llvm::count(shape, 1);
// N == 1 and N != 1
if (ones != shape.size() - 1 &&
!(ones == shape.size() && reshapeInputShape[0] == 1))
return std::nullopt;
// Do not insert a TransposeOp, instead we fold the reshape and its attribute.
auto foldedReshape = rewriter.create<ReshapeOp>(
reshapeOp.getLoc(),
RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms),
reshapeOutputType.getElementType()),
reshapeOp.getInput1(),
rewriter.getDenseI64ArrayAttr(
applyTOSAPermutation(reshapeOp.getNewShape(), hoistedPerms)));
return foldedReshape->getResult(0);
}
std::optional<Value> TosaReduceTransposes::buildMappedToValue(
ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValue());
if (!denseAttr)
return std::nullopt;
auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, hoistedPerms);
if (!maybeNewDenseAttr.has_value())
return std::nullopt;
auto newDenseAttr = maybeNewDenseAttr.value();
auto newConstOp = rewriter.create<ConstOp>(
constOp.getLoc(), newDenseAttr.getType(), newDenseAttr);
return newConstOp->getResult(0);
}
bool TosaReduceTransposes::convertDependentOps(
SetVector<Operation *> &dependentOps, DenseMap<Value, Value> &valuesMap,
IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
for (Operation *op : dependentOps) {
if (!op || op->getNumResults() != 1)
return false;
Value priorValue = op->getResult(0);
// It's possible on a prior transposeOp we had the same dependency and
// already resolved it.
if (valuesMap.contains(priorValue))
continue;
// Keep converted ops close to the original.
rewriter.setInsertionPointAfter(op);
std::optional<Value> maybeValue =
llvm::TypeSwitch<Operation *, std::optional<Value>>(op)
.Case<TransposeOp, ReshapeOp, ConstOp>([&](auto transposeOp) {
return buildMappedToValue(transposeOp, valuesMap, rewriter,
hoistedPerms);
})
.Default([&](Operation *op) {
return buildMappedToValue(op, valuesMap, rewriter, hoistedPerms);
});
if (!maybeValue.has_value())
return false;
valuesMap[priorValue] = maybeValue.value();
}
return true;
}
bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies(
Operation *user, std::set<TransposeOp> &validTransposes,
std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
&transposeInfo) {
return llvm::none_of(
transposeInfo,
[&validTransposes,
user](const std::pair<TransposeOp, SetVector<Operation *>> &info) {
const auto &[transposeOp, dependentOps] = info;
return validTransposes.count(transposeOp) &&
dependentOps.contains(user);
});
}
// Dependencies are valid for an operation if none of them occur outside
// of the proper fan-in cones of the hoisted TransposeOp with the same perms
// that we can replace. Described in more detail within.
bool TosaReduceTransposes::dependenciesAreValid(
ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
std::set<TransposeOp> &validTransposes,
std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
&transposeInfo) {
for (Operation *op : dependentOps) {
// It's OK wherever ConstOp has uses -- in the worst case, we duplicate.
// This can be changed later if we find the memory impact is too high.
if (llvm::isa<ConstOp>(op))
continue;
for (OpOperand &use : op->getUses()) {
// Want the uses to be (1) contained in the dependentOps of other
// validTransposes, or (2) to be directly used in a TransposeOp with the
// same perms. For (2) it means the fan-in is a subset of our
// dependentOps, so it is also a validTranspose that will eventually be
// replaced.
Operation *user = use.getOwner();
if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
SmallVector<int32_t> otherPerms;
// Can later think about cases where transpose -> transpose
// or reshape -> transpose, where the transposes are not necessarily
// the same perms as the hoisted, if implementing a more general
// transform. These could be permitted.
if (failed(otherTranspose.getConstantPerms(otherPerms)) ||
!llvm::equal(perms, otherPerms))
return false;
} else if (userNotContainedInValidTransposeDependencies(
user, validTransposes, transposeInfo)) {
return false;
}
}
}
return true;
}
// Getting the set of TransposeOp that we can replace without causing
// the old fan-in cones of any TransposeOp to remain "live", i.e, -- not being
// dead code. This is done by iterating the set until convergence, since
// if you are used outside your own fan-in cone, it's possible to be used
// in another fan-in cone of a TransposeOp that is being replaced -- unless
// we find that that one has a usage outside of it too.
std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements(
ArrayRef<int32_t> perms,
std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
&transposeInfo) {
// Initially, we assume they are all good to replace,
// and we whittle them down based on our criteria.
std::set<TransposeOp> ableToReplace;
for (const auto &[transposeOp, _] : transposeInfo)
ableToReplace.insert(transposeOp);
bool gotRid;
do {
gotRid = false;
for (const auto &[transposeOp, dependentOps] : transposeInfo) {
// We don't care about it. Already invalidated.
if (!ableToReplace.count(transposeOp))
continue;
// Check for validity.
if (!dependenciesAreValid(perms, dependentOps, ableToReplace,
transposeInfo)) {
ableToReplace.erase(transposeOp);
gotRid = true;
break;
}
}
} while (gotRid);
return ableToReplace;
}
void TosaReduceTransposes::runOnOperation() {
// We want to operate only within a single block.
if (!getOperation().getRegion().hasOneBlock())
return;
IRRewriter rewriter(&getContext());
// For each perms, maintain a mapping for converted ops, avoid duplication.
DenseMap<ArrayRef<int32_t>, DenseMap<Value, Value>> permsToValues;
// For each perms, we keep track of which TransposeOp are eligible
// for replacement alongside their dependentOps.
DenseMap<ArrayRef<int32_t>,
std::vector<std::pair<TransposeOp, SetVector<Operation *>>>>
permsToTransposeInfo;
// Necessary for lifetime, since DenseMap keeps a copy of the ArrayRef.
// Use SmallVector for perms (common-case is <= 4) but std::vector otherwise
// since no guarantee of smallness.
std::vector<SmallVector<int32_t>> collectedPerms;
// This keeps track of the order across all eligible-for-replacement
// TransposeOp and their perms, a necessity for the final replacements.
std::stack<std::pair<TransposeOp, ArrayRef<int32_t>>> totalTransposeOrder;
// We want to reserve the space up front, since SmallVector stores some data
// internally and the ArrayRef can reference that, which we don't want to get
// invalidated.
size_t expectedMaxPerms = 0;
getOperation().walk([&](TransposeOp) { expectedMaxPerms += 1; });
collectedPerms.reserve(expectedMaxPerms);
getOperation().walk([&](TransposeOp transposeOp) {
SetVector<Operation *> dependentOps;
collectedPerms.emplace_back();
SmallVector<int32_t> &perms = collectedPerms.back();
// Dynamic shapes are OK, but the incompatible ones will be rejected later.
auto input = transposeOp.getInput1();
auto output = transposeOp.getOutput();
// However, we don't support unranked tensors.
if (!llvm::isa<RankedTensorType>(input.getType()) ||
!llvm::isa<RankedTensorType>(output.getType()))
return;
// No transformation when transpose permutation non-constant.
if (failed(transposeOp.getConstantPerms(perms)))
return;
// We let --canonicalize deal with identity transpose.
if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
return;
// Can fail if some set of basic invariants is not met that we want to
// perform our conversions.
if (!collectFanIn(input.getDefiningOp(), dependentOps))
return;
// Want to associate valuesMap for already converted of the same perms,
// since it's possible multiple hoisted transposes w/ different perms
// converge on an op, which would result in different transformations.
DenseMap<Value, Value> &valuesMap = permsToValues[perms];
// Attempt to perform the conversions and placements into IR
// without turning inserted code "live". Also fills out valuesMap.
// Fails if there is an intermediary we do not support.
if (!convertDependentOps(dependentOps, valuesMap, rewriter, perms))
// Some additional operations may have been inserted, but will be
// removed by dead code elimination.
return;
// This should not happen. If it does -- it's unexpected,
// so we fail the pass.
if (!valuesMap.contains(input))
return signalPassFailure();
// It's possible the types are not compatible (because of dynamic shapes),
// and in these cases, want to resolve dynamic shapes before running the
// pass.
if (output.getType() != valuesMap.at(input).getType())
return;
auto &transposeInfo = permsToTransposeInfo[perms];
// In general, we might also want to introduce "newDependentOps"
// if there are new usages that don't fall inside the original fan-ins
// (like the TransposeOp we insert for ReshapeOp),
// but in this case, that is specialized enough and overlaps
// with another direct-use TransposeOp case we need to cover anyway.
transposeInfo.push_back({transposeOp, dependentOps});
// This is for the final replacement across all transposes.
totalTransposeOrder.push({transposeOp, perms});
});
// We want to do a full fan-in analysis on a perms-level,
// since if we do it on a multi-perms level, and they share (due to a shared
// dependency on a Reshape) then we would also get duplicate ops.
// Const is special cased.
std::set<TransposeOp> ableToReplace;
for (auto &[perms, transposeInfo] : permsToTransposeInfo) {
// Gives us back replacements that would never result in any duplicate
// operations being inserted by us in the IR (i.e, our goal is only to
// remove transposes, and not create a "new chain" to do so, but replace
// the existing chains).
// Ideally, --canonicalize is run before this pass, since it helps this
// analysis by removing dead code to allow more potentially acceptable
// transformations.
auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo);
ableToReplace.insert(goodReplacementsForPerms.begin(),
goodReplacementsForPerms.end());
}
// We want to do replacement across all transposes
// in reverse order, due to invalidation of valuesMap mappings
// if we did it otherwise.
while (!totalTransposeOrder.empty()) {
auto [transposeOp, perms] = totalTransposeOrder.top();
totalTransposeOrder.pop();
if (ableToReplace.count(transposeOp) == 0)
continue;
auto &valuesMap = permsToValues[perms];
auto input = transposeOp.getInput1();
// The purpose of this reverse iteration
// is to avoid valuesMap invalidation. If it happens,
// something is wrong.
if (!valuesMap.contains(input))
return signalPassFailure();
rewriter.replaceOp(transposeOp, valuesMap.at(input));
}
// We can remove all dead code by going in reverse.
// This is because we would remove usages before we
// see the users.
getOperation().walk<WalkOrder::PostOrder, ReverseIterator>(
[&](Operation *op) {
if (isOpTriviallyDead(op))
rewriter.eraseOp(op);
});
}
} // namespace