109 lines
3.6 KiB
Python
109 lines
3.6 KiB
Python
import logging
|
|
|
|
import zmq
|
|
from zmq import Context, Socket
|
|
|
|
from flandre.utils.Msg import InterruptMsg, Msg
|
|
|
|
|
|
class BusClient:
|
|
fp = 5001
|
|
bp = 5002
|
|
|
|
def __init__(
|
|
self,
|
|
*msgs: type(Msg),
|
|
ctx=None,
|
|
pub=True,
|
|
sub=True,
|
|
conflare=False,
|
|
poller=False,
|
|
req_socket_str: str = None,
|
|
):
|
|
self.sub: Socket = None
|
|
if ctx is None:
|
|
self.ctx: Context = zmq.Context()
|
|
else:
|
|
self.ctx = ctx
|
|
if sub:
|
|
self.sub = self.ctx.socket(zmq.SUB)
|
|
for msg in msgs:
|
|
self.sub.setsockopt(zmq.SUBSCRIBE, msg.magic() + msg.eid())
|
|
if msgs.__len__() == 0:
|
|
self.sub.setsockopt(zmq.SUBSCRIBE, b"")
|
|
if conflare:
|
|
self.sub.setsockopt(zmq.CONFLATE, 1)
|
|
self.sub.connect(f"tcp://127.0.0.1:{self.bp}")
|
|
if poller:
|
|
self.poller = zmq.Poller()
|
|
self.poller.register(self.sub, zmq.POLLIN)
|
|
if pub:
|
|
self.pub = self.ctx.socket(zmq.PUB)
|
|
self.pub.connect(f"tcp://127.0.0.1:{self.fp}")
|
|
self.req_socket = None
|
|
if req_socket_str is not None:
|
|
self.poller_for_interrupt = zmq.Poller()
|
|
self.sub_for_interrupt = self.ctx.socket(zmq.SUB)
|
|
self.sub_for_interrupt.setsockopt(
|
|
zmq.SUBSCRIBE, InterruptMsg.magic() + InterruptMsg.eid()
|
|
)
|
|
# self.sub2.connect(f'tcp://127.0.0.1:{self.bp}')
|
|
self.req_socket_str = req_socket_str
|
|
self.req_socket = self.ctx.socket(zmq.REQ)
|
|
self.poller_for_interrupt.register(self.req_socket, zmq.POLLIN)
|
|
self.poller_for_interrupt.register(self.sub_for_interrupt, zmq.POLLIN)
|
|
self.req_socket.connect(req_socket_str)
|
|
# todo fix poller
|
|
|
|
def recv(self):
|
|
b = self.sub.recv()
|
|
return Msg.decode_msg(b)
|
|
|
|
def poll(self, timeout):
|
|
b = self.sub.poll(timeout)
|
|
if b != 0:
|
|
return self.recv()
|
|
|
|
def send(self, msg: Msg):
|
|
return self.pub.send(msg.encode_msg())
|
|
|
|
async def recv_async(self):
|
|
b = await self.sub.recv()
|
|
return Msg.decode_msg(b)
|
|
|
|
async def send_async(self, msg: Msg):
|
|
return self.pub.send(msg.encode_msg())
|
|
|
|
def req_interrupt(
|
|
self,
|
|
data: bytes,
|
|
interrupt_name: str,
|
|
timeout=3000,
|
|
retry_times=114514,
|
|
cb_retry=None,
|
|
):
|
|
self.sub_for_interrupt.connect(f"tcp://127.0.0.1:{self.bp}")
|
|
|
|
for _ in range(retry_times):
|
|
self.req_socket.send(data)
|
|
r = dict(self.poller_for_interrupt.poll(timeout))
|
|
if self.req_socket in r:
|
|
self.sub_for_interrupt.disconnect(f"tcp://127.0.0.1:{self.bp}")
|
|
return self.req_socket.recv()
|
|
if cb_retry is not None:
|
|
cb_retry()
|
|
self.poller_for_interrupt.unregister(self.req_socket)
|
|
self.req_socket.close()
|
|
self.req_socket = self.ctx.socket(zmq.REQ)
|
|
self.poller_for_interrupt.register(self.req_socket, zmq.POLLIN)
|
|
self.req_socket.connect(self.req_socket_str)
|
|
if self.sub_for_interrupt in r:
|
|
msg = Msg.decode_msg(self.sub_for_interrupt.recv())
|
|
if isinstance(msg, InterruptMsg):
|
|
if msg.value == interrupt_name:
|
|
self.sub_for_interrupt.disconnect(f"tcp://127.0.0.1:{self.bp}")
|
|
return None
|
|
logging.warning("timeout")
|
|
self.sub_for_interrupt.disconnect(f"tcp://127.0.0.1:{self.bp}")
|
|
return "timeout"
|