Skip to content

Stencil shape minimize

stencil_shape_minimize

ShapeAnalysis dataclass

Bases: TypeConversionPattern

Source code in xdsl/transforms/stencil_shape_minimize.py
21
22
23
24
25
26
27
28
29
@dataclass
class ShapeAnalysis(TypeConversionPattern):
    seen: set[stencil.TempType[Attribute]] = field(
        default_factory=set[stencil.TempType[Attribute]]
    )

    @attr_type_rewrite_pattern
    def convert_type(self, typ: stencil.TempType[Attribute], /) -> Attribute | None:
        self.seen.add(typ)

seen: set[stencil.TempType[Attribute]] = field(default_factory=(set[stencil.TempType[Attribute]])) class-attribute instance-attribute

__init__(recursive: bool = False, ops: tuple[type[Operation], ...] | None = None, seen: set[stencil.TempType[Attribute]] = set[stencil.TempType[Attribute]]()) -> None

convert_type(typ: stencil.TempType[Attribute]) -> Attribute | None

Source code in xdsl/transforms/stencil_shape_minimize.py
27
28
29
@attr_type_rewrite_pattern
def convert_type(self, typ: stencil.TempType[Attribute], /) -> Attribute | None:
    self.seen.add(typ)

ShapeMinimisation dataclass

Bases: TypeConversionPattern

Source code in xdsl/transforms/stencil_shape_minimize.py
32
33
34
35
36
37
38
39
40
41
42
43
@dataclass
class ShapeMinimisation(TypeConversionPattern):
    shape: stencil.StencilBoundsAttr | None = None

    @attr_type_rewrite_pattern
    def convert_type(self, typ: stencil.FieldType[Attribute], /) -> Attribute | None:
        if (
            typ.bounds != self.shape
            and self.shape
            and len(self.shape.ub) == typ.get_num_dims()
        ):
            return stencil.FieldType(self.shape, typ.element_type)

shape: stencil.StencilBoundsAttr | None = None class-attribute instance-attribute

__init__(recursive: bool = False, ops: tuple[type[Operation], ...] | None = None, shape: stencil.StencilBoundsAttr | None = None) -> None

convert_type(typ: stencil.FieldType[Attribute]) -> Attribute | None

Source code in xdsl/transforms/stencil_shape_minimize.py
36
37
38
39
40
41
42
43
@attr_type_rewrite_pattern
def convert_type(self, typ: stencil.FieldType[Attribute], /) -> Attribute | None:
    if (
        typ.bounds != self.shape
        and self.shape
        and len(self.shape.ub) == typ.get_num_dims()
    ):
        return stencil.FieldType(self.shape, typ.element_type)

InvalidateTemps dataclass

Bases: TypeConversionPattern

Source code in xdsl/transforms/stencil_shape_minimize.py
46
47
48
49
50
51
@dataclass
class InvalidateTemps(TypeConversionPattern):
    @attr_type_rewrite_pattern
    def convert_type(self, typ: stencil.TempType[Attribute], /) -> Attribute | None:
        if isinstance(typ.bounds, stencil.StencilBoundsAttr):
            return stencil.TempType(len(typ.bounds.lb), typ.element_type)

__init__(recursive: bool = False, ops: tuple[type[Operation], ...] | None = None) -> None

convert_type(typ: stencil.TempType[Attribute]) -> Attribute | None

Source code in xdsl/transforms/stencil_shape_minimize.py
48
49
50
51
@attr_type_rewrite_pattern
def convert_type(self, typ: stencil.TempType[Attribute], /) -> Attribute | None:
    if isinstance(typ.bounds, stencil.StencilBoundsAttr):
        return stencil.TempType(len(typ.bounds.lb), typ.element_type)

FuncOpShapeUpdate dataclass

Bases: RewritePattern

Source code in xdsl/transforms/stencil_shape_minimize.py
54
55
56
57
58
59
@dataclass(frozen=True)
class FuncOpShapeUpdate(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
        if not op.is_declaration:
            op.update_function_type()

__init__() -> None

match_and_rewrite(op: func.FuncOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/stencil_shape_minimize.py
56
57
58
59
@op_type_rewrite_pattern
def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
    if not op.is_declaration:
        op.update_function_type()

RestrictStoreOp dataclass

Bases: RewritePattern

Source code in xdsl/transforms/stencil_shape_minimize.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
@dataclass
class RestrictStoreOp(RewritePattern):
    restrict: tuple[int, ...]

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: stencil.StoreOp, rewriter: PatternRewriter, /):
        if len(self.restrict) != len(op.bounds.ub):
            return
        new_bounds = [
            (min(lower_bound, bound_lim), min(upper_bound, bound_lim))
            for lower_bound, upper_bound, bound_lim in zip(
                op.bounds.lb, op.bounds.ub, self.restrict
            )
        ]
        new_bounds_attr = stencil.StencilBoundsAttr(new_bounds)
        if new_bounds_attr != op.bounds:
            rewriter.replace_op(
                op,
                stencil.StoreOp.get(
                    temp=op.temp, field=op.field, bounds=new_bounds_attr
                ),
            )

restrict: tuple[int, ...] instance-attribute

__init__(restrict: tuple[int, ...]) -> None

match_and_rewrite(op: stencil.StoreOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/stencil_shape_minimize.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
@op_type_rewrite_pattern
def match_and_rewrite(self, op: stencil.StoreOp, rewriter: PatternRewriter, /):
    if len(self.restrict) != len(op.bounds.ub):
        return
    new_bounds = [
        (min(lower_bound, bound_lim), min(upper_bound, bound_lim))
        for lower_bound, upper_bound, bound_lim in zip(
            op.bounds.lb, op.bounds.ub, self.restrict
        )
    ]
    new_bounds_attr = stencil.StencilBoundsAttr(new_bounds)
    if new_bounds_attr != op.bounds:
        rewriter.replace_op(
            op,
            stencil.StoreOp.get(
                temp=op.temp, field=op.field, bounds=new_bounds_attr
            ),
        )

StencilShapeMinimize dataclass

Bases: ModulePass

Minimises the shapes of stencil.field types that have been over-allocated and are larger than necessary.

Source code in xdsl/transforms/stencil_shape_minimize.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
@dataclass(frozen=True)
class StencilShapeMinimize(ModulePass):
    """
    Minimises the shapes of `stencil.field` types that have been over-allocated and are larger than necessary.
    """

    name = "stencil-shape-minimize"

    restrict: tuple[int, ...] | None = None

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        if self.restrict:
            PatternRewriteWalker(
                GreedyRewritePatternApplier(
                    [
                        InvalidateTemps(),
                        RestrictStoreOp(restrict=self.restrict),
                    ]
                )
            ).rewrite_module(op)
            infer_shapes(op)
        analysis = ShapeAnalysis(seen=set())
        PatternRewriteWalker(analysis).rewrite_module(op)
        bounds = set(
            t.bounds
            for t in analysis.seen
            if isinstance(t.bounds, stencil.StencilBoundsAttr)
        )
        if not bounds:
            return
        dim_shapes = dict[int, StencilBoundsAttr]()

        # construct one minimal shape for each number of dimensions
        for b in bounds:
            dims = len(b.ub)
            if dims in dim_shapes:
                dim_shapes[dims] |= b
            else:
                dim_shapes[dims] = b

        if self.restrict and len(dim_shapes) != 1:
            raise PassFailedException(
                "Cannot restrict stencil programs with different dimensionality"
            )

        shape_minimisations = [
            ShapeMinimisation(shape=shape) for shape in dim_shapes.values()
        ]

        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    *shape_minimisations,
                    FuncOpShapeUpdate(),
                ]
            ),
        ).rewrite_module(op)

name = 'stencil-shape-minimize' class-attribute instance-attribute

restrict: tuple[int, ...] | None = None class-attribute instance-attribute

__init__(restrict: tuple[int, ...] | None = None) -> None

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/stencil_shape_minimize.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    if self.restrict:
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    InvalidateTemps(),
                    RestrictStoreOp(restrict=self.restrict),
                ]
            )
        ).rewrite_module(op)
        infer_shapes(op)
    analysis = ShapeAnalysis(seen=set())
    PatternRewriteWalker(analysis).rewrite_module(op)
    bounds = set(
        t.bounds
        for t in analysis.seen
        if isinstance(t.bounds, stencil.StencilBoundsAttr)
    )
    if not bounds:
        return
    dim_shapes = dict[int, StencilBoundsAttr]()

    # construct one minimal shape for each number of dimensions
    for b in bounds:
        dims = len(b.ub)
        if dims in dim_shapes:
            dim_shapes[dims] |= b
        else:
            dim_shapes[dims] = b

    if self.restrict and len(dim_shapes) != 1:
        raise PassFailedException(
            "Cannot restrict stencil programs with different dimensionality"
        )

    shape_minimisations = [
        ShapeMinimisation(shape=shape) for shape in dim_shapes.values()
    ]

    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                *shape_minimisations,
                FuncOpShapeUpdate(),
            ]
        ),
    ).rewrite_module(op)