flandre/flandre/BusClient.py
2025-06-10 20:35:01 +08:00

109 lines
3.6 KiB
Python

import logging
import zmq
from zmq import Context, Socket
from flandre.utils.Msg import InterruptMsg, Msg
class BusClient:
fp = 5001
bp = 5002
def __init__(
self,
*msgs: type(Msg),
ctx=None,
pub=True,
sub=True,
conflare=False,
poller=False,
req_socket_str: str = None,
):
self.sub: Socket = None
if ctx is None:
self.ctx: Context = zmq.Context()
else:
self.ctx = ctx
if sub:
self.sub = self.ctx.socket(zmq.SUB)
for msg in msgs:
self.sub.setsockopt(zmq.SUBSCRIBE, msg.magic() + msg.eid())
if msgs.__len__() == 0:
self.sub.setsockopt(zmq.SUBSCRIBE, b"")
if conflare:
self.sub.setsockopt(zmq.CONFLATE, 1)
self.sub.connect(f"tcp://127.0.0.1:{self.bp}")
if poller:
self.poller = zmq.Poller()
self.poller.register(self.sub, zmq.POLLIN)
if pub:
self.pub = self.ctx.socket(zmq.PUB)
self.pub.connect(f"tcp://127.0.0.1:{self.fp}")
self.req_socket = None
if req_socket_str is not None:
self.poller_for_interrupt = zmq.Poller()
self.sub_for_interrupt = self.ctx.socket(zmq.SUB)
self.sub_for_interrupt.setsockopt(
zmq.SUBSCRIBE, InterruptMsg.magic() + InterruptMsg.eid()
)
# self.sub2.connect(f'tcp://127.0.0.1:{self.bp}')
self.req_socket_str = req_socket_str
self.req_socket = self.ctx.socket(zmq.REQ)
self.poller_for_interrupt.register(self.req_socket, zmq.POLLIN)
self.poller_for_interrupt.register(self.sub_for_interrupt, zmq.POLLIN)
self.req_socket.connect(req_socket_str)
# todo fix poller
def recv(self):
b = self.sub.recv()
return Msg.decode_msg(b)
def poll(self, timeout):
b = self.sub.poll(timeout)
if b != 0:
return self.recv()
def send(self, msg: Msg):
return self.pub.send(msg.encode_msg())
async def recv_async(self):
b = await self.sub.recv()
return Msg.decode_msg(b)
async def send_async(self, msg: Msg):
return self.pub.send(msg.encode_msg())
def req_interrupt(
self,
data: bytes,
interrupt_name: str,
timeout=3000,
retry_times=114514,
cb_retry=None,
):
self.sub_for_interrupt.connect(f"tcp://127.0.0.1:{self.bp}")
for _ in range(retry_times):
self.req_socket.send(data)
r = dict(self.poller_for_interrupt.poll(timeout))
if self.req_socket in r:
self.sub_for_interrupt.disconnect(f"tcp://127.0.0.1:{self.bp}")
return self.req_socket.recv()
if cb_retry is not None:
cb_retry()
self.poller_for_interrupt.unregister(self.req_socket)
self.req_socket.close()
self.req_socket = self.ctx.socket(zmq.REQ)
self.poller_for_interrupt.register(self.req_socket, zmq.POLLIN)
self.req_socket.connect(self.req_socket_str)
if self.sub_for_interrupt in r:
msg = Msg.decode_msg(self.sub_for_interrupt.recv())
if isinstance(msg, InterruptMsg):
if msg.value == interrupt_name:
self.sub_for_interrupt.disconnect(f"tcp://127.0.0.1:{self.bp}")
return None
logging.warning("timeout")
self.sub_for_interrupt.disconnect(f"tcp://127.0.0.1:{self.bp}")
return "timeout"