| // RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s |
| |
| // Checks that NOP casts are removed. |
| // CHECK-LABEL: cast_values |
| func @cast_values(%arg0: tensor<*xi32>) -> tensor<2xi32> { |
| // NOP cast |
| %0 = tensor.cast %arg0 : tensor<*xi32> to tensor<*xi32> |
| // CHECK-NEXT: %[[RET:.*]] = tensor.cast %arg0 : tensor<*xi32> to tensor<2xi32> |
| %2 = tensor.cast %0 : tensor<*xi32> to tensor<2xi32> |
| // NOP cast |
| %4 = tensor.cast %2 : tensor<2xi32> to tensor<2xi32> |
| // CHECK-NEXT: return %[[RET]] : tensor<2xi32> |
| return %4 : tensor<2xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @tensor.cast_chain_ok |
| // CHECK-SAME: %[[IN:.*]]: tensor<*xi32> |
| func @tensor.cast_chain_ok(%input: tensor<*xi32>) -> tensor<4x8xi32> { |
| // CHECK-NEXT: %[[RES:.*]] = tensor.cast %[[IN]] : tensor<*xi32> to tensor<4x8xi32> |
| %0 = tensor.cast %input : tensor<*xi32> to tensor<4x?xi32> |
| %1 = tensor.cast %0 : tensor<4x?xi32> to tensor<4x8xi32> |
| // CHECK-NEXT: return %[[RES]] |
| return %1 : tensor<4x8xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @tensor.cast_chain_regain |
| // CHECK-SAME: %[[IN:.*]]: tensor<4xi32> |
| func @tensor.cast_chain_regain(%input: tensor<4xi32>) -> tensor<4xi32> { |
| %0 = tensor.cast %input : tensor<4xi32> to tensor<?xi32> |
| %1 = tensor.cast %0 : tensor<?xi32> to tensor<4xi32> |
| // CHECK-NEXT: return %[[IN]] |
| return %1 : tensor<4xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @tensor.cast_chain_keep |
| // CHECK-SAME: %[[IN:.*]]: tensor<?x?xi32> |
| func @tensor.cast_chain_keep(%input: tensor<?x?xi32>) -> tensor<?x8xi32> { |
| // CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]] |
| %0 = tensor.cast %input : tensor<?x?xi32> to tensor<4x?xi32> |
| // CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]] |
| %1 = tensor.cast %0 : tensor<4x?xi32> to tensor<?x8xi32> |
| // CHECK-NEXT: return %[[C2]] |
| return %1 : tensor<?x8xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @tensor.cast_chain_invalid |
| // CHECK-SAME: %[[IN:.*]]: tensor<4x8xi32> |
| func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> { |
| // CHECK-NEXT: %[[C1:.*]] = tensor.cast %[[IN]] |
| %0 = tensor.cast %input : tensor<4x8xi32> to tensor<?x?xi32> |
| // CHECK-NEXT: %[[C2:.*]] = tensor.cast %[[C1]] |
| %1 = tensor.cast %0 : tensor<?x?xi32> to tensor<8x4xi32> |
| // CHECK-NEXT: return %[[C2]] |
| return %1 : tensor<8x4xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: func @fold_extract |
| func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32) { |
| %const_0 = arith.constant 0 : index |
| %const_1 = arith.constant 1 : index |
| %const_3 = arith.constant 3 : index |
| // CHECK-DAG: [[C64:%.+]] = arith.constant 64 : i32 |
| // CHECK-DAG: [[C0:%.+]] = arith.constant 0.{{0*}}e+00 : f16 |
| // CHECK-DAG: [[CM2:%.+]] = arith.constant -2.{{0*}}e+00 : f16 |
| |
| // Fold an extract into a splat. |
| // CHECK-DAG: [[C4:%.+]] = arith.constant 4.{{0*}}e+00 : f32 |
| %0 = arith.constant dense<4.0> : tensor<4xf32> |
| %ext_1 = tensor.extract %0[%arg0] : tensor<4xf32> |
| |
| // Fold an extract into a sparse with a sparse index. |
| %1 = arith.constant sparse<[[0, 0, 0], [1, 1, 1]], [-5.0, -2.0]> : tensor<4x4x4xf16> |
| %ext_2 = tensor.extract %1[%const_1, %const_1, %const_1] : tensor<4x4x4xf16> |
| |
| // Fold an extract into a sparse with a non sparse index. |
| %2 = arith.constant sparse<[[1, 1, 1]], [-2.0]> : tensor<2x2x2xf16> |
| %ext_3 = tensor.extract %2[%const_0, %const_0, %const_0] : tensor<2x2x2xf16> |
| |
| // Fold an extract into a dense tensor. |
| %3 = arith.constant dense<[[[1, -2, 1, 36]], [[0, 2, -1, 64]]]> : tensor<2x1x4xi32> |
| %ext_4 = tensor.extract %3[%const_1, %const_0, %const_3] : tensor<2x1x4xi32> |
| |
| // CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]] |
| return %ext_1, %ext_2, %ext_3, %ext_4 : f32, f16, f16, i32 |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: func @fold_insert |
| func @fold_insert(%arg0 : index) -> (tensor<4xf32>) { |
| // Fold an insert into a splat. |
| // CHECK-DAG: %[[C4:.+]] = arith.constant dense<4.{{0*}}e+00> : tensor<4xf32> |
| %0 = arith.constant dense<4.0> : tensor<4xf32> |
| %1 = arith.constant 4.0 : f32 |
| %ins_1 = tensor.insert %1 into %0[%arg0] : tensor<4xf32> |
| // CHECK-NEXT: return %[[C4]] |
| return %ins_1 : tensor<4xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: func @extract_from_tensor.cast |
| // CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32> |
| func @extract_from_tensor.cast(%tensor: tensor<*xf32>) -> f32 { |
| // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index |
| %c0 = arith.constant 0 : index |
| // CHECK-NOT: tensor.cast |
| %casted = tensor.cast %tensor : tensor<*xf32> to tensor<?xf32> |
| // CHECK-NEXT: tensor.extract %[[TENSOR]][%[[C0]]] |
| %result = tensor.extract %casted[%c0] : tensor<?xf32> |
| return %result : f32 |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: func @extract_from_tensor.from_elements |
| func @extract_from_tensor.from_elements(%element : index) -> index { |
| // CHECK-SAME: ([[ARG:%.*]]: index) |
| %c0 = arith.constant 0 : index |
| %tensor = tensor.from_elements %element : tensor<1xindex> |
| %extracted_element = tensor.extract %tensor[%c0] : tensor<1xindex> |
| // CHECK: [[ARG]] : index |
| return %extracted_element : index |
| } |
| |
| // ----- |
| |
| // Ensure the optimization doesn't segfault from bad constants |
| // CHECK-LABEL: func @extract_negative_from_tensor.from_elements |
| func @extract_negative_from_tensor.from_elements(%element : index) -> index { |
| // CHECK-SAME: ([[ARG:%.*]]: index) |
| %c-1 = arith.constant -1 : index |
| %tensor = tensor.from_elements %element : tensor<1xindex> |
| %extracted_element = tensor.extract %tensor[%c-1] : tensor<1xindex> |
| // CHECK: tensor.from_elements |
| // CHECK: %[[RESULT:.*]] = tensor.extract |
| // CHECK: return %[[RESULT]] |
| return %extracted_element : index |
| } |
| |
| // ----- |
| |
| // Ensure the optimization doesn't segfault from bad constants |
| // CHECK-LABEL: func @extract_oob_from_tensor.from_elements |
| func @extract_oob_from_tensor.from_elements(%element : index) -> index { |
| // CHECK-SAME: ([[ARG:%.*]]: index) |
| %c1 = arith.constant 1 : index |
| %tensor = tensor.from_elements %element : tensor<1xindex> |
| %extracted_element = tensor.extract %tensor[%c1] : tensor<1xindex> |
| // CHECK: tensor.from_elements |
| // CHECK: %[[RESULT:.*]] = tensor.extract |
| // CHECK: return %[[RESULT]] |
| return %extracted_element : index |
| } |
| |
| // ----- |
| |
| // Ensure the optimization doesn't segfault from bad constants |
| // CHECK-LABEL: func @extract_oob_from_tensor.from_elements |
| func @extract_oob_from_tensor.from_elements(%element : index) -> index { |
| // CHECK-SAME: ([[ARG:%.*]]: index) |
| %c2 = arith.constant 2 : index |
| %tensor = tensor.from_elements %element : tensor<1xindex> |
| %extracted_element = tensor.extract %tensor[%c2] : tensor<1xindex> |
| // CHECK: tensor.from_elements |
| // CHECK: %[[RESULT:.*]] = tensor.extract |
| // CHECK: return %[[RESULT]] |
| return %extracted_element : index |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: func @extract_from_tensor.generate |
| // CHECK-SAME: %[[IDX:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32> |
| func @extract_from_tensor.generate(%idx: index, %tensor: tensor<*xf32>) -> index { |
| %size = rank %tensor : tensor<*xf32> |
| // CHECK-NEXT: %[[RES:.*]] = tensor.dim %[[TENSOR]], %[[IDX]] |
| %0 = tensor.generate %size { |
| ^bb0(%arg0: index): |
| %1 = tensor.dim %tensor, %arg0 : tensor<*xf32> |
| tensor.yield %1 : index |
| } : tensor<?xindex> |
| %1 = tensor.extract %0[%idx] : tensor<?xindex> |
| // CHECK-NEXT: return %[[RES]] |
| return %1 : index |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: func @extract_from_tensor.generate_2d |
| // CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index, %[[TENSOR:.*]]: tensor<*xf32> |
| func @extract_from_tensor.generate_2d(%idx0: index, %idx1: index, %tensor: tensor<*xf32>) -> index { |
| %size = rank %tensor : tensor<*xf32> |
| // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[TENSOR]], %[[IDX0]] |
| // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[TENSOR]], %[[IDX1]] |
| // CHECK-NEXT: %[[RES:.*]] = arith.addi %[[DIM0]], %[[DIM1]] |
| %0 = tensor.generate %size, %size { |
| ^bb0(%arg0: index, %arg1: index): |
| %1 = tensor.dim %tensor, %arg0 : tensor<*xf32> |
| %2 = tensor.dim %tensor, %arg1 : tensor<*xf32> |
| %3 = arith.addi %1, %2 : index |
| tensor.yield %3 : index |
| } : tensor<?x?xindex> |
| %4 = tensor.extract %0[%idx0, %idx1] : tensor<?x?xindex> |
| // CHECK-NEXT: return %[[RES]] |
| return %4 : index |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: func @extract_from_tensor.generate_sideeffects |
| // CHECK-SAME: %[[IDX:.*]]: index |
| func @extract_from_tensor.generate_sideeffects(%idx: index, %tensor: tensor<*xf32>, %mem: memref<?xindex>) -> index { |
| %size = rank %tensor : tensor<*xf32> |
| // CHECK: %[[DTENSOR:.*]] = tensor.generate |
| %0 = tensor.generate %size { |
| ^bb0(%arg0: index): |
| %1 = tensor.dim %tensor, %arg0 : tensor<*xf32> |
| memref.store %1, %mem[%arg0] : memref<?xindex> |
| tensor.yield %1 : index |
| } : tensor<?xindex> |
| // CHECK: %[[RES:.*]] = tensor.extract %[[DTENSOR]][%[[IDX]]] |
| %1 = tensor.extract %0[%idx] : tensor<?xindex> |
| // CHECK-NEXT: return %[[RES]] |
| return %1 : index |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @static_tensor.generate |
| // CHECK-SAME: %[[SIZE1:.*]]: index, %[[SIZE4:.*]]: index) |
| func @static_tensor.generate(%size1: index, %size4: index) -> tensor<3x?x?x7x?xindex> { |
| %c5 = arith.constant 5 : index |
| // CHECK: tensor.generate %[[SIZE1]], %[[SIZE4]] |
| %0 = tensor.generate %size1, %c5, %size4 { |
| ^bb0(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index): |
| %1 = arith.constant 32 : index |
| tensor.yield %1 : index |
| // CHECK: : tensor<3x?x5x7x?xindex> |
| } : tensor<3x?x?x7x?xindex> |
| // CHECK: tensor.cast %{{.*}} : tensor<3x?x5x7x?xindex> to tensor<3x?x?x7x?xindex> |
| return %0 : tensor<3x?x?x7x?xindex> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @from_elements.constant |
| func @from_elements.constant() -> tensor<3xindex> { |
| // CHECK: %[[CST:.*]] = arith.constant dense<[1, 2, 1]> : tensor<3xindex> |
| // CHECK: return %[[CST]] |
| %c1 = arith.constant 1 : index |
| %c2 = arith.constant 2 : index |
| %tensor = tensor.from_elements %c1, %c2, %c1 : tensor<3xindex> |
| return %tensor : tensor<3xindex> |
| } |
| |
| // ----- |
| |
| func @slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index, |
| %arg2 : index) -> tensor<?x?x?xf32> |
| { |
| %c0 = arith.constant 0 : index |
| %c1 = arith.constant 1 : index |
| %c4 = arith.constant 4 : index |
| %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?x?xf32> |
| return %0 : tensor<?x?x?xf32> |
| } |
| // CHECK-LABEL: func @slice_canonicalize |
| // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32> |
| // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] |
| // CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] |
| // CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x1x?xf32> |
| // CHECK: %[[RESULT:.+]] = tensor.cast %[[SLICE]] |
| // CHEKC: return %[[RESULT]] |
| |
| // ----- |
| |
| func @rank_reducing_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index, |
| %arg2 : index) -> tensor<?x?xf32> |
| { |
| %c0 = arith.constant 0 : index |
| %c1 = arith.constant 1 : index |
| %c4 = arith.constant 4 : index |
| %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32> |
| return %0 : tensor<?x?xf32> |
| } |
| // CHECK-LABEL: func @rank_reducing_slice_canonicalize |
| // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32> |
| // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1] |
| // CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1] |
| // CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x?xf32> |
| // CHECK: %[[RESULT:.+]] = tensor.cast %[[SLICE]] |
| // CHEKC: return %[[RESULT]] |
| |
| // ----- |
| |
| // CHECK-LABEL: func @trivial_slice |
| // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> |
| // CHECK-NOT: tensor.extract_slice |
| // CHECK: return %[[ARG0]] : tensor<4x6x16x32xi8> |
| func @trivial_slice(%arg0 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> { |
| %0 = tensor.extract_slice %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<4x6x16x32xi8> |
| return %0 : tensor<4x6x16x32xi8> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: func @trivial_insert_slice |
| // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> |
| // CHECK-NOT: tensor.extract_slice |
| // CHECK: return %[[ARG0]] : tensor<4x6x16x32xi8> |
| func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> { |
| %0 = tensor.insert_slice %arg0 into %arg1[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> into tensor<4x6x16x32xi8> |
| return %0 : tensor<4x6x16x32xi8> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: func @rank_reducing_tensor_of_cast |
| // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> |
| // CHECK: %[[S:.+]] = tensor.extract_slice %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8> |
| // Tensor cast is moved after slice and then gets canonicalized away. |
| // CHECK-NOT: tensor.cast |
| // CHECK: return %[[S]] : tensor<16x32xi8> |
| func @rank_reducing_tensor_of_cast(%arg : tensor<4x6x16x32xi8>) -> tensor<16x32xi8> { |
| %0 = tensor.cast %arg : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8> |
| %1 = tensor.extract_slice %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x?x16x32xi8> to tensor<16x32xi8> |
| return %1 : tensor<16x32xi8> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: func @rank_reducing_insert_slice_of_cast |
| // CHECK-SAME: %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8> |
| // CHECK-SAME: %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> |
| // CHECK: %[[S:.+]] = tensor.insert_slice %[[A]] into %[[B]][0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<16x32xi8> into tensor<4x6x16x32xi8> |
| // Tensor cast is folded away. |
| // CHECK-NOT: tensor.cast |
| // CHECK: return %[[S]] : tensor<4x6x16x32xi8> |
| func @rank_reducing_insert_slice_of_cast(%a : tensor<16x32xi8>, %b : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> { |
| %cast = tensor.cast %a : tensor<16x32xi8> to tensor<?x32xi8> |
| %res = tensor.insert_slice %cast into %b[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<?x32xi8> into tensor<4x6x16x32xi8> |
| return %res : tensor<4x6x16x32xi8> |
| } |
| |
| // ----- |
| |
| func @insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index, |
| %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> |
| { |
| %c0 = arith.constant 0 : index |
| %c1 = arith.constant 1 : index |
| %c4 = arith.constant 4 : index |
| %0 = tensor.insert_slice %arg0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> into tensor<?x?x?xf32> |
| return %0 : tensor<?x?x?xf32> |
| } |
| // CHECK-LABEL: func @insert_slice_canonicalize |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> |
| // CHECK: %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<4x1x?xf32> |
| // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[CAST]] |
| // CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] |
| // CHECK-SAME: : tensor<4x1x?xf32> into tensor<?x?x?xf32> |
| // CHECK: return %[[RESULT]] |
| |
| // ----- |
| |
| func @slice_to_insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index, |
| %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> |
| { |
| %c0 = arith.constant 0 : index |
| %c1 = arith.constant 1 : index |
| %c4 = arith.constant 4 : index |
| %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?x?xf32> |
| %1 = tensor.insert_slice %0 into %arg3[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> into tensor<?x?x?xf32> |
| return %1 : tensor<?x?x?xf32> |
| } |
| // CHECK-LABEL: func @slice_to_insert_slice_canonicalize |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> |
| // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> |
| // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]] |
| // CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}} [1, 1, 1] |
| // CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x1x?xf32> |
| // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[SLICE]] |
| // CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] |
| // CHECK-SAME: : tensor<4x1x?xf32> into tensor<?x?x?xf32> |
| // CHEKC: return %[[RESULT]] |
| |
| // ----- |
| |
| func @rank_reducing_insert_slice_canonicalize(%arg0 : tensor<?x?xf32>, %arg1 : index, |
| %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> |
| { |
| %c0 = arith.constant 0 : index |
| %c1 = arith.constant 1 : index |
| %c4 = arith.constant 4 : index |
| %0 = tensor.insert_slice %arg0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32> |
| return %0 : tensor<?x?x?xf32> |
| } |
| // CHECK-LABEL: func @rank_reducing_insert_slice_canonicalize |
| // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32> |
| // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] |
| // CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] |
| // CHECK-SAME: : tensor<?x?xf32> into tensor<?x?x?xf32> |
| // CHEKC: return %[[RESULT]] |
| |
| // ----- |
| |
| func @rank_reducing_slice_to_insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index, |
| %arg2 : index, %arg3 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> |
| { |
| %c0 = arith.constant 0 : index |
| %c1 = arith.constant 1 : index |
| %c4 = arith.constant 4 : index |
| %0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32> to tensor<?x?xf32> |
| %1 = tensor.insert_slice %0 into %arg3[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : tensor<?x?xf32> into tensor<?x?x?xf32> |
| return %1 : tensor<?x?x?xf32> |
| } |
| // CHECK-LABEL: func @rank_reducing_slice_to_insert_slice_canonicalize |
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> |
| // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32> |
| // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]] |
| // CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] |
| // CHECK-SAME: : tensor<?x?x?xf32> to tensor<4x?xf32> |
| // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[SLICE]] into %[[ARG3]] |
| // CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] |
| // CHECK-SAME: : tensor<4x?xf32> into tensor<?x?x?xf32> |
| // CHEKC: return %[[RESULT]] |
| |
| // ----- |
| |
| func @insert_slice_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor<i32>, |
| %arg2 : index, %arg3 : index) -> tensor<?x?xi32> { |
| %c0 = arith.constant 0 : index |
| %c1 = arith.constant 1 : index |
| %c2 = arith.constant 2 : index |
| %c8 = arith.constant 8 : index |
| %0 = tensor.dim %arg0, %c1 : tensor<2x?xi32> |
| %1 = tensor.extract %arg1[] : tensor<i32> |
| %2 = tensor.generate %arg2, %c8 { |
| ^bb0(%arg4: index, %arg5: index): |
| tensor.yield %1 : i32 |
| } : tensor<?x?xi32> |
| %3 = tensor.insert_slice %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor<?x?xi32> |
| return %3 : tensor<?x?xi32> |
| } |
| // CHECK-LABEL: func @insert_slice_propagate_dest_cast |
| // CHECK: %[[UPDATED:.+]] = tensor.insert_slice %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1] |
| // CHECK-SAME: tensor<2x?xi32> into tensor<?x8xi32> |
| // CHECK: %[[CAST:.+]] = tensor.cast %[[UPDATED]] |
| // CHECK: return %[[CAST]] |
| |
| // ----- |
| |
| func @insert_slice_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor<i32>) -> tensor<3x9xi32> { |
| %c0 = arith.constant 0 : index |
| %c1 = arith.constant 1 : index |
| %c2 = arith.constant 2 : index |
| %c9 = arith.constant 9 : index |
| %c3 = arith.constant 3 : index |
| %2 = tensor.extract %arg1[] : tensor<i32> |
| %4 = tensor.generate %c3, %c9 { |
| ^bb0(%arg2: index, %arg3: index): |
| tensor.yield %2 : i32 |
| } : tensor<?x?xi32> |
| %5 = tensor.insert_slice %arg0 into %4[%c0, %c1] [%c2, %c3] [1, 1] : tensor<2x3xi32> into tensor<?x?xi32> |
| %6 = tensor.cast %5 : tensor<?x?xi32> to tensor<3x9xi32> |
| return %6 : tensor<3x9xi32> |
| } |
| // CHECK-LABEL: func @insert_slice_output_dest_canonicalize |
| // CHECK-SAME: %[[ARG0:[a-zA-z0-9_]+]]: tensor<2x3xi32> |
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<i32> |
| // CHECK: %[[PAD:.+]] = tensor.extract %[[ARG1]] |
| // CHECK: %[[GENERATE:.+]] = tensor.generate |
| // CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] into %[[GENERATE]] |
| // CHECK: return %[[RESULT]] |
| |
| // ----- |
| |
| // Test case: Folding of tensor.dim(tensor.generate %idx) -> %idx |
| // CHECK-LABEL: func @dim_of_tensor.generate( |
| // CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index |
| // CHECK-NOT: tensor.dim |
| // CHECK: return %[[IDX1]] : index |
| func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index { |
| %c3 = arith.constant 3 : index |
| %0 = tensor.generate %arg0, %arg1 { |
| ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index): |
| tensor.yield %c3 : index |
| } : tensor<2x?x4x?x5xindex> |
| %1 = tensor.dim %0, %c3 : tensor<2x?x4x?x5xindex> |
| return %1 : index |
| } |
| |
| // ----- |
| |
| // Test case: Folding tensor.dim(tensor.cast %0, %idx) -> tensor.dim %0, %idx |
| // CHECK-LABEL: func @fold_dim_of_tensor.cast |
| // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32> |
| // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index |
| // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index |
| // CHECK: %[[T0:.+]] = tensor.dim %[[ARG0]], %[[C1]] |
| // CHECK-NEXT: return %[[C4]], %[[T0]] |
| func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) { |
| %c0 = arith.constant 0 : index |
| %c1 = arith.constant 1 : index |
| %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32> |
| %1 = tensor.dim %0, %c0 : tensor<?x?xf32> |
| %2 = tensor.dim %0, %c1 : tensor<?x?xf32> |
| return %1, %2: index, index |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src( |
| // CHECK-SAME: %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32> |
| // CHECK: %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x64xf32> |
| // CHECK: %[[r:.*]] = tensor.insert_slice %[[cast]] into %[[arg1]][0, 1, 2] [64, 5, 64] [1, 1, 1] : tensor<64x5x64xf32> into tensor<?x?x?xf32> |
| // CHECK: return %[[r]] |
| func @insert_tensor_cast_on_insert_slice_src( |
| %arg0 : tensor<?x5x?xf32>, %arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> { |
| %r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [64, 5, 64] [1, 1, 1] |
| : tensor<?x5x?xf32> into tensor<?x?x?xf32> |
| return %r : tensor<?x?x?xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: func @fold_extract_insert |
| // CHECK-SAME: %{{.+}}: tensor<?x?x?xf32>, %[[SLICE:.+]]: tensor<4x?x8xf32> |
| func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<4x?x8xf32>) { |
| %c0 = arith.constant 0: index |
| %c1 = arith.constant 1: index |
| %0 = tensor.insert_slice %slice into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32> |
| %1 = tensor.extract_slice %0[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<?x?x?xf32> to tensor<4x?x8xf32> |
| // CHECK: return %[[SLICE]] |
| return %1 : tensor<4x?x8xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: func @fold_overlapping_insert |
| // CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32> |
| func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) { |
| %c0 = arith.constant 0: index |
| %c1 = arith.constant 1: index |
| %0 = tensor.insert_slice %slice1 into %input[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32> |
| // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[SLICE2]] into %[[INPUT]] |
| %1 = tensor.insert_slice %slice2 into %0[%c0, %i, 0] [4, %size, 8] [1, 1, %c1] : tensor<4x?x8xf32> into tensor<?x?x?xf32> |
| // CHECK: return %[[INSERT]] |
| return %1 : tensor<?x?x?xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: func @folding_incorrect_ir_triggers_infinite_loop |
| func @folding_incorrect_ir_triggers_infinite_loop( |
| %A : tensor<4x4xf32>, %C : tensor<?x?xf32>) -> tensor<?x?xf32> { |
| %rC = tensor.insert_slice %A into %C[0, 0][12345, 67890][1, 1] : |
| tensor<4x4xf32> into tensor<?x?xf32> |
| return %rC: tensor<?x?xf32> |
| } |