Skip to content

Convert ptr to riscv

convert_ptr_to_riscv

PtrTypeConversion dataclass

Bases: TypeConversionPattern

Source code in xdsl/transforms/convert_ptr_to_riscv.py
30
31
32
33
class PtrTypeConversion(TypeConversionPattern):
    @attr_type_rewrite_pattern
    def convert_type(self, typ: ptr.PtrType) -> riscv.IntRegisterType:
        return riscv.Registers.UNALLOCATED_INT

convert_type(typ: ptr.PtrType) -> riscv.IntRegisterType

Source code in xdsl/transforms/convert_ptr_to_riscv.py
31
32
33
@attr_type_rewrite_pattern
def convert_type(self, typ: ptr.PtrType) -> riscv.IntRegisterType:
    return riscv.Registers.UNALLOCATED_INT

ConvertPtrAddOp dataclass

Bases: RewritePattern

Source code in xdsl/transforms/convert_ptr_to_riscv.py
36
37
38
39
40
41
@dataclass
class ConvertPtrAddOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: ptr.PtrAddOp, rewriter: PatternRewriter, /):
        oper1, oper2 = cast_operands_to_regs(rewriter, op)
        rewriter.replace_op(op, riscv.AddOp(oper1, oper2))

__init__() -> None

match_and_rewrite(op: ptr.PtrAddOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/convert_ptr_to_riscv.py
38
39
40
41
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ptr.PtrAddOp, rewriter: PatternRewriter, /):
    oper1, oper2 = cast_operands_to_regs(rewriter, op)
    rewriter.replace_op(op, riscv.AddOp(oper1, oper2))

ConvertStoreOp dataclass

Bases: RewritePattern

Source code in xdsl/transforms/convert_ptr_to_riscv.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@dataclass
class ConvertStoreOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: ptr.StoreOp, rewriter: PatternRewriter, /):
        addr, value = cast_operands_to_regs(rewriter, op)

        match value.type:
            case riscv.IntRegisterType():
                new_op = riscv.SwOp(
                    addr, value, 0, comment="store int value to pointer"
                )
            case riscv.FloatRegisterType():
                float_type = cast(AnyFloat, op.value.type)
                match float_type:
                    case Float32Type():
                        new_op = riscv.FSwOp(
                            addr,
                            value,
                            0,
                            comment="store float value to pointer",
                        )
                    case Float64Type():
                        new_op = riscv.FSdOp(
                            addr,
                            value,
                            0,
                            comment="store double value to pointer",
                        )
                    case _:
                        raise DiagnosticException(
                            f"Lowering memref.store op with floating point type {float_type} not yet implemented"
                        )
            case _:
                raise ValueError(f"Unexpected register type {op.value.type}")

        rewriter.replace_op(op, new_op)

__init__() -> None

match_and_rewrite(op: ptr.StoreOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/convert_ptr_to_riscv.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ptr.StoreOp, rewriter: PatternRewriter, /):
    addr, value = cast_operands_to_regs(rewriter, op)

    match value.type:
        case riscv.IntRegisterType():
            new_op = riscv.SwOp(
                addr, value, 0, comment="store int value to pointer"
            )
        case riscv.FloatRegisterType():
            float_type = cast(AnyFloat, op.value.type)
            match float_type:
                case Float32Type():
                    new_op = riscv.FSwOp(
                        addr,
                        value,
                        0,
                        comment="store float value to pointer",
                    )
                case Float64Type():
                    new_op = riscv.FSdOp(
                        addr,
                        value,
                        0,
                        comment="store double value to pointer",
                    )
                case _:
                    raise DiagnosticException(
                        f"Lowering memref.store op with floating point type {float_type} not yet implemented"
                    )
        case _:
            raise ValueError(f"Unexpected register type {op.value.type}")

    rewriter.replace_op(op, new_op)

ConvertLoadOp dataclass

Bases: RewritePattern

Source code in xdsl/transforms/convert_ptr_to_riscv.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
@dataclass
class ConvertLoadOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: ptr.LoadOp, rewriter: PatternRewriter, /):
        casted = cast_operands_to_regs(rewriter, op)
        addr = casted[0]

        result_register_type = register_type_for_type(op.res.type)

        if issubclass(result_register_type, riscv.IntRegisterType):
            lw_op = riscv.LwOp(addr, 0, comment="load word from pointer")
        else:
            float_type = cast(AnyFloat, op.res.type)
            match float_type:
                case Float32Type():
                    lw_op = riscv.FLwOp(addr, 0, comment="load float from pointer")
                case Float64Type():
                    lw_op = riscv.FLdOp(addr, 0, comment="load double from pointer")
                case _:
                    raise DiagnosticException(
                        f"Lowering memref.load op with floating point type {float_type} not yet implemented"
                    )

        rewriter.replace_op(
            op,
            (lw := lw_op, UnrealizedConversionCastOp.get(lw.results, (op.res.type,))),
        )

__init__() -> None

match_and_rewrite(op: ptr.LoadOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/convert_ptr_to_riscv.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ptr.LoadOp, rewriter: PatternRewriter, /):
    casted = cast_operands_to_regs(rewriter, op)
    addr = casted[0]

    result_register_type = register_type_for_type(op.res.type)

    if issubclass(result_register_type, riscv.IntRegisterType):
        lw_op = riscv.LwOp(addr, 0, comment="load word from pointer")
    else:
        float_type = cast(AnyFloat, op.res.type)
        match float_type:
            case Float32Type():
                lw_op = riscv.FLwOp(addr, 0, comment="load float from pointer")
            case Float64Type():
                lw_op = riscv.FLdOp(addr, 0, comment="load double from pointer")
            case _:
                raise DiagnosticException(
                    f"Lowering memref.load op with floating point type {float_type} not yet implemented"
                )

    rewriter.replace_op(
        op,
        (lw := lw_op, UnrealizedConversionCastOp.get(lw.results, (op.res.type,))),
    )

ConvertMemRefToPtrOp dataclass

Bases: RewritePattern

Source code in xdsl/transforms/convert_ptr_to_riscv.py
111
112
113
114
115
116
117
118
119
120
@dataclass
class ConvertMemRefToPtrOp(RewritePattern):
    @op_type_rewrite_pattern
    def match_and_rewrite(self, op: ptr.ToPtrOp, rewriter: PatternRewriter, /):
        rewriter.replace_op(
            op,
            UnrealizedConversionCastOp.get(
                (op.source,), (riscv.Registers.UNALLOCATED_INT,)
            ),
        )

__init__() -> None

match_and_rewrite(op: ptr.ToPtrOp, rewriter: PatternRewriter)

Source code in xdsl/transforms/convert_ptr_to_riscv.py
113
114
115
116
117
118
119
120
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ptr.ToPtrOp, rewriter: PatternRewriter, /):
    rewriter.replace_op(
        op,
        UnrealizedConversionCastOp.get(
            (op.source,), (riscv.Registers.UNALLOCATED_INT,)
        ),
    )

ConvertPtrToRiscvPass dataclass

Bases: ModulePass

Source code in xdsl/transforms/convert_ptr_to_riscv.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class ConvertPtrToRiscvPass(ModulePass):
    name = "convert-ptr-to-riscv"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    PtrTypeConversion(),
                    ConvertPtrAddOp(),
                    ConvertStoreOp(),
                    ConvertLoadOp(),
                    ConvertMemRefToPtrOp(),
                ],
                dce_enabled=False,
            ),
        ).rewrite_module(op)

name = 'convert-ptr-to-riscv' class-attribute instance-attribute

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/transforms/convert_ptr_to_riscv.py
126
127
128
129
130
131
132
133
134
135
136
137
138
def apply(self, ctx: Context, op: ModuleOp) -> None:
    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                PtrTypeConversion(),
                ConvertPtrAddOp(),
                ConvertStoreOp(),
                ConvertLoadOp(),
                ConvertMemRefToPtrOp(),
            ],
            dce_enabled=False,
        ),
    ).rewrite_module(op)