import warnings
from collections.abc import Iterable
from bisect import bisect_left, bisect_right
import numpy as np
from hdmf.utils import docval, getargs, popargs, popargs_to_dict, get_docval, AllowPositional
from . import register_class, CORE_NAMESPACE
from .base import TimeSeries
from .ecephys import ElectrodeGroup
from hdmf.common import DynamicTable, DynamicTableRegion
__all__ = [
'AnnotationSeries',
'AbstractFeatureSeries',
'IntervalSeries',
'Units',
'DecompositionSeries'
]
[docs]
@register_class('AnnotationSeries', CORE_NAMESPACE)
class AnnotationSeries(TimeSeries):
"""DEPRECATED. Stores text-based records about the experiment.
AnnotationSeries is deprecated. Use an EventsTable with an 'annotation' column instead.
See :py:class:`~pynwb.event.EventsTable`.
To use the AnnotationSeries, add records individually through add_annotation(). Alternatively, if all annotations
are already stored in a list or numpy array, set the data and timestamps in the constructor.
"""
__nwbfields__ = ()
@docval(*get_docval(TimeSeries.__init__, 'name'), # required
{'name': 'data', 'type': ('array_data', 'data', TimeSeries), 'shape': (None,),
'doc': 'The annotations over time. Must be 1D.',
'default': list()},
*get_docval(TimeSeries.__init__, 'timestamps', 'comments', 'description'),
allow_positional=AllowPositional.WARNING,)
def __init__(self, **kwargs):
name, data, timestamps = popargs('name', 'data', 'timestamps', kwargs)
super().__init__(name=name, data=data, unit='n/a', resolution=-1.0, timestamps=timestamps, **kwargs)
self._warn_on_new_pass_on_construct(
"AnnotationSeries is deprecated. Use an EventsTable with an 'annotation' column instead. "
"Creating a new AnnotationSeries will not be allowed in a future version of PyNWB."
)
[docs]
@docval({'name': 'time', 'type': float, 'doc': 'The time for the annotation'},
{'name': 'annotation', 'type': str, 'doc': 'the annotation'})
def add_annotation(self, **kwargs):
"""Add an annotation."""
time, annotation = getargs('time', 'annotation', kwargs)
self.fields['timestamps'].append(time)
self.fields['data'].append(annotation)
[docs]
@register_class('AbstractFeatureSeries', CORE_NAMESPACE)
class AbstractFeatureSeries(TimeSeries):
"""
Represents the salient features of a data stream. Typically this
will be used for things like a visual grating stimulus, where
the bulk of data (each frame sent to the graphics card) is bulky
and not of high value, while the salient characteristics (eg,
orientation, spatial frequency, contrast, etc) are what important
and are what are used for analysis
"""
__nwbfields__ = ('feature_units',
'features')
@docval(*get_docval(TimeSeries.__init__, 'name'), # required
{'name': 'feature_units', 'type': Iterable, 'shape': (None, ), # required
'doc': 'The unit of each feature'},
{'name': 'features', 'type': Iterable, 'shape': (None, ), # required
'doc': 'Description of each feature'},
{'name': 'data', 'type': ('array_data', 'data', TimeSeries), 'shape': ((None,), (None, None)),
'doc': ('The data values. May be 1D or 2D. The first dimension must be time. The optional second '
'dimension represents features'),
'default': list()},
*get_docval(TimeSeries.__init__, 'resolution', 'conversion', 'timestamps', 'starting_time', 'rate',
'comments', 'description', 'control', 'control_description', 'offset'),
allow_positional=AllowPositional.WARNING,)
def __init__(self, **kwargs):
name, data, features, feature_units = popargs('name', 'data',
'features', 'feature_units', kwargs)
super().__init__(name=name, data=data, unit="see 'feature_units'", **kwargs)
self.features = features
self.feature_units = feature_units
[docs]
@docval({'name': 'time', 'type': float, 'doc': 'the time point of this feature'},
{'name': 'features', 'type': (list, np.ndarray), 'doc': 'the feature values for this time point'})
def add_features(self, **kwargs):
time, features = getargs('time', 'features', kwargs)
if isinstance(self.timestamps, list) and isinstance(self.data, list):
self.timestamps.append(time)
self.data.append(features)
else:
raise ValueError('Can only add feature if timestamps and data are lists')
[docs]
@register_class('IntervalSeries', CORE_NAMESPACE)
class IntervalSeries(TimeSeries):
"""
Stores intervals of data. The timestamps field stores the beginning and end of intervals. The
data field stores whether the interval just started (>0 value) or ended (<0 value). Different interval
types can be represented in the same series by using multiple key values (eg, 1 for feature A, 2
for feature B, 3 for feature C, etc). The field data stores an 8-bit integer. This is largely an alias
of a standard TimeSeries but that is identifiable as representing time intervals in a machine-readable
way.
"""
__nwbfields__ = ()
@docval(*get_docval(TimeSeries.__init__, 'name'), # required
{'name': 'data', 'type': ('array_data', 'data', TimeSeries), 'shape': (None,),
'doc': ('The data values. Must be 1D, where the first dimension must be time. Values are >0 if '
'interval started, <0 if interval ended.'),
'default': list()},
*get_docval(TimeSeries.__init__, 'timestamps', 'comments', 'description', 'control', 'control_description'),
allow_positional=AllowPositional.WARNING,)
def __init__(self, **kwargs):
name, data, timestamps = popargs('name', 'data', 'timestamps', kwargs)
self.__interval_timestamps = timestamps
self.__interval_data = data
super().__init__(name=name, data=data, unit='n/a', resolution=-1.0, timestamps=timestamps, **kwargs)
[docs]
@docval({'name': 'start', 'type': float, 'doc': 'The start time of the interval'},
{'name': 'stop', 'type': float, 'doc': 'The stop time of the interval'})
def add_interval(self, **kwargs):
start, stop = getargs('start', 'stop', kwargs)
self.__interval_timestamps.append(start)
self.__interval_timestamps.append(stop)
self.__interval_data.append(1)
self.__interval_data.append(-1)
@property
def data(self):
return self.__interval_data
@property
def timestamps(self):
return self.__interval_timestamps
[docs]
@register_class('Units', CORE_NAMESPACE)
class Units(DynamicTable):
"""
Event times of observed units (e.g. cell, synapse, etc.).
"""
__fields__ = (
'waveform_rate',
'waveform_unit',
'resolution'
)
waveforms_desc = ('Individual waveforms for each spike. If the dataset is three-dimensional, the third dimension '
'shows the response from different electrodes that all observe this unit simultaneously. In this'
' case, the `electrodes` column of this Units table should be used to indicate which electrodes '
'are associated with this unit, and the electrodes dimension here should be in the same order as'
' the electrodes referenced in the `electrodes` column of this table.')
__columns__ = (
{'name': 'spike_times', 'description': 'the spike times for each unit in seconds', 'index': True},
{'name': 'obs_intervals', 'description': 'the observation intervals for each unit',
'index': True},
{'name': 'electrodes', 'description': 'the electrodes that each spike unit came from',
'index': True, 'table': True},
{'name': 'electrode_group', 'description': 'the electrode group that each spike unit came from'},
{'name': 'waveform_mean', 'description': 'the spike waveform mean for each spike unit'},
{'name': 'waveform_sd', 'description': 'the spike waveform standard deviation for each spike unit'},
{'name': 'waveforms', 'description': waveforms_desc, 'index': 2}
)
@docval({'name': 'name', 'type': str, 'doc': 'Name of this Units interface', 'default': 'Units'},
*get_docval(DynamicTable.__init__, 'id', 'columns', 'colnames', 'target_tables', 'meanings_tables'),
{'name': 'description', 'type': str, 'doc': 'a description of what is in this table', 'default': None},
{'name': 'electrode_table', 'type': DynamicTable,
'doc': 'the table that the *electrodes* column indexes', 'default': None},
{'name': 'waveform_rate', 'type': float,
'doc': 'Sampling rate of the waveform means', 'default': None},
{'name': 'waveform_unit', 'type': str,
'doc': 'Unit of measurement of the waveform means', 'default': 'volts'},
{'name': 'resolution', 'type': float,
'doc': 'The smallest possible difference between two spike times', 'default': None},
allow_positional=AllowPositional.WARNING,
)
def __init__(self, **kwargs):
args_to_set = popargs_to_dict(("waveform_rate", "waveform_unit", "resolution"), kwargs)
electrode_table = popargs("electrode_table", kwargs)
if kwargs['description'] is None:
kwargs['description'] = "data on spiking units"
super().__init__(**kwargs)
for key, val in args_to_set.items():
setattr(self, key, val)
if 'spike_times' not in self.colnames:
self.__has_spike_times = False
self.__electrode_table = electrode_table
[docs]
@docval({'name': 'spike_times', 'type': 'array_data', 'doc': 'the spike times for each unit in seconds',
'default': None, 'shape': (None,)},
{'name': 'obs_intervals', 'type': 'array_data',
'doc': 'the observation intervals (valid times) for each unit. All spike_times for a given unit ' +
'should fall within these intervals. [[start1, end1], [start2, end2], ...]',
'default': None, 'shape': (None, 2)},
{'name': 'electrodes', 'type': 'array_data', 'doc': 'the electrodes that each unit came from',
'default': None},
{'name': 'electrode_group', 'type': ElectrodeGroup, 'default': None,
'doc': 'the electrode group that each unit came from'},
{'name': 'waveform_mean', 'type': 'array_data',
'doc': 'the spike waveform mean for each unit. Shape is (time,) or (time, electrodes)',
'default': None},
{'name': 'waveform_sd', 'type': 'array_data', 'default': None,
'doc': 'the spike waveform standard deviation for each unit. Shape is (time,) or (time, electrodes)'},
{'name': 'waveforms', 'type': 'array_data', 'default': None, 'doc': waveforms_desc,
'shape': ((None, None), (None, None, None))},
{'name': 'id', 'type': int, 'default': None, 'doc': 'the id for each unit'},
allow_extra=True,)
def add_unit(self, **kwargs):
"""
Add a unit to this table
"""
super().add_row(**kwargs)
if 'electrodes' in self:
elec_col = self['electrodes'].target
if elec_col.table is None:
if self.__electrode_table is None:
nwbfile = self.get_ancestor(data_type='NWBFile')
elec_col.table = nwbfile.electrodes
if elec_col.table is None:
warnings.warn('Reference to electrode table that does not yet exist')
else:
elec_col.table = self.__electrode_table
[docs]
@docval({'name': 'index', 'type': (int, list, tuple, np.ndarray),
'doc': 'the index of the unit in unit_ids to retrieve spike times for'},
{'name': 'in_interval', 'type': (tuple, list), 'doc': 'only return values within this interval',
'default': None, 'shape': (2,)})
def get_unit_spike_times(self, **kwargs):
index, in_interval = getargs('index', 'in_interval', kwargs)
if type(index) in (list, tuple):
return [self.get_unit_spike_times(i, in_interval=in_interval) for i in index]
if in_interval is None:
return np.asarray(self['spike_times'][index])
else:
st = self['spike_times']
unit_start = 0 if index == 0 else st.data[index - 1]
unit_stop = st.data[index]
start_time, stop_time = in_interval
ind_start = bisect_left(st.target, start_time, unit_start, unit_stop)
ind_stop = bisect_right(st.target, stop_time, ind_start, unit_stop)
return np.asarray(st.target[ind_start:ind_stop])
[docs]
@docval({'name': 'index', 'type': int,
'doc': 'the index of the unit in unit_ids to retrieve observation intervals for'})
def get_unit_obs_intervals(self, **kwargs):
index = getargs('index', kwargs)
return np.asarray(self['obs_intervals'][index])
[docs]
def get_starting_time(self):
"""
Get the earliest spike time across all units in this Units table.
Returns
-------
float or None
The earliest spike time in seconds, or None if the table is empty,
has no spike_times column, or has no spike data.
Notes
-----
This method checks the first spike of every unit because units are not
assumed to be in chronological order (the earliest spike may be in any unit).
Edge cases:
- Returns None if the table is empty (no units)
- Returns None if the spike_times column does not exist
- Returns None if all units have empty spike_times arrays
"""
if len(self) == 0:
return None
if 'spike_times' not in self:
return None
spike_times_col = self['spike_times']
indices = np.asarray(spike_times_col.data[:]) # Cumulative end positions
# First spike indices: unit 0 starts at 0, unit i starts at indices[i-1]
first_spike_indices = np.concatenate([[0], indices[:-1]])
# Filter out empty units where start index == end index (no spikes)
has_spikes = first_spike_indices != indices
first_spike_indices = first_spike_indices[has_spikes]
if len(first_spike_indices) == 0:
return None
spike_times_data = spike_times_col.target.data
# In-memory data might be stored as a list which doesn't support numpy operations below
if isinstance(spike_times_data, list):
spike_times_data = np.array(spike_times_data)
if len(spike_times_data) == 0:
return None
first_spike_times = spike_times_data[first_spike_indices]
return float(np.min(first_spike_times))
[docs]
def get_duration(self):
"""
Get the duration from the earliest to the latest spike time across all units.
Returns
-------
float or None
The duration in seconds, or None if the table is empty, has no
spike_times column, or has no spike data.
Notes
-----
The duration represents the time span from the earliest spike to the latest
spike across all units, not the sum of individual unit recording durations.
This method checks the first and last spike of every unit because units are
not assumed to be in chronological order (the earliest or latest spike may
be in any unit).
Edge cases:
- Returns None if the table is empty (no units)
- Returns None if the spike_times column does not exist
- Returns None if all units have empty spike_times arrays
- Returns 0.0 if there is only one spike across all units
"""
if len(self) == 0:
return None
if 'spike_times' not in self:
return None
spike_times_col = self['spike_times']
indices = np.asarray(spike_times_col.data[:]) # Cumulative end positions
# First spike indices: unit 0 starts at 0, unit i starts at indices[i-1]
first_spike_indices = np.concatenate([[0], indices[:-1]])
# Last spike indices: unit i ends at indices[i], so last spike is at indices[i] - 1
last_spike_indices = indices - 1
# Filter out empty units where start index == end index (no spikes)
has_spikes = first_spike_indices != indices
first_spike_indices = first_spike_indices[has_spikes]
last_spike_indices = last_spike_indices[has_spikes]
# Combine and deduplicate for efficient reading
all_indices = np.unique(np.concatenate([first_spike_indices, last_spike_indices]))
spike_times_data = spike_times_col.target.data
# In-memory data might be stored as a list which doesn't support numpy operations below
if isinstance(spike_times_data, list):
spike_times_data = np.array(spike_times_data)
if len(spike_times_data) == 0:
return None
boundary_spike_times = spike_times_data[all_indices]
first_spike_time = float(np.min(boundary_spike_times))
last_spike_time = float(np.max(boundary_spike_times))
return last_spike_time - first_spike_time
[docs]
@register_class('FrequencyBandsTable', CORE_NAMESPACE)
class FrequencyBandsTable(DynamicTable):
"""
Table for describing the bands that DecompositionSeries was generated from.
"""
__columns__ = (
{'name': 'band_name', 'description': 'Name of the band, e.g. theta.', 'required': True},
{'name': 'band_limits', 'description': 'Low and high limit of each band in Hz.', 'required': True},
{'name': 'band_mean', 'description': 'The mean Gaussian filters, in Hz.', 'required': False},
{'name': 'band_stdev', 'description': 'The standard deviation Gaussian filters, in Hz.', 'required': False}
)
@docval(*get_docval(DynamicTable.__init__, 'id', 'columns', 'colnames', 'target_tables', 'meanings_tables'),
allow_positional=AllowPositional.WARNING,)
def __init__(self, **kwargs):
kwargs['name'] = 'bands'
kwargs['description'] = 'Table for describing the bands that DecompositionSeries was generated from.'
super().__init__(**kwargs)
[docs]
@docval(
{'name': 'band_name', 'type': str, 'doc': 'Name of the band, e.g. theta.'},
{'name': 'band_limits', 'type': ('array_data', 'data'), 'shape': (2, ),
'doc': 'Low and high limit of each band in Hz.'},
{'name': 'band_mean', 'type': float, 'doc': 'The mean Gaussian filters, in Hz.',
'default': None},
{'name': 'band_stdev', 'type': float, 'doc': 'The standard deviation Gaussian filters, in Hz.',
'default': None},
allow_extra=True
)
def add_band(self, **kwargs):
super().add_row(**kwargs)
[docs]
@register_class('DecompositionSeries', CORE_NAMESPACE)
class DecompositionSeries(TimeSeries):
"""
Stores product of spectral analysis
"""
__nwbfields__ = ('metric',
{'name': 'source_timeseries', 'child': False, 'doc': 'the input TimeSeries from this analysis'},
{'name': 'source_channels', 'child': True, 'doc': 'the channels that provided the source data'},
{'name': 'bands',
'doc': 'the bands that the signal is decomposed into', 'child': True})
# value used when a DecompositionSeries is read and missing data
DEFAULT_DATA = np.ndarray(shape=(0, 0, 0), dtype=np.uint8)
@docval(*get_docval(TimeSeries.__init__, 'name'), # required
{'name': 'data', 'type': ('array_data', 'data', TimeSeries), # required
'doc': ('The data values. Must be 3D, where the first dimension must be time, the second dimension must '
'be channels, and the third dimension must be bands.'),
'shape': (None, None, None)},
*get_docval(TimeSeries.__init__, 'description'),
{'name': 'metric', 'type': str, # required
'doc': "metric of analysis. recommended - 'phase', 'amplitude', 'power'"},
{'name': 'unit', 'type': str, 'doc': 'SI unit of measurement', 'default': 'no unit'},
{'name': 'bands', 'type': FrequencyBandsTable,
'doc': 'a table for describing the frequency bands that the signal was decomposed into', 'default': None},
{'name': 'source_timeseries', 'type': TimeSeries,
'doc': 'the input TimeSeries from this analysis', 'default': None},
{'name': 'source_channels', 'type': DynamicTableRegion,
'doc': ('The channels that provided the source data. In the case of electrical recordings this is '
'typically a DynamicTableRegion pointing to the electrodes table at NWBFile.electrodes, '
'similar to ElectricalSeries.electrodes.'),
'default': None},
*get_docval(TimeSeries.__init__, 'resolution', 'conversion', 'timestamps', 'starting_time', 'rate',
'comments', 'control', 'control_description', 'offset'),
allow_positional=AllowPositional.WARNING,)
def __init__(self, **kwargs):
metric, source_timeseries, bands, source_channels = popargs('metric', 'source_timeseries', 'bands',
'source_channels', kwargs)
super().__init__(**kwargs)
self.source_timeseries = source_timeseries
self.source_channels = source_channels
if self.source_timeseries is None and self.source_channels is None:
warnings.warn("Neither source_timeseries nor source_channels is present in DecompositionSeries. It is "
"recommended to indicate the source timeseries if it is present, or else to link to the "
"corresponding source_channels. (Optional)")
self.metric = metric
if bands is None:
bands = FrequencyBandsTable()
self.bands = bands
[docs]
@docval(
{'name': 'band_name', 'type': str, 'doc': 'the name of the frequency band'},
{'name': 'band_limits', 'type': ('array_data', 'data'),
'doc': 'low and high frequencies of bandpass filter in Hz'},
{'name': 'band_mean', 'type': float, 'doc': 'the mean of Gaussian filters in Hz',
'default': None},
{'name': 'band_stdev', 'type': float, 'doc': 'the standard deviation of Gaussian filters in Hz',
'default': None},
allow_extra=True
)
def add_band(self, **kwargs):
"""Add a frequency band to the bands table of this DecompositionSeries."""
self.bands.add_band(**kwargs)