| # RUN: %PYTHON -m mlir.dialects.linalg.opdsl.dump_oplib --file %s | FileCheck %s |
| |
| from mlir.dialects.linalg.opdsl.lang import * |
| |
| |
| # CHECK: --- |
| # CHECK-LABEL: matmul |
| # CHECK: assignments: |
| # CHECK: - |
| # CHECK: arg: C |
| # CHECK: value: |
| # CHECK: scalar_fn: |
| # CHECK: kind: binary |
| # CHECK: fn_name: add |
| # CHECK: operands: |
| # CHECK: scalar_fn: |
| # CHECK: kind: binary |
| # CHECK: attr_name: mul |
| # CHECK: operands: |
| # CHECK: scalar_fn: |
| # CHECK: kind: type |
| # CHECK: attr_name: cast |
| # CHECK: type_var: U |
| # CHECK: operands: |
| # CHECK: scalar_arg: A |
| # CHECK: scalar_fn: |
| # CHECK: kind: type |
| # CHECK: attr_name: cast |
| # CHECK: type_var: U |
| # CHECK: operands: |
| # CHECK: scalar_arg: B |
| @linalg_structured_op |
| def matmul( |
| A=TensorDef(T, S.M, S.K), |
| B=TensorDef(T, S.K, S.N), |
| C=TensorDef(U, S.M, S.N, output=True), |
| mul=BinaryFnAttrDef(default=BinaryFn.mul), |
| cast=TypeFnAttrDef(default=TypeFn.cast_signed), |
| ): |
| C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n])) |
| |
| |
| # CHECK: --- |
| # CHECK-LABEL: constants |
| # CHECK: assignments: |
| # CHECK: - |
| # CHECK: arg: O |
| # CHECK: scalar_fn: |
| # CHECK: kind: binary |
| # CHECK: fn_name: sub |
| # CHECK: operands: |
| # CHECK: scalar_fn: |
| # CHECK: kind: binary |
| # CHECK: fn_name: add |
| # CHECK: operands: |
| # CHECK: scalar_fn: |
| # CHECK: kind: unary |
| # CHECK: fn_name: exp |
| # CHECK: operands: |
| # CHECK: scalar_fn: |
| # CHECK: kind: type |
| # CHECK: type_var: T |
| # CHECK: operands: |
| # CHECK: scalar_const: '3.1415926535897931 : f64' |
| # CHECK: scalar_fn: |
| # CHECK: kind: type |
| # CHECK: fn_name: cast_signed |
| # CHECK: type_var: T |
| # CHECK: operands: |
| # CHECK: scalar_const: '42 : i64' |
| # CHECK: scalar_fn: |
| # CHECK: kind: type |
| # CHECK: fn_name: cast_signed |
| # CHECK: type_var: T |
| # CHECK: operands: |
| # CHECK: scalar_fn: |
| # CHECK: kind: unary |
| # CHECK: attr_name: exp |
| # CHECK: operands: |
| # CHECK: scalar_const: '1.{{[0]*}}e+03 : f64' |
| @linalg_structured_op |
| def constants( |
| O=TensorDef(T, S.M, S.K, output=True), exp=UnaryFnAttrDef(default=UnaryFn.exp) |
| ): |
| pi = TypeFn.cast_signed(T, const(3.1415926535897931)) |
| cst42 = TypeFn.cast_signed(T, const(42)) |
| cst1000 = TypeFn.cast_signed(T, exp(const(1e3))) |
| O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000 |
| |
| |
| # CHECK: --- |
| # CHECK-LABEL: indices |
| # CHECK: assignments: |
| # CHECK: - |
| # CHECK: arg: O |
| # CHECK: scalar_fn: |
| # CHECK: kind: binary |
| # CHECK: fn_name: add |
| # CHECK: operands: |
| # CHECK: scalar_index: 1 |
| # CHECK: scalar_index: 0 |
| @linalg_structured_op |
| def indices(O=TensorDef(T, S.M, S.K, output=True)): |
| O[D.m, D.n] = index(D.n) + index(D.m) |
| |
| |
| # CHECK: --- |
| # CHECK-LABEL: fill |
| # CHECK: assignments: |
| # CHECK: - |
| # CHECK: arg: O |
| # CHECK: scalar_arg: value |
| @linalg_structured_op |
| def fill(value=ScalarDef(T), O=TensorDef(T, S.M, S.K, output=True)): |
| O[D.m, D.n] = value |