blob: 81c70983cded1be3a11a2c66c573decd09fc311d [file] [log] [blame]
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-axpy=1 | FileCheck %s
#matvec_accesses = [
affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (j)>,
affine_map<(i, j) -> (i)>
]
#matvec_trait = {
indexing_maps = #matvec_accesses,
iterator_types = ["parallel", "reduction"]
}
#mattransvec_accesses = [
affine_map<(i, j) -> (j, i)>,
affine_map<(i, j) -> (j)>,
affine_map<(i, j) -> (i)>
]
#mattransvec_trait = {
indexing_maps = #mattransvec_accesses,
iterator_types = ["parallel", "reduction"]
}
#vecmat_accesses = [
affine_map<(i, j) -> (j)>,
affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (i)>
]
#vecmat_trait = {
indexing_maps = #vecmat_accesses,
iterator_types = ["parallel", "reduction"]
}
#vecmattrans_accesses = [
affine_map<(i, j) -> (j)>,
affine_map<(i, j) -> (j, i)>,
affine_map<(i, j) -> (i)>
]
#vecmattrans_trait = {
indexing_maps = #vecmattrans_accesses,
iterator_types = ["parallel", "reduction"]
}
// CHECK-LABEL: func @matvec2x2
// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
// CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
// CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0, 0] : vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0, 0] : f32 into vector<2x2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[T0]][1, 0] : vector<2x2xf32>
// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T4]] [0, 1] : f32 into vector<2x2xf32>
// CHECK: %[[T7:.*]] = vector.extract %[[T0]][0, 1] : vector<2x2xf32>
// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T6]] [1, 0] : f32 into vector<2x2xf32>
// CHECK: %[[T9:.*]] = vector.extract %[[T0]][1, 1] : vector<2x2xf32>
// CHECK: %[[T10:.*]] = vector.insert %[[T9]], %[[T8]] [1, 1] : f32 into vector<2x2xf32>
// CHECK: %[[T11:.*]] = vector.extract %[[T10]][0] : vector<2x2xf32>
// CHECK: %[[T12:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
// CHECK: %[[T13:.*]] = splat %[[T12]] : vector<2xf32>
// CHECK: %[[T14:.*]] = vector.fma %[[T11]], %[[T13]], %[[T2]] : vector<2xf32>
// CHECK: %[[T15:.*]] = vector.extract %[[T10]][1] : vector<2x2xf32>
// CHECK: %[[T16:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
// CHECK: %[[T17:.*]] = splat %[[T16]] : vector<2xf32>
// CHECK: %[[T18:.*]] = vector.fma %[[T15]], %[[T17]], %[[T14]] : vector<2xf32>
// CHECK: store %[[T18]], %[[C]][] : memref<vector<2xf32>>
func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
%arg2: memref<vector<2xf32>>) {
%A = load %arg0[] : memref<vector<2x2xf32>>
%x = load %arg1[] : memref<vector<2xf32>>
%b = load %arg2[] : memref<vector<2xf32>>
%0 = vector.contract #matvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
store %0, %arg2[] : memref<vector<2xf32>>
return
}
// CHECK-LABEL: func @mattransvec2x2
// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
// CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
// CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<2xf32>
// CHECK: %[[T6:.*]] = vector.fma %[[T3]], %[[T5]], %[[T2]] : vector<2xf32>
// CHECK: %[[T7:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
// CHECK: %[[T9:.*]] = splat %[[T8]] : vector<2xf32>
// CHECK: %[[T10:.*]] = vector.fma %[[T7]], %[[T9]], %[[T6]] : vector<2xf32>
// CHECK: store %[[T10]], %[[C]][] : memref<vector<2xf32>>
func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
%arg2: memref<vector<2xf32>>) {
%A = load %arg0[] : memref<vector<2x2xf32>>
%x = load %arg1[] : memref<vector<2xf32>>
%b = load %arg2[] : memref<vector<2xf32>>
%0 = vector.contract #mattransvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
store %0, %arg2[] : memref<vector<2xf32>>
return
}
// CHECK-LABEL: func @vecmat2x2
// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x2xf32>
// CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
// CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0, 0] : vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0, 0] : f32 into vector<2x2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[T0]][1, 0] : vector<2x2xf32>
// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T4]] [0, 1] : f32 into vector<2x2xf32>
// CHECK: %[[T7:.*]] = vector.extract %[[T0]][0, 1] : vector<2x2xf32>
// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T6]] [1, 0] : f32 into vector<2x2xf32>
// CHECK: %[[T9:.*]] = vector.extract %[[T0]][1, 1] : vector<2x2xf32>
// CHECK: %[[T10:.*]] = vector.insert %[[T9]], %[[T8]] [1, 1] : f32 into vector<2x2xf32>
// CHECK: %[[T11:.*]] = vector.extract %[[T10]][0] : vector<2x2xf32>
// CHECK: %[[T12:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
// CHECK: %[[T13:.*]] = splat %[[T12]] : vector<2xf32>
// CHECK: %[[T14:.*]] = vector.fma %[[T11]], %[[T13]], %[[T2]] : vector<2xf32>
// CHECK: %[[T15:.*]] = vector.extract %[[T10]][1] : vector<2x2xf32>
// CHECK: %[[T16:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
// CHECK: %[[T17:.*]] = splat %[[T16]] : vector<2xf32>
// CHECK: %[[T18:.*]] = vector.fma %[[T15]], %[[T17]], %[[T14]] : vector<2xf32>
// CHECK: store %[[T18]], %[[C]][] : memref<vector<2xf32>>
func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
%arg2: memref<vector<2xf32>>) {
%A = load %arg0[] : memref<vector<2x2xf32>>
%x = load %arg1[] : memref<vector<2xf32>>
%b = load %arg2[] : memref<vector<2xf32>>
%0 = vector.contract #vecmat_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
store %0, %arg2[] : memref<vector<2xf32>>
return
}
// CHECK-LABEL: func @vecmattrans2x2
// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
// CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
// CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<2xf32>
// CHECK: %[[T6:.*]] = vector.fma %[[T3]], %[[T5]], %[[T2]] : vector<2xf32>
// CHECK: %[[T7:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
// CHECK: %[[T9:.*]] = splat %[[T8]] : vector<2xf32>
// CHECK: %[[T10:.*]] = vector.fma %[[T7]], %[[T9]], %[[T6]] : vector<2xf32>
// CHECK: store %[[T10]], %[[C]][] : memref<vector<2xf32>>
func @vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
%arg2: memref<vector<2xf32>>) {
%A = load %arg0[] : memref<vector<2x2xf32>>
%x = load %arg1[] : memref<vector<2xf32>>
%b = load %arg2[] : memref<vector<2xf32>>
%0 = vector.contract #vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
store %0, %arg2[] : memref<vector<2xf32>>
return
}