[mlir][python] Normalize asm-printing IR behavior.

While working on an integration, I found a lot of inconsistencies on IR printing and verification. It turns out that we were:
  * Only doing "soft fail" verification on IR printing of Operation, not of a Module.
  * Failed verification was interacting badly with binary=True IR printing (causing a TypeError trying to pass an `str` to a `bytes` based handle).
  * For systematic integrations, it is often desirable to control verification yourself so that you can explicitly handle errors.

This patch:
  * Trues up the "soft fail" semantics by having `Module.__str__` delegate to `Operation.__str__` vs having a shortcut implementation.
  * Fixes soft fail in the presence of binary=True (and adds an additional happy path test case to make sure the binary functionality works).
  * Adds an `assume_verified` boolean flag to the `print`/`get_asm` methods which disables internal verification, presupposing that the caller has taken care of it.

It turns out that we had a number of tests which were generating illegal IR but it wasn't being caught because they were doing a print on the `Module` vs operation. All except two were trivially fixed:
  * linalg/ops.py : Had two tests for direct constructing a Matmul incorrectly. Fixing them made them just like the next two tests so just deleted (no need to test the verifier only at this level).
  * linalg/opdsl/emit_structured_generic.py : Hand coded conv and pooling tests appear to be using illegal shaped inputs/outputs, causing a verification failure. I just used the `assume_verified=` flag to restore the original behavior and left a TODO. Will get someone who owns that to fix it properly in a followup (would also be nice to break this file up into multiple test modules as it is hard to tell exactly what is failing).

Notes to downstreams:
  * If, like some of our tests, you get verification failures after this patch, it is likely that your IR was always invalid and you will need to fix the root cause. To temporarily revert to prior (broken) behavior, replace calls like `print(module)` with `print(module.operation.get_asm(assume_verified=True))`.

Differential Revision: https://reviews.llvm.org/D114680

GitOrigin-RevId: ace1d0ad3dc43e28715cbe2f3e0a5a76578bda9f
diff --git a/lib/Bindings/Python/IRCore.cpp b/lib/Bindings/Python/IRCore.cpp
index 4c25fd4..c70cfc5 100644
--- a/lib/Bindings/Python/IRCore.cpp
+++ b/lib/Bindings/Python/IRCore.cpp
@@ -93,6 +93,13 @@
   use_local_Scope: Whether to print in a way that is more optimized for
     multi-threaded access but may not be consistent with how the overall
     module prints.
+  assume_verified: By default, if not printing generic form, the verifier
+    will be run and if it fails, generic form will be printed with a comment
+    about failed verification. While a reasonable default for interactive use,
+    for systematic use, it is often better for the caller to verify explicitly
+    and report failures in a more robust fashion. Set this to True if doing this
+    in order to avoid running a redundant verification. If the IR is actually
+    invalid, behavior is undefined.
 )";
 
 static const char kOperationGetAsmDocstring[] =
@@ -828,14 +835,21 @@
 void PyOperationBase::print(py::object fileObject, bool binary,
                             llvm::Optional<int64_t> largeElementsLimit,
                             bool enableDebugInfo, bool prettyDebugInfo,
-                            bool printGenericOpForm, bool useLocalScope) {
+                            bool printGenericOpForm, bool useLocalScope,
+                            bool assumeVerified) {
   PyOperation &operation = getOperation();
   operation.checkValid();
   if (fileObject.is_none())
     fileObject = py::module::import("sys").attr("stdout");
 
-  if (!printGenericOpForm && !mlirOperationVerify(operation)) {
-    fileObject.attr("write")("// Verification failed, printing generic form\n");
+  if (!assumeVerified && !printGenericOpForm &&
+      !mlirOperationVerify(operation)) {
+    std::string message("// Verification failed, printing generic form\n");
+    if (binary) {
+      fileObject.attr("write")(py::bytes(message));
+    } else {
+      fileObject.attr("write")(py::str(message));
+    }
     printGenericOpForm = true;
   }
 
@@ -857,8 +871,8 @@
 py::object PyOperationBase::getAsm(bool binary,
                                    llvm::Optional<int64_t> largeElementsLimit,
                                    bool enableDebugInfo, bool prettyDebugInfo,
-                                   bool printGenericOpForm,
-                                   bool useLocalScope) {
+                                   bool printGenericOpForm, bool useLocalScope,
+                                   bool assumeVerified) {
   py::object fileObject;
   if (binary) {
     fileObject = py::module::import("io").attr("BytesIO")();
@@ -870,7 +884,8 @@
         /*enableDebugInfo=*/enableDebugInfo,
         /*prettyDebugInfo=*/prettyDebugInfo,
         /*printGenericOpForm=*/printGenericOpForm,
-        /*useLocalScope=*/useLocalScope);
+        /*useLocalScope=*/useLocalScope,
+        /*assumeVerified=*/assumeVerified);
 
   return fileObject.attr("getvalue")();
 }
@@ -2149,12 +2164,9 @@
           kDumpDocstring)
       .def(
           "__str__",
-          [](PyModule &self) {
-            MlirOperation operation = mlirModuleGetOperation(self.get());
-            PyPrintAccumulator printAccum;
-            mlirOperationPrint(operation, printAccum.getCallback(),
-                               printAccum.getUserData());
-            return printAccum.join();
+          [](py::object self) {
+            // Defer to the operation's __str__.
+            return self.attr("operation").attr("__str__")();
           },
           kOperationStrDunderDocstring);
 
@@ -2234,7 +2246,8 @@
                                /*enableDebugInfo=*/false,
                                /*prettyDebugInfo=*/false,
                                /*printGenericOpForm=*/false,
-                               /*useLocalScope=*/false);
+                               /*useLocalScope=*/false,
+                               /*assumeVerified=*/false);
           },
           "Returns the assembly form of the operation.")
       .def("print", &PyOperationBase::print,
@@ -2244,7 +2257,8 @@
            py::arg("enable_debug_info") = false,
            py::arg("pretty_debug_info") = false,
            py::arg("print_generic_op_form") = false,
-           py::arg("use_local_scope") = false, kOperationPrintDocstring)
+           py::arg("use_local_scope") = false,
+           py::arg("assume_verified") = false, kOperationPrintDocstring)
       .def("get_asm", &PyOperationBase::getAsm,
            // Careful: Lots of arguments must match up with get_asm method.
            py::arg("binary") = false,
@@ -2252,7 +2266,8 @@
            py::arg("enable_debug_info") = false,
            py::arg("pretty_debug_info") = false,
            py::arg("print_generic_op_form") = false,
-           py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
+           py::arg("use_local_scope") = false,
+           py::arg("assume_verified") = false, kOperationGetAsmDocstring)
       .def(
           "verify",
           [](PyOperationBase &self) {
diff --git a/lib/Bindings/Python/IRModule.h b/lib/Bindings/Python/IRModule.h
index eb5c238..dc024a2 100644
--- a/lib/Bindings/Python/IRModule.h
+++ b/lib/Bindings/Python/IRModule.h
@@ -394,11 +394,13 @@
   /// Implements the bound 'print' method and helps with others.
   void print(pybind11::object fileObject, bool binary,
              llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo,
-             bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope);
+             bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
+             bool assumeVerified);
   pybind11::object getAsm(bool binary,
                           llvm::Optional<int64_t> largeElementsLimit,
                           bool enableDebugInfo, bool prettyDebugInfo,
-                          bool printGenericOpForm, bool useLocalScope);
+                          bool printGenericOpForm, bool useLocalScope,
+                          bool assumeVerified);
 
   /// Moves the operation before or after the other operation.
   void moveAfter(PyOperationBase &other);
diff --git a/test/python/dialects/builtin.py b/test/python/dialects/builtin.py
index 8f3a041..7caf5b5 100644
--- a/test/python/dialects/builtin.py
+++ b/test/python/dialects/builtin.py
@@ -175,7 +175,8 @@
 # CHECK-LABEL: TEST: testFuncArgumentAccess
 @run
 def testFuncArgumentAccess():
-  with Context(), Location.unknown():
+  with Context() as ctx, Location.unknown():
+    ctx.allow_unregistered_dialects = True
     module = Module.create()
     f32 = F32Type.get()
     f64 = F64Type.get()
@@ -185,38 +186,38 @@
         std.ReturnOp(func.arguments)
       func.arg_attrs = ArrayAttr.get([
           DictAttr.get({
-              "foo": StringAttr.get("bar"),
-              "baz": UnitAttr.get()
+              "custom_dialect.foo": StringAttr.get("bar"),
+              "custom_dialect.baz": UnitAttr.get()
           }),
-          DictAttr.get({"qux": ArrayAttr.get([])})
+          DictAttr.get({"custom_dialect.qux": ArrayAttr.get([])})
       ])
       func.result_attrs = ArrayAttr.get([
-          DictAttr.get({"res1": FloatAttr.get(f32, 42.0)}),
-          DictAttr.get({"res2": FloatAttr.get(f64, 256.0)})
+          DictAttr.get({"custom_dialect.res1": FloatAttr.get(f32, 42.0)}),
+          DictAttr.get({"custom_dialect.res2": FloatAttr.get(f64, 256.0)})
       ])
 
       other = builtin.FuncOp("other_func", ([f32, f32], []))
       with InsertionPoint(other.add_entry_block()):
         std.ReturnOp([])
       other.arg_attrs = [
-          DictAttr.get({"foo": StringAttr.get("qux")}),
+          DictAttr.get({"custom_dialect.foo": StringAttr.get("qux")}),
           DictAttr.get()
       ]
 
-  # CHECK: [{baz, foo = "bar"}, {qux = []}]
+  # CHECK: [{custom_dialect.baz, custom_dialect.foo = "bar"}, {custom_dialect.qux = []}]
   print(func.arg_attrs)
 
-  # CHECK: [{res1 = 4.200000e+01 : f32}, {res2 = 2.560000e+02 : f64}]
+  # CHECK: [{custom_dialect.res1 = 4.200000e+01 : f32}, {custom_dialect.res2 = 2.560000e+02 : f64}]
   print(func.result_attrs)
 
   # CHECK: func @some_func(
-  # CHECK: %[[ARG0:.*]]: f32 {baz, foo = "bar"},
-  # CHECK: %[[ARG1:.*]]: f32 {qux = []}) ->
-  # CHECK: f32 {res1 = 4.200000e+01 : f32},
-  # CHECK: f32 {res2 = 2.560000e+02 : f64})
+  # CHECK: %[[ARG0:.*]]: f32 {custom_dialect.baz, custom_dialect.foo = "bar"},
+  # CHECK: %[[ARG1:.*]]: f32 {custom_dialect.qux = []}) ->
+  # CHECK: f32 {custom_dialect.res1 = 4.200000e+01 : f32},
+  # CHECK: f32 {custom_dialect.res2 = 2.560000e+02 : f64})
   # CHECK: return %[[ARG0]], %[[ARG1]] : f32, f32
   #
   # CHECK: func @other_func(
-  # CHECK: %{{.*}}: f32 {foo = "qux"},
+  # CHECK: %{{.*}}: f32 {custom_dialect.foo = "qux"},
   # CHECK: %{{.*}}: f32)
   print(module)
diff --git a/test/python/dialects/linalg/opdsl/emit_structured_generic.py b/test/python/dialects/linalg/opdsl/emit_structured_generic.py
index d0c7427..115c227 100644
--- a/test/python/dialects/linalg/opdsl/emit_structured_generic.py
+++ b/test/python/dialects/linalg/opdsl/emit_structured_generic.py
@@ -405,4 +405,7 @@
       return non_default_op_name(input, outs=[init_result])
 
 
-print(module)
+# TODO: Fix me! Conv and pooling ops above do not verify, which was uncovered
+# when switching to more robust module verification. For now, reverting to the
+# old behavior which does not verify on module print.
+print(module.operation.get_asm(assume_verified=True))
diff --git a/test/python/dialects/linalg/ops.py b/test/python/dialects/linalg/ops.py
index e5b96c2..4f9f138 100644
--- a/test/python/dialects/linalg/ops.py
+++ b/test/python/dialects/linalg/ops.py
@@ -83,49 +83,6 @@
   print(module)
 
 
-# CHECK-LABEL: TEST: testStructuredOpOnTensors
-@run
-def testStructuredOpOnTensors():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    tensor_type = RankedTensorType.get((2, 3, 4), f32)
-    with InsertionPoint(module.body):
-      func = builtin.FuncOp(
-          name="matmul_test",
-          type=FunctionType.get(
-              inputs=[tensor_type, tensor_type], results=[tensor_type]))
-      with InsertionPoint(func.add_entry_block()):
-        lhs, rhs = func.entry_block.arguments
-        result = linalg.MatmulOp([lhs, rhs], results=[tensor_type]).result
-        std.ReturnOp([result])
-
-  # CHECK: %[[R:.*]] = linalg.matmul ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
-  print(module)
-
-
-# CHECK-LABEL: TEST: testStructuredOpOnBuffers
-@run
-def testStructuredOpOnBuffers():
-  with Context() as ctx, Location.unknown():
-    module = Module.create()
-    f32 = F32Type.get()
-    memref_type = MemRefType.get((2, 3, 4), f32)
-    with InsertionPoint(module.body):
-      func = builtin.FuncOp(
-          name="matmul_test",
-          type=FunctionType.get(
-              inputs=[memref_type, memref_type, memref_type], results=[]))
-      with InsertionPoint(func.add_entry_block()):
-        lhs, rhs, result = func.entry_block.arguments
-        # TODO: prperly hook up the region.
-        linalg.MatmulOp([lhs, rhs], outputs=[result])
-        std.ReturnOp([])
-
-  # CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>)
-  print(module)
-
-
 # CHECK-LABEL: TEST: testNamedStructuredOpCustomForm
 @run
 def testNamedStructuredOpCustomForm():
diff --git a/test/python/dialects/shape.py b/test/python/dialects/shape.py
index 7c1c5d6..a798b85 100644
--- a/test/python/dialects/shape.py
+++ b/test/python/dialects/shape.py
@@ -22,7 +22,8 @@
       @builtin.FuncOp.from_py_func(
           RankedTensorType.get((12, -1), f32))
       def const_shape_tensor(arg):
-        return shape.ConstShapeOp(DenseElementsAttr.get(np.array([10, 20])))
+        return shape.ConstShapeOp(
+          DenseElementsAttr.get(np.array([10, 20]), type=IndexType.get()))
 
     # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
     # CHECK: shape.const_shape [10, 20] : tensor<2xindex>
diff --git a/test/python/dialects/std.py b/test/python/dialects/std.py
index f6e77ca..2a3b2df 100644
--- a/test/python/dialects/std.py
+++ b/test/python/dialects/std.py
@@ -78,8 +78,11 @@
 @constructAndPrintInModule
 def testFunctionCalls():
   foo = builtin.FuncOp("foo", ([], []))
+  foo.sym_visibility = StringAttr.get("private")
   bar = builtin.FuncOp("bar", ([], [IndexType.get()]))
+  bar.sym_visibility = StringAttr.get("private")
   qux = builtin.FuncOp("qux", ([], [F32Type.get()]))
+  qux.sym_visibility = StringAttr.get("private")
 
   with InsertionPoint(builtin.FuncOp("caller", ([], [])).add_entry_block()):
     std.CallOp(foo, [])
@@ -88,9 +91,9 @@
     std.ReturnOp([])
 
 
-# CHECK: func @foo()
-# CHECK: func @bar() -> index
-# CHECK: func @qux() -> f32
+# CHECK: func private @foo()
+# CHECK: func private @bar() -> index
+# CHECK: func private @qux() -> f32
 # CHECK: func @caller() {
 # CHECK:   call @foo() : () -> ()
 # CHECK:   %0 = call @bar() : () -> index
diff --git a/test/python/ir/module.py b/test/python/ir/module.py
index abddc66..76358eb 100644
--- a/test/python/ir/module.py
+++ b/test/python/ir/module.py
@@ -8,11 +8,13 @@
   f()
   gc.collect()
   assert Context._get_live_count() == 0
+  return f
 
 
 # Verify successful parse.
 # CHECK-LABEL: TEST: testParseSuccess
 # CHECK: module @successfulParse
+@run
 def testParseSuccess():
   ctx = Context()
   module = Module.parse(r"""module @successfulParse {}""", ctx)
@@ -23,12 +25,11 @@
   module.dump()  # Just outputs to stderr. Verifies that it functions.
   print(str(module))
 
-run(testParseSuccess)
-
 
 # Verify parse error.
 # CHECK-LABEL: TEST: testParseError
 # CHECK: testParseError: Unable to parse module assembly (see diagnostics)
+@run
 def testParseError():
   ctx = Context()
   try:
@@ -38,12 +39,11 @@
   else:
     print("Exception not produced")
 
-run(testParseError)
-
 
 # Verify successful parse.
 # CHECK-LABEL: TEST: testCreateEmpty
 # CHECK: module {
+@run
 def testCreateEmpty():
   ctx = Context()
   loc = Location.unknown(ctx)
@@ -53,8 +53,6 @@
   gc.collect()
   print(str(module))
 
-run(testCreateEmpty)
-
 
 # Verify round-trip of ASM that contains unicode.
 # Note that this does not test that the print path converts unicode properly
@@ -62,6 +60,7 @@
 # CHECK-LABEL: TEST: testRoundtripUnicode
 # CHECK: func private @roundtripUnicode()
 # CHECK: foo = "\F0\9F\98\8A"
+@run
 def testRoundtripUnicode():
   ctx = Context()
   module = Module.parse(r"""
@@ -69,11 +68,28 @@
   """, ctx)
   print(str(module))
 
-run(testRoundtripUnicode)
+
+# Verify round-trip of ASM that contains unicode.
+# Note that this does not test that the print path converts unicode properly
+# because MLIR asm always normalizes it to the hex encoding.
+# CHECK-LABEL: TEST: testRoundtripBinary
+# CHECK: func private @roundtripUnicode()
+# CHECK: foo = "\F0\9F\98\8A"
+@run
+def testRoundtripBinary():
+  with Context():
+    module = Module.parse(r"""
+      func private @roundtripUnicode() attributes { foo = "😊" }
+    """)
+    binary_asm = module.operation.get_asm(binary=True)
+    assert isinstance(binary_asm, bytes)
+    module = Module.parse(binary_asm)
+    print(module)
 
 
 # Tests that module.operation works and correctly interns instances.
 # CHECK-LABEL: TEST: testModuleOperation
+@run
 def testModuleOperation():
   ctx = Context()
   module = Module.parse(r"""module @successfulParse {}""", ctx)
@@ -101,10 +117,9 @@
   assert ctx._get_live_operation_count() == 0
   assert ctx._get_live_module_count() == 0
 
-run(testModuleOperation)
-
 
 # CHECK-LABEL: TEST: testModuleCapsule
+@run
 def testModuleCapsule():
   ctx = Context()
   module = Module.parse(r"""module @successfulParse {}""", ctx)
@@ -122,5 +137,3 @@
   gc.collect()
   assert ctx._get_live_module_count() == 0
 
-
-run(testModuleCapsule)
diff --git a/test/python/ir/operation.py b/test/python/ir/operation.py
index 8771ca0..133edc2 100644
--- a/test/python/ir/operation.py
+++ b/test/python/ir/operation.py
@@ -630,21 +630,50 @@
   print(module.body.operations[2])
 
 
-# CHECK-LABEL: TEST: testPrintInvalidOperation
+def create_invalid_operation():
+  # This module has two region and is invalid verify that we fallback
+  # to the generic printer for safety.
+  op = Operation.create("builtin.module", regions=2)
+  op.regions[0].blocks.append()
+  return op
+
+# CHECK-LABEL: TEST: testInvalidOperationStrSoftFails
 @run
-def testPrintInvalidOperation():
+def testInvalidOperationStrSoftFails():
   ctx = Context()
   with Location.unknown(ctx):
-    module = Operation.create("builtin.module", regions=2)
-    # This module has two region and is invalid verify that we fallback
-    # to the generic printer for safety.
-    block = module.regions[0].blocks.append()
+    invalid_op = create_invalid_operation()
+    # Verify that we fallback to the generic printer for safety.
     # CHECK: // Verification failed, printing generic form
     # CHECK: "builtin.module"() ( {
     # CHECK: }) : () -> ()
-    print(module)
+    print(invalid_op)
     # CHECK: .verify = False
-    print(f".verify = {module.operation.verify()}")
+    print(f".verify = {invalid_op.operation.verify()}")
+
+
+# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
+@run
+def testInvalidModuleStrSoftFails():
+  ctx = Context()
+  with Location.unknown(ctx):
+    module = Module.create()
+    with InsertionPoint(module.body):
+      invalid_op = create_invalid_operation()
+    # Verify that we fallback to the generic printer for safety.
+    # CHECK: // Verification failed, printing generic form
+    print(module)
+
+
+# CHECK-LABEL: TEST: testInvalidOperationGetAsmBinarySoftFails
+@run
+def testInvalidOperationGetAsmBinarySoftFails():
+  ctx = Context()
+  with Location.unknown(ctx):
+    invalid_op = create_invalid_operation()
+    # Verify that we fallback to the generic printer for safety.
+    # CHECK: b'// Verification failed, printing generic form\n
+    print(invalid_op.get_asm(binary=True))
 
 
 # CHECK-LABEL: TEST: testCreateWithInvalidAttributes