Skip to content

Printf to llvm

printf_to_llvm

i8 = builtin.IntegerType(8) module-attribute

PrintlnOpToPrintfCall

Bases: RewritePattern

Source code in xdsl/transforms/printf_to_llvm.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
class PrintlnOpToPrintfCall(RewritePattern):
    collected_global_symbs: dict[str, llvm.GlobalOp]

    def __init__(self):
        self.collected_global_symbs = dict()

    def _construct_global(self, val: str):
        """
        Constructs an llvm.global operation containing the string. Assigns a unique
        symbol name to the value that is derived from the string value.
        """
        data = val.encode() + b"\x00"

        t_type = builtin.TensorType(i8, [len(data)])

        return llvm.GlobalOp(
            llvm.LLVMArrayType(len(data), i8),
            _key_from_str(val),
            constant=True,
            linkage="internal",
            value=builtin.DenseIntOrFPElementsAttr(t_type, builtin.BytesAttr(data)),
        )

    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: printf.PrintFormatOp, rewriter: PatternRewriter, /):
        format_str = ""
        args: list[SSAValue] = []
        casts: list[Operation] = []
        # make sure all arguments are in the format libc expects them to be
        # e.g. floats must be promoted to double before calling
        for part in _format_string_spec_from_print_op(op):
            if isinstance(part, str):
                format_str += part
            elif isinstance(part.type, builtin.IndexType):
                # index must be cast to fixed bitwidth before printing
                casts.append(new_val := arith.IndexCastOp(part, builtin.i64))
                args.append(new_val.result)
                format_str += "%li"
            elif part.type == builtin.f32:
                # f32 must be promoted to f64 before printing
                casts.append(new_val := arith.ExtFOp(part, builtin.f64))
                args.append(new_val.result)
                format_str += "%f"
            else:
                args.append(part)
                format_str += _format_str_for_type(part.type)

        globl = self._construct_global(format_str)
        self.collected_global_symbs[globl.sym_name.data] = globl

        rewriter.replace_op(
            op,
            casts
            + [
                ptr := llvm.AddressOfOp(globl.sym_name, llvm.LLVMPointerType()),
                llvm.CallOp("printf", ptr.result, *args, variadic_args=len(args)),
            ],
        )

collected_global_symbs: dict[str, llvm.GlobalOp] = dict() instance-attribute

__init__()

Source code in xdsl/transforms/printf_to_llvm.py
89
90
def __init__(self):
    self.collected_global_symbs = dict()

match_and_rewrite(op: printf.PrintFormatOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/printf_to_llvm.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
@op_type_rewrite_pattern
def match_and_rewrite(self, op: printf.PrintFormatOp, rewriter: PatternRewriter, /):
    format_str = ""
    args: list[SSAValue] = []
    casts: list[Operation] = []
    # make sure all arguments are in the format libc expects them to be
    # e.g. floats must be promoted to double before calling
    for part in _format_string_spec_from_print_op(op):
        if isinstance(part, str):
            format_str += part
        elif isinstance(part.type, builtin.IndexType):
            # index must be cast to fixed bitwidth before printing
            casts.append(new_val := arith.IndexCastOp(part, builtin.i64))
            args.append(new_val.result)
            format_str += "%li"
        elif part.type == builtin.f32:
            # f32 must be promoted to f64 before printing
            casts.append(new_val := arith.ExtFOp(part, builtin.f64))
            args.append(new_val.result)
            format_str += "%f"
        else:
            args.append(part)
            format_str += _format_str_for_type(part.type)

    globl = self._construct_global(format_str)
    self.collected_global_symbs[globl.sym_name.data] = globl

    rewriter.replace_op(
        op,
        casts
        + [
            ptr := llvm.AddressOfOp(globl.sym_name, llvm.LLVMPointerType()),
            llvm.CallOp("printf", ptr.result, *args, variadic_args=len(args)),
        ],
    )

PrintfToLLVM dataclass

Bases: ModulePass

Source code in xdsl/transforms/printf_to_llvm.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
class PrintfToLLVM(ModulePass):
    name = "printf-to-llvm"

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        add_printf_call = PrintlnOpToPrintfCall()

        PatternRewriteWalker(add_printf_call).rewrite_module(op)

        if not add_printf_call.collected_global_symbs:
            return

        op.body.block.add_ops(
            [
                llvm.FuncOp(
                    "printf",
                    llvm.LLVMFunctionType([llvm.LLVMPointerType()], is_variadic=True),
                    linkage=llvm.LinkageAttr("external"),
                ),
                *add_printf_call.collected_global_symbs.values(),
            ]
        )

name = 'printf-to-llvm' class-attribute instance-attribute

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/printf_to_llvm.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    add_printf_call = PrintlnOpToPrintfCall()

    PatternRewriteWalker(add_printf_call).rewrite_module(op)

    if not add_printf_call.collected_global_symbs:
        return

    op.body.block.add_ops(
        [
            llvm.FuncOp(
                "printf",
                llvm.LLVMFunctionType([llvm.LLVMPointerType()], is_variadic=True),
                linkage=llvm.LinkageAttr("external"),
            ),
            *add_printf_call.collected_global_symbs.values(),
        ]
    )

legalize_str_for_symbol_name(val: str)

Takes any string and legalizes it to be a global llvm symbol. (for the strictest possible interpretation of this)

  • Replaces all whitespaces and dots with _
  • Deletes all non ascii alphanumerical characters
  • Strips all underscores from the start and end of the string

The resulting string consists only of ascii letters, underscores and digits.

This is a surjective mapping, meaning that multiple inputs will produce the same output. This function alone cannot be used to get a uniquely identifying global symbol name for a string!

Source code in xdsl/transforms/printf_to_llvm.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def legalize_str_for_symbol_name(val: str):
    """
    Takes any string and legalizes it to be a global llvm symbol.
    (for the strictest possible interpretation of this)

     - Replaces all whitespaces and dots with _
     - Deletes all non ascii alphanumerical characters
     - Strips all underscores from the start and end of the string

    The resulting string consists only of ascii letters, underscores and digits.

    This is a surjective mapping, meaning that multiple inputs will produce the same
    output. This function alone cannot be used to get a uniquely identifying global
    symbol name for a string!
    """
    val = re.sub(r"(\s+|\.)", "_", val)
    val = re.sub(r"[^A-Za-z0-9_]+", "", val).strip("_")
    return val