Skip to content

Dialect stub

dialect_stub

DialectStubGenerator dataclass

Generate a typing stub file (.pyi) for a dialect.

Source code in xdsl/utils/dialect_stub.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
@dataclass
class DialectStubGenerator:
    """Generate a typing stub file (.pyi) for a dialect."""

    dialect: Dialect
    dependencies: dict[str, set[str]] = field(
        init=False, default_factory=dict[str, set[str]]
    )

    def _import(self, module: ModuleType | str, name: str | type[Any]):
        """
        Internal helper to keep track of dependencies, to later generate clean import
        statements.
        """
        # If passed a type, use its name.
        if isinstance(name, type):
            name = name.__name__
        # If passed a module, use its name.
        if isinstance(module, ModuleType):
            module = module.__name__
        # Do not import from builtins, those are the implicitely available ones.
        if module == "builtins":
            return
        # Do no import from banned nested modules
        if module.startswith("xdsl.ir."):
            module = "xdsl.ir"
        if module.startswith("xdsl.irdl."):
            module = "xdsl.irdl"
        # Create a module dependency or add a new name to the ones from an existing
        # dependency
        if module in self.dependencies:
            self.dependencies[module].add(name)
        else:
            self.dependencies[module] = {name}

    def _generate_constraint_type(self, constraint: AttrConstraint) -> str:
        """
        Return a type hint for the member constrained by a constraint, by it an
        attribute parameter, or an operation attribute/property.
        """
        import xdsl.dialects.builtin
        import xdsl.ir
        from xdsl.dialects.builtin import ArrayAttr, ArrayOfConstraint

        match constraint:
            case BaseAttr(attr_type):
                if attr_type not in self.dialect.attributes:
                    self._import(attr_type.__module__, attr_type.__name__)
                return attr_type.__name__
            case EqAttrConstraint(attr):
                if type(attr) not in self.dialect.attributes:
                    self._import(type(attr).__module__, type(attr).__name__)
                return type(attr).__name__

            case AnyOf(attr_constrs=constraints):
                return " | ".join(
                    self._generate_constraint_type(c) for c in constraints
                )
            case AllOf(constraints):
                self._import(typing, "Annotated")
                return f"Annotated[{', '.join(self._generate_constraint_type(c) for c in reversed(constraints))}]"  # noqa: E501
            case ArrayOfConstraint(RangeOf(constraint)):
                self._import(xdsl.dialects.builtin, ArrayAttr)
                return f"ArrayAttr[{self._generate_constraint_type(constraint)}]"
            case AnyAttr():
                self._import(xdsl.ir, Attribute)
                return "Attribute"
            case ParamAttrConstraint():
                base_type = cast(
                    ParamAttrConstraint[ParametrizedAttribute], constraint
                ).base_attr
                return base_type.__name__

            case _:
                raise NotImplementedError(
                    f"Unsupported constraint type: {type(constraint)}"
                )

    def _generate_attribute_stub(self, attr: type[ParametrizedAttribute]):
        """
        Generate type stub for an irdl attribute.
        """
        # They all are ParametrizedAttributes.
        self._import(xdsl.ir, ParametrizedAttribute)
        # Get the bases that are not already bases of ParametrizedAttribute.
        bases = set(attr.__mro__[1:]) - set(ParametrizedAttribute.__mro__)
        # Add them as stub dependencies
        for base in bases:
            self._import(base.__module__, base)

        # Also add them to the Attribute class' bases.
        bases = ", ".join(b.__name__ for b in bases)
        if bases:
            bases += ", "
        yield f"class {attr.__name__}({bases}ParametrizedAttribute):"

        # Generate the parameters' stubs, if any
        attr_def = attr.get_irdl_definition()
        for name, param in attr_def.parameters:
            yield f'    {name} : "{self._generate_constraint_type(param.constr)}"'
        # Otherwise, generate a pass for Python's indentation
        if not attr_def.parameters:
            yield "    pass"
        yield ""
        yield ""

    def _generate_operation_stub(self, op: type[IRDLOperation]):
        """
        Generate type stub for an irdl operation.
        """
        # Keep track of whether the operation has any body, to generate a pass if it
        # does not.
        had_body = False

        # They all are IRDLOperations.
        self._import(xdsl.irdl, IRDLOperation)

        # Currently there's nothing that should be a base class or the IRDLOperations.
        # Traits are not supported, and implemented as fields in PyRDL.
        yield f"class {op.__name__}(IRDLOperation):"

        # Generate the constructs' stubs, if any
        op_def = op.get_irdl_definition()
        for name, o in op_def.operands:
            had_body = True
            match o:
                case VarOperandDef(_):
                    self._import(xdsl.irdl, "VarOperand")
                    yield f"    {name} : VarOperand"
                case OptOperandDef(_):
                    self._import(xdsl.irdl, "OptOperand")
                    yield f"    {name} : OptOperand"
                case OperandDef(_):
                    self._import(xdsl.irdl, "Operand")
                    yield f"    {name} : Operand"
        for name, o in op_def.results:
            had_body = True
            match o:
                case VarResultDef():
                    self._import(xdsl.irdl, "VarOpResult")
                    yield f"    {name} : VarOpResult"
                case OptResultDef():
                    self._import(xdsl.irdl, "OptOpResult")
                    yield f"    {name} : OptOpResult"
                case ResultDef():
                    self._import(xdsl.ir, OpResult)
                    yield f"    {name} : OpResult"
        for name, o in op_def.attributes.items():
            had_body = True
            match o:
                case OptAttributeDef():
                    yield f"    {name} : {self._generate_constraint_type(o.constr)} | None"  # noqa: E501
                case AttributeDef():
                    yield f"    {name} : {self._generate_constraint_type(o.constr)}"
        for name, o in op_def.properties.items():
            had_body = True
            match o:
                case OptPropertyDef():
                    yield f"    {name} : {self._generate_constraint_type(o.constr)} | None"  # noqa: E501
                case PropertyDef():
                    yield f"    {name} : {self._generate_constraint_type(o.constr)}"

        for name, r in op_def.regions:
            had_body = True
            match r:
                case OptRegionDef():
                    self._import(xdsl.irdl, "OptRegion")
                    yield f"    {name} : OptRegion"
                case VarRegionDef():
                    self._import(xdsl.irdl, "VarRegion")
                    yield f"    {name} : VarRegion"
                case RegionDef():
                    self._import(xdsl.ir, Region)
                    yield f"    {name} : Region"

        for name, r in op_def.successors:
            had_body = True
            match r:
                case OptSuccessorDef():
                    self._import(xdsl.irdl, "OptSuccessor")
                    yield f"    {name} : OptSuccessor"
                case VarSuccessorDef():
                    self._import(xdsl.irdl, "VarSuccessor")
                    yield f"    {name} : VarSuccessor"
                case SuccessorDef():
                    self._import(xdsl.irdl, "Successor")
                    yield f"    {name} : Successor"
        # Generate a pass if the operation had no body.
        if not had_body:
            yield "    pass"
        yield ""
        yield ""

    def _generate_dialect_stubs(self):
        """
        Generate a dialect's stubs.

        Just generate stubs for all attributes and operations in the dialect.
        """
        for attr in self.dialect.attributes:
            if issubclass(attr, ParametrizedAttribute):
                for l in self._generate_attribute_stub(attr):
                    yield l

        for op in self.dialect.operations:
            if issubclass(op, IRDLOperation):
                for l in self._generate_operation_stub(op):
                    yield l

    def _generate_imports(self):
        """
        Generate import statements for all the dependencies of the stub.
        """
        # sort modules alphabetically for deterministic and clean output.
        items = list(self.dependencies.items())
        items.sort()

        for module, names in items:
            # If only one name is imported from a module, make a one-liner import.
            if len(names) == 1:
                name = names.pop()
                yield f"from {module} import {name}"
            # Otherwise, import all names in a multi-line import, sorted again for
            # a deterministic and clean output.
            else:
                names = list(names)
                names.sort()
                yield f"from {module} import ("
                for o in names:
                    yield f"    {o},"
                yield ")"

    def generate_dialect_stubs(self):
        """
        The main function, generate stubs for the passed dialect and return as a string.

        NB: probably not optimal perf-wise, but I don't foresee this as a bottleneck.
        """
        self._import(xdsl.ir, Dialect)
        dialect_body = "\n".join(self._generate_dialect_stubs())
        imports = "\n".join(self._generate_imports())
        if imports:
            imports += "\n"

        return f"""\
{imports}
{dialect_body}
{self.dialect.name.capitalize()} : Dialect
"""

dialect: Dialect instance-attribute

dependencies: dict[str, set[str]] = field(init=False, default_factory=(dict[str, set[str]])) class-attribute instance-attribute

__init__(dialect: Dialect) -> None

generate_dialect_stubs()

The main function, generate stubs for the passed dialect and return as a string.

NB: probably not optimal perf-wise, but I don't foresee this as a bottleneck.

Source code in xdsl/utils/dialect_stub.py
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    def generate_dialect_stubs(self):
        """
        The main function, generate stubs for the passed dialect and return as a string.

        NB: probably not optimal perf-wise, but I don't foresee this as a bottleneck.
        """
        self._import(xdsl.ir, Dialect)
        dialect_body = "\n".join(self._generate_dialect_stubs())
        imports = "\n".join(self._generate_imports())
        if imports:
            imports += "\n"

        return f"""\
{imports}
{dialect_body}
{self.dialect.name.capitalize()} : Dialect
"""

make_all_stubs()

Source code in xdsl/utils/dialect_stub.py
291
292
293
294
295
296
297
298
299
def make_all_stubs():
    import xdsl.dialects

    dialects = xdsl.dialects
    directory = "/".join(dialects.__path__)
    for file in os.listdir(directory):
        name, ext = os.path.splitext(file)
        if ext == ".irdl":
            import_module(f"{directory}/{name}")