#!/usr/bin/env python3
# coding=utf-8
#
# File: EmBCI/embci/io/base.py
# Authors: Hank <hankso1106@gmail.com>
# Create: 2019-03-16 19:37:41
'''Save and Load Utilities'''
# built-in
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import re
import time
import traceback
# requirements.txt: data: numpy, scipy, mne==0.17
# requirements.txt: necessary: six
import numpy as np
import scipy.io
import mne
from six import string_types
from ..utils import (
mkuserdir, check_input, typename, validate_filename,
TempStream
)
from ..configs import DIR_DATA
from . import logger
__all__ = [
'create_data_dict', 'find_data_info',
'save_trials', 'save_chunks', 'save_action',
'load_data', 'load_mat', 'load_label_data',
'validate_datafile',
]
_name_datafile_pattern = re.compile(r'^([ \w\.-]+)-(\d+)\.(\w+)(?:\.gz)?$')
[docs]def find_data_info(username): # noqa: W611
'''
Count all saved data files under user's directory that match a pattern:
``${DIR_DATA}/${username}/${label}-${num}.${suffix}[.gz]``
Returns
-------
(label_dict, filename_dict, summary string)
Examples
--------
>>> find_data_info('test')
({
'left': [0, 1, 2, 5, 8, 9, 10],
'right': [2, 3, 4, 5],
'thumb_cross': [1, 2, 3, 4, 5]
}, {
'left': ['/path/to/left-0.mat', '/path/to/left-1.fif', ...],
'right': ['/path/to/right-2.h5', '/path/to/right-3', ...],
'thumb_cross': [...]
}, 'There are 3 actions with 16 records.\\n')
'''
label_dict = {}
name_dict = {}
root = os.path.join(DIR_DATA, username)
if not os.path.exists(root):
return label_dict, name_dict, ''
for filename in sorted(os.listdir(root)):
if not _name_datafile_pattern.match(filename):
continue
label, num, ext = _name_datafile_pattern.findall(filename)[0]
if ext not in ['mat', 'fif', 'h5', 'csv']:
continue
num = int(num)
if label in label_dict:
label_dict[label].append(num)
name_dict[label].append(filename)
else:
label_dict[label] = [num]
name_dict[label] = [filename]
# construct a neat summary report
summary = 'There are {} actions with {} data recorded.'.format(
len(label_dict), sum(map(len, label_dict.values())))
if label_dict:
maxname = max(len(fn) for fns in name_dict.values() for fn in fns) - 2
summary += '\n * ' + '\n * '.join([
label.ljust(maxname) + '%2d' % len(label_dict[label]) +
'\n ' + '\n '.join(name_dict[label])
for label in label_dict
])
for label in name_dict:
name_dict[label] = [os.path.join(root, fn) for fn in name_dict[label]]
return label_dict, name_dict, summary
@mkuserdir
def validate_datafile(username, label='default', checkname=False):
'''
Resolve saved datafiles and generate a valid filename for new data.
``${DIR_DATA}/${username}/${label}-${num}.${suffix}[.gz]``
Parameters
----------
username : str
label : str
checkname : bool
Whether to ensure that the username is valid for filename. Default False.
Returns
-------
(datafile name, username)
If checkname set to True, username is validated too.
'''
if checkname:
username = ''.join([
c for c in validate_filename(username) if c not in '()[]'
]).replace(' ', '_').replace('.', ' ')
label = validate_filename(label)
label_dict = find_data_info(username)[0]
ns = label_dict.get(label, [])
num = list(set(range(len(ns) + 1)).difference(ns))[0]
fn = os.path.join(DIR_DATA, username, '%s-%d' % (label, num))
return fn, username
[docs]def create_data_dict(data, label='default', sample_rate=500, suffix=None):
'''
Create a data_dict that can be saved by function :func:`save_trials`.
Parameters
----------
data : ndarray | array list | instance of mne.Raw[Array] | dict
1-3d array with a shape of [[num_trial x] num_channel x] window_size
label : str
Action name, data label. Char :code:`-` is not suggested in label.
sample_rate : int
Sample rate of data, default set to 500Hz.
suffix : str
Currently supported formats are MATLAB-style '.mat'(default),
MNE-style '.fif[.gz]' and raw text '.csv'.
Returns
-------
data_dict : dict
{'data': dict, 'label': str, 'sample_rate': int, ...}
'''
data_dict = {
'label': str(label),
'sample_rate': int(sample_rate),
'key': 'data'
}
if suffix is not None:
data_dict['suffix'] = str(suffix)
if isinstance(data, mne.io.BaseRaw):
data_dict['info'] = data.info
data_dict['sample_rate'] = data.info['sfreq']
data = data.get_data() # num_channel x window_size
elif isinstance(data, dict):
dct = data
if 'key' in dct:
key = dct.pop('key')
else:
for key in ['raw', 'data', 'array']:
if key in dct:
break
else:
raise TypeError('No data key in dict: %s' % dct.keys())
data = np.atleast_2d(dct.pop(key))
data_dict.update(dct)
elif isinstance(data, (tuple, list, np.ndarray)):
data = np.atleast_2d(data)
else:
raise TypeError('Invalid data type: %s' % typename(data))
if data.ndim == 2:
data = data[np.newaxis]
elif data.ndim > 3:
raise ValueError('Array with too many dimensions: %s' % data.shape)
data_dict[data_dict['key']] = data
return data_dict
[docs]def save_trials(username, data_dict, suffix='mat', summary=False):
'''
Save trials of data into ${DIR_DATA}/${username}/${label}-${num}.${suffix}
Parameters
----------
username : str
data_dict : dict
created by function create_data_dict(data, label, format, sample_rate)
suffix : str
Currently supported formats are MATLAB-style '.mat'(default),
MNE-style '.fif[.gz]' and raw text '.csv'. Format setting in
data_dict will overwrite this argument.
summary : bool
Whether to print summary of currently saved data, default `False`.
Examples
--------
>>> data = np.random.rand(8, 1000) # 8chs x 4sec x 250Hz data
>>> save_trials('test', create_data_dict(data, 'random_data', 250))
(8, 1000) data saved to ${DIR_DATA}/test/random_data-1.mat
>>> raw = mne.io.RawArray(data, mne.create_info(8, 250))
>>> save_trials('test', create_data_dict(raw, format='fif.gz'))
(8, 1000) data saved to ${DIR_DATA}/test/default-1.fif.gz
'''
try:
label = data_dict.pop('label')
srate = data_dict['sample_rate']
key = data_dict['key']
except KeyError:
raise TypeError('`data_dict` object created by function '
'`create_data_dict` is preferred.')
else:
suffix = data_dict.pop('suffix', suffix).strip('.')
event = data_dict.pop('event', [])
info = data_dict.pop('info', None)
data = data_dict.pop(key)
data_dict['key'] = key = 'trial'
# function create_data_dict maybe offer mne.Info object
if 'fif' in suffix and not isinstance(info, mne.Info):
info = mne.create_info(data.shape[1], srate)
username = validate_datafile(username, checkname=True)[1]
for trial in data:
fn = '{}.{}'.format(
validate_datafile(username, label)[0], suffix)
try:
if suffix == 'mat':
data_dict[key] = trial
# TODO: save event channel
event
scipy.io.savemat(fn, data_dict, do_compression=True)
elif suffix == 'csv':
np.savetxt(fn, trial, delimiter=',')
elif 'fif' in suffix:
# mute mne.io.BaseRaw.save info from stdout and stderr
with TempStream(stdout=None, stderr=None) as ts:
try:
mne.io.RawArray(trial, info).save(fn)
except Exception:
logger.error(traceback.format_exc())
logger.debug('%s %s' % (ts.stdout, ts.stderr))
else:
logger.error('format `%s` is not supported.' % suffix)
break
logger.info('save {} data to {}'.format(trial.shape, fn))
except Exception:
logger.warning('save %s failed.' % fn)
logger.error(traceback.format_exc())
if os.path.exists(fn):
os.remove(fn)
if summary:
print('\n' + find_data_info(username)[2])
_append_keys = {}
[docs]def save_chunks(fn, data_dict, suffix='mat', append=False):
'''
Save chunks of data into ${DIR_DATA}/${username}/${label}-${num}.${suffix}
Parameters
----------
fn : str or file-like
data_dict : dict
created by function create_data_dict(data, label, format, sample_rate)
suffix : str
Currently supported formats are MATLAB-style '.mat'(default),
HDF5-style '.h5' and raw text '.csv'. Format setting in
data_dict will overwrite this argument.
summary : bool
Whether to print summary of currently saved data, default `False`.
Examples
--------
>>> data = np.random.rand(8, 1000) # 8chs x 4sec x 250Hz data
>>> save_trials('./test', create_data_dict(data, 'random_data', 250))
(8, 1000) data saved to ${DIR_DATA}/test/random_data-1.mat
>>> raw = mne.io.RawArray(data, mne.create_info(8, 250))
>>> save_trials('test', create_data_dict(raw, format='fif.gz'))
(8, 1000) data saved to ${DIR_DATA}/test/default-1.fif.gz
'''
try:
label = data_dict['label']
srate = data_dict['sample_rate']
key = data_dict.pop('key')
except KeyError:
raise TypeError('`data_dict` object created by function '
'`create_data_dict` is preferred.')
else:
suffix = data_dict.pop('suffix', suffix).lstrip('.')
info = data_dict.pop('info', None)
data = data_dict.pop(key)
data_dict['key'] = key = 'chunk'
data_dict[key] = data
if isinstance(fn, string_types):
if not fn.startswith(DIR_DATA):
fn = os.path.join(DIR_DATA, fn)
fobj = open(fn, 'a+b' if append else 'wb')
elif hasattr(fn, 'name') and hasattr(fn, 'write'):
fobj, fn = fn, fn.name
if fobj.closed:
raise ValueError('File object %s has already closed.' % fobj)
append = True
else:
raise TypeError('Param `fn` only accepts filename or file-like object')
try:
if suffix == 'mat':
if append:
data_dict = _sort_mat(fn, data_dict)
scipy.io.savemat(fobj, data_dict, do_compression=True)
elif 'fif' in suffix:
# TODO: append data to fif file
info, srate, label
elif suffix == 'h5':
# TODO: append data to HDF5 file
pass
elif suffix == 'csv':
np.savetxt(fobj, data, delimiter=',')
else:
raise TypeError('format `%s` is not supported.' % suffix)
fobj.flush()
logger.info('save {} data to {}'.format(data_dict.keys(), fn))
except Exception:
logger.warning('save %s failed.' % fn)
logger.error(traceback.format_exc())
else:
if not append:
fobj.close()
return fobj
def _sort_mat(fn, data_dict):
if fn not in _append_keys:
try:
data = scipy.io.loadmat(fn)
except Exception:
ks = []
else:
ks = [k for k in data.keys() if not k.startswith('_')]
del data
_append_keys[fn] = ks
keys = _append_keys[fn]
for k in list(data_dict.keys()):
replicate = [_ for _ in keys if _.startswith(k)]
if not replicate:
continue
data_dict['%s/%d' % (k, len(replicate))] = data_dict.pop(k)
_append_keys[fn].extend(data_dict.keys())
return data_dict
[docs]def load_mat(fn):
if isinstance(fn, dict):
dct = fn
else: # file-like object or filename
dct = scipy.io.loadmat(fn)
if 'trial' in dct:
key = dct.get('key', 'trial')[0]
data = dct.pop(key)
if isinstance(data, dict):
data = data['raw']
if not isinstance(data, np.ndarray) or data.ndim != 2:
raise IOError('Data file {} not support'.format(fn))
elif 'chunk' in dct:
keys = sorted(dct.keys())
while keys:
k = keys[0]
if k[0] == k[-1] == '_':
keys.remove(k)
continue
replst = [_ for _ in keys[1:] if _.startswith(k)] + [k]
arrays = [dct.pop(keys.pop(keys.index(_))) for _ in replst]
dct[k] = [arr[0] if arr.size else [] for arr in arrays]
if isinstance(dct[k][0], (np.ndarray, list, tuple)):
try:
dct[k] = np.concatenate(dct[k], -1)
except Exception:
pass
key = dct.get('key', 'chunk')[0]
data = dct.pop(key)
elif 'data' in dct:
data = dct.pop('data')
else:
logger.error('Unrecognized mat file: %s' % dct.keys())
raise IOError('Data file {} not support'.format(fn))
dct['key'] = key = 'raw'
dct[key] = data
return dct
[docs]def load_label_data(username, label='default'):
'''
Load all data files that match ${DIR_DATA}/${username}/${label}-*.*
Returns
-------
data_list : list
'''
data_list = []
for fn in find_data_info(username)[1].get(label, []):
name, suffix = os.path.splitext(fn)
if suffix == '.gz':
name, suffix = os.path.splitext(name)
suffix = suffix.strip('.')
try:
if suffix == 'mat':
dct = load_mat(fn)
data = dct[dct['key']]
if isinstance(data, (tuple, list)):
data = np.concatenate(data, -1)
elif suffix == 'csv':
data = np.loadtxt(fn, np.float32, delimiter=',')
elif suffix == 'fif':
with TempStream(stdout=None, stderr=None):
# data = mne.io.RawFIF(fn).get_data()
data = mne.io.RawFIF(fn, preload=True)._data
else:
raise ValueError('format `%s` is not supported.' % suffix)
data_list.append(data)
logger.info('Load {} data from {}'.format(data.shape, fn))
except Exception:
logger.warning('Load %s failed.' % fn)
logger.error(traceback.format_exc())
return data_list
[docs]def load_data(username, pick=None, summary=True):
'''
Load all data files under directory ${DIR_DATA}/${username}
Parameters
----------
username : str
pick : str | list or tuple of str | regex pattern | function
load data files whose label name:
equal to | inside | match | return True by appling `pick`
summary : bool
whether to print summary of currently saved data, default `False`.
Returns
-------
out : tuple
(data_array, label_list)
data_array : ndarray
3D array with a shape of n_samples x num_channel x window_size
label_list : list
String list with a length of n_samples. Each element indicate
label(action name) of corresponding data sample.
Examples
--------
>>> data, label = load_data('test')
>>> len(data), label
(5, ['default', 'default', 'default', 'right', 'left'])
>>> _, _ = load_data('test', pick=('left', 'right'), summary=True)
There are 3 actions with 5 data recorded.
* default 3
default-1.fif.gz
default-2.fif.gz
default-3.mat
* right 1
right-1.mat
* left 1
left-1.fif
There are 2 actions with 2 data loaded.
+ left 1
+ right 1
'''
def filterer(label):
if isinstance(pick, str):
return label == pick
if isinstance(pick, (tuple, list)):
return label in pick
if isinstance(pick, re._pattern_type):
return bool(pick.match(label))
if callable(pick):
return pick(label)
return True
# pick labels that match the rule from label_dict
label_dict, _, msg = find_data_info(username)
labels = list(filter(filterer, label_dict))
data_list = []
label_list = []
for label in labels:
data = load_label_data(username, label)
data_list.extend(data)
label_list.extend([label] * len(data))
if summary:
msg += '\nThere are {} actions with {} data loaded.'.format(
len(labels), len(data_list))
if len(data_list):
maxname = max(len(s) for s in msg.split('\n')[1:-1]) - 6
msg += '\n + ' + '\n + '.join([
label.ljust(maxname) + '%2d' % label_list.count(label)
for label in labels
])
print('\n' + msg.strip())
# data_list: n_samples x (num_channel x window_size)
# label_list: n_samples
return data_list, label_list
[docs]def save_action(username, reader, action_list=['relax', 'grab']):
'''
Guidance on command line interface to save data with label to
``${DIR_DATA}/${username}/${action}-*.mat``
Parameters
----------
username : str
reader : Reader
Instance of :class:`embci.io.readers.BaseReader`, repersenting a stream
from which data will be read.
'''
logger.info('You have to finish each action in {} seconds.'.format(
reader.sample_time))
num = check_input(('How many times would you like to record for each '
'action?(empty to abort): '), {}, times=999)
if num == '' or not num.isdigit():
return
num = int(num)
action_list = [
action.replace('-', '_').replace(' ', ' ')
for action in validate_filename(*action_list) if action
]
name_list = action_list * num
np.random.shuffle(name_list)
while name_list:
action = name_list.pop()
print('action name: %s, start recording in 2s' % action)
time.sleep(2)
print('Start')
time.sleep(reader.sample_time)
print('Stop')
save_trials(username, summary=True, data_dict=create_data_dict(
reader.data_frame, action, reader.sample_rate))
return find_data_info(username)[0]
# THE END