| // RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s |
| |
| #map = affine_map<(d0, d1) -> (d0, d1)> |
| func.func @specialize_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { |
| %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) { |
| ^bb0(%in: f32, %in_0: f32, %out: f32): |
| %1 = arith.addf %in, %in_0 : f32 |
| linalg.yield %1 : f32 |
| } -> tensor<?x?xf32> |
| return %0 : tensor<?x?xf32> |
| } |
| // CHECK-LABEL: specialize_add |
| // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> |
| // CHECK-NOT: linalg.generic |
| // CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32> |
| |
| func.func @specialize_sub(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { |
| %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) { |
| ^bb0(%in: f32, %in_0: f32, %out: f32): |
| %1 = arith.subf %in, %in_0 : f32 |
| linalg.yield %1 : f32 |
| } -> tensor<?x?xf32> |
| return %0 : tensor<?x?xf32> |
| } |
| // CHECK-LABEL: specialize_sub |
| // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> |
| // CHECK-NOT: linalg.generic |
| // CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32> |
| |
| func.func @specialize_sub_swapped_operands(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { |
| %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) { |
| ^bb0(%in: f32, %in_0: f32, %out: f32): |
| %1 = arith.subf %in_0, %in : f32 |
| linalg.yield %1 : f32 |
| } -> tensor<?x?xf32> |
| return %0 : tensor<?x?xf32> |
| } |
| // CHECK-LABEL: specialize_sub |
| // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> |
| // CHECK-NOT: linalg.generic |
| // CHECK: linalg.sub ins(%[[ARG1]], %[[ARG0]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32> |
| |
| func.func @specialize_mul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { |
| %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) { |
| ^bb0(%in: f32, %in_0: f32, %out: f32): |
| %1 = arith.mulf %in, %in_0 : f32 |
| linalg.yield %1 : f32 |
| } -> tensor<?x?xf32> |
| return %0 : tensor<?x?xf32> |
| } |
| // CHECK-LABEL: specialize_mul |
| // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> |
| // CHECK-NOT: linalg.generic |
| // CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32> |
| |
| func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> { |
| %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) { |
| ^bb0(%in: f32, %in_0: f32, %out: f32): |
| %1 = arith.divf %in, %in_0 : f32 |
| linalg.yield %1 : f32 |
| } -> tensor<?x?xf32> |
| return %0 : tensor<?x?xf32> |
| } |
| // CHECK-LABEL: specialize_div |
| // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> |
| // CHECK-NOT: linalg.generic |
| // CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32> |
| |
| |
| module attributes {transform.with_named_sequence} { |
| transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { |
| %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op |
| %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op |
| transform.yield |
| } |
| } |