blob: 6e5df86b13106acdae359d81ab0e26f9c25fd139 [file] [log] [blame]
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
// CHECK: mesh.mesh @mesh0
mesh.mesh @mesh0(shape = 2x2x4)
// CHECK: mesh.mesh @mesh1(shape = 4x?)
mesh.mesh @mesh1(shape = 4x?)
// CHECK: mesh.mesh @mesh2(shape = ?x4)
mesh.mesh @mesh2(shape = ?x4)
// CHECK: mesh.mesh @mesh3(shape = ?x?)
mesh.mesh @mesh3(shape = ?x?)
mesh.mesh @mesh4(shape = 3)
// CHECK: mesh.mesh @mesh5(shape = ?)
mesh.mesh @mesh5(shape = ?)
// CHECK-LABEL: func @mesh_shard_op_fully_replicated
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
func.func @mesh_shard_op_fully_replicated(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
// CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh0, {{\[\[}}]]> : tensor<4x8xf32>
%0 = mesh.shard %arg0 to <@mesh0, [[]]> : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// CHECK-LABEL: func @mesh_shard_op_1st_dim
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
func.func @mesh_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
// CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh0, {{\[\[}}0]]> : tensor<4x8xf32>
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// CHECK-LABEL: func @mesh_shard_op_2nd_dim
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
func.func @mesh_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
// CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh1, {{\[\[}}], [0]]> : tensor<4x8xf32>
%0 = mesh.shard %arg0 to <@mesh1, [[], [0]]> : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// CHECK-LABEL: func @mesh_shard_op_1st_and_3rd_dim
func.func @mesh_shard_op_1st_and_3rd_dim(
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32>
%arg0 : tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
// CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0], [], [1]]> : tensor<4x8x16xf32>
%0 = mesh.shard %arg0 to <@mesh3, [[0], [], [1]]> : tensor<4x8x16xf32>
return %0 : tensor<4x8x16xf32>
}
// CHECK-LABEL: func @mesh_shard_op_partial_max
func.func @mesh_shard_op_partial_max(
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
// CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = max[1]> : tensor<4x8xf32>
%0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = max[1]> : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// CHECK-LABEL: func @mesh_shard_op_partial_min
func.func @mesh_shard_op_partial_min(
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
// CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = min[1]> : tensor<4x8xf32>
%0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = min[1]> : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// CHECK-LABEL: func @mesh_shard_op_partial_generic
func.func @mesh_shard_op_partial_generic(
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
// CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = generic[1]> : tensor<4x8xf32>
%0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = generic[1]> : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// CHECK-LABEL: func @mesh_shard_op_partial_sum
func.func @mesh_shard_op_partial_sum(
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
// CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = sum[1]> : tensor<4x8xf32>
%0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = sum[1]> : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// CHECK-LABEL: func @mesh_shard_op_partial_sum_multi_axes
func.func @mesh_shard_op_partial_sum_multi_axes(
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> {
// CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = sum[1, 2]> : tensor<4x8xf32>
%0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = sum[1, 2]> : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
// CHECK-LABEL: func @mesh_shard_op_two_users
// CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32>
func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
(tensor<4x8xf32>, tensor<4x8xf32>) {
// CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh0, {{\[\[}}0]]> : tensor<4x8xf32>
%0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32>
// CHECK-DAG: mesh.shard %[[V0]] to <@mesh0, {{\[\[}}1]]> annotate_for_users : tensor<4x8xf32>
%1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32>
// CHECK-DAG: mesh.shard %[[V0]] to <@mesh0, {{\[\[}}2]]> annotate_for_users : tensor<4x8xf32>
%2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32>
return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
}
// CHECK-LABEL: func @mesh_shape
func.func @mesh_shape() -> (index, index) {
// CHECK: %[[RES:.*]]:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
%0:2 = mesh.mesh_shape @mesh0 axes = [0, 1] : index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
return %0#0, %0#1 : index, index
}
// CHECK-LABEL: func @mesh_shape_default_axes
func.func @mesh_shape_default_axes() -> (index, index, index) {
// CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index
%0:3 = mesh.mesh_shape @mesh0 : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
// CHECK-LABEL: func @mesh_shape_empty_axes
func.func @mesh_shape_empty_axes() -> (index, index, index) {
// CHECK: %[[RES:.*]]:3 = mesh.mesh_shape @mesh0 : index, index, index
%0:3 = mesh.mesh_shape @mesh0 axes = [] : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
// CHECK-LABEL: func @process_multi_index
func.func @process_multi_index() -> (index, index) {
// CHECK: %[[RES:.*]]:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
%0:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
return %0#0, %0#1 : index, index
}
// CHECK-LABEL: func @process_multi_index_default_axes
func.func @process_multi_index_default_axes() -> (index, index, index) {
// CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
%0:3 = mesh.process_multi_index on @mesh0 : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
// CHECK-LABEL: func @process_multi_index_empty_axes
func.func @process_multi_index_empty_axes() -> (index, index, index) {
// CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index
%0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index
// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}
// CHECK-LABEL: func @process_linear_index
func.func @process_linear_index() -> index {
// CHECK: %[[RES:.*]] = mesh.process_linear_index on @mesh0 : index
%0 = mesh.process_linear_index on @mesh0 : index
// CHECK: return %[[RES]] : index
return %0 : index
}
// CHECK-LABEL: func @all_reduce
func.func @all_reduce(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> {
// CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = <max>
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64>
%0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = <max>
: tensor<3x4xf32> -> tensor<3x4xf64>
return %0 : tensor<3x4xf64>
}
// CHECK-LABEL: func @all_gather
func.func @all_gather(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> {
// CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32>
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
: tensor<3x4xf32> -> tensor<3x16xf32>
return %0 : tensor<3x16xf32>
}
// CHECK-LABEL: func @all_gather_dynamic_dims_in_tensor
func.func @all_gather_dynamic_dims_in_tensor(
// CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1
// CHECK-SAME: : tensor<?x?xf32> -> tensor<?x?xf32>
%0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1
: tensor<?x?xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-LABEL: func @all_gather_dynamic_dims_in_mesh
func.func @all_gather_dynamic_dims_in_mesh(
// CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32>
%arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> {
// CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh3 mesh_axes = [1] gather_axis = 1
// CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32>
%0 = mesh.all_gather %arg0 on @mesh3 mesh_axes = [1] gather_axis = 1
: tensor<5x6xf32> -> tensor<5x?xf32>
return %0 : tensor<5x?xf32>
}
// CHECK-LABEL: func @all_slice_static_dimensions
func.func @all_slice_static_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
// CHECK-NEXT: mesh.all_slice %[[ARG]]
// CHECK-SAME: on @mesh0 mesh_axes = [2] slice_axis = 1
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf32>
%0 = mesh.all_slice %arg0 on @mesh0 mesh_axes = [2] slice_axis = 1
: tensor<3x4xf32> -> tensor<3x1xf32>
return %0 : tensor<3x1xf32>
}
// CHECK-LABEL: func @all_slice_dynamic_dimensions
func.func @all_slice_dynamic_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
// CHECK-NEXT: mesh.all_slice %[[ARG]]
// CHECK-SAME: on @mesh3 mesh_axes = [0, 1] slice_axis = 0
// CHECK-SAME: : tensor<?xf32> -> tensor<?xf32>
%0 = mesh.all_slice %arg0 on @mesh3 mesh_axes = [0, 1] slice_axis = 0
: tensor<?xf32> -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: func @all_to_all
func.func @all_to_all(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
// CHECK-NEXT: mesh.all_to_all %[[ARG]]
// CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
// CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8>
%0 = mesh.all_to_all %arg0 on @mesh4
split_axis = 1 concat_axis = 0
: tensor<3x6xi8> -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
}
// CHECK-LABEL: func @all_to_all_dynamic_dims_in_result
func.func @all_to_all_dynamic_dims_in_result(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
%arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> {
// CHECK-NEXT: mesh.all_to_all %[[ARG]]
// CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0
// CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8>
%0 = mesh.all_to_all %arg0 on @mesh4
split_axis = 1 concat_axis = 0
: tensor<3x6xi8> -> tensor<3x?xi8>
return %0 : tensor<3x?xi8>
}
// CHECK-LABEL: func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size
func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size(
// CHECK-SAME: %[[ARG:.*]]: tensor<3xi8>
%arg0 : tensor<3xi8>) -> tensor<3xi8> {
// CHECK-NEXT: mesh.all_to_all %[[ARG]]
// CHECK-SAME: @mesh4 split_axis = 0 concat_axis = 0
// CHECK-SAME: : tensor<3xi8> -> tensor<3xi8>
%0 = mesh.all_to_all %arg0 on @mesh4
split_axis = 0 concat_axis = 0
: tensor<3xi8> -> tensor<3xi8>
return %0 : tensor<3xi8>
}
// CHECK-LABEL: func @all_to_all_non_divisible_split_axis_size
func.func @all_to_all_non_divisible_split_axis_size(
// CHECK-SAME: %[[ARG:.*]]: tensor<2x3xi8>
%arg0 : tensor<2x3xi8>) -> tensor<?x12xi8> {
// CHECK-NEXT: mesh.all_to_all %[[ARG]]
// CHECK-SAME: @mesh0 mesh_axes = [0, 1] split_axis = 0 concat_axis = 1
// CHECK-SAME: : tensor<2x3xi8> -> tensor<?x12xi8>
%0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 1]
split_axis = 0 concat_axis = 1
: tensor<2x3xi8> -> tensor<?x12xi8>
return %0 : tensor<?x12xi8>
}
// CHECK-LABEL: func @broadcast_static_root
func.func @broadcast_static_root(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
%arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> {
// CHECK-NEXT: mesh.broadcast %[[ARG]]
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
// CHECK-SAME: root = [0, 1]
// CHECK-SAME: : (tensor<3x6xi8>) -> tensor<3x6xi8>
%0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2]
root = [0, 1]
: (tensor<3x6xi8>) -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
}
// CHECK-LABEL: func @broadcast_dynamic_root
func.func @broadcast_dynamic_root(
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x6xi8>
%arg0 : tensor<3x6xi8>,
// CHECK-SAME: %[[ARG1:.*]]: index
%arg1 : index
) -> tensor<3x6xi8> {
// CHECK-NEXT: mesh.broadcast %[[ARG0]]
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
// CHECK-SAME: root = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<3x6xi8>
%0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2]
root = [1, %arg1]
: (tensor<3x6xi8>, index) -> tensor<3x6xi8>
return %0 : tensor<3x6xi8>
}
// CHECK-LABEL: func @gather_static_root
func.func @gather_static_root(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8>
%arg0 : tensor<3x6xi8>) -> tensor<24x6xi8> {
// CHECK-NEXT: mesh.gather %[[ARG]]
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
// CHECK-SAME: gather_axis = 0
// CHECK-SAME: root = [0, 1]
// CHECK-SAME: : (tensor<3x6xi8>) -> tensor<24x6xi8>
%0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2]
gather_axis = 0
root = [0, 1]
: (tensor<3x6xi8>) -> tensor<24x6xi8>
return %0 : tensor<24x6xi8>
}
// CHECK-LABEL: func @gather_dynamic_root
func.func @gather_dynamic_root(
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x6xi8>
%arg0 : tensor<3x6xi8>,
// CHECK-SAME: %[[ARG1:.*]]: index
%arg1 : index
) -> tensor<24x6xi8> {
// CHECK-NEXT: mesh.gather %[[ARG0]]
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
// CHECK-SAME: gather_axis = 0
// CHECK-SAME: root = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<24x6xi8>
%0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2]
gather_axis = 0
root = [1, %arg1]
: (tensor<3x6xi8>, index) -> tensor<24x6xi8>
return %0 : tensor<24x6xi8>
}
// CHECK-LABEL: func @receive_static_source
func.func @receive_static_source(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// CHECK-NEXT: mesh.recv %[[ARG]]
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
// CHECK-SAME: source = [0, 1]
// CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
%0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
source = [0, 1]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
}
// CHECK-LABEL: func @receive_dynamic_source
func.func @receive_dynamic_source(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>,
// CHECK-SAME: %[[ARG1:.*]]: index
%arg1 : index
) -> tensor<2xi8> {
// CHECK-NEXT: mesh.recv %[[ARG0]]
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
// CHECK-SAME: source = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
%0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
source = [1, %arg1]
: (tensor<2xi8>, index) -> tensor<2xi8>
return %0 : tensor<2xi8>
}
// CHECK-LABEL: func @receive_no_source
func.func @receive_no_source(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// CHECK-NEXT: mesh.recv %[[ARG]]
// CHECK-NOT: source
%0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
}
// CHECK-LABEL: func @reduce_static_root
func.func @reduce_static_root(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// CHECK-NEXT: mesh.reduce %[[ARG]]
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
// CHECK-SAME: root = [0, 1]
// CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
%0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
root = [0, 1]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
}
// CHECK-LABEL: func @reduce_dynamic_root
func.func @reduce_dynamic_root(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>,
// CHECK-SAME: %[[ARG1:.*]]: index
%arg1 : index
) -> tensor<2xi8> {
// CHECK-NEXT: mesh.reduce %[[ARG0]]
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
// CHECK-SAME: root = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
%0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
root = [1, %arg1]
: (tensor<2xi8>, index) -> tensor<2xi8>
return %0 : tensor<2xi8>
}
// CHECK-LABEL: func @reduce_different_return_element_type
func.func @reduce_different_return_element_type(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>) -> tensor<2xi16> {
// CHECK-NEXT: mesh.reduce %[[ARG]]
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
// CHECK-SAME: root = [0, 1]
// CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi16>
%0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2]
root = [0, 1]
: (tensor<2xi8>) -> tensor<2xi16>
return %0 : tensor<2xi16>
}
// CHECK-LABEL: func @reduce_scatter_static_dimensions
func.func @reduce_scatter_static_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> {
// CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
// CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = <max> scatter_axis = 1
// CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64>
%0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [2]
reduction = <max> scatter_axis = 1
: tensor<3x4xf32> -> tensor<3x1xf64>
return %0 : tensor<3x1xf64>
}
// CHECK-LABEL: func @reduce_scatter_dynamic_dimensions
func.func @reduce_scatter_dynamic_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
%arg0 : tensor<?xf32>) -> tensor<?xf64> {
// CHECK-NEXT: mesh.reduce_scatter %[[ARG]]
// CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
// CHECK-SAME: : tensor<?xf32> -> tensor<?xf64>
%0 = mesh.reduce_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0
: tensor<?xf32> -> tensor<?xf64>
return %0 : tensor<?xf64>
}
// CHECK-LABEL: func @scatter_static_dimensions
func.func @scatter_static_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
%arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> {
// CHECK-NEXT: mesh.scatter %[[ARG]]
// CHECK-SAME: on @mesh0 mesh_axes = [2]
// CHECK-SAME: scatter_axis = 1 root = [1]
// CHECK-SAME: : (tensor<3x4xf32>) -> tensor<3x1xf32>
%0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [2]
scatter_axis = 1 root = [1]
: (tensor<3x4xf32>) -> tensor<3x1xf32>
return %0 : tensor<3x1xf32>
}
// CHECK-LABEL: func @scatter_dynamic_dimensions
func.func @scatter_dynamic_dimensions(
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
%arg0 : tensor<?xf32>) -> tensor<?xf32> {
// CHECK-NEXT: mesh.scatter %[[ARG]]
// CHECK-SAME: on @mesh3 mesh_axes = [0, 1]
// CHECK-SAME: scatter_axis = 0 root = [1, 2]
// CHECK-SAME: : (tensor<?xf32>) -> tensor<?xf32>
%0 = mesh.scatter %arg0 on @mesh3 mesh_axes = [0, 1]
scatter_axis = 0 root = [1, 2]
: (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: func @scatter_dynamic_root
func.func @scatter_dynamic_root(
// CHECK-SAME: %[[ARG0:.*]]: tensor<8xi8>
%arg0 : tensor<8xi8>,
// CHECK-SAME: %[[ARG1:.*]]: index
%arg1 : index
) -> tensor<1xi8> {
// CHECK-NEXT: mesh.scatter %[[ARG0]]
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
// CHECK-SAME: scatter_axis = 0
// CHECK-SAME: root = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<8xi8>, index) -> tensor<1xi8>
%0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 2]
scatter_axis = 0
root = [1, %arg1]
: (tensor<8xi8>, index) -> tensor<1xi8>
return %0 : tensor<1xi8>
}
// CHECK-LABEL: func @send_static_destination
func.func @send_static_destination(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// CHECK-NEXT: mesh.send %[[ARG]]
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
// CHECK-SAME: destination = [0, 1]
// CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8>
%0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2]
destination = [0, 1]
: (tensor<2xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
}
// CHECK-LABEL: func @send_dynamic_destination
func.func @send_dynamic_destination(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>,
// CHECK-SAME: %[[ARG1:.*]]: index
%arg1 : index
) -> tensor<2xi8> {
// CHECK-NEXT: mesh.send %[[ARG0]]
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
// CHECK-SAME: destination = [1, %[[ARG1]]]
// CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8>
%0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2]
destination = [1, %arg1]
: (tensor<2xi8>, index) -> tensor<2xi8>
return %0 : tensor<2xi8>
}
// CHECK-LABEL: func @shift
func.func @shift(
// CHECK-SAME: %[[ARG:.*]]: tensor<2xi8>
%arg0 : tensor<2xi8>) -> tensor<2xi8> {
// CHECK-NEXT: mesh.shift %[[ARG]]
// CHECK-SAME: on @mesh0 mesh_axes = [0, 2]
// CHECK-SAME: shift_axis = 2 offset = -2 rotate
// CHECK-SAME: : tensor<2xi8> -> tensor<2xi8>
%0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 2]
shift_axis = 2 offset = -2 rotate
: tensor<2xi8> -> tensor<2xi8>
return %0 : tensor<2xi8>
}