blob: c270727078a20da65cf2ed52246045acd6014afc [file] [log] [blame] [edit]
# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
import mlir.dialects.builtin as builtin
import mlir.dialects.func as func
import mlir.dialects.x86vector as x86vector
def run(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
f()
return f
# CHECK-LABEL: TEST: testAvxOp
@run
def testAvxOp():
module = Module.create()
with InsertionPoint(module.body):
@func.FuncOp.from_py_func(MemRefType.get((1,), BF16Type.get()))
def avx_op(arg):
return x86vector.BcstToPackedF32Op(
a=arg, dst=VectorType.get((8,), F32Type.get())
)
# CHECK-LABEL: func @avx_op(
# CHECK-SAME: %[[ARG:.+]]: memref<1xbf16>) -> vector<8xf32> {
# CHECK: %[[VAL:.+]] = x86vector.avx.bcst_to_f32.packed %[[ARG]]
# CHECK: return %[[VAL]] : vector<8xf32>
# CHECK: }
print(module)
# CHECK-LABEL: TEST: testAvx512Op
@run
def testAvx512Op():
module = Module.create()
with InsertionPoint(module.body):
@func.FuncOp.from_py_func(VectorType.get((8,), F32Type.get()))
def avx512_op(arg):
return x86vector.CvtPackedF32ToBF16Op(
a=arg, dst=VectorType.get((8,), BF16Type.get())
)
# CHECK-LABEL: func @avx512_op(
# CHECK-SAME: %[[ARG:.+]]: vector<8xf32>) -> vector<8xbf16> {
# CHECK: %[[VAL:.+]] = x86vector.avx512.cvt.packed.f32_to_bf16 %[[ARG]]
# CHECK: return %[[VAL]] : vector<8xbf16>
# CHECK: }
print(module)
# CHECK-LABEL: TEST: testAvx10Op
@run
def testAvx10Op():
module = Module.create()
with InsertionPoint(module.body):
@func.FuncOp.from_py_func(
VectorType.get((16,), IntegerType.get(32)),
VectorType.get((64,), IntegerType.get(8)),
VectorType.get((64,), IntegerType.get(8)),
)
def avx10_op(*args):
return x86vector.AVX10DotInt8Op(w=args[0], a=args[1], b=args[2])
# CHECK-LABEL: func @avx10_op(
# CHECK-SAME: %[[W:.+]]: vector<16xi32>, %[[A:.+]]: vector<64xi8>,
# CHECK-SAME: %[[B:.+]]: vector<64xi8>) -> vector<16xi32> {
# CHECK: %[[VAL:.+]] = x86vector.avx10.dot.i8 %[[W]], %[[A]], %[[B]]
# CHECK: return %[[VAL]] : vector<16xi32>
# CHECK: }
print(module)