Module graia.application.message.parser.literature

Expand source code
import re
import shlex
import getopt
import itertools
from typing import Dict, List, Tuple
from graia.broadcast.entities.dispatcher import BaseDispatcher
from graia.broadcast.entities.signatures import Force
from graia.broadcast.exceptions import ExecutionStop

from graia.broadcast.interfaces.dispatcher import DispatcherInterface
from graia.broadcast.utilles import printer
from graia.application.message.chain import MessageChain, MessageIndex
from graia.application.message.elements import Element
from graia.application.message.elements.internal import (
    At,
    App,
    Json,
    Plain,
    Quote,
    Source,
    Xml,
    Voice,
    Poke,
    FlashImage,
)

from graia.application.message.parser.pattern import (
    BoxParameter,
    ParamPattern,
    SwitchParameter,
)

BLOCKING_ELEMENTS = (Xml, Json, App, Poke, Voice, FlashImage)


class Literature(BaseDispatcher):
    "旅途的浪漫"

    always = False
    prefixs: Tuple[str]  # 匹配前缀
    arguments: Dict[str, ParamPattern]

    allow_quote: bool
    skip_one_at_in_quote: bool

    def __init__(
        self,
        *prefixs,
        arguments: Dict[str, ParamPattern] = None,
        allow_quote: bool = False,
        skip_one_at_in_quote: bool = False,
    ) -> None:
        self.prefixs = prefixs
        self.arguments = arguments or {}
        self.allow_quote = allow_quote
        self.skip_one_at_in_quote = skip_one_at_in_quote

    def trans_to_map(self, message_chain: MessageChain):
        string_result: List[str] = []
        id_elem_map: Dict[int, Element] = {}

        for elem in message_chain.__root__:
            if isinstance(elem, Plain):
                string_result.append(
                    re.sub(
                        r"\$(?P<id>\d+)",
                        lambda match: f'\\${match.group("id")}',
                        elem.text,
                    )
                )
            else:
                index = len(id_elem_map) + 1
                string_result.append(f"${index}")
                id_elem_map[index] = elem

        return ("".join(string_result), id_elem_map)

    def gen_long_map(self):
        result = {}
        for param_name, arg in self.arguments.items():
            for long in arg.longs:
                if long in result:
                    raise ValueError("conflict item")
                result[long] = param_name
        return result

    def gen_short_map(self):
        result = {}
        for param_name, arg in self.arguments.items():
            if arg.short in result:
                raise ValueError("conflict item")
            result[arg.short] = param_name
        return result

    def gen_long_map_with_bar(self):
        return {("--" + k): v for k, v in self.gen_long_map().items()}

    def gen_short_map_with_bar(self):
        return {("-" + k): v for k, v in self.gen_short_map().items() if k is not None}

    def parse_message(self, message_chain: MessageChain):
        string_result, id_elem_map = self.trans_to_map(message_chain)

        parsed_args, variables = getopt.getopt(
            shlex.split(string_result),
            "".join(
                [
                    arg.short if isinstance(arg, SwitchParameter) else (arg.short + ":")
                    for arg in self.arguments.values()
                    if arg.short
                ]
            ),
            [
                long if isinstance(arg, SwitchParameter) else long + "="
                for arg in self.arguments.values()
                for long in arg.longs
            ],
        )
        map_with_bar = {**self.gen_long_map_with_bar(), **self.gen_short_map_with_bar()}
        parsed_args = {
            map_with_bar[k]: (
                MessageChain.create(
                    [
                        Plain(i)
                        if not re.match("^\$\d+$", i)
                        else id_elem_map[int(i[1:])]
                        for i in re.split(r"((?<!\\)\$[0-9]+)", v)
                        if i
                    ]
                ).asMerged()
                if isinstance(self.arguments[map_with_bar[k]], BoxParameter)
                else (
                    self.arguments[map_with_bar[k]].auto_reverse
                    and not self.arguments[map_with_bar[k]].default
                    or True
                ),
                self.arguments[map_with_bar[k]],
            )
            for k, v in parsed_args
        }
        variables = [
            MessageChain.create(
                [
                    Plain(i) if not re.match("^\$\d+$", i) else id_elem_map[int(i[1:])]
                    for i in re.split(r"((?<!\\)\$[0-9]+)", v)
                    if i
                ]
            ).asMerged()
            for v in variables
        ]
        for param_name, argument_setting in self.arguments.items():
            if param_name not in parsed_args:
                if argument_setting.default is not None:
                    parsed_args[param_name] = (
                        argument_setting.default,
                        argument_setting,
                    )
                else:
                    raise ExecutionStop()

        return (parsed_args, variables)

    def prefix_match(self, target_chain: MessageChain):
        target_chain = target_chain.asMerged()

        chain_frames: List[MessageChain] = target_chain.split(" ", raw_string=True)

        # 前缀匹配
        if len(self.prefixs) > len(chain_frames):
            return
        for index, current_prefix in enumerate(self.prefixs):
            current_frame = chain_frames[index]
            if (
                not current_frame.__root__
                or type(current_frame.__root__[0]) is not Plain
            ):
                return
            if current_frame.__root__[0].text != current_prefix:
                return

        chain_frames = chain_frames[len(self.prefixs) :]
        return MessageChain.create(
            list(itertools.chain(*[i.__root__ + [Plain(" ")] for i in chain_frames]))[
                :-1
            ]
        ).asMerged()

    async def beforeDispatch(self, interface: DispatcherInterface):
        message_chain: MessageChain = (
            await interface.lookup_param(
                "__literature_messagechain__", MessageChain, None
            )
        ).exclude(Source)
        if set([i.__class__ for i in message_chain.__root__]).intersection(
            BLOCKING_ELEMENTS
        ):
            raise ExecutionStop()
        if self.allow_quote and message_chain.has(Quote):
            # 自动忽略自 Quote 后第一个 At
            message_chain = message_chain[(1, None):]
            if self.skip_one_at_in_quote and message_chain.__root__:
                if message_chain.__root__[0].__class__ is At:
                    message_chain = message_chain[(1, 1):]
        noprefix = self.prefix_match(message_chain)
        if noprefix is None:
            raise ExecutionStop()

        interface.execution_contexts[-1].literature_detect_result = self.parse_message(
            noprefix
        )

    async def catch(self, interface: DispatcherInterface):
        if interface.name == "__literature_messagechain__":
            return

        result = interface.execution_contexts[-1].literature_detect_result
        if result:
            match_result, variargs = result
            if interface.default == "__literature_variables__":
                return variargs

            arg_fetch_result = match_result.get(interface.name)
            if arg_fetch_result:

                match_value, raw_argument = arg_fetch_result
                if isinstance(raw_argument, SwitchParameter):
                    return Force(match_value)
                elif interface.annotation is ParamPattern:
                    return raw_argument
                elif match_value is not None:
                    return match_value


if __name__ == "__main__":
    from graia.application.message.elements.internal import AtAll, At

    mc = MessageChain.create(
        [
            Plain('test n --f3 "1 2 tsthd thsd ydj re7u  '),
            At(351453455),
            Plain(' " --f34 "arg arega er ae aghr ae rtyh'),
            # At(656735757),
            Plain(' "'),
        ]
    )

    l = Literature(
        "test",
        "n",
        arguments={
            "a": BoxParameter(["test_f1", "f23"], "f"),
            "b": SwitchParameter(["f34"], "d"),
        },
    )
    from devtools import debug

    # debug(l.prefix_match(mc))
    debug(l.parse_message(l.prefix_match(mc)))
    print(mc.asDisplay())

Classes

class Literature (*prefixs, arguments: Dict[str, ParamPattern] = None, allow_quote: bool = False, skip_one_at_in_quote: bool = False)

旅途的浪漫

Expand source code
class Literature(BaseDispatcher):
    "旅途的浪漫"

    always = False
    prefixs: Tuple[str]  # 匹配前缀
    arguments: Dict[str, ParamPattern]

    allow_quote: bool
    skip_one_at_in_quote: bool

    def __init__(
        self,
        *prefixs,
        arguments: Dict[str, ParamPattern] = None,
        allow_quote: bool = False,
        skip_one_at_in_quote: bool = False,
    ) -> None:
        self.prefixs = prefixs
        self.arguments = arguments or {}
        self.allow_quote = allow_quote
        self.skip_one_at_in_quote = skip_one_at_in_quote

    def trans_to_map(self, message_chain: MessageChain):
        string_result: List[str] = []
        id_elem_map: Dict[int, Element] = {}

        for elem in message_chain.__root__:
            if isinstance(elem, Plain):
                string_result.append(
                    re.sub(
                        r"\$(?P<id>\d+)",
                        lambda match: f'\\${match.group("id")}',
                        elem.text,
                    )
                )
            else:
                index = len(id_elem_map) + 1
                string_result.append(f"${index}")
                id_elem_map[index] = elem

        return ("".join(string_result), id_elem_map)

    def gen_long_map(self):
        result = {}
        for param_name, arg in self.arguments.items():
            for long in arg.longs:
                if long in result:
                    raise ValueError("conflict item")
                result[long] = param_name
        return result

    def gen_short_map(self):
        result = {}
        for param_name, arg in self.arguments.items():
            if arg.short in result:
                raise ValueError("conflict item")
            result[arg.short] = param_name
        return result

    def gen_long_map_with_bar(self):
        return {("--" + k): v for k, v in self.gen_long_map().items()}

    def gen_short_map_with_bar(self):
        return {("-" + k): v for k, v in self.gen_short_map().items() if k is not None}

    def parse_message(self, message_chain: MessageChain):
        string_result, id_elem_map = self.trans_to_map(message_chain)

        parsed_args, variables = getopt.getopt(
            shlex.split(string_result),
            "".join(
                [
                    arg.short if isinstance(arg, SwitchParameter) else (arg.short + ":")
                    for arg in self.arguments.values()
                    if arg.short
                ]
            ),
            [
                long if isinstance(arg, SwitchParameter) else long + "="
                for arg in self.arguments.values()
                for long in arg.longs
            ],
        )
        map_with_bar = {**self.gen_long_map_with_bar(), **self.gen_short_map_with_bar()}
        parsed_args = {
            map_with_bar[k]: (
                MessageChain.create(
                    [
                        Plain(i)
                        if not re.match("^\$\d+$", i)
                        else id_elem_map[int(i[1:])]
                        for i in re.split(r"((?<!\\)\$[0-9]+)", v)
                        if i
                    ]
                ).asMerged()
                if isinstance(self.arguments[map_with_bar[k]], BoxParameter)
                else (
                    self.arguments[map_with_bar[k]].auto_reverse
                    and not self.arguments[map_with_bar[k]].default
                    or True
                ),
                self.arguments[map_with_bar[k]],
            )
            for k, v in parsed_args
        }
        variables = [
            MessageChain.create(
                [
                    Plain(i) if not re.match("^\$\d+$", i) else id_elem_map[int(i[1:])]
                    for i in re.split(r"((?<!\\)\$[0-9]+)", v)
                    if i
                ]
            ).asMerged()
            for v in variables
        ]
        for param_name, argument_setting in self.arguments.items():
            if param_name not in parsed_args:
                if argument_setting.default is not None:
                    parsed_args[param_name] = (
                        argument_setting.default,
                        argument_setting,
                    )
                else:
                    raise ExecutionStop()

        return (parsed_args, variables)

    def prefix_match(self, target_chain: MessageChain):
        target_chain = target_chain.asMerged()

        chain_frames: List[MessageChain] = target_chain.split(" ", raw_string=True)

        # 前缀匹配
        if len(self.prefixs) > len(chain_frames):
            return
        for index, current_prefix in enumerate(self.prefixs):
            current_frame = chain_frames[index]
            if (
                not current_frame.__root__
                or type(current_frame.__root__[0]) is not Plain
            ):
                return
            if current_frame.__root__[0].text != current_prefix:
                return

        chain_frames = chain_frames[len(self.prefixs) :]
        return MessageChain.create(
            list(itertools.chain(*[i.__root__ + [Plain(" ")] for i in chain_frames]))[
                :-1
            ]
        ).asMerged()

    async def beforeDispatch(self, interface: DispatcherInterface):
        message_chain: MessageChain = (
            await interface.lookup_param(
                "__literature_messagechain__", MessageChain, None
            )
        ).exclude(Source)
        if set([i.__class__ for i in message_chain.__root__]).intersection(
            BLOCKING_ELEMENTS
        ):
            raise ExecutionStop()
        if self.allow_quote and message_chain.has(Quote):
            # 自动忽略自 Quote 后第一个 At
            message_chain = message_chain[(1, None):]
            if self.skip_one_at_in_quote and message_chain.__root__:
                if message_chain.__root__[0].__class__ is At:
                    message_chain = message_chain[(1, 1):]
        noprefix = self.prefix_match(message_chain)
        if noprefix is None:
            raise ExecutionStop()

        interface.execution_contexts[-1].literature_detect_result = self.parse_message(
            noprefix
        )

    async def catch(self, interface: DispatcherInterface):
        if interface.name == "__literature_messagechain__":
            return

        result = interface.execution_contexts[-1].literature_detect_result
        if result:
            match_result, variargs = result
            if interface.default == "__literature_variables__":
                return variargs

            arg_fetch_result = match_result.get(interface.name)
            if arg_fetch_result:

                match_value, raw_argument = arg_fetch_result
                if isinstance(raw_argument, SwitchParameter):
                    return Force(match_value)
                elif interface.annotation is ParamPattern:
                    return raw_argument
                elif match_value is not None:
                    return match_value

Ancestors

Class variables

var allow_quote : bool
var always
var arguments : Dict[str, ParamPattern]
var prefixs : Tuple[str]
var skip_one_at_in_quote : bool

Methods

def gen_long_map(self)
Expand source code
def gen_long_map(self):
    result = {}
    for param_name, arg in self.arguments.items():
        for long in arg.longs:
            if long in result:
                raise ValueError("conflict item")
            result[long] = param_name
    return result
def gen_long_map_with_bar(self)
Expand source code
def gen_long_map_with_bar(self):
    return {("--" + k): v for k, v in self.gen_long_map().items()}
def gen_short_map(self)
Expand source code
def gen_short_map(self):
    result = {}
    for param_name, arg in self.arguments.items():
        if arg.short in result:
            raise ValueError("conflict item")
        result[arg.short] = param_name
    return result
def gen_short_map_with_bar(self)
Expand source code
def gen_short_map_with_bar(self):
    return {("-" + k): v for k, v in self.gen_short_map().items() if k is not None}
def parse_message(self, message_chain: MessageChain)
Expand source code
def parse_message(self, message_chain: MessageChain):
    string_result, id_elem_map = self.trans_to_map(message_chain)

    parsed_args, variables = getopt.getopt(
        shlex.split(string_result),
        "".join(
            [
                arg.short if isinstance(arg, SwitchParameter) else (arg.short + ":")
                for arg in self.arguments.values()
                if arg.short
            ]
        ),
        [
            long if isinstance(arg, SwitchParameter) else long + "="
            for arg in self.arguments.values()
            for long in arg.longs
        ],
    )
    map_with_bar = {**self.gen_long_map_with_bar(), **self.gen_short_map_with_bar()}
    parsed_args = {
        map_with_bar[k]: (
            MessageChain.create(
                [
                    Plain(i)
                    if not re.match("^\$\d+$", i)
                    else id_elem_map[int(i[1:])]
                    for i in re.split(r"((?<!\\)\$[0-9]+)", v)
                    if i
                ]
            ).asMerged()
            if isinstance(self.arguments[map_with_bar[k]], BoxParameter)
            else (
                self.arguments[map_with_bar[k]].auto_reverse
                and not self.arguments[map_with_bar[k]].default
                or True
            ),
            self.arguments[map_with_bar[k]],
        )
        for k, v in parsed_args
    }
    variables = [
        MessageChain.create(
            [
                Plain(i) if not re.match("^\$\d+$", i) else id_elem_map[int(i[1:])]
                for i in re.split(r"((?<!\\)\$[0-9]+)", v)
                if i
            ]
        ).asMerged()
        for v in variables
    ]
    for param_name, argument_setting in self.arguments.items():
        if param_name not in parsed_args:
            if argument_setting.default is not None:
                parsed_args[param_name] = (
                    argument_setting.default,
                    argument_setting,
                )
            else:
                raise ExecutionStop()

    return (parsed_args, variables)
def prefix_match(self, target_chain: MessageChain)
Expand source code
def prefix_match(self, target_chain: MessageChain):
    target_chain = target_chain.asMerged()

    chain_frames: List[MessageChain] = target_chain.split(" ", raw_string=True)

    # 前缀匹配
    if len(self.prefixs) > len(chain_frames):
        return
    for index, current_prefix in enumerate(self.prefixs):
        current_frame = chain_frames[index]
        if (
            not current_frame.__root__
            or type(current_frame.__root__[0]) is not Plain
        ):
            return
        if current_frame.__root__[0].text != current_prefix:
            return

    chain_frames = chain_frames[len(self.prefixs) :]
    return MessageChain.create(
        list(itertools.chain(*[i.__root__ + [Plain(" ")] for i in chain_frames]))[
            :-1
        ]
    ).asMerged()
def trans_to_map(self, message_chain: MessageChain)
Expand source code
def trans_to_map(self, message_chain: MessageChain):
    string_result: List[str] = []
    id_elem_map: Dict[int, Element] = {}

    for elem in message_chain.__root__:
        if isinstance(elem, Plain):
            string_result.append(
                re.sub(
                    r"\$(?P<id>\d+)",
                    lambda match: f'\\${match.group("id")}',
                    elem.text,
                )
            )
        else:
            index = len(id_elem_map) + 1
            string_result.append(f"${index}")
            id_elem_map[index] = elem

    return ("".join(string_result), id_elem_map)

Inherited members