| // RUN: mlir-opt --canonicalize %s | FileCheck %s |
| |
| mesh.mesh @mesh0(shape = 2x4) |
| |
| // CHECK-LABEL: func @all_reduce_empty_mesh_axes |
| func.func @all_reduce_empty_mesh_axes( |
| // CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> |
| %arg0 : tensor<4xf32>) -> tensor<4xf32> { |
| // CHECK-NOT: mesh.all_reduce |
| %0 = mesh.all_reduce %arg0 on @mesh0 |
| mesh_axes = [] |
| : tensor<4xf32> -> tensor<4xf32> |
| // CHECK: return %[[ARG]] |
| return %0 : tensor<4xf32> |
| } |
| |
| // CHECK-LABEL: func @all_reduce_empty_mesh_axes_different_return_type |
| func.func @all_reduce_empty_mesh_axes_different_return_type( |
| %arg0 : tensor<4xf32>) -> tensor<4xf64> { |
| // CHECK: mesh.all_reduce |
| %0 = mesh.all_reduce %arg0 on @mesh0 |
| // CHECK-NOT: mesh_axes |
| mesh_axes = [] |
| : tensor<4xf32> -> tensor<4xf64> |
| return %0 : tensor<4xf64> |
| } |
| |
| // CHECK-LABEL: func @all_reduce_default_reduction |
| func.func @all_reduce_default_reduction( |
| %arg0 : tensor<4xf32>) -> tensor<4xf64> { |
| %0 = mesh.all_reduce %arg0 on @mesh0 |
| mesh_axes = [0] |
| // CHECK-NOT: reduction |
| reduction = <sum> |
| : tensor<4xf32> -> tensor<4xf64> |
| return %0 : tensor<4xf64> |
| } |
| |
| // CHECK-LABEL: func @all_to_all_empty_mesh_axes |
| func.func @all_to_all_empty_mesh_axes( |
| // CHECK-SAME: %[[ARG:.*]]: tensor<8xf32> |
| %arg0 : tensor<8xf32>) -> tensor<8xf32> { |
| // CHECK-NOT: mesh.all_to_all |
| %0 = mesh.all_to_all %arg0 on @mesh0 |
| mesh_axes = [] |
| split_axis = 0 |
| concat_axis = 0 |
| : tensor<8xf32> -> tensor<8xf32> |
| // CHECK: return %[[ARG]] |
| return %0 : tensor<8xf32> |
| } |
| |
| // CHECK-LABEL: func @all_gather_empty_mesh_axes |
| func.func @all_gather_empty_mesh_axes( |
| // CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> |
| %arg0 : tensor<4xf32>) -> tensor<4xf32> { |
| // CHECK-NOT: mesh.all_gather |
| %0 = mesh.all_gather %arg0 on @mesh0 |
| mesh_axes = [] |
| gather_axis = 0 |
| : tensor<4xf32> -> tensor<4xf32> |
| // CHECK: return %[[ARG]] |
| return %0 : tensor<4xf32> |
| } |
| |
| // CHECK-LABEL: func @all_slice_empty_mesh_axes |
| func.func @all_slice_empty_mesh_axes( |
| // CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> |
| %arg0 : tensor<4xf32>) -> tensor<4xf32> { |
| // CHECK-NOT: mesh.scatter |
| %0 = mesh.all_slice %arg0 on @mesh0 |
| mesh_axes = [] |
| slice_axis = 0 |
| : tensor<4xf32> -> tensor<4xf32> |
| // CHECK: return %[[ARG]] |
| return %0 : tensor<4xf32> |
| } |
| |
| // CHECK-LABEL: func @broadcast_empty_mesh_axes |
| func.func @broadcast_empty_mesh_axes( |
| // CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> |
| %arg0 : tensor<4xf32>) -> tensor<4xf32> { |
| // CHECK-NOT: mesh.broadcast |
| %0 = mesh.broadcast %arg0 on @mesh0 |
| mesh_axes = [] |
| root = [] |
| : (tensor<4xf32>) -> tensor<4xf32> |
| // CHECK: return %[[ARG]] |
| return %0 : tensor<4xf32> |
| } |
| |
| // CHECK-LABEL: func @gather_empty_mesh_axes |
| func.func @gather_empty_mesh_axes( |
| // CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> |
| %arg0 : tensor<4xf32>) -> tensor<4xf32> { |
| // CHECK-NOT: mesh.gather |
| %0 = mesh.gather %arg0 on @mesh0 |
| mesh_axes = [] |
| gather_axis = 0 |
| root = [] |
| : (tensor<4xf32>) -> tensor<4xf32> |
| // CHECK: return %[[ARG]] |
| return %0 : tensor<4xf32> |
| } |
| |
| // CHECK-LABEL: func @receive_empty_mesh_axes |
| func.func @receive_empty_mesh_axes( |
| // CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> |
| %arg0 : tensor<4xf32>) -> tensor<4xf32> { |
| // CHECK-NOT: mesh.recv |
| %0 = mesh.recv %arg0 on @mesh0 |
| mesh_axes = [] |
| : (tensor<4xf32>) -> tensor<4xf32> |
| // CHECK: return %[[ARG]] |
| return %0 : tensor<4xf32> |
| } |
| |
| // CHECK-LABEL: func @reduce_empty_mesh_axes |
| func.func @reduce_empty_mesh_axes( |
| // CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> |
| %arg0 : tensor<4xf32>) -> tensor<4xf32> { |
| // CHECK-NOT: mesh.reduce |
| %0 = mesh.reduce %arg0 on @mesh0 |
| mesh_axes = [] |
| root = [] |
| : (tensor<4xf32>) -> tensor<4xf32> |
| // CHECK: return %[[ARG]] |
| return %0 : tensor<4xf32> |
| } |
| |
| // CHECK-LABEL: func @reduce_scatter_empty_mesh_axes |
| func.func @reduce_scatter_empty_mesh_axes( |
| // CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> |
| %arg0 : tensor<4xf32>) -> tensor<4xf32> { |
| // CHECK-NOT: mesh.reduce_scatter |
| %0 = mesh.reduce_scatter %arg0 on @mesh0 |
| mesh_axes = [] |
| scatter_axis = 0 |
| : tensor<4xf32> -> tensor<4xf32> |
| // CHECK: return %[[ARG]] |
| return %0 : tensor<4xf32> |
| } |
| |
| // CHECK-LABEL: func @reduce_scatter_empty_mesh_axes_different_return_type |
| func.func @reduce_scatter_empty_mesh_axes_different_return_type( |
| %arg0 : tensor<4xf32>) -> tensor<4xf64> { |
| // CHECK: mesh.reduce_scatter |
| %0 = mesh.reduce_scatter %arg0 on @mesh0 |
| // CHECK-NOT: mesh_axes |
| mesh_axes = [] |
| scatter_axis = 0 |
| : tensor<4xf32> -> tensor<4xf64> |
| return %0 : tensor<4xf64> |
| } |
| |
| // CHECK-LABEL: func @reduce_scatter_default_reduction |
| func.func @reduce_scatter_default_reduction( |
| %arg0 : tensor<4xf32>) -> tensor<2xf64> { |
| %0 = mesh.reduce_scatter %arg0 on @mesh0 |
| mesh_axes = [0] |
| // CHECK-NOT: reduction |
| reduction = <sum> |
| scatter_axis = 0 |
| : tensor<4xf32> -> tensor<2xf64> |
| return %0 : tensor<2xf64> |
| } |
| |
| // CHECK-LABEL: func @scatter_empty_mesh_axes |
| func.func @scatter_empty_mesh_axes( |
| // CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> |
| %arg0 : tensor<4xf32>) -> tensor<4xf32> { |
| // CHECK-NOT: mesh.scatter |
| %0 = mesh.scatter %arg0 on @mesh0 |
| mesh_axes = [] |
| scatter_axis = 0 |
| root = [] |
| : (tensor<4xf32>) -> tensor<4xf32> |
| // CHECK: return %[[ARG]] |
| return %0 : tensor<4xf32> |
| } |
| |
| // CHECK-LABEL: func @send_empty_mesh_axes |
| func.func @send_empty_mesh_axes( |
| // CHECK-SAME: %[[ARG:.*]]: tensor<4xf32> |
| %arg0 : tensor<4xf32>) -> tensor<4xf32> { |
| // CHECK-NOT: mesh.send |
| %0 = mesh.send %arg0 on @mesh0 |
| mesh_axes = [] |
| destination = [] |
| : (tensor<4xf32>) -> tensor<4xf32> |
| // CHECK: return %[[ARG]] |
| return %0 : tensor<4xf32> |
| } |