blob: b694e65062df8fbce99897c96e2d5eef47ee2653 [file] [log] [blame]
//===- MemRefTypeTest.cpp - MemRefType unit tests -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "gtest/gtest.h"
using namespace mlir;
using namespace mlir::detail;
namespace {
TEST(MemRefTypeTest, GetStridesAndOffset) {
MLIRContext context;
SmallVector<int64_t> shape({2, 3, 4});
Type f32 = FloatType::getF32(&context);
AffineMap map1 = makeStridedLinearLayoutMap({12, 4, 1}, 5, &context);
MemRefType type1 = MemRefType::get(shape, f32, {map1});
SmallVector<int64_t> strides1;
int64_t offset1 = -1;
LogicalResult res1 = getStridesAndOffset(type1, strides1, offset1);
ASSERT_TRUE(res1.succeeded());
ASSERT_EQ(3u, strides1.size());
EXPECT_EQ(12, strides1[0]);
EXPECT_EQ(4, strides1[1]);
EXPECT_EQ(1, strides1[2]);
ASSERT_EQ(5, offset1);
AffineMap map2 = AffineMap::getPermutationMap({1, 2, 0}, &context);
AffineMap map3 = makeStridedLinearLayoutMap({8, 2, 1}, 0, &context);
MemRefType type2 = MemRefType::get(shape, f32, {map2, map3});
SmallVector<int64_t> strides2;
int64_t offset2 = -1;
LogicalResult res2 = getStridesAndOffset(type2, strides2, offset2);
ASSERT_TRUE(res2.succeeded());
ASSERT_EQ(3u, strides2.size());
EXPECT_EQ(1, strides2[0]);
EXPECT_EQ(8, strides2[1]);
EXPECT_EQ(2, strides2[2]);
ASSERT_EQ(0, offset2);
}
} // end namespace