blob: aa6539e466c95b32e33b00b763ebb9fecd80ec11 [file] [log] [blame] [edit]
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
// This file contains some tests of folding/canonicalizing vector.from_elements
///===----------------------------------------------===//
/// Tests of `rewriteFromElementsAsSplat`
///===----------------------------------------------===//
// CHECK-LABEL: func @extract_scalar_from_from_elements(
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) {
// Extract from 0D.
%0 = vector.from_elements %a : vector<f32>
%1 = vector.extract %0[] : f32 from vector<f32>
// Extract from 1D.
%2 = vector.from_elements %a : vector<1xf32>
%3 = vector.extract %2[0] : f32 from vector<1xf32>
%4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32>
%5 = vector.extract %4[4] : f32 from vector<5xf32>
// Extract from 2D.
%6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
%7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32>
%8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
%9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32>
%10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32>
// CHECK: return %[[A]], %[[A]], %[[B]], %[[A]], %[[A]], %[[B]], %[[B]]
return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32
}
// -----
// CHECK-LABEL: func @extract_1d_from_from_elements(
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32>
// CHECK: %[[SPLAT1:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32>
%1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32>
// CHECK: %[[SPLAT2:.*]] = vector.broadcast %[[B]] : f32 to vector<3xf32>
%2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32>
// CHECK: return %[[SPLAT1]], %[[SPLAT2]]
return %1, %2 : vector<3xf32>, vector<3xf32>
}
// -----
// CHECK-LABEL: func @extract_2d_from_from_elements(
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) {
%0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32>
// CHECK: %[[SPLAT1:.*]] = vector.from_elements %[[A]], %[[A]], %[[A]], %[[B]] : vector<2x2xf32>
%1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: %[[SPLAT2:.*]] = vector.from_elements %[[B]], %[[B]], %[[B]], %[[A]] : vector<2x2xf32>
%2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32>
// CHECK: return %[[SPLAT1]], %[[SPLAT2]]
return %1, %2 : vector<2x2xf32>, vector<2x2xf32>
}
// -----
// CHECK-LABEL: func @from_elements_to_splat(
// CHECK-SAME: %[[A:.*]]: f32, %[[B:.*]]: f32)
func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector<f32>) {
// CHECK: %[[SPLAT:.*]] = vector.broadcast %[[A]] : f32 to vector<2x3xf32>
%0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32>
// CHECK: %[[FROM_EL:.*]] = vector.from_elements {{.*}} : vector<2x3xf32>
%1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32>
// CHECK: %[[SPLAT2:.*]] = vector.broadcast %[[A]] : f32 to vector<f32>
%2 = vector.from_elements %a : vector<f32>
// CHECK: return %[[SPLAT]], %[[FROM_EL]], %[[SPLAT2]]
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
}
// -----
///===----------------------------------------------===//
/// Tests of `FromElementsToShapeCast`
///===----------------------------------------------===//
// CHECK-LABEL: func @to_shape_cast_rank2_to_rank1(
// CHECK-SAME: %[[A:.*]]: vector<1x2xi8>)
// CHECK: %[[EXTRACT:.*]] = vector.shape_cast %[[A]] : vector<1x2xi8> to vector<2xi8>
// CHECK: return %[[EXTRACT]] : vector<2xi8>
func.func @to_shape_cast_rank2_to_rank1(%arg0: vector<1x2xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
%1 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
%4 = vector.from_elements %0, %1 : vector<2xi8>
return %4 : vector<2xi8>
}
// -----
// CHECK-LABEL: func @to_shape_cast_rank1_to_rank3(
// CHECK-SAME: %[[A:.*]]: vector<8xi8>)
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[A]] : vector<8xi8> to vector<2x2x2xi8>
// CHECK: return %[[SHAPE_CAST]] : vector<2x2x2xi8>
func.func @to_shape_cast_rank1_to_rank3(%arg0: vector<8xi8>) -> vector<2x2x2xi8> {
%0 = vector.extract %arg0[0] : i8 from vector<8xi8>
%1 = vector.extract %arg0[1] : i8 from vector<8xi8>
%2 = vector.extract %arg0[2] : i8 from vector<8xi8>
%3 = vector.extract %arg0[3] : i8 from vector<8xi8>
%4 = vector.extract %arg0[4] : i8 from vector<8xi8>
%5 = vector.extract %arg0[5] : i8 from vector<8xi8>
%6 = vector.extract %arg0[6] : i8 from vector<8xi8>
%7 = vector.extract %arg0[7] : i8 from vector<8xi8>
%8 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7 : vector<2x2x2xi8>
return %8 : vector<2x2x2xi8>
}
// -----
// CHECK-LABEL: func @source_larger_than_out(
// CHECK-SAME: %[[A:.*]]: vector<2x3x4xi8>)
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][1] : vector<3x4xi8> from vector<2x3x4xi8>
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<3x4xi8> to vector<12xi8>
// CHECK: return %[[SHAPE_CAST]] : vector<12xi8>
func.func @source_larger_than_out(%arg0: vector<2x3x4xi8>) -> vector<12xi8> {
%0 = vector.extract %arg0[1, 0, 0] : i8 from vector<2x3x4xi8>
%1 = vector.extract %arg0[1, 0, 1] : i8 from vector<2x3x4xi8>
%2 = vector.extract %arg0[1, 0, 2] : i8 from vector<2x3x4xi8>
%3 = vector.extract %arg0[1, 0, 3] : i8 from vector<2x3x4xi8>
%4 = vector.extract %arg0[1, 1, 0] : i8 from vector<2x3x4xi8>
%5 = vector.extract %arg0[1, 1, 1] : i8 from vector<2x3x4xi8>
%6 = vector.extract %arg0[1, 1, 2] : i8 from vector<2x3x4xi8>
%7 = vector.extract %arg0[1, 1, 3] : i8 from vector<2x3x4xi8>
%8 = vector.extract %arg0[1, 2, 0] : i8 from vector<2x3x4xi8>
%9 = vector.extract %arg0[1, 2, 1] : i8 from vector<2x3x4xi8>
%10 = vector.extract %arg0[1, 2, 2] : i8 from vector<2x3x4xi8>
%11 = vector.extract %arg0[1, 2, 3] : i8 from vector<2x3x4xi8>
%12 = vector.from_elements %0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11 : vector<12xi8>
return %12 : vector<12xi8>
}
// -----
// This test is similar to `source_larger_than_out` except here the number of elements
// extracted contigously starting from the first position [0,0] could be 6 instead of 3
// and the pattern would still match.
// CHECK-LABEL: func @suffix_with_excess_zeros(
// CHECK: %[[EXT:.*]] = vector.extract {{.*}}[0] : vector<3xi8> from vector<2x3xi8>
// CHECK: return %[[EXT]] : vector<3xi8>
func.func @suffix_with_excess_zeros(%arg0: vector<2x3xi8>) -> vector<3xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8>
%1 = vector.extract %arg0[0, 1] : i8 from vector<2x3xi8>
%2 = vector.extract %arg0[0, 2] : i8 from vector<2x3xi8>
%3 = vector.from_elements %0, %1, %2 : vector<3xi8>
return %3 : vector<3xi8>
}
// -----
// CHECK-LABEL: func @large_source_with_shape_cast_required(
// CHECK-SAME: %[[A:.*]]: vector<2x2x2x2xi8>)
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[A]][0, 1] : vector<2x2xi8> from vector<2x2x2x2xi8>
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[EXTRACT]] : vector<2x2xi8> to vector<1x4x1xi8>
// CHECK: return %[[SHAPE_CAST]] : vector<1x4x1xi8>
func.func @large_source_with_shape_cast_required(%arg0: vector<2x2x2x2xi8>) -> vector<1x4x1xi8> {
%0 = vector.extract %arg0[0, 1, 0, 0] : i8 from vector<2x2x2x2xi8>
%1 = vector.extract %arg0[0, 1, 0, 1] : i8 from vector<2x2x2x2xi8>
%2 = vector.extract %arg0[0, 1, 1, 0] : i8 from vector<2x2x2x2xi8>
%3 = vector.extract %arg0[0, 1, 1, 1] : i8 from vector<2x2x2x2xi8>
%4 = vector.from_elements %0, %1, %2, %3 : vector<1x4x1xi8>
return %4 : vector<1x4x1xi8>
}
// -----
// Could match, but handled by `rewriteFromElementsAsSplat`.
// CHECK-LABEL: func @extract_single_elm(
// CHECK-NEXT: vector.extract
// CHECK-NEXT: vector.broadcast
// CHECK-NEXT: return
func.func @extract_single_elm(%arg0 : vector<2x3xi8>) -> vector<1xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<2x3xi8>
%1 = vector.from_elements %0 : vector<1xi8>
return %1 : vector<1xi8>
}
// -----
// CHECK-LABEL: func @negative_source_contiguous_but_not_suffix(
// CHECK-NOT: shape_cast
// CHECK: from_elements
func.func @negative_source_contiguous_but_not_suffix(%arg0: vector<2x3xi8>) -> vector<3xi8> {
%0 = vector.extract %arg0[0, 1] : i8 from vector<2x3xi8>
%1 = vector.extract %arg0[0, 2] : i8 from vector<2x3xi8>
%2 = vector.extract %arg0[1, 0] : i8 from vector<2x3xi8>
%3 = vector.from_elements %0, %1, %2 : vector<3xi8>
return %3 : vector<3xi8>
}
// -----
// The extracted elements are recombined into a single vector, but in a new order.
// CHECK-LABEL: func @negative_nonascending_order(
// CHECK-NOT: shape_cast
// CHECK: from_elements
func.func @negative_nonascending_order(%arg0: vector<1x2xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 1] : i8 from vector<1x2xi8>
%1 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>
return %2 : vector<2xi8>
}
// -----
// CHECK-LABEL: func @negative_nonstatic_extract(
// CHECK-NOT: shape_cast
// CHECK: from_elements
func.func @negative_nonstatic_extract(%arg0: vector<1x2xi8>, %i0 : index, %i1 : index) -> vector<2xi8> {
%0 = vector.extract %arg0[0, %i0] : i8 from vector<1x2xi8>
%1 = vector.extract %arg0[0, %i1] : i8 from vector<1x2xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>
return %2 : vector<2xi8>
}
// -----
// CHECK-LABEL: func @negative_different_sources(
// CHECK-NOT: shape_cast
// CHECK: from_elements
func.func @negative_different_sources(%arg0: vector<1x2xi8>, %arg1: vector<1x2xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x2xi8>
%1 = vector.extract %arg1[0, 1] : i8 from vector<1x2xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>
return %2 : vector<2xi8>
}
// -----
// CHECK-LABEL: func @negative_source_not_suffix(
// CHECK-NOT: shape_cast
// CHECK: from_elements
func.func @negative_source_not_suffix(%arg0: vector<1x3xi8>) -> vector<2xi8> {
%0 = vector.extract %arg0[0, 0] : i8 from vector<1x3xi8>
%1 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
%2 = vector.from_elements %0, %1 : vector<2xi8>
return %2 : vector<2xi8>
}
// -----
// The inserted elements are a subset of the extracted elements.
// [0, 1, 2] -> [1, 1, 2]
// CHECK-LABEL: func @negative_nobijection_order(
// CHECK-NOT: shape_cast
// CHECK: from_elements
func.func @negative_nobijection_order(%arg0: vector<1x3xi8>) -> vector<3xi8> {
%0 = vector.extract %arg0[0, 1] : i8 from vector<1x3xi8>
%1 = vector.extract %arg0[0, 2] : i8 from vector<1x3xi8>
%2 = vector.from_elements %0, %0, %1 : vector<3xi8>
return %2 : vector<3xi8>
}
// -----
// CHECK-LABEL: func @negative_source_too_small(
// CHECK-NOT: shape_cast
// CHECK: from_elements
func.func @negative_source_too_small(%arg0: vector<2xi8>) -> vector<4xi8> {
%0 = vector.extract %arg0[0] : i8 from vector<2xi8>
%1 = vector.extract %arg0[1] : i8 from vector<2xi8>
%2 = vector.from_elements %0, %1, %1, %1 : vector<4xi8>
return %2 : vector<4xi8>
}