| //===- LegalityTest.cpp ---------------------------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h" |
| #include "llvm/Analysis/AssumptionCache.h" |
| #include "llvm/Analysis/LoopInfo.h" |
| #include "llvm/Analysis/ScalarEvolution.h" |
| #include "llvm/Analysis/TargetLibraryInfo.h" |
| #include "llvm/AsmParser/Parser.h" |
| #include "llvm/IR/DataLayout.h" |
| #include "llvm/IR/Dominators.h" |
| #include "llvm/SandboxIR/Function.h" |
| #include "llvm/SandboxIR/Instruction.h" |
| #include "llvm/Support/SourceMgr.h" |
| #include "gtest/gtest.h" |
| |
| using namespace llvm; |
| |
| struct LegalityTest : public testing::Test { |
| LLVMContext C; |
| std::unique_ptr<Module> M; |
| std::unique_ptr<DominatorTree> DT; |
| std::unique_ptr<TargetLibraryInfoImpl> TLII; |
| std::unique_ptr<TargetLibraryInfo> TLI; |
| std::unique_ptr<AssumptionCache> AC; |
| std::unique_ptr<LoopInfo> LI; |
| std::unique_ptr<ScalarEvolution> SE; |
| |
| ScalarEvolution &getSE(llvm::Function &LLVMF) { |
| DT = std::make_unique<DominatorTree>(LLVMF); |
| TLII = std::make_unique<TargetLibraryInfoImpl>(); |
| TLI = std::make_unique<TargetLibraryInfo>(*TLII); |
| AC = std::make_unique<AssumptionCache>(LLVMF); |
| LI = std::make_unique<LoopInfo>(*DT); |
| SE = std::make_unique<ScalarEvolution>(LLVMF, *TLI, *AC, *DT, *LI); |
| return *SE; |
| } |
| |
| void parseIR(LLVMContext &C, const char *IR) { |
| SMDiagnostic Err; |
| M = parseAssemblyString(IR, Err, C); |
| if (!M) |
| Err.print("LegalityTest", errs()); |
| } |
| }; |
| |
| TEST_F(LegalityTest, Legality) { |
| parseIR(C, R"IR( |
| define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) { |
| %gep0 = getelementptr float, ptr %ptr, i32 0 |
| %gep1 = getelementptr float, ptr %ptr, i32 1 |
| %gep3 = getelementptr float, ptr %ptr, i32 3 |
| %ld0 = load float, ptr %gep0 |
| %ld0b = load float, ptr %gep0 |
| %ld1 = load float, ptr %gep1 |
| %ld3 = load float, ptr %gep3 |
| store float %ld0, ptr %gep0 |
| store float %ld1, ptr %gep1 |
| store <2 x float> %vec2, ptr %gep1 |
| store <3 x float> %vec3, ptr %gep3 |
| store i8 %arg, ptr %gep1 |
| %fadd0 = fadd float %farg0, %farg0 |
| %fadd1 = fadd fast float %farg1, %farg1 |
| %trunc0 = trunc nuw nsw i64 %v0 to i8 |
| %trunc1 = trunc nsw i64 %v1 to i8 |
| %trunc64to8 = trunc i64 %v0 to i8 |
| %trunc32to8 = trunc i32 %v2 to i8 |
| %cmpSLT = icmp slt i64 %v0, %v1 |
| %cmpSGT = icmp sgt i64 %v0, %v1 |
| ret void |
| } |
| )IR"); |
| llvm::Function *LLVMF = &*M->getFunction("foo"); |
| auto &SE = getSE(*LLVMF); |
| const auto &DL = M->getDataLayout(); |
| |
| sandboxir::Context Ctx(C); |
| auto *F = Ctx.createFunction(LLVMF); |
| auto *BB = &*F->begin(); |
| auto It = BB->begin(); |
| [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++); |
| [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++); |
| [[maybe_unused]] auto *Gep3 = cast<sandboxir::GetElementPtrInst>(&*It++); |
| auto *Ld0 = cast<sandboxir::LoadInst>(&*It++); |
| auto *Ld0b = cast<sandboxir::LoadInst>(&*It++); |
| auto *Ld1 = cast<sandboxir::LoadInst>(&*It++); |
| auto *Ld3 = cast<sandboxir::LoadInst>(&*It++); |
| auto *St0 = cast<sandboxir::StoreInst>(&*It++); |
| auto *St1 = cast<sandboxir::StoreInst>(&*It++); |
| auto *StVec2 = cast<sandboxir::StoreInst>(&*It++); |
| auto *StVec3 = cast<sandboxir::StoreInst>(&*It++); |
| auto *StI8 = cast<sandboxir::StoreInst>(&*It++); |
| auto *FAdd0 = cast<sandboxir::BinaryOperator>(&*It++); |
| auto *FAdd1 = cast<sandboxir::BinaryOperator>(&*It++); |
| auto *Trunc0 = cast<sandboxir::TruncInst>(&*It++); |
| auto *Trunc1 = cast<sandboxir::TruncInst>(&*It++); |
| auto *Trunc64to8 = cast<sandboxir::TruncInst>(&*It++); |
| auto *Trunc32to8 = cast<sandboxir::TruncInst>(&*It++); |
| auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++); |
| auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++); |
| |
| sandboxir::LegalityAnalysis Legality(SE, DL); |
| const auto &Result = Legality.canVectorize({St0, St1}); |
| EXPECT_TRUE(isa<sandboxir::Widen>(Result)); |
| |
| { |
| // Check NotInstructions |
| auto &Result = Legality.canVectorize({F, St0}); |
| EXPECT_TRUE(isa<sandboxir::Pack>(Result)); |
| EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), |
| sandboxir::ResultReason::NotInstructions); |
| } |
| { |
| // Check DiffOpcodes |
| const auto &Result = Legality.canVectorize({St0, Ld0}); |
| EXPECT_TRUE(isa<sandboxir::Pack>(Result)); |
| EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), |
| sandboxir::ResultReason::DiffOpcodes); |
| } |
| { |
| // Check DiffTypes |
| EXPECT_TRUE(isa<sandboxir::Widen>(Legality.canVectorize({St0, StVec2}))); |
| EXPECT_TRUE(isa<sandboxir::Widen>(Legality.canVectorize({StVec2, StVec3}))); |
| |
| const auto &Result = Legality.canVectorize({St0, StI8}); |
| EXPECT_TRUE(isa<sandboxir::Pack>(Result)); |
| EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), |
| sandboxir::ResultReason::DiffTypes); |
| } |
| { |
| // Check DiffMathFlags |
| const auto &Result = Legality.canVectorize({FAdd0, FAdd1}); |
| EXPECT_TRUE(isa<sandboxir::Pack>(Result)); |
| EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), |
| sandboxir::ResultReason::DiffMathFlags); |
| } |
| { |
| // Check DiffWrapFlags |
| const auto &Result = Legality.canVectorize({Trunc0, Trunc1}); |
| EXPECT_TRUE(isa<sandboxir::Pack>(Result)); |
| EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), |
| sandboxir::ResultReason::DiffWrapFlags); |
| } |
| { |
| // Check DiffTypes for unary operands that have a different type. |
| const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8}); |
| EXPECT_TRUE(isa<sandboxir::Pack>(Result)); |
| EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), |
| sandboxir::ResultReason::DiffTypes); |
| } |
| { |
| // Check DiffOpcodes for CMPs with different predicates. |
| const auto &Result = Legality.canVectorize({CmpSLT, CmpSGT}); |
| EXPECT_TRUE(isa<sandboxir::Pack>(Result)); |
| EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), |
| sandboxir::ResultReason::DiffOpcodes); |
| } |
| { |
| // Check NotConsecutive Ld0,Ld0b |
| const auto &Result = Legality.canVectorize({Ld0, Ld0b}); |
| EXPECT_TRUE(isa<sandboxir::Pack>(Result)); |
| EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), |
| sandboxir::ResultReason::NotConsecutive); |
| } |
| { |
| // Check NotConsecutive Ld0,Ld3 |
| const auto &Result = Legality.canVectorize({Ld0, Ld3}); |
| EXPECT_TRUE(isa<sandboxir::Pack>(Result)); |
| EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(), |
| sandboxir::ResultReason::NotConsecutive); |
| } |
| { |
| // Check Widen Ld0,Ld1 |
| const auto &Result = Legality.canVectorize({Ld0, Ld1}); |
| EXPECT_TRUE(isa<sandboxir::Widen>(Result)); |
| } |
| } |
| |
| #ifndef NDEBUG |
| TEST_F(LegalityTest, LegalityResultDump) { |
| parseIR(C, R"IR( |
| define void @foo() { |
| ret void |
| } |
| )IR"); |
| llvm::Function *LLVMF = &*M->getFunction("foo"); |
| auto &SE = getSE(*LLVMF); |
| const auto &DL = M->getDataLayout(); |
| |
| auto Matches = [](const sandboxir::LegalityResult &Result, |
| const std::string &ExpectedStr) -> bool { |
| std::string Buff; |
| raw_string_ostream OS(Buff); |
| Result.print(OS); |
| return Buff == ExpectedStr; |
| }; |
| |
| sandboxir::LegalityAnalysis Legality(SE, DL); |
| EXPECT_TRUE( |
| Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen")); |
| EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>( |
| sandboxir::ResultReason::NotInstructions), |
| "Pack Reason: NotInstructions")); |
| EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>( |
| sandboxir::ResultReason::DiffOpcodes), |
| "Pack Reason: DiffOpcodes")); |
| EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>( |
| sandboxir::ResultReason::DiffTypes), |
| "Pack Reason: DiffTypes")); |
| EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>( |
| sandboxir::ResultReason::DiffMathFlags), |
| "Pack Reason: DiffMathFlags")); |
| EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>( |
| sandboxir::ResultReason::DiffWrapFlags), |
| "Pack Reason: DiffWrapFlags")); |
| } |
| #endif // NDEBUG |