Skip to content

Convert arith to riscv snitch

convert_arith_to_riscv_snitch

lower_arith_addf = LowerBinaryFloatVectorOp(arith.AddfOp, riscv.FAddDOp, riscv_snitch.VFAddSOp, riscv_snitch.VFAddHOp) module-attribute

LowerBinaryFloatVectorOp dataclass

Bases: RewritePattern

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv_snitch.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
@dataclass
class LowerBinaryFloatVectorOp(RewritePattern):
    arith_op_cls: type[arith.FloatingPointLikeBinaryOperation]
    riscv_d_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]
    riscv_snitch_v_f32_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]
    riscv_snitch_v_f16_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]

    def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:
        if not isinstance(op, self.arith_op_cls):
            return

        operand_type = op.result.type
        if not isinstance(operand_type, VectorType):
            return
        shape = operand_type.shape
        count = prod(dim.data for dim in shape.data)

        operand_type = cast(VectorType[Any], operand_type)
        scalar_type = operand_type.element_type

        lhs = UnrealizedConversionCastOp.get(
            (op.lhs,), (riscv.Registers.UNALLOCATED_FLOAT,)
        )
        rhs = UnrealizedConversionCastOp.get(
            (op.rhs,), (riscv.Registers.UNALLOCATED_FLOAT,)
        )

        match scalar_type:
            case Float64Type():
                if count != 1:
                    return
                cls = self.riscv_d_op_cls
            case Float32Type():
                if count != 2:
                    return
                cls = self.riscv_snitch_v_f32_op_cls
            case Float16Type():
                if count != 4:
                    return
                cls = self.riscv_snitch_v_f16_op_cls
            case _:
                raise ValueError(f"Unexpected float type {op.lhs.type}")

        rv_flags = riscv.FastMathFlagsAttr(op.fastmath.data)

        new_op = cls(lhs, rhs, rd=riscv.Registers.UNALLOCATED_FLOAT, fastmath=rv_flags)
        cast_op = UnrealizedConversionCastOp.get((new_op.rd,), (op.result.type,))

        rewriter.replace_op(op, (lhs, rhs, new_op, cast_op))

arith_op_cls: type[arith.FloatingPointLikeBinaryOperation] instance-attribute

riscv_d_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath] instance-attribute

riscv_snitch_v_f32_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath] instance-attribute

riscv_snitch_v_f16_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath] instance-attribute

__init__(arith_op_cls: type[arith.FloatingPointLikeBinaryOperation], riscv_d_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath], riscv_snitch_v_f32_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath], riscv_snitch_v_f16_op_cls: type[riscv.RdRsRsFloatOperationWithFastMath]) -> None

match_and_rewrite(op: Operation, rewriter: PatternRewriter) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv_snitch.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None:
    if not isinstance(op, self.arith_op_cls):
        return

    operand_type = op.result.type
    if not isinstance(operand_type, VectorType):
        return
    shape = operand_type.shape
    count = prod(dim.data for dim in shape.data)

    operand_type = cast(VectorType[Any], operand_type)
    scalar_type = operand_type.element_type

    lhs = UnrealizedConversionCastOp.get(
        (op.lhs,), (riscv.Registers.UNALLOCATED_FLOAT,)
    )
    rhs = UnrealizedConversionCastOp.get(
        (op.rhs,), (riscv.Registers.UNALLOCATED_FLOAT,)
    )

    match scalar_type:
        case Float64Type():
            if count != 1:
                return
            cls = self.riscv_d_op_cls
        case Float32Type():
            if count != 2:
                return
            cls = self.riscv_snitch_v_f32_op_cls
        case Float16Type():
            if count != 4:
                return
            cls = self.riscv_snitch_v_f16_op_cls
        case _:
            raise ValueError(f"Unexpected float type {op.lhs.type}")

    rv_flags = riscv.FastMathFlagsAttr(op.fastmath.data)

    new_op = cls(lhs, rhs, rd=riscv.Registers.UNALLOCATED_FLOAT, fastmath=rv_flags)
    cast_op = UnrealizedConversionCastOp.get((new_op.rd,), (op.result.type,))

    rewriter.replace_op(op, (lhs, rhs, new_op, cast_op))

ConvertArithToRiscvSnitchPass dataclass

Bases: ModulePass

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv_snitch.py
76
77
78
79
80
81
82
83
84
class ConvertArithToRiscvSnitchPass(ModulePass):
    name = "convert-arith-to-riscv-snitch"

    def apply(self, ctx: Context, op: ModuleOp) -> None:
        walker = PatternRewriteWalker(
            lower_arith_addf,
            apply_recursively=False,
        )
        walker.rewrite_module(op)

name = 'convert-arith-to-riscv-snitch' class-attribute instance-attribute

apply(ctx: Context, op: ModuleOp) -> None

Source code in xdsl/backend/riscv/lowering/convert_arith_to_riscv_snitch.py
79
80
81
82
83
84
def apply(self, ctx: Context, op: ModuleOp) -> None:
    walker = PatternRewriteWalker(
        lower_arith_addf,
        apply_recursively=False,
    )
    walker.rewrite_module(op)