[mlir][EmitC] Add pass to wrap a func in class (#141158) Goal: Enable using C++ classes to AOT compile models for MLGO. This commit introduces a transformation pass that converts standalone `emitc.func` operations into `emitc.class `structures to support class-based C++ code generation for MLGO. Transformation details: - Wrap `emitc.func @func_name` into `emitc.class @Myfunc_nameClass` - Converts function arguments to class fields with preserved attributes - Transforms function body into an `execute()` method with no arguments - Replaces argument references with `get_field` operations Before: emitc.func @Model(%arg0, %arg1, %arg2) with direct argument access After: emitc.class with fields and execute() method using get_field operations This enables generating C++ classes that can be instantiated and executed as self-contained model objects for AOT compilation workflows.
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 9ecdb74f..91ee899 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1593,4 +1593,90 @@ let hasVerifier = 1; } +def EmitC_ClassOp + : EmitC_Op<"class", [AutomaticAllocationScope, IsolatedFromAbove, + OpAsmOpInterface, SymbolTable, + Symbol]#GraphRegionNoTerminator.traits> { + let summary = + "Represents a C++ class definition, encapsulating fields and methods."; + + let description = [{ + The `emitc.class` operation defines a C++ class, acting as a container + for its data fields (`emitc.field`) and methods (`emitc.func`). + It creates a distinct scope, isolating its contents from the surrounding + MLIR region, similar to how C++ classes encapsulate their internals. + + Example: + + ```mlir + emitc.class @modelClass { + emitc.field @fieldName0 : !emitc.array<1xf32> = {emitc.opaque = "input_tensor"} + emitc.func @execute() { + %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t + %1 = get_field @fieldName0 : !emitc.array<1xf32> + %2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32> + return + } + } + ``` + }]; + + let arguments = (ins SymbolNameAttr:$sym_name); + + let regions = (region AnyRegion:$body); + + let extraClassDeclaration = [{ + // Returns the body block containing class members and methods. + Block &getBlock(); + }]; + + let hasCustomAssemblyFormat = 1; + + let assemblyFormat = [{ $sym_name attr-dict-with-keyword $body }]; +} + +def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> { + let summary = "A field within a class"; + let description = [{ + The `emitc.field` operation declares a named field within an `emitc.class` + operation. The field's type must be an EmitC type. + + Example: + + ```mlir + // Example with an attribute: + emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.opaque = "another_feature"} + // Example with no attribute: + emitc.field @fieldName0 : !emitc.array<1xf32> + ``` + }]; + + let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type, + OptionalAttr<AnyAttr>:$attrs); + + let assemblyFormat = [{ $sym_name `:` $type ($attrs^)? attr-dict}]; + + let hasVerifier = 1; +} + +def EmitC_GetFieldOp + : EmitC_Op<"get_field", [Pure, DeclareOpInterfaceMethods< + SymbolUserOpInterface>]> { + let summary = "Obtain access to a field within a class instance"; + let description = [{ + The `emitc.get_field` operation retrieves the lvalue of a + named field from a given class instance. + + Example: + + ```mlir + %0 = get_field @fieldName0 : !emitc.array<1xf32> + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$field_name); + let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result); + let assemblyFormat = "$field_name `:` type($result) attr-dict"; +} + #endif // MLIR_DIALECT_EMITC_IR_EMITC
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h index 5a103f1..1af4aa0 100644 --- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
@@ -15,6 +15,7 @@ namespace emitc { #define GEN_PASS_DECL_FORMEXPRESSIONSPASS +#define GEN_PASS_DECL_WRAPFUNCINCLASSPASS #include "mlir/Dialect/EmitC/Transforms/Passes.h.inc" //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td index f46b705..74c4913 100644 --- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -20,4 +20,42 @@ let dependentDialects = ["emitc::EmitCDialect"]; } +def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> { + let summary = "Wrap functions in classes, using arguments as fields."; + let description = [{ + This pass transforms `emitc.func` operations into `emitc.class` operations. + Function arguments become fields of the class, and the function body is moved + to a new `execute` method within the class. + If the corresponding function argument has attributes (accessed via `argAttrs`), + these attributes are attached to the field operation. + Otherwise, the field is created without additional attributes. + + Example: + + ```mlir + emitc.func @model(%input_data : !emitc.array<1xf32> {emitc.opaque = "input_tensor"}) attributes { } { + %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t + %1 = subscript %input_data[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32> + return + } + // becomes + emitc.class @modelClass { + emitc.field @input_tensor : !emitc.array<1xf32> {emitc.opaque = "input_tensor"} + emitc.func @execute() { + %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t + %1 = get_field @input_tensor : !emitc.array<1xf32> + %2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32> + return + } + } + ``` + }]; + let dependentDialects = ["emitc::EmitCDialect"]; + let options = [Option< + "namedAttribute", "named-attribute", "std::string", + /*default=*/"", + "Attribute key used to extract field names from function argument's " + "dictionary attributes">]; +} + #endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h index 2574acd..a4e8fe1 100644 --- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -28,6 +28,10 @@ /// Populates `patterns` with expression-related patterns. void populateExpressionPatterns(RewritePatternSet &patterns); +/// Populates 'patterns' with func-related patterns. +void populateFuncPatterns(RewritePatternSet &patterns, + StringRef namedAttribute); + } // namespace emitc } // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index e602210..d17c4af 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -1401,6 +1401,49 @@ } //===----------------------------------------------------------------------===// +// FieldOp +//===----------------------------------------------------------------------===// +LogicalResult FieldOp::verify() { + if (!isSupportedEmitCType(getType())) + return emitOpError("expected valid emitc type"); + + Operation *parentOp = getOperation()->getParentOp(); + if (!parentOp || !isa<emitc::ClassOp>(parentOp)) + return emitOpError("field must be nested within an emitc.class operation"); + + StringAttr symName = getSymNameAttr(); + if (!symName || symName.getValue().empty()) + return emitOpError("field must have a non-empty symbol name"); + + if (!getAttrs()) + return success(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// GetFieldOp +//===----------------------------------------------------------------------===// +LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr(); + FieldOp fieldOp = + symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, fieldNameAttr); + if (!fieldOp) + return emitOpError("field '") + << fieldNameAttr << "' not found in the class"; + + Type getFieldResultType = getResult().getType(); + Type fieldType = fieldOp.getType(); + + if (fieldType != getFieldResultType) + return emitOpError("result type ") + << getFieldResultType << " does not match field '" << fieldNameAttr + << "' type " << fieldType; + + return success(); +} + +//===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt index 19b80b22..baf67af 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ Transforms.cpp FormExpressions.cpp TypeConversions.cpp + WrapFuncInClass.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp new file mode 100644 index 0000000..17d436f --- /dev/null +++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -0,0 +1,112 @@ +//===- WrapFuncInClass.cpp - Wrap Emitc Funcs in classes -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/Transforms/Passes.h" +#include "mlir/Dialect/EmitC/Transforms/Transforms.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" + +using namespace mlir; +using namespace emitc; + +namespace mlir { +namespace emitc { +#define GEN_PASS_DEF_WRAPFUNCINCLASSPASS +#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc" + +namespace { +struct WrapFuncInClassPass + : public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> { + using WrapFuncInClassPassBase::WrapFuncInClassPassBase; + void runOnOperation() override { + Operation *rootOp = getOperation(); + + RewritePatternSet patterns(&getContext()); + populateFuncPatterns(patterns, namedAttribute); + + walkAndApplyPatterns(rootOp, std::move(patterns)); + } +}; + +} // namespace +} // namespace emitc +} // namespace mlir + +class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> { +public: + WrapFuncInClass(MLIRContext *context, StringRef attrName) + : OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {} + + LogicalResult matchAndRewrite(emitc::FuncOp funcOp, + PatternRewriter &rewriter) const override { + + auto className = funcOp.getSymNameAttr().str() + "Class"; + ClassOp newClassOp = rewriter.create<ClassOp>(funcOp.getLoc(), className); + + SmallVector<std::pair<StringAttr, TypeAttr>> fields; + rewriter.createBlock(&newClassOp.getBody()); + rewriter.setInsertionPointToStart(&newClassOp.getBody().front()); + + auto argAttrs = funcOp.getArgAttrs(); + for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) { + StringAttr fieldName; + Attribute argAttr = nullptr; + + fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx)); + if (argAttrs && idx < argAttrs->size()) + argAttr = (*argAttrs)[idx]; + + TypeAttr typeAttr = TypeAttr::get(val.getType()); + fields.push_back({fieldName, typeAttr}); + rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr, + argAttr); + } + + rewriter.setInsertionPointToEnd(&newClassOp.getBody().front()); + FunctionType funcType = funcOp.getFunctionType(); + Location loc = funcOp.getLoc(); + FuncOp newFuncOp = + rewriter.create<emitc::FuncOp>(loc, ("execute"), funcType); + + rewriter.createBlock(&newFuncOp.getBody()); + newFuncOp.getBody().takeBody(funcOp.getBody()); + + rewriter.setInsertionPointToStart(&newFuncOp.getBody().front()); + std::vector<Value> newArguments; + newArguments.reserve(fields.size()); + for (auto &[fieldName, attr] : fields) { + GetFieldOp arg = + rewriter.create<emitc::GetFieldOp>(loc, attr.getValue(), fieldName); + newArguments.push_back(arg); + } + + for (auto [oldArg, newArg] : + llvm::zip(newFuncOp.getArguments(), newArguments)) { + rewriter.replaceAllUsesWith(oldArg, newArg); + } + + llvm::BitVector argsToErase(newFuncOp.getNumArguments(), true); + if (failed(newFuncOp.eraseArguments(argsToErase))) + newFuncOp->emitOpError("failed to erase all arguments using BitVector"); + + rewriter.replaceOp(funcOp, newClassOp); + return success(); + } + +private: + StringRef attributeName; +}; + +void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns, + StringRef namedAttribute) { + patterns.add<WrapFuncInClass>(patterns.getContext(), namedAttribute); +}
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir new file mode 100644 index 0000000..c67a0c1 --- /dev/null +++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
@@ -0,0 +1,40 @@ +// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.name_hint' %s | FileCheck %s + +module attributes { } { + emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.name_hint = "another_feature"}, + %arg1: !emitc.array<1xf32> {emitc.name_hint = "some_feature"}, + %arg2: !emitc.array<1xf32> {emitc.name_hint = "output_0"}) attributes { } { + %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t + %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32> + %2 = load %1 : <f32> + %3 = subscript %arg0[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32> + %4 = load %3 : <f32> + %5 = add %2, %4 : (f32, f32) -> f32 + %6 = subscript %arg2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32> + assign %5 : f32 to %6 : <f32> + return + } +} + + +// CHECK: module { +// CHECK-NEXT: emitc.class @modelClass { +// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"} +// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.name_hint = "some_feature"} +// CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.name_hint = "output_0"} +// CHECK-NEXT: emitc.func @execute() { +// CHECK-NEXT: get_field @fieldName0 : !emitc.array<1xf32> +// CHECK-NEXT: get_field @fieldName1 : !emitc.array<1xf32> +// CHECK-NEXT: get_field @fieldName2 : !emitc.array<1xf32> +// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t +// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32> +// CHECK-NEXT: load {{.*}} : <f32> +// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32> +// CHECK-NEXT: load {{.*}} : <f32> +// CHECK-NEXT: add {{.*}}, {{.*}} : (f32, f32) -> f32 +// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32> +// CHECK-NEXT: assign {{.*}} : f32 to {{.*}} : <f32> +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir new file mode 100644 index 0000000..92ed20c --- /dev/null +++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir
@@ -0,0 +1,17 @@ +// RUN: mlir-opt --wrap-emitc-func-in-class %s | FileCheck %s + +emitc.func @foo(%arg0 : !emitc.array<1xf32>) { + emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> () + emitc.return +} + +// CHECK: module { +// CHECK-NEXT: emitc.class @fooClass { +// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> +// CHECK-NEXT: emitc.func @execute() { +// CHECK-NEXT: %0 = get_field @fieldName0 : !emitc.array<1xf32> +// CHECK-NEXT: call_opaque "bar"(%0) : (!emitc.array<1xf32>) -> () +// CHECK-NEXT: return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: }