]> average.org Git - loctrkd.git/commitdiff
collector: close old connection on new login
authorEugene Crosser <crosser@average.org>
Wed, 13 Jul 2022 21:32:48 +0000 (23:32 +0200)
committerEugene Crosser <crosser@average.org>
Thu, 14 Jul 2022 20:48:05 +0000 (22:48 +0200)
loctrkd/collector.py

index 22bc9c3ffa2014346d74fd13b19de48d6d098472..26d37f33f60af39a123c8aafa043c50244125ef0 100644 (file)
@@ -14,7 +14,7 @@ from socket import (
 )
 from struct import pack
 from time import time
-from typing import Any, cast, Dict, List, Optional, Tuple, Union
+from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union
 import zmq
 
 from . import common
@@ -151,6 +151,7 @@ class Clients:
     def __init__(self) -> None:
         self.by_fd: Dict[int, Client] = {}
         self.by_imei: Dict[str, Client] = {}
+        self.tostop: Set[int] = set()
 
     def add(self, clntsock: socket, clntaddr: Tuple[str, int]) -> int:
         fd = clntsock.fileno()
@@ -162,7 +163,7 @@ class Clients:
         clnt = self.by_fd[fd]
         log.info("Stop serving fd %d (IMEI %s)", clnt.sock.fileno(), clnt.imei)
         clnt.close()
-        if clnt.imei:
+        if clnt.imei and self.by_imei[clnt.imei] == clnt:  # could be replaced
             del self.by_imei[clnt.imei]
         del self.by_fd[fd]
 
@@ -185,18 +186,11 @@ class Clients:
                     clnt.imei = imei
                     oldclnt = self.by_imei.get(clnt.imei)
                     if oldclnt is not None:
-                        log.info(
-                            "Orphaning fd %d with the same IMEI",
-                            oldclnt.sock.fileno(),
-                        )
+                        oldfd = oldclnt.sock.fileno()
+                        log.info("Removing stale connection on fd %d", oldfd)
                         oldclnt.imei = None
+                        self.tostop.add(oldfd)
                     self.by_imei[clnt.imei] = clnt
-                else:
-                    log.warning(
-                        "Login message from %s: %s, but client imei unfilled",
-                        peeraddr,
-                        packet,
-                    )
             result.append((clnt.pmod, clnt.imei, when, peeraddr, packet))
             log.debug(
                 "Received from %s (IMEI %s): %s",
@@ -243,7 +237,7 @@ def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
         while True:
             tosend = []
             topoll = []
-            tostop = []
+            clients.tostop = set()
             events = poller.poll(1000)
             for sk, fl in events:
                 if sk is zpull:
@@ -262,7 +256,7 @@ def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
                     received = clients.recv(sk)
                     if received is None:
                         log.debug("Terminal gone from fd %d", sk)
-                        tostop.append(sk)
+                        clients.tostop.add(sk)
                     else:
                         for pmod, imei, when, peeraddr, packet in received:
                             proto = pmod.proto_of_message(packet)
@@ -284,7 +278,7 @@ def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
                                     sk,
                                     imei,
                                 )
-                                tostop.append(sk)
+                                clients.tostop.add(sk)
                             respmsg = pmod.inline_response(packet)
                             if respmsg is not None:
                                 tosend.append(
@@ -306,7 +300,7 @@ def runserver(conf: ConfigParser, handle_hibernate: bool = True) -> None:
                             packet=zmsg.packet,
                         ).packed
                     )
-            for fd in tostop:
+            for fd in clients.tostop:
                 poller.unregister(fd)  # type: ignore
                 clients.stop(fd)
             for clntsock, clntaddr in topoll: