Skip to content

Stencil

stencil

ApplyRedundantOperands

Bases: RewritePattern

Merge duplicate operands of a stencil.apply.

Source code in xdsl/transforms/canonicalization_patterns/stencil.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class ApplyRedundantOperands(RewritePattern):
    """
    Merge duplicate operands of a `stencil.apply`.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> None:
        unique_operands = list[SSAValue]()
        rbargs = list[int]()

        found_duplicate: bool = False

        for i, o in enumerate(op.args):
            try:
                ui = unique_operands.index(o)
                rbargs.append(ui)
                found_duplicate = True
            except ValueError:
                unique_operands.append(o)
                rbargs.append(i)

        if not found_duplicate:
            return

        bbargs = op.region.block.args
        for i, a in enumerate(bbargs):
            if rbargs[i] == i:
                continue
            a.replace_all_uses_with(bbargs[rbargs[i]])

        cse(op.region.block, rewriter)

match_and_rewrite(op: stencil.ApplyOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/canonicalization_patterns/stencil.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@op_type_rewrite_pattern
def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> None:
    unique_operands = list[SSAValue]()
    rbargs = list[int]()

    found_duplicate: bool = False

    for i, o in enumerate(op.args):
        try:
            ui = unique_operands.index(o)
            rbargs.append(ui)
            found_duplicate = True
        except ValueError:
            unique_operands.append(o)
            rbargs.append(i)

    if not found_duplicate:
        return

    bbargs = op.region.block.args
    for i, a in enumerate(bbargs):
        if rbargs[i] == i:
            continue
        a.replace_all_uses_with(bbargs[rbargs[i]])

    cse(op.region.block, rewriter)

ApplyUnusedOperands

Bases: RewritePattern

Remove unused operands of a stencil.apply.

Source code in xdsl/transforms/canonicalization_patterns/stencil.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class ApplyUnusedOperands(RewritePattern):
    """
    Remove unused operands of a `stencil.apply`.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> None:
        op_args = op.region.block.args
        unused = {a for a in op_args if not a.uses}
        if not unused:
            return
        bbargs = [a for a in op_args if a not in unused]
        bbargs_type = [a.type for a in bbargs]
        operands = [a for i, a in enumerate(op.args) if op_args[i] not in unused]

        for arg in unused:
            op.region.block.erase_arg(arg)

        new = stencil.ApplyOp.get(
            operands,
            block := Block(arg_types=bbargs_type),
            [r.type for r in op.res],
        )

        rewriter.inline_block(op.region.block, InsertPoint.at_start(block), block.args)
        rewriter.replace_op(op, new)

match_and_rewrite(op: stencil.ApplyOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/canonicalization_patterns/stencil.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
@op_type_rewrite_pattern
def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> None:
    op_args = op.region.block.args
    unused = {a for a in op_args if not a.uses}
    if not unused:
        return
    bbargs = [a for a in op_args if a not in unused]
    bbargs_type = [a.type for a in bbargs]
    operands = [a for i, a in enumerate(op.args) if op_args[i] not in unused]

    for arg in unused:
        op.region.block.erase_arg(arg)

    new = stencil.ApplyOp.get(
        operands,
        block := Block(arg_types=bbargs_type),
        [r.type for r in op.res],
    )

    rewriter.inline_block(op.region.block, InsertPoint.at_start(block), block.args)
    rewriter.replace_op(op, new)

ApplyUnusedResults

Bases: RewritePattern

Remove unused results of a stencil.apply.

Source code in xdsl/transforms/canonicalization_patterns/stencil.py
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class ApplyUnusedResults(RewritePattern):
    """
    Remove unused results of a `stencil.apply`.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> None:
        unused = [i for i, r in enumerate(op.res) if not r.uses]

        if not unused:
            return

        block = op.region.block
        op.region.detach_block(block)
        old_return = cast(stencil.ReturnOp, block.last_op)

        results = list(op.res)
        return_args = list(old_return.arg)

        for i in reversed(unused):
            results.pop(i)
            return_args.pop(i)

        new = stencil.ApplyOp.build(
            operands=[op.args, op.dest],
            regions=[Region(block)],
            result_types=[[r.type for r in results]],
            properties=op.properties.copy(),
            attributes=op.attributes.copy(),
        )

        replace_results: list[SSAValue | None] = list(new.res)
        for i in unused:
            replace_results.insert(i, None)

        rewriter.replace_op(old_return, stencil.ReturnOp.get(return_args))
        rewriter.replace_op(op, new, replace_results)

match_and_rewrite(op: stencil.ApplyOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/canonicalization_patterns/stencil.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
@op_type_rewrite_pattern
def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter) -> None:
    unused = [i for i, r in enumerate(op.res) if not r.uses]

    if not unused:
        return

    block = op.region.block
    op.region.detach_block(block)
    old_return = cast(stencil.ReturnOp, block.last_op)

    results = list(op.res)
    return_args = list(old_return.arg)

    for i in reversed(unused):
        results.pop(i)
        return_args.pop(i)

    new = stencil.ApplyOp.build(
        operands=[op.args, op.dest],
        regions=[Region(block)],
        result_types=[[r.type for r in results]],
        properties=op.properties.copy(),
        attributes=op.attributes.copy(),
    )

    replace_results: list[SSAValue | None] = list(new.res)
    for i in unused:
        replace_results.insert(i, None)

    rewriter.replace_op(old_return, stencil.ReturnOp.get(return_args))
    rewriter.replace_op(op, new, replace_results)

RemoveCastWithNoEffect

Bases: RewritePattern

Remove stencil.cast where input and output types are equal.

Source code in xdsl/transforms/canonicalization_patterns/stencil.py
114
115
116
117
118
119
120
121
122
class RemoveCastWithNoEffect(RewritePattern):
    """
    Remove `stencil.cast` where input and output types are equal.
    """

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: stencil.CastOp, rewriter: PatternRewriter) -> None:
        if op.result.type == op.field.type:
            rewriter.replace_op(op, [], new_results=[op.field])

match_and_rewrite(op: stencil.CastOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/canonicalization_patterns/stencil.py
119
120
121
122
@op_type_rewrite_pattern
def match_and_rewrite(self, op: stencil.CastOp, rewriter: PatternRewriter) -> None:
    if op.result.type == op.field.type:
        rewriter.replace_op(op, [], new_results=[op.field])