[mlir][EmitC] Add pass to wrap a func in class (#141158)

Goal: Enable using C++ classes to AOT compile models for MLGO.

This commit introduces a transformation pass that converts standalone
`emitc.func` operations into `emitc.class `structures to support
class-based C++ code generation for MLGO.

Transformation details:
- Wrap `emitc.func @func_name` into `emitc.class @Myfunc_nameClass`
- Converts function arguments to class fields with preserved attributes
- Transforms function body into an `execute()` method with no arguments
- Replaces argument references with `get_field` operations

Before: emitc.func @Model(%arg0, %arg1, %arg2) with direct argument
access
After: emitc.class with fields and execute() method using get_field
operations

This enables generating C++ classes that can be instantiated and
executed as self-contained model objects for AOT compilation workflows.
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 9ecdb74f..91ee899 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1593,4 +1593,90 @@
   let hasVerifier = 1;
 }
 
+def EmitC_ClassOp
+    : EmitC_Op<"class", [AutomaticAllocationScope, IsolatedFromAbove,
+                         OpAsmOpInterface, SymbolTable,
+                         Symbol]#GraphRegionNoTerminator.traits> {
+  let summary =
+      "Represents a C++ class definition, encapsulating fields and methods.";
+
+  let description = [{
+    The `emitc.class` operation defines a C++ class, acting as a container
+    for its data fields (`emitc.field`) and methods (`emitc.func`).
+    It creates a distinct scope, isolating its contents from the surrounding
+    MLIR region, similar to how C++ classes encapsulate their internals.
+
+    Example:
+
+    ```mlir
+    emitc.class @modelClass {
+      emitc.field @fieldName0 : !emitc.array<1xf32> = {emitc.opaque = "input_tensor"}
+      emitc.func @execute() {
+        %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+        %1 = get_field @fieldName0 : !emitc.array<1xf32>
+        %2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+        return
+      }
+    }
+    ```
+  }];
+
+  let arguments = (ins SymbolNameAttr:$sym_name);
+
+  let regions = (region AnyRegion:$body);
+
+  let extraClassDeclaration = [{
+    // Returns the body block containing class members and methods.
+    Block &getBlock();
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+
+  let assemblyFormat = [{ $sym_name attr-dict-with-keyword $body }];
+}
+
+def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
+  let summary = "A field within a class";
+  let description = [{
+    The `emitc.field` operation declares a named field within an `emitc.class`
+    operation. The field's type must be an EmitC type. 
+
+    Example:
+
+    ```mlir
+    // Example with an attribute:
+    emitc.field @fieldName0 : !emitc.array<1xf32>  {emitc.opaque = "another_feature"}
+    // Example with no attribute:
+    emitc.field @fieldName0 : !emitc.array<1xf32>
+    ```
+  }];
+
+  let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type,
+      OptionalAttr<AnyAttr>:$attrs);
+
+  let assemblyFormat = [{ $sym_name `:` $type ($attrs^)? attr-dict}];
+
+  let hasVerifier = 1;
+}
+
+def EmitC_GetFieldOp
+    : EmitC_Op<"get_field", [Pure, DeclareOpInterfaceMethods<
+                                       SymbolUserOpInterface>]> {
+  let summary = "Obtain access to a field within a class instance";
+  let description = [{
+     The `emitc.get_field` operation retrieves the lvalue of a
+     named field from a given class instance.
+
+     Example:
+
+     ```mlir
+     %0 = get_field @fieldName0 : !emitc.array<1xf32>
+     ```
+  }];
+
+  let arguments = (ins FlatSymbolRefAttr:$field_name);
+  let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result);
+  let assemblyFormat = "$field_name `:` type($result) attr-dict";
+}
+
 #endif // MLIR_DIALECT_EMITC_IR_EMITC
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
index 5a103f1..1af4aa0 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
@@ -15,6 +15,7 @@
 namespace emitc {
 
 #define GEN_PASS_DECL_FORMEXPRESSIONSPASS
+#define GEN_PASS_DECL_WRAPFUNCINCLASSPASS
 #include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index f46b705..74c4913 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -20,4 +20,42 @@
   let dependentDialects = ["emitc::EmitCDialect"];
 }
 
+def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
+  let summary = "Wrap functions in classes, using arguments as fields.";
+  let description = [{
+    This pass transforms `emitc.func` operations into `emitc.class` operations.
+    Function arguments become fields of the class, and the function body is moved
+    to a new `execute` method within the class.
+    If the corresponding function argument has attributes (accessed via `argAttrs`), 
+    these attributes are attached to the field operation. 
+    Otherwise, the field is created without additional attributes.
+
+    Example:
+    
+    ```mlir
+    emitc.func @model(%input_data : !emitc.array<1xf32> {emitc.opaque = "input_tensor"}) attributes { } {
+      %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+      %1 = subscript %input_data[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+      return
+    }
+    // becomes 
+    emitc.class @modelClass {
+      emitc.field @input_tensor : !emitc.array<1xf32> {emitc.opaque = "input_tensor"}
+      emitc.func @execute() {
+        %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+        %1 = get_field @input_tensor : !emitc.array<1xf32>
+        %2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+        return
+      }
+    }
+    ```
+  }];
+  let dependentDialects = ["emitc::EmitCDialect"];
+  let options = [Option<
+      "namedAttribute", "named-attribute", "std::string",
+      /*default=*/"",
+      "Attribute key used to extract field names from function argument's "
+      "dictionary attributes">];
+}
+
 #endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index 2574acd..a4e8fe1 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -28,6 +28,10 @@
 /// Populates `patterns` with expression-related patterns.
 void populateExpressionPatterns(RewritePatternSet &patterns);
 
+/// Populates 'patterns' with func-related patterns.
+void populateFuncPatterns(RewritePatternSet &patterns,
+                          StringRef namedAttribute);
+
 } // namespace emitc
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index e602210..d17c4af 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -1401,6 +1401,49 @@
 }
 
 //===----------------------------------------------------------------------===//
+// FieldOp
+//===----------------------------------------------------------------------===//
+LogicalResult FieldOp::verify() {
+  if (!isSupportedEmitCType(getType()))
+    return emitOpError("expected valid emitc type");
+
+  Operation *parentOp = getOperation()->getParentOp();
+  if (!parentOp || !isa<emitc::ClassOp>(parentOp))
+    return emitOpError("field must be nested within an emitc.class operation");
+
+  StringAttr symName = getSymNameAttr();
+  if (!symName || symName.getValue().empty())
+    return emitOpError("field must have a non-empty symbol name");
+
+  if (!getAttrs())
+    return success();
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GetFieldOp
+//===----------------------------------------------------------------------===//
+LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr();
+  FieldOp fieldOp =
+      symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, fieldNameAttr);
+  if (!fieldOp)
+    return emitOpError("field '")
+           << fieldNameAttr << "' not found in the class";
+
+  Type getFieldResultType = getResult().getType();
+  Type fieldType = fieldOp.getType();
+
+  if (fieldType != getFieldResultType)
+    return emitOpError("result type ")
+           << getFieldResultType << " does not match field '" << fieldNameAttr
+           << "' type " << fieldType;
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
 
diff --git a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
index 19b80b22..baf67af 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@
   Transforms.cpp
   FormExpressions.cpp
   TypeConversions.cpp
+  WrapFuncInClass.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
new file mode 100644
index 0000000..17d436f
--- /dev/null
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -0,0 +1,112 @@
+//===- WrapFuncInClass.cpp - Wrap Emitc Funcs in classes -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/EmitC/Transforms/Passes.h"
+#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace emitc;
+
+namespace mlir {
+namespace emitc {
+#define GEN_PASS_DEF_WRAPFUNCINCLASSPASS
+#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
+
+namespace {
+struct WrapFuncInClassPass
+    : public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
+  using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
+  void runOnOperation() override {
+    Operation *rootOp = getOperation();
+
+    RewritePatternSet patterns(&getContext());
+    populateFuncPatterns(patterns, namedAttribute);
+
+    walkAndApplyPatterns(rootOp, std::move(patterns));
+  }
+};
+
+} // namespace
+} // namespace emitc
+} // namespace mlir
+
+class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
+public:
+  WrapFuncInClass(MLIRContext *context, StringRef attrName)
+      : OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {}
+
+  LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
+                                PatternRewriter &rewriter) const override {
+
+    auto className = funcOp.getSymNameAttr().str() + "Class";
+    ClassOp newClassOp = rewriter.create<ClassOp>(funcOp.getLoc(), className);
+
+    SmallVector<std::pair<StringAttr, TypeAttr>> fields;
+    rewriter.createBlock(&newClassOp.getBody());
+    rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
+
+    auto argAttrs = funcOp.getArgAttrs();
+    for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) {
+      StringAttr fieldName;
+      Attribute argAttr = nullptr;
+
+      fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
+      if (argAttrs && idx < argAttrs->size())
+        argAttr = (*argAttrs)[idx];
+
+      TypeAttr typeAttr = TypeAttr::get(val.getType());
+      fields.push_back({fieldName, typeAttr});
+      rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
+                                      argAttr);
+    }
+
+    rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
+    FunctionType funcType = funcOp.getFunctionType();
+    Location loc = funcOp.getLoc();
+    FuncOp newFuncOp =
+        rewriter.create<emitc::FuncOp>(loc, ("execute"), funcType);
+
+    rewriter.createBlock(&newFuncOp.getBody());
+    newFuncOp.getBody().takeBody(funcOp.getBody());
+
+    rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
+    std::vector<Value> newArguments;
+    newArguments.reserve(fields.size());
+    for (auto &[fieldName, attr] : fields) {
+      GetFieldOp arg =
+          rewriter.create<emitc::GetFieldOp>(loc, attr.getValue(), fieldName);
+      newArguments.push_back(arg);
+    }
+
+    for (auto [oldArg, newArg] :
+         llvm::zip(newFuncOp.getArguments(), newArguments)) {
+      rewriter.replaceAllUsesWith(oldArg, newArg);
+    }
+
+    llvm::BitVector argsToErase(newFuncOp.getNumArguments(), true);
+    if (failed(newFuncOp.eraseArguments(argsToErase)))
+      newFuncOp->emitOpError("failed to erase all arguments using BitVector");
+
+    rewriter.replaceOp(funcOp, newClassOp);
+    return success();
+  }
+
+private:
+  StringRef attributeName;
+};
+
+void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns,
+                                       StringRef namedAttribute) {
+  patterns.add<WrapFuncInClass>(patterns.getContext(), namedAttribute);
+}
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
new file mode 100644
index 0000000..c67a0c1
--- /dev/null
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.name_hint' %s | FileCheck %s
+
+module attributes { } {
+  emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.name_hint = "another_feature"}, 
+   %arg1: !emitc.array<1xf32> {emitc.name_hint = "some_feature"},
+   %arg2: !emitc.array<1xf32> {emitc.name_hint = "output_0"}) attributes { } {
+    %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+    %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+    %2 = load %1 : <f32>
+    %3 = subscript %arg0[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+    %4 = load %3 : <f32>
+    %5 = add %2, %4 : (f32, f32) -> f32
+    %6 = subscript %arg2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+    assign %5 : f32 to %6 : <f32>
+    return
+  }
+}
+
+
+// CHECK: module {
+// CHECK-NEXT:   emitc.class @modelClass {
+// CHECK-NEXT:     emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"}
+// CHECK-NEXT:     emitc.field @fieldName1 : !emitc.array<1xf32>  {emitc.name_hint = "some_feature"}
+// CHECK-NEXT:     emitc.field @fieldName2 : !emitc.array<1xf32>  {emitc.name_hint = "output_0"}
+// CHECK-NEXT:     emitc.func @execute() {
+// CHECK-NEXT:       get_field @fieldName0 : !emitc.array<1xf32>
+// CHECK-NEXT:       get_field @fieldName1 : !emitc.array<1xf32>
+// CHECK-NEXT:       get_field @fieldName2 : !emitc.array<1xf32>
+// CHECK-NEXT:       "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+// CHECK-NEXT:       subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK-NEXT:       load {{.*}} : <f32>
+// CHECK-NEXT:       subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK-NEXT:       load {{.*}} : <f32>
+// CHECK-NEXT:       add {{.*}}, {{.*}} : (f32, f32) -> f32
+// CHECK-NEXT:       subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK-NEXT:       assign {{.*}} : f32 to {{.*}} : <f32>
+// CHECK-NEXT:       return
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir
new file mode 100644
index 0000000..92ed20c
--- /dev/null
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt --wrap-emitc-func-in-class %s | FileCheck %s
+
+emitc.func @foo(%arg0 : !emitc.array<1xf32>) {
+  emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> ()
+  emitc.return
+}
+
+// CHECK: module {
+// CHECK-NEXT:   emitc.class @fooClass {
+// CHECK-NEXT:     emitc.field @fieldName0 : !emitc.array<1xf32>
+// CHECK-NEXT:     emitc.func @execute() {
+// CHECK-NEXT:       %0 = get_field @fieldName0 : !emitc.array<1xf32>
+// CHECK-NEXT:       call_opaque "bar"(%0) : (!emitc.array<1xf32>) -> ()
+// CHECK-NEXT:       return
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT: }