blob: bc9a3a52728ad4649f30f4fc4ffbd164eb143009 [file] [log] [blame]
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import operator
from itertools import accumulate
from typing import Optional
from ._memref_ops_gen import *
from ._ods_common import _dispatch_mixed_values, MixedValues
from .arith import ConstantOp, _is_integer_like_type
from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation
def _is_constant_int_like(i):
return (
isinstance(i, Value)
and isinstance(i.owner, Operation)
and isinstance(i.owner.opview, ConstantOp)
and _is_integer_like_type(i.type)
)
def _is_static_int_like(i):
return (
isinstance(i, int) and not ShapedType.is_dynamic_size(i)
) or _is_constant_int_like(i)
def _infer_memref_subview_result_type(
source_memref_type, offsets, static_sizes, static_strides
):
source_strides, source_offset = source_memref_type.get_strides_and_offset()
# "canonicalize" from tuple|list -> list
offsets, static_sizes, static_strides, source_strides = map(
list, (offsets, static_sizes, static_strides, source_strides)
)
if not all(
all(_is_static_int_like(i) for i in s)
for s in [
static_sizes,
static_strides,
source_strides,
]
):
raise ValueError(
"Only inferring from python or mlir integer constant is supported."
)
for s in [offsets, static_sizes, static_strides]:
for idx, i in enumerate(s):
if _is_constant_int_like(i):
s[idx] = i.owner.opview.literal_value
if any(not _is_static_int_like(i) for i in offsets + [source_offset]):
target_offset = ShapedType.get_dynamic_size()
else:
target_offset = source_offset
for offset, target_stride in zip(offsets, source_strides):
target_offset += offset * target_stride
target_strides = []
for source_stride, static_stride in zip(source_strides, static_strides):
target_strides.append(source_stride * static_stride)
# If default striding then no need to complicate things for downstream ops (e.g., expand_shape).
default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1]
if target_strides == default_strides and target_offset == 0:
layout = None
else:
layout = StridedLayoutAttr.get(target_offset, target_strides)
return (
offsets,
static_sizes,
static_strides,
MemRefType.get(
static_sizes,
source_memref_type.element_type,
layout,
source_memref_type.memory_space,
),
)
_generated_subview = subview
def subview(
source: Value,
offsets: MixedValues,
sizes: MixedValues,
strides: MixedValues,
*,
result_type: Optional[MemRefType] = None,
loc=None,
ip=None,
):
if offsets is None:
offsets = []
if sizes is None:
sizes = []
if strides is None:
strides = []
source_strides, source_offset = source.type.get_strides_and_offset()
if result_type is None and all(
all(_is_static_int_like(i) for i in s) for s in [sizes, strides, source_strides]
):
# If any are arith.constant results then this will canonicalize to python int
# (which can then be used to fully specify the subview).
(
offsets,
sizes,
strides,
result_type,
) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides)
elif result_type is None:
raise ValueError(
"mixed static/dynamic offset/sizes/strides requires explicit result type."
)
offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets)
sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes)
strides, _packed_strides, static_strides = _dispatch_mixed_values(strides)
return _generated_subview(
result_type,
source,
offsets,
sizes,
strides,
static_offsets,
static_sizes,
static_strides,
loc=loc,
ip=ip,
)