| # RUN: %PYTHON %s | FileCheck %s |
| |
| from mlir.ir import * |
| import mlir.dialects.builtin as builtin |
| import mlir.dialects.func as func |
| import mlir.dialects.x86vector as x86vector |
| |
| |
| def run(f): |
| print("\nTEST:", f.__name__) |
| with Context(), Location.unknown(): |
| f() |
| return f |
| |
| |
| # CHECK-LABEL: TEST: testAvxOp |
| @run |
| def testAvxOp(): |
| module = Module.create() |
| with InsertionPoint(module.body): |
| |
| @func.FuncOp.from_py_func(MemRefType.get((1,), BF16Type.get())) |
| def avx_op(arg): |
| return x86vector.BcstToPackedF32Op( |
| a=arg, dst=VectorType.get((8,), F32Type.get()) |
| ) |
| |
| # CHECK-LABEL: func @avx_op( |
| # CHECK-SAME: %[[ARG:.+]]: memref<1xbf16>) -> vector<8xf32> { |
| # CHECK: %[[VAL:.+]] = x86vector.avx.bcst_to_f32.packed %[[ARG]] |
| # CHECK: return %[[VAL]] : vector<8xf32> |
| # CHECK: } |
| print(module) |
| |
| |
| # CHECK-LABEL: TEST: testAvx512Op |
| @run |
| def testAvx512Op(): |
| module = Module.create() |
| with InsertionPoint(module.body): |
| |
| @func.FuncOp.from_py_func(VectorType.get((8,), F32Type.get())) |
| def avx512_op(arg): |
| return x86vector.CvtPackedF32ToBF16Op( |
| a=arg, dst=VectorType.get((8,), BF16Type.get()) |
| ) |
| |
| # CHECK-LABEL: func @avx512_op( |
| # CHECK-SAME: %[[ARG:.+]]: vector<8xf32>) -> vector<8xbf16> { |
| # CHECK: %[[VAL:.+]] = x86vector.avx512.cvt.packed.f32_to_bf16 %[[ARG]] |
| # CHECK: return %[[VAL]] : vector<8xbf16> |
| # CHECK: } |
| print(module) |
| |
| |
| # CHECK-LABEL: TEST: testAvx10Op |
| @run |
| def testAvx10Op(): |
| module = Module.create() |
| with InsertionPoint(module.body): |
| |
| @func.FuncOp.from_py_func( |
| VectorType.get((16,), IntegerType.get(32)), |
| VectorType.get((64,), IntegerType.get(8)), |
| VectorType.get((64,), IntegerType.get(8)), |
| ) |
| def avx10_op(*args): |
| return x86vector.AVX10DotInt8Op(w=args[0], a=args[1], b=args[2]) |
| |
| # CHECK-LABEL: func @avx10_op( |
| # CHECK-SAME: %[[W:.+]]: vector<16xi32>, %[[A:.+]]: vector<64xi8>, |
| # CHECK-SAME: %[[B:.+]]: vector<64xi8>) -> vector<16xi32> { |
| # CHECK: %[[VAL:.+]] = x86vector.avx10.dot.i8 %[[W]], %[[A]], %[[B]] |
| # CHECK: return %[[VAL]] : vector<16xi32> |
| # CHECK: } |
| print(module) |