blob: 588704f24574f9026211aa085b42d29c25ee0c4e [file]
//===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"
#define DEBUG_TYPE "mesh-ops"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
using namespace mlir;
using namespace mlir::mesh;
#include "mlir/Dialect/Mesh/IR/MeshOpsDialect.cpp.inc"
//===----------------------------------------------------------------------===//
// Mesh dialect
//===----------------------------------------------------------------------===//
void MeshDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.cpp.inc"
>();
}
Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
Type type, Location loc) {
return arith::ConstantOp::materialize(builder, value, type, loc);
}
//===----------------------------------------------------------------------===//
// Mesh utilities
//===----------------------------------------------------------------------===//
bool mesh::isReductionLoop(IteratorType iType) {
return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
}
bool mesh::areReductionAndPartialMatch(IteratorType iType, Partial partial) {
return (partial == Partial::Generic &&
iType == IteratorType::ReductionGeneric) ||
(partial == Partial::Sum && iType == IteratorType::ReductionSum) ||
(partial == Partial::Max && iType == IteratorType::ReductionMax) ||
(partial == Partial::Min && iType == IteratorType::ReductionMin);
}
Partial mesh::getPartialTypeFromReduction(IteratorType iType) {
switch (iType) {
case IteratorType::ReductionGeneric:
return Partial::Generic;
case IteratorType::ReductionSum:
return Partial::Sum;
case IteratorType::ReductionMax:
return Partial::Max;
case IteratorType::ReductionMin:
return Partial::Min;
default:
llvm_unreachable("No corresponding partial type can be found");
}
}
//===----------------------------------------------------------------------===//
// mesh.cluster op
//===----------------------------------------------------------------------===//
LogicalResult ClusterOp::verify() {
ArrayRef<int64_t> dimSizes = getDimSizes();
uint64_t rank = getRank();
if (rank == 0)
return emitOpError("rank of cluster is expected to be a positive integer");
if (dimSizes.size() > rank)
return emitOpError(
"rank of dim_sizes is not expected to be larger than rank of cluster");
for (int64_t dimSize : dimSizes) {
if (dimSize < 0)
return emitOpError(
"dimension size of a mesh cluster is expected to be non-negative");
}
return success();
}
//===----------------------------------------------------------------------===//
// mesh.shard op
//===----------------------------------------------------------------------===//
LogicalResult
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
SymbolRefAttr, ArrayRef<DenseI32ArrayAttr> splitAxes,
ArrayRef<int32_t> partialAxes, Partial) {
// TODO: At present cluster symbol ref is not verified. This is due to the
// difficulty in fetching the corresponding symbol op based on an attribute.
llvm::SmallSet<int32_t, 4> visitedAxes;
auto checkMeshAxis = [&](ArrayRef<int32_t> axesArray) -> LogicalResult {
for (int32_t axis : axesArray) {
if (axis < 0)
return emitError() << "mesh axis is expected to be non-negative";
if (!visitedAxes.insert(axis).second)
return emitError() << "mesh axis duplicated";
}
return success();
};
for (DenseI32ArrayAttr subAxes : splitAxes) {
ArrayRef<int32_t> subAxesArray = subAxes.asArrayRef();
if (failed(checkMeshAxis(subAxesArray)))
return failure();
}
if (failed(checkMeshAxis(partialAxes)))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Mesh/IR/MeshOpsAttributes.cpp.inc"
#include "mlir/Dialect/Mesh/IR/MeshOpsEnums.cpp.inc"