| // RUN: mlir-opt -memref-elide-reinterpret-cast %s | FileCheck %s |
| |
| //===----------------------------------------------------------------------===// |
| // Positive tests |
| //===----------------------------------------------------------------------===// |
| |
| // CHECK-LABEL: func.func private @concat_zero_offset( |
| // CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32> |
| // CHECK-SAME: %[[DST:.*]]: memref<1x108xf32> |
| func.func private @concat_zero_offset(%src : memref<1x1xf32>, |
| %dst : memref<1x108xf32>) { |
| /// reinterpret_cast removed |
| // CHECK-NOT: memref.reinterpret_cast |
| %reinterpret_cast = memref.reinterpret_cast %dst |
| to offset: [0], sizes: [1, 1], strides: [1, 1] |
| : memref<1x108xf32> to memref<1x1xf32> |
| |
| /// Ensure copy was replaced |
| // CHECK-NOT: memref.copy |
| // CHECK: %[[C0:.*]] = arith.constant 0 : index |
| // CHECK: %[[C0_0:.*]] = arith.constant 0 : index |
| // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32> |
| // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32> |
| memref.copy %src, %reinterpret_cast |
| : memref<1x1xf32> to memref<1x1xf32> |
| return |
| } |
| |
| // CHECK-LABEL: func.func private @concat_nonzero_offset( |
| // CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32> |
| // CHECK-SAME: %[[DST:.*]]: memref<1x108xf32> |
| func.func private @concat_nonzero_offset(%src : memref<1x1xf32>, |
| %dst : memref<1x108xf32>) { |
| // CHECK-NOT: memref.reinterpret_cast |
| %reinterpret_cast = memref.reinterpret_cast %dst |
| to offset: [1], sizes: [1, 1], strides: [1, 1] |
| : memref<1x108xf32> |
| to memref<1x1xf32, strided<[1, 1], offset: 1>> |
| |
| // CHECK-NOT: memref.copy |
| // CHECK: %[[C0:.*]] = arith.constant 0 : index |
| // CHECK: %[[C1:.*]] = arith.constant 1 : index |
| // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32> |
| // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C1]]] : memref<1x108xf32> |
| memref.copy %src, %reinterpret_cast |
| : memref<1x1xf32> |
| to memref<1x1xf32, strided<[1, 1], offset: 1>> |
| return |
| } |
| |
| // CHECK-LABEL: func.func private @concat_dynamic_offset( |
| // CHECK-SAME: %[[OFF:.*]]: index |
| // CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32> |
| // CHECK-SAME: %[[DST:.*]]: memref<1x108xf32> |
| func.func private @concat_dynamic_offset(%offset: index, %src : memref<1x1xf32>, |
| %dst : memref<1x108xf32>) { |
| // CHECK-NOT: memref.reinterpret_cast |
| %reinterpret_cast = memref.reinterpret_cast %dst |
| to offset: [%offset], sizes: [1, 1], strides: [1, 1] |
| : memref<1x108xf32> |
| to memref<1x1xf32, strided<[1, 1], offset: ?>> |
| |
| // CHECK-NOT: memref.copy |
| // CHECK: %[[C0:.*]] = arith.constant 0 : index |
| // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] |
| // CHECK-SAME: : memref<1x1xf32> |
| /// Dynamic offset used in store |
| // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[OFF]]] : memref<1x108xf32> |
| memref.copy %src, %reinterpret_cast |
| : memref<1x1xf32> |
| to memref<1x1xf32, strided<[1, 1], offset: ?>> |
| return |
| } |
| |
| // CHECK-LABEL: func.func private @concat_strided( |
| // CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32> |
| // CHECK-SAME: %[[DST:.*]]: memref<1x108xf32> |
| func.func private @concat_strided(%src : memref<1x1xf32>, |
| %dst : memref<1x108xf32>) { |
| // CHECK-NOT: memref.reinterpret_cast |
| %reinterpret_cast = memref.reinterpret_cast %dst |
| to offset: [0], sizes: [1, 1], strides: [107, 2] |
| : memref<1x108xf32> to memref<1x1xf32, strided<[107, 2]>> |
| |
| // CHECK-NOT: memref.copy |
| // CHECK: %[[C0:.*]] = arith.constant 0 : index |
| // CHECK: %[[C0_0:.*]] = arith.constant 0 : index |
| // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32> |
| // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32> |
| memref.copy %src, %reinterpret_cast |
| : memref<1x1xf32> to memref<1x1xf32, strided<[107, 2]>> |
| return |
| } |
| |
| // CHECK-LABEL: func.func private @concat_dynamic_stride( |
| // CHECK-SAME: %[[STR0:[A-Za-z][A-Za-z0-9-]*]]: index |
| // CHECK-SAME: %[[STR1:[A-Za-z][A-Za-z0-9-]*]]: index |
| // CHECK-SAME: %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<1x1xf32> |
| // CHECK-SAME: %[[DST:[A-Za-z][A-Za-z0-9-]*]]: memref<1x108xf32> |
| func.func private @concat_dynamic_stride(%stride0: index, |
| %stride1: index, %src : memref<1x1xf32>, %dst : memref<1x108xf32>) { |
| // CHECK-NOT: memref.reinterpret_cast |
| %reinterpret_cast = memref.reinterpret_cast %dst |
| to offset: [0], sizes: [1, 1], strides: [%stride0, %stride1] |
| : memref<1x108xf32> |
| to memref<1x1xf32, strided<[?, ?]>> |
| |
| // CHECK-NOT: memref.copy |
| // CHECK: %[[C0:.*]] = arith.constant 0 : index |
| // CHECK: %[[C0_0:.*]] = arith.constant 0 : index |
| // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32> |
| /// Dynamic offset used in store |
| // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32> |
| memref.copy %src, %reinterpret_cast |
| : memref<1x1xf32> |
| to memref<1x1xf32, strided<[?, ?]>> |
| return |
| } |
| |
| // CHECK-LABEL: func.func private @concat_rank1( |
| // CHECK-SAME: %[[SRC:.*]]: memref<1xf32> |
| // CHECK-SAME: %[[DST:.*]]: memref<108xf32> |
| func.func private @concat_rank1(%src : memref<1xf32>, %dst : memref<108xf32>) { |
| // CHECK-NOT: memref.reinterpret_cast |
| %reinterpret_cast = memref.reinterpret_cast %dst |
| to offset: [0], sizes: [1], strides: [1] |
| : memref<108xf32> to memref<1xf32> |
| |
| // CHECK-NOT: memref.copy |
| // CHECK: %[[C0:.*]] = arith.constant 0 : index |
| // CHECK: %[[C0_0:.*]] = arith.constant 0 : index |
| // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<1xf32> |
| // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0_0]]] : memref<108xf32> |
| memref.copy %src, %reinterpret_cast |
| : memref<1xf32> to memref<1xf32> |
| return |
| } |
| |
| // CHECK-LABEL: func.func private @concat_rank3( |
| // CHECK-SAME: %[[SRC:.*]]: memref<1x1x1xf32> |
| // CHECK-SAME: %[[DST:.*]]: memref<1x1x108xf32> |
| func.func private @concat_rank3(%src : memref<1x1x1xf32>, |
| %dst : memref<1x1x108xf32>) { |
| // CHECK-NOT: memref.reinterpret_cast |
| %reinterpret_cast = memref.reinterpret_cast %dst |
| to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1] |
| : memref<1x1x108xf32> to memref<1x1x1xf32> |
| |
| // CHECK-NOT: memref.copy |
| // CHECK: %[[C0:.*]] = arith.constant 0 : index |
| // CHECK: %[[C0_0:.*]] = arith.constant 0 : index |
| // CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32> |
| // CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0]], %[[C0_0]]] : memref<1x1x108xf32> |
| memref.copy %src, %reinterpret_cast |
| : memref<1x1x1xf32> to memref<1x1x1xf32> |
| return |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Negative tests (must NOT rewrite) |
| //===----------------------------------------------------------------------===// |
| |
| // CHECK-LABEL: func.func private @negative_concat_strided_base( |
| func.func private @negative_concat_strided_base(%src: memref<1x1xf32>, |
| %dst: memref<8x1xf32, strided<[10, 2]>>) { |
| // CHECK: %reinterpret_cast = memref.reinterpret_cast %arg1 |
| %reinterpret_cast = memref.reinterpret_cast %dst |
| to offset: [6], sizes: [1, 1], strides: [11, 80] |
| : memref<8x1xf32, strided<[10, 2]>> |
| to memref<1x1xf32, strided<[11, 80], offset: 6>> |
| |
| // CHECK: memref.copy %arg0, %reinterpret_cast |
| // CHECK-NOT: memref.load |
| // CHECK-NOT: memref.store |
| memref.copy %src, %reinterpret_cast |
| : memref<1x1xf32> to memref<1x1xf32, strided<[11, 80], offset: 6>> |
| |
| return |
| } |
| |
| // CHECK-LABEL: func.func private @negative_reshape_rank_change( |
| func.func private @negative_reshape_rank_change(%src : memref<2x3xf32>, |
| %dst : memref<6xf32>) { |
| // CHECK: %reinterpret_cast = memref.reinterpret_cast %arg1 |
| %reinterpret_cast = memref.reinterpret_cast %dst |
| to offset: [0], sizes: [2, 3], strides: [3, 1] |
| : memref<6xf32> to memref<2x3xf32> |
| |
| // CHECK: memref.copy %arg0, %reinterpret_cast |
| // CHECK-NOT: memref.load |
| // CHECK-NOT: memref.store |
| memref.copy %src, %reinterpret_cast |
| : memref<2x3xf32> to memref<2x3xf32> |
| return |
| } |
| |
| // CHECK-LABEL: func.func private @negative_concat_multiple_non_unit_dims( |
| func.func private @negative_concat_multiple_non_unit_dims( |
| %src : memref<1x1xf32>, %dst : memref<2x108xf32>) { |
| // CHECK: %reinterpret_cast = memref.reinterpret_cast %arg1 |
| %reinterpret_cast = memref.reinterpret_cast %dst |
| to offset: [0], sizes: [1, 1], strides: [1, 1] |
| : memref<2x108xf32> |
| to memref<1x1xf32> |
| // CHECK: memref.copy %arg0, %reinterpret_cast |
| // CHECK-NOT: memref.load |
| // CHECK-NOT: memref.store |
| memref.copy %src, %reinterpret_cast |
| : memref<1x1xf32> to memref<1x1xf32> |
| return |
| } |
| |
| // CHECK-LABEL: func.func private @negative_plain_copy( |
| func.func private @negative_plain_copy(%src : memref<1x1xf32>, |
| %dst : memref<1x1xf32>) { |
| // CHECK: memref.copy %arg0, %arg1 |
| // CHECK-NOT: memref.load |
| // CHECK-NOT: memref.store |
| memref.copy %src, %dst |
| : memref<1x1xf32> to memref<1x1xf32> |
| return |
| } |