Skip to content

Shape inference

shape_inference

ShapeInferenceRewritePattern

Bases: RewritePattern

Rewrite pattern that applies a shape inference pattern.

Source code in xdsl/transforms/shape_inference.py
14
15
16
17
18
19
20
21
22
23
24
25
class ShapeInferenceRewritePattern(RewritePattern):
    """Rewrite pattern that applies a shape inference pattern."""

    def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter, /):
        trait = op.get_trait(HasShapeInferencePatternsTrait)
        if trait is None:
            return
        patterns = trait.get_shape_inference_patterns()
        if len(patterns) == 1:
            patterns[0].match_and_rewrite(op, rewriter)
            return
        GreedyRewritePatternApplier(list(patterns)).match_and_rewrite(op, rewriter)

match_and_rewrite(op: Operation, rewriter: PatternRewriter)

Source code in xdsl/transforms/shape_inference.py
17
18
19
20
21
22
23
24
25
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter, /):
    trait = op.get_trait(HasShapeInferencePatternsTrait)
    if trait is None:
        return
    patterns = trait.get_shape_inference_patterns()
    if len(patterns) == 1:
        patterns[0].match_and_rewrite(op, rewriter)
        return
    GreedyRewritePatternApplier(list(patterns)).match_and_rewrite(op, rewriter)

ShapeInferencePass dataclass

Bases: ModulePass

Applies all shape inference patterns.

Source code in xdsl/transforms/shape_inference.py
28
29
30
31
32
33
34
35
36
class ShapeInferencePass(ModulePass):
    """
    Applies all shape inference patterns.
    """

    name = "shape-inference"

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        infer_shapes(op)

name = 'shape-inference' class-attribute instance-attribute

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/shape_inference.py
35
36
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    infer_shapes(op)

infer_shapes(op: builtin.ModuleOp)

A helper function for ShapeInferencePass which allows it to be called from within other passes while exposing the least restrictive API.

Source code in xdsl/transforms/shape_inference.py
39
40
41
42
43
44
45
def infer_shapes(op: builtin.ModuleOp):
    """
    A helper function for ShapeInferencePass which allows it to be called from
    within other passes while exposing the least restrictive API.
    """

    PatternRewriteWalker(ShapeInferenceRewritePattern()).rewrite_module(op)