blob: da47562e9c0d6829933c95554b9334a33e1cdf36 [file]
// RUN: mlir-opt -memref-elide-reinterpret-cast %s | FileCheck %s
//===----------------------------------------------------------------------===//
// Positive tests
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func.func private @concat_zero_offset(
// CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32>
// CHECK-SAME: %[[DST:.*]]: memref<1x108xf32>
func.func private @concat_zero_offset(%src : memref<1x1xf32>,
%dst : memref<1x108xf32>) {
/// reinterpret_cast removed
// CHECK-NOT: memref.reinterpret_cast
%reinterpret_cast = memref.reinterpret_cast %dst
to offset: [0], sizes: [1, 1], strides: [1, 1]
: memref<1x108xf32> to memref<1x1xf32>
/// Ensure copy was replaced
// CHECK-NOT: memref.copy
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
// CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32>
memref.copy %src, %reinterpret_cast
: memref<1x1xf32> to memref<1x1xf32>
return
}
// CHECK-LABEL: func.func private @concat_nonzero_offset(
// CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32>
// CHECK-SAME: %[[DST:.*]]: memref<1x108xf32>
func.func private @concat_nonzero_offset(%src : memref<1x1xf32>,
%dst : memref<1x108xf32>) {
// CHECK-NOT: memref.reinterpret_cast
%reinterpret_cast = memref.reinterpret_cast %dst
to offset: [1], sizes: [1, 1], strides: [1, 1]
: memref<1x108xf32>
to memref<1x1xf32, strided<[1, 1], offset: 1>>
// CHECK-NOT: memref.copy
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
// CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C1]]] : memref<1x108xf32>
memref.copy %src, %reinterpret_cast
: memref<1x1xf32>
to memref<1x1xf32, strided<[1, 1], offset: 1>>
return
}
// CHECK-LABEL: func.func private @concat_dynamic_offset(
// CHECK-SAME: %[[OFF:.*]]: index
// CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32>
// CHECK-SAME: %[[DST:.*]]: memref<1x108xf32>
func.func private @concat_dynamic_offset(%offset: index, %src : memref<1x1xf32>,
%dst : memref<1x108xf32>) {
// CHECK-NOT: memref.reinterpret_cast
%reinterpret_cast = memref.reinterpret_cast %dst
to offset: [%offset], sizes: [1, 1], strides: [1, 1]
: memref<1x108xf32>
to memref<1x1xf32, strided<[1, 1], offset: ?>>
// CHECK-NOT: memref.copy
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]]
// CHECK-SAME: : memref<1x1xf32>
/// Dynamic offset used in store
// CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[OFF]]] : memref<1x108xf32>
memref.copy %src, %reinterpret_cast
: memref<1x1xf32>
to memref<1x1xf32, strided<[1, 1], offset: ?>>
return
}
// CHECK-LABEL: func.func private @concat_strided(
// CHECK-SAME: %[[SRC:.*]]: memref<1x1xf32>
// CHECK-SAME: %[[DST:.*]]: memref<1x108xf32>
func.func private @concat_strided(%src : memref<1x1xf32>,
%dst : memref<1x108xf32>) {
// CHECK-NOT: memref.reinterpret_cast
%reinterpret_cast = memref.reinterpret_cast %dst
to offset: [0], sizes: [1, 1], strides: [107, 2]
: memref<1x108xf32> to memref<1x1xf32, strided<[107, 2]>>
// CHECK-NOT: memref.copy
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
// CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32>
memref.copy %src, %reinterpret_cast
: memref<1x1xf32> to memref<1x1xf32, strided<[107, 2]>>
return
}
// CHECK-LABEL: func.func private @concat_dynamic_stride(
// CHECK-SAME: %[[STR0:[A-Za-z][A-Za-z0-9-]*]]: index
// CHECK-SAME: %[[STR1:[A-Za-z][A-Za-z0-9-]*]]: index
// CHECK-SAME: %[[SRC:[A-Za-z][A-Za-z0-9-]*]]: memref<1x1xf32>
// CHECK-SAME: %[[DST:[A-Za-z][A-Za-z0-9-]*]]: memref<1x108xf32>
func.func private @concat_dynamic_stride(%stride0: index,
%stride1: index, %src : memref<1x1xf32>, %dst : memref<1x108xf32>) {
// CHECK-NOT: memref.reinterpret_cast
%reinterpret_cast = memref.reinterpret_cast %dst
to offset: [0], sizes: [1, 1], strides: [%stride0, %stride1]
: memref<1x108xf32>
to memref<1x1xf32, strided<[?, ?]>>
// CHECK-NOT: memref.copy
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]]] : memref<1x1xf32>
/// Dynamic offset used in store
// CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0_0]]] : memref<1x108xf32>
memref.copy %src, %reinterpret_cast
: memref<1x1xf32>
to memref<1x1xf32, strided<[?, ?]>>
return
}
// CHECK-LABEL: func.func private @concat_rank1(
// CHECK-SAME: %[[SRC:.*]]: memref<1xf32>
// CHECK-SAME: %[[DST:.*]]: memref<108xf32>
func.func private @concat_rank1(%src : memref<1xf32>, %dst : memref<108xf32>) {
// CHECK-NOT: memref.reinterpret_cast
%reinterpret_cast = memref.reinterpret_cast %dst
to offset: [0], sizes: [1], strides: [1]
: memref<108xf32> to memref<1xf32>
// CHECK-NOT: memref.copy
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]]] : memref<1xf32>
// CHECK: memref.store %[[VAL]], %[[DST]][%[[C0_0]]] : memref<108xf32>
memref.copy %src, %reinterpret_cast
: memref<1xf32> to memref<1xf32>
return
}
// CHECK-LABEL: func.func private @concat_rank3(
// CHECK-SAME: %[[SRC:.*]]: memref<1x1x1xf32>
// CHECK-SAME: %[[DST:.*]]: memref<1x1x108xf32>
func.func private @concat_rank3(%src : memref<1x1x1xf32>,
%dst : memref<1x1x108xf32>) {
// CHECK-NOT: memref.reinterpret_cast
%reinterpret_cast = memref.reinterpret_cast %dst
to offset: [0], sizes: [1, 1, 1], strides: [1, 1, 1]
: memref<1x1x108xf32> to memref<1x1x1xf32>
// CHECK-NOT: memref.copy
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[C0]]] : memref<1x1x1xf32>
// CHECK: memref.store %[[VAL]], %[[DST]][%[[C0]], %[[C0]], %[[C0_0]]] : memref<1x1x108xf32>
memref.copy %src, %reinterpret_cast
: memref<1x1x1xf32> to memref<1x1x1xf32>
return
}
//===----------------------------------------------------------------------===//
// Negative tests (must NOT rewrite)
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func.func private @negative_concat_strided_base(
func.func private @negative_concat_strided_base(%src: memref<1x1xf32>,
%dst: memref<8x1xf32, strided<[10, 2]>>) {
// CHECK: %reinterpret_cast = memref.reinterpret_cast %arg1
%reinterpret_cast = memref.reinterpret_cast %dst
to offset: [6], sizes: [1, 1], strides: [11, 80]
: memref<8x1xf32, strided<[10, 2]>>
to memref<1x1xf32, strided<[11, 80], offset: 6>>
// CHECK: memref.copy %arg0, %reinterpret_cast
// CHECK-NOT: memref.load
// CHECK-NOT: memref.store
memref.copy %src, %reinterpret_cast
: memref<1x1xf32> to memref<1x1xf32, strided<[11, 80], offset: 6>>
return
}
// CHECK-LABEL: func.func private @negative_reshape_rank_change(
func.func private @negative_reshape_rank_change(%src : memref<2x3xf32>,
%dst : memref<6xf32>) {
// CHECK: %reinterpret_cast = memref.reinterpret_cast %arg1
%reinterpret_cast = memref.reinterpret_cast %dst
to offset: [0], sizes: [2, 3], strides: [3, 1]
: memref<6xf32> to memref<2x3xf32>
// CHECK: memref.copy %arg0, %reinterpret_cast
// CHECK-NOT: memref.load
// CHECK-NOT: memref.store
memref.copy %src, %reinterpret_cast
: memref<2x3xf32> to memref<2x3xf32>
return
}
// CHECK-LABEL: func.func private @negative_concat_multiple_non_unit_dims(
func.func private @negative_concat_multiple_non_unit_dims(
%src : memref<1x1xf32>, %dst : memref<2x108xf32>) {
// CHECK: %reinterpret_cast = memref.reinterpret_cast %arg1
%reinterpret_cast = memref.reinterpret_cast %dst
to offset: [0], sizes: [1, 1], strides: [1, 1]
: memref<2x108xf32>
to memref<1x1xf32>
// CHECK: memref.copy %arg0, %reinterpret_cast
// CHECK-NOT: memref.load
// CHECK-NOT: memref.store
memref.copy %src, %reinterpret_cast
: memref<1x1xf32> to memref<1x1xf32>
return
}
// CHECK-LABEL: func.func private @negative_plain_copy(
func.func private @negative_plain_copy(%src : memref<1x1xf32>,
%dst : memref<1x1xf32>) {
// CHECK: memref.copy %arg0, %arg1
// CHECK-NOT: memref.load
// CHECK-NOT: memref.store
memref.copy %src, %dst
: memref<1x1xf32> to memref<1x1xf32>
return
}