Skip to content

Arith add fastmath

arith_add_fastmath

Passes to manipulate fastmath flags in FP arith operations.

AddArithFastMathFlags dataclass

Bases: RewritePattern

Adds fastmath flags to FP binary operations from arith dialect.

Source code in xdsl/transforms/arith_add_fastmath.py
27
28
29
30
31
32
33
34
35
36
37
38
39
@dataclass
class AddArithFastMathFlags(RewritePattern):
    """Adds fastmath flags to FP binary operations from arith dialect."""

    fastmath_op_attr: arith.FastMathFlagsAttr

    @op_type_rewrite_pattern
    def match_and_rewrite(
        self,
        op: arith.FloatingPointLikeBinaryOperation | arith.CmpfOp,
        rewriter: PatternRewriter,
    ) -> None:
        op.fastmath = self.fastmath_op_attr

fastmath_op_attr: arith.FastMathFlagsAttr instance-attribute

__init__(fastmath_op_attr: arith.FastMathFlagsAttr) -> None

match_and_rewrite(op: arith.FloatingPointLikeBinaryOperation | arith.CmpfOp, rewriter: PatternRewriter) -> None

Source code in xdsl/transforms/arith_add_fastmath.py
33
34
35
36
37
38
39
@op_type_rewrite_pattern
def match_and_rewrite(
    self,
    op: arith.FloatingPointLikeBinaryOperation | arith.CmpfOp,
    rewriter: PatternRewriter,
) -> None:
    op.fastmath = self.fastmath_op_attr

AddArithFastMathFlagsPass dataclass

Bases: ModulePass

Module pass that adds fastmath flags to FP binary operations from arith dialect. It currently does not preserve any existing fastmath flags that may already be part of the operation. By default (no arguments) it adds the "fast" flag.

Arguments (all optional):

  • flags: {"fast", "none"} | list[str]: Set specific fastmath flags. Using "fast" or "none" enables or disables all flags, respectively.
Source code in xdsl/transforms/arith_add_fastmath.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@dataclass(frozen=True)
class AddArithFastMathFlagsPass(ModulePass):
    """Module pass that adds fastmath flags to FP binary operations from arith dialect.
    It currently does not preserve any existing fastmath flags that may already be part
    of the operation.
    By default (no arguments) it adds the "fast" flag.

    Arguments (all optional):

    - flags: {"fast", "none"} | list[str]: Set specific fastmath flags. Using "fast" or
      "none" enables or disables all flags, respectively.
    """

    name = "arith-add-fastmath"

    flags: Literal["fast", "none"] | tuple[str, ...] = "fast"

    def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
        if isinstance(self.flags, str):
            fm_flags = arith.FastMathFlagsAttr(self.flags)
        else:
            if "none" in self.flags or "fast" in self.flags:
                raise ValueError(
                    'f{"none" or "fast" cannot be provided along with other fastmath flags'
                )

            fm_flags = arith.FastMathFlagsAttr(_get_flag_list(self.flags))

        PatternRewriteWalker(
            GreedyRewritePatternApplier(
                [
                    AddArithFastMathFlags(fm_flags),
                ]
            ),
            apply_recursively=False,
        ).rewrite_module(op)

name = 'arith-add-fastmath' class-attribute instance-attribute

flags: Literal['fast', 'none'] | tuple[str, ...] = 'fast' class-attribute instance-attribute

__init__(flags: Literal['fast', 'none'] | tuple[str, ...] = 'fast') -> None

apply(ctx: Context, op: builtin.ModuleOp) -> None

Source code in xdsl/transforms/arith_add_fastmath.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def apply(self, ctx: Context, op: builtin.ModuleOp) -> None:
    if isinstance(self.flags, str):
        fm_flags = arith.FastMathFlagsAttr(self.flags)
    else:
        if "none" in self.flags or "fast" in self.flags:
            raise ValueError(
                'f{"none" or "fast" cannot be provided along with other fastmath flags'
            )

        fm_flags = arith.FastMathFlagsAttr(_get_flag_list(self.flags))

    PatternRewriteWalker(
        GreedyRewritePatternApplier(
            [
                AddArithFastMathFlags(fm_flags),
            ]
        ),
        apply_recursively=False,
    ).rewrite_module(op)