| # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| # This file contains the Sparsifier class. |
| |
| from mlir import execution_engine |
| from mlir import ir |
| from mlir import passmanager |
| from typing import Sequence |
| |
| |
| class Sparsifier: |
| """Sparsifier class for compiling and building MLIR modules.""" |
| |
| def __init__( |
| self, |
| extras: str, |
| options: str, |
| opt_level: int, |
| shared_libs: Sequence[str], |
| ): |
| pipeline = ( |
| f"builtin.module({extras}sparsifier{{{options} reassociate-fp-reductions=1" |
| " enable-index-optimizations=1})" |
| ) |
| self.pipeline = pipeline |
| self.opt_level = opt_level |
| self.shared_libs = shared_libs |
| |
| def __call__(self, module: ir.Module): |
| """Convenience application method.""" |
| self.compile(module) |
| |
| def compile(self, module: ir.Module): |
| """Compiles the module by invoking the sparsifier pipeline.""" |
| passmanager.PassManager.parse(self.pipeline).run(module.operation) |
| |
| def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: |
| """Wraps the module in a JIT execution engine.""" |
| return execution_engine.ExecutionEngine( |
| module, opt_level=self.opt_level, shared_libs=self.shared_libs |
| ) |
| |
| def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: |
| """Compiles and jits the module.""" |
| self.compile(module) |
| return self.jit(module) |