blob: 81bf8e3f60e2c11afd38cca098924ce96516c86e [file] [edit]
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-expand-shape-bubbling %s | FileCheck %s
func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
return %expand : tensor<?x?x?x?xf32>
}
// CHECK: func @bubble_parallel_reshapes
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
// CHECK-SAME: output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?xf32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
// CHECK: return %[[COLLAPSE]]
// -----
func.func @no_bubble_full_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
%collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
%expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
return %expand : tensor<?x?x?x?xf32>
}
// CHECK: func @no_bubble_full_intersecting_reshapes
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]]
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]]
// CHECK: return %[[EXPAND]]
// -----
func.func @no_bubble_partial_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
%collapse = tensor.collapse_shape %arg0 [[0, 1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?xf32>
%expand = tensor.expand_shape %collapse [[0, 1], [2, 3]]
output_shape [%s0, %s1, %s2, %s3] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
return %expand : tensor<?x?x?x?xf32>
}
// CHECK: func @no_bubble_partial_intersecting_reshapes
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1], [2, 3]]
// CHECK: return %[[EXPAND]]
// -----
func.func @no_bubble_0d_tensor_reshapes(%arg0: tensor<1x1xf32>) -> tensor<1x1x1xf32> {
%collapse = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor<f32>
%expand = tensor.expand_shape %collapse []
output_shape [1, 1, 1] : tensor<f32> into tensor<1x1x1xf32>
return %expand : tensor<1x1x1xf32>
}
// CHECK: func @no_bubble_0d_tensor_reshapes
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1xf32>
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}]
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}]
// CHECK: return %[[EXPAND]]
// -----
// Test the case where the reassocation indices in the collapse and expand
// are of same size.
func.func @bubble_expand_match_non_unit_size_reassocation(
%arg0 : tensor<4x?x4x32x4x?xf16>, %arg1 : index, %arg2 : index) -> tensor<4x?x4x128x?x32xf16> {
%collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]]
: tensor<4x?x4x32x4x?xf16> into tensor<?x128x?xf16>
%expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4, 5]] output_shape [4, %arg1, 4, 128, %arg2, 32]
: tensor<?x128x?xf16> into tensor<4x?x4x128x?x32xf16>
return %expanded : tensor<4x?x4x128x?x32xf16>
}
// CHECK: func @bubble_expand_match_non_unit_size_reassocation
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x?x4x32x4x?xf16>
// CHECK-SAME: %[[ARG1:[a-zA-z0-9]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
// CHECK-SAME: {{\[}}[0], [1], [2], [3], [4], [5, 6]{{\]}}
// CHECK-SAME: [4, %[[ARG1]], 4, 32, 4, %[[ARG2]], 32]
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]]
// CHECK-SAME: {{\[}}[0], [1], [2], [3, 4], [5], [6]{{\]}}
// CHECK: return %[[COLLAPSED]]
// -----
// Test the case where the trailing collapse isnt needed.
func.func @no_collapse_generated(
%arg0 : tensor<4x?x4x128x?xf16>, %arg1 : index, %arg2 : index) -> tensor<4x?x4x128x?x32xf16> {
%collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4]]
: tensor<4x?x4x128x?xf16> into tensor<?x128x?xf16>
%expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4, 5]] output_shape [4, %arg1, 4, 128, %arg2, 32]
: tensor<?x128x?xf16> into tensor<4x?x4x128x?x32xf16>
return %expanded : tensor<4x?x4x128x?x32xf16>
}
// CHECK: func @no_collapse_generated
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape
// CHECK: return %[[EXPANDED]]
// -----
// Test the case where the leading expand isnt needed.
func.func @no_expand_generated(
%arg0 : tensor<4x?x4x128x?x?x?xf16>, %arg1 : index, %arg2 : index, %arg3 : index) -> tensor<4x?x4x128x?x?xf16> {
%collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5, 6]]
: tensor<4x?x4x128x?x?x?xf16> into tensor<?x128x?x?xf16>
%expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4], [5]] output_shape [4, %arg1, 4, 128, %arg2, %arg3]
: tensor<?x128x?x?xf16> into tensor<4x?x4x128x?x?xf16>
return %expanded : tensor<4x?x4x128x?x?xf16>
}
// CHECK: func @no_expand_generated
// CHECK: %[[EXPANDED:.+]] = tensor.collapse_shape
// CHECK: return %[[EXPANDED]]