| // RUN: mlir-opt %s -arm-sme-outer-product-fusion -cse -split-input-file -allow-unregistered-dialect | FileCheck %s |
| |
| // CHECK-LABEL: @outerproduct_add_widening_2way_f16f16f32 |
| // CHECK-SAME: %[[A0:.*]]: vector<[4]xf16>, %[[B0:.*]]: vector<[4]xf16>, %[[A1:.*]]: vector<[4]xf16>, %[[B1:.*]]: vector<[4]xf16>, |
| // CHECK-SAME: %[[A0_MASK:.*]]: vector<[4]xi1>, %[[B0_MASK:.*]]: vector<[4]xi1>, %[[A1_MASK:.*]]: vector<[4]xi1>, %[[B1_MASK:.*]]: vector<[4]xi1> |
| // CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<[4]x[4]xf32> |
| // CHECK-DAG: %[[LHS:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A0]], %[[A1]]) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16> |
| // CHECK-DAG: %[[RHS:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B0]], %[[B1]]) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16> |
| // CHECK-DAG: %[[LHS_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A0_MASK]], %[[A1_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1> |
| // CHECK-DAG: %[[RHS_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B0_MASK]], %[[B1_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1> |
| // CHECK-DAG: arm_sme.fmopa_2way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> |
| func.func @outerproduct_add_widening_2way_f16f16f32( |
| %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>, |
| %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>, |
| %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, |
| %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> { |
| %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> |
| %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> |
| %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> |
| %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> |
| |
| %acc = arith.constant dense<0.0> : vector<[4]x[4]xf32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xf32>, vector<[4]xf32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32> |
| |
| return %1 : vector<[4]x[4]xf32> |
| } |
| |
| // ----- |
| |
| /// Verify chain of 4 outer products are fused into 2 2-way widening outer |
| /// products. |
| |
| // CHECK-LABEL: @outerproduct_x2_add_widening_2way_f16f16f32 |
| // CHECK-COUNT-2: arm_sme.fmopa_2way |
| func.func @outerproduct_x2_add_widening_2way_f16f16f32( |
| %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>, |
| %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>, |
| %a2 : vector<[4]xf16>, %b2 : vector<[4]xf16>, |
| %a3 : vector<[4]xf16>, %b3 : vector<[4]xf16>) -> vector<[4]x[4]xf32> { |
| %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> |
| %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> |
| |
| %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> |
| %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> |
| |
| %a2_ext = arith.extf %a2 : vector<[4]xf16> to vector<[4]xf32> |
| %b2_ext = arith.extf %b2 : vector<[4]xf16> to vector<[4]xf32> |
| |
| %a3_ext = arith.extf %a3 : vector<[4]xf16> to vector<[4]xf32> |
| %b3_ext = arith.extf %b3 : vector<[4]xf16> to vector<[4]xf32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32> |
| %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xf32>, vector<[4]xf32> |
| %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xf32>, vector<[4]xf32> |
| |
| return %3 : vector<[4]x[4]xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_sub_widening_2way_f16f16f32 |
| // CHECK: arm_sme.fmops_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> |
| func.func @outerproduct_sub_widening_2way_f16f16f32( |
| %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>, |
| %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>, |
| %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, |
| %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> { |
| %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> |
| %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> |
| %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> |
| %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> |
| |
| %acc = arith.constant dense<0.0> : vector<[4]x[4]xf32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xf32>, vector<[4]xf32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32> |
| |
| return %1 : vector<[4]x[4]xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_add_widening_2way_bf16bf16f32 |
| // CHECK: arm_sme.fmopa_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32> |
| func.func @outerproduct_add_widening_2way_bf16bf16f32( |
| %a0 : vector<[4]xbf16>, %b0 : vector<[4]xbf16>, |
| %a1 : vector<[4]xbf16>, %b1 : vector<[4]xbf16>, |
| %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, |
| %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> { |
| %a0_ext = arith.extf %a0 : vector<[4]xbf16> to vector<[4]xf32> |
| %b0_ext = arith.extf %b0 : vector<[4]xbf16> to vector<[4]xf32> |
| %a1_ext = arith.extf %a1 : vector<[4]xbf16> to vector<[4]xf32> |
| %b1_ext = arith.extf %b1 : vector<[4]xbf16> to vector<[4]xf32> |
| |
| %acc = arith.constant dense<0.0> : vector<[4]x[4]xf32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xf32>, vector<[4]xf32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32> |
| |
| return %1 : vector<[4]x[4]xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_sub_widening_2way_bf16bf16f32 |
| // CHECK: arm_sme.fmops_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32> |
| func.func @outerproduct_sub_widening_2way_bf16bf16f32( |
| %a0 : vector<[4]xbf16>, %b0 : vector<[4]xbf16>, |
| %a1 : vector<[4]xbf16>, %b1 : vector<[4]xbf16>, |
| %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, |
| %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> { |
| %a0_ext = arith.extf %a0 : vector<[4]xbf16> to vector<[4]xf32> |
| %b0_ext = arith.extf %b0 : vector<[4]xbf16> to vector<[4]xf32> |
| %a1_ext = arith.extf %a1 : vector<[4]xbf16> to vector<[4]xf32> |
| %b1_ext = arith.extf %b1 : vector<[4]xbf16> to vector<[4]xf32> |
| |
| %acc = arith.constant dense<0.0> : vector<[4]x[4]xf32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xf32>, vector<[4]xf32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32> |
| |
| return %1 : vector<[4]x[4]xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_add_widening_2way_signed_i16i16i32 |
| // CHECK: arm_sme.smopa_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> |
| func.func @outerproduct_add_widening_2way_signed_i16i16i32( |
| %a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>, |
| %a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>, |
| %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, |
| %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> { |
| %a0_ext = arith.extsi %a0 : vector<[4]xi16> to vector<[4]xi32> |
| %b0_ext = arith.extsi %b0 : vector<[4]xi16> to vector<[4]xi32> |
| %a1_ext = arith.extsi %a1 : vector<[4]xi16> to vector<[4]xi32> |
| %b1_ext = arith.extsi %b1 : vector<[4]xi16> to vector<[4]xi32> |
| |
| %acc = arith.constant dense<0> : vector<[4]x[4]xi32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32> |
| |
| return %1 : vector<[4]x[4]xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_sub_widening_2way_signed_i16i16i32 |
| // CHECK: arm_sme.smops_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> |
| func.func @outerproduct_sub_widening_2way_signed_i16i16i32( |
| %a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>, |
| %a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>, |
| %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, |
| %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> { |
| %a0_ext = arith.extsi %a0 : vector<[4]xi16> to vector<[4]xi32> |
| %b0_ext = arith.extsi %b0 : vector<[4]xi16> to vector<[4]xi32> |
| %a1_ext = arith.extsi %a1 : vector<[4]xi16> to vector<[4]xi32> |
| %b1_ext = arith.extsi %b1 : vector<[4]xi16> to vector<[4]xi32> |
| |
| %acc = arith.constant dense<0> : vector<[4]x[4]xi32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32> |
| |
| return %1 : vector<[4]x[4]xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_add_widening_2way_unsigned_i16i16i32 |
| // CHECK: arm_sme.umopa_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> |
| func.func @outerproduct_add_widening_2way_unsigned_i16i16i32( |
| %a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>, |
| %a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>, |
| %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, |
| %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> { |
| %a0_ext = arith.extui %a0 : vector<[4]xi16> to vector<[4]xi32> |
| %b0_ext = arith.extui %b0 : vector<[4]xi16> to vector<[4]xi32> |
| %a1_ext = arith.extui %a1 : vector<[4]xi16> to vector<[4]xi32> |
| %b1_ext = arith.extui %b1 : vector<[4]xi16> to vector<[4]xi32> |
| |
| %acc = arith.constant dense<0> : vector<[4]x[4]xi32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32> |
| |
| return %1 : vector<[4]x[4]xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_sub_widening_2way_unsigned_i16i16i32 |
| // CHECK: arm_sme.umops_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> |
| func.func @outerproduct_sub_widening_2way_unsigned_i16i16i32( |
| %a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>, |
| %a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>, |
| %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, |
| %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> { |
| %a0_ext = arith.extui %a0 : vector<[4]xi16> to vector<[4]xi32> |
| %b0_ext = arith.extui %b0 : vector<[4]xi16> to vector<[4]xi32> |
| %a1_ext = arith.extui %a1 : vector<[4]xi16> to vector<[4]xi32> |
| %b1_ext = arith.extui %b1 : vector<[4]xi16> to vector<[4]xi32> |
| |
| %acc = arith.constant dense<0> : vector<[4]x[4]xi32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32> |
| |
| return %1 : vector<[4]x[4]xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_add_widening_4way_signed_i8i8i32 |
| // CHECK-SAME: %[[A0:[a-z0-9]+]]: vector<[4]xi8>, %[[B0:[a-z0-9]+]]: vector<[4]xi8>, |
| // CHECK-SAME: %[[A1:[a-z0-9]+]]: vector<[4]xi8>, %[[B1:[a-z0-9]+]]: vector<[4]xi8>, |
| // CHECK-SAME: %[[A2:[a-z0-9]+]]: vector<[4]xi8>, %[[B2:[a-z0-9]+]]: vector<[4]xi8>, |
| // CHECK-SAME: %[[A3:[a-z0-9]+]]: vector<[4]xi8>, %[[B3:[a-z0-9]+]]: vector<[4]xi8>, |
| // CHECK-SAME: %[[A0_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B0_MASK:[a-z0-9]+]]: vector<[4]xi1>, |
| // CHECK-SAME: %[[A1_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B1_MASK:[a-z0-9]+]]: vector<[4]xi1>, |
| // CHECK-SAME: %[[A2_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B2_MASK:[a-z0-9]+]]: vector<[4]xi1>, |
| // CHECK-SAME: %[[A3_MASK:[a-z0-9]+]]: vector<[4]xi1>, %[[B3_MASK:[a-z0-9]+]]: vector<[4]xi1> |
| // CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0> : vector<[4]x[4]xi32> |
| // CHECK-DAG: %[[LHS0:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A0]], %[[A2]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8> |
| // CHECK-DAG: %[[LHS1:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A1]], %[[A3]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8> |
| // CHECK-DAG: %[[RHS0:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B0]], %[[B2]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8> |
| // CHECK-DAG: %[[RHS1:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B1]], %[[B3]]) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8> |
| // CHECK-DAG: %[[LHS:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS0]], %[[LHS1]]) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8> |
| // CHECK-DAG: %[[RHS:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[RHS0]], %[[RHS1]]) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8> |
| // CHECK-DAG: %[[LHS0_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A0_MASK]], %[[A2_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1> |
| // CHECK-DAG: %[[LHS1_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A1_MASK]], %[[A3_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1> |
| // CHECK-DAG: %[[RHS0_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B0_MASK]], %[[B2_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1> |
| // CHECK-DAG: %[[RHS1_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B1_MASK]], %[[B3_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1> |
| // CHECK-DAG: %[[LHS_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS0_MASK]], %[[LHS1_MASK]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[16]xi1> |
| // CHECK-DAG: %[[RHS_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[RHS0_MASK]], %[[RHS1_MASK]]) : (vector<[8]xi1>, vector<[8]xi1>) -> vector<[16]xi1> |
| // CHECK-DAG: arm_sme.smopa_4way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32> |
| func.func @outerproduct_add_widening_4way_signed_i8i8i32( |
| %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>, |
| %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>, |
| %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>, |
| %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>, |
| %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, |
| %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>, |
| %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>, |
| %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> { |
| %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32> |
| %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32> |
| %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32> |
| %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32> |
| %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %acc = arith.constant dense<0> : vector<[4]x[4]xi32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32> |
| |
| return %3 : vector<[4]x[4]xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_sub_widening_4way_signed_i8i8i32 |
| // CHECK: arm_sme.smops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32> |
| func.func @outerproduct_sub_widening_4way_signed_i8i8i32( |
| %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>, |
| %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>, |
| %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>, |
| %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>, |
| %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, |
| %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>, |
| %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>, |
| %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> { |
| %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32> |
| %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32> |
| %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32> |
| %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32> |
| %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %acc = arith.constant dense<0> : vector<[4]x[4]xi32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32> |
| |
| return %3 : vector<[4]x[4]xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_add_widening_4way_signed_i16i16i64 |
| // CHECK: arm_sme.smopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> |
| func.func @outerproduct_add_widening_4way_signed_i16i16i64( |
| %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>, |
| %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>, |
| %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>, |
| %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>, |
| %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>, |
| %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>, |
| %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>, |
| %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> { |
| %a0_ext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64> |
| %b0_ext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a1_ext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64> |
| %b1_ext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a2_ext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64> |
| %b2_ext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a3_ext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64> |
| %b3_ext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %acc = arith.constant dense<0> : vector<[2]x[2]xi64> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64> |
| |
| return %3 : vector<[2]x[2]xi64> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_sub_widening_4way_signed_i16i16i64 |
| // CHECK: arm_sme.smops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> |
| func.func @outerproduct_sub_widening_4way_signed_i16i16i64( |
| %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>, |
| %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>, |
| %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>, |
| %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>, |
| %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>, |
| %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>, |
| %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>, |
| %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> { |
| %a0_ext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64> |
| %b0_ext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a1_ext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64> |
| %b1_ext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a2_ext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64> |
| %b2_ext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a3_ext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64> |
| %b3_ext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %acc = arith.constant dense<0> : vector<[2]x[2]xi64> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64> |
| |
| return %3 : vector<[2]x[2]xi64> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_i8i8i32 |
| // CHECK: arm_sme.umopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32> |
| func.func @outerproduct_add_widening_4way_unsigned_i8i8i32( |
| %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>, |
| %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>, |
| %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>, |
| %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>, |
| %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, |
| %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>, |
| %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>, |
| %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> { |
| %a0_ext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32> |
| %b0_ext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a1_ext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32> |
| %b1_ext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a2_ext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32> |
| %b2_ext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a3_ext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32> |
| %b3_ext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %acc = arith.constant dense<0> : vector<[4]x[4]xi32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32> |
| |
| return %3 : vector<[4]x[4]xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_i8i8i32 |
| // CHECK: arm_sme.umops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32> |
| func.func @outerproduct_sub_widening_4way_unsigned_i8i8i32( |
| %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>, |
| %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>, |
| %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>, |
| %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>, |
| %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, |
| %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>, |
| %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>, |
| %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> { |
| %a0_ext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32> |
| %b0_ext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a1_ext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32> |
| %b1_ext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a2_ext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32> |
| %b2_ext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a3_ext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32> |
| %b3_ext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %acc = arith.constant dense<0> : vector<[4]x[4]xi32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32> |
| |
| return %3 : vector<[4]x[4]xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_i16i16i64 |
| // CHECK: arm_sme.umopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> |
| func.func @outerproduct_add_widening_4way_unsigned_i16i16i64( |
| %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>, |
| %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>, |
| %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>, |
| %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>, |
| %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>, |
| %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>, |
| %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>, |
| %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> { |
| %a0_ext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64> |
| %b0_ext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a1_ext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64> |
| %b1_ext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a2_ext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64> |
| %b2_ext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a3_ext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64> |
| %b3_ext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %acc = arith.constant dense<0> : vector<[2]x[2]xi64> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64> |
| |
| return %3 : vector<[2]x[2]xi64> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_i16i16i64 |
| // CHECK: arm_sme.umops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> |
| func.func @outerproduct_sub_widening_4way_unsigned_i16i16i64( |
| %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>, |
| %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>, |
| %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>, |
| %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>, |
| %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>, |
| %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>, |
| %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>, |
| %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> { |
| %a0_ext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64> |
| %b0_ext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a1_ext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64> |
| %b1_ext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a2_ext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64> |
| %b2_ext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a3_ext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64> |
| %b3_ext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %acc = arith.constant dense<0> : vector<[2]x[2]xi64> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64> |
| |
| return %3 : vector<[2]x[2]xi64> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_add_widening_4way_signed_by_unsigned_i8i8i32 |
| // CHECK: arm_sme.sumopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32> |
| func.func @outerproduct_add_widening_4way_signed_by_unsigned_i8i8i32( |
| %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>, |
| %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>, |
| %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>, |
| %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>, |
| %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, |
| %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>, |
| %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>, |
| %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> { |
| %a0_sext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32> |
| %b0_zext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a1_sext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32> |
| %b1_zext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a2_sext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32> |
| %b2_zext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a3_sext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32> |
| %b3_zext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %acc = arith.constant dense<0> : vector<[4]x[4]xi32> |
| |
| %0 = arm_sme.outerproduct %a0_sext, %b0_zext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %1 = arm_sme.outerproduct %a1_sext, %b1_zext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %2 = arm_sme.outerproduct %a2_sext, %b2_zext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %3 = arm_sme.outerproduct %a3_sext, %b3_zext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32> |
| |
| return %3 : vector<[4]x[4]xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_sub_widening_4way_signed_by_unsigned_i8i8i32 |
| // CHECK: arm_sme.sumops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32> |
| func.func @outerproduct_sub_widening_4way_signed_by_unsigned_i8i8i32( |
| %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>, |
| %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>, |
| %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>, |
| %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>, |
| %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, |
| %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>, |
| %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>, |
| %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> { |
| %a0_sext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32> |
| %b0_zext = arith.extui %b0 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a1_sext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32> |
| %b1_zext = arith.extui %b1 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a2_sext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32> |
| %b2_zext = arith.extui %b2 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a3_sext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32> |
| %b3_zext = arith.extui %b3 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %acc = arith.constant dense<0> : vector<[4]x[4]xi32> |
| |
| %0 = arm_sme.outerproduct %a0_sext, %b0_zext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %1 = arm_sme.outerproduct %a1_sext, %b1_zext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %2 = arm_sme.outerproduct %a2_sext, %b2_zext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %3 = arm_sme.outerproduct %a3_sext, %b3_zext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32> |
| |
| return %3 : vector<[4]x[4]xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_add_widening_4way_signed_by_unsigned_i16i16i64 |
| // CHECK: arm_sme.sumopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> |
| func.func @outerproduct_add_widening_4way_signed_by_unsigned_i16i16i64( |
| %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>, |
| %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>, |
| %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>, |
| %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>, |
| %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>, |
| %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>, |
| %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>, |
| %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> { |
| %a0_sext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64> |
| %b0_zext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a1_sext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64> |
| %b1_zext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a2_sext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64> |
| %b2_zext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a3_sext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64> |
| %b3_zext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %acc = arith.constant dense<0> : vector<[2]x[2]xi64> |
| |
| %0 = arm_sme.outerproduct %a0_sext, %b0_zext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %1 = arm_sme.outerproduct %a1_sext, %b1_zext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %2 = arm_sme.outerproduct %a2_sext, %b2_zext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %3 = arm_sme.outerproduct %a3_sext, %b3_zext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64> |
| |
| return %3 : vector<[2]x[2]xi64> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_sub_widening_4way_signed_by_unsigned_i16i16i64 |
| // CHECK: arm_sme.sumops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> |
| func.func @outerproduct_sub_widening_4way_signed_by_unsigned_i16i16i64( |
| %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>, |
| %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>, |
| %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>, |
| %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>, |
| %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>, |
| %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>, |
| %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>, |
| %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> { |
| %a0_sext = arith.extsi %a0 : vector<[2]xi16> to vector<[2]xi64> |
| %b0_zext = arith.extui %b0 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a1_sext = arith.extsi %a1 : vector<[2]xi16> to vector<[2]xi64> |
| %b1_zext = arith.extui %b1 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a2_sext = arith.extsi %a2 : vector<[2]xi16> to vector<[2]xi64> |
| %b2_zext = arith.extui %b2 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a3_sext = arith.extsi %a3 : vector<[2]xi16> to vector<[2]xi64> |
| %b3_zext = arith.extui %b3 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %acc = arith.constant dense<0> : vector<[2]x[2]xi64> |
| |
| %0 = arm_sme.outerproduct %a0_sext, %b0_zext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %1 = arm_sme.outerproduct %a1_sext, %b1_zext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %2 = arm_sme.outerproduct %a2_sext, %b2_zext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %3 = arm_sme.outerproduct %a3_sext, %b3_zext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64> |
| |
| return %3 : vector<[2]x[2]xi64> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_by_signed_i8i8i32 |
| // CHECK: arm_sme.usmopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32> |
| func.func @outerproduct_add_widening_4way_unsigned_by_signed_i8i8i32( |
| %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>, |
| %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>, |
| %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>, |
| %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>, |
| %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, |
| %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>, |
| %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>, |
| %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> { |
| %a0_zext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32> |
| %b0_sext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a1_zext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32> |
| %b1_sext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a2_zext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32> |
| %b2_sext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a3_zext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32> |
| %b3_sext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %acc = arith.constant dense<0> : vector<[4]x[4]xi32> |
| |
| %0 = arm_sme.outerproduct %a0_zext, %b0_sext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %1 = arm_sme.outerproduct %a1_zext, %b1_sext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %2 = arm_sme.outerproduct %a2_zext, %b2_sext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %3 = arm_sme.outerproduct %a3_zext, %b3_sext acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32> |
| |
| return %3 : vector<[4]x[4]xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_by_signed_i8i8i32 |
| // CHECK: arm_sme.usmops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32> |
| func.func @outerproduct_sub_widening_4way_unsigned_by_signed_i8i8i32( |
| %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>, |
| %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>, |
| %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>, |
| %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>, |
| %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, |
| %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>, |
| %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>, |
| %a3_mask : vector<[4]xi1>, %b3_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> { |
| %a0_zext = arith.extui %a0 : vector<[4]xi8> to vector<[4]xi32> |
| %b0_sext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a1_zext = arith.extui %a1 : vector<[4]xi8> to vector<[4]xi32> |
| %b1_sext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a2_zext = arith.extui %a2 : vector<[4]xi8> to vector<[4]xi32> |
| %b2_sext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a3_zext = arith.extui %a3 : vector<[4]xi8> to vector<[4]xi32> |
| %b3_sext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %acc = arith.constant dense<0> : vector<[4]x[4]xi32> |
| |
| %0 = arm_sme.outerproduct %a0_zext, %b0_sext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %1 = arm_sme.outerproduct %a1_zext, %b1_sext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %2 = arm_sme.outerproduct %a2_zext, %b2_sext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %3 = arm_sme.outerproduct %a3_zext, %b3_sext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[4]xi32>, vector<[4]xi32> |
| |
| return %3 : vector<[4]x[4]xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_add_widening_4way_unsigned_by_signed_i16i16i64 |
| // CHECK: arm_sme.usmopa_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> |
| func.func @outerproduct_add_widening_4way_unsigned_by_signed_i16i16i64( |
| %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>, |
| %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>, |
| %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>, |
| %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>, |
| %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>, |
| %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>, |
| %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>, |
| %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> { |
| %a0_zext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64> |
| %b0_sext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a1_zext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64> |
| %b1_sext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a2_zext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64> |
| %b2_sext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a3_zext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64> |
| %b3_sext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %acc = arith.constant dense<0> : vector<[2]x[2]xi64> |
| |
| %0 = arm_sme.outerproduct %a0_zext, %b0_sext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %1 = arm_sme.outerproduct %a1_zext, %b1_sext acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %2 = arm_sme.outerproduct %a2_zext, %b2_sext acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %3 = arm_sme.outerproduct %a3_zext, %b3_sext acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64> |
| |
| return %3 : vector<[2]x[2]xi64> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_sub_widening_4way_unsigned_by_signed_i16i16i64 |
| // CHECK: arm_sme.usmops_4way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> |
| func.func @outerproduct_sub_widening_4way_unsigned_by_signed_i16i16i64( |
| %a0 : vector<[2]xi16>, %b0 : vector<[2]xi16>, |
| %a1 : vector<[2]xi16>, %b1 : vector<[2]xi16>, |
| %a2 : vector<[2]xi16>, %b2 : vector<[2]xi16>, |
| %a3 : vector<[2]xi16>, %b3 : vector<[2]xi16>, |
| %a0_mask : vector<[2]xi1>, %b0_mask : vector<[2]xi1>, |
| %a1_mask : vector<[2]xi1>, %b1_mask : vector<[2]xi1>, |
| %a2_mask : vector<[2]xi1>, %b2_mask : vector<[2]xi1>, |
| %a3_mask : vector<[2]xi1>, %b3_mask : vector<[2]xi1>) -> vector<[2]x[2]xi64> { |
| %a0_zext = arith.extui %a0 : vector<[2]xi16> to vector<[2]xi64> |
| %b0_sext = arith.extsi %b0 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a1_zext = arith.extui %a1 : vector<[2]xi16> to vector<[2]xi64> |
| %b1_sext = arith.extsi %b1 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a2_zext = arith.extui %a2 : vector<[2]xi16> to vector<[2]xi64> |
| %b2_sext = arith.extsi %b2 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %a3_zext = arith.extui %a3 : vector<[2]xi16> to vector<[2]xi64> |
| %b3_sext = arith.extsi %b3 : vector<[2]xi16> to vector<[2]xi64> |
| |
| %acc = arith.constant dense<0> : vector<[2]x[2]xi64> |
| |
| %0 = arm_sme.outerproduct %a0_zext, %b0_sext kind<sub> acc(%acc) masks(%a0_mask, %b0_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %1 = arm_sme.outerproduct %a1_zext, %b1_sext kind<sub> acc(%0) masks(%a1_mask, %b1_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %2 = arm_sme.outerproduct %a2_zext, %b2_sext kind<sub> acc(%1) masks(%a2_mask, %b2_mask) : vector<[2]xi64>, vector<[2]xi64> |
| %3 = arm_sme.outerproduct %a3_zext, %b3_sext kind<sub> acc(%2) masks(%a3_mask, %b3_mask) : vector<[2]xi64>, vector<[2]xi64> |
| |
| return %3 : vector<[2]x[2]xi64> |
| } |
| |
| /// Tests for related patterns. |
| |
| // ----- |
| |
| // CHECK-LABEL: @extract_from_arith_ext( |
| // CHECK-SAME: %[[SRC:.*]]: vector<4x[8]xi8> |
| // CHECK: %[[EXTRACT:.*]] = vector.extract %[[SRC]][0] : vector<[8]xi8> from vector<4x[8]xi8> |
| // CHECK: %[[EXTEND:.*]] = arith.extsi %[[EXTRACT]] : vector<[8]xi8> to vector<[8]xi32> |
| // CHECK: return %[[EXTEND]] |
| func.func @extract_from_arith_ext(%src: vector<4x[8]xi8>) -> vector<[8]xi32> { |
| %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32> |
| %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32> |
| return %1 : vector<[8]xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @non_constant_extract_from_arith_ext( |
| // CHECK-SAME: %[[SRC:[a-z0-9]+]]: vector<4x[8]xi8>, |
| // CHECK-SAME: %[[DIM:[a-z0-9]+]]: index |
| // CHECK: %[[EXTRACT:.*]] = vector.extract %[[SRC]][%[[DIM]]] : vector<[8]xi8> from vector<4x[8]xi8> |
| // CHECK: %[[EXTEND:.*]] = arith.extui %[[EXTRACT]] : vector<[8]xi8> to vector<[8]xi32> |
| // CHECK: return %[[EXTEND]] |
| func.func @non_constant_extract_from_arith_ext(%src: vector<4x[8]xi8>, %dim: index) -> vector<[8]xi32> { |
| %0 = arith.extui %src : vector<4x[8]xi8> to vector<4x[8]xi32> |
| %1 = vector.extract %0[%dim] : vector<[8]xi32> from vector<4x[8]xi32> |
| return %1 : vector<[8]xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @scalable_extract_from_arith_ext( |
| // CHECK-SAME: %[[SRC:.*]]: vector<[8]xf16> |
| // CHECK: %[[EXTRACT:.*]] = vector.scalable.extract %[[SRC]][0] : vector<[4]xf16> from vector<[8]xf16> |
| // CHECK: %[[EXTEND:.*]] = arith.extf %[[EXTRACT]] : vector<[4]xf16> to vector<[4]xf32> |
| // CHECK: return %[[EXTEND]] |
| func.func @scalable_extract_from_arith_ext(%src: vector<[8]xf16>) -> vector<[4]xf32> { |
| %0 = arith.extf %src : vector<[8]xf16> to vector<[8]xf32> |
| %1 = vector.scalable.extract %0[0] : vector<[4]xf32> from vector<[8]xf32> |
| return %1 : vector<[4]xf32> |
| } |
| |
| /// Negative tests |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_widening_2way__no_acc |
| // CHECK-NOT: arm_sme.fmopa_2way |
| // CHECK: arm_sme.outerproduct |
| // CHECK-NOT: arm_sme.fmopa_2way |
| func.func @outerproduct_widening_2way__no_acc(%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>) -> vector<[4]x[4]xf32> { |
| %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> |
| %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32> |
| |
| return %0 : vector<[4]x[4]xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_widening_4way__no_acc |
| // CHECK-NOT: arm_sme.fmopa_4way |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK-NOT: arm_sme.fmopa_4way |
| func.func @outerproduct_widening_4way__no_acc( |
| %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>, |
| %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>, |
| %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>) -> vector<[4]x[4]xi32> { |
| %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32> |
| %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32> |
| %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32> |
| %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32> |
| %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32> |
| |
| return %2 : vector<[4]x[4]xi32> |
| } |
| |
| // ----- |
| |
| /// Defining op of accumulator operand must be an 'arm_sme.outerproduct'. |
| |
| // CHECK-LABEL: @outerproduct_widening_2way__bad_acc |
| // CHECK-NOT: arm_sme.fmopa_2way |
| // CHECK: arm_sme.outerproduct |
| // CHECK-NOT: arm_sme.fmopa_2way |
| func.func @outerproduct_widening_2way__bad_acc(%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>, %acc : vector<[4]x[4]xf32>) -> vector<[4]x[4]xf32> { |
| %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> |
| %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) : vector<[4]xf32>, vector<[4]xf32> |
| |
| return %0 : vector<[4]x[4]xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_widening_4way__missing_acc |
| // CHECK-NOT: arm_sme.fmopa_4way |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK-NOT: arm_sme.fmopa_4way |
| func.func @outerproduct_widening_4way__missing_acc( |
| %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>, |
| %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>, |
| %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>, |
| %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> { |
| %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32> |
| %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32> |
| %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32> |
| %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32> |
| %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32> |
| %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32> |
| // Missing accumulator breaks use-def chain. |
| %3 = arm_sme.outerproduct %a3_ext, %b3_ext : vector<[4]xi32>, vector<[4]xi32> |
| |
| return %3 : vector<[4]x[4]xi32> |
| } |
| |
| // ----- |
| |
| /// Combining kinds of outer products must match to be fused. |
| |
| // CHECK-LABEL: @outerproduct_widening_2way__bad_combining_kind |
| // CHECK-NOT: arm_sme.fmopa_2way |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK-NOT: arm_sme.fmopa_2way |
| func.func @outerproduct_widening_2way__bad_combining_kind( |
| %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>, |
| %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>) -> vector<[4]x[4]xf32> { |
| %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> |
| %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> |
| %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> |
| %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<add> : vector<[4]xf32>, vector<[4]xf32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<sub> acc(%0) : vector<[4]xf32>, vector<[4]xf32> |
| |
| return %1 : vector<[4]x[4]xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_widening_4way__inconsistent_combining_kind |
| // CHECK-NOT: arm_sme.fmopa_4way |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK-NOT: arm_sme.fmopa_4way |
| func.func @outerproduct_widening_4way__inconsistent_combining_kind( |
| %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>, |
| %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>, |
| %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>, |
| %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> { |
| %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32> |
| %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32> |
| %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32> |
| %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32> |
| %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind<sub> : vector<[4]xi32>, vector<[4]xi32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind<add> acc(%0) : vector<[4]xi32>, vector<[4]xi32> |
| %2 = arm_sme.outerproduct %a2_ext, %b2_ext kind<add> acc(%1) : vector<[4]xi32>, vector<[4]xi32> |
| %3 = arm_sme.outerproduct %a3_ext, %b3_ext kind<add> acc(%2) : vector<[4]xi32>, vector<[4]xi32> |
| |
| return %3 : vector<[4]x[4]xi32> |
| } |
| |
| // ----- |
| |
| /// If the first outer product has uses other than as the input to another |
| /// outer product, it can't be erased after fusion. This is a problem when |
| /// it also has an accumulator as this will be used as the root for tile |
| /// allocation and since the widening outer product uses the same |
| /// accumulator it will get assigned the same tile ID, resulting in 3 |
| /// outer products and incorrect results. Check this is prevented. |
| |
| // CHECK-LABEL: @outerproduct_widening_2way__cant_erase |
| // CHECK-NOT: arm_sme.fmopa_2way |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK-NOT: arm_sme.fmopa_2way |
| func.func @outerproduct_widening_2way__cant_erase( |
| %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>, |
| %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>) -> vector<[4]x[4]xf32> { |
| %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> |
| %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> |
| %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> |
| %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> |
| |
| %acc = arith.constant dense<1.0> : vector<[4]x[4]xf32> |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) : vector<[4]xf32>, vector<[4]xf32> |
| "fake.use"(%0) : (vector<[4]x[4]xf32>) -> () |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32> |
| |
| return %1 : vector<[4]x[4]xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_widening_4way__multi_use_cant_erase |
| // CHECK-NOT: arm_sme.fmopa_4way |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK-NOT: arm_sme.fmopa_4way |
| func.func @outerproduct_widening_4way__multi_use_cant_erase( |
| %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>, |
| %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>, |
| %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>, |
| %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> { |
| %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32> |
| %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32> |
| %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32> |
| %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32> |
| %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32> |
| "fake.use"(%1) : (vector<[4]x[4]xi32>) -> () |
| %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32> |
| %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32> |
| |
| return %3 : vector<[4]x[4]xi32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_widening_2way__unsupported_type_f32f32f64 |
| // CHECK-NOT: arm_sme.fmopa_2way |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK-NOT: arm_sme.fmopa_2way |
| func.func @outerproduct_widening_2way__unsupported_type_f32f32f64( |
| %a0 : vector<[2]xf32>, %b0 : vector<[2]xf32>, |
| %a1 : vector<[2]xf32>, %b1 : vector<[2]xf32>) -> vector<[2]x[2]xf64> { |
| %a0_ext = arith.extf %a0 : vector<[2]xf32> to vector<[2]xf64> |
| %b0_ext = arith.extf %b0 : vector<[2]xf32> to vector<[2]xf64> |
| %a1_ext = arith.extf %a1 : vector<[2]xf32> to vector<[2]xf64> |
| %b1_ext = arith.extf %b1 : vector<[2]xf32> to vector<[2]xf64> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[2]xf64>, vector<[2]xf64> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[2]xf64>, vector<[2]xf64> |
| |
| return %1 : vector<[2]x[2]xf64> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_widening_4way__unsupported_type_f16f16f64 |
| // CHECK-NOT: arm_sme.fmopa_4way |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK-NOT: arm_sme.fmopa_4way |
| func.func @outerproduct_widening_4way__unsupported_type_f16f16f64( |
| %a0 : vector<[2]xf16>, %b0 : vector<[2]xf16>, |
| %a1 : vector<[2]xf16>, %b1 : vector<[2]xf16>, |
| %a2 : vector<[2]xf16>, %b2 : vector<[2]xf16>, |
| %a3 : vector<[2]xf16>, %b3 : vector<[2]xf16>) -> vector<[2]x[2]xf64> { |
| %a0_ext = arith.extf %a0 : vector<[2]xf16> to vector<[2]xf64> |
| %b0_ext = arith.extf %b0 : vector<[2]xf16> to vector<[2]xf64> |
| |
| %a1_ext = arith.extf %a1 : vector<[2]xf16> to vector<[2]xf64> |
| %b1_ext = arith.extf %b1 : vector<[2]xf16> to vector<[2]xf64> |
| |
| %a2_ext = arith.extf %a2 : vector<[2]xf16> to vector<[2]xf64> |
| %b2_ext = arith.extf %b2 : vector<[2]xf16> to vector<[2]xf64> |
| |
| %a3_ext = arith.extf %a3 : vector<[2]xf16> to vector<[2]xf64> |
| %b3_ext = arith.extf %b3 : vector<[2]xf16> to vector<[2]xf64> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[2]xf64>, vector<[2]xf64> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[2]xf64>, vector<[2]xf64> |
| %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[2]xf64>, vector<[2]xf64> |
| %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[2]xf64>, vector<[2]xf64> |
| |
| return %3 : vector<[2]x[2]xf64> |
| } |
| |
| // ----- |
| |
| /// Fusion only occurs if either both outer products are masked, or neither. |
| |
| // CHECK-LABEL: @outerproduct_widening_2way__bad_masking |
| // CHECK-NOT: arm_sme.fmopa_2way |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK-NOT: arm_sme.fmopa_2way |
| func.func @outerproduct_widening_2way__bad_masking( |
| %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>, |
| %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>, |
| %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> { |
| %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> |
| %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> |
| %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> |
| %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32> |
| |
| return %1 : vector<[4]x[4]xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_widening_4way__inconsistent_masking |
| // CHECK-NOT: arm_sme.fmopa_4way |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK-NOT: arm_sme.fmopa_4way |
| func.func @outerproduct_widening_4way__inconsistent_masking( |
| %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>, |
| %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>, |
| %a2 : vector<[4]xi8>, %b2 : vector<[4]xi8>, |
| %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>, |
| %a2_mask : vector<[4]xi1>, %b2_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> { |
| %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32> |
| %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32> |
| %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32> |
| %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32> |
| %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32> |
| %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) masks(%a2_mask, %b2_mask) : vector<[4]xi32>, vector<[4]xi32> |
| %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32> |
| |
| return %3 : vector<[4]x[4]xi32> |
| } |
| |
| // ----- |
| |
| /// Defining op of outer product must be a supported extension op. |
| |
| // CHECK-LABEL: @outerproduct_widening_2way__bad_defining_op |
| // CHECK-NOT: arm_sme.fmopa_2way |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK-NOT: arm_sme.fmopa_2way |
| func.func @outerproduct_widening_2way__bad_defining_op( |
| %a0 : vector<[4]xf32>, %b0 : vector<[4]xf32>, |
| %a1 : vector<[4]xf32>, %b1 : vector<[4]xf32>) -> vector<[4]x[4]xf32> { |
| %0 = arm_sme.outerproduct %a0, %b0 : vector<[4]xf32>, vector<[4]xf32> |
| %1 = arm_sme.outerproduct %a1, %b1 acc(%0) : vector<[4]xf32>, vector<[4]xf32> |
| |
| return %1 : vector<[4]x[4]xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @outerproduct_widening_4way__bad_defining_op |
| // CHECK-NOT: arm_sme.fmopa_4way |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK: arm_sme.outerproduct |
| // CHECK-NOT: arm_sme.fmopa_4way |
| func.func @outerproduct_widening_4way__bad_defining_op( |
| %a0 : vector<[4]xi8>, %b0 : vector<[4]xi8>, |
| %a1 : vector<[4]xi8>, %b1 : vector<[4]xi8>, |
| %a2 : vector<[4]xi32>, %b2 : vector<[4]xi32>, |
| %a3 : vector<[4]xi8>, %b3 : vector<[4]xi8>) -> vector<[4]x[4]xi32> { |
| %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32> |
| %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32> |
| %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32> |
| %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32> |
| |
| %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32> |
| %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32> |
| /// Inputs must come from an arith.ext. |
| %2 = arm_sme.outerproduct %a2, %b2 acc(%1) : vector<[4]xi32>, vector<[4]xi32> |
| %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32> |
| |
| return %3 : vector<[4]x[4]xi32> |
| } |
| |
| /// Negative tests for related patterns. |
| |
| // ----- |
| |
| /// Non-vector extracts should be ignored. |
| |
| // CHECK-LABEL: @extract_scalar_from_arith_ext |
| // CHECK-NEXT: arith.extsi |
| // CHECK-NEXT: vector.extract |
| func.func @extract_scalar_from_arith_ext(%src: vector<4x[8]xi8>) -> i32 { |
| %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32> |
| %1 = vector.extract %0[0, 0] : i32 from vector<4x[8]xi32> |
| return %1 : i32 |
| } |
| |
| // ----- |
| |
| /// Extracted type should be a 1-D scalable vector type. |
| |
| // CHECK-LABEL: @extract_fixed_1d_vec_from_arith_ext |
| // CHECK-NEXT: arith.extsi |
| // CHECK-NEXT: vector.extract |
| func.func @extract_fixed_1d_vec_from_arith_ext(%src: vector<4x8xi8>) -> vector<8xi32> { |
| %0 = arith.extsi %src : vector<4x8xi8> to vector<4x8xi32> |
| %1 = vector.extract %0[0] : vector<8xi32> from vector<4x8xi32> |
| return %1 : vector<8xi32> |
| } |
| |
| // ----- |
| |
| /// Extract must come from an arith extend. |
| |
| // CHECK-LABEL: @extract_from_non_arith_ext |
| // CHECK-NEXT: vector.extract |
| // CHECK-NEXT: return |
| func.func @extract_from_non_arith_ext(%src: vector<4x[8]xi32>) -> vector<[8]xi32> { |
| %0 = vector.extract %src[0] : vector<[8]xi32> from vector<4x[8]xi32> |
| return %0 : vector<[8]xi32> |
| } |
| |
| // ----- |
| |
| /// Scalable extract must come from an arith extend. |
| |
| // CHECK-LABEL: @scalable_extract_from_non_arith_ext |
| // CHECK-NEXT: vector.scalable.extract |
| // CHECK-NEXT: return |
| func.func @scalable_extract_from_non_arith_ext(%src: vector<[8]xf32>) -> vector<[4]xf32> { |
| %0 = vector.scalable.extract %src[0] : vector<[4]xf32> from vector<[8]xf32> |
| return %0 : vector<[4]xf32> |
| } |