Bootstrap particle filter for Python¶
Welcome to the pypfilt documentation. This package implements a bootstrap particle filter that can be used for recursive Bayesian estimation and forecasting.
If there is a system or process that can be:
- Described (modelled) with mathematical equations; and
- Measured repeatedly in some (noisy) way.
Then you can use pypfilt to estimate the state (and parameters) of this system.
The Getting Started guide shows how to estimate the size of prey and predator species populations, and how to generate forecasts that predict the future sizes of these populations.

Example forecasts of the prey population \(x(t)\) and the predator population \(y(t)\), generated at different times and using noisy observations of both populations.
Getting Started¶
This guide assumes that you have already installed the pypfilt package.
Lotka-Volterra (predator-prey) equations¶
Here we will show how to generate forecasts for the (continuous)
Lotka-Volterra equations, which describe the dynamics of biological systems
in which two species interact (one predator, one prey).
The source code is provided in pypfilt.examples.predation
.
-
class
pypfilt.examples.predation.
LotkaVolterra
¶ An implementation of the (continuous) Lotka-Volterra equations.
\[\begin{split}\frac{dx}{dt} &= \alpha x - \beta xy \\ \frac{dy}{dt} &= \delta xy - \gamma y\end{split}\]Symbol Meaning \(x(t)\) The size of the prey population (1,000s). \(y(t)\) The size of the predator population (1,000s). \(\alpha\) Exponential growth rate in the absence of predators. \(\beta\) The rate at which prey suffer from predation. \(\delta\) The predator growth rate, driven by predation. \(\gamma\) Exponential decay rate of the predator population. All of the state variables and parameters are stored in the particle:
\[\mathbf{x_t} = [x, y, \alpha, \beta, \delta, \gamma]^T\]This class also provides a method for generating noisy observations from a known ground truth:
-
obs
(sdev, x0, y0, alpha, beta, gamma, delta, t_max, seed=42)¶ Parameters: - sdev – The standard deviation of the observation error.
- x0 – The initial size of the prey population.
- y0 – The initial size of the predator population.
- alpha – The true value of the model parameter alpha.
- beta – The true value of the model parameter beta.
- gamma – The true value of the model parameter gamma.
- delta – The true value of the model parameter delta.
- t_max – The simulation duration.
- seed – The seed for the observation PRNG.
-
Example outputs¶
The rest of this “Getting Started” guide will demonstrate how to generate forecasts and produce plots like those shown below.

Forecasts produced by the LotkaVolterra
model, using noisy
observations generated by this same model (LotkaVolterra.obs()
)
and a known ground truth.

The posterior parameter distributions for the LotkaVolterra
model, using the noisy observations shown in the forecasts above.
Running the forecasts¶
Model estimations and subsequent forecasts are generated by
pypfilt.forecast()
, which takes the following arguments:
- A parameter dictionary, including the system model and the observation model;
- The start and end of the simulation period;
- Any number of observation streams;
- The times at which forecasts should be generated;
- A summary object to calculate relevant statistics; and
- The output file.
One critical parameter is the number of particles to use (px_count
).
With too few particles, it is highly likely that none of the particles will be
in good agreement with the observations (“particle degeneracy”).
With too many particles, the computational cost will be very high and the
forecasts will take a long time to complete.
In the example below, 1,000 particles are used (highlighted line).
def forecast(data_file):
"""Run a suite of forecasts against generated observations."""
logger = logging.getLogger(__name__)
logger.info('Preparing the forecast simulations')
# Define the simulation period and forecasting times.
t0 = 0.0
t1 = 15.0
fs_times = [1.0, 3.0, 5.0, 7.0, 9.0]
params = make_params(px_count=1000, seed=42, obs_sdev=0.2)
# Generate noisy observations.
obs = params['model'].obs(params['obs']['sdev'], x0=0.9, y0=0.25,
alpha=2/3, beta=4/3, gamma=1, delta=1, t_max=t1)
# Define the summary tables to be saved to disk.
summary = pypfilt.summary.HDF5(params, obs, first_day=True)
summary.add_tables(
pypfilt.summary.ModelCIs(probs=[0, 50, 95]),
pypfilt.summary.Obs())
# Run the forecast simulations.
pypfilt.forecast(params, t0, t1, [obs], fs_times, summary, data_file)
When generating forecasts on a regular basis (e.g., daily or weekly, in response to new or updated observations) the particle states can be saved to disk to greatly improve the speed with which the forecasts are generated. This is enabled by defining a cache file:
# If the location is not an absolute path, it is defined
# relative to the output directory, params['out_dir'].
params['hist']['cache_file'] = 'cache.hdf5'
Plotting the results¶
Plotting the forecast results is a two-step process; first, the results must be read from the output file and massaged into the appropriate form, then the plots themselves must be constructed. The first step is illustrated here:
def plot(data_file, png=True, pdf=True):
logger = logging.getLogger(__name__)
logger.info('Loading outputs from {}'.format(data_file))
# Use the 'Agg' backend so that plots can be generated non-interactively.
import matplotlib
matplotlib.use('Agg')
# File names for the generated plots.
fs_pdf = 'predation_forecasts.pdf'
fs_png = 'predation_forecasts.png'
pp_pdf = 'predation_params.pdf'
pp_png = 'predation_params.png'
# Read in the model credible intervals and the observations.
with h5py.File(data_file) as f:
cints = f['/data/model_cints'][()]
obs = f['/data/obs'][()]
# Convert serialised values into more convenient data types.
convs = pypfilt.summary.default_converters(pypfilt.Scalar())
cints = pypfilt.summary.convert_cols(cints, convs)
obs = pypfilt.summary.convert_cols(obs, convs)
# Separate the observations of the two populations.
x_obs = obs[obs['unit'] == b'x']
y_obs = obs[obs['unit'] == b'y']
# Separate the credible intervals for the population sizes from the
# credible intervals for the model parameters.
var_mask = np.logical_or(cints['name'] == b'x',
cints['name'] == b'y')
state_cints = cints[var_mask]
param_cints = cints[np.logical_not(var_mask)]
# Only keep the population sizes from each forecast.
fs_mask = state_cints['fs_date'] < max(state_cints['date'])
state_cints = state_cints[fs_mask]
# Only keep the model parameter posteriors from the estimation run.
est_mask = param_cints['fs_date'] == max(param_cints['date'])
param_cints = param_cints[est_mask]
# Plot the population forecasts.
pdf_file = fs_pdf if pdf else None
png_file = fs_png if png else None
plot_forecasts(state_cints, x_obs, y_obs, pdf_file, png_file)
# Plot the model parameter posterior distributions.
pdf_file = pp_pdf if pdf else None
png_file = pp_png if png else None
plot_params(param_cints, pdf_file, png_file)
The pypfilt.plot
module provides functions for plotting observations
and credible intervals, and classes for constructing figures with sub-plots.
These are highlighted in the following two functions, which were used to
produce the figures shown at the top of this guide.
def plot_forecasts(state_cints, x_obs, y_obs, pdf_file=None, png_file=None):
"""Plot the population predictions at each forecasting date."""
logger = logging.getLogger(__name__)
with pypfilt.plot.apply_style():
plot = pypfilt.plot.Grid(
state_cints, 'Time', 'Population Size (1,000s)',
('fs_date', 'Forecast @ t = {:0.0f}'),
('name', lambda bs: '{}(t)'.format(pypfilt.text.to_unicode(bs))))
plot.expand_x_lims('date')
plot.expand_y_lims('ymax')
for (ax, df) in plot.subplots():
ax.axhline(y=0, xmin=0, xmax=1,
linewidth=1, linestyle='--', color='k')
hs = pypfilt.plot.cred_ints(ax, df, 'date', 'prob')
if df['name'][0] == b'x':
df_obs = x_obs
else:
df_obs = y_obs
past_obs = df_obs[df_obs['date'] <= df['fs_date'][0]]
future_obs = df_obs[df_obs['date'] > df['fs_date'][0]]
hs.extend(pypfilt.plot.observations(ax, past_obs,
label='Past observations'))
hs.extend(pypfilt.plot.observations(ax, future_obs,
future=True,
label='Future observations'))
plot.add_to_legend(hs)
# Adjust the axis limits and the number of ticks.
ax.set_xlim(left=0)
ax.locator_params(axis='x', nbins=4)
ax.set_ylim(bottom=-0.2)
ax.locator_params(axis='y', nbins=4)
plot.legend(loc='upper center', ncol=5)
if pdf_file:
logger.info('Plotting to {}'.format(pdf_file))
plot.save(pdf_file, format='pdf', width=10, height=5)
if png_file:
logger.info('Plotting to {}'.format(png_file))
plot.save(png_file, format='png', width=10, height=5)
def plot_params(param_cints, pdf_file=None, png_file=None):
"""Plot the parameter posteriors over the estimation run."""
logger = logging.getLogger(__name__)
with pypfilt.plot.apply_style():
plot = pypfilt.plot.Wrap(
param_cints, 'Time', 'Value',
('name', lambda bs: '$\\{}$'.format(pypfilt.text.to_unicode(bs))),
nr=1)
plot.expand_y_lims('ymax')
for (ax, df) in plot.subplots():
hs = pypfilt.plot.cred_ints(ax, df, 'date', 'prob')
if df['name'][0] == b'alpha':
y_true = 2/3
elif df['name'][0] == b'beta':
y_true = 4/3
elif df['name'][0] == b'gamma':
y_true = 1
elif df['name'][0] == b'delta':
y_true = 1
hs.append(ax.axhline(y=y_true, xmin=0, xmax=1, label='True value',
linewidth=1, linestyle='--', color='k'))
plot.add_to_legend(hs)
plot.legend(loc='upper center', ncol=5)
if pdf_file:
logger.info('Plotting to {}'.format(pdf_file))
plot.save(pdf_file, format='pdf', width=10, height=3)
if png_file:
logger.info('Plotting to {}'.format(png_file))
plot.save(png_file, format='png', width=10, height=3)
Observations¶
Observations are represented as dictionaries that have the following keys:
{'date': ..., # When the observation was made (number, date, etc)
'value': 200, # The numerical quantity that was measured
'unit': 'Some measure', # A description of the measurement units
'period': 7, # The observation period, in days
'source': 'Some system', # A description of the data source
}
An observation stream is represented as a chronologically sorted list of
observations (oldest first).
The particle filter accepts any number of observation streams, which must be
provided as a list (i.e., a list of observation lists); see
forecast()
and run()
.
Observation models¶
For simplicity, we assume that both the prey and predator populations — \(x(t)\) and \(y(t)\) — are directly observed, and that the observation error is distributed normally with zero mean and a known standard deviation.
def log_llhd(params, obs_list, curr, prev_dict, weights):
"""Calculate the observation log-likelihoods for each particle."""
# The expected observations are x(t) and y(t).
x_dist = scipy.stats.norm(loc=curr[..., 0], scale=params['obs']['sdev'])
y_dist = scipy.stats.norm(loc=curr[..., 1], scale=params['obs']['sdev'])
# Calculate the log-likelihood of each observation in turn.
log_llhd = np.zeros(curr.shape[:-1])
for o in obs_list:
if o['unit'] == 'x':
log_llhd += x_dist.logpdf(o['value'])
elif o['unit'] == 'y':
log_llhd += y_dist.logpdf(o['value'])
else:
raise ValueError('invalid observation')
return log_llhd
The observation model must be stored in params['log_llhd_fn']
.
Note that the argument prev_dict
can be used to obtain the state vectors
at the beginning of an observation period.
This is useful for situations where the observation depends on the change
in the state vector over the observation period.
def log_llhd(params, obs_list, curr, prev_dict, weights):
# Obtain the state vectors two time units ago.
# This is only valid if an observation has a period of 2.
prev_state = prev_dict[2]
dx = curr[..., 0] - prev_state[..., 0]
...
Parameters¶
Particle filter parameters are provided by default_params()
.
At a minimum, the simulation parameters must define the model, the time scale,
and the observation model.
For reproducibility, it is also advisable to set the PRNG seed.
def make_params(px_count, seed, obs_sdev):
"""Define the default simulation parameters for this model."""
model = LotkaVolterra()
time_scale = pypfilt.Scalar()
params = pypfilt.default_params(model, time_scale, px_count=px_count)
# Use one time-step per unit time, odeint will interpolate as needed.
params['steps_per_unit'] = 1
params['log_llhd_fn'] = log_llhd
params['obs'] = {'sdev': obs_sdev}
# Set the PRNG seed.
params['resample']['prng_seed'] = seed
# Write output to the working directory.
params['out_dir'] = '.'
params['tmp_dir'] = '.'
return params
System models¶
The model of the underlying system must inherit from pypfilt.Model
.
Here is the predator-prey model from pypfilt.examples.predation
:
class LotkaVolterra(pypfilt.Model):
"""An implementation of the (continuous) Lotka-Volterra equations."""
def init(self, params, vec):
"""Initialise a matrix of state vectors."""
# Select x(0), y(0), and the parameters according to the priors.
rnd = params['resample']['rnd']
size = vec[..., 0].shape
vec[..., 0] = params['prior']['x'](rnd, size)
vec[..., 1] = params['prior']['y'](rnd, size)
vec[..., 2] = params['prior']['alpha'](rnd, size)
vec[..., 3] = params['prior']['beta'](rnd, size)
vec[..., 4] = params['prior']['gamma'](rnd, size)
vec[..., 5] = params['prior']['delta'](rnd, size)
def state_size(self):
"""Return the size of the state vector."""
return 6
def priors(self, params):
"""Return a dictionary of model priors."""
return {
'x': lambda r, size=None: r.uniform(0.5, 1.5, size=size),
'y': lambda r, size=None: r.uniform(0.2, 0.4, size=size),
'alpha': lambda r, size=None: r.uniform(0.6, 0.8, size=size),
'beta': lambda r, size=None: r.uniform(1.2, 1.4, size=size),
'gamma': lambda r, size=None: r.uniform(0.9, 1.1, size=size),
'delta': lambda r, size=None: r.uniform(0.9, 1.1, size=size),
}
def d_dt(self, xt, t):
"""Calculate the derivatives of x(t) and y(t)."""
# Restore the 2D shape of the flattened state matrix.
xt = xt.reshape((-1, 6))
x, y = xt[..., 0], xt[..., 1]
d_dt = np.zeros(xt.shape)
# Calculate dx/dt and dy/dt.
d_dt[..., 0] = xt[..., 2] * x - xt[..., 3] * x * y
d_dt[..., 1] = xt[..., 4] * x * y - xt[..., 5] * y
# Flatten the 2D derivatives matrix.
return d_dt.reshape(-1)
def update(self, params, t, dt, is_fs, prev, curr):
"""Perform a single time-step."""
# Use scalar time, so that ``t + dt`` is well-defined.
t = params['time'].to_scalar(t)
# The state matrix must be flattened for odeint.
xt = scipy.integrate.odeint(self.d_dt, prev.reshape(-1),
[t, t + dt])[1]
# Restore the 2D shape of the flattened state matrix.
curr[:] = xt.reshape(curr.shape)
def describe(self):
"""Describe each component of the state vector."""
return [
# Restrict x(t), y(t) to [0, 10^5], don't allow regularisation.
('x', False, 0, 1e5),
('y', False, 0, 1e5),
# Restrict parameters to [0, 2], allow regularisation.
('alpha', True, 0, 2),
('beta', True, 0, 2),
('gamma', True, 0, 2),
('delta', True, 0, 2),
]
def obs(self, sdev, x0, y0, alpha, beta, gamma, delta, t_max, seed=42):
"""Generate noisy observations from a known ground truth."""
# Make the priors reflect the known ground truth.
rnd = np.random.RandomState(seed)
obs_params = {
'resample': {
'rnd': rnd,
},
'prior': {
'x': lambda r, size=None: x0 * np.ones(size),
'y': lambda r, size=None: y0 * np.ones(size),
'alpha': lambda r, size=None: alpha * np.ones(size),
'beta': lambda r, size=None: beta * np.ones(size),
'gamma': lambda r, size=None: gamma * np.ones(size),
'delta': lambda r, size=None: delta * np.ones(size),
},
}
# Simulate a single particle.
xt_init = np.zeros((1, self.state_size()))
self.init(obs_params, xt_init)
xt = scipy.integrate.odeint(self.d_dt, xt_init.reshape(-1),
range(int(np.ceil(t_max + 1))))[1:]
# Observe both populations once per time unit.
obs = []
for (ix, x) in enumerate(xt):
obs.append({'date': ix + 1, 'period': 1, 'unit': 'x',
'value': rnd.normal(x[0], sdev),
'source': 'noisy_obs()'})
obs.append({'date': ix + 1, 'period': 1, 'unit': 'y',
'value': rnd.normal(x[1], sdev),
'source': 'noisy_obs()'})
return obs
Summary objects¶
Simulations typically comprise a large number of both particles and time steps, and so it is generally preferable to record statistics that summarise the particles than to store the entire state history of each simulation.
This functionality is provided by pypfilt.summary.HDF5
, which allows
any number of summary tables to be recorded. Once all of the estimation and
forecasting simulations have been performed,
save_forecasts()
will save the results to disk.
The example shown in Running the forecasts (and repeated below, with the relevant
lines highlighted) demonstrates how to record fixed-probability central
credible intervals for the state variables and model parameters with the
ModelCIs
table, and the observations with the
Obs
table.
These are the same tables that were used to produce the plots show in
Example outputs.
def forecast(data_file):
"""Run a suite of forecasts against generated observations."""
logger = logging.getLogger(__name__)
logger.info('Preparing the forecast simulations')
# Define the simulation period and forecasting times.
t0 = 0.0
t1 = 15.0
fs_times = [1.0, 3.0, 5.0, 7.0, 9.0]
params = make_params(px_count=1000, seed=42, obs_sdev=0.2)
# Generate noisy observations.
obs = params['model'].obs(params['obs']['sdev'], x0=0.9, y0=0.25,
alpha=2/3, beta=4/3, gamma=1, delta=1, t_max=t1)
# Define the summary tables to be saved to disk.
summary = pypfilt.summary.HDF5(params, obs, first_day=True)
summary.add_tables(
pypfilt.summary.ModelCIs(probs=[0, 50, 95]),
pypfilt.summary.Obs())
# Run the forecast simulations.
pypfilt.forecast(params, t0, t1, [obs], fs_times, summary, data_file)
Installation¶
The requirements for pypfilt are:
- NumPy 1.8 or newer;
- SciPy 0.11 or newer;
- h5py 2.2 or newer; and
- matplotlib 1.5 or newer (optional, see Plotting).
Installing required packages¶
Recommended method¶
The simplest way to install these packages (particularly on Windows) is to
use Anaconda, which automatically
installs them all by default.
You can also use a package manager, such as apt-get
(Debian, Ubuntu),
yum
(Red Hat Enterprise Linux, CentOS), dnf
(Fedora), or
Homebrew (OS X).
Binary installation on Linux and OS X¶
On Linux and OS X, you should be able to install binary versions of the required packages (“wheels”) and avoid lengthy compilation times:
pip install --only-binary :all: 'numpy>=1.8' 'scipy>=0.11' 'h5py>=2.2' # Optional: install matplotlib to use the pypfilt.plot module. pip install --only-binary :all: 'matplotlib>=1.5'
This is best done within a virtual environment (see the source installation instructions, below).
Source installation on Linux and OS X¶
Warning
Installing from source on Windows is effectively impossible, due to the dependencies of h5py.
Note
If you are using Python 3, you will most likely need to substitute “python3” for “python” in all of the package names listed here.
Alternatively, these packages can be manually installed in a Virtual Environment, by using virtualenv. This requires the following development tools:
- C and Fortran compilers (typically
gcc
andgfortran
).- Debian:
sudo apt-get install gcc fortran
. - Red Hat Enterprise Linux and CentOS:
sudo yum install gcc gcc-gfortran
. - Fedora:
sudo dnf install gcc gcc-gfortran
. - OS X: Install Command Line Tools for Xcode (instructions).
- Debian:
- Python header files.
- Debian:
sudo apt-get install python-dev
. - Red Hat Enterprise Linux and CentOS:
sudo yum install python-devel
. - Fedora:
sudo dnf install python-devel
. - OS X:
brew install python
(see why installing a separate version of Python is a good idea).
- Debian:
- Linear algebra libraries (typically ATLAS and LAPACK, or MKL, or ACML).
Then install virtualenv
and the libhdf5
development files:
# For Debian and Debian-based distributions such as Ubuntu.
sudo apt-get install virtualenv libhdf5-dev
# For Red Hat Enterprise Linux and CentOS.
sudo yum install python-virtualenv hdf5-devel
# For Fedora.
sudo dnf install python-virtualenv hdf5-devel
# For OS X.
brew install python homebrew/science/hdf5
pip install virtualenv
Then create a virtual environment (called venv-pypfilt
in the following
example):
# Create and activate the virtual environment.
virtualenv venv-pypfilt
source venv-pypfilt/bin/activate
# Upgrade pip, setuptools and wheel to the latest versions.
pip install --upgrade pip
pip install --upgrade setuptools wheel
# Install NumPy before SciPy, and Cython before h5py.
pip install 'numpy>=1.8' 'Cython >=0.17'
pip install 'scipy>=0.11'
# Note: may need to identify the directory that contains `include/hdf5.h`.
# For example, for 64-bit Debian and Debian-based distributions:
# export HDF5_DIR=/usr/lib/x86_64-linux-gnu/hdf5/serial
pip install 'h5py>=2.2'
# Optional: install matplotlib to use the pypfilt.plot module.
pip install 'matplotlib>=1.5'
Note
In order to install h5py, you may need to identify the directory
that contains include/hdf5.h
by defining the HDF5_DIR
environment
variable (see the comments in the code block above).
You can search for include/hdf5.h
by running
find -L /usr -name hdf5.h
.
On Red Hat Enterprise Linux, CentOS, and Fedora, this file is located at
/usr/include/hdf5.h
and there is no need to define HDF5_DIR
.
On OS X, this file is located at /usr/local/include/hdf5.h
when using
Homebrew, and there should be no need to define HDF5_DIR
.
Installing pypfilt¶
Once the required packages have been installed (see instructions, above), you
can clone the pypfilt repository and install it in the venv-pypfilt
virtual environment:
# Activate the virtual environment.
source venv-pypfilt/bin/activate
# Clone the pypfilt repository.
git clone https://bitbucket.org/robmoss/particle-filter-for-python.git
# Install pypfilt in the virtual environment.
cd particle-filter-for-python
python setup.py install
If you are not using a virtual environment, and you don’t have permission to install pypfilt system-wide, you can install the package locally:
# Clone the pypfilt repository.
git clone https://bitbucket.org/robmoss/particle-filter-for-python.git
# Install pypfilt in the user's "site-packages" directory.
cd particle-filter-for-python
python setup.py install --user
Building the documentation¶
If you want to build the documentation locally, you will need to install Sphinx 1.3 or newer, and the Read the Docs Sphinx Theme.
These can be installed through a package manager:
# For Debian and Debian-based distributions such as Ubuntu.
sudo apt-get install python-sphinx python-sphinx-rtd-theme
# For Red Hat Enterprise Linux and CentOS.
sudo yum install python-sphinx python-sphinx_rtd_theme
# For Fedora.
sudo dnf install python-sphinx python-sphinx_rtd_theme
# For OS X.
brew install sphinx
pip install sphinx_rtd_theme
Alternatively, they can be installed in the venv-pypfilt
virtual
environment:
# Activate the virtual environment.
source venv-pypfilt/bin/activate
pip install 'Sphinx>=1.3' sphinx_rtd_theme
You can then build the documentation from the pypfilt repository, which
will be written to the doc/build/html
directory:
python setup.py build_sphinx
API documentation¶
Generating a series of forecasts¶
Model estimation and forecasting is provided as a single function:
-
pypfilt.
forecast
(params, start, end, streams, dates, summary, filename)¶ Generate forecasts from various dates during a simulation.
Parameters: - params (dict) – The simulation parameters.
- start – The start of the simulation period.
- end – The (exclusive) end of the simulation period.
- streams – A list of observation streams.
- dates – The dates at which forecasts should be generated.
- summary – An object that generates summaries of each simulation.
- filename – The output file to generate (can be
None
).
Returns: The simulation state for each forecast date.
This function returns a dictionary that contains the following keys:
'obs'
: a (flattened) list of every observation;'complete'
: the simulation state obtained by assimilating every observation; anddatetime.datetime
instances: the simulation state obtained for each forecast, identified by the forecasting date.
The simulation states are generated by pypfilt.run()
and contain the
following keys:
'params'
: the simulation parameters;'summary'
: the dictionary of summary statistics; and'hist'
: the matrix of particle state vectors, including individual particle weights (hist[..., -2]
) and the index of each particle at the previous time-step (hist[..., -1]
), since these can change due to resampling.
The matrix has dimensions \(N_{Steps} \times N_{Particles} \times (N_{SV} + 2)\) for state vectors of size \(N_{SV}\).
Note: if max_days > 0
was passed to pypfilt.default_params()
,
only a fraction of the entire simulation period will be available.
Particle filter parameters¶
Default values for the particle filter parameters are provided:
-
pypfilt.
default_params
(model, time_scale, max_days=0, px_count=0)¶ The default particle filter parameters.
Memory usage can reach extreme levels with a large number of particles, and so it may be necessary to keep only a sliding window of the entire particle history matrix in memory.
Parameters: - model – The system model.
- time_scale – The simulation time scale.
- max_days – The number of contiguous days that must be kept in memory (e.g., the largest observation period).
- px_count – The number of particles.
The bootstrap particle filter¶
The bootstrap particle filter is exposed as a single-step function, which will update particle weights and perform resampling as necessary:
-
pypfilt.
step
(params, hist, hist_ix, step_num, when, step_obs, max_back, is_fs)¶ Perform a single time-step for every particle.
Parameters: - params – The simulation parameters.
- hist – The particle history matrix.
- hist_ix – The index of the current time-step in the history matrix.
- step_num – The time-step number.
- when – The current simulation time.
- step_obs – The list of observations for this time-step.
- max_back – The number of time-steps into the past when the most
recent resampling occurred; must be either a positive integer or
None
(no limit). - is_fs – Indicate whether this is a forecasting simulation (i.e., no observations). For deterministic models it is useful to add some random noise when estimating, to allow identical particles to differ in their behaviour, but this is not desirable when forecasting.
Returns: True
if resampling was performed, otherwiseFalse
.
Running a single simulation¶
-
pypfilt.
run
(params, start, end, streams, summary, state=None, save_when=None, save_to=None)¶ Run the particle filter against any number of data streams.
Parameters: - params (dict) – The simulation parameters.
- start – The start of the simulation period.
- end – The (exclusive) end of the simulation period.
- streams – A list of observation streams.
- summary – An object that generates summaries of each simulation.
- state – A previous simulation state as returned by, e.g., this function.
- save_when – Times at which to save the particle history matrix.
- save_to – The filename for saving the particle history matrix.
Returns: The resulting simulation state: a dictionary that contains the simulation parameters (
'params'
), the particle history matrix ('hist'
), and the summary statistics ('summary'
).
Simulation models¶
All simulation models should derive the following base class:
-
class
pypfilt.
Model
¶ The base class for simulation models, which defines the minimal set of methods that are required.
-
init
(params, vec)¶ Initialise a matrix of state vectors.
Parameters: - params – Simulation parameters.
- vec – An uninitialised \(P \times S\) matrix of state
vectors, for \(P\) particles and state vectors of length
\(S\) (as defined by
state_size()
). To set, e.g., the first element of each state vector to \(1\), you can use an ellipsis slice:vec[..., 0] = 1
.
-
state_size
()¶ Return the size of the state vector.
-
priors
(params)¶ Return a dictionary of model parameter priors. Each key must identify a parameter by name. Each value must be a function that returns samples from the associated prior distribution, and should have the following form:
lambda r, size=None: r.uniform(1.0, 2.0, size=size)
Here, the argument
r
is a PRNG instance andsize
specifies the output shape (by default, a single value).Parameters: params – Simulation parameters.
-
update
(params, step_date, dt, is_fs, prev, curr)¶ Perform a single time-step.
Parameters: - params – Simulation parameters.
- step_date – The date and time of the current time-step.
- dt – The time-step size (days).
- is_fs – Indicates whether this is a forecasting simulation.
- prev – The state before the time-step.
- curr – The state after the time-step (destructively updated).
-
describe
()¶ Describe each component of the state vector with a tuple of the form
(name, smooth, min, max)
, wherename
is a descriptive name for the variable/parameter,smooth
is a boolean that indicates whether the parameter admits continuous sampling (e.g., post-regularisation), andmin
andmax
define the (inclusive) range of valid values. These tuples must be in the same order as the state vector itself.
-
stat_info
()¶ Describe each statistic that can be calculated by this model as a
(name, stat_fn)
tuple, wherename
is a string that identifies the statistic andstat_fn
is a function that calculates the value of the statistic.
-
is_valid
(hist)¶ Identify particles whose state and parameters can be inspected. By default, this function returns
True
for all particles. Override this function to ensure that inchoate particles are correctly ignored.
-
Weighted statistics¶
The pypfilt.stats
module provides functions for calculating weighted
statistics across particle populations.
-
pypfilt.stats.
cov_wt
(x, wt, cor=False)¶ Estimate the weighted covariance matrix, based on a NumPy pull request.
Equivalent to
cov.wt(x, wt, cor, center=TRUE, method="unbiased")
as provided by thestats
package for R.Parameters: - x – A 2-D array; columns represent variables and rows represent observations.
- wt – A 1-D array of observation weights.
- cor – Whether to return a correlation matrix instead of a covariance matrix.
Returns: The covariance matrix (or correlation matrix, if
cor=True
).
-
pypfilt.stats.
avg_var_wt
(x, weights, biased=True)¶ Return the weighted average and variance (based on a Stack Overflow answer).
Parameters: - x – The data points.
- weights – The normalised weights.
- biased – Use a biased variance estimator.
Returns: A tuple that contains the weighted average and weighted variance.
-
pypfilt.stats.
qtl_wt
(x, weights, probs)¶ Equivalent to
wtd.quantile(x, weights, probs, normwt=TRUE)
as provided by the Hmisc package for R.Parameters: - x – The numerical data.
- weights – The weight of each data point.
- probs – The quantile(s) to compute.
Returns: The array of weighted quantiles.
-
pypfilt.stats.
cred_wt
(x, weights, creds)¶ Calculate weighted credible intervals.
Parameters: - x – The numerical data.
- weights – The weight of each data point.
- creds (List(int)) – The credible interval(s) to compute (
0..100
, where0
represents the median and100
the entire range).
Returns: A dictionary that maps credible intervals to the lower and upper interval bounds.
Simulation metadata¶
Every simulation data file should include metadata that documents the
simulation parameters and working environment.
The Metadata
class provides the means for generating
such metadata:
-
class
pypfilt.summary.
Metadata
¶ Document the simulation parameters and system environment for a set of simulations. A black-list (
ignore_dict
) defines which members of the parameters dictionary will be excluded from this metadata, seefilter()
for details.-
build
(params, pkgs=None)¶ Construct a metadata dictionary that documents the simulation parameters and system environment. Note that this should be generated at the start of the simulation, and that the git metadata will only be valid if the working directory is located within a git repository.
Parameters: - params – The simulation parameters.
- pkgs – A dictionary that maps package names to modules that
define appropriate
__version__
attributes, used to record the versions of additional relevant packages (see the example below).
By default, the versions of
pypfilt
,h5py
,numpy
andscipy
are recorded. The following example demonstrates how to also record the installed version of theepifx
package:import epifx import pypfilt.summary params = ... meta = pypfilt.summary.Metadata() metadata = meta.build(params, {'epifx': epifx})
-
filter
(values, ignore, encode_fn)¶ Recursively filter items from a dictionary, used to remove parameters from the metadata dictionary that, e.g., have no meaningful representation.
Parameters: - values – The original dictionary.
- ignore – A dictionary that specifies which values to ignore.
- encode_fn – A function that encodes the remaining values (see
encode_value()
).
For example, to ignore
['px_range']
,['resample']['rnd']
, and'expect_fn'
and'log_llhd_fn'
for every observation system when usingepifx
:m = pypfilt.summary.Metadata() ignore = { 'px_range': None, 'resample': {'rnd': None}, # Note the use of ``None`` to match any key under 'obs'. 'obs': {None: {'expect_fn': None, 'log_llhd_fn': None}} } m.filter(params, ignore, m.encode)
-
encode
(value)¶ Encode values in a form suitable for serialisation in HDF5 files.
- Integer values are converted to
numpy.int32
values. - Floating-point values and arrays retain their data type.
- All other (i.e., non-numerical) values are converted to UTF-8 strings.
- Integer values are converted to
-
object_name
(obj)¶ Return the fully qualified name of the object as a byte string.
-
priors
(params)¶ Return a dictionary that describes the model parameter priors.
Each key identifies a parameter (by name); the corresponding value is a byte string representation of the prior distribution, which is typically a
numpy.random.RandomState
method call.For example:
{'R0': b'random.uniform(1.0, 2.0)', 'gamma': b'(1 / random.uniform(1.0, 3.0))'}
-
pkg_version
(module)¶ Attempt to obtain the version of a Python module.
-
git_data
()¶ Record the status of the git repository within which the working directory is located (if such a repository exists).
-
run_cmd
(args, all_lines=False, err_val=u'')¶ Run a command and return the (Unicode) output. By default, only the first line is returned; set
all_lines=True
to receive all of the output as a list of Unicode strings. If the command returns a non-zero exit status, returnerr_val
instead.
-
Summary data files¶
The HDF5
class encapsulates the process of calculating and recording
summary statistics for each simulation.
-
class
pypfilt.summary.
HDF5
(params, obs_list, meta=None, first_day=False, only_fs=False)¶ Save tables of summary statistics to an HDF5 file.
Parameters: - params – The simulation parameters.
- obs_list – A list of all observations.
- meta – The simulation metadata; by default the output of
Metadata.build()
is used. - first_day – If
False
(the default) statistics are calculated from the date of the first observation. IfTrue
, statistics are calculated from the very beginning of the simulation period. - only_fs – If
False
(the default) statistics are calculated for the initial estimation simulation and for forecasting simulations. IfTrue
, statistics are only calculated for forecasting simulations.
-
add_tables
(*tables)¶ Add summary statistic tables that will be included in the output file.
-
save_forecasts
(fs, filename)¶ Save forecast summaries to disk in the HDF5 binary data format.
This function creates the following datasets that summarise the estimation and forecasting outputs:
'data/TABLE'
for each table.
The provided metadata will be recorded under
'meta/'
.If dataset creation timestamps are enabled, two simulations that produce identical outputs will not result in identical files. Timestamps will be disabled where possible (requires h5py >= 2.2):
'hdf5_track_times'
: Presence of creation timestamps.
Parameters: - fs – Simulation outputs, as returned by
pypfilt.forecast()
. - filename – The filename to which the data will be written.
Summary statistic tables¶
Summary statistics are stored in tables, each of which comprises a set of named columns and a specific number of rows.
The Table class¶
To calculate a summary statistic, you need to define a subclass of the
Table
class and provide implementations of each method.
-
class
pypfilt.
Table
(name)¶ The base class for summary statistic tables.
Tables are used to record rows of summary statistics as a simulation progresses.
Parameters: name – the name of the table in the output file. -
dtype
(params, obs_list)¶ Return the column names and data types, represented as a list of
(name, data type)
tuples. See the NumPy documentation for details.Parameters: - params – The simulation parameters.
- obs_list – A list of all observations.
-
n_rows
(start_date, end_date, n_days, n_sys, forecasting)¶ Return the number of rows required for a single simulation.
Parameters: - start_date – The date at which the simulation starts.
- end_date – The date at which the simulation ends.
- n_days – The number of days for which the simulation runs.
- n_sys – The number of observation systems (i.e., data sources).
- forecasting –
True
if this is a forecasting simulation, otherwiseFalse
.
-
add_rows
(hist, weights, fs_date, dates, obs_types, insert_fn)¶ Record rows of summary statistics for some portion of a simulation.
Parameters: - hist – The particle history matrix.
- weights – The weight of each particle at each date in the
simulation window; it has dimensions
(d, p)
ford
days andp
particles. - fs_date – The forecasting date; if this is not a forecasting simulation, this is the date at which the simulation ends.
- dates – A list of
(datetime, ix, hist_ix)
tuples that identify each day in the simulation window, the index of that day in the simulation window, and the index of that day in the particle history matrix. - obs_types – A set of
(unit, period)
tuples that identify each observation system from which observations have been taken. - insert_fn – A function that inserts one or more rows into the underlying data table; see the examples below.
The row insertion function can be used as follows:
# Insert a single row, represented as a tuple. insert_fn((x, y, z)) # Insert multiple rows, represented as a list of tuples. insert_fn([(x0, y0, z0), (x1, y1, z1)], n=2)
-
finished
(hist, weights, fs_date, dates, obs_types, insert_fn)¶ Record rows of summary statistics at the end of a simulation.
The parameters are as per
add_rows()
.Derived classes should only implement this method if rows must be recorded by this method; the provided method does nothing.
-
monitors
()¶ Return a list of monitors required by this Table.
Derived classes should implement this method if they require one or more monitors; the provided method returns an empty list.
-
Predefined statistics¶
The following derived classes are provided to calculate basic summary statistics of any generic simulation model.
-
class
pypfilt.summary.
ModelCIs
(probs=None, name=u'model_cints')¶ Calculate fixed-probability central credible intervals for all state variables and model parameters.
Parameters: - probs – an array of probabilities that define the size of each
central credible interval.
The default value is
numpy.uint8([0, 50, 90, 95, 99, 100])
. - name – the name of the table in the output file.
- probs – an array of probabilities that define the size of each
central credible interval.
The default value is
-
class
pypfilt.summary.
ParamCovar
(name=u'param_covar')¶ Calculate the covariance between all pairs of model parameters during each simulation.
Parameters: name – the name of the table in the output file.
Utility functions¶
The following column types are provided for convenience when defining custom
Table
subclasses.
-
pypfilt.summary.
dtype_unit
(obs_list, name=u'unit')¶ The dtype for columns that store observation units.
-
pypfilt.summary.
dtype_period
(name=u'period')¶ The dtype for columns that store observation periods.
-
pypfilt.summary.
dtype_value
(value, name=u'value')¶ The dtype for columns that store observation values.
-
pypfilt.summary.
dtype_names_to_str
(dtypes, encoding=u'utf-8')¶ Ensure that dtype field names are native strings, as required by NumPy. Unicode strings are not valid field names in Python 2, and this can cause problems when using Unicode string literals.
Parameters: - dtypes – A list of fields where each field is either a string, or a tuple of length 2 or 3 (see the NumPy docs for details).
- encoding – The encoding for converting Unicode strings to native strings in Python 2.
Returns: A list of fields, where each field name is a native string (
str
type).Raises: ValueError – If a name cannot be converted to a native string.
The following functions are provided for converting column types in structured arrays.
-
pypfilt.summary.
convert_cols
(data, converters)¶ Convert columns in a structured array from one type to another.
Parameters: - data – The input structured array.
- converters – A dictionary that maps (unicode) column names to
(convert_fn, new_dtype)
tuples, which contain a conversion function and define the output dtype.
Returns: A new structured array.
-
pypfilt.summary.
default_converters
(time_scale)¶ Return a dictionary for converting the
'fs_date'
and'date'
columns from (seeconvert_cols()
).
Retrospective statistics¶
In some cases, the Table
model is not sufficiently flexible, since it
assumes that statistics can be calculated during the course of a simulation.
For some statistics, it may be necessary to observe the entire simulation
before the statistics can be calculated.
In this case, you need to define a subclass of the Monitor
class,
which will observe (“monitor”) each simulation and, upon completion of each
simulation, can calculate the necessary summary statistics.
Note that a Table
subclass is also required to define the table
columns, the number of rows, and to record each row at the end of the
simulation.
-
class
pypfilt.
Monitor
¶ The base class for simulation monitors.
Monitors are used to calculate quantities that:
- Are used by multiple Tables (i.e., avoiding repeated computation); or
- Require a complete simulation for calculation (as distinct from Tables, which incrementally record rows as a simulation progresses).
The quantities calculated by a Monitor can then be recorded by
Table.add_rows()
and/orTable.finished()
.-
prepare
(params, obs_list)¶ Perform any required preparation prior to a set of simulations.
Parameters: - params – The simulation parameters.
- obs_list – A list of all observations.
-
begin_sim
(start_date, end_date, n_days, n_sys, forecasting)¶ Perform any required preparation at the start of a simulation.
Parameters: - start_date – The date at which the simulation starts.
- end_date – The date at which the simulation ends.
- n_days – The number of days for which the simulation runs.
- n_sys – The number of observation systems (i.e., data sources).
- forecasting –
True
if this is a forecasting simulation, otherwiseFalse
.
-
monitor
(hist, weights, fs_date, dates, obs_types)¶ Monitor the simulation progress.
Parameters: - hist – The particle history matrix.
- weights – The weight of each particle at each date in the
simulation window; it has dimensions
(d, p)
ford
days andp
particles. - fs_date – The forecasting date; if this is not a forecasting simulation, this is the date at which the simulation ends.
- dates – A list of
(datetime, ix, hist_ix)
tuples that identify each day in the simulation window, the index of that day in the simulation window, and the index of that day in the particle history matrix. - obs_types – A set of
(unit, period)
tuples that identify each observation system from which observations have been taken.
-
end_sim
(hist, weights, fs_date, dates, obs_types)¶ Finalise the data as required for the relevant summary statistics.
The parameters are as per
monitor()
.Derived classes should only implement this method if finalisation of the monitored data is required; the provided method does nothing.
-
load_state
(grp)¶ Load the monitor state from a cache file.
Parameters: grp – The h5py Group object from which to load the state.
-
save_state
(grp)¶ Save the monitor state to a cache file.
Parameters: grp – The h5py Group object in which to save the state.
Tables and Monitors¶
The methods of each Table
and Monitor
will be called in the
following sequence by the HDF5
summary class:
Before any simulations are performed:
Table.dtype()
Monitor.prepare()
In addition to defining the column types for each
Table
, this allows objects to store the simulation parameters and observations.At the start of each simulation:
Monitor.begin_sim()
Table.n_rows()
This notifies each
Monitor
and eachTable
of the simulation period, the number of observation systems (i.e., data sources), and whether it is a forecasting simulation (where no resampling will take place).During each simulation:
Monitor.monitor()
Table.add_rows()
This provides a portion of the simulation period for analysis by each
Monitor
and eachTable
. Because all of theMonitor.monitor()
methods are called before theTable.add_rows()
methods, tables can interrogate monitors to obtain any quantities of interest that are calculated byMonitor.monitor()
.At the end of each simulation:
Monitor.end_sim()
Table.finished()
This allows each
Monitor
and eachTable
to perform any final calculations once the simulation has completed. Because all of theMonitor.end_sim()
methods are called before theTable.finished()
methods, tables can interrogate monitors to obtain any quantities of interest that are calculated byMonitor.end_sim()
.
Time scales¶
Two pre-defined simulation time scales are provided.
-
class
pypfilt.
Scalar
(np_dtype=None)¶ A dimensionless time scale.
-
__init__
(np_dtype=None)¶ Parameters: np_dtype – The data type used for serialisation; the default is np.float64
.
-
set_period
(start, end, steps_per_unit)¶ Define the simulation period and time-step size.
Parameters: - start (float) – The start of the simulation period.
- end (float) – The end of the simulation period.
- steps_per_unit (int) – The number of time-steps per day.
Raises: ValueError – if
start
and/orend
are not floats, or ifsteps_per_unit
is not a positive integer.
-
with_observations
(*streams)¶ Return a generator that yields a sequence of tuples that contain: the time-step number, the current time, and a list of observations.
Parameters: streams – Any number of observation streams (each of which is assumed to be sorted chronologically).
-
with_observations_from_time
(start, *streams)¶ Return a generator that yields a sequence of tuples that contain: the time-step number, the current time, and a list of observations.
Parameters: - start – The starting time (set to
None
to use the start of the simulation period). - streams – Any number of observation streams (each of which is assumed to be sorted chronologically).
- start – The starting time (set to
-
-
class
pypfilt.
Datetime
(fmt=None)¶ A
datetime
scale where the time unit is days.-
__init__
(fmt=None)¶ Parameters: fmt – The format string used to serialise datetime
objects; the default is'%Y-%m-%d %H:%M:%S'
.
-
set_period
(start, end, steps_per_unit)¶ Define the simulation period and time-step size.
Parameters: - start (datetime.datetime) – The start of the simulation period.
- end (datetime.datetime) – The end of the simulation period.
- steps_per_unit (int) – The number of time-steps per day.
Raises: ValueError – if
start
and/orend
are notdatetime.datetime
instances, or ifsteps_per_unit
is not a positive integer.
-
with_observations
(*streams)¶ Return a generator that yields a sequence of tuples that contain: the time-step number, the current time, and a list of observations.
Parameters: streams – Any number of observation streams (each of which is assumed to be sorted chronologically).
-
with_observations_from_time
(start, *streams)¶ Return a generator that yields a sequence of tuples that contain: the time-step number, the current time, and a list of observations.
Parameters: - start – The starting time (set to
None
to use the start of the simulation period). - streams – Any number of observation streams (each of which is assumed to be sorted chronologically).
- start – The starting time (set to
-
Custom time scales¶
If neither of the above time scales is suitable, you can define a custom time scale, which should derive the following base class and define the methods listed here:
-
class
pypfilt.time.
Time
¶ The base class for simulation time scales, which defines the minimal set of methods that are required.
-
dtype
(name)¶ Define the dtype for columns that store times.
-
native_dtype
()¶ Define the Python type used to represent times in NumPy arrays.
-
is_instance
(value)¶ Return whether
value
is an instance of the native time type.
-
to_dtype
(time)¶ Convert from time to a dtype value.
-
from_dtype
(dval)¶ Convert from a dtype value to time.
-
to_unicode
(time)¶ Convert from time to a Unicode string.
This is used to define group names in HDF5 files, and for logging.
-
steps
()¶ Return a generator that yields a sequence of time-step numbers and times (represented as tuples) that span the simulation period.
The first time-step should be numbered 1 and occur at a time that is one time-step after the beginning of the simulation period.
-
step_count
()¶ Return the number of time-steps required for the simulation period.
-
step_of
(time)¶ Return the time-step number that corresponds to the specified time.
-
add_scalar
(time, scalar)¶ Add a scalar quantity to the specified time.
-
time_of_obs
(obs)¶ Return the time associated with an observation.
-
to_scalar
(time)¶ Convert the specified time into a scalar quantity, defined as the time-step number divided by the number of time-steps per time unit.
-
Plotting¶
Several plotting routines, built on top of
matplotlib, are provided in the pypilt.plot
module (matplotlib
must be installed in order to use this module).
To generate plots non-interactively (i.e., without having a window appear) use
the 'Agg'
backend:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
See the matplotlib FAQ for more details.
Styles and colour palettes¶
-
pypfilt.plot.
default_style
()¶ The style sheet provided by pypfilt.
-
pypfilt.plot.
apply_style
(*args, **kwds)¶ Temporarily apply a style sheet.
Parameters: style – The style sheet to apply (default: default_style()
).with apply_style(): make_plots()
-
pypfilt.plot.
n_colours
(name, n)¶ Extract a fixed number of colours from a colour map.
Parameters: - name – The colour map name (or a
matplotlib.colors.Colormap
instance). - n – The number of colours required.
colours = n_colours('Blues', 3)
- name – The colour map name (or a
-
pypfilt.plot.
brewer_qual
(name)¶ Qualitative palettes from the ColorBrewer project:
'Accent'
,'Dark2'
,'Paired'
,'Pastel1'
,'Pastel2'
,'Set1'
,'Set2'
,'Set3'
.Raises: ValueError – if the palette name is invalid.
-
pypfilt.plot.
colour_iter
(col, palette, reverse=False)¶ Iterate over the unique (sorted) values in an array, returning a
(value, colour)
tuple for each of the values.Parameters: - col – The column of (unsorted, repeated) values.
- palette – The colour map name or a list of colours.
- reverse – Whether to sort the values in ascending (default) or descending order.
Plotting functions¶
-
pypfilt.plot.
cred_ints
(ax, data, x, ci, palette=u'Blues', **kwargs)¶ Plot credible intervals as shaded regions.
Parameters: - ax – The plot axes.
- data – The NumPy array containing the credible intervals.
- x – The name of the x-axis column.
- ci – The name of the credible interval column.
- palette – The colour map name or a list of colours.
- **kwargs – Extra arguments to pass to Axes.plot and Axes.fill_between.
Returns: A list of the series that were plotted.
-
pypfilt.plot.
observations
(ax, data, label=u'Observations', future=False, **kwargs)¶ Plot observed values.
Parameters: - ax – The plot axes.
- data – The NumPy array containing the observation data.
- label – The label for the observation data.
- future – Whether the observations occur after the forecasting date.
- **kwargs – Extra arguments to pass to Axes.plot.
Returns: A list of the series that were plotted.
-
pypfilt.plot.
series
(ax, data, x, y, scales, legend_cols=True, **kwargs)¶ Add multiple series to a single plot, each of which is styled according to values in other columns.
Parameters: - ax – The axes on which to draw the line series.
- data – The structured array that contains the data to plot.
- x – The name of the column that corresponds to the x-axis.
- y – The name of the column that corresponds to the y-axis.
- scales –
A list of “scales” to apply to each line series; each scale is a tuple
(column, kwarg, kwvals, label_fmt)
where:column
is the name of a column indata
;kwarg
is the name of a keyword argument passed toplot()
;kwvals
is a list of values that the keyword argument will take; andlabel_fmt
is a format string for the legend keys.
- legend_cols – Whether to show each scale in a separate column.
- **kwargs – Extra arguments to pass to Axes.plot.
Returns: A list of the series that were plotted.
scales = [ # Colour lines according to the dispersion parameter. ('disp', 'color', brewer_qual('Set1', False), r'$k = {:.0f}$'), # Vary line style according to the background signal. ('bg_obs', 'linestyle', ['-', '--', ':'], r'$bg_{{obs}} = {}$'), ] series(ax, data, 'x_col', 'y_col', scales)
Faceted plots¶
This package provides a base class (Plot
) for plots that comprise
any number of subplots, and three subclasses for specific types of plots:
Wrap
for plots where a single variable identifies each subplot.Grid
for plots where two variables are used to identify each subplot.Single
for single plots.
The key method of these classes is Plot.subplots()
, which returns an
iterator that yields (axes, data)
tuples for each subplot.
By looping over these tuples, one set of plotting commands can be used to
generate all of the subplots.
For examples, see the plot_forecasts()
and plot_params()
functions in
the Plotting the results section of Getting Started.
-
class
pypfilt.plot.
Plot
(**kwargs)¶ The base class for plots that comprise multiple subplots.
Parameters: **kwargs – Extra arguments to pass to pyplot.subplots.
Variables: - fig – The
matplotlib.figure.Figure
instance for the plot. - axs – The \(M \times N\) array of
matplotlib.axes.Axes
instances for each of the sub-plots (\(M\) rows and \(N\) columns).
-
subplots
()¶ Return an iterator that yields
(axes, data)
tuples for each subplot.
-
add_to_legend
(objs, replace=False)¶ Add plot objects to the list of items to show in the figure legend.
Parameters: replace – Whether to ignore objects which share a label with any object already in this list (default) or to replace such objects (set to True
).
-
legend
(**kwargs)¶ Add a figure legend that lists the objects registered with
add_to_legend()
.Parameters: **kwargs – Extra arguments to pass to Figure.legend.
-
set_xlabel
(text, dy, **kwargs)¶ Add an x-axis label that is centred across all subplots.
Parameters: - text – The label text.
- dy – The vertical position of the label.
- **kwargs – Extra arguments to pass to Figure.text.
-
set_ylabel
(text, dx, **kwargs)¶ Add an y-axis label that is centred across all subplots.
Parameters: - text – The label text.
- dx – The horizontal position of the label.
- **kwargs – Extra arguments to pass to Figure.text.
-
expand_x_lims
(xs, pad_frac=0.05, pad_abs=None)¶ Increase the range of the x-axis, relative to the plot data.
Parameters: - xs – The x-axis data.
- pad_frac – The fractional increase in range.
- pad_abs – The absolute increase in range.
-
expand_y_lims
(ys, pad_frac=0.05, pad_abs=None)¶ Increase the range of the y-axis, relative to the plot data.
Parameters: - xs – The y-axis data.
- pad_frac – The fractional increase in range.
- pad_abs – The absolute increase in range.
-
scale_x_date
(lbl_fmt, day=None, month=None, year=None)¶ Use a datetime scale to locate and label the x-axis ticks.
Parameters: - lbl_fmt – The
strftime()
format string for tick labels. - day – Locate ticks at every N days.
- month – Locate ticks at every N months.
- year – Locate ticks at every N years.
Raises: ValueError – unless exactly one of
day
,month
, andyear
is specified.- lbl_fmt – The
-
scale_y_date
(lbl_fmt, day=None, month=None, year=None)¶ Use a datetime scale to locate and label the y-axis ticks.
Parameters: - lbl_fmt – The
strftime()
format string for tick labels. - day – Locate ticks at every N days.
- month – Locate ticks at every N months.
- year – Locate ticks at every N years.
Raises: ValueError – unless exactly one of
day
,month
, andyear
is specified.- lbl_fmt – The
-
save
(filename, format, width, height, **kwargs)¶ Save the plot to disk (a thin wrapper for savefig).
Parameters: - filename – The output filename or a Python file-like object.
- format – The output format.
- width – The figure width in inches.
- height – The figure height in inches.
- **kwargs – Extra arguments for
savefig
; the defaults aretransparent=True
andbbox_inches='tight'
.
- fig – The
-
class
pypfilt.plot.
Wrap
(data, xlbl, ylbl, fac, nr=None, nc=None, **kwargs)¶ Faceted plots similar to those produced by ggplot2’s
facet_wrap()
.Parameters: - data – The NumPy array containing the data to plot.
- xlbl – The label for the x-axis.
- ylbl – The label for the y-axis.
- fac – The faceting variable, represented as a tuple
(column_name, label_fmt)
wherecolumn_name
is the name of a column indata
andlabel_fmt
is the format string for facet labels or a function that returns the facet label. - nr – The number of rows; one of
nr
andnc
must be specified. - nc – The number of columns; one of
nr
andnc
must be specified. - **kwargs – Extra arguments for
Plot
.
Raises: ValueError – if
nr
andnc
are bothNone
or are both specified.-
expand_x_lims
(col, pad_frac=0.05, pad_abs=None)¶ Increase the range of the x-axis, relative to the plot data.
Parameters: - col – The column name for the x-axis data.
- pad_frac – The fractional increase in range.
- pad_abs – The absolute increase in range.
-
expand_y_lims
(col, pad_frac=0.05, pad_abs=None)¶ Increase the range of the y-axis, relative to the plot data.
Parameters: - col – The column name for the y-axis data.
- pad_frac – The fractional increase in range.
- pad_abs – The absolute increase in range.
-
subplots
(hide_axes=False, dx=0.055, dy=0.025)¶ Return an iterator that yields
(axes, data)
tuples for each subplot.Parameters: - hide_axes – Whether to hide x and y axes that are not on their bottom or left edge, respectively, of the figure.
- dx – The horizontal location for the y-axis label.
- dy – The vertical location for the x-axis label.
-
class
pypfilt.plot.
Grid
(data, xlbl, ylbl, xfac, yfac, **kwargs)¶ Faceted plots similar to those produced by ggplot2’s
facet_grid()
.Parameters: - data – The NumPy array containing the data to plot.
- xlbl – The label for the x-axis.
- ylbl – The label for the y-axis.
- xfac – The horizontal faceting variable, represented as a tuple
(column_name, label_fmt)
wherecolumn_name
is the name of a column indata
andlabel_fmt
is the format string for facet labels or a function that returns the facet label. - yfac – The vertical faceting variable (see
xfac
). - **kwargs – Extra arguments for
Plot
.
-
expand_x_lims
(col, pad_frac=0.05, pad_abs=None)¶ Increase the range of the x-axis, relative to the plot data.
Parameters: - col – The column name for the x-axis data.
- pad_frac – The fractional increase in range.
- pad_abs – The absolute increase in range.
-
expand_y_lims
(col, pad_frac=0.05, pad_abs=None)¶ Increase the range of the y-axis, relative to the plot data.
Parameters: - col – The column name for the y-axis data.
- pad_frac – The fractional increase in range.
- pad_abs – The absolute increase in range.
-
subplots
(hide_axes=False, dx=0.055, dy=0.025)¶ Return an iterator that yields
(axes, data)
tuples for each subplot.Parameters: - hide_axes – Whether to hide x and y axes that are not on their bottom or left edge, respectively, of the figure.
- dx – The horizontal location for the y-axis label.
- dy – The vertical location for the x-axis label.
For consistency, a class is also provided for single plots.
-
class
pypfilt.plot.
Single
(data, xlbl, ylbl, **kwargs)¶ Faceted plots that contain only one sub-plot; i.e., a single plot that provides the same methods as faceted plots that contain many sub-plots.
Parameters: - data – The NumPy array containing the data to plot.
- xlbl – The label for the x-axis.
- ylbl – The label for the y-axis.
- **kwargs – Extra arguments for
Plot
.
-
expand_x_lims
(col, pad_frac=0.05, pad_abs=None)¶ Increase the range of the x-axis, relative to the plot data.
Parameters: - col – The column name for the x-axis data.
- pad_frac – The fractional increase in range.
- pad_abs – The absolute increase in range.
-
expand_y_lims
(col, pad_frac=0.05, pad_abs=None)¶ Increase the range of the y-axis, relative to the plot data.
Parameters: - col – The column name for the y-axis data.
- pad_frac – The fractional increase in range.
- pad_abs – The absolute increase in range.
-
subplots
(hide_axes=False, dx=0.055, dy=0.025)¶ Return an iterator that yields
(axes, data)
tuples for each subplot.Parameters: - hide_axes – Whether to hide x and y axes that are not on their bottom or left edge, respectively, of the figure.
- dx – The horizontal location for the y-axis label.
- dy – The vertical location for the x-axis label.
Unicode and byte strings¶
The pypfilt
package simultaneously supports Python 2.7 and Python 3.x,
and is intended to behave identically regardless of the Python version.
It is assumed that the following Python 3 features are enabled in Python 2.7:
from __future__ import absolute_import, division, print_function from __future__ import unicode_literals
Importantly, among the
differences
between Python 2.7 and
Python 3.x, the native str
type is a byte string in Python 2 and a Unicode
string in Python 3.
This means that, e.g., the str()
built-in function returns byte strings
in Python 2 and Unicode strings in Python 3.
Guidelines for working with text¶
As per the Unicode HOWTO for Python 2 and Python 3:
Tip
Software should only work with Unicode strings internally, decoding the input data as soon as possible and encoding the output only at the end (the “Unicode sandwich”).
To that end, adhere to the following guidelines:
Use Unicode strings and Unicode literals everywhere. In Python 2, this means placing the following at the top of every file:
from __future__ import unicode_literals
If you have non-ASCII characters in a Python source file (e.g., in Unicode literals such as
'α'
), you need to declare the file encoding at the top of the file:# -*- coding: utf-8 -*-
Encode Unicode text into UTF-8 when writing to disk:
# Note: in Python 3, the open() built-in accepts an encoding argument with codecs.open(filename, 'wb', encoding='utf-8') as f: f.write(unicode_string)
Decode UTF-8 bytes into Unicode text when reading from disk:
# Note: in Python 3, the open() built-in accepts an encoding argument with codecs.open(filename, 'rb', encoding='utf-8') as f: unicode_lines = f.read().splitlines()
Note that NumPy functions such as loadtxt and genfromtxt cannot reliably handle non-ASCII text (e.g., see NumPy issues #3184, #4543, #4600, #4939), and should only be used with ASCII files:
import numpy as np with codecs.open(filename, encoding='ascii') as f: return np.loadtxt(f, ...)
Use the
'S'
(byte string) data type when storing text in NumPy arrays. Encode Unicode text into UTF-8 when storing text, and decode UTF-8 bytes when reading text:>>> from __future__ import unicode_literals >>> import numpy as np >>> xs = np.empty(3, dtype='S20') >>> xs[0] = 'abc'.encode('utf-8') >>> xs[1] = '« äëïöü »'.encode('utf-8') >>> xs[2] = 'ç'.encode('utf-8') >>> print(max(len(x) for x in xs)) 16 >>> for x in xs: >>> print(x.decode('utf-8')) abc « äëïöü » ç
NumPy has a Unicode data type (
'U'
), but it is not supported by h5py (and is platform-specific).Note that h5py object names (i.e., groups and datasets) are exclusively Unicode and are stored as bytes, so byte strings will be used as-is and Unicode strings will be encoded using UTF-8.
Use Unicode strings and literals when encoding to and decoding from JSON:
# Write UTF-8 bytes rather than '\uXXXX' escape sequences. with codecs.open(filename, 'wb', encoding='utf-8') as f: json.dump(json_data, f, ensure_ascii=False)
Functions for working with text¶
The pypfilt.text
module provides functions for converting between Unicode
strings and byte strings, which behave identically in Python 2 and Python 3.
-
pypfilt.text.
to_unicode
(value, encoding=u'utf-8')¶ Convert a value into a Unicode string.
- If the value is a Unicode string, no conversion is performed.
- If the value is a byte string, it is decoded according to the provided encoding.
- If the value is neither a Unicode string nor a byte string, it is
first converted into a string (by the
str()
built-in function) and then decoded if necessary.
-
pypfilt.text.
to_bytes
(value, encoding=u'utf-8')¶ Convert a value into a byte string.
- If the value is a Unicode string, it is encoded according to the provided encoding.
- If the value is a byte string, no conversion is performed.
- If the value is neither a Unicode string nor a byte string, it is
first converted into a string (by the
str()
built-in function) and then encoded if necessary.
It also provides functions for determining whether a value is a Unicode string or a byte string, although this should generally be known in advance.
-
pypfilt.text.
is_unicode
(value)¶ Return
True
if the value is a Unicode string.
-
pypfilt.text.
is_bytes
(value)¶ Return
True
if the value is a byte string.
Change Log¶
0.5.2 (2017-05-05)¶
- Bug fix: make pypfilt.examples a valid Python module.
- Bug fix: fix the Lotka-Volterra model in
pypfilt.examples.predation
to work correctly with scalar and non-scalar time scales.
0.5.1 (2017-04-28)¶
- Bug fix: correctly generate summaries for the case where no table rows will
be generated. This bug was introduced in pypfilt 0.5.0 (commit
8a0a614
).
0.5.0 (2017-04-26)¶
- Breaking change: the base model class has been renamed to
pypfilt.Model
. - Breaking change: the base model class has been simplified; the
state_info
,param_info
, andparam_bounds
methods have been replaced by a single method,describe
. This method also defines, for each element of the state vector, whether that element can be sampled continuously (e.g., by the post-regularised filter). - Breaking change:
pypfilt.summary.HDF5
no longer creates a table of observations if no such table has been defined, since it may be desirable to store observations in multiple tables (e.g., grouped by source or observation unit). To retain the previous behaviour, add the new observations tablepypfilt.summary.Obs
to the summary object. - Breaking change: particle weights are now passed as an additional argument to the log-likelihood function. Previously, the log-likelihood function was inspected to determine whether it accepted an extra argument (a nasty hack).
- Bug fix: avoid raising an exception when
regularise_or_fail
isFalse
(this was the intended behaviour in previous versions). - Bug fix: ensure that
pypfilt.summary.obs_table
correctly encodes the observation source and units. - Bug fix: correct an off-by-one error in
pypfilt.stats.qtl_wt
that caused the weighted quantiles to be calculated incorrectly. The calculation error was inversely proportional to the number of particles and should be negligible for any reasonable number of particles (e.g., \(\ge 10^3\)). - Enhancement: custom simulation time scales are supported. Two time scales
are provided (
pypfilt.Datetime
andpypfilt.Scalar
) and additional time scales can be implemented by inheriting frompypfilt.time.Time
. - Enhancement: allow likelihoods to depend on past states by settings
params['last_n_periods']
to N > 1, so that the current observation period can be compared to previous observation periods. - Enhancement: monitor states are now cached and restored, allowing them to calculate statistics over the combined estimation and forecasting runs. This means that, e.g., peak times and sizes are correctly reported even if they occurred prior to the forecasting date.
- Enhancement: add conversion functions for manipulating individual columns in structured arrays.
- Enhancement: plotting functions are provided by a new module,
pypfilt.plot
(adding an optional dependency on matplotlib). - Enhancement: provide a base class for simulation metadata
(
pypfilt.summary.Metadata
). - Enhancement: the (continuous) Lotka-Volterra equations are provided as an
example in
pypfilt.examples.predation
and act as the example system in the documentation. - Enhancement:
pypfilt.summary.dtype_names_to_str
now also accepts fields as a list field names (i.e., strings). - Enhancement: test cases for several modules are now provided in
./tests
and can be run with tox. - Enhancement: document how to install required packages as wheels, avoiding lengthy compilation times.
- Enhancement: document the release process and provide instructions for uploading packages to PyPI.
0.4.3 (2016-09-16)¶
Bug fix: correct the basic resampling method. Previously, random samples were drawn from the unit interval and were erroneously assumed to be in sorted order (as is the case for the stratified and deterministic methods).
Enhancement: automatically convert Unicode field names to native strings when using Python 2, to prevent NumPy from throwing a TypeError, as may occur when using
from __future__ import unicode_literals
.This functionality is provided by
pypfilt.summary.dtype_names_to_str
.Enhancement: ensure that temporary files are deleted when the simulation process is terminated by the SIGTERM signal.
Previously, they were only deleted upon normal termination (as noted in the atexit documentation).
Enhancement: consistently separate Unicode strings from bytes, and provide utility functions in the
pypfilt.text
module.Enhancement: forecast from the most recent known-good cached state, avoiding the estimation pass whenever possible.
Enhancement: allow the observation table to be generated externally. This means that users can include additional columns as needed.
Enhancement: separate the calculation of log-likelihoods from the adjustment of particle weights, resulting in the new function
pypfilt.log_llhd_of
.Enhancement: provide particle weights to the log-likelihood function, if the log-likelihood function accepts an extra argument. This has no impact on existing log-likelihood functions.
Enhancement: by default, allow simulations to continue if regularisation fails. This behaviour can be changed:
params['resample']['regularise_or_fail'] = True
0.4.2 (2016-06-16)¶
- Breaking change:
pypfilt.forecast
will raise an exception if no forecasting dates are provided. - Add installation instructions for Red Hat Enterprise Linux, Fedora, and Mac OS X (using Homebrew).
0.4.1 (2016-04-26)¶
Enhancement: allow forecasts to resume from cached states, greatly improving the speed with which forecasts can be generated when new or updated observations become available. This is enabled by defining a cache file:
params['hist']['cache_file'] = 'cache.hdf5'
Enhancement: add option to restrict summary statistics to forecasting simulations, ignoring the initial estimation run. This is enabled by passing
only_fs=True
as an argument to thepypfilt.summary.HDF5
constructor.
0.4.0 (2016-04-22)¶
Breaking change: require models to define default parameter bounds by implementing the
param_bounds
method.Enhancement: offer the post-regularised particle filter (post-RPF) as an alternative means of avoiding particle impoverishment (as opposed to incorporating stochastic noise into the model equations). This is enabled by setting:
params['resample']['regularisation'] = True
See the example script (
./doc/example/run.py
) for a demonstration.Improved documentation for
pypfilt.model.Base
and summary statistics.Add documentation for installing in a virtual environment.
0.3.0 (2016-02-23)¶
- This release includes a complete overhaul of simulation metadata and summary
statistics. See
./doc/example/run.py
for an overview of these changes. - Breaking change: decrease the default resampling threshold from 75% to 25%.
- Breaking change: define base classes for summary statistics and output.
- Breaking change: define a base class for simulation models.
- Breaking change: collate the resampling and history matrix parameters to reduce clutter.
- Breaking change: move
pypfilt.metadata_priors
topypfilt.summary
. - Bug fix: prevent
stats.cov_wt
from mutating the history matrix. - Bug fix: ensure that the time-step mapping behaves as documented.
- Bug fix: ensure that state vector slices have correct dimensions.
- Enhancement: ensure that forecasting dates lie within the simulation period.
- Performance improvement: Vectorise the history matrix initialisation.
- Host the documentation at Read The Docs.
0.2.0 (2015-11-16)¶
Notify models whether the current simulation is a forecast (i.e., if there are no observations). This allows deterministic models to add noise when estimating, to allow identical particles to differ in their behaviour, and to avoid doing so when forecasting.
Note that this is a breaking change, as it alters the parameters passed to the model update function.
Simplify the API for running a single simulation;
pypfilt.set_limits
has been removed andpypfilt.Time
is not included in the API documentation, on the grounds that users should not need to make use of this class.Greater use of NumPy array functions, removing the dependency on six >= 1.7.
Minor corrections to the example script (
./doc/example/run.py
).
0.1.2 (2015-06-08)¶
- Avoid error messages if no logging handler is configured by the application.
- Use a relative path for the output directory. This makes simulation metadata easier to reproduce, since the absolute path of the output directory is no longer included in the output file.
- Build a universal wheel via
python setup.py bdist_wheel
, which supports both Python 2 and Python 3.
0.1.1 (2015-06-01)¶
- Make the output directory a simulation parameter (
out_dir
) so that it can be changed without affecting the working directory, and vice versa.
0.1.0 (2015-05-29)¶
- Initial release.
Contributing to pypfilt¶
As an open source project, pypfilt welcomes contributions of many forms.
Examples of contributions include:
- Code patches
- Documentation improvements
- Bug reports and patch reviews
Testing with tox¶
The pypfilt testing suite uses the pytest
framework, and uses the tox automation tool
to run the tests under Python 2 and Python 3.
The test cases are contained in the ./tests
directory.
To run all tests using all of the Python versions defined in tox.ini
, run:
tox
The tox.ini
contents are shown below, and include targets that check
whether the documentation in ./doc
builds correctly with Python 2 and with
Python 3.
#
# Configuration file for tox, used to automate test activities.
#
# https://tox.readthedocs.io/en/latest/
#
# This configuration file defines four test environments:
#
# py27-test: Run the test cases in ./tests/ using Python 2.7.
# py35-test: Run the test cases in ./tests/ using Python 3.5.
# py27-docs: Build the package documentation using Python 2.7.
# py35-docs: Build the package documentation using Python 3.5.
#
# To perform each of these test activities, run:
#
# tox
#
[tox]
envlist = py{27,35}-{test,docs}
#
# Define common settings.
#
# * Cache installed wheels to accelerate environment creation.
# * Ensure tests are run against the installed package.
# * Add test-specific package dependencies.
#
[base]
pkg = pypfilt
wheels = {homedir}/.cache/pip/wheels
pytest = {envbindir}/py.test --cov={envsitepackagesdir}/{[base]pkg} --capture=no
install_command=pip install -f {[base]wheels} {opts} {packages}
deps =
wheel>=0.29
pytest
pytest-cov
hypothesis>=3.7
#
# Define environment-specific settings.
#
# * The documentation builds are performed in the ./doc directory.
# * The documentation builds depend on Sphinx and associated packages.
# * The test cases depend on the testing packages defined in [base].
# * Python 3.5 tests issue errors about comparing bytes and strings (-bb).
#
[testenv]
changedir =
docs: doc
deps =
test: {[base]deps}
docs: sphinx>=1.4
docs: sphinx-rtd-theme>=0.1.9
docs: sphinxcontrib-inlinesyntaxhighlight>=0.2
commands =
py27-test: {envpython} {[base]pytest} {posargs}
py35-test: {envpython} -bb {[base]pytest} {posargs}
docs: sphinx-build -W -b html -d {envtmpdir}/doctrees . {envtmpdir}/html
Release process¶
Feature development takes places on the “master” branch. Periodically, a release is created by increasing the version number and tagging the relevant commit with the new version number.
Update the version number according to the versioning scheme.
- Update the version number in
doc/conf.py
. The full version must always be updated, the short (X.Y) version does not need to be updated if the version number is being increased from X.Y.Z to X.Y.Z+1. - Update the version number in
pypfilt/version.py
. - Update the version number in
setup.py
.
- Update the version number in
Describe the changes at the top of
NEWS.rst
under a heading of the formX.Y.Z (YYYY-MM-DD)
, which identifies the new version number and the date on which this version was released.Commit these changes; set the commit message to
Release pypfilt X.Y.Z
.Tag this commit
X.Y.Z
.Push this commit and the new tag upstream.
Publishing to PyPI¶
These instructions are based on the Python Packaging User Guide.
Ensure that twine
is installed:
pip install twine
Define the PyPI server(s) in .pypirc
:
[distutils]
index-servers =
pypi
pypitest
[pypi]
repository=https://upload.pypi.org/legacy/
[pypitest]
repository=https://testpypi.python.org/pypi
Build the wheel ./dist/pypfilt-X.Y.Z-py2.py3-none-any.whl
:
python setup.py bdist_wheel
Upload this wheel to the PyPI test server, so that any problems can be identified and fixed:
twine upload -r pypitest dist/pypfilt-X.Y.Z-py2.py3-none-any.whl
Then upload this wheel to PyPI:
twine upload dist/pypfilt-X.Y.Z-py2.py3-none-any.whl