blob: 4c88df432b6c7dbc3112d67c0fa1472f9775ab8e [file] [edit]
// RUN: mlir-opt %s -acc-if-clause-lowering -split-input-file | FileCheck %s
// Test acc.parallel with if condition
// CHECK-LABEL: func.func @test_parallel_if
func.func @test_parallel_if(%arg0: memref<10xi32>, %cond: i1) {
%c0_i32 = arith.constant 0 : i32
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32>
%create = acc.create varPtr(%arg0 : memref<10xi32>) -> memref<10xi32> {dataClause = #acc<data_clause acc_copyout>}
// CHECK-NOT: acc.parallel if
// CHECK: scf.if %{{.*}} {
// CHECK: %[[COPYIN:.*]] = acc.copyin varPtr(%{{.*}}) -> memref<10xi32>
// CHECK: %[[CREATE:.*]] = acc.create varPtr(%{{.*}}) -> memref<10xi32>
// CHECK: acc.parallel dataOperands(%[[COPYIN]], %[[CREATE]] : memref<10xi32>, memref<10xi32>) {
// CHECK: scf.for
// CHECK: acc.yield
// CHECK: }
// CHECK: acc.delete accPtr(%[[CREATE]] : memref<10xi32>)
// CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<10xi32>) to varPtr(%{{.*}} : memref<10xi32>)
// CHECK: } else {
// CHECK: scf.for
// CHECK: }
acc.parallel dataOperands(%copyin, %create : memref<10xi32>, memref<10xi32>) if(%cond) {
scf.for %i = %c1 to %c10 step %c1 {
memref.store %c0_i32, %arg0[%i] : memref<10xi32>
}
acc.yield
}
acc.delete accPtr(%create : memref<10xi32>)
acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>)
return
}
// -----
// Test acc.parallel if lowering when host fallback region has multiple blocks.
// CHECK-LABEL: func.func @test_parallel_if_multiblock
func.func @test_parallel_if_multiblock(%cond: i1, %n: i32) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%counter = memref.alloca() : memref<i32>
memref.store %n, %counter[] : memref<i32>
// CHECK-NOT: acc.parallel if
// CHECK: scf.if %{{.*}} {
// CHECK: acc.parallel {
// CHECK: } else {
// CHECK: scf.execute_region {
// CHECK: ^bb
// CHECK: cf.cond_br
// CHECK: scf.yield
// CHECK: }
// CHECK: }
acc.parallel if(%cond) {
cf.br ^bb1
^bb1:
%v = memref.load %counter[] : memref<i32>
%pred = arith.cmpi sgt, %v, %c0_i32 : i32
cf.cond_br %pred, ^bb2, ^bb3
^bb2:
%next = arith.subi %v, %c1_i32 : i32
memref.store %next, %counter[] : memref<i32>
cf.br ^bb1
^bb3:
acc.yield
}
return
}
// -----
// Test acc.kernels with if condition
// CHECK-LABEL: func.func @test_kernels_if
func.func @test_kernels_if(%arg0: memref<5xi32>, %cond: i1) {
%c1_i32 = arith.constant 1 : i32
%c1 = arith.constant 1 : index
%c5 = arith.constant 5 : index
%copyin = acc.copyin varPtr(%arg0 : memref<5xi32>) -> memref<5xi32>
%create = acc.create varPtr(%arg0 : memref<5xi32>) -> memref<5xi32> {dataClause = #acc<data_clause acc_copyout>}
// CHECK-NOT: acc.kernels if
// CHECK: scf.if %{{.*}} {
// CHECK: %[[COPYIN:.*]] = acc.copyin
// CHECK: %[[CREATE:.*]] = acc.create
// CHECK: acc.kernels dataOperands(%[[COPYIN]], %[[CREATE]] : memref<5xi32>, memref<5xi32>) {
// CHECK: scf.for
// CHECK: acc.terminator
// CHECK: }
// CHECK: acc.delete accPtr(%[[CREATE]] : memref<5xi32>)
// CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<5xi32>) to varPtr(%{{.*}} : memref<5xi32>)
// CHECK: } else {
// CHECK: scf.for
// CHECK: }
acc.kernels dataOperands(%copyin, %create : memref<5xi32>, memref<5xi32>) if(%cond) {
scf.for %i = %c1 to %c5 step %c1 {
memref.store %c1_i32, %arg0[%i] : memref<5xi32>
}
acc.terminator
}
acc.delete accPtr(%create : memref<5xi32>)
acc.copyout accPtr(%copyin : memref<5xi32>) to varPtr(%arg0 : memref<5xi32>)
return
}
// -----
// Test acc.serial with if condition
// CHECK-LABEL: func.func @test_serial_if
func.func @test_serial_if(%arg0: memref<8xi32>, %cond: i1) {
%c2_i32 = arith.constant 2 : i32
%c1 = arith.constant 1 : index
%c8 = arith.constant 8 : index
%copyin = acc.copyin varPtr(%arg0 : memref<8xi32>) -> memref<8xi32>
%create = acc.create varPtr(%arg0 : memref<8xi32>) -> memref<8xi32> {dataClause = #acc<data_clause acc_copyout>}
// CHECK-NOT: acc.serial if
// CHECK: scf.if %{{.*}} {
// CHECK: %[[COPYIN:.*]] = acc.copyin
// CHECK: %[[CREATE:.*]] = acc.create
// CHECK: acc.serial dataOperands(%[[COPYIN]], %[[CREATE]] : memref<8xi32>, memref<8xi32>) {
// CHECK: scf.for
// CHECK: acc.yield
// CHECK: }
// CHECK: acc.delete accPtr(%[[CREATE]] : memref<8xi32>)
// CHECK: acc.copyout accPtr(%[[COPYIN]] : memref<8xi32>) to varPtr(%{{.*}} : memref<8xi32>)
// CHECK: } else {
// CHECK: scf.for
// CHECK: }
acc.serial dataOperands(%copyin, %create : memref<8xi32>, memref<8xi32>) if(%cond) {
scf.for %i = %c1 to %c8 step %c1 {
memref.store %c2_i32, %arg0[%i] : memref<8xi32>
}
acc.yield
}
acc.delete accPtr(%create : memref<8xi32>)
acc.copyout accPtr(%copyin : memref<8xi32>) to varPtr(%arg0 : memref<8xi32>)
return
}
// -----
// Test that acc.parallel without if condition is not modified
// CHECK-LABEL: func.func @test_parallel_no_if
func.func @test_parallel_no_if(%arg0: memref<10xi32>) {
%c0_i32 = arith.constant 0 : i32
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32>
// CHECK-NOT: scf.if
// CHECK: acc.parallel dataOperands(%{{.*}}) {
acc.parallel dataOperands(%copyin : memref<10xi32>) {
scf.for %i = %c1 to %c10 step %c1 {
memref.store %c0_i32, %arg0[%i] : memref<10xi32>
}
acc.yield
}
acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>)
return
}
// -----
// Test with private and reduction clauses inside compute construct
acc.private.recipe @privatization_memref_i32 : memref<i32> init {
^bb0(%arg0: memref<i32>):
%0 = memref.alloca() : memref<i32>
acc.yield %0 : memref<i32>
}
acc.reduction.recipe @reduction_add_memref_f32 : memref<f32> reduction_operator <add> init {
^bb0(%arg0: memref<f32>):
%cst = arith.constant 0.000000e+00 : f32
%0 = memref.alloca() : memref<f32>
memref.store %cst, %0[] : memref<f32>
acc.yield %0 : memref<f32>
} combiner {
^bb0(%arg0: memref<f32>, %arg1: memref<f32>):
%0 = memref.load %arg0[] : memref<f32>
%1 = memref.load %arg1[] : memref<f32>
%2 = arith.addf %0, %1 : f32
memref.store %2, %arg0[] : memref<f32>
acc.yield %arg0 : memref<f32>
}
// CHECK-LABEL: func.func @test_reduction_if
func.func @test_reduction_if(%r: memref<f32>, %a: memref<8xf32>, %cond: i1) {
%c8_i32 = arith.constant 8 : i32
%c1_i32 = arith.constant 1 : i32
%copyin = acc.copyin varPtr(%r : memref<f32>) -> memref<f32> {dataClause = #acc<data_clause acc_reduction>, implicit = true}
// CHECK: scf.if
// CHECK: acc.parallel
// CHECK: } else {
// The else branch should have acc ops converted to host
// CHECK-NOT: acc.loop
// CHECK-NOT: acc.reduction
// CHECK-NOT: acc.private
// CHECK: }
acc.parallel combined(loop) dataOperands(%copyin : memref<f32>) if(%cond) {
%red = acc.reduction varPtr(%r : memref<f32>) recipe(@reduction_add_memref_f32) -> memref<f32>
%iter_var = memref.alloca() : memref<i32>
%priv = acc.private varPtr(%iter_var : memref<i32>) recipe(@privatization_memref_i32) -> memref<i32>
acc.loop combined(parallel) vector private(%priv : memref<i32>) reduction(%red : memref<f32>) control(%iv : i32) = (%c1_i32 : i32) to (%c8_i32 : i32) step (%c1_i32 : i32) {
memref.store %iv, %priv[] : memref<i32>
%idx = memref.load %priv[] : memref<i32>
%idx_cast = arith.index_cast %idx : i32 to index
%elem = memref.load %a[%idx_cast] : memref<8xf32>
%r_val = memref.load %r[] : memref<f32>
%new_r = arith.addf %r_val, %elem : f32
memref.store %new_r, %r[] : memref<f32>
acc.yield
} attributes {inclusiveUpperbound = array<i1: true>, independent = [#acc.device_type<none>]}
acc.yield
}
acc.copyout accPtr(%copyin : memref<f32>) to varPtr(%r : memref<f32>) {dataClause = #acc<data_clause acc_reduction>, implicit = true}
return
}
// -----
// Test that acc variable uses in host path are replaced with host variables
// CHECK-LABEL: func.func @test_acc_var_replacement
func.func @test_acc_var_replacement(%arg0: memref<10xi32>, %cond: i1) {
%c0_i32 = arith.constant 0 : i32
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32>
// In the else branch, uses of %copyin should be replaced with %arg0
// CHECK: scf.if
// CHECK: } else {
// CHECK: scf.for
// CHECK: memref.store %{{.*}}, %arg0[%{{.*}}]
// CHECK: }
acc.parallel dataOperands(%copyin : memref<10xi32>) if(%cond) {
scf.for %i = %c1 to %c10 step %c1 {
// Use the acc ptr inside the region
memref.store %c0_i32, %copyin[%i] : memref<10xi32>
}
acc.yield
}
acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>)
return
}
// -----
// Test that acc variable uses in host path are replaced with host variables;
// and the firstprivate operands are cloned
// CHECK-LABEL: func.func @test_acc_firstprivate
acc.firstprivate.recipe @memref_i32 : memref<i32> init {
^bb0(%arg0: memref<i32>):
%0 = memref.alloca() : memref<i32>
acc.yield %0 : memref<i32>
} copy {
^bb0(%arg0: memref<i32>, %arg1: memref<i32>):
%0 = memref.load %arg0[] : memref<i32>
memref.store %0, %arg1[] : memref<i32>
acc.terminator
}
func.func @test_acc_firstprivate(%arg0: memref<10xi32>, %arg1: memref<i32>, %cond: i1) {
%c0_i32 = arith.constant 0 : i32
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32>
%firstprivate = acc.firstprivate varPtr(%arg1 : memref<i32>) recipe(@memref_i32) -> memref<i32>
// In the else branch, uses of %firstprivate should be replaced with %arg0
// CHECK: scf.if
// CHECK: [[FIRSTPRIVATE:%.*]] = acc.firstprivate varPtr(%arg1 : memref<i32>) recipe(@memref_i32) -> memref<i32>
// CHECK: acc.parallel {{.*}} firstprivate([[FIRSTPRIVATE]] : memref<i32>) {
// CHECK: } else {
// CHECK: [[LOAD:%.*]] = memref.load %arg1[] : memref<i32>
// CHECK: }
acc.parallel dataOperands(%copyin : memref<10xi32>) firstprivate(%firstprivate : memref<i32>) if(%cond) {
%load = memref.load %firstprivate[] : memref<i32>
%ub = arith.index_cast %load : i32 to index
scf.for %i = %c1 to %ub step %c1 {
// Use the acc ptr inside the region
memref.store %c0_i32, %copyin[%i] : memref<10xi32>
}
acc.yield
}
acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>)
return
}
// -----
// Test that acc variable uses in host path are replaced with host variables;
// and the reduction operands are cloned
// CHECK-LABEL: func.func @test_acc_reduction
acc.reduction.recipe @reduction_add_memref_i32 : memref<i32> reduction_operator <add> init {
^bb0(%arg0: memref<i32>):
%c0_i32 = arith.constant 0 : i32
%0 = memref.alloca() : memref<i32>
memref.store %c0_i32, %0[] : memref<i32>
acc.yield %0 : memref<i32>
} combiner {
^bb0(%arg0: memref<i32>, %arg1: memref<i32>):
%0 = memref.load %arg1[] : memref<i32>
%1 = memref.load %arg0[] : memref<i32>
%2 = arith.addi %1, %0 : i32
memref.store %2, %arg0[] : memref<i32>
acc.yield %arg0 : memref<i32>
}
func.func @test_acc_reduction(%arg0: memref<i32>, %cond: i1) {
%c0_i32 = arith.constant 0 : i32
%reduction = acc.reduction varPtr(%arg0 : memref<i32>) recipe(@reduction_add_memref_i32) -> memref<i32>
// In the else branch, uses of %reduction should be replaced with %arg0
// CHECK: scf.if
// CHECK: [[REDUCTION:%.*]] = acc.reduction varPtr(%arg0 : memref<i32>) recipe(@reduction_add_memref_i32) -> memref<i32>
// CHECK: acc.parallel reduction([[REDUCTION]] : memref<i32>) {
// CHECK: } else {
// CHECK: [[LOAD:%.*]] = memref.load %arg0[] : memref<i32>
// CHECK: memref.store {{.*}}, %arg0[] : memref<i32>
// CHECK: }
acc.parallel reduction(%reduction : memref<i32>) if(%cond) {
%load = memref.load %reduction[] : memref<i32>
%add = arith.addi %load, %c0_i32 : i32
memref.store %add, %reduction[] : memref<i32>
acc.yield
}
return
}
acc.private.recipe @privatization_memref_i32 : memref<i32> init {
^bb0(%arg0: memref<i32>):
%0 = memref.alloca() : memref<i32>
acc.yield %0 : memref<i32>
}
func.func @test_acc_private(%arg0: memref<i32>, %cond: i1) {
%c0_i32 = arith.constant 0 : i32
%private = acc.private varPtr(%arg0 : memref<i32>) recipe(@privatization_memref_i32) -> memref<i32>
// In the else branch, uses of %private should be replaced with %arg0
// CHECK: scf.if
// CHECK: [[PRIVATE:%.*]] = acc.private varPtr(%arg0 : memref<i32>) recipe(@privatization_memref_i32) -> memref<i32>
// CHECK: acc.parallel private([[PRIVATE]] : memref<i32>) {
// CHECK: } else {
// CHECK: memref.store {{.*}}, %arg0[] : memref<i32>
// CHECK: }
acc.parallel private(%private : memref<i32>) if(%cond) {
memref.store %c0_i32, %private[] : memref<i32>
acc.yield
}
return
}