blob: d0a7bc744d31111b04d8f7369eb48d7c699933a1 [file] [edit]
// RUN: mlir-opt --convert-xevm-to-llvm --split-input-file %s | FileCheck %s
// CHECK-LABEL: llvm.func @truncf_f16_to_bf8
// CHECK-SAME: %[[ARG0:.*]]: vector<16xf16>
llvm.func @truncf_f16_to_bf8(%src: vector<16xf16>) -> vector<16xi8> {
// CHECK: %[[VAR0:.*]] = llvm.shufflevector %[[ARG0]], %[[ARG0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<16xf16>
// CHECK: %[[VAR1:.*]] = llvm.shufflevector %[[ARG0]], %[[ARG0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>
// CHECK: %[[VAR2:.*]] = llvm.bitcast %[[VAR0]] : vector<8xf16> to vector<16xi8>
// CHECK: %[[VAR3:.*]] = llvm.bitcast %[[VAR1]] : vector<8xf16> to vector<16xi8>
// CHECK: %[[VAR4:.*]] = llvm.shufflevector %[[VAR2]], %[[VAR2]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi8>
// CHECK: %[[VAR5:.*]] = llvm.shufflevector %[[VAR3]], %[[VAR3]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi8>
// CHECK: %[[VAR6:.*]] = llvm.shufflevector %4, %5 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xi8>
%dst = xevm.truncf %src { src_etype = f16, dst_etype = bf8 } : (vector<16xf16>) -> vector<16xi8>
llvm.return %dst : vector<16xi8>
}
// -----
// CHECK-LABEL: llvm.func spir_funccc @__builtin_IB_sub_group16_bdpas_f_f_bf8_bf8_8_8
// CHECK-SAME: (vector<8xf32>, vector<8xi16>, vector<8xi32>, i8, i8) -> vector<8xf32>
// CHECK-SAME: attributes {convergent, memory_effects = #llvm.memory_effects<other = none,
// CHECK-SAME: argMem = none, inaccessibleMem = none, errnoMem = none,
// CHECK-SAME: targetMem0 = none, targetMem1 = none>, no_unwind, will_return}
// CHECK: llvm.func @mma_mx_bf8_bf8_k32_f32
// CHECK-SAME: %[[ARG0:.*]]: vector<8xi16>, %[[ARG1:.*]]: vector<8xi32>
// CHECK-SAME: %[[ARG2:.*]]: i8, %[[ARG3:.*]]: i8, %[[ARG4:.*]]: vector<8xf32>
llvm.func @mma_mx_bf8_bf8_k32_f32(%a: vector<8xi16>, %b: vector<8xi32>, %scale_a: i8, %scale_b: i8, %c: vector<8xf32>) -> vector<8xf32> {
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @__builtin_IB_sub_group16_bdpas_f_f_bf8_bf8_8_8
// CHECK-SAME: (%[[ARG4]], %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]])
// CHECK-SAME: {convergent, function_type = !llvm.func<vector<8xf32> (vector<8xf32>, vector<8xi16>, vector<8xi32>, i8, i8)>,
// CHECK-SAME: linkage = #llvm.linkage<external>, memory_effects = #llvm.memory_effects<other = none,
// CHECK-SAME: argMem = none, inaccessibleMem = none, errnoMem = none,
// CHECK-SAME: targetMem0 = none, targetMem1 = none>,
// CHECK-SAME: no_unwind, sym_name = "__builtin_IB_sub_group16_bdpas_f_f_bf8_bf8_8_8",
// CHECK-SAME: visibility_ = 0 : i64, will_return} :
// CHECK-SAME: (vector<8xf32>, vector<8xi16>, vector<8xi32>, i8, i8) -> vector<8xf32>
%result = xevm.mma_mx %a, %b, %scale_a, %scale_b, %c
{shape=<m=8, n=16, k=32>, types=<d=f32, a=bf8, b=bf8, c=f32>}
: (vector<8xi16>, vector<8xi32>, i8, i8, vector<8xf32>) -> vector<8xf32>
llvm.return %result : vector<8xf32>
}