A pass that applies the interpreter to operations with no side effects where all the
inputs are constant, replacing the computation with a constant value.
Bases: RewritePattern
Source code in xdsl/transforms/constant_fold_interp.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82 | @dataclass
class ConstantFoldInterpPattern(RewritePattern):
ctx: Context
interpreter: Interpreter
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter, /):
if not op.has_trait(Pure):
# Only rewrite operations that don't have side-effects
return
# No need to rewrite operations that are already constant-like
if op.has_trait(ConstantLike):
return
if not all(
isinstance(operand, OpResult) and operand.op.has_trait(ConstantLike)
for operand in op.operands
):
# Only rewrite operations where all the operands are constants
return
dialect = self.ctx.get_dialect(op.dialect_name())
if (
materializer := dialect.get_interface(ConstantMaterializationInterface)
) is None:
return
try:
args = tuple(
self.interpreter.run_op(cast(OpResult, operand).op, ())[0]
for operand in op.operands
)
results = self.interpreter.run_op(op, args)
except InterpretationError:
return
new_ops: list[Operation] = []
for interp_result, op_result in zip(results, op.results):
result_attr = self.convert_to_attr(interp_result, op_result.type)
if result_attr is None:
return
new_op = materializer.materialize_constant(result_attr, op_result.type)
if new_op is None:
return
new_ops.append(new_op)
rewriter.replace_op(op, new_ops, [new_op.results[0] for new_op in new_ops])
def convert_to_attr(self, value: Any, value_type: Attribute) -> Attribute | None:
match (value, value_type):
case int(), IntegerType():
return IntegerAttr(value, cast(IntegerType, value_type))
case _:
return None
|
Source code in xdsl/transforms/constant_fold_interp.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75 | def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter, /):
if not op.has_trait(Pure):
# Only rewrite operations that don't have side-effects
return
# No need to rewrite operations that are already constant-like
if op.has_trait(ConstantLike):
return
if not all(
isinstance(operand, OpResult) and operand.op.has_trait(ConstantLike)
for operand in op.operands
):
# Only rewrite operations where all the operands are constants
return
dialect = self.ctx.get_dialect(op.dialect_name())
if (
materializer := dialect.get_interface(ConstantMaterializationInterface)
) is None:
return
try:
args = tuple(
self.interpreter.run_op(cast(OpResult, operand).op, ())[0]
for operand in op.operands
)
results = self.interpreter.run_op(op, args)
except InterpretationError:
return
new_ops: list[Operation] = []
for interp_result, op_result in zip(results, op.results):
result_attr = self.convert_to_attr(interp_result, op_result.type)
if result_attr is None:
return
new_op = materializer.materialize_constant(result_attr, op_result.type)
if new_op is None:
return
new_ops.append(new_op)
rewriter.replace_op(op, new_ops, [new_op.results[0] for new_op in new_ops])
|
Source code in xdsl/transforms/constant_fold_interp.py
| def convert_to_attr(self, value: Any, value_type: Attribute) -> Attribute | None:
match (value, value_type):
case int(), IntegerType():
return IntegerAttr(value, cast(IntegerType, value_type))
case _:
return None
|
Bases: ModulePass
A pass that applies the interpreter to operations with no side effects where all the
inputs are constant, replacing the computation with a constant value.
Source code in xdsl/transforms/constant_fold_interp.py
85
86
87
88
89
90
91
92
93
94
95
96
97 | class ConstantFoldInterpPass(ModulePass):
"""
A pass that applies the interpreter to operations with no side effects where all the
inputs are constant, replacing the computation with a constant value.
"""
name = "constant-fold-interp"
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
interpreter = Interpreter(op)
register_implementations(interpreter, ctx)
pattern = ConstantFoldInterpPattern(ctx, interpreter)
PatternRewriteWalker(pattern).rewrite_module(op)
|
Source code in xdsl/transforms/constant_fold_interp.py
| def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
interpreter = Interpreter(op)
register_implementations(interpreter, ctx)
pattern = ConstantFoldInterpPattern(ctx, interpreter)
PatternRewriteWalker(pattern).rewrite_module(op)
|