blob: 12430226173bea759e3ffd854c3b12fd76893077 [file] [log] [blame]
//===- GPUDialect.h - MLIR Dialect for GPU Kernels --------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the GPU kernel-related operations and puts them in the
// corresponding dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_GPU_GPUDIALECT_H
#define MLIR_DIALECT_GPU_GPUDIALECT_H
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/DLTI/Traits.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/FunctionSupport.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir {
class FuncOp;
namespace gpu {
/// Utility class for the GPU dialect to represent triples of `Value`s
/// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation.
struct KernelDim3 {
Value x;
Value y;
Value z;
};
class AsyncTokenType
: public Type::TypeBase<AsyncTokenType, Type, TypeStorage> {
public:
// Used for generic hooks in TypeBase.
using Base::Base;
};
/// MMAMatrixType storage and uniquing. Array is uniqued based on its shape
/// and type.
struct MMAMatrixStorageType : public TypeStorage {
MMAMatrixStorageType(unsigned numDims, const int64_t *dimShapes,
Type elementType, StringRef operand)
: dimShapes(dimShapes), numDims(numDims), elementType(elementType),
operand(operand) {}
/// The hash key for uniquing.
using KeyTy = std::tuple<ArrayRef<int64_t>, Type, StringRef>;
bool operator==(const KeyTy &key) const {
return key == KeyTy(getShape(), elementType, operand);
}
/// Construction.
static MMAMatrixStorageType *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));
StringRef operand = allocator.copyInto(std::get<2>(key));
return new (allocator.allocate<MMAMatrixStorageType>())
MMAMatrixStorageType(shape.size(), shape.data(), std::get<1>(key),
operand);
}
ArrayRef<int64_t> getShape() const {
return ArrayRef<int64_t>(dimShapes, numDims);
}
StringRef getOperand() const { return operand; }
/// Reference to the shape of the MMA matrix.
const int64_t *dimShapes;
/// Number of dimensions in the MMA matrix.
unsigned numDims;
/// Element type of elements held in the MMA matrix.
Type elementType;
/// MMA operand that this MMAMatrix holds. The general form of operation this
/// type supports is given by the equation C += A*B. This field specifies
/// which operand in the given equation is held by this type. The valid values
/// are "AOp", "BOp" and "COp".
StringRef operand;
};
/// MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply
/// accumulate operations. MMAMatrices are taken as direct operands by these
/// operations and are also produced as results. These matrices are meant to
/// reside in the registers. A limited number of pointwise operations can be
/// performed on these matrices, i.e., operations which operate uniformly on
/// all the elements in the matrix and do not change the order of matrix
/// elements. The above conditions exist because the layout of matrix elements
/// inside the matrix is opaque i.e., the elements may be present in the
/// matrix in any order. The general usage of this type is shown as follows:-
///
/// %0 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {leadDimension = 16 :
/// index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
///
/// The MMAMatrixType describes the shape of the matrix being loaded and the
/// operand being loaded too. The operand needs to be specified to aid the
/// lowering of this type to dialects such as NVVM where each workitem may
/// hold different amount of elements depending on the elementType of the
/// matrix. For e.g., Each workitem holds 4 vector<2xf16>s for f16 data type
/// and 8 f32s for f32 data type of MMAMatrix. Some other instances of usage
/// are:-
///
/// %3 = gpu.subgroup_mma_compute %0, %1, %2 :
/// !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">
/// -> !gpu.mma_matrix<16x16xf32, "COp">
///
///
/// gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16
/// : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
// TODO: consider moving this to ODS.
class MMAMatrixType
: public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
public:
using Base::Base;
/// Get MMAMatrixType and verify construction Invariants.
static MMAMatrixType get(ArrayRef<int64_t> shape, Type elementType,
StringRef operand);
/// Get MMAMatrixType at a particular location and verify construction
/// Invariants.
static MMAMatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand);
/// Check if a type is valid a MMAMatrixType elementType.
static bool isValidElementType(Type elementType);
/// Verify that shape and elementType are actually allowed for the
/// MMAMatrixType.
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
StringRef operand);
/// Get number of dims.
unsigned getNumDims() const;
/// Get shape of the matrix.
ArrayRef<int64_t> getShape() const;
/// Get elementType of a single element.
Type getElementType() const;
/// The general form of operation this type supports is given by the equation
/// C += A*B. This function returns which operand in the given equation is
/// held by this type. String returned can be one of"AOp", "BOp" and "COp".
StringRef getOperand() const;
};
// Adds a `gpu.async.token` to the front of the argument list.
void addAsyncDependency(Operation *op, Value token);
} // end namespace gpu
} // end namespace mlir
#include "mlir/Dialect/GPU/GPUOpsEnums.h.inc"
#include "mlir/Dialect/GPU/GPUOpsDialect.h.inc"
#include "mlir/Dialect/GPU/GPUOpInterfaces.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/GPU/GPUOps.h.inc"
#endif // MLIR_DIALECT_GPU_GPUDIALECT_H