]> average.org Git - loctrkd.git/commitdiff
wsgateway properly handle write-busy websockets
authorEugene Crosser <crosser@average.org>
Fri, 6 May 2022 10:50:28 +0000 (12:50 +0200)
committerEugene Crosser <crosser@average.org>
Fri, 6 May 2022 10:50:28 +0000 (12:50 +0200)
gps303/wsgateway.py

index a4a0b6d441d31d8a42800adc3e3d77eebd196547..127dbb4b002fc8c62dce6be2cd42e35b6df54e3e 100644 (file)
@@ -60,7 +60,7 @@ def try_http(data, fd, e):
                 return (
                     f"{proto} 404 File not found\r\n"
                     f"Content-Type: text/plain\r\n\r\n"
-                    f"We can only serve \"/\"\r\n".encode()
+                    f'We can only serve "/"\r\n'.encode()
                 )
         else:
             return (
@@ -143,16 +143,18 @@ class Client:
         self.ws_data += self.ws.send(Message(data=message.json))
 
     def write(self):
-        try:
-            sent = self.sock.send(self.ws_data)
-            self.ws_data = self.ws_data[sent:]
-        except OSError as e:
-            log.error(
-                "Sending to fd %d: %s",
-                self.sock.fileno(),
-                e,
-            )
-            self.ws_data = b""
+        if self.ws_data:
+            try:
+                sent = self.sock.send(self.ws_data)
+                self.ws_data = self.ws_data[sent:]
+            except OSError as e:
+                log.error(
+                    "Sending to fd %d: %s",
+                    self.sock.fileno(),
+                    e,
+                )
+                self.ws_data = b""
+        return bool(self.ws_data)
 
 
 class Clients:
@@ -182,18 +184,25 @@ class Clients:
         return result
 
     def send(self, msg):
-        for clnt in self.by_fd.values():
+        towrite = set()
+        for fd, clnt in self.by_fd.items():
             if clnt.wants(msg.imei):
                 clnt.send(msg)
-                clnt.write()
+                towrite.add(fd)
+        return towrite
+
+    def write(self, towrite):
+        waiting = set()
+        for fd, clnt in [(fd, self.by_fd.get(fd)) for fd in towrite]:
+            if clnt.write():
+                waiting.add(fd)
+        return waiting
 
 
 def runserver(conf):
     global htmldata
     try:
-        with open(
-            conf.get("wsgateway", "htmlfile"), encoding="utf-8"
-        ) as fl:
+        with open(conf.get("wsgateway", "htmlfile"), encoding="utf-8") as fl:
             htmldata = fl.read()
     except OSError:
         pass
@@ -212,10 +221,12 @@ def runserver(conf):
     poller.register(tcpfd, flags=zmq.POLLIN)
     clients = Clients()
     try:
+        towait = set()
         while True:
             tosend = []
             topoll = []
             tostop = []
+            towrite = set()
             events = poller.poll(5000)
             for sk, fl in events:
                 if sk is zsub:
@@ -236,6 +247,10 @@ def runserver(conf):
                     else:
                         for msg in received:
                             log.debug("Received from %d: %s", sk, msg)
+                elif fl & zmq.POLLOUT:
+                    log.debug("Write now open for fd %d", sk)
+                    towrite.add(sk)
+                    towait.discard(sk)
                 else:
                     log.debug("Stray event: %s on socket %s", fl, sk)
             # poll queue consumed, make changes now
@@ -243,12 +258,26 @@ def runserver(conf):
                 poller.unregister(fd)
                 clients.stop(fd)
             for zmsg in tosend:
-                log.debug("Sending to the client: %s", zmsg)
-                clients.send(zmsg)
+                log.debug("Sending to the clients: %s", zmsg)
+                towrite |= clients.send(zmsg)
             for clntsock, clntaddr in topoll:
                 fd = clients.add(clntsock, clntaddr)
                 poller.register(fd, flags=zmq.POLLIN)
-            # TODO: Handle write overruns (register for POLLOUT)
+            # Deal with actually writing the data out
+            trywrite = towrite - towait
+            morewait = clients.write(trywrite)
+            log.debug(
+                "towait %s, tried %s, still busy %s",
+                towait,
+                trywrite,
+                morewait,
+            )
+            for fd in morewait - trywrite:  # new fds waiting for write
+                poller.modify(fd, flags=zmq.POLLIN | zmq.POLLOUT)
+            for fd in trywrite - morewait:  # no longer waiting for write
+                poller.modify(fd, flags=zmq.POLLIN)
+            towait &= trywrite
+            towait |= morewait
     except KeyboardInterrupt:
         pass