Skip to content

Passes

passes

ModulePassT = TypeVar('ModulePassT', bound='ModulePass') module-attribute

ModulePass dataclass

Bases: ABC

A Pass is a named rewrite pass over an IR module that can accept arguments.

All passes are expected to leave the IR in a valid state after application, meaning that a call to .verify() succeeds on the whole module. In turn, all passes can expect that the IR they are applied to is in a valid state. It is not required that the IR verifies at any point while the pass is being applied.

In order to make a pass accept arguments, it must be a dataclass. Furthermore, only the following types are supported as argument types:

Base types: int | float | bool | string N-tuples of base types: tuple[int, ...], tuple[int|float, ...], tuple[int, ...] | tuple[float, ...] Top-level optional: ... | None

Pass arguments on the CLI are formatted as follows:

CLI arg Mapped to class field ------------------------- ------------------------------ my-pass{arg-1=1} arg_1: int = 1 my-pass{arg-1} arg_1: int | None = None my-pass{arg-1=1,2,3} arg_1: tuple[int, ...] = (1, 2, 3) my-pass{arg-1=true} arg_1: bool | None = True

Source code in xdsl/passes.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
@dataclass(frozen=True)
class ModulePass(ABC):
    """
    A Pass is a named rewrite pass over an IR module that can accept arguments.

    All passes are expected to leave the IR in a valid state *after* application,
    meaning that a call to .verify() succeeds on the whole module. In turn, all
    passes can expect that the IR they are applied to is in a valid state. It
    is not required that the IR verifies at any point while the pass is being
    applied.

    In order to make a pass accept arguments, it must be a dataclass. Furthermore,
    only the following types are supported as argument types:

    Base types:                int | float | bool | string
    N-tuples of base types:
        tuple[int, ...], tuple[int|float, ...], tuple[int, ...] | tuple[float, ...]
    Top-level optional:        ... | None

    Pass arguments on the CLI are formatted as follows:

    CLI arg                             Mapped to class field
    -------------------------           ------------------------------
    my-pass{arg-1=1}                    arg_1: int             = 1
    my-pass{arg-1}                      arg_1: int | None      = None
    my-pass{arg-1=1,2,3}                arg_1: tuple[int, ...] = (1, 2, 3)
    my-pass{arg-1=true}                 arg_1: bool | None     = True
    """

    name: ClassVar[str]

    @abstractmethod
    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None: ...

    def apply_to_clone(
        self, ctx: Context, op: builtin.ModuleOp
    ) -> tuple[Context, builtin.ModuleOp]:
        """
        Creates deep copies of the module and the context, and returns the result of
        calling `apply` on them.
        """
        ctx = ctx.clone()
        op = op.clone()
        self.apply(ctx, op)
        return ctx, op

    @classmethod
    def from_pass_spec(cls, spec: PipelinePassSpec) -> Self:
        """
        This method takes a PipelinePassSpec, does type checking on the
        arguments, and then instantiates an instance of the ModulePass
        from the spec.
        """
        if spec.name != cls.name:
            raise ValueError(
                f"Cannot create Pass {cls.name} from pass arguments for pass {spec.name}"
            )

        # normalize spec arg names:
        spec_arguments_dict: dict[str, PassArgListType] = (
            spec.normalize_arg_names().args
        )

        # get all dataclass fields
        fields: tuple[Field[Any], ...] = dataclasses.fields(cls)

        # start constructing the argument dict for the dataclass
        arg_dict = dict[str, PassArgListType | PassArgElementType | None]()

        required_fields = cls.required_fields()

        field_types = get_type_hints(cls)

        # iterate over all fields of the dataclass
        for op_field in fields:
            # ignore the name field and everything that's not used by __init__
            if op_field.name == "name" or not op_field.init:
                continue
            # check that non-optional fields are present
            if op_field.name not in spec_arguments_dict:
                if op_field.name not in required_fields:
                    arg_dict[op_field.name] = _get_default(op_field)
                    continue
                raise ValueError(f'Pass {cls.name} requires argument "{op_field.name}"')

            # convert pass arg to the correct type:
            field_type = field_types[op_field.name]
            arg_dict[op_field.name] = _convert_pass_arg_to_type(
                spec_arguments_dict.pop(op_field.name),
                field_type,
            )
            # we use .pop here to also remove the arg from the dict

        # if not all args were removed we raise an error
        if len(spec_arguments_dict) != 0:
            arguments_str = ", ".join(f'"{arg}"' for arg in spec_arguments_dict)
            fields_str = ", ".join(f'"{field.name}"' for field in fields)
            raise ValueError(
                f"Provided arguments [{arguments_str}] not found in expected pass "
                f"arguments [{fields_str}]"
            )

        # instantiate the dataclass using kwargs
        return cls(**arg_dict)

    @classmethod
    def required_fields(cls) -> set[str]:
        """
        Inspects the definition of the pass for fields that do not have default values.
        """
        return {
            field.name for field in dataclasses.fields(cls) if not _is_optional(field)
        }

    def pipeline_pass_spec(self, *, include_default: bool = False) -> PipelinePassSpec:
        """
        This function takes a ModulePass and returns a PipelinePassSpec.

        If `include_default` is `True`, then optional arguments are not included in the
        spec.
        """
        # get all dataclass fields
        fields = dataclasses.fields(self)
        args: dict[str, PassArgListType] = {}

        # iterate over all fields of the dataclass
        for op_field in fields:
            name = op_field.name
            # ignore the name field and everything that's not used by __init__
            if name == "name" or not op_field.init:
                continue

            val = getattr(self, name)

            if _is_optional(op_field):
                if val == _get_default(op_field) and not include_default:
                    continue

            if val is None:
                arg_list = ()
            elif isinstance(val, PassArgElementType):
                arg_list = (val,)
            else:
                arg_list = val

            args[name] = arg_list
        return PipelinePassSpec(self.name, args)

    @classmethod
    def schedule_space(
        cls, ctx: Context, module_op: builtin.ModuleOp
    ) -> tuple[Self, ...]:
        """
        Returns a tuple of `Self` that can be applied to rewrite the given module with
        the given context without error.
        The default implementation attempts to construct an instance with no parameters,
        and run it on the module_op; if the module_op is mutated then the pass instance
        is returned.
        Parametrizable passes should override this implementation to provide a full
        schedule space of transformations.
        """
        try:
            pass_instance = cls()
            _, cloned_module = pass_instance.apply_to_clone(ctx, module_op)
            if module_op.is_structurally_equivalent(cloned_module):
                return ()
        except Exception:
            return ()
        return (pass_instance,)

    def __str__(self) -> str:
        return str(self.pipeline_pass_spec())

name: str class-attribute

__init__() -> None

apply(ctx: Context, op: builtin.ModuleOp) -> None abstractmethod

Source code in xdsl/passes.py
64
65
@abstractmethod
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None: ...

apply_to_clone(ctx: Context, op: builtin.ModuleOp) -> tuple[Context, builtin.ModuleOp]

Creates deep copies of the module and the context, and returns the result of calling apply on them.

Source code in xdsl/passes.py
67
68
69
70
71
72
73
74
75
76
77
def apply_to_clone(
    self, ctx: Context, op: builtin.ModuleOp
) -> tuple[Context, builtin.ModuleOp]:
    """
    Creates deep copies of the module and the context, and returns the result of
    calling `apply` on them.
    """
    ctx = ctx.clone()
    op = op.clone()
    self.apply(ctx, op)
    return ctx, op

from_pass_spec(spec: PipelinePassSpec) -> Self classmethod

This method takes a PipelinePassSpec, does type checking on the arguments, and then instantiates an instance of the ModulePass from the spec.

Source code in xdsl/passes.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
@classmethod
def from_pass_spec(cls, spec: PipelinePassSpec) -> Self:
    """
    This method takes a PipelinePassSpec, does type checking on the
    arguments, and then instantiates an instance of the ModulePass
    from the spec.
    """
    if spec.name != cls.name:
        raise ValueError(
            f"Cannot create Pass {cls.name} from pass arguments for pass {spec.name}"
        )

    # normalize spec arg names:
    spec_arguments_dict: dict[str, PassArgListType] = (
        spec.normalize_arg_names().args
    )

    # get all dataclass fields
    fields: tuple[Field[Any], ...] = dataclasses.fields(cls)

    # start constructing the argument dict for the dataclass
    arg_dict = dict[str, PassArgListType | PassArgElementType | None]()

    required_fields = cls.required_fields()

    field_types = get_type_hints(cls)

    # iterate over all fields of the dataclass
    for op_field in fields:
        # ignore the name field and everything that's not used by __init__
        if op_field.name == "name" or not op_field.init:
            continue
        # check that non-optional fields are present
        if op_field.name not in spec_arguments_dict:
            if op_field.name not in required_fields:
                arg_dict[op_field.name] = _get_default(op_field)
                continue
            raise ValueError(f'Pass {cls.name} requires argument "{op_field.name}"')

        # convert pass arg to the correct type:
        field_type = field_types[op_field.name]
        arg_dict[op_field.name] = _convert_pass_arg_to_type(
            spec_arguments_dict.pop(op_field.name),
            field_type,
        )
        # we use .pop here to also remove the arg from the dict

    # if not all args were removed we raise an error
    if len(spec_arguments_dict) != 0:
        arguments_str = ", ".join(f'"{arg}"' for arg in spec_arguments_dict)
        fields_str = ", ".join(f'"{field.name}"' for field in fields)
        raise ValueError(
            f"Provided arguments [{arguments_str}] not found in expected pass "
            f"arguments [{fields_str}]"
        )

    # instantiate the dataclass using kwargs
    return cls(**arg_dict)

required_fields() -> set[str] classmethod

Inspects the definition of the pass for fields that do not have default values.

Source code in xdsl/passes.py
138
139
140
141
142
143
144
145
@classmethod
def required_fields(cls) -> set[str]:
    """
    Inspects the definition of the pass for fields that do not have default values.
    """
    return {
        field.name for field in dataclasses.fields(cls) if not _is_optional(field)
    }

pipeline_pass_spec(*, include_default: bool = False) -> PipelinePassSpec

This function takes a ModulePass and returns a PipelinePassSpec.

If include_default is True, then optional arguments are not included in the spec.

Source code in xdsl/passes.py
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def pipeline_pass_spec(self, *, include_default: bool = False) -> PipelinePassSpec:
    """
    This function takes a ModulePass and returns a PipelinePassSpec.

    If `include_default` is `True`, then optional arguments are not included in the
    spec.
    """
    # get all dataclass fields
    fields = dataclasses.fields(self)
    args: dict[str, PassArgListType] = {}

    # iterate over all fields of the dataclass
    for op_field in fields:
        name = op_field.name
        # ignore the name field and everything that's not used by __init__
        if name == "name" or not op_field.init:
            continue

        val = getattr(self, name)

        if _is_optional(op_field):
            if val == _get_default(op_field) and not include_default:
                continue

        if val is None:
            arg_list = ()
        elif isinstance(val, PassArgElementType):
            arg_list = (val,)
        else:
            arg_list = val

        args[name] = arg_list
    return PipelinePassSpec(self.name, args)

schedule_space(ctx: Context, module_op: builtin.ModuleOp) -> tuple[Self, ...] classmethod

Returns a tuple of Self that can be applied to rewrite the given module with the given context without error. The default implementation attempts to construct an instance with no parameters, and run it on the module_op; if the module_op is mutated then the pass instance is returned. Parametrizable passes should override this implementation to provide a full schedule space of transformations.

Source code in xdsl/passes.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
@classmethod
def schedule_space(
    cls, ctx: Context, module_op: builtin.ModuleOp
) -> tuple[Self, ...]:
    """
    Returns a tuple of `Self` that can be applied to rewrite the given module with
    the given context without error.
    The default implementation attempts to construct an instance with no parameters,
    and run it on the module_op; if the module_op is mutated then the pass instance
    is returned.
    Parametrizable passes should override this implementation to provide a full
    schedule space of transformations.
    """
    try:
        pass_instance = cls()
        _, cloned_module = pass_instance.apply_to_clone(ctx, module_op)
        if module_op.is_structurally_equivalent(cloned_module):
            return ()
    except Exception:
        return ()
    return (pass_instance,)

__str__() -> str

Source code in xdsl/passes.py
203
204
def __str__(self) -> str:
    return str(self.pipeline_pass_spec())

PassOptionInfo

Bases: NamedTuple

The name, expected type, and default value for one option of a module pass.

Source code in xdsl/passes.py
207
208
209
210
211
212
class PassOptionInfo(NamedTuple):
    """The name, expected type, and default value for one option of a module pass."""

    name: str
    expected_type: str
    default_value: str | None = None

name: str instance-attribute

expected_type: str instance-attribute

default_value: str | None = None class-attribute instance-attribute

PassPipeline dataclass

A representation of a pass pipeline, with an optional callback to be executed between each of the passes.

Source code in xdsl/passes.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
@dataclass(frozen=True)
class PassPipeline:
    """
    A representation of a pass pipeline, with an optional callback to be executed
    between each of the passes.
    """

    passes: tuple[ModulePass, ...]
    """
    These will be executed sequentially during the execution of the pipeline.
    """
    callback: Callable[[ModulePass, builtin.ModuleOp, ModulePass], None] | None = field(
        default=None
    )
    """
    Function called in between every pass, taking the pass that just ran, the module,
    and the next pass.
    """

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        if not self.passes:
            # Early exit to avoid fetching a non-existing last pass.
            return
        callback = self.callback

        for prev, next in zip(self.passes[:-1], self.passes[1:]):
            prev.apply(ctx, op)
            if callback is not None:
                callback(prev, op, next)

        self.passes[-1].apply(ctx, op)

    @staticmethod
    def parse_spec(
        available_passes: dict[str, Callable[[], type[ModulePass]]],
        spec: str,
        callback: Callable[[ModulePass, builtin.ModuleOp, ModulePass], None]
        | None = None,
    ) -> PassPipeline:
        specs = tuple(parse_pipeline(spec))
        unrecognised_passes = tuple(
            p.name for p in specs if p.name not in available_passes
        )
        if unrecognised_passes:
            raise ValueError(f"Unrecognized passes: {list(unrecognised_passes)}")

        passes = tuple(available_passes[p.name]().from_pass_spec(p) for p in specs)

        return PassPipeline(passes, callback)

passes: tuple[ModulePass, ...] instance-attribute

These will be executed sequentially during the execution of the pipeline.

callback: Callable[[ModulePass, builtin.ModuleOp, ModulePass], None] | None = field(default=None) class-attribute instance-attribute

Function called in between every pass, taking the pass that just ran, the module, and the next pass.

__init__(passes: tuple[ModulePass, ...], callback: Callable[[ModulePass, builtin.ModuleOp, ModulePass], None] | None = None) -> None

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/passes.py
252
253
254
255
256
257
258
259
260
261
262
263
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    if not self.passes:
        # Early exit to avoid fetching a non-existing last pass.
        return
    callback = self.callback

    for prev, next in zip(self.passes[:-1], self.passes[1:]):
        prev.apply(ctx, op)
        if callback is not None:
            callback(prev, op, next)

    self.passes[-1].apply(ctx, op)

parse_spec(available_passes: dict[str, Callable[[], type[ModulePass]]], spec: str, callback: Callable[[ModulePass, builtin.ModuleOp, ModulePass], None] | None = None) -> PassPipeline staticmethod

Source code in xdsl/passes.py
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
@staticmethod
def parse_spec(
    available_passes: dict[str, Callable[[], type[ModulePass]]],
    spec: str,
    callback: Callable[[ModulePass, builtin.ModuleOp, ModulePass], None]
    | None = None,
) -> PassPipeline:
    specs = tuple(parse_pipeline(spec))
    unrecognised_passes = tuple(
        p.name for p in specs if p.name not in available_passes
    )
    if unrecognised_passes:
        raise ValueError(f"Unrecognized passes: {list(unrecognised_passes)}")

    passes = tuple(available_passes[p.name]().from_pass_spec(p) for p in specs)

    return PassPipeline(passes, callback)

get_pass_option_infos(arg: type[ModulePassT]) -> tuple[PassOptionInfo, ...]

Returns the expected argument names, types, and optional expected values for options for the given pass.

Source code in xdsl/passes.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def get_pass_option_infos(
    arg: type[ModulePassT],
) -> tuple[PassOptionInfo, ...]:
    """
    Returns the expected argument names, types, and optional expected values for options
    for the given pass.
    """

    return tuple(
        PassOptionInfo(
            field.name,
            type_repr(field.type),
            str(getattr(arg, field.name)).lower() if hasattr(arg, field.name) else None,
        )
        for field in dataclasses.fields(arg)
    )