blob: e93c5a08bdc7c91a7d613de85fbeba4cadbd0a2e [file] [log] [blame]
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
// CHECK-LABEL: func @parallel_contract_lowering
// CHECK: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
// CHECK: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
// CHECK: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32>
// CHECK: return %[[F]] : vector<4xf32>
func.func @parallel_contract_lowering(%arg0: vector<1x1x4xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> {
%0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: func @parallel_contract_lowering_broadcast
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32>
// CHECK: %[[T:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32>
// CHECK: %[[E0:.*]] = vector.extract %[[T]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
// CHECK: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<4xf32> from vector<1x1x4xf32>
// CHECK: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32>
// CHECK: return %[[F]] : vector<4xf32>
func.func @parallel_contract_lowering_broadcast(%arg0: vector<1x1xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> {
%0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1x4xf32> into vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: func @parallel_contract_lowering
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32>
// CHECK: %[[T0:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32>
// CHECK: %[[T1:.*]] = vector.transpose %{{.*}}, [0, 2, 1] : vector<1x4x1xf32> to vector<1x1x4xf32>
// CHECK: %[[E0:.*]] = vector.extract %[[T0]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
// CHECK: %[[E1:.*]] = vector.extract %[[T1]][0, 0] : vector<4xf32> from vector<1x1x4xf32>
// CHECK: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %arg2 : vector<4xf32>
// CHECK: return %[[F]] : vector<4xf32>
func.func @parallel_contract_lowering_transpose(%arg0: vector<1x1xf32>, %arg1: vector<1x4x1xf32>, %arg2: vector<4xf32>) -> vector<4xf32> {
%0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind<add>} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x4x1xf32> into vector<4xf32>
return %0 : vector<4xf32>
}
// CHECK-LABEL: func @parallel_contract_lowering_scalar
// CHECK: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : f32 from vector<1x1xf32>
// CHECK: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : f32 from vector<1x1xf32>
// CHECK: %[[M:.*]] = arith.mulf %[[E0]], %[[E1]] : f32
// CHECK: %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32
// CHECK: return %[[A]] : f32
func.func @parallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vector<1x1xf32>, %arg2: f32) -> f32 {
%0 = vector.contract {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> ()>],
iterator_types = ["reduction", "reduction"], kind = #vector.kind<add>}
%arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1xf32> into f32
return %0 : f32
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
: (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %f {
transform.apply_patterns.vector.lower_contraction lowering_strategy = "parallelarith"
} : !transform.any_op
transform.yield
}
}