[mlir][llvm] Support nusw and nuw in GEP (#137272)
nusw and nuw were introduced in getelementptr, this patch plumbs them in
MLIR.
Since inbounds implies nusw, this patch also adds an inboundsFlag to
represent the concept of raw inbounds with no nusw implication, and have
the inbounds literal captured as the combination of inboundsFlag and
nusw.
Fixes: iree#20482
Signed-off-by: Lin, Peiyong <linpyong@gmail.com>
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index 45ccf30..6c0fe36 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -876,4 +876,32 @@
let cppNamespace = "::mlir::LLVM::uwtable";
}
+//===----------------------------------------------------------------------===//
+// GEPNoWrapFlags
+//===----------------------------------------------------------------------===//
+
+// These values must match llvm::GEPNoWrapFlags ones.
+// See llvm/include/llvm/IR/GEPNoWrapFlags.h.
+// Since inbounds implies nusw, create an inboundsFlag that represents the
+// concept of raw inbounds with no nusw implication and the actual inbounds
+// literal will be captured as the combination of inboundsFlag and nusw.
+
+def GEPNone : I32BitEnumCaseNone<"none">;
+def GEPInboundsFlag : I32BitEnumCaseBit<"inboundsFlag", 0, "inbounds_flag">;
+def GEPNusw : I32BitEnumCaseBit<"nusw", 1>;
+def GEPNuw : I32BitEnumCaseBit<"nuw", 2>;
+def GEPInbounds : BitEnumCaseGroup<"inbounds", [GEPInboundsFlag, GEPNusw]>;
+
+def GEPNoWrapFlags : I32BitEnum<
+ "GEPNoWrapFlags",
+ "::mlir::LLVM::GEPNoWrapFlags",
+ [GEPNone, GEPInboundsFlag, GEPNusw, GEPNuw, GEPInbounds]> {
+ let cppNamespace = "::mlir::LLVM";
+ let printBitEnumPrimaryGroups = 1;
+}
+
+def GEPNoWrapFlagsProp : EnumProp<GEPNoWrapFlags> {
+ let defaultValue = interfaceType # "::none";
+}
+
#endif // LLVMIR_ENUMS
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 5745d37..5315e39 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -291,7 +291,7 @@
Variadic<LLVM_ScalarOrVectorOf<AnySignlessInteger>>:$dynamicIndices,
DenseI32ArrayAttr:$rawConstantIndices,
TypeAttr:$elem_type,
- UnitAttr:$inbounds);
+ GEPNoWrapFlagsProp:$noWrapFlags);
let results = (outs LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$res);
let skipDefaultBuilders = 1;
@@ -303,8 +303,12 @@
as indices. In the case of indexing within a structure, it is required to
either use constant indices directly, or supply a constant SSA value.
- An optional 'inbounds' attribute specifies the low-level pointer arithmetic
+ The no-wrap flags can be used to specify the low-level pointer arithmetic
overflow behavior that LLVM uses after lowering the operation to LLVM IR.
+ Valid options include 'inbounds' (pointer arithmetic must be within object
+ bounds), 'nusw' (no unsigned signed wrap), and 'nuw' (no unsigned wrap).
+ Note that 'inbounds' implies 'nusw' which is ensured by the enum
+ definition. The flags can be set individually or in combination.
Examples:
@@ -323,10 +327,12 @@
let builders = [
OpBuilder<(ins "Type":$resultType, "Type":$elementType, "Value":$basePtr,
- "ValueRange":$indices, CArg<"bool", "false">:$inbounds,
+ "ValueRange":$indices,
+ CArg<"GEPNoWrapFlags", "GEPNoWrapFlags::none">:$noWrapFlags,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
OpBuilder<(ins "Type":$resultType, "Type":$elementType, "Value":$basePtr,
- "ArrayRef<GEPArg>":$indices, CArg<"bool", "false">:$inbounds,
+ "ArrayRef<GEPArg>":$indices,
+ CArg<"GEPNoWrapFlags", "GEPNoWrapFlags::none">:$noWrapFlags,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
];
let llvmBuilder = [{
@@ -343,10 +349,13 @@
}
Type baseElementType = op.getElemType();
llvm::Type *elementType = moduleTranslation.convertType(baseElementType);
- $res = builder.CreateGEP(elementType, $base, indices, "", $inbounds);
+ $res = builder.CreateGEP(elementType, $base, indices, "",
+ llvm::GEPNoWrapFlags::fromRaw(
+ static_cast<unsigned>(
+ op.getNoWrapFlags())));
}];
let assemblyFormat = [{
- (`inbounds` $inbounds^)?
+ ($noWrapFlags^)?
$base `[` custom<GEPIndices>($dynamicIndices, $rawConstantIndices) `]` attr-dict
`:` functional-type(operands, results) `,` $elem_type
}];
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 26c3ef1..d0ac39f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -673,29 +673,29 @@
void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
Type elementType, Value basePtr, ArrayRef<GEPArg> indices,
- bool inbounds, ArrayRef<NamedAttribute> attributes) {
+ GEPNoWrapFlags noWrapFlags,
+ ArrayRef<NamedAttribute> attributes) {
SmallVector<int32_t> rawConstantIndices;
SmallVector<Value> dynamicIndices;
destructureIndices(elementType, indices, rawConstantIndices, dynamicIndices);
result.addTypes(resultType);
result.addAttributes(attributes);
- result.addAttribute(getRawConstantIndicesAttrName(result.name),
- builder.getDenseI32ArrayAttr(rawConstantIndices));
- if (inbounds) {
- result.addAttribute(getInboundsAttrName(result.name),
- builder.getUnitAttr());
- }
- result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType));
+ result.getOrAddProperties<Properties>().rawConstantIndices =
+ builder.getDenseI32ArrayAttr(rawConstantIndices);
+ result.getOrAddProperties<Properties>().noWrapFlags = noWrapFlags;
+ result.getOrAddProperties<Properties>().elem_type =
+ TypeAttr::get(elementType);
result.addOperands(basePtr);
result.addOperands(dynamicIndices);
}
void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
Type elementType, Value basePtr, ValueRange indices,
- bool inbounds, ArrayRef<NamedAttribute> attributes) {
+ GEPNoWrapFlags noWrapFlags,
+ ArrayRef<NamedAttribute> attributes) {
build(builder, result, resultType, elementType, basePtr,
- SmallVector<GEPArg>(indices), inbounds, attributes);
+ SmallVector<GEPArg>(indices), noWrapFlags, attributes);
}
static ParseResult
@@ -794,6 +794,9 @@
return emitOpError("expected as many dynamic indices as specified in '")
<< getRawConstantIndicesAttrName().getValue() << "'";
+ if (getNoWrapFlags() == GEPNoWrapFlags::inboundsFlag)
+ return emitOpError("'inbounds_flag' cannot be used directly.");
+
return verifyStructIndices(getElemType(), getIndices(),
[&] { return emitOpError(); });
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 8640ef2..bc451f8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -891,7 +891,7 @@
auto byteType = IntegerType::get(builder.getContext(), 8);
auto newPtr = builder.createOrFold<LLVM::GEPOp>(
getLoc(), getResult().getType(), byteType, newSlot.ptr,
- ArrayRef<GEPArg>(accessInfo->subslotOffset), getInbounds());
+ ArrayRef<GEPArg>(accessInfo->subslotOffset), getNoWrapFlags());
getResult().replaceAllUsesWith(newPtr);
return DeletionKind::Delete;
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 0a3371c..7f2c0ca1 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -2035,8 +2035,9 @@
}
Type type = convertType(inst->getType());
- auto gepOp = builder.create<GEPOp>(loc, type, sourceElementType, *basePtr,
- indices, gepInst->isInBounds());
+ auto gepOp = builder.create<GEPOp>(
+ loc, type, sourceElementType, *basePtr, indices,
+ static_cast<GEPNoWrapFlags>(gepInst->getNoWrapFlags().getRaw()));
mapValue(inst, gepOp);
return success();
}
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index e8902c4..5dea940 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1819,3 +1819,11 @@
^bb1:
llvm.return %0 : !llvm.ptr
}
+
+// -----
+
+llvm.func @gep_inbounds_flag_usage(%ptr: !llvm.ptr, %idx: i64) {
+ // expected-error@+1 {{'inbounds_flag' cannot be used directly}}
+ llvm.getelementptr inbounds_flag %ptr[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
+ llvm.return
+}
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 93916f6..f30c8f2 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -236,6 +236,16 @@
llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>
// CHECK: llvm.getelementptr inbounds %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
llvm.getelementptr inbounds %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
+ // CHECK: llvm.getelementptr inbounds|nuw %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
+ llvm.getelementptr inbounds | nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
+ // CHECK: llvm.getelementptr inbounds %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
+ llvm.getelementptr inbounds | nusw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
+ // CHECK: llvm.getelementptr nusw %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
+ llvm.getelementptr nusw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
+ // CHECK: llvm.getelementptr nusw|nuw %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
+ llvm.getelementptr nusw | nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
+ // CHECK: llvm.getelementptr nuw %{{.*}}[%{{.*}}, 0, %{{.*}}] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
+ llvm.getelementptr nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/Import/instructions.ll b/mlir/test/Target/LLVMIR/Import/instructions.ll
index c294e1b..2098d85 100644
--- a/mlir/test/Target/LLVMIR/Import/instructions.ll
+++ b/mlir/test/Target/LLVMIR/Import/instructions.ll
@@ -557,6 +557,25 @@
; // -----
+; CHECK-LABEL: @gep_no_wrap_flags
+; CHECK-SAME: %[[PTR:[a-zA-Z0-9]+]]
+define void @gep_no_wrap_flags(ptr %ptr) {
+ ; CHECK: %[[IDX:.+]] = llvm.mlir.constant(7 : i32)
+ ; CHECK: llvm.getelementptr inbounds %[[PTR]][%[[IDX]]] : (!llvm.ptr, i32) -> !llvm.ptr, f32
+ %1 = getelementptr inbounds float, ptr %ptr, i32 7
+ ; CHECK: llvm.getelementptr nusw %[[PTR]][%[[IDX]]] : (!llvm.ptr, i32) -> !llvm.ptr, f32
+ %2 = getelementptr nusw float, ptr %ptr, i32 7
+ ; CHECK: llvm.getelementptr nuw %[[PTR]][%[[IDX]]] : (!llvm.ptr, i32) -> !llvm.ptr, f32
+ %3 = getelementptr nuw float, ptr %ptr, i32 7
+ ; CHECK: llvm.getelementptr nusw|nuw %[[PTR]][%[[IDX]]] : (!llvm.ptr, i32) -> !llvm.ptr, f32
+ %4 = getelementptr nusw nuw float, ptr %ptr, i32 7
+ ; CHECK: llvm.getelementptr inbounds|nuw %[[PTR]][%[[IDX]]] : (!llvm.ptr, i32) -> !llvm.ptr, f32
+ %5 = getelementptr inbounds nuw float, ptr %ptr, i32 7
+ ret void
+}
+
+; // -----
+
; CHECK: @varargs(...)
declare void @varargs(...)
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index 80778e4..e2eac7d 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1057,6 +1057,14 @@
llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>
// CHECK: = getelementptr inbounds { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}}
llvm.getelementptr inbounds %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
+ // CHECK: = getelementptr inbounds nuw { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}}
+ llvm.getelementptr inbounds | nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
+ // CHECK: = getelementptr nusw { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}}
+ llvm.getelementptr nusw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
+ // CHECK: = getelementptr nusw nuw { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}}
+ llvm.getelementptr nusw | nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
+ // CHECK: = getelementptr nuw { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}}
+ llvm.getelementptr nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
llvm.return
}