From: Eugene Crosser Date: Wed, 27 Jul 2022 22:41:54 +0000 (+0200) Subject: abstract protocol selection in `common` X-Git-Tag: 1.90~8 X-Git-Url: http://average.org/gitweb/?a=commitdiff_plain;h=023da3cd78841eb34d8286cf289995be658f0fa2;p=loctrkd.git abstract protocol selection in `common` --- diff --git a/loctrkd/__main__.py b/loctrkd/__main__.py index bba38ff..8808c8a 100644 --- a/loctrkd/__main__.py +++ b/loctrkd/__main__.py @@ -17,17 +17,9 @@ from .zmsg import Bcast, Resp log = getLogger("loctrkd") -pmods: List[ProtoModule] = [] - - def main( conf: ConfigParser, opts: List[Tuple[str, str]], args: List[str] ) -> None: - global pmods - pmods = [ - cast(ProtoModule, import_module("." + modnm, __package__)) - for modnm in conf.get("common", "protocols").split(",") - ] # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?! zctx = zmq.Context() # type: ignore zpush = zctx.socket(zmq.PUSH) # type: ignore @@ -40,12 +32,8 @@ def main( imei = args[0] cmd = args[1] args = args[2:] - handled = False - for pmod in pmods: - if pmod.proto_handled(cmd): - handled = True - break - if not handled: + pmod = common.pmod_for_proto(cmd) + if pmod is None: raise NotImplementedError(f"No protocol can handle {cmd}") cls = pmod.class_by_prefix(cmd) if isinstance(cls, list): diff --git a/loctrkd/collector.py b/loctrkd/collector.py index 17c98d5..788cb11 100644 --- a/loctrkd/collector.py +++ b/loctrkd/collector.py @@ -14,7 +14,7 @@ from socket import ( ) from struct import pack from time import time -from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import zmq from . import common @@ -26,9 +26,6 @@ log = getLogger("loctrkd/collector") MAXBUFFER: int = 4096 -pmods: List[ProtoModule] = [] - - class Client: """Connected socket to the terminal plus buffer and metadata""" @@ -71,11 +68,9 @@ class Client: ) return None if self.stream is None: - for pmod in pmods: - if pmod.probe_buffer(segment): - self.pmod = pmod - self.stream = pmod.Stream() - break + self.pmod = common.probe_pmod(segment) + if self.pmod is not None: + self.stream = self.pmod.Stream() if self.stream is None: log.info( "unrecognizable %d bytes of data %s from fd %d", @@ -181,11 +176,6 @@ class Clients: def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None: - global pmods - pmods = [ - cast(ProtoModule, import_module("." + modnm, __package__)) - for modnm in conf.get("common", "protocols").split(",") - ] # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?! zctx = zmq.Context() # type: ignore zpub = zctx.socket(zmq.PUB) # type: ignore diff --git a/loctrkd/common.py b/loctrkd/common.py index 227611c..02b520f 100644 --- a/loctrkd/common.py +++ b/loctrkd/common.py @@ -1,16 +1,18 @@ """ Common housekeeping for all daemons """ -from configparser import ConfigParser, SectionProxy +from configparser import ConfigParser +from importlib import import_module from getopt import getopt from logging import Formatter, getLogger, Logger, StreamHandler, DEBUG, INFO from logging.handlers import SysLogHandler from pkg_resources import get_distribution, DistributionNotFound from sys import argv, stderr, stdout -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, cast, Dict, List, Optional, Tuple, Union + +from .protomodule import ProtoModule CONF = "/etc/loctrkd.conf" -PORT = 4303 -DBFN = "/var/lib/loctrkd/loctrkd.sqlite" +pmods: List[ProtoModule] = [] try: version = get_distribution("loctrkd").version @@ -18,13 +20,22 @@ except DistributionNotFound: version = "" +def init_protocols(conf: ConfigParser) -> None: + global pmods + pmods = [ + cast(ProtoModule, import_module("." + modnm, __package__)) + for modnm in conf.get("common", "protocols").split(",") + ] + + def init( log: Logger, opts: Optional[List[Tuple[str, str]]] = None ) -> ConfigParser: if opts is None: opts, _ = getopt(argv[1:], "c:d") dopts = dict(opts) - conf = readconfig(dopts["-c"] if "-c" in dopts else CONF) + conf = ConfigParser() + conf.read(dopts["-c"] if "-c" in dopts else CONF) log.setLevel(DEBUG if "-d" in dopts else INFO) if stdout.isatty(): fhdl = StreamHandler(stderr) @@ -40,59 +51,19 @@ def init( ) log.addHandler(lhdl) log.info("%s starting with options: %s", version, dopts) + init_protocols(conf) return conf -def readconfig(fname: str) -> ConfigParser: - config = ConfigParser() - config["collector"] = { - "port": str(PORT), - } - config["storage"] = { - "dbfn": DBFN, - } - config["termconfig"] = {} - config.read(fname) - return config - - -def normconf(section: SectionProxy) -> Dict[str, Any]: - result: Dict[str, Any] = {} - for key, val in section.items(): - vals = val.split("\n") - if len(vals) > 1 and vals[0] == "": - vals = vals[1:] - lst: List[Union[str, int]] = [] - for el in vals: - try: - lst.append(int(el, 0)) - except ValueError: - if el[0] == '"' and el[-1] == '"': - el = el.strip('"').rstrip('"') - lst.append(el) - if not ( - all([isinstance(x, int) for x in lst]) - or all([isinstance(x, str) for x in lst]) - ): - raise ValueError( - "Values of %s - %s are of different type", key, vals - ) - if len(lst) == 1: - result[key] = lst[0] - else: - result[key] = lst - return result - - -if __name__ == "__main__": - from sys import argv +def probe_pmod(segment: bytes) -> Optional[ProtoModule]: + for pmod in pmods: + if pmod.probe_buffer(segment): + return pmod + return None - def _print_config(conf: ConfigParser) -> None: - for section in conf.sections(): - print("section", section) - for option in conf.options(section): - print(" ", option, conf[section][option]) - conf = readconfig(argv[1]) - _print_config(conf) - print(normconf(conf["termconfig"])) +def pmod_for_proto(proto: str) -> Optional[ProtoModule]: + for pmod in pmods: + if pmod.proto_handled(proto): + return pmod + return None diff --git a/loctrkd/mkgpx.py b/loctrkd/mkgpx.py index 35ff77e..6d1ee27 100644 --- a/loctrkd/mkgpx.py +++ b/loctrkd/mkgpx.py @@ -19,17 +19,9 @@ from .protomodule import ProtoModule log = getLogger("loctrkd/mkgpx") -pmods: List[ProtoModule] = [] - - def main( conf: ConfigParser, opts: List[Tuple[str, str]], args: List[str] ) -> None: - global pmods - pmods = [ - cast(ProtoModule, import_module("." + modnm, __package__)) - for modnm in conf.get("common", "protocols").split(",") - ] db = connect(conf.get("storage", "dbfn")) c = db.cursor() c.execute( @@ -52,9 +44,9 @@ def main( ) for tstamp, is_incoming, proto, packet in c: - for pmod in pmods: - if pmod.proto_handled(proto): - msg = pmod.parse_message(packet, is_incoming=is_incoming) + pmod = common.pmod_for_proto(proto) + if pmod is not None: + msg = pmod.parse_message(packet, is_incoming=is_incoming) lat, lon = msg.latitude, msg.longitude isotime = ( datetime.fromtimestamp(tstamp) diff --git a/loctrkd/termconfig.py b/loctrkd/termconfig.py index b1ea80a..4968e28 100644 --- a/loctrkd/termconfig.py +++ b/loctrkd/termconfig.py @@ -1,9 +1,10 @@ """ For when responding to the terminal is not trivial """ -from configparser import ConfigParser +from configparser import ConfigParser, SectionProxy from datetime import datetime, timezone from logging import getLogger from struct import pack +from typing import Any, Dict, List, Union import zmq from . import common @@ -14,6 +15,34 @@ from .zmsg import Bcast, Resp, topic log = getLogger("loctrkd/termconfig") +def normconf(section: SectionProxy) -> Dict[str, Any]: + result: Dict[str, Any] = {} + for key, val in section.items(): + vals = val.split("\n") + if len(vals) > 1 and vals[0] == "": + vals = vals[1:] + lst: List[Union[str, int]] = [] + for el in vals: + try: + lst.append(int(el, 0)) + except ValueError: + if el[0] == '"' and el[-1] == '"': + el = el.strip('"').rstrip('"') + lst.append(el) + if not ( + all([isinstance(x, int) for x in lst]) + or all([isinstance(x, str) for x in lst]) + ): + raise ValueError( + "Values of %s - %s are of different type", key, vals + ) + if len(lst) == 1: + result[key] = lst[0] + else: + result[key] = lst + return result + + def runserver(conf: ConfigParser) -> None: # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?! zctx = zmq.Context() # type: ignore @@ -44,9 +73,9 @@ def runserver(conf: ConfigParser) -> None: "%s does not expect externally provided response", msg ) if zmsg.imei is not None and conf.has_section(zmsg.imei): - termconfig = common.normconf(conf[zmsg.imei]) + termconfig = normconf(conf[zmsg.imei]) elif conf.has_section("termconfig"): - termconfig = common.normconf(conf["termconfig"]) + termconfig = normconf(conf["termconfig"]) else: termconfig = {} kwargs = {} diff --git a/loctrkd/watch.py b/loctrkd/watch.py index 6d3dcd9..bda952c 100644 --- a/loctrkd/watch.py +++ b/loctrkd/watch.py @@ -14,15 +14,7 @@ from .zmsg import Bcast log = getLogger("loctrkd/watch") -pmods: List[ProtoModule] = [] - - def runserver(conf: ConfigParser) -> None: - global pmods - pmods = [ - cast(ProtoModule, import_module("." + modnm, __package__)) - for modnm in conf.get("common", "protocols").split(",") - ] # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?! zctx = zmq.Context() # type: ignore zsub = zctx.socket(zmq.SUB) # type: ignore @@ -33,12 +25,12 @@ def runserver(conf: ConfigParser) -> None: while True: zmsg = Bcast(zsub.recv()) print("I" if zmsg.is_incoming else "O", zmsg.proto, zmsg.imei) - for pmod in pmods: - if pmod.proto_handled(zmsg.proto): - msg = pmod.parse_message(zmsg.packet, zmsg.is_incoming) - print(msg) - if zmsg.is_incoming and hasattr(msg, "rectified"): - print(msg.rectified()) + pmod = common.pmod_for_proto(zmsg.proto) + if pmod is not None: + msg = pmod.parse_message(zmsg.packet, zmsg.is_incoming) + print(msg) + if zmsg.is_incoming and hasattr(msg, "rectified"): + print(msg.rectified()) except KeyboardInterrupt: pass diff --git a/test/common.py b/test/common.py index 284ca4d..58954a2 100644 --- a/test/common.py +++ b/test/common.py @@ -23,6 +23,8 @@ from time import sleep from typing import Optional from unittest import TestCase +from loctrkd.common import init_protocols + NUMPORTS = 3 @@ -61,6 +63,7 @@ class TestWithServers(TestCase): self.conf["wsgateway"] = { "port": str(freeports[1]), } + init_protocols(self.conf) self.children = [] for srvname in args: if srvname == "collector":