AMDGPU: Rewrite VGPR MFMAs to AGPR when directly copied to AGPR class (#152480)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp b/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp
index f580f43..c21a9a1 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURewriteAGPRCopyMFMA.cpp
@@ -109,12 +109,17 @@
// Find AV_* registers assigned to AGPRs.
const TargetRegisterClass *VirtRegRC = MRI.getRegClass(VReg);
- if (!TRI.isVectorSuperClass(VirtRegRC))
+ if (!TRI.hasAGPRs(VirtRegRC))
continue;
- const TargetRegisterClass *AssignedRC = TRI.getPhysRegBaseClass(PhysReg);
- if (!TRI.isAGPRClass(AssignedRC))
- continue;
+ const TargetRegisterClass *AssignedRC = VirtRegRC;
+ if (TRI.hasVGPRs(VirtRegRC)) {
+ // If this is an AV register, we have to check if the actual assignment is
+ // to an AGPR
+ AssignedRC = TRI.getPhysRegBaseClass(PhysReg);
+ if (!TRI.isAGPRClass(AssignedRC))
+ continue;
+ }
LiveInterval &LI = LIS.getInterval(VReg);
diff --git a/llvm/test/CodeGen/AMDGPU/rewrite-vgpr-mfma-to-agpr.ll b/llvm/test/CodeGen/AMDGPU/rewrite-vgpr-mfma-to-agpr.ll
index 0b43ff2..b35a74e 100644
--- a/llvm/test/CodeGen/AMDGPU/rewrite-vgpr-mfma-to-agpr.ll
+++ b/llvm/test/CodeGen/AMDGPU/rewrite-vgpr-mfma-to-agpr.ll
@@ -200,8 +200,199 @@
ret void
}
-declare <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float, float, <32 x float>, i32 immarg, i32 immarg, i32 immarg) #1
-declare noundef i32 @llvm.amdgcn.workitem.id.x() #2
+; The inline asm requires the value be copied to an AGPR class, not
+; the AV_* pseudo we usually expect for register allocator live range
+; splits.
+define amdgpu_kernel void @test_rewrite_mfma_direct_copy_to_agpr_class(ptr addrspace(1) %arg) #0 {
+; CHECK-LABEL: test_rewrite_mfma_direct_copy_to_agpr_class:
+; CHECK: ; %bb.0: ; %bb
+; CHECK-NEXT: s_load_dwordx2 s[0:1], s[4:5], 0x0
+; CHECK-NEXT: v_and_b32_e32 v0, 0x3ff, v0
+; CHECK-NEXT: v_lshlrev_b32_e32 v0, 7, v0
+; CHECK-NEXT: v_mov_b32_e32 v32, 2.0
+; CHECK-NEXT: v_mov_b32_e32 v33, 4.0
+; CHECK-NEXT: s_waitcnt lgkmcnt(0)
+; CHECK-NEXT: global_load_dwordx4 a[28:31], v0, s[0:1] offset:112
+; CHECK-NEXT: global_load_dwordx4 a[24:27], v0, s[0:1] offset:96
+; CHECK-NEXT: global_load_dwordx4 a[20:23], v0, s[0:1] offset:80
+; CHECK-NEXT: global_load_dwordx4 a[16:19], v0, s[0:1] offset:64
+; CHECK-NEXT: global_load_dwordx4 a[12:15], v0, s[0:1] offset:48
+; CHECK-NEXT: global_load_dwordx4 a[8:11], v0, s[0:1] offset:32
+; CHECK-NEXT: global_load_dwordx4 a[4:7], v0, s[0:1] offset:16
+; CHECK-NEXT: global_load_dwordx4 a[0:3], v0, s[0:1]
+; CHECK-NEXT: s_waitcnt vmcnt(0)
+; CHECK-NEXT: v_mfma_f32_32x32x1_2b_f32 a[0:31], v32, v33, a[0:31]
+; CHECK-NEXT: ;;#ASMSTART
+; CHECK-NEXT: ; use a[0:31]
+; CHECK-NEXT: ;;#ASMEND
+; CHECK-NEXT: s_endpgm
+bb:
+ %id = call i32 @llvm.amdgcn.workitem.id.x()
+ %gep = getelementptr <32 x float>, ptr addrspace(1) %arg, i32 %id
+ %in = load <32 x float>, ptr addrspace(1) %gep, align 128
+ %mai = call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float 2.0, float 4.0, <32 x float> %in, i32 0, i32 0, i32 0)
+ call void asm sideeffect "; use $0", "a"(<32 x float> %mai)
+ ret void
+}
+
+; TODO: Handle rewriting this case
+define void @test_rewrite_mfma_imm_src2(float %arg0, float %arg1) #0 {
+; CHECK-LABEL: test_rewrite_mfma_imm_src2:
+; CHECK: ; %bb.0: ; %bb
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: v_mfma_f32_32x32x1_2b_f32 v[0:31], v0, v1, 2.0
+; CHECK-NEXT: s_nop 7
+; CHECK-NEXT: s_nop 7
+; CHECK-NEXT: s_nop 1
+; CHECK-NEXT: v_accvgpr_write_b32 a0, v0
+; CHECK-NEXT: v_accvgpr_write_b32 a1, v1
+; CHECK-NEXT: v_accvgpr_write_b32 a2, v2
+; CHECK-NEXT: v_accvgpr_write_b32 a3, v3
+; CHECK-NEXT: v_accvgpr_write_b32 a4, v4
+; CHECK-NEXT: v_accvgpr_write_b32 a5, v5
+; CHECK-NEXT: v_accvgpr_write_b32 a6, v6
+; CHECK-NEXT: v_accvgpr_write_b32 a7, v7
+; CHECK-NEXT: v_accvgpr_write_b32 a8, v8
+; CHECK-NEXT: v_accvgpr_write_b32 a9, v9
+; CHECK-NEXT: v_accvgpr_write_b32 a10, v10
+; CHECK-NEXT: v_accvgpr_write_b32 a11, v11
+; CHECK-NEXT: v_accvgpr_write_b32 a12, v12
+; CHECK-NEXT: v_accvgpr_write_b32 a13, v13
+; CHECK-NEXT: v_accvgpr_write_b32 a14, v14
+; CHECK-NEXT: v_accvgpr_write_b32 a15, v15
+; CHECK-NEXT: v_accvgpr_write_b32 a16, v16
+; CHECK-NEXT: v_accvgpr_write_b32 a17, v17
+; CHECK-NEXT: v_accvgpr_write_b32 a18, v18
+; CHECK-NEXT: v_accvgpr_write_b32 a19, v19
+; CHECK-NEXT: v_accvgpr_write_b32 a20, v20
+; CHECK-NEXT: v_accvgpr_write_b32 a21, v21
+; CHECK-NEXT: v_accvgpr_write_b32 a22, v22
+; CHECK-NEXT: v_accvgpr_write_b32 a23, v23
+; CHECK-NEXT: v_accvgpr_write_b32 a24, v24
+; CHECK-NEXT: v_accvgpr_write_b32 a25, v25
+; CHECK-NEXT: v_accvgpr_write_b32 a26, v26
+; CHECK-NEXT: v_accvgpr_write_b32 a27, v27
+; CHECK-NEXT: v_accvgpr_write_b32 a28, v28
+; CHECK-NEXT: v_accvgpr_write_b32 a29, v29
+; CHECK-NEXT: v_accvgpr_write_b32 a30, v30
+; CHECK-NEXT: v_accvgpr_write_b32 a31, v31
+; CHECK-NEXT: ;;#ASMSTART
+; CHECK-NEXT: ; use a[0:31]
+; CHECK-NEXT: ;;#ASMEND
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+bb:
+ %mai = call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float %arg0, float %arg1, <32 x float> splat (float 2.0), i32 0, i32 0, i32 0)
+ call void asm sideeffect "; use $0", "a"(<32 x float> %mai)
+ ret void
+}
+
+; TODO: Handle rewriting this case
+define void @test_rewrite_mfma_subreg_extract0(float %arg0, float %arg1, ptr addrspace(1) %ptr) #0 {
+; CHECK-LABEL: test_rewrite_mfma_subreg_extract0:
+; CHECK: ; %bb.0: ; %bb
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: global_load_dwordx4 v[30:33], v[2:3], off offset:112
+; CHECK-NEXT: global_load_dwordx4 v[26:29], v[2:3], off offset:96
+; CHECK-NEXT: global_load_dwordx4 v[22:25], v[2:3], off offset:80
+; CHECK-NEXT: global_load_dwordx4 v[18:21], v[2:3], off offset:64
+; CHECK-NEXT: global_load_dwordx4 v[14:17], v[2:3], off offset:48
+; CHECK-NEXT: global_load_dwordx4 v[10:13], v[2:3], off offset:32
+; CHECK-NEXT: global_load_dwordx4 v[6:9], v[2:3], off offset:16
+; CHECK-NEXT: s_nop 0
+; CHECK-NEXT: global_load_dwordx4 v[2:5], v[2:3], off
+; CHECK-NEXT: s_waitcnt vmcnt(0)
+; CHECK-NEXT: v_mfma_f32_32x32x1_2b_f32 v[2:33], v0, v1, v[2:33]
+; CHECK-NEXT: s_nop 7
+; CHECK-NEXT: s_nop 7
+; CHECK-NEXT: s_nop 1
+; CHECK-NEXT: v_accvgpr_write_b32 a0, v2
+; CHECK-NEXT: v_accvgpr_write_b32 a1, v3
+; CHECK-NEXT: v_accvgpr_write_b32 a2, v4
+; CHECK-NEXT: v_accvgpr_write_b32 a3, v5
+; CHECK-NEXT: ;;#ASMSTART
+; CHECK-NEXT: ; use a[0:3]
+; CHECK-NEXT: ;;#ASMEND
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+bb:
+ %src2 = load <32 x float>, ptr addrspace(1) %ptr
+ %mai = call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float %arg0, float %arg1, <32 x float> %src2, i32 0, i32 0, i32 0)
+ %extract.sub4 = shufflevector <32 x float> %mai, <32 x float> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ call void asm sideeffect "; use $0", "a"(<4 x float> %extract.sub4)
+ ret void
+}
+
+define void @test_rewrite_mfma_subreg_extract1(float %arg0, float %arg1, ptr addrspace(1) %ptr) #0 {
+; CHECK-LABEL: test_rewrite_mfma_subreg_extract1:
+; CHECK: ; %bb.0: ; %bb
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: global_load_dwordx4 v[30:33], v[2:3], off offset:112
+; CHECK-NEXT: global_load_dwordx4 v[26:29], v[2:3], off offset:96
+; CHECK-NEXT: global_load_dwordx4 v[22:25], v[2:3], off offset:80
+; CHECK-NEXT: global_load_dwordx4 v[18:21], v[2:3], off offset:64
+; CHECK-NEXT: global_load_dwordx4 v[14:17], v[2:3], off offset:48
+; CHECK-NEXT: global_load_dwordx4 v[10:13], v[2:3], off offset:32
+; CHECK-NEXT: global_load_dwordx4 v[6:9], v[2:3], off offset:16
+; CHECK-NEXT: s_nop 0
+; CHECK-NEXT: global_load_dwordx4 v[2:5], v[2:3], off
+; CHECK-NEXT: s_waitcnt vmcnt(0)
+; CHECK-NEXT: v_mfma_f32_32x32x1_2b_f32 v[2:33], v0, v1, v[2:33]
+; CHECK-NEXT: s_nop 7
+; CHECK-NEXT: s_nop 7
+; CHECK-NEXT: s_nop 1
+; CHECK-NEXT: v_accvgpr_write_b32 a0, v6
+; CHECK-NEXT: v_accvgpr_write_b32 a1, v7
+; CHECK-NEXT: v_accvgpr_write_b32 a2, v8
+; CHECK-NEXT: v_accvgpr_write_b32 a3, v9
+; CHECK-NEXT: ;;#ASMSTART
+; CHECK-NEXT: ; use a[0:3]
+; CHECK-NEXT: ;;#ASMEND
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+bb:
+ %src2 = load <32 x float>, ptr addrspace(1) %ptr
+ %mai = call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float %arg0, float %arg1, <32 x float> %src2, i32 0, i32 0, i32 0)
+ %extract.sub4 = shufflevector <32 x float> %mai, <32 x float> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+ call void asm sideeffect "; use $0", "a"(<4 x float> %extract.sub4)
+ ret void
+}
+
+; odd offset
+define void @test_rewrite_mfma_subreg_extract2(float %arg0, float %arg1, ptr addrspace(1) %ptr) #0 {
+; CHECK-LABEL: test_rewrite_mfma_subreg_extract2:
+; CHECK: ; %bb.0: ; %bb
+; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
+; CHECK-NEXT: global_load_dwordx4 v[30:33], v[2:3], off offset:112
+; CHECK-NEXT: global_load_dwordx4 v[26:29], v[2:3], off offset:96
+; CHECK-NEXT: global_load_dwordx4 v[22:25], v[2:3], off offset:80
+; CHECK-NEXT: global_load_dwordx4 v[18:21], v[2:3], off offset:64
+; CHECK-NEXT: global_load_dwordx4 v[14:17], v[2:3], off offset:48
+; CHECK-NEXT: global_load_dwordx4 v[10:13], v[2:3], off offset:32
+; CHECK-NEXT: global_load_dwordx4 v[6:9], v[2:3], off offset:16
+; CHECK-NEXT: s_nop 0
+; CHECK-NEXT: global_load_dwordx4 v[2:5], v[2:3], off
+; CHECK-NEXT: s_waitcnt vmcnt(0)
+; CHECK-NEXT: v_mfma_f32_32x32x1_2b_f32 v[2:33], v0, v1, v[2:33]
+; CHECK-NEXT: s_nop 7
+; CHECK-NEXT: s_nop 7
+; CHECK-NEXT: s_nop 1
+; CHECK-NEXT: v_accvgpr_write_b32 a0, v3
+; CHECK-NEXT: v_accvgpr_write_b32 a1, v4
+; CHECK-NEXT: v_accvgpr_write_b32 a2, v5
+; CHECK-NEXT: v_accvgpr_write_b32 a3, v6
+; CHECK-NEXT: ;;#ASMSTART
+; CHECK-NEXT: ; use a[0:3]
+; CHECK-NEXT: ;;#ASMEND
+; CHECK-NEXT: s_setpc_b64 s[30:31]
+bb:
+ %src2 = load <32 x float>, ptr addrspace(1) %ptr
+ %mai = call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float %arg0, float %arg1, <32 x float> %src2, i32 0, i32 0, i32 0)
+ %extract.sub4 = shufflevector <32 x float> %mai, <32 x float> poison, <4 x i32> <i32 1, i32 2, i32 3, i32 4>
+ call void asm sideeffect "; use $0", "a"(<4 x float> %extract.sub4)
+ ret void
+}
+
+declare <4 x float> @llvm.amdgcn.mfma.f32.16x16x16f16(<4 x half>, <4 x half>, <4 x float>, i32 immarg, i32 immarg, i32 immarg) #2
+declare <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float, float, <32 x float>, i32 immarg, i32 immarg, i32 immarg) #2
+declare noundef range(i32 0, 1024) i32 @llvm.amdgcn.workitem.id.x() #3
attributes #0 = { nounwind "amdgpu-flat-work-group-size"="1,256" "amdgpu-waves-per-eu"="4,4" }
attributes #1 = { convergent nocallback nofree nosync nounwind willreturn memory(none) }