Defines new PDLInterp operations needed for multi-root matching in PDL.

This is commit 1 of 4 for the multi-root matching in PDL, discussed in https://llvm.discourse.group/t/rfc-multi-root-pdl-patterns-for-kernel-matching/4148 (topic flagged for review).

These operations are:
* pdl.get_accepting_ops: Returns a list of operations accepting the given value or a range of values at the specified position. Thus if there are two operations `%op1 = "foo"(%val)` and `%op2 = "bar"(%val)` accepting a value at position 0, `%ops = pdl_interp.get_accepting_ops of %val : !pdl.value at 0` will return both of them. This allows us to traverse upwards from a value to operations accepting the value.
* pdl.choose_op: Iteratively chooses one operation from a range of operations. Therefore, writing `%op = pdl_interp.choose_op from %ops` in the example above will select either `%op1`or `%op2`.

Testing: Added the corresponding test cases to mlir/test/Dialect/PDLInterp/ops.mlir.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D108543

GitOrigin-RevId: 842b6861c01cc6961f170d58332ecf0fb0232441
diff --git a/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
index ff9b3dd..87033ed 100644
--- a/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
+++ b/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td
@@ -371,6 +371,29 @@
 }
 
 //===----------------------------------------------------------------------===//
+// pdl_interp::ContinueOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_ContinueOp
+    : PDLInterp_Op<"continue", [NoSideEffect, HasParent<"ForEachOp">,
+                               Terminator]> {
+  let summary = "Breaks the current iteration";
+  let description = [{
+    `pdl_interp.continue` operation breaks the current iteration within the
+    `pdl_interp.foreach` region and continues with the next iteration from
+    the beginning of the region.
+
+    Example:
+
+    ```mlir
+    pdl_interp.continue
+    ```
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
+//===----------------------------------------------------------------------===//
 // pdl_interp::CreateAttributeOp
 //===----------------------------------------------------------------------===//
 
@@ -514,6 +537,42 @@
 }
 
 //===----------------------------------------------------------------------===//
+// pdl_interp::ExtractOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_ExtractOp
+    : PDLInterp_Op<"extract", [NoSideEffect,
+     TypesMatchWith<
+        "`range` is a PDL range whose element type matches type of `result`",
+        "result", "range", "pdl::RangeType::get($_self)">]> {
+  let summary = "Extract the item at the specified index in a range";
+  let description = [{
+    `pdl_interp.extract` operations are used to extract an item from a range
+    at the specified index. If the index is out of range, returns null.
+
+    Example:
+
+    ```mlir
+    // Extract the value at index 1 from a range of values.
+    %ops = pdl_interp.extract 1 of %values : !pdl.value
+    ```
+  }];
+
+  let arguments = (ins PDL_RangeOf<PDL_AnyType>:$range,
+                       Confined<I32Attr, [IntNonNegative]>:$index);
+  let results = (outs PDL_AnyType:$result);
+  let assemblyFormat = "$index `of` $range `:` type($result) attr-dict";
+
+  let builders = [
+    OpBuilder<(ins "Value":$range, "unsigned":$index), [{
+      build($_builder, $_state,
+            range.getType().cast<pdl::RangeType>().getElementType(),
+            range, index);
+    }]>,
+  ];
+}
+
+//===----------------------------------------------------------------------===//
 // pdl_interp::FinalizeOp
 //===----------------------------------------------------------------------===//
 
@@ -534,6 +593,48 @@
 }
 
 //===----------------------------------------------------------------------===//
+// pdl_interp::ForEachOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_ForEachOp
+    : PDLInterp_Op<"foreach", [Terminator]> {
+  let summary = "Iterates over a range of values or ranges";
+  let description = [{
+    `pdl_interp.foreach` iteratively selects an element from a range of values
+    and executes the region until pdl.continue is reached.
+
+    In the bytecode interpreter, this operation is implemented by looping over
+    the values and, for each selection, running the bytecode until we reach
+    pdl.continue. This may result in multiple matches being reported. Note
+    that the input range is mutated (popped from).
+
+    Example:
+
+    ```mlir
+    pdl_interp.foreach %op : !pdl.operation in %ops {
+      pdl_interp.continue
+    } -> ^next
+    ```
+  }];
+
+  let arguments = (ins PDL_RangeOf<PDL_AnyType>:$values);
+  let regions = (region AnyRegion:$region);
+  let successors = (successor AnySuccessor:$successor);
+
+  let builders = [
+    OpBuilder<(ins "Value":$range, "Block *":$successor, "bool":$initLoop)>
+  ];
+
+  let extraClassDeclaration = [{
+    /// Returns the loop variable.
+    BlockArgument getLoopVariable() { return region().getArgument(0); }
+  }];
+  let parser = [{ return ::parseForEachOp(parser, result); }];
+  let printer = [{ return ::print(p, *this); }];
+  let verifier = [{ return ::verify(*this); }];
+}
+
+//===----------------------------------------------------------------------===//
 // pdl_interp::GetAttributeOp
 //===----------------------------------------------------------------------===//
 
@@ -751,6 +852,42 @@
 }
 
 //===----------------------------------------------------------------------===//
+// pdl_interp::GetUsersOp
+//===----------------------------------------------------------------------===//
+
+def PDLInterp_GetUsersOp
+    : PDLInterp_Op<"get_users", [NoSideEffect]> {
+  let summary = "Get the users of a `Value`";
+  let description = [{
+    `pdl_interp.get_users` extracts the users that accept this value. In the
+    case of a range, the union of users of the all the values are returned,
+    similarly to ResultRange::getUsers.
+
+    Example:
+
+    ```mlir
+    // Get all the users of a single value.
+    %ops = pdl_interp.get_users of %value : !pdl.value
+
+    // Get all the users of the first value in a range.
+    %ops = pdl_interp.get_users of %values : !pdl.range<value>
+    ```
+  }];
+
+  let arguments = (ins PDL_InstOrRangeOf<PDL_Value>:$value);
+  let results = (outs PDL_RangeOf<PDL_Operation>:$operations);
+  let assemblyFormat = "`of` $value `:` type($value) attr-dict";
+
+  let builders = [
+    OpBuilder<(ins "Value":$value), [{
+      build($_builder, $_state,
+            pdl::RangeType::get($_builder.getType<pdl::OperationType>()),
+            value);
+    }]>,
+  ];
+}
+
+//===----------------------------------------------------------------------===//
 // pdl_interp::GetValueTypeOp
 //===----------------------------------------------------------------------===//
 
diff --git a/lib/Dialect/PDLInterp/IR/PDLInterp.cpp b/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
index d149ef5..5a14fff 100644
--- a/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
+++ b/lib/Dialect/PDLInterp/IR/PDLInterp.cpp
@@ -66,6 +66,85 @@
 }
 
 //===----------------------------------------------------------------------===//
+// pdl_interp::ForEachOp
+//===----------------------------------------------------------------------===//
+
+void ForEachOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
+                      Value range, Block *successor, bool initLoop) {
+  build(builder, state, range, successor);
+  if (initLoop) {
+    // Create the block and the loop variable.
+    auto range_type = range.getType().cast<pdl::RangeType>();
+    state.regions.front()->emplaceBlock();
+    state.regions.front()->addArgument(range_type.getElementType());
+  }
+}
+
+static ParseResult parseForEachOp(OpAsmParser &parser, OperationState &result) {
+  // Parse the loop variable followed by type.
+  OpAsmParser::OperandType loopVariable;
+  Type loopVariableType;
+  if (parser.parseRegionArgument(loopVariable) ||
+      parser.parseColonType(loopVariableType))
+    return failure();
+
+  // Parse the "in" keyword.
+  if (parser.parseKeyword("in", " after loop variable"))
+    return failure();
+
+  // Parse the operand (value range).
+  OpAsmParser::OperandType operandInfo;
+  if (parser.parseOperand(operandInfo))
+    return failure();
+
+  // Resolve the operand.
+  Type rangeType = pdl::RangeType::get(loopVariableType);
+  if (parser.resolveOperand(operandInfo, rangeType, result.operands))
+    return failure();
+
+  // Parse the body region.
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, {loopVariable}, {loopVariableType}))
+    return failure();
+
+  // Parse the attribute dictionary.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  // Parse the successor.
+  Block *successor;
+  if (parser.parseArrow() || parser.parseSuccessor(successor))
+    return failure();
+  result.addSuccessors(successor);
+
+  return success();
+}
+
+static void print(OpAsmPrinter &p, ForEachOp op) {
+  BlockArgument arg = op.getLoopVariable();
+  p << ' ' << arg << " : " << arg.getType() << " in " << op.values();
+  p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
+  p.printOptionalAttrDict(op->getAttrs());
+  p << " -> ";
+  p.printSuccessor(op.successor());
+}
+
+static LogicalResult verify(ForEachOp op) {
+  // Verify that the operation has exactly one argument.
+  if (op.region().getNumArguments() != 1)
+    return op.emitOpError("requires exactly one argument");
+
+  // Verify that the loop variable and the operand (value range)
+  // have compatible types.
+  BlockArgument arg = op.getLoopVariable();
+  Type rangeType = pdl::RangeType::get(arg.getType());
+  if (rangeType != op.values().getType())
+    return op.emitOpError("operand must be a range of loop variable type");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // pdl_interp::GetValueTypeOp
 //===----------------------------------------------------------------------===//
 
diff --git a/test/Dialect/PDLInterp/ops.mlir b/test/Dialect/PDLInterp/ops.mlir
index 072dfad..8bf559b 100644
--- a/test/Dialect/PDLInterp/ops.mlir
+++ b/test/Dialect/PDLInterp/ops.mlir
@@ -23,3 +23,46 @@
 
   pdl_interp.finalize
 }
+
+// -----
+
+func @extract(%attrs : !pdl.range<attribute>, %ops : !pdl.range<operation>, %types : !pdl.range<type>, %vals: !pdl.range<value>) {
+  // attribute at index 0
+  %attr = pdl_interp.extract 0 of %attrs : !pdl.attribute
+
+  // operation at index 1
+  %op = pdl_interp.extract 1 of %ops : !pdl.operation
+
+  // type at index 2
+  %type = pdl_interp.extract 2 of %types : !pdl.type
+
+  // value at index 3
+  %val = pdl_interp.extract 3 of %vals : !pdl.value
+
+  pdl_interp.finalize
+}
+
+// -----
+
+func @foreach(%ops: !pdl.range<operation>) {
+  // iterate over a range of operations
+  pdl_interp.foreach %op : !pdl.operation in %ops {
+    %val = pdl_interp.get_result 0 of %op
+    pdl_interp.continue
+  } -> ^end
+
+  ^end:
+    pdl_interp.finalize
+}
+
+// -----
+
+func @users(%value: !pdl.value, %values: !pdl.range<value>) {
+  // all the users of a single value
+  %ops1 = pdl_interp.get_users of %value : !pdl.value
+
+  // all the users of all the values in a range
+  %ops2 = pdl_interp.get_users of %values : !pdl.range<value>
+
+  pdl_interp.finalize
+}