blob: 35572debb3450d77003cc3bfcf43ce0c1542766d [file] [log] [blame]
//===-- TosaTypesBase.td - TOSA type definitions -----------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the type definitions for the TOSA dialect.
//
//===----------------------------------------------------------------------===//
#ifndef TOSA_TYPES_BASE
#define TOSA_TYPES_BASE
include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
// Tosa Type Definitions.
//===----------------------------------------------------------------------===//
// The base class of a quantized type.
// Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end].
// Where low and high ends are 0,255 when unsigned, -128,127 when signed, for
// the 8-bit case.
class Tosa_QuantizedType<string n, list<int> params, bit signed>
: Type<And<[CPred<"$_self.isa<mlir::quant::QuantizedType>()">,
CPred<"$_self.cast<mlir::quant::QuantizedType>()" #
".getStorageTypeIntegralWidth() == " # !head(params)>]>,
"Q" # !if (signed, "int", "uint") # !head(params) # " type"> {
string name = n;
string asTraitArgsStr = !interleave(params, ", ") #
!if(signed, ", true", ", false");
}
//===----------------------------------------------------------------------===//
// Non-Quantized Signed Integer Types.
// Used to express accumulator results or compare results.
//===----------------------------------------------------------------------===//
def Tosa_Int8 : I<8>;
def Tosa_Int16 : I<16>;
def Tosa_Int32 : I<32>;
def Tosa_Int48 : I<48>;
def Tosa_Int64 : I<64>;
def Tosa_SignedInt : AnyTypeOf<[Tosa_Int8,
Tosa_Int16,
Tosa_Int32,
Tosa_Int48,
Tosa_Int64]>;
def Tosa_Bool : I<1>;
// No unsigned unquantized int types.
def Tosa_Int : AnyTypeOf<[Tosa_Bool,
Tosa_SignedInt]>;
def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
Tosa_Int64]>;
//===----------------------------------------------------------------------===//
// Quantized Integer Types.
// Datatype for network feature map or weight content.
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Name Symmetry Grouping Sign
//===----------------------------------------------------------------------===//
// uint8 : asymmetric per tensor , unsigned
// int4 : symmetric per channel, signed
// int8 : symmetric per tensor/per channel, signed
// int16 : symmetric per tensor, signed
//===----------------------------------------------------------------------===//
def Tosa_QuantizedInt : AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>,
Tosa_QuantizedType<"int4", [4, 0], 1>,
Tosa_QuantizedType<"int8", [8, 0], 1>,
Tosa_QuantizedType<"int16", [16, 0], 1>]>;
//===----------------------------------------------------------------------===//
// Floating-point types.
//===----------------------------------------------------------------------===//
def Tosa_Float : AnyTypeOf<[
F32,
F16,
BF16]>;
//===----------------------------------------------------------------------===//
// Multi-category types.
//===----------------------------------------------------------------------===//
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float],
"number">;
//===----------------------------------------------------------------------===//
// Tensor types
//===----------------------------------------------------------------------===//
def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
// Any tensor element type allowed in Tosa ops.
def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
Tosa_Float.predicate]>, "tosa.dtype">;
class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;
//===----------------------------------------------------------------------===//
// Tensor types with constrained ranks.
//===----------------------------------------------------------------------===//
// Must be listed rank.
def Tosa_Tensor1D : 1DTensorOf<[Tosa_AnyNumber]>;
def Tosa_Tensor2D : 2DTensorOf<[Tosa_AnyNumber]>;
def Tosa_Tensor3D : 3DTensorOf<[Tosa_AnyNumber]>;
def Tosa_Tensor4D : 4DTensorOf<[Tosa_AnyNumber]>;
def Tosa_Tensor5D : TensorRankOf<[Tosa_AnyNumber], [5]>;
def Tosa_Tensor6D : TensorRankOf<[Tosa_AnyNumber], [6]>;
// Ranked tensors up to given rank.
def Tosa_Tensor1Dto2D : TensorRankOf<[Tosa_AnyNumber], [1,2]>;
def Tosa_Tensor1Dto4D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>;
def Tosa_Tensor1Dto5D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5]>;
def Tosa_Tensor1Dto6D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>;
def Tosa_TensorUpto4D : TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>;
def Tosa_TensorUpto6D : TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4,5,6]>;
def Tosa_Int32TensorUpto4D : TensorRankOf<[Tosa_Int32], [0,1,2,3,4]>;
//===----------------------------------------------------------------------===//
// Generic scalar, vector, or tensor of a particular type.
//===----------------------------------------------------------------------===//
class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<Or<[
AnyTypeOf<types>.predicate,
VectorOf<types>.predicate,
TensorOf<types>.predicate]>,
"signless-integer-32-like">;
def Tosa_Int8Like : Tosa_TypeLike<[Tosa_Int8], "signless-integer-8-bit-like">;
def Tosa_Int16Like : Tosa_TypeLike<[Tosa_Int16], "signless-integer-16-bit-like">;
def Tosa_Int32Like : Tosa_TypeLike<[Tosa_Int32], "signless-integer-32-bit-like">;
def Tosa_Int64Like : Tosa_TypeLike<[Tosa_Int64], "signless-integer-64-bit-like">;
//===----------------------------------------------------------------------===//
// Attribute predicates and classes.
//===----------------------------------------------------------------------===//
class ArrayMaxCt<int n> : AttrConstraint<
CPred<"$_self.cast<::mlir::ArrayAttr>().size() <= " # n>,
"with at least " # n # " elements">;
def Tosa_Fp32ArrayAttr2 : Confined<F32ArrayAttr, [ArrayCount<2>]>;
def Tosa_Fp32ArrayAttr3 : Confined<F32ArrayAttr, [ArrayCount<3>]>;
def Tosa_Fp32ArrayAttr4 : Confined<F32ArrayAttr, [ArrayCount<4>]>;
def Tosa_Fp32ArrayAttr5 : Confined<F32ArrayAttr, [ArrayCount<5>]>;
def Tosa_Fp32ArrayAttr6 : Confined<F32ArrayAttr, [ArrayCount<6>]>;
def Tosa_IntArrayAttr2 : Confined<I64ArrayAttr, [ArrayCount<2>]>;
def Tosa_IntArrayAttr3 : Confined<I64ArrayAttr, [ArrayCount<3>]>;
def Tosa_IntArrayAttr4 : Confined<I64ArrayAttr, [ArrayCount<4>]>;
def Tosa_IntArrayAttr5 : Confined<I64ArrayAttr, [ArrayCount<5>]>;
def Tosa_IntArrayAttr6 : Confined<I64ArrayAttr, [ArrayCount<6>]>;
def Tosa_IntArrayAttrUpto2 : Confined<I64ArrayAttr, [ArrayMaxCt<2>]>;
def Tosa_IntArrayAttrUpto4 : Confined<I64ArrayAttr, [ArrayMaxCt<4>]>;
def Tosa_IntArrayAttrUpto5 : Confined<I64ArrayAttr, [ArrayMaxCt<5>]>;
//===----------------------------------------------------------------------===//
// Iterable attributes.
//===----------------------------------------------------------------------===//
// Supported regimes for tosa.resize.
def Tosa_ResizeTypeAttr : StringBasedAttr<
CPred<"$_self.cast<StringAttr>().getValue() == \"BILINEAR\" || " #
"$_self.cast<StringAttr>().getValue() == \"NEAREST_NEIGHBOR\"">,
"Supported resize/upsampling strategies">;
def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
// Tensor to buffer types.
def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>;
def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>;
def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>;
#endif // TOSA_TYPES_BASE