[flang][fir] Add array value operations.

We lower expressions with rank > 0 to a set of high-level array operations.
These operations are then analyzed and refined to more primitve
operations in subsequent pass(es).

This patch upstreams these array operations and some other helper ops.

Authors: Eric Schweitz, Rajan Walia, Kiran Chandramohan, et.al.

https://github.com/flang-compiler/f18-llvm-project/pull/565

Differential Revision: https://reviews.llvm.org/D97421

GitOrigin-RevId: 67360decc3d4bda363c2fa2550eb1c2b876c8cf0
diff --git a/include/flang/Optimizer/Dialect/FIROps.td b/include/flang/Optimizer/Dialect/FIROps.td
index c7f4353..32b9d34 100644
--- a/include/flang/Optimizer/Dialect/FIROps.td
+++ b/include/flang/Optimizer/Dialect/FIROps.td
@@ -19,7 +19,6 @@
 include "mlir/Interfaces/ControlFlowInterfaces.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
-
 include "flang/Optimizer/Dialect/FIRTypes.td"
 
 // Base class for FIR operations.
@@ -1495,9 +1494,263 @@
   let results = (outs fir_TypeDescType);
 }
 
+//===----------------------------------------------------------------------===//
+// Array value operations
+//===----------------------------------------------------------------------===//
+
+def fir_ArrayLoadOp : fir_Op<"array_load", [AttrSizedOperandSegments]> {
+
+  let summary = "Load an array as a value.";
+
+  let description = [{
+    Load an entire array as a single SSA value.
+
+    ```fortran
+      real :: a(o:n,p:m)
+      ...
+      ... = ... a ...
+    ```
+
+    One can use `fir.array_load` to produce an ssa-value that captures an
+    immutable value of the entire array `a`, as in the Fortran array expression
+    shown above. Subsequent changes to the memory containing the array do not
+    alter its composite value. This operation let's one load an array as a
+    value while applying a runtime shape, shift, or slice to the memory
+    reference, and its semantics guarantee immutability.
+
+    ```mlir
+      %s = fir.shape_shift %o, %n, %p, %m : (index, index, index, index) -> !fir.shape<2>
+      // load the entire array 'a'
+      %v = fir.array_load %a(%s) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.array<?x?xf32>
+      // a fir.store here into array %a does not change %v
+    ```
+  }];
+
+  let arguments = (ins
+    Arg<AnyRefOrBox, "", [MemRead]>:$memref,
+    Optional<AnyShapeOrShiftType>:$shape,
+    Optional<fir_SliceType>:$slice,
+    Variadic<AnyIntegerType>:$lenParams
+  );
+
+  let results = (outs fir_SequenceType);
+
+  let assemblyFormat = [{
+    $memref (`(`$shape^`)`)? (`[`$slice^`]`)? (`typeparams` $lenParams^)? attr-dict `:` functional-type(operands, results)
+  }];
+
+  let verifier = [{ return ::verify(*this); }];
+
+  let extraClassDeclaration = [{
+    std::vector<mlir::Value> getExtents();
+  }];
+}
+
+def fir_ArrayFetchOp : fir_Op<"array_fetch", [NoSideEffect]> {
+
+  let summary = "Fetch the value of an element of an array value";
+
+  let description = [{
+    Fetch the value of an element in an array value.
+
+    ```fortran
+      real :: a(n,m)
+      ...
+      ... a ...
+      ... a(r,s+1) ...
+    ```
+
+    One can use `fir.array_fetch` to fetch the (implied) value of `a(i,j)` in
+    an array expression as shown above. It can also be used to extract the
+    element `a(r,s+1)` in the second expression.
+
+    ```mlir
+      %s = fir.shape %n, %m : (index, index) -> !fir.shape<2>
+      // load the entire array 'a'
+      %v = fir.array_load %a(%s) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.array<?x?xf32>
+      // fetch the value of one of the array value's elements
+      %1 = fir.array_fetch %v, %i, %j : (!fir.array<?x?xf32>, index, index) -> f32
+    ```
+
+    It is only possible to use `array_fetch` on an `array_load` result value.
+  }];
+
+  let arguments = (ins
+    fir_SequenceType:$sequence,
+    Variadic<AnyCoordinateType>:$indices
+  );
+
+  let results = (outs AnyType:$element);
+
+  let assemblyFormat = [{
+    $sequence `,` $indices attr-dict `:` functional-type(operands, results)
+  }];
+
+  let verifier = [{
+    auto arrTy = sequence().getType().cast<fir::SequenceType>();
+    if (indices().size() != arrTy.getDimension())
+      return emitOpError("number of indices != dimension of array");
+    if (element().getType() != arrTy.getEleTy())
+      return emitOpError("return type does not match array");
+    if (!isa<fir::ArrayLoadOp>(sequence().getDefiningOp()))
+      return emitOpError("argument #0 must be result of fir.array_load");
+    return mlir::success();
+  }];
+}
+
+def fir_ArrayUpdateOp : fir_Op<"array_update", [NoSideEffect]> {
+
+  let summary = "Update the value of an element of an array value";
+
+  let description = [{
+    Updates the value of an element in an array value. A new array value is
+    returned where all element values of the input array are identical except
+    for the selected element which is the value passed in the update.
+
+    ```fortran
+      real :: a(n,m)
+      ...
+      a = ...
+    ```
+
+    One can use `fir.array_update` to update the (implied) value of `a(i,j)`
+    in an array expression as shown above.
+
+    ```mlir
+      %s = fir.shape %n, %m : (index, index) -> !fir.shape<2>
+      // load the entire array 'a'
+      %v = fir.array_load %a(%s) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> !fir.array<?x?xf32>
+      // update the value of one of the array value's elements
+      // %r_{ij} = %f  if (i,j) = (%i,%j),   %v_{ij} otherwise
+      %r = fir.array_update %v, %f, %i, %j : (!fir.array<?x?xf32>, f32, index, index) -> !fir.array<?x?xf32>
+      fir.array_merge_store %v, %r to %a : !fir.ref<!fir.array<?x?xf32>>
+    ```
+
+    An array value update behaves as if a mapping function from the indices
+    to the new value has been added, replacing the previous mapping. These
+    mappings can be added to the ssa-value, but will not be materialized in
+    memory until the `fir.array_merge_store` is performed.
+  }];
+
+  let arguments = (ins
+    fir_SequenceType:$sequence,
+    AnyType:$merge,
+    Variadic<AnyCoordinateType>:$indices
+  );
+
+  let results = (outs fir_SequenceType);
+
+  let assemblyFormat = [{
+    $sequence `,` $merge `,` $indices attr-dict `:` functional-type(operands, results)
+  }];
+
+  let verifier = [{
+    auto arrTy = sequence().getType().cast<fir::SequenceType>();
+    if (merge().getType() != arrTy.getEleTy())
+      return emitOpError("merged value does not have element type");
+    if (indices().size() != arrTy.getDimension())
+      return emitOpError("number of indices != dimension of array");
+    return mlir::success();
+  }];
+}
+
+def fir_ArrayMergeStoreOp : fir_Op<"array_merge_store", [
+    TypesMatchWith<"type of 'original' matches element type of 'memref'",
+                     "memref", "original",
+                     "fir::dyn_cast_ptrOrBoxEleTy($_self)">,
+    TypesMatchWith<"type of 'sequence' matches element type of 'memref'",
+                     "memref", "sequence",
+                     "fir::dyn_cast_ptrOrBoxEleTy($_self)">]> {
+
+  let summary = "Store merged array value to memory.";
+
+  let description = [{
+    Store a merged array value to memory.
+
+    ```fortran
+      real :: a(n,m)
+      ...
+      a = ...
+    ```
+
+    One can use `fir.array_merge_store` to merge/copy the value of `a` in an
+    array expression as shown above.
+
+    ```mlir
+      %v = fir.array_load %a(%shape) : ...
+      %r = fir.array_update %v, %f, %i, %j : (!fir.array<?x?xf32>, f32, index, index) -> !fir.array<?x?xf32>
+      fir.array_merge_store %v, %r to %a : !fir.ref<!fir.array<?x?xf32>>
+    ```
+
+    This operation merges the original loaded array value, `%v`, with the
+    chained updates, `%r`, and stores the result to the array at address, `%a`.
+  }];
+
+  let arguments = (ins
+    fir_SequenceType:$original,
+    fir_SequenceType:$sequence,
+    Arg<AnyRefOrBox, "", [MemWrite]>:$memref
+  );
+
+  let assemblyFormat = "$original `,` $sequence `to` $memref attr-dict `:` type($memref)";
+
+  let verifier = [{
+    if (!isa<ArrayLoadOp>(original().getDefiningOp()))
+       return emitOpError("operand #0 must be result of a fir.array_load op");
+    return mlir::success();
+  }];
+}
+
+//===----------------------------------------------------------------------===//
 // Record and array type operations
+//===----------------------------------------------------------------------===//
+
+def fir_ArrayCoorOp : fir_Op<"array_coor",
+    [NoSideEffect, AttrSizedOperandSegments]> {
+
+  let summary = "Find the coordinate of an element of an array";
+
+  let description = [{
+    Compute the location of an element in an array when the shape of the
+    array is only known at runtime.
+
+    This operation is intended to capture all the runtime values needed to
+    compute the address of an array reference in a single high-level op. Given
+    the following Fortran input:
+
+    ```fortran
+      real :: a(n,m)
+      ...
+      ... a(i,j) ...
+    ```
+
+    One can use `fir.array_coor` to determine the address of `a(i,j)`.
+
+    ```mlir
+      %s = fir.shape %n, %m : (index, index) -> !fir.shape<2>
+      %1 = fir.array_coor %a(%s) %i, %j : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>, index, index) -> !fir.ref<f32>
+    ```
+  }];
+
+  let arguments = (ins
+    AnyRefOrBox:$memref,
+    Optional<AnyShapeOrShiftType>:$shape,
+    Optional<fir_SliceType>:$slice,
+    Variadic<AnyCoordinateType>:$indices,
+    Variadic<AnyIntegerType>:$lenParams
+  );
+
+  let results = (outs fir_ReferenceType);
+
+  let assemblyFormat = [{
+    $memref (`(`$shape^`)`)? (`[`$slice^`]`)? $indices (`typeparams` $lenParams^)? attr-dict `:` functional-type(operands, results)
+  }];
+
+  let verifier = [{ return ::verify(*this); }];
+}
 
 def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> {
+
   let summary = "Finds the coordinate (location) of a value in memory";
 
   let description = [{
@@ -1674,18 +1927,218 @@
     }
   }];
 
-  let builders = [
-    OpBuilderDAG<(ins "StringRef":$fieldName, "Type":$recTy,
-      CArg<"ValueRange", "{}">:$operands),
+  let builders = [OpBuilderDAG<(ins "llvm::StringRef":$fieldName,
+      "mlir::Type":$recTy, CArg<"mlir::ValueRange","{}">:$operands),
     [{
-      $_state.addAttribute(fieldAttrName(), $_builder.getStringAttr(fieldName));
+      $_state.addAttribute(fieldAttrName(),
+        $_builder.getStringAttr(fieldName));
       $_state.addAttribute(typeAttrName(), TypeAttr::get(recTy));
       $_state.addOperands(operands);
-    }]>];
+    }]
+  >];
 
   let extraClassDeclaration = [{
     static constexpr llvm::StringRef fieldAttrName() { return "field_id"; }
     static constexpr llvm::StringRef typeAttrName() { return "on_type"; }
+    llvm::StringRef getFieldName() { return field_id(); }
+  }];
+}
+
+def fir_ShapeOp : fir_Op<"shape", [NoSideEffect]> {
+
+  let summary = "generate an abstract shape vector of type `!fir.shape`";
+
+  let description = [{
+    The arguments are an ordered list of integral type values that define the
+    runtime extent of each dimension of an array. The shape information is
+    given in the same row-to-column order as Fortran. This abstract shape value
+    must be applied to a reified object, so all shape information must be
+    specified.  The extent must be nonnegative.
+
+    ```mlir
+      %d = fir.shape %row_sz, %col_sz : (index, index) -> !fir.shape<2>
+    ```
+  }];
+
+  let arguments = (ins Variadic<AnyIntegerType>:$extents);
+
+  let results = (outs fir_ShapeType);
+
+  let assemblyFormat = [{
+    operands attr-dict `:` functional-type(operands, results)
+  }];
+
+  let verifier = [{
+    auto size = extents().size();
+    auto shapeTy = getType().dyn_cast<fir::ShapeType>();
+    assert(shapeTy && "must be a shape type");
+    if (shapeTy.getRank() != size)
+      return emitOpError("shape type rank mismatch");
+    return mlir::success();
+  }];
+
+  let extraClassDeclaration = [{
+    std::vector<mlir::Value> getExtents() {
+      return {extents().begin(), extents().end()};
+    }
+  }];
+}
+
+def fir_ShapeShiftOp : fir_Op<"shape_shift", [NoSideEffect]> {
+
+  let summary = [{
+    generate an abstract shape and shift vector of type `!fir.shapeshift`
+  }];
+
+  let description = [{
+    The arguments are an ordered list of integral type values that is a multiple
+    of 2 in length. Each such pair is defined as: the lower bound and the
+    extent for that dimension. The shifted shape information is given in the
+    same row-to-column order as Fortran. This abstract shifted shape value must
+    be applied to a reified object, so all shifted shape information must be
+    specified.  The extent must be nonnegative.
+
+    ```mlir
+      %d = fir.shape_shift %lo, %extent : (index, index) -> !fir.shapeshift<1>
+    ```
+  }];
+
+  let arguments = (ins Variadic<AnyIntegerType>:$pairs);
+
+  let results = (outs fir_ShapeShiftType);
+
+  let assemblyFormat = [{
+    operands attr-dict `:` functional-type(operands, results)
+  }];
+
+  let verifier = [{
+    auto size = pairs().size();
+    if (size < 2 || size > 16 * 2)
+      return emitOpError("incorrect number of args");
+    if (size % 2 != 0)
+      return emitOpError("requires a multiple of 2 args");
+    auto shapeTy = getType().dyn_cast<fir::ShapeShiftType>();
+    assert(shapeTy && "must be a shape shift type");
+    if (shapeTy.getRank() * 2 != size)
+      return emitOpError("shape type rank mismatch");
+    return mlir::success();
+  }];
+
+  let extraClassDeclaration = [{
+    // Logically unzip the origins from the extent values.
+    std::vector<mlir::Value> getOrigins() {
+      std::vector<mlir::Value> result;
+      for (auto i : llvm::enumerate(pairs()))
+        if (!(i.index() & 1))
+          result.push_back(i.value());
+      return result;
+    }
+
+    // Logically unzip the extents from the origin values.
+    std::vector<mlir::Value> getExtents() {
+      std::vector<mlir::Value> result;
+      for (auto i : llvm::enumerate(pairs()))
+        if (i.index() & 1)
+          result.push_back(i.value());
+      return result;
+    }
+  }];
+}
+
+def fir_ShiftOp : fir_Op<"shift", [NoSideEffect]> {
+
+  let summary = "generate an abstract shift vector of type `!fir.shift`";
+
+  let description = [{
+    The arguments are an ordered list of integral type values that define the
+    runtime lower bound of each dimension of an array. The shape information is
+    given in the same row-to-column order as Fortran. This abstract shift value
+    must be applied to a reified object, so all shift information must be
+    specified.
+
+    ```mlir
+      %d = fir.shift %row_lb, %col_lb : (index, index) -> !fir.shift<2>
+    ```
+  }];
+
+  let arguments = (ins Variadic<AnyIntegerType>:$origins);
+
+  let results = (outs fir_ShiftType);
+
+  let assemblyFormat = [{
+    operands attr-dict `:` functional-type(operands, results)
+  }];
+
+  let verifier = [{
+    auto size = origins().size();
+    auto shiftTy = getType().dyn_cast<fir::ShiftType>();
+    assert(shiftTy && "must be a shift type");
+    if (shiftTy.getRank() != size)
+      return emitOpError("shift type rank mismatch");
+    return mlir::success();
+  }];
+
+  let extraClassDeclaration = [{
+    std::vector<mlir::Value> getOrigins() {
+      return {origins().begin(), origins().end()};
+    }
+  }];
+}
+
+def fir_SliceOp : fir_Op<"slice", [NoSideEffect, AttrSizedOperandSegments]> {
+
+  let summary = "generate an abstract slice vector of type `!fir.slice`";
+
+  let description = [{
+    The array slicing arguments are an ordered list of integral type values
+    that must be a multiple of 3 in length.  Each such triple is defined as:
+    the lower bound, the upper bound, and the stride for that dimension, as in
+    Fortran syntax. Both bounds are inclusive. The array slice information is
+    given in the same row-to-column order as Fortran. This abstract slice value
+    must be applied to a reified object, so all slice information must be
+    specified.  The extent must be nonnegative and the stride must not be zero.
+
+    ```mlir
+      %d = fir.slice %lo, %hi, %step : (index, index, index) -> !fir.slice<1>
+    ```
+
+    To support generalized slicing of Fortran's dynamic derived types, a slice
+    op can be given a component path (narrowing from the product type of the
+    original array to the specific elemental type of the sliced projection).
+
+    ```mlir
+      %fld = fir.field_index component, !fir.type<t{...component:ct...}>
+      %d = fir.slice %lo, %hi, %step path %fld : (index, index, index, !fir.field) -> !fir.slice<1>
+    ```
+  }];
+
+  let arguments = (ins
+    Variadic<AnyCoordinateType>:$triples,
+    Variadic<AnyComponentType>:$fields
+  );
+
+  let results = (outs fir_SliceType);
+
+  let assemblyFormat = [{
+    $triples (`path` $fields^)? attr-dict `:` functional-type(operands, results)
+  }];
+
+  let verifier = [{
+    auto size = triples().size();
+    if (size < 3 || size > 16 * 3)
+      return emitOpError("incorrect number of args for triple");
+    if (size % 3 != 0)
+      return emitOpError("requires a multiple of 3 args");
+    auto sliceTy = getType().dyn_cast<fir::SliceType>();
+    assert(sliceTy && "must be a slice type");
+    if (sliceTy.getRank() * 3 != size)
+      return emitOpError("slice type rank mismatch");
+    return mlir::success();
+  }];
+
+  let extraClassDeclaration = [{
+    unsigned getOutRank() { return getOutputRank(triples()); }
+    static unsigned getOutputRank(mlir::ValueRange triples);
   }];
 }
 
diff --git a/include/flang/Optimizer/Dialect/FIRType.h b/include/flang/Optimizer/Dialect/FIRType.h
index 2477b07..ca0dddd 100644
--- a/include/flang/Optimizer/Dialect/FIRType.h
+++ b/include/flang/Optimizer/Dialect/FIRType.h
@@ -10,8 +10,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef OPTIMIZER_DIALECT_FIRTYPE_H
-#define OPTIMIZER_DIALECT_FIRTYPE_H
+#ifndef FORTRAN_OPTIMIZER_DIALECT_FIRTYPE_H
+#define FORTRAN_OPTIMIZER_DIALECT_FIRTYPE_H
 
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -23,7 +23,8 @@
 namespace llvm {
 class raw_ostream;
 class StringRef;
-template <typename> class ArrayRef;
+template <typename>
+class ArrayRef;
 class hash_code;
 } // namespace llvm
 
@@ -80,6 +81,10 @@
 /// not a memory reference type, then returns a null `Type`.
 mlir::Type dyn_cast_ptrEleTy(mlir::Type t);
 
+/// Extract the `Type` pointed to from a FIR memory reference or box type. If
+/// `t` is not a memory reference or box type, then returns a null `Type`.
+mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t);
+
 /// Is `t` a FIR Real or MLIR Float type?
 inline bool isa_real(mlir::Type t) {
   return t.isa<fir::RealType>() || t.isa<mlir::FloatType>();
@@ -125,4 +130,4 @@
 
 } // namespace fir
 
-#endif // OPTIMIZER_DIALECT_FIRTYPE_H
+#endif // FORTRAN_OPTIMIZER_DIALECT_FIRTYPE_H
diff --git a/lib/Optimizer/Dialect/FIROps.cpp b/lib/Optimizer/Dialect/FIROps.cpp
index ed053fd..80f1a1d 100644
--- a/lib/Optimizer/Dialect/FIROps.cpp
+++ b/lib/Optimizer/Dialect/FIROps.cpp
@@ -5,6 +5,10 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
 
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/Dialect/FIRAttr.h"
@@ -116,6 +120,90 @@
 }
 
 //===----------------------------------------------------------------------===//
+// ArrayCoorOp
+//===----------------------------------------------------------------------===//
+
+static mlir::LogicalResult verify(fir::ArrayCoorOp op) {
+  auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(op.memref().getType());
+  auto arrTy = eleTy.dyn_cast<fir::SequenceType>();
+  if (!arrTy)
+    return op.emitOpError("must be a reference to an array");
+  auto arrDim = arrTy.getDimension();
+
+  if (auto shapeOp = op.shape()) {
+    auto shapeTy = shapeOp.getType();
+    unsigned shapeTyRank = 0;
+    if (auto s = shapeTy.dyn_cast<fir::ShapeType>()) {
+      shapeTyRank = s.getRank();
+    } else if (auto ss = shapeTy.dyn_cast<fir::ShapeShiftType>()) {
+      shapeTyRank = ss.getRank();
+    } else {
+      auto s = shapeTy.cast<fir::ShiftType>();
+      shapeTyRank = s.getRank();
+      if (!op.memref().getType().isa<fir::BoxType>())
+        return op.emitOpError("shift can only be provided with fir.box memref");
+    }
+    if (arrDim && arrDim != shapeTyRank)
+      return op.emitOpError("rank of dimension mismatched");
+    if (shapeTyRank != op.indices().size())
+      return op.emitOpError("number of indices do not match dim rank");
+  }
+
+  if (auto sliceOp = op.slice())
+    if (auto sliceTy = sliceOp.getType().dyn_cast<fir::SliceType>())
+      if (sliceTy.getRank() != arrDim)
+        return op.emitOpError("rank of dimension in slice mismatched");
+
+  return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
+// ArrayLoadOp
+//===----------------------------------------------------------------------===//
+
+std::vector<mlir::Value> fir::ArrayLoadOp::getExtents() {
+  if (auto sh = shape())
+    if (auto *op = sh.getDefiningOp()) {
+      if (auto shOp = dyn_cast<fir::ShapeOp>(op))
+        return shOp.getExtents();
+      return cast<fir::ShapeShiftOp>(op).getExtents();
+    }
+  return {};
+}
+
+static mlir::LogicalResult verify(fir::ArrayLoadOp op) {
+  auto eleTy = fir::dyn_cast_ptrOrBoxEleTy(op.memref().getType());
+  auto arrTy = eleTy.dyn_cast<fir::SequenceType>();
+  if (!arrTy)
+    return op.emitOpError("must be a reference to an array");
+  auto arrDim = arrTy.getDimension();
+
+  if (auto shapeOp = op.shape()) {
+    auto shapeTy = shapeOp.getType();
+    unsigned shapeTyRank = 0;
+    if (auto s = shapeTy.dyn_cast<fir::ShapeType>()) {
+      shapeTyRank = s.getRank();
+    } else if (auto ss = shapeTy.dyn_cast<fir::ShapeShiftType>()) {
+      shapeTyRank = ss.getRank();
+    } else {
+      auto s = shapeTy.cast<fir::ShiftType>();
+      shapeTyRank = s.getRank();
+      if (!op.memref().getType().isa<fir::BoxType>())
+        return op.emitOpError("shift can only be provided with fir.box memref");
+    }
+    if (arrDim && arrDim != shapeTyRank)
+      return op.emitOpError("rank of dimension mismatched");
+  }
+
+  if (auto sliceOp = op.slice())
+    if (auto sliceTy = sliceOp.getType().dyn_cast<fir::SliceType>())
+      if (sliceTy.getRank() != arrDim)
+        return op.emitOpError("rank of dimension in slice mismatched");
+
+  return mlir::success();
+}
+
+//===----------------------------------------------------------------------===//
 // BoxAddrOp
 //===----------------------------------------------------------------------===//
 
diff --git a/lib/Optimizer/Dialect/FIRType.cpp b/lib/Optimizer/Dialect/FIRType.cpp
index 4cbdf15..d3b9a2f 100644
--- a/lib/Optimizer/Dialect/FIRType.cpp
+++ b/lib/Optimizer/Dialect/FIRType.cpp
@@ -223,6 +223,19 @@
       .Default([](mlir::Type) { return mlir::Type{}; });
 }
 
+mlir::Type dyn_cast_ptrOrBoxEleTy(mlir::Type t) {
+  return llvm::TypeSwitch<mlir::Type, mlir::Type>(t)
+      .Case<fir::ReferenceType, fir::PointerType, fir::HeapType>(
+          [](auto p) { return p.getEleTy(); })
+      .Case<fir::BoxType>([](auto p) {
+        auto eleTy = p.getEleTy();
+        if (auto ty = fir::dyn_cast_ptrEleTy(eleTy))
+          return ty;
+        return eleTy;
+      })
+      .Default([](mlir::Type) { return mlir::Type{}; });
+}
+
 } // namespace fir
 
 namespace {
diff --git a/test/Fir/fir-ops.fir b/test/Fir/fir-ops.fir
index 3e8c81c..cbfe318 100644
--- a/test/Fir/fir-ops.fir
+++ b/test/Fir/fir-ops.fir
@@ -618,5 +618,17 @@
 
   // CHECK: [[ARR2:%.*]] = fir.zero_bits !fir.array<10xi32>
   %arr2 = fir.zero_bits !fir.array<10xi32>
+
+  // CHECK: [[SHAPE:%.*]] = fir.shape_shift [[INDXM:%.*]], [[INDXN:%.*]], [[INDXO:%.*]], [[INDXP:%.*]] : (index, index, index, index) -> !fir.shapeshift<2>
+  // CHECK: [[AV1:%.*]] = fir.array_load [[ARR1]]([[SHAPE]]) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shapeshift<2>) -> !fir.array<?x?xf32>
+  // CHECK: [[FVAL:%.*]] = fir.array_fetch [[AV1]], [[I10]], [[J20]] : (!fir.array<?x?xf32>, index, index) -> f32
+  // CHECK: [[AV2:%.*]] = fir.array_update [[AV1]], [[FVAL]], [[I10]], [[J20]] : (!fir.array<?x?xf32>, f32, index, index) -> !fir.array<?x?xf32>
+  // CHECK: fir.array_merge_store [[AV1]], [[AV2]] to [[ARR1]] : !fir.ref<!fir.array<?x?xf32>>
+  %s = fir.shape_shift %m, %n, %o, %p : (index, index, index, index) -> !fir.shapeshift<2>
+  %av1 = fir.array_load %arr1(%s) : (!fir.ref<!fir.array<?x?xf32>>, !fir.shapeshift<2>) -> !fir.array<?x?xf32>
+  %f = fir.array_fetch %av1, %i10, %j20 : (!fir.array<?x?xf32>, index, index) -> f32
+  %av2 = fir.array_update %av1, %f, %i10, %j20 : (!fir.array<?x?xf32>, f32, index, index) -> !fir.array<?x?xf32>
+  fir.array_merge_store %av1, %av2 to %arr1 : !fir.ref<!fir.array<?x?xf32>>
+
   return
 }