blob: d0bdadf500f6f6ccd48441723bfc2fc4b26e9b7d [file] [log] [blame]
//===- SubsetOpInterface.cpp - Tensor Subsets -----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Interfaces/SubsetOpInterface.cpp.inc"
using namespace mlir;
OpOperand &detail::defaultGetDestinationOperand(Operation *op) {
auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
assert(dstOp && "getDestination must be implemented for non-DPS ops");
assert(
dstOp.getNumDpsInits() == 1 &&
"getDestination must be implemented for ops with 0 or more than 1 init");
return *dstOp.getDpsInitOperand(0);
}
OpResult detail::defaultGetUpdatedDestination(Operation *op) {
auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
assert(dstOp && "getUpdatedDestination must be implemented for non-DPS ops");
auto insertionOp = cast<SubsetInsertionOpInterface>(op);
return dstOp.getTiedOpResult(&insertionOp.getDestinationOperand());
}
bool detail::defaultIsEquivalentSubset(
Operation *op, Value candidate,
function_ref<bool(Value, Value)> equivalenceFn) {
assert(isa<SubsetInsertionOpInterface>(op) &&
"expected SubsetInsertionOpInterface");
if (!candidate.getDefiningOp<SubsetExtractionOpInterface>())
return false;
return cast<SubsetOpInterface>(op).operatesOnEquivalentSubset(
candidate.getDefiningOp<SubsetOpInterface>(), equivalenceFn);
}
bool detail::defaultOperatesOnEquivalentSubset(
Operation *op, SubsetOpInterface candidate,
function_ref<bool(Value, Value)> equivalenceFn) {
auto subsetOp = cast<SubsetOpInterface>(op);
FailureOr<HyperrectangularSlice> slice =
subsetOp.getAccessedHyperrectangularSlice();
assert(succeeded(slice) &&
"operatesOnEquivalentSubset must be implemented if "
"getAccessedHyperrectangularSlice is not implemented");
FailureOr<HyperrectangularSlice> otherSlice =
candidate.getAccessedHyperrectangularSlice();
if (failed(otherSlice))
return false;
if (!equivalenceFn(subsetOp.getTensorContainer(),
candidate.getTensorContainer()))
return false;
FailureOr<bool> equivalent = ValueBoundsConstraintSet::areEquivalentSlices(
op->getContext(), *slice, *otherSlice);
return succeeded(equivalent) && *equivalent;
}
bool detail::defaultOperatesOnDisjointSubset(
Operation *op, SubsetOpInterface candidate,
function_ref<bool(Value, Value)> equivalenceFn) {
auto subsetOp = cast<SubsetOpInterface>(op);
FailureOr<HyperrectangularSlice> slice =
subsetOp.getAccessedHyperrectangularSlice();
assert(succeeded(slice) &&
"defaultOperatesOnDisjointSubset must be implemented if "
"getAccessedHyperrectangularSlice is not implemented");
FailureOr<HyperrectangularSlice> otherSlice =
candidate.getAccessedHyperrectangularSlice();
if (failed(otherSlice))
return false;
if (!equivalenceFn(subsetOp.getTensorContainer(),
candidate.getTensorContainer()))
return false;
FailureOr<bool> overlapping = ValueBoundsConstraintSet::areOverlappingSlices(
op->getContext(), *slice, *otherSlice);
return succeeded(overlapping) && !*overlapping;
}
Value detail::getTensorContainer(Operation *op) {
if (auto insertionOp = dyn_cast<::mlir::SubsetInsertionOpInterface>(op))
return insertionOp.getDestinationOperand().get();
return cast<::mlir::SubsetExtractionOpInterface>(op).getSourceOperand().get();
}
LogicalResult detail::verifySubsetOpInterface(SubsetOpInterface op) {
if (!(isa<SubsetExtractionOpInterface>(op.getOperation()) ^
isa<SubsetInsertionOpInterface>(op.getOperation())))
return op->emitOpError(
"SubsetOpInterface ops must implement either "
"SubsetExtractionOpInterface or SubsetInsertionOpInterface");
return success();
}
LogicalResult
detail::verifySubsetExtractionOpInterface(SubsetExtractionOpInterface op) {
if (op->getNumResults() != 1)
return op->emitOpError(
"SubsetExtractionOpInterface ops must have one result");
return success();
}