blob: 3669cae87408df76c3f46fe30c55f20aed6d9c41 [file] [log] [blame]
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-rewrite-extract-slice-from-collapse-shape %s | FileCheck %s
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-rewrite-extract-slice-from-collapse-shape use-foreach" %s | FileCheck %s --check-prefix=FOREACH
func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf32> {
%collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32>
%slice = tensor.extract_slice %collapsed [0, 0] [20, 11] [1, 1] : tensor<105x11xf32> to tensor<20x11xf32>
return %slice : tensor<20x11xf32>
}
// CHECK: func.func @extract_slice_static(%[[arg0:.+]]:
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[c20:.+]] = arith.constant 20 : index
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index
// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index
// CHECK-DAG: %[[init:.+]] = tensor.empty() : tensor<20x11xf32>
// CHECK-DAG: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c20]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]]
// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] :
// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 11] [1, 1] :
// CHECK: scf.yield %[[update]] :
// CHECK: return %[[tile]]
// FOREACH: func.func @extract_slice_static(%[[arg0:.+]]:
// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index
// FOREACH-DAG: %[[c5:.+]] = arith.constant 5 : index
// FOREACH-DAG: %[[c7:.+]] = arith.constant 7 : index
// FOREACH-DAG: %[[init:.+]] = tensor.empty() : tensor<20x11xf32>
// FOREACH: %[[tile:.+]] = scf.forall (%[[iv:.+]]) in (20) shared_outs(%[[dest:.+]] = %[[init]])
// FOREACH: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]]
// FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] :
// FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
// FOREACH: in_parallel
// FOREACH-NEXT: tensor.parallel_insert_slice %[[sliceFlat]] into %[[dest]][%[[iv]], 0] [1, 11] [1, 1] :
// FOREACH: return %[[tile]]
// -----
func.func @extract_slice_static_strided(%input: tensor<3x5x7x11xf32>) -> tensor<10x5xf32> {
%collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32>
%slice = tensor.extract_slice %collapsed [13, 0] [10, 5] [2, 2] : tensor<105x11xf32> to tensor<10x5xf32>
return %slice : tensor<10x5xf32>
}
// CHECK: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2 + 13)>
// CHECK: func.func @extract_slice_static_strided(%[[arg0:.+]]:
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index
// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index
// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index
// CHECK: %[[init:.+]] = tensor.empty() : tensor<10x5xf32>
// CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
// CHECK: %[[inputIv:.+]] = affine.apply #[[$map0]](%[[iv]])
// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[c5]], %[[c7]]
// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] :
// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] :
// CHECK: scf.yield %[[update]] :
// CHECK: return %[[tile]]
// -----
func.func @extract_slice_dynamic(%input: tensor<3x?x?x11xf32>, %offt: index, %size: index) -> tensor<?x5xf32> {
%collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x?x?x11xf32> into tensor<?x11xf32>
%slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 5] [2, 2] : tensor<?x11xf32> to tensor<?x5xf32>
return %slice : tensor<?x5xf32>
}
// CHECK: #[[map0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
// CHECK: func.func @extract_slice_dynamic(%[[arg0:.+]]: tensor<{{.*}}>, %[[lb:.+]]: index, %[[sz:.+]]: index)
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
// CHECK: %[[init:.+]] = tensor.empty(%[[sz]]) : tensor<?x5xf32>
// CHECK-DAG: %[[d1:.+]] = tensor.dim %arg0, %[[c1]] : tensor<3x?x?x11xf32>
// CHECK-DAG: %[[d2:.+]] = tensor.dim %arg0, %[[c2]] : tensor<3x?x?x11xf32>
// CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[sz]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]])
// CHECK: %[[inputIv:.+]] = affine.apply #[[map0]](%[[iv]])[%[[lb]]]
// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[d1]], %[[d2]]) :
// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] :
// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} :
// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] :
// CHECK: scf.yield %[[update]] :
// CHECK: return %[[tile]] :
// -----
func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0: index, %size0: index, %offt1: index, %size1: index) -> tensor<?x?xf32> {
%collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x?xf32> into tensor<?x?xf32>
%slice = tensor.extract_slice %collapsed [%offt0, %offt1] [%size0, %size1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
return %slice : tensor<?x?xf32>
}
// CHECK: #[[map0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
// CHECK: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index)
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index
// CHECK-DAG: %[[c11:.+]] = arith.constant 11 : index
// CHECK: %[[init:.+]] = tensor.empty(%[[sz1]], %[[sz2]]) : tensor<?x?xf32>
// CHECK-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] :
// CHECK-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] :
// CHECK-DAG: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] :
// CHECK: %[[tile1:.+]] = scf.for %[[iv1:.+]] = %[[c0]] to %[[sz1]] step %[[c1]] iter_args(%[[iterArg1:.+]] = %[[init]])
// CHECK: %[[tile2:.+]] = scf.for %[[iv2:.+]] = %[[c0]] to %[[sz2]] step %[[c1]] iter_args(%[[iterArg2:.+]] = %[[iterArg1]])
// CHECK: %[[inputIv1:.+]] = affine.apply #[[map0:.+]](%[[iv1]])[%[[lb1]]]
// CHECK: %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[inputIv1]] into (%[[c3]], %[[d1]], %[[d2]]) :
// CHECK: %[[inputIv2:.+]] = affine.apply #[[map0:.+]](%[[iv2]])[%[[lb2]]]
// CHECK: %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[inputIv2]] into (%[[c11]], %[[d4]]) :
// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] :
// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} :
// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg2]][%[[iv1]], %[[iv2]]] [1, 1] [1, 1] :
// CHECK: scf.yield %[[update]] :
// CHECK: scf.yield %[[tile2]] :
// CHECK: return %[[tile1]] :
// FOREACH: #[[map1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)>
// FOREACH: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index)
// FOREACH-DAG: %[[c1:.+]] = arith.constant 1 : index
// FOREACH-DAG: %[[c2:.+]] = arith.constant 2 : index
// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index
// FOREACH-DAG: %[[c4:.+]] = arith.constant 4 : index
// FOREACH-DAG: %[[c11:.+]] = arith.constant 11 : index
// FOREACH: %[[init:.+]] = tensor.empty(%[[sz1]], %[[sz2]]) : tensor<?x?xf32>
// FOREACH-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] :
// FOREACH-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] :
// FOREACH-DAG: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] :
// FOREACH: %[[tile1:.+]] = scf.forall (%[[tid1:.+]], %[[tid2:.+]]) in (%[[sz1]], %[[sz2]]) shared_outs(%[[dest:.+]] = %[[init]])
// FOREACH-DAG: %[[iv1:.+]] = affine.apply #[[map1]](%[[tid1]])[%[[lb1]]]
// FOREACH: %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[iv1]] into (%[[c3]], %[[d1]], %[[d2]]) :
// FOREACH-DAG: %[[iv2:.+]] = affine.apply #[[map1]](%[[tid2]])[%[[lb2]]]
// FOREACH: %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[iv2]] into (%[[c11]], %[[d4]]) :
// FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] :
// FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} :
// FOREACH: in_parallel
//FOREACH-NEXT: tensor.parallel_insert_slice %[[sliceFlat]] into %[[dest]][%[[tid1]], %[[tid2]]] [1, 1] [1, 1] :
// -----
// Verifies that a linearized dimension that is not sliced does not generate a loop. Note that this
// only works for static shapes.
// CHECK: @extract_slice_non_sliced_linearized_dim(%[[arg0:.+]]: tensor<{{.*}}>,
func.func @extract_slice_non_sliced_linearized_dim(%input: tensor<3x?x?x11x2xf32>, %offt: index, %size: index) -> tensor<?x22xf32> {
%collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x2xf32> into tensor<?x22xf32>
%slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 22] [1, 1] : tensor<?x22xf32> to tensor<?x22xf32>
// CHECK: scf.for
// CHECK-NOT: scf.for
// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index
// CHECK: tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0, 0] [1, 1, 1, 11, 2] [1, 1, 1, 1, 1]
return %slice : tensor<?x22xf32>
}
// -----
// CHECK: @no_sliced_linearized_dims(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index
func.func @no_sliced_linearized_dims(%input: tensor<30x11x100xf32>, %offt: index, %size: index) -> tensor<330x?xf32> {
%collapsed = tensor.collapse_shape %input [[0, 1], [2]] : tensor<30x11x100xf32> into tensor<330x100xf32>
%slice = tensor.extract_slice %collapsed [0, %offt] [330, %size] [1, 1] : tensor<330x100xf32> to tensor<330x?xf32>
// CHECK-NOT: scf.for
// CHECK: %[[init:.+]] = tensor.empty(%[[arg2]])
// CHECK: %[[e:.+]] = tensor.extract_slice %[[arg0]][0, 0, %[[arg1]]] [30, 11, %[[arg2]]] [1, 1, 1]
// CHECK: %[[c:.+]] = tensor.collapse_shape %[[e]] {{\[}}[0, 1], [2]]
// CHECK: %[[res:.+]] = tensor.insert_slice %[[c]] into %[[init]]
// CHECK: return %[[res]]
return %slice : tensor<330x?xf32>
}
// -----
// The below tests verify that a dimension which is the result of collapsing at
// most one non-unit dim is handled properly.
// CHECK: @collapse_and_slice_unit_dim(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index
func.func @collapse_and_slice_unit_dim(%input: tensor<1x11x100xf32>, %offt: index, %size: index) -> tensor<?x100xf32> {
%collapsed = tensor.collapse_shape %input [[0, 1], [2]] : tensor<1x11x100xf32> into tensor<11x100xf32>
%slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 100] [1, 1] : tensor<11x100xf32> to tensor<?x100xf32>
// CHECK-NOT: scf.for
// CHECK: %[[e:.+]] = tensor.extract_slice %[[arg0]][0, 0, 0] [1, 11, 100] [1, 1, 1]
// CHECK-SAME: tensor<1x11x100xf32> to tensor<11x100xf32>
// CHECK: %[[e1:.+]] = tensor.extract_slice %[[e]][%[[arg1]], 0] [%[[arg2]], 100] [1, 1]
// CHECK-SAME: tensor<11x100xf32> to tensor<?x100xf32>
return %slice : tensor<?x100xf32>
}
// CHECK: @collapse_and_slice_multiple_unit_dim_dynamic(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index
func.func @collapse_and_slice_multiple_unit_dim_dynamic(%input: tensor<1x?x1x100xf32>, %offt: index, %size: index) -> tensor<?x100xf32> {
%collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<1x?x1x100xf32> into tensor<?x100xf32>
%slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 100] [1, 1] : tensor<?x100xf32> to tensor<?x100xf32>
// CHECK-NOT: scf.for
// CHECK: %[[c1:.+]] = arith.constant 1 : index
// CHECK: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c1]] :
// CHECK: %[[e:.+]] = tensor.extract_slice %[[arg0]][0, 0, 0, 0] [1, %[[dim]], 1, 100] [1, 1, 1, 1]
// CHECK-SAME: tensor<1x?x1x100xf32> to tensor<?x100xf32>
// CHECK: %[[e1:.+]] = tensor.extract_slice %[[e]][%[[arg1]], 0] [%[[arg2]], 100] [1, 1]
// CHECK-SAME: tensor<?x100xf32> to tensor<?x100xf32>
return %slice : tensor<?x100xf32>
}
// CHECK: @collapse_and_slice_multiple_unit_dim_mixed(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index
func.func @collapse_and_slice_multiple_unit_dim_mixed(%input: tensor<1x?x1x100x10xf32>, %offt: index, %size: index) -> tensor<?x?xf32> {
%collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<1x?x1x100x10xf32> into tensor<?x1000xf32>
%slice = tensor.extract_slice %collapsed [%offt, %offt] [%size, %size] [1, 1] : tensor<?x1000xf32> to tensor<?x?xf32>
// CHECK-DAG: %[[c0]] = arith.constant 0 : index
// CHECK-DAG: %[[c1]] = arith.constant 1 : index
// CHECK: %[[dim:.+]] = tensor.dim %[[arg0]], %[[c1]]
// CHECK: %[[rank_reduced:.+]] = tensor.extract_slice %[[arg0]][0, 0, 0, 0, 0] [1, %[[dim]], 1, 100, 10] [1, 1, 1, 1, 1]
// CHECK: %[[empty:.+]] = tensor.empty
// CHECK: %[[result:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[arg2]] step %[[c1]] iter_args(%[[ia:.+]] = %[[empty]])
// CHECK: %[[idx:.+]] = affine.apply
// CHECK: %[[multi_index:.+]] = affine.delinearize_index %[[idx]] into
// CHECK: %[[collapsed:.+]] = tensor.collapse_shape
// CHECK: %[[updated:.+]] = tensor.insert_slice
// CHECK: scf.yield %[[updated]]
// CHECK: return %[[result]]
return %slice : tensor<?x?xf32>
}
// Edge case where all collapsed dims are unit dims. This pattern can't eliminate the collapse shape,
// that should be handled by `linalg-fold-unit-extent-dims`.
// CHECK: @collapse_and_slice_multiple_all_unit_dim(%[[arg0:.+]]: tensor<{{.*}}>)
func.func @collapse_and_slice_multiple_all_unit_dim(%input: tensor<1x1x1x100xf32>) -> tensor<1x100xf32> {
%collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<1x1x1x100xf32> into tensor<1x100xf32>
%slice = tensor.extract_slice %collapsed [0, 0] [1, 100] [1, 1] : tensor<1x100xf32> to tensor<1x100xf32>
return %slice : tensor<1x100xf32>
// CHECK: %[[collapse:.+]] = tensor.collapse_shape %[[arg0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x1x100xf32> into tensor<1x100xf32>
// CHECK: return %[[collapse]]
}