X-Git-Url: http://average.org/gitweb/?a=blobdiff_plain;f=loctrkd%2Fcollector.py;h=788cb11ca1ee8b186838ea1ef882a5b6faf4ea83;hb=1a544ee9aa7429cd1a6a3a819c111c1b58d02507;hp=136ecba99aeeb602ee85028caadc0caa0caf744d;hpb=63a086cf3956b93f760b1a0344afd757e0d0392f;p=loctrkd.git diff --git a/loctrkd/collector.py b/loctrkd/collector.py index 136ecba..788cb11 100644 --- a/loctrkd/collector.py +++ b/loctrkd/collector.py @@ -14,10 +14,11 @@ from socket import ( ) from struct import pack from time import time -from typing import Any, cast, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import zmq from . import common +from .protomodule import ProtoModule from .zmsg import Bcast, Resp log = getLogger("loctrkd/collector") @@ -25,54 +26,10 @@ log = getLogger("loctrkd/collector") MAXBUFFER: int = 4096 -class ProtoModule: - class Stream: - @staticmethod - def enframe(buffer: bytes, imei: Optional[str] = None) -> bytes: - ... - - def recv(self, segment: bytes) -> List[Union[bytes, str]]: - ... - - def close(self) -> bytes: - ... - - @staticmethod - def probe_buffer(buffer: bytes) -> bool: - ... - - @staticmethod - def parse_message(packet: bytes, is_incoming: bool = True) -> Any: - ... - - @staticmethod - def inline_response(packet: bytes) -> Optional[bytes]: - ... - - @staticmethod - def is_goodbye_packet(packet: bytes) -> bool: - ... - - @staticmethod - def imei_from_packet(packet: bytes) -> Optional[str]: - ... - - @staticmethod - def proto_of_message(packet: bytes) -> str: - ... - - @staticmethod - def proto_by_name(name: str) -> int: - ... - - -pmods: List[ProtoModule] = [] - - class Client: """Connected socket to the terminal plus buffer and metadata""" - def __init__(self, sock: socket, addr: Tuple[str, int]) -> None: + def __init__(self, sock: socket, addr: Any) -> None: self.sock = sock self.addr = addr self.pmod: Optional[ProtoModule] = None @@ -91,7 +48,7 @@ class Client: "%d bytes in buffer on close: %s", len(rest), rest[:64].hex() ) - def recv(self) -> Optional[List[Tuple[float, Tuple[str, int], bytes]]]: + def recv(self) -> Optional[List[Tuple[float, Any, bytes]]]: """Read from the socket and parse complete messages""" try: segment = self.sock.recv(MAXBUFFER) @@ -111,11 +68,9 @@ class Client: ) return None if self.stream is None: - for pmod in pmods: - if pmod.probe_buffer(segment): - self.pmod = pmod - self.stream = pmod.Stream() - break + self.pmod = common.probe_pmod(segment) + if self.pmod is not None: + self.stream = self.pmod.Stream() if self.stream is None: log.info( "unrecognizable %d bytes of data %s from fd %d", @@ -139,9 +94,9 @@ class Client: return msgs def send(self, buffer: bytes) -> None: - assert self.stream is not None + assert self.stream is not None and self.pmod is not None try: - self.sock.send(self.stream.enframe(buffer, imei=self.imei)) + self.sock.send(self.pmod.enframe(buffer, imei=self.imei)) except OSError as e: log.error( "Sending to fd %d (IMEI %s): %s", @@ -156,25 +111,32 @@ class Clients: self.by_fd: Dict[int, Client] = {} self.by_imei: Dict[str, Client] = {} - def add(self, clntsock: socket, clntaddr: Tuple[str, int]) -> int: + def fds(self) -> Set[int]: + return set(self.by_fd.keys()) + + def add(self, clntsock: socket, clntaddr: Any) -> int: fd = clntsock.fileno() log.info("Start serving fd %d from %s", fd, clntaddr) self.by_fd[fd] = Client(clntsock, clntaddr) return fd def stop(self, fd: int) -> None: + if fd not in self.by_fd: + log.debug("Fd %d is not served, ingore stop", fd) + return clnt = self.by_fd[fd] log.info("Stop serving fd %d (IMEI %s)", clnt.sock.fileno(), clnt.imei) clnt.close() - if clnt.imei: + if clnt.imei and self.by_imei[clnt.imei] == clnt: # could be replaced del self.by_imei[clnt.imei] del self.by_fd[fd] def recv( self, fd: int - ) -> Optional[ - List[Tuple[ProtoModule, Optional[str], float, Tuple[str, int], bytes]] - ]: + ) -> Optional[List[Tuple[ProtoModule, Optional[str], float, Any, bytes]]]: + if fd not in self.by_fd: + log.debug("Client at fd %d gone, ingore event", fd) + return None clnt = self.by_fd[fd] msgs = clnt.recv() if msgs is None: @@ -189,18 +151,11 @@ class Clients: clnt.imei = imei oldclnt = self.by_imei.get(clnt.imei) if oldclnt is not None: - log.info( - "Orphaning fd %d with the same IMEI", - oldclnt.sock.fileno(), - ) + oldfd = oldclnt.sock.fileno() + log.info("Removing stale connection on fd %d", oldfd) oldclnt.imei = None + self.stop(oldfd) self.by_imei[clnt.imei] = clnt - else: - log.warning( - "Login message from %s: %s, but client imei unfilled", - peeraddr, - packet, - ) result.append((clnt.pmod, clnt.imei, when, peeraddr, packet)) log.debug( "Received from %s (IMEI %s): %s", @@ -221,11 +176,6 @@ class Clients: def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None: - global pmods - pmods = [ - cast(ProtoModule, import_module("." + modnm, __package__)) - for modnm in conf.get("collector", "protocols").split(",") - ] # Is this https://github.com/zeromq/pyzmq/issues/1627 still not fixed?! zctx = zmq.Context() # type: ignore zpub = zctx.socket(zmq.PUB) # type: ignore @@ -243,11 +193,11 @@ def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None: poller.register(zpull, flags=zmq.POLLIN) poller.register(tcpfd, flags=zmq.POLLIN) clients = Clients() + pollingfds: Set[int] = set() try: while True: - tosend = [] - topoll = [] - tostop = [] + tosend: List[Resp] = [] + toadd: List[Tuple[socket, Any]] = [] events = poller.poll(1000) for sk, fl in events: if sk is zpull: @@ -261,12 +211,12 @@ def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None: elif sk == tcpfd: clntsock, clntaddr = tcpl.accept() clntsock.setsockopt(SOL_SOCKET, SO_KEEPALIVE, 1) - topoll.append((clntsock, clntaddr)) + toadd.append((clntsock, clntaddr)) elif fl & zmq.POLLIN: received = clients.recv(sk) if received is None: log.debug("Terminal gone from fd %d", sk) - tostop.append(sk) + clients.stop(sk) else: for pmod, imei, when, peeraddr, packet in received: proto = pmod.proto_of_message(packet) @@ -288,7 +238,7 @@ def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None: sk, imei, ) - tostop.append(sk) + clients.stop(sk) respmsg = pmod.inline_response(packet) if respmsg is not None: tosend.append( @@ -310,12 +260,13 @@ def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None: packet=zmsg.packet, ).packed ) - for fd in tostop: + for fd in pollingfds - clients.fds(): poller.unregister(fd) # type: ignore - clients.stop(fd) - for clntsock, clntaddr in topoll: + for clntsock, clntaddr in toadd: fd = clients.add(clntsock, clntaddr) + for fd in clients.fds() - pollingfds: poller.register(fd, flags=zmq.POLLIN) + pollingfds = clients.fds() except KeyboardInterrupt: zpub.close() zpull.close()