Skip to content

Constant fold interp

constant_fold_interp

A pass that applies the interpreter to operations with no side effects where all the inputs are constant, replacing the computation with a constant value.

ConstantFoldInterpPattern dataclass

Bases: RewritePattern

Source code in xdsl/transforms/constant_fold_interp.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
@dataclass
class ConstantFoldInterpPattern(RewritePattern):
    ctx: Context
    interpreter: Interpreter

    def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter, /):
        if not op.has_trait(Pure):
            # Only rewrite operations that don't have side-effects
            return

        # No need to rewrite operations that are already constant-like
        if op.has_trait(ConstantLike):
            return

        if not all(
            isinstance(operand, OpResult) and operand.op.has_trait(ConstantLike)
            for operand in op.operands
        ):
            # Only rewrite operations where all the operands are constants
            return

        dialect = self.ctx.get_dialect(op.dialect_name())

        if (
            materializer := dialect.get_interface(ConstantMaterializationInterface)
        ) is None:
            return

        try:
            args = tuple(
                self.interpreter.run_op(cast(OpResult, operand).op, ())[0]
                for operand in op.operands
            )
            results = self.interpreter.run_op(op, args)
        except InterpretationError:
            return

        new_ops: list[Operation] = []
        for interp_result, op_result in zip(results, op.results):
            result_attr = self.convert_to_attr(interp_result, op_result.type)
            if result_attr is None:
                return
            new_op = materializer.materialize_constant(result_attr, op_result.type)
            if new_op is None:
                return
            new_ops.append(new_op)

        rewriter.replace_op(op, new_ops, [new_op.results[0] for new_op in new_ops])

    def convert_to_attr(self, value: Any, value_type: Attribute) -> Attribute | None:
        match (value, value_type):
            case int(), IntegerType():
                return IntegerAttr(value, cast(IntegerType, value_type))
            case _:
                return None

ctx: Context instance-attribute

interpreter: Interpreter instance-attribute

__init__(ctx: Context, interpreter: Interpreter) -> None

match_and_rewrite(op: Operation, rewriter: PatternRewriter)

Source code in xdsl/transforms/constant_fold_interp.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter, /):
    if not op.has_trait(Pure):
        # Only rewrite operations that don't have side-effects
        return

    # No need to rewrite operations that are already constant-like
    if op.has_trait(ConstantLike):
        return

    if not all(
        isinstance(operand, OpResult) and operand.op.has_trait(ConstantLike)
        for operand in op.operands
    ):
        # Only rewrite operations where all the operands are constants
        return

    dialect = self.ctx.get_dialect(op.dialect_name())

    if (
        materializer := dialect.get_interface(ConstantMaterializationInterface)
    ) is None:
        return

    try:
        args = tuple(
            self.interpreter.run_op(cast(OpResult, operand).op, ())[0]
            for operand in op.operands
        )
        results = self.interpreter.run_op(op, args)
    except InterpretationError:
        return

    new_ops: list[Operation] = []
    for interp_result, op_result in zip(results, op.results):
        result_attr = self.convert_to_attr(interp_result, op_result.type)
        if result_attr is None:
            return
        new_op = materializer.materialize_constant(result_attr, op_result.type)
        if new_op is None:
            return
        new_ops.append(new_op)

    rewriter.replace_op(op, new_ops, [new_op.results[0] for new_op in new_ops])

convert_to_attr(value: Any, value_type: Attribute) -> Attribute | None

Source code in xdsl/transforms/constant_fold_interp.py
77
78
79
80
81
82
def convert_to_attr(self, value: Any, value_type: Attribute) -> Attribute | None:
    match (value, value_type):
        case int(), IntegerType():
            return IntegerAttr(value, cast(IntegerType, value_type))
        case _:
            return None

ConstantFoldInterpPass dataclass

Bases: ModulePass

A pass that applies the interpreter to operations with no side effects where all the inputs are constant, replacing the computation with a constant value.

Source code in xdsl/transforms/constant_fold_interp.py
85
86
87
88
89
90
91
92
93
94
95
96
97
class ConstantFoldInterpPass(ModulePass):
    """
    A pass that applies the interpreter to operations with no side effects where all the
    inputs are constant, replacing the computation with a constant value.
    """

    name = "constant-fold-interp"

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        interpreter = Interpreter(op)
        register_implementations(interpreter, ctx)
        pattern = ConstantFoldInterpPattern(ctx, interpreter)
        PatternRewriteWalker(pattern).rewrite_module(op)

name = 'constant-fold-interp' class-attribute instance-attribute

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/constant_fold_interp.py
93
94
95
96
97
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    interpreter = Interpreter(op)
    register_implementations(interpreter, ctx)
    pattern = ConstantFoldInterpPattern(ctx, interpreter)
    PatternRewriteWalker(pattern).rewrite_module(op)