Source code for embci.io.commanders

#!/usr/bin/env python3
# coding=utf-8
#
# File: EmBCI/embci/io/commanders.py
# Authors: Hank <hankso1106@gmail.com>
# Create: 2018-03-06 21:02:35

'''Commanders'''

# built-in
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import socket
import select
import threading
import traceback

# requirements.txt: data: pylsl
# requirements.txt: drivers: pyserial
import pylsl
import serial

from ..utils import (
    ensure_unicode, ensure_bytes, find_serial_ports, duration,
    LoopTaskInThread, Singleton
)
from ..constants import command_dict_null, command_dict_plane
from . import logger

__all__ = ['SocketTCPServer'] + [
    _ + 'Commander' for _ in ('Torcs', 'Plane', 'LSL', 'Serial')
]


# TODO: embci.io.commander: valid_name
def validate_commandername(name):
    return name


class BaseCommander(object):
    name = '[embci.io.Commander]'
    cmd_dict = command_dict_null

    def __init__(self, cmd_dict=None, name=None, *a, **k):
        self.name = validate_commandername(name or self.__class__.name)
        self.cmd_dict = cmd_dict or self.cmd_dict
        try:
            logger.debug('[Command Dict] %s' % self.cmd_dict['_desc'])
        except KeyError:
            logger.warning(
                '[Command Dict] current command dict does not have a '
                'key named _desc to describe itself. please add it.')
        # alias `send` as `write` to make instances file-like object
        self.write = self.send

    def start(self):
        raise NotImplementedError('you can not directly use this class')

    def send(self, key, *args, **kwargs):
        raise NotImplementedError('you can not directly use this class')

    flush, read = lambda: None, lambda *a: ''
    seek = truncate = tell = lambda *a: 0

    def get_command(self, cmd, warning=True):
        if cmd not in self.cmd_dict:
            if warning:
                logger.warning('{} command {} is not supported'.format(
                    self.name, cmd))
            return
        return self.cmd_dict[cmd]

    def close(self):
        raise NotImplementedError('you can not directly use this class')


class TorcsCommander(BaseCommander):
    '''
    Send command to TORCS (The Open Race Car Simulator). You can output
    predict result from classifier to the game to control race car like turn
    left, turn right, throttle, brake etc.
    '''
    name = 'TorcsCommander'

    def start(self):
        from ..gyms import TorcsEnv
        logger.debug(self.name + ' initializing TORCS...')
        self.env = TorcsEnv(vision=True, throttle=False, gear_change=False)
        self.env.reset()

    @duration(1, 'TorcsCommander')
    def send(self, key, prob, *args, **kwargs):
        cmd = [abs(prob) if key == 'right' else -abs(prob)]
        logger.debug(self.name + ' sending cmd {}'.format(cmd))
        self.env.step(cmd)
        return cmd

    def close(self):
        self.env.end()


class PlaneCommander(BaseCommander, Singleton):
    '''
    Send command to plane war game. Control plane with commands
    [`left`, `right`, `up` and `down`].
    '''
    name = 'PlaneCommander'
    cmd_dict = command_dict_plane

    def start(self):
        from ..gyms import PlaneClient
        self.client = PlaneClient()

    @duration(1, 'PlaneCommander')
    def send(self, key, *args, **kwargs):
        ret = self.get_command(key)
        if ret is None:
            return
        self.client.send(ret[0])
        time.sleep(ret[1])
        return ret[0]

    def close(self):
        self.client.close()


class LSLCommander(BaseCommander):
    '''
    Broadcast string[s] by pylsl.StreamOutlet as an online command stream.
    '''
    name = 'LSLCommander'

    def start(self, name=None, type='Result', source=None):
        '''
        Initialize and start pylsl outlet.

        Parameters
        ----------
        name : str, optional
            Name describes the data stream or session name.
        source : str
            Source specifies an unique identifier of the device or
            data generator, such as serial number or MAC.

        Examples
        --------
        >>> c = LSLCommander(name='pylsl commander 2')
        >>> c.start('result of recognition', 'EmBCI Hardware Re.A7.1221')
        >>> pylsl.resolve_bypred("contains('recognition')")
        [<pylsl.pylsl.StreamInfo instance at 0x7f3e82d8c3b0>]
        '''
        self._outlet = pylsl.StreamOutlet(pylsl.StreamInfo(
            name or self.name, type=type, channel_count=1,
            channel_format='string', source_id=source or self.name))

    def send(self, key, *args, **kwargs):
        self._outlet.push_sample([ensure_unicode(key), ])

    def close(self):
        del self._outlet


class SerialCommander(BaseCommander):
    name = 'SerialCommander'

    def __init__(self, *a, **k):
        super(SerialCommander, self).__init__(*a, **k)
        self._command_lock = threading.Lock()
        self._command_serial = serial.Serial()

    def start(self, port=None, baudrate=9600):
        self._command_serial.port = port or find_serial_ports()
        self._command_serial.baudrate = baudrate
        self._command_serial.open()

    def send(self, key, *args, **kwargs):
        ret = self.get_command(key)
        if ret is None:
            return
        with self._command_lock:
            self._command_serial.write(ensure_bytes(ret[0]))
            time.sleep(ret[1])
        return ret[0]

    def close(self):
        self._command_serial.close()

    def reconnect(self):
        try:
            self._command_serial.close()
            time.sleep(1)
            self._command_serial.open()
            logger.info(self.name + ' reconnect success.')
        except serial.serialutil.SerialException:
            logger.error(self.name + ' reconnect failed.')


[docs]class SocketTCPServer(LoopTaskInThread): ''' Socket TCP server on host:port, default to 0.0.0.0:0. A data broadcaster. ''' def __init__(self, host='0.0.0.0', port=0): self.host, self.port = host, port self._conns, self._addrs = [], [] self._server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) LoopTaskInThread.__init__(self, self.manager) def __repr__(self): return '<%s 0x%x>' % (LoopTaskInThread.__repr__(self)[1:-1], id(self))
[docs] def hook_before(self): self._server.bind((self.host, self.port)) self._server.listen(5) self._server.settimeout(0.5) self.host, self.port = self._server.getsockname() logger.info('{} socket server is listening on {}:{}'.format( self.name, self.host, self.port))
[docs] def handle_client(self, sock): addr = self.getaddr(sock) msg = sock.recv(4096).decode('utf8').strip() # client sent some data if msg not in ['shutdown', '']: logger.info('{} recv {} from {}:{}'.format(self.name, msg, *addr)) if hasattr(sock, 'onmessage'): try: sock.onmessage(msg) except Exception: logger.error(traceback.format_exc()) return # client shutdown and we should clear correspond server try: sock.sendall(b'shutdown') sock.shutdown(socket.SHUT_RDWR) except socket.error: pass except Exception: logger.error(traceback.format_exc()) finally: sock.close() self.remove(sock, addr)
[docs] def manager(self): ''' This loop task does following things to manage connections: 1. wait for new clients and add them into a list 2. remove closed clients 3. recieve msg from all clients and run callback functions ''' rlist, _, _ = select.select([self._server] + self._conns, [], [], 3) if not rlist: return if rlist[0] is self._server: # new connection con, addr = self._server.accept() con.settimeout(0.5) self.add(con, addr) else: # some client maybe closed self.handle_client(rlist[0])
[docs] def send(self, con, data): try: con.sendall(data) except socket.error: pass
[docs] def multicast(self, data): data = ensure_bytes(data) for con in self._conns: self.send(con, data)
[docs] def hook_after(self): for con in self._conns: con.close() self._server.close() logger.debug(self.name + ' Socket server shut down.')
[docs] def has_listeners(self): return len(self._conns)
[docs] def getaddr(self, sock): if sock in self._conns: return self._addrs[self._conns.index(sock)] return sock.getpeername()
[docs] def add(self, sock, addr=None): if sock in self._conns: return try: addr = addr or self.getaddr(sock) except Exception: return self._conns.append(sock) self._addrs.append(addr) logger.debug('{} add client from {}:{}'.format(self.name, *addr))
[docs] def remove(self, sock, addr=None): if sock not in self._conns: return try: addr = addr or self.getaddr(sock) except Exception: return self._conns.remove(sock) self._addrs.remove(addr) logger.debug('{} lost client from {}:{}'.format(self.name, *addr))
# THE END