blob: f9f32c22f8d4d335a83d531787f401b6c5c471f0 [file] [log] [blame]
//===- ShapeInferencePass.cpp - Shape Inference ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a Function level pass performing interprocedural
// propagation of array shapes through function specialization.
//
//===----------------------------------------------------------------------===//
#include "mlir/Pass/Pass.h"
#include "toy/Dialect.h"
#include "toy/Passes.h"
#include "toy/ShapeInferenceInterface.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "shape-inference"
using namespace mlir;
using namespace toy;
/// Include the auto-generated definitions for the shape inference interfaces.
#include "toy/ShapeInferenceOpInterfaces.cpp.inc"
namespace {
/// The ShapeInferencePass is a FunctionPass that performs intra-procedural
/// shape inference.
///
/// Algorithm:
///
/// 1) Build a worklist containing all the operations that return a
/// dynamically shaped tensor: these are the operations that need shape
/// inference.
/// 2) Iterate on the worklist:
/// a) find an operation to process: the next ready operation in the
/// worklist has all of its arguments non-generic,
/// b) if no operation is found, break out of the loop,
/// c) remove the operation from the worklist,
/// d) infer the shape of its output from the argument types.
/// 3) If the worklist is empty, the algorithm succeeded.
///
class ShapeInferencePass
: public mlir::PassWrapper<ShapeInferencePass, FunctionPass> {
public:
void runOnFunction() override {
auto f = getFunction();
// Populate the worklist with the operations that need shape inference:
// these are operations that return a dynamic shape.
llvm::SmallPtrSet<mlir::Operation *, 16> opWorklist;
f.walk([&](mlir::Operation *op) {
if (returnsDynamicShape(op))
opWorklist.insert(op);
});
// Iterate on the operations in the worklist until all operations have been
// inferred or no change happened (fix point).
while (!opWorklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
auto nextop = llvm::find_if(opWorklist, allOperandsInferred);
if (nextop == opWorklist.end())
break;
Operation *op = *nextop;
opWorklist.erase(op);
// Ask the operation to infer its output shapes.
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
if (auto shapeOp = dyn_cast<ShapeInference>(op)) {
shapeOp.inferShapes();
} else {
op->emitError("unable to infer shape of operation without shape "
"inference interface");
return signalPassFailure();
}
}
// If the operation worklist isn't empty, this indicates a failure.
if (!opWorklist.empty()) {
f.emitError("Shape inference failed, ")
<< opWorklist.size() << " operations couldn't be inferred\n";
signalPassFailure();
}
}
/// A utility method that returns if the given operation has all of its
/// operands inferred.
static bool allOperandsInferred(Operation *op) {
return llvm::all_of(op->getOperandTypes(), [](Type operandType) {
return operandType.isa<RankedTensorType>();
});
}
/// A utility method that returns if the given operation has a dynamically
/// shaped result.
static bool returnsDynamicShape(Operation *op) {
return llvm::any_of(op->getResultTypes(), [](Type resultType) {
return !resultType.isa<RankedTensorType>();
});
}
};
} // end anonymous namespace
/// Create a Shape Inference pass.
std::unique_ptr<mlir::Pass> mlir::toy::createShapeInferencePass() {
return std::make_unique<ShapeInferencePass>();
}