Update fir.insert_on_range syntax to make the range more explicit (NFC)

Also replace ArrayAttr with IndexElementsAttr to model subscript dimensions.
An array of attribute is a sparse inefficient storage, with an API that
requires to unpack/repack integers at every call site.
Instead we can store dense array of integer as IndexElementsAttr.

Reviewed By: clementval, kiranchandramohan

Differential Revision: https://reviews.llvm.org/D112899

GitOrigin-RevId: 8ec0f221843c51096cf3e7a479e780be371388a8
diff --git a/include/flang/Optimizer/Dialect/FIROps.td b/include/flang/Optimizer/Dialect/FIROps.td
index 031b244..1368009 100644
--- a/include/flang/Optimizer/Dialect/FIROps.td
+++ b/include/flang/Optimizer/Dialect/FIROps.td
@@ -1971,18 +1971,18 @@
     ```mlir
       %a = fir.undefined !fir.array<10x10xf32>
       %c = arith.constant 3.0 : f32
-      %1 = fir.insert_on_range %a, %c, [0 : index, 7 : index, 0 : index, 2 : index] : (!fir.array<10x10xf32>, f32) -> !fir.array<10x10xf32>
+      %1 = fir.insert_on_range %a, %c from (0, 0) to (7, 2) : (!fir.array<10x10xf32>, f32) -> !fir.array<10x10xf32>
     ```
 
     The first 28 elements of %1, with coordinates from (0,0) to (7,2), have
     the value 3.0.
   }];
 
-  let arguments = (ins fir_SequenceType:$seq, AnyType:$val, ArrayAttr:$coor);
+  let arguments = (ins fir_SequenceType:$seq, AnyType:$val, IndexElementsAttr:$coor);
   let results = (outs fir_SequenceType);
 
   let assemblyFormat = [{
-    $seq `,` $val `,` $coor attr-dict `:` functional-type(operands, results)
+    $seq `,` $val custom<CustomRangeSubscript>($coor) attr-dict `:` functional-type(operands, results)
   }];
 
   let verifier = "return ::verify(*this);";
diff --git a/lib/Optimizer/CodeGen/CodeGen.cpp b/lib/Optimizer/CodeGen/CodeGen.cpp
index 1e1f3ee..7583d5d 100644
--- a/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -929,14 +929,16 @@
     return success();
   }
 
-  bool isFullRange(mlir::ArrayAttr indexes, fir::SequenceType seqTy) const {
+  bool isFullRange(mlir::DenseIntElementsAttr indexes,
+                   fir::SequenceType seqTy) const {
     auto extents = seqTy.getShape();
-    if (indexes.size() / 2 != extents.size())
+    if (indexes.size() / 2 != static_cast<int64_t>(extents.size()))
       return false;
+    auto cur_index = indexes.value_begin<int64_t>();
     for (unsigned i = 0; i < indexes.size(); i += 2) {
-      if (indexes[i].cast<IntegerAttr>().getInt() != 0)
+      if (*(cur_index++) != 0)
         return false;
-      if (indexes[i + 1].cast<IntegerAttr>().getInt() != extents[i / 2] - 1)
+      if (*(cur_index++) != extents[i / 2] - 1)
         return false;
     }
     return true;
@@ -1728,14 +1730,10 @@
     SmallVector<uint64_t> lBounds;
     SmallVector<uint64_t> uBounds;
 
-    // Extract integer value from the attribute
-    SmallVector<int64_t> coordinates = llvm::to_vector<4>(
-        llvm::map_range(range.coor(), [](Attribute a) -> int64_t {
-          return a.cast<IntegerAttr>().getInt();
-        }));
-
     // Unzip the upper and lower bound and convert to a row major format.
-    for (auto i = coordinates.rbegin(), e = coordinates.rend(); i != e; ++i) {
+    mlir::DenseIntElementsAttr coor = range.coor();
+    auto reversedCoor = llvm::reverse(coor.getValues<int64_t>());
+    for (auto i = reversedCoor.begin(), e = reversedCoor.end(); i != e; ++i) {
       uBounds.push_back(*i++);
       lBounds.push_back(*i);
     }
diff --git a/lib/Optimizer/Dialect/FIROps.cpp b/lib/Optimizer/Dialect/FIROps.cpp
index 262df2a..9ec3bc5 100644
--- a/lib/Optimizer/Dialect/FIROps.cpp
+++ b/lib/Optimizer/Dialect/FIROps.cpp
@@ -17,10 +17,14 @@
 #include "flang/Optimizer/Support/Utils.h"
 #include "mlir/Dialect/CommonFolders.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/ADT/TypeSwitch.h"
 
@@ -1374,16 +1378,62 @@
 // InsertOnRangeOp
 //===----------------------------------------------------------------------===//
 
+static ParseResult
+parseCustomRangeSubscript(mlir::OpAsmParser &parser,
+                          mlir::DenseIntElementsAttr &coord) {
+  llvm::SmallVector<int64_t> lbounds;
+  llvm::SmallVector<int64_t> ubounds;
+  if (parser.parseKeyword("from") ||
+      parser.parseCommaSeparatedList(
+          AsmParser::Delimiter::Paren,
+          [&] { return parser.parseInteger(lbounds.emplace_back(0)); }) ||
+      parser.parseKeyword("to") ||
+      parser.parseCommaSeparatedList(AsmParser::Delimiter::Paren, [&] {
+        return parser.parseInteger(ubounds.emplace_back(0));
+      }))
+    return failure();
+  llvm::SmallVector<int64_t> zippedBounds;
+  for (auto zip : llvm::zip(lbounds, ubounds)) {
+    zippedBounds.push_back(std::get<0>(zip));
+    zippedBounds.push_back(std::get<1>(zip));
+  }
+  coord = mlir::Builder(parser.getContext()).getIndexTensorAttr(zippedBounds);
+  return success();
+}
+
+void printCustomRangeSubscript(mlir::OpAsmPrinter &printer, InsertOnRangeOp op,
+                               mlir::DenseIntElementsAttr coord) {
+  printer << "from (";
+  auto enumerate = llvm::enumerate(coord.getValues<int64_t>());
+  // Even entries are the lower bounds.
+  llvm::interleaveComma(
+      make_filter_range(
+          enumerate,
+          [](auto indexed_value) { return indexed_value.index() % 2 == 0; }),
+      printer, [&](auto indexed_value) { printer << indexed_value.value(); });
+  printer << ") to (";
+  // Odd entries are the upper bounds.
+  llvm::interleaveComma(
+      make_filter_range(
+          enumerate,
+          [](auto indexed_value) { return indexed_value.index() % 2 != 0; }),
+      printer, [&](auto indexed_value) { printer << indexed_value.value(); });
+  printer << ")";
+}
+
 /// Range bounds must be nonnegative, and the range must not be empty.
 static mlir::LogicalResult verify(fir::InsertOnRangeOp op) {
   if (fir::hasDynamicSize(op.seq().getType()))
     return op.emitOpError("must have constant shape and size");
-  if (op.coor().size() < 2 || op.coor().size() % 2 != 0)
+  mlir::DenseIntElementsAttr coor = op.coor();
+  if (coor.size() < 2 || coor.size() % 2 != 0)
     return op.emitOpError("has uneven number of values in ranges");
   bool rangeIsKnownToBeNonempty = false;
-  for (auto i = op.coor().end(), b = op.coor().begin(); i != b;) {
-    int64_t ub = (*--i).cast<IntegerAttr>().getInt();
-    int64_t lb = (*--i).cast<IntegerAttr>().getInt();
+  for (auto i = coor.getValues<int64_t>().end(),
+            b = coor.getValues<int64_t>().begin();
+       i != b;) {
+    int64_t ub = (*--i);
+    int64_t lb = (*--i);
     if (lb < 0 || ub < 0)
       return op.emitOpError("negative range bound");
     if (rangeIsKnownToBeNonempty)
diff --git a/test/Fir/convert-to-llvm.fir b/test/Fir/convert-to-llvm.fir
index 33b7941..1ba4e54 100644
--- a/test/Fir/convert-to-llvm.fir
+++ b/test/Fir/convert-to-llvm.fir
@@ -80,7 +80,7 @@
 fir.global internal @_QEmultiarray : !fir.array<32x32xi32> {
   %c0_i32 = arith.constant 1 : i32
   %0 = fir.undefined !fir.array<32x32xi32>
-  %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 31 : index, 0 : index, 31 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
+  %2 = fir.insert_on_range %0, %c0_i32 from (0, 0) to (31, 31) : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
   fir.has_value %2 : !fir.array<32x32xi32>
 }
 
@@ -97,7 +97,7 @@
 fir.global internal @_QEmultiarray : !fir.array<32xi32> {
   %c0_i32 = arith.constant 1 : i32
   %0 = fir.undefined !fir.array<32xi32>
-  %2 = fir.insert_on_range %0, %c0_i32, [5 : index, 31 : index] : (!fir.array<32xi32>, i32) -> !fir.array<32xi32>
+  %2 = fir.insert_on_range %0, %c0_i32 from (5) to (31) : (!fir.array<32xi32>, i32) -> !fir.array<32xi32>
   fir.has_value %2 : !fir.array<32xi32>
 }
 
diff --git a/test/Fir/fir-ops.fir b/test/Fir/fir-ops.fir
index a6ae02a..631200f 100644
--- a/test/Fir/fir-ops.fir
+++ b/test/Fir/fir-ops.fir
@@ -617,10 +617,10 @@
   %c1_i32 = arith.constant 9 : i32
 
   // CHECK: [[ARR2:%.*]] = fir.zero_bits !fir.array<10xi32>
-  // CHECK: [[ARR3:%.*]] = fir.insert_on_range [[ARR2]], [[C1_I32]], [2 : index, 9 : index] : (!fir.array<10xi32>, i32) -> !fir.array<10xi32>
+  // CHECK: [[ARR3:%.*]] = fir.insert_on_range [[ARR2]], [[C1_I32]] from (2) to (9) : (!fir.array<10xi32>, i32) -> !fir.array<10xi32>
   // CHECK: fir.call @noret1([[ARR3]]) : (!fir.array<10xi32>) -> ()
   %arr2 = fir.zero_bits !fir.array<10xi32>
-    %arr3 = fir.insert_on_range %arr2, %c1_i32, [2 : index, 9 : index] : (!fir.array<10xi32>, i32) -> !fir.array<10xi32>
+  %arr3 = fir.insert_on_range %arr2, %c1_i32 from (2) to (9) : (!fir.array<10xi32>, i32) -> !fir.array<10xi32>
   fir.call @noret1(%arr3) : (!fir.array<10xi32>) -> ()
 
   // CHECK: [[SHAPE:%.*]] = fir.shape_shift [[INDXM:%.*]], [[INDXN:%.*]], [[INDXO:%.*]], [[INDXP:%.*]] : (index, index, index, index) -> !fir.shapeshift<2>
@@ -664,6 +664,14 @@
   return
 }
 
+// CHECK-LABEL: @insert_on_range_multi_dim
+// CHECK-SAME: %[[ARR:.*]]: !fir.array<10x20xi32>, %[[CST:.*]]: i32
+func @insert_on_range_multi_dim(%arr : !fir.array<10x20xi32>, %cst : i32) {
+  // CHECK: fir.insert_on_range %[[ARR]], %[[CST]] from (2, 3) to (5, 6) : (!fir.array<10x20xi32>, i32) -> !fir.array<10x20xi32>
+  %arr3 = fir.insert_on_range %arr, %cst from (2, 3) to (5, 6) : (!fir.array<10x20xi32>, i32) -> !fir.array<10x20xi32>
+  return
+}
+
 // CHECK-LABEL: @test_shift
 func @test_shift(%arg0: !fir.box<!fir.array<?xf32>>) -> !fir.ref<f32> {
   %c4 = arith.constant 4 : index
diff --git a/test/Fir/invalid.fir b/test/Fir/invalid.fir
index 8bc2ac6..98ee4a4 100644
--- a/test/Fir/invalid.fir
+++ b/test/Fir/invalid.fir
@@ -428,7 +428,7 @@
   %c0_i32 = arith.constant 1 : i32
   %0 = fir.undefined !fir.array<32x32xi32>
   // expected-error@+1 {{'fir.insert_on_range' op has uneven number of values in ranges}}
-  %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 31 : index, 0 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
+  %2 = "fir.insert_on_range"(%0, %c0_i32) { coor = dense<[0, 31, 0]> : tensor<3xindex> } : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
   fir.has_value %2 : !fir.array<32x32xi32>
 }
 
@@ -438,7 +438,7 @@
   %c0_i32 = arith.constant 1 : i32
   %0 = fir.undefined !fir.array<32x32xi32>
   // expected-error@+1 {{'fir.insert_on_range' op has uneven number of values in ranges}}
-  %2 = fir.insert_on_range %0, %c0_i32, [0 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
+  %2 = "fir.insert_on_range"(%0, %c0_i32) { coor = dense<[0]> : tensor<1xindex> }  : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
   fir.has_value %2 : !fir.array<32x32xi32>
 }
 
@@ -448,7 +448,7 @@
   %c0_i32 = arith.constant 1 : i32
   %0 = fir.undefined !fir.array<32x32xi32>
   // expected-error@+1 {{'fir.insert_on_range' op negative range bound}}
-  %2 = fir.insert_on_range %0, %c0_i32, [-1 : index, 0 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
+  %2 = fir.insert_on_range %0, %c0_i32 from (-1) to (0) : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
   fir.has_value %2 : !fir.array<32x32xi32>
 }
 
@@ -458,7 +458,7 @@
   %c0_i32 = arith.constant 1 : i32
   %0 = fir.undefined !fir.array<32x32xi32>
   // expected-error@+1 {{'fir.insert_on_range' op empty range}}
-  %2 = fir.insert_on_range %0, %c0_i32, [10 : index, 9 : index] : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
+  %2 = fir.insert_on_range %0, %c0_i32 from (10) to (9) : (!fir.array<32x32xi32>, i32) -> !fir.array<32x32xi32>
   fir.has_value %2 : !fir.array<32x32xi32>
 }
 
@@ -468,7 +468,7 @@
   %c0_i32 = arith.constant 1 : i32
   %0 = fir.undefined !fir.array<?xi32>
   // expected-error@+1 {{'fir.insert_on_range' op must have constant shape and size}}
-  %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 10 : index] : (!fir.array<?xi32>, i32) -> !fir.array<?xi32>
+  %2 = fir.insert_on_range %0, %c0_i32 from (0) to (10) : (!fir.array<?xi32>, i32) -> !fir.array<?xi32>
   fir.has_value %2 : !fir.array<?xi32>
 }
 
@@ -478,7 +478,7 @@
   %c0_i32 = arith.constant 1 : i32
   %0 = fir.undefined !fir.array<*:i32>
   // expected-error@+1 {{'fir.insert_on_range' op must have constant shape and size}}
-  %2 = fir.insert_on_range %0, %c0_i32, [0 : index, 10 : index] : (!fir.array<*:i32>, i32) -> !fir.array<*:i32>
+  %2 = fir.insert_on_range %0, %c0_i32 from (0) to (10) : (!fir.array<*:i32>, i32) -> !fir.array<*:i32>
   fir.has_value %2 : !fir.array<*:i32>
 }