Source code for jps.subscriber

import threading
import time
import types

import zmq
from zmq.utils.strtypes import cast_bytes

from .common import Error
from .env import get_master_host
from .env import get_sub_port
from .env import get_topic_suffix
from .env import get_default_deserializer
from .env import get_remapped_topic_name


[docs]class Subscriber(object): '''Subscribe the topic and call the callback function Example: >>> def callback(msg): ... print msg ... >>> sub = jps.Subscriber('topic_name', callback) >>> sub.spin() or you can use python generator style >>> import jps >>> for msg in jps.Subscriber('/hoge1'): ... print msg :param topic_name: topic name :param host: host name of publisher/forwarder :param sub_port: port of publisher/forwarder :param deserializer: this function is applied after received (default: None) ''' def __init__(self, topic_name, callback=None, host=None, sub_port=None, deserializer='DEFAULT'): topic_name = get_remapped_topic_name(topic_name) if topic_name.count(' '): raise Error('you can\'t use " " for topic_name') if topic_name == '': raise Error('empty topic name is not supported') if host is None: host = get_master_host() if sub_port is None: sub_port = get_sub_port() if deserializer is 'DEFAULT': deserializer = get_default_deserializer() self._deserializer = deserializer context = zmq.Context() self._socket = context.socket(zmq.SUB) self._socket.connect('tcp://{host}:{port}'.format(host=host, port=sub_port)) self._topic = cast_bytes(topic_name + get_topic_suffix()) self._topic_without_star = self._topic.rstrip('*') self._socket.setsockopt(zmq.SUBSCRIBE, self._topic_without_star) self._user_callback = callback if type(callback) == types.MethodType: # arg=[self, message, topic_name] self._user_callback_takes_topic_name = callback.im_func.func_code.co_argcount == 3 elif type(callback) == types.FunctionType: # arg=[message, topic_name] self._user_callback_takes_topic_name = callback.func_code.co_argcount == 2 elif hasattr(callback, '__call__'): # arg=[self, message, topic_name] self._user_callback_takes_topic_name = callback.__call__.im_func.func_code.co_argcount == 3 else: self._user_callback_takes_topic_name = False if type(callback) == types.InstanceType: print 'argcoutn = ' + callback.func_code.co_argcount self._thread = None self._poller = zmq.Poller() self._poller.register(self._socket, zmq.POLLIN) def _strip_topic_name_if_not_wildcard(self, raw_msg): topic, _, msg = raw_msg.partition(' ') if self._topic != self._topic_without_star: return (msg, topic) elif topic == self._topic: return (msg, topic) return (None, topic)
[docs] def deserialize(self, msg): if self._deserializer is not None: return self._deserializer(msg) return msg
def _callback(self, raw_msg): if self._user_callback is None: return msg, topic_name = self._strip_topic_name_if_not_wildcard(raw_msg) if msg is not None: if self._user_callback_takes_topic_name: self._user_callback(self.deserialize(msg), topic_name) else: self._user_callback(self.deserialize(msg))
[docs] def spin_once(self, polling_sec=0.010): '''Read the queued data and call the callback for them. You have to handle KeyboardInterrupt (\C-c) manually. Example: >>> def callback(msg): ... print msg >>> sub = jps.Subscriber('topic_name', callback) >>> try: ... while True: ... sub.spin_once(): ... time.sleep(0.1) ... except KeyboardInterrupt: ... pass ''' # parse all data while True: socks = dict(self._poller.poll(polling_sec * 1000)) if socks.get(self._socket) == zmq.POLLIN: msg = self._socket.recv() self._callback(msg) else: return
[docs] def spin(self, use_thread=False): '''call callback for all data forever (until \C-c) :param use_thread: use thread for spin (do not block) ''' if use_thread: if self._thread is not None: raise Error('spin called twice') self._thread = threading.Thread(target=self._spin_internal) self._thread.setDaemon(True) self._thread.start() else: self._spin_internal()
def _spin_internal(self): for msg in self: if self._user_callback_takes_topic_name: self._user_callback(*msg) else: self._user_callback(self.deserialize(msg)) def __iter__(self): return self
[docs] def next(self): '''receive next data (block until next data)''' try: raw_msg = self._socket.recv() except KeyboardInterrupt: raise StopIteration() msg, topic_name = self._strip_topic_name_if_not_wildcard(raw_msg) if msg is None: return self.next() if self._user_callback_takes_topic_name: return (self.deserialize(msg), topic_name) else: return self.deserialize(msg)
# for python3 __next__ = next