blob: 4fb88dd1651263ad9611670799abd550f7f81208 [file] [edit]
// RUN: mlir-opt %s -convert-vector-to-amx -split-input-file | FileCheck %s
/// VNNI format is Intel's packed data layout.
/// For matrix multiplication, elements from the reduction dimension `k`
/// are packed into 32-bit tuples. Then the appropriate AMX operations can
/// perform tile multiplication directly on the packed data.
///
/// These packed elements are represented in the indexing maps by a separate
/// reduction dimension `vnni`.
#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
func.func @contract_vnni_f16(%A: vector<4x8x2xf16>, %B: vector<8x16x2xf16>,
%C: vector<4x16xf32>) -> vector<4x16xf32> {
%0 = vector.contract
{kind = #vector.kind<add>,
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
%A, %B, %C : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
return %0 : vector<4x16xf32>
}
// CHECK-LABEL: @contract_vnni_f16(
// CHECK-SAME: %[[A:.+]]: vector<4x8x2xf16>,
// CHECK-SAME: %[[B:.+]]: vector<8x16x2xf16>,
// CHECK-SAME: %[[C:.+]]: vector<4x16xf32>
/// AMX hardware has no direct access to the registers. Thus, data must
/// be transfered through intermediate buffers.
///
/// Load A vector into an AMX tile
// CHECK: %[[A_BUF:.+]] = memref.alloca() : memref<4x8x2xf16>
// CHECK: vector.transfer_write %[[A]], %[[A_BUF]]
// CHECK: %[[A_BUF_2D:.+]] = memref.collapse_shape %[[A_BUF]]
// CHECK-SAME: {{\[}}[0], [1, 2]] : memref<4x8x2xf16> into memref<4x16xf16>
// CHECK: %[[A_TILE:.+]] = amx.tile_load %[[A_BUF_2D]]
/// Load B vector into an AMX tile
// CHECK: %[[B_BUF:.+]] = memref.alloca() : memref<8x16x2xf16>
// CHECK: vector.transfer_write %[[B]], %[[B_BUF]]
// CHECK: %[[B_BUF_2D:.+]] = memref.collapse_shape %[[B_BUF]]
// CHECK-SAME: {{\[}}[0], [1, 2]] : memref<8x16x2xf16> into memref<8x32xf16>
// CHECK: %[[B_TILE:.+]] = amx.tile_load %[[B_BUF_2D]]
/// Load C vector into an AMX tile
// CHECK: %[[C_BUF:.+]] = memref.alloca() : memref<4x16xf32>
// CHECK: vector.transfer_write %[[C]], %[[C_BUF]]
// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_BUF]]
/// Perform tile multiplication
// CHECK: %[[RES:.+]] = amx.tile_mulf
// CHECK-SAME: %[[A_TILE]], %[[B_TILE]], %[[C_TILE]]
/// Load the result back into a vector
// CHECK: %[[RES_BUF:.+]] = memref.alloca() : memref<4x16xf32>
// CHECK: amx.tile_store %[[RES_BUF]]{{.*}}, %[[RES]]
// CHECK: %[[RES_VEC:.+]] = vector.transfer_read %[[RES_BUF]]
// CHECK: return %[[RES_VEC]]
// -----
#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
func.func @contract_vnni_bf16(%A: vector<4x8x2xbf16>, %B: vector<8x16x2xbf16>,
%C: vector<4x16xf32>) -> vector<4x16xf32> {
%0 = vector.contract
{kind = #vector.kind<add>,
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
%A, %B, %C : vector<4x8x2xbf16>, vector<8x16x2xbf16> into vector<4x16xf32>
return %0 : vector<4x16xf32>
}
// CHECK-LABEL: @contract_vnni_bf16(
// CHECK-COUNT-3: amx.tile_load
// CHECK: amx.tile_mulf
// CHECK: amx.tile_store
// -----
#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
func.func @contract_vnni_i8(%A: vector<4x16x4xi8>, %B: vector<16x8x4xi8>,
%C: vector<4x8xi32>) -> vector<4x8xi32> {
%0 = vector.contract
{kind = #vector.kind<add>,
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
%A, %B, %C : vector<4x16x4xi8>, vector<16x8x4xi8> into vector<4x8xi32>
return %0 : vector<4x8xi32>
}
// CHECK-LABEL: @contract_vnni_i8(
// CHECK-COUNT-3: amx.tile_load
// CHECK: amx.tile_muli
// CHECK: amx.tile_store
// -----
#map = affine_map<(vnni, m, k, n) -> (m, k, vnni)>
#map1 = affine_map<(vnni, m, k, n) -> (k, n, vnni)>
#map2 = affine_map<(vnni, m, k, n) -> (m, n)>
func.func @contract_shuffled_iterators(%A: vector<4x16x4xi8>, %B: vector<16x8x4xi8>,
%C: vector<4x8xi32>) -> vector<4x8xi32> {
%0 = vector.contract
{kind = #vector.kind<add>,
indexing_maps = [#map, #map1, #map2],
iterator_types = ["reduction", "parallel", "reduction", "parallel"]}
%A, %B, %C : vector<4x16x4xi8>, vector<16x8x4xi8> into vector<4x8xi32>
return %0 : vector<4x8xi32>
}
// CHECK-LABEL: @contract_shuffled_iterators(
// CHECK-COUNT-3: amx.tile_load
// CHECK: amx.tile_muli
// CHECK: amx.tile_store
// -----
#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
func.func @negative_invalid_kind(%A: vector<4x8x2xf16>, %B: vector<8x16x2xf16>,
%C: vector<4x16xf32>) -> vector<4x16xf32> {
%0 = vector.contract
{kind = #vector.kind<mul>,
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
%A, %B, %C : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
return %0 : vector<4x16xf32>
}
// CHECK-LABEL: @negative_invalid_kind(
// CHECK-NOT: amx
// CHECK: vector.contract
// -----
#map = affine_map<(m, k, vnni) -> (m, k, vnni)>
#map1 = affine_map<(m, k, vnni) -> (k, m, vnni)>
#map2 = affine_map<(m, k, vnni) -> ()>
func.func @negative_non_vector_acc(%A: vector<4x8x2xf16>, %B: vector<8x4x2xf16>,
%C: f32) -> f32 {
%0 = vector.contract
{kind = #vector.kind<add>,
indexing_maps = [#map, #map1, #map2],
iterator_types = ["reduction", "reduction", "reduction"]}
%A, %B, %C : vector<4x8x2xf16>, vector<8x4x2xf16> into f32
return %0 : f32
}
// CHECK-LABEL: @negative_non_vector_acc(
// CHECK-NOT: amx
// CHECK: vector.contract
// -----
#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
func.func @negative_invalid_operand_types(%A: vector<4x8x2xf32>,
%B: vector<8x16x2xf32>, %C: vector<4x16xf32>) -> vector<4x16xf32> {
%0 = vector.contract
{kind = #vector.kind<add>,
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
%A, %B, %C : vector<4x8x2xf32>, vector<8x16x2xf32> into vector<4x16xf32>
return %0 : vector<4x16xf32>
}
// CHECK-LABEL: @negative_invalid_operand_types(
// CHECK-NOT: amx
// CHECK: vector.contract
// -----
#map = affine_map<(m, n, k) -> (m, k)>
#map1 = affine_map<(m, n, k) -> (k, n)>
#map2 = affine_map<(m, n, k) -> (m, n)>
func.func @negative_non_packed_layout(%A: vector<4x16xf16>, %B: vector<16x16xf16>,
%C: vector<4x16xf32>) -> vector<4x16xf32> {
%0 = vector.contract
{kind = #vector.kind<add>,
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction"]}
%A, %B, %C : vector<4x16xf16>, vector<16x16xf16> into vector<4x16xf32>
return %0 : vector<4x16xf32>
}
// CHECK-LABEL: @negative_non_packed_layout(
// CHECK-NOT: amx
// CHECK: vector.contract
// -----
#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
func.func @negative_invalid_vnni_factor(%A: vector<4x2x4xf16>, %B: vector<2x2x4xf16>,
%C: vector<4x2xf32>) -> vector<4x2xf32> {
%0 = vector.contract
{kind = #vector.kind<add>,
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
%A, %B, %C : vector<4x2x4xf16>, vector<2x2x4xf16> into vector<4x2xf32>
return %0 : vector<4x2xf32>
}
// CHECK-LABEL: @negative_invalid_vnni_factor(
// CHECK-NOT: amx
// CHECK: vector.contract
// -----
#map = affine_map<(batch, m, n, k, vnni) -> (batch, m, k, vnni)>
#map1 = affine_map<(batch, m, n, k, vnni) -> (batch, k, n, vnni)>
#map2 = affine_map<(batch, m, n, k, vnni) -> (batch, m, n)>
func.func @negative_invalid_operands_shapes(%A: vector<1x4x8x2xf16>,
%B: vector<1x8x16x2xf16>, %C: vector<1x4x16xf32>) -> vector<1x4x16xf32> {
%0 = vector.contract
{kind = #vector.kind<add>,
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
%A, %B, %C : vector<1x4x8x2xf16>, vector<1x8x16x2xf16> into vector<1x4x16xf32>
return %0 : vector<1x4x16xf32>
}
// CHECK-LABEL: @negative_invalid_operands_shapes(
// CHECK-NOT: amx
// CHECK: vector.contract
// -----
#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
func.func @negative_too_many_rows(%A: vector<32x8x2xf16>, %B: vector<8x16x2xf16>,
%C: vector<32x16xf32>) -> vector<32x16xf32> {
%0 = vector.contract
{kind = #vector.kind<add>,
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
%A, %B, %C : vector<32x8x2xf16>, vector<8x16x2xf16> into vector<32x16xf32>
return %0 : vector<32x16xf32>
}
// CHECK-LABEL: @negative_too_many_rows(
// CHECK-NOT: amx
// CHECK: vector.contract
// -----
#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
func.func @negative_too_wide_rows(%A: vector<4x32x2xf16>, %B: vector<32x16x2xf16>,
%C: vector<4x16xf32>) -> vector<4x16xf32> {
%0 = vector.contract
{kind = #vector.kind<add>,
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
%A, %B, %C : vector<4x32x2xf16>, vector<32x16x2xf16> into vector<4x16xf32>
return %0 : vector<4x16xf32>
}
// CHECK-LABEL: @negative_too_wide_rows(
// CHECK-NOT: amx
// CHECK: vector.contract
// -----
#map = affine_map<(m, n, k, vnni) -> (k, vnni, m)>
#map1 = affine_map<(m, n, k, vnni) -> (n, k, vnni)>
#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
func.func @negative_input_dim_permutation(%A: vector<2x2x2xf16>,
%B: vector<2x2x2xf16>, %C: vector<2x2xf32>) -> vector<2x2xf32> {
%0 = vector.contract
{kind = #vector.kind<add>,
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
%A, %B, %C : vector<2x2x2xf16>, vector<2x2x2xf16> into vector<2x2xf32>
return %0 : vector<2x2xf32>
}
// CHECK-LABEL: @negative_input_dim_permutation(
// CHECK-NOT: amx
// CHECK: vector.contract
// -----
#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
#map2 = affine_map<(m, n, k, vnni) -> (n, m)>
func.func @negative_output_dim_permutation(%A: vector<4x8x2xf16>,
%B: vector<8x16x2xf16>, %C: vector<16x4xf32>) -> vector<16x4xf32> {
%0 = vector.contract
{kind = #vector.kind<add>,
indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
%A, %B, %C : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<16x4xf32>
return %0 : vector<16x4xf32>
}
// CHECK-LABEL: @negative_output_dim_permutation(
// CHECK-NOT: amx
// CHECK: vector.contract