Skip to content

Commit

Permalink
馃敡 Add typing of rule functions (#283)
Browse files Browse the repository at this point in the history
Rule functions signature is specific to the state it acts on.
  • Loading branch information
chrisjsewell committed Jun 2, 2023
1 parent 64965cf commit 90b367d
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 48 deletions.
1 change: 0 additions & 1 deletion docs/conf.py
Expand Up @@ -53,7 +53,6 @@
".*Literal.*",
".*_Result",
"EnvType",
"RuleFunc",
"Path",
"Ellipsis",
)
Expand Down
12 changes: 9 additions & 3 deletions markdown_it/parser_block.py
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Callable

from . import rules_block
from .ruler import Ruler
Expand All @@ -16,7 +16,13 @@
LOGGER = logging.getLogger(__name__)


_rules: list[tuple[str, Any, list[str]]] = [
RuleFuncBlockType = Callable[[StateBlock, int, int, bool], bool]
"""(state: StateBlock, startLine: int, endLine: int, silent: bool) -> matched: bool)
`silent` disables token generation, useful for lookahead.
"""

_rules: list[tuple[str, RuleFuncBlockType, list[str]]] = [
# First 2 params - rule name & source. Secondary array - list of rules,
# which can be terminated by this one.
("table", rules_block.table, ["paragraph", "reference"]),
Expand Down Expand Up @@ -45,7 +51,7 @@ class ParserBlock:
"""

def __init__(self) -> None:
self.ruler = Ruler()
self.ruler = Ruler[RuleFuncBlockType]()
for name, rule, alt in _rules:
self.ruler.push(name, rule, {"alt": alt})

Expand Down
10 changes: 7 additions & 3 deletions markdown_it/parser_core.py
Expand Up @@ -6,7 +6,9 @@
"""
from __future__ import annotations

from .ruler import RuleFunc, Ruler
from typing import Callable

from .ruler import Ruler
from .rules_core import (
block,
inline,
Expand All @@ -18,7 +20,9 @@
)
from .rules_core.state_core import StateCore

_rules: list[tuple[str, RuleFunc]] = [
RuleFuncCoreType = Callable[[StateCore], None]

_rules: list[tuple[str, RuleFuncCoreType]] = [
("normalize", normalize),
("block", block),
("inline", inline),
Expand All @@ -31,7 +35,7 @@

class ParserCore:
def __init__(self) -> None:
self.ruler = Ruler()
self.ruler = Ruler[RuleFuncCoreType]()
for name, rule in _rules:
self.ruler.push(name, rule)

Expand Down
19 changes: 13 additions & 6 deletions markdown_it/parser_inline.py
Expand Up @@ -2,19 +2,25 @@
"""
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable

from . import rules_inline
from .ruler import RuleFunc, Ruler
from .ruler import Ruler
from .rules_inline.state_inline import StateInline
from .token import Token
from .utils import EnvType

if TYPE_CHECKING:
from markdown_it import MarkdownIt


# Parser rules
_rules: list[tuple[str, RuleFunc]] = [
RuleFuncInlineType = Callable[[StateInline, bool], bool]
"""(state: StateInline, silent: bool) -> matched: bool)
`silent` disables token generation, useful for lookahead.
"""
_rules: list[tuple[str, RuleFuncInlineType]] = [
("text", rules_inline.text),
("linkify", rules_inline.linkify),
("newline", rules_inline.newline),
Expand All @@ -34,7 +40,8 @@
#
# Don't use this for anything except pairs (plugins working with `balance_pairs`).
#
_rules2: list[tuple[str, RuleFunc]] = [
RuleFuncInline2Type = Callable[[StateInline], None]
_rules2: list[tuple[str, RuleFuncInline2Type]] = [
("balance_pairs", rules_inline.link_pairs),
("strikethrough", rules_inline.strikethrough.postProcess),
("emphasis", rules_inline.emphasis.postProcess),
Expand All @@ -46,11 +53,11 @@

class ParserInline:
def __init__(self) -> None:
self.ruler = Ruler()
self.ruler = Ruler[RuleFuncInlineType]()
for name, rule in _rules:
self.ruler.push(name, rule)
# Second ruler used for post-processing (e.g. in emphasis-like rules)
self.ruler2 = Ruler()
self.ruler2 = Ruler[RuleFuncInline2Type]()
for name, rule2 in _rules2:
self.ruler2.push(name, rule2)

Expand Down
45 changes: 23 additions & 22 deletions markdown_it/ruler.py
Expand Up @@ -17,9 +17,9 @@ class Ruler
"""
from __future__ import annotations

from collections.abc import Callable, Iterable
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, TypedDict
from typing import TYPE_CHECKING, Generic, TypedDict, TypeVar
import warnings

from markdown_it._compat import DATACLASS_KWARGS
Expand Down Expand Up @@ -57,33 +57,30 @@ def srcCharCode(self) -> tuple[int, ...]:
return self._srcCharCode


# The first positional arg is always a subtype of `StateBase`. Other
# arguments may or may not exist, based on the rule's type (block,
# core, inline). Return type is either `None` or `bool` based on the
# rule's type.
RuleFunc = Callable # type: ignore


class RuleOptionsType(TypedDict, total=False):
alt: list[str]


RuleFuncTv = TypeVar("RuleFuncTv")
"""A rule function, whose signature is dependent on the state type."""


@dataclass(**DATACLASS_KWARGS)
class Rule:
class Rule(Generic[RuleFuncTv]):
name: str
enabled: bool
fn: RuleFunc = field(repr=False)
fn: RuleFuncTv = field(repr=False)
alt: list[str]


class Ruler:
class Ruler(Generic[RuleFuncTv]):
def __init__(self) -> None:
# List of added rules.
self.__rules__: list[Rule] = []
self.__rules__: list[Rule[RuleFuncTv]] = []
# Cached rule chains.
# First level - chain name, '' for default.
# Second level - diginal anchor for fast filtering by charcodes.
self.__cache__: dict[str, list[RuleFunc]] | None = None
self.__cache__: dict[str, list[RuleFuncTv]] | None = None

def __find__(self, name: str) -> int:
"""Find rule index by name"""
Expand Down Expand Up @@ -112,7 +109,7 @@ def __compile__(self) -> None:
self.__cache__[chain].append(rule.fn)

def at(
self, ruleName: str, fn: RuleFunc, options: RuleOptionsType | None = None
self, ruleName: str, fn: RuleFuncTv, options: RuleOptionsType | None = None
) -> None:
"""Replace rule by name with new function & options.
Expand All @@ -133,7 +130,7 @@ def before(
self,
beforeName: str,
ruleName: str,
fn: RuleFunc,
fn: RuleFuncTv,
options: RuleOptionsType | None = None,
) -> None:
"""Add new rule to chain before one with given name.
Expand All @@ -148,14 +145,16 @@ def before(
options = options or {}
if index == -1:
raise KeyError(f"Parser rule not found: {beforeName}")
self.__rules__.insert(index, Rule(ruleName, True, fn, options.get("alt", [])))
self.__rules__.insert(
index, Rule[RuleFuncTv](ruleName, True, fn, options.get("alt", []))
)
self.__cache__ = None

def after(
self,
afterName: str,
ruleName: str,
fn: RuleFunc,
fn: RuleFuncTv,
options: RuleOptionsType | None = None,
) -> None:
"""Add new rule to chain after one with given name.
Expand All @@ -171,12 +170,12 @@ def after(
if index == -1:
raise KeyError(f"Parser rule not found: {afterName}")
self.__rules__.insert(
index + 1, Rule(ruleName, True, fn, options.get("alt", []))
index + 1, Rule[RuleFuncTv](ruleName, True, fn, options.get("alt", []))
)
self.__cache__ = None

def push(
self, ruleName: str, fn: RuleFunc, options: RuleOptionsType | None = None
self, ruleName: str, fn: RuleFuncTv, options: RuleOptionsType | None = None
) -> None:
"""Push new rule to the end of chain.
Expand All @@ -185,7 +184,9 @@ def push(
:param options: new rule options (not mandatory).
"""
self.__rules__.append(Rule(ruleName, True, fn, (options or {}).get("alt", [])))
self.__rules__.append(
Rule[RuleFuncTv](ruleName, True, fn, (options or {}).get("alt", []))
)
self.__cache__ = None

def enable(
Expand Down Expand Up @@ -252,7 +253,7 @@ def disable(
self.__cache__ = None
return result

def getRules(self, chainName: str) -> list[RuleFunc]:
def getRules(self, chainName: str = "") -> list[RuleFuncTv]:
"""Return array of active functions (rules) for given chain name.
It analyzes rules configuration, compiles caches if not exists and returns result.
Expand Down
3 changes: 1 addition & 2 deletions markdown_it/rules_block/lheading.py
@@ -1,7 +1,6 @@
# lheading (---, ==)
import logging

from ..ruler import Ruler
from .state_block import StateBlock

LOGGER = logging.getLogger(__name__)
Expand All @@ -12,7 +11,7 @@ def lheading(state: StateBlock, startLine: int, endLine: int, silent: bool) -> b

level = None
nextLine = startLine + 1
ruler: Ruler = state.md.block.ruler
ruler = state.md.block.ruler
terminatorRules = ruler.getRules("paragraph")

if state.is_code_block(startLine):
Expand Down
3 changes: 1 addition & 2 deletions markdown_it/rules_block/paragraph.py
@@ -1,7 +1,6 @@
"""Paragraph."""
import logging

from ..ruler import Ruler
from .state_block import StateBlock

LOGGER = logging.getLogger(__name__)
Expand All @@ -13,7 +12,7 @@ def paragraph(state: StateBlock, startLine: int, endLine: int, silent: bool) ->
)

nextLine = startLine + 1
ruler: Ruler = state.md.block.ruler
ruler = state.md.block.ruler
terminatorRules = ruler.getRules("paragraph")
endLine = state.lineMax

Expand Down
20 changes: 11 additions & 9 deletions tests/test_api/test_plugin_creation.py
Expand Up @@ -6,26 +6,27 @@

def inline_rule(state, silent):
print("plugin called")
return False


def test_inline_after(capsys):
def _plugin(_md):
def _plugin(_md: MarkdownIt) -> None:
_md.inline.ruler.after("text", "new_rule", inline_rule)

MarkdownIt().use(_plugin).parse("[")
assert "plugin called" in capsys.readouterr().out


def test_inline_before(capsys):
def _plugin(_md):
def _plugin(_md: MarkdownIt) -> None:
_md.inline.ruler.before("text", "new_rule", inline_rule)

MarkdownIt().use(_plugin).parse("a")
assert "plugin called" in capsys.readouterr().out


def test_inline_at(capsys):
def _plugin(_md):
def _plugin(_md: MarkdownIt) -> None:
_md.inline.ruler.at("text", inline_rule)

MarkdownIt().use(_plugin).parse("a")
Expand All @@ -34,26 +35,27 @@ def _plugin(_md):

def block_rule(state, startLine, endLine, silent):
print("plugin called")
return False


def test_block_after(capsys):
def _plugin(_md):
def _plugin(_md: MarkdownIt) -> None:
_md.block.ruler.after("hr", "new_rule", block_rule)

MarkdownIt().use(_plugin).parse("a")
assert "plugin called" in capsys.readouterr().out


def test_block_before(capsys):
def _plugin(_md):
def _plugin(_md: MarkdownIt) -> None:
_md.block.ruler.before("hr", "new_rule", block_rule)

MarkdownIt().use(_plugin).parse("a")
assert "plugin called" in capsys.readouterr().out


def test_block_at(capsys):
def _plugin(_md):
def _plugin(_md: MarkdownIt) -> None:
_md.block.ruler.at("hr", block_rule)

MarkdownIt().use(_plugin).parse("a")
Expand All @@ -65,23 +67,23 @@ def core_rule(state):


def test_core_after(capsys):
def _plugin(_md):
def _plugin(_md: MarkdownIt) -> None:
_md.core.ruler.after("normalize", "new_rule", core_rule)

MarkdownIt().use(_plugin).parse("a")
assert "plugin called" in capsys.readouterr().out


def test_core_before(capsys):
def _plugin(_md):
def _plugin(_md: MarkdownIt) -> None:
_md.core.ruler.before("normalize", "new_rule", core_rule)

MarkdownIt().use(_plugin).parse("a")
assert "plugin called" in capsys.readouterr().out


def test_core_at(capsys):
def _plugin(_md):
def _plugin(_md: MarkdownIt) -> None:
_md.core.ruler.at("normalize", core_rule)

MarkdownIt().use(_plugin).parse("a")
Expand Down

0 comments on commit 90b367d

Please sign in to comment.