import collections
import pandas as pd
import numpy as np
import warnings
from openpathsampling.progress import SimpleProgress
try:
from collections import abc
except ImportError:
import collections as abc
# based on http://stackoverflow.com/a/3387975
class TransformedDict(abc.MutableMapping):
"""A dictionary that applies an arbitrary key-altering function before
accessing the keys
This implementation involves a particular hashing function. It is
assumed that any two input objects which give the same hash are
effectively identical, allowing later rehashing based on the same.
"""
def __init__(self, hash_function, *args, **kwargs):
self.store = dict()
self.hash_representatives = dict()
self.hash_function = hash_function
self.update(dict(*args, **kwargs)) # use the free update to set keys
def __getitem__(self, key):
return self.store[self.hash_function(key)]
def __setitem__(self, key, value):
hashed = self.hash_function(key)
if hashed not in self.hash_representatives:
self.hash_representatives[hashed] = key
self.store[hashed] = value
def __delitem__(self, key):
hashed = self.hash_function(key)
del self.store[hashed]
del self.hash_representatives[hashed]
def __iter__(self):
return iter(self.hash_representatives.values())
def __len__(self):
return len(self.store)
def rehash(self, new_hash):
"""Create a new TransformedDict with this data and new hash.
It is up to the user to ensure that the mapping from the old hash to
the new is a function (i.e., each entry from the old hash can be
mapped directly onto the new hash).
For example, this is used to map from a snapshot's coordinates to
a collective variable based on the coordinates. However, if the
orignal hash was based on coordinates, but the new hash included
velocities, the resulting mapping would be invalid. It is up to the
user to avoid such invalid remappings.
"""
return TransformedDict(new_hash,
{self.hash_representatives[k]: self.store[k]
for k in self.store})
class SnapshotByCoordinateDict(TransformedDict):
"""TransformedDict that uses snapshot coordinates as keys.
This is primarily used to have a unique key for shooting point analysis
(e.g., committor analysis).
"""
def __init__(self, *args, **kwargs):
hash_fcn = lambda x: x.coordinates.tobytes()
super(SnapshotByCoordinateDict, self).__init__(hash_fcn,
*args, **kwargs)
class ShootingPointAnalysisError(AssertionError):
# TODO this should inherit from a different Error type in OPS 2.0
pass
class NoFramesInStateError(ShootingPointAnalysisError):
pass
class MoreStatesThanFramesError(ShootingPointAnalysisError):
pass
[docs]
class ShootingPointAnalysis(SimpleProgress, SnapshotByCoordinateDict):
"""
Container and methods for shooting point analysis.
This is especially useful for analyzing committors, which is
automatically done on a per-configuration basis, and can also be done
as a histogram.
Parameters
----------
steps : iterable of :class:`.MCStep` or None
input MC steps to analyze; if None, no analysis performed
states : list of :class:`.Volume`
volumes to consider as states for the analysis. For pandas output,
these volumes must be named.
error_if_no_state: bool, default True
boolean flag to error on steps that don't end in one of the states
"""
[docs]
def __init__(self, steps, states, error_if_no_state=True):
super(ShootingPointAnalysis, self).__init__()
self.states = states
self.error_if_no_state = error_if_no_state
if steps:
self.analyze(steps)
def analyze(self, steps):
"""Analyze a list of steps, adding to internal results.
Parameters
----------
steps : iterable of :class:`.MCStep` or None
MC steps to analyze
"""
for step in self.progress(steps):
try:
self.analyze_single_step(step)
except NoFramesInStateError as err:
if self.error_if_no_state:
addition = ("\nTo disable this error set "
"'error_if_no_state=False'")
raise type(err)(str(err) + addition)
else:
warnings.warn(str(err))
def analyze_single_step(self, step):
"""
Analyzes final states from a path sampling step. Adds to internal
results.
Parameters
----------
step : :class:`.MCStep`
the step to analyze and add to this analysis
Returns
-------
list of :class:`.Volume`
the states which are identified as new final states from this
move
"""
key = self.step_key(step)
if key is not None:
details = step.change.canonical.details
trial_traj = step.change.canonical.trials[0].trajectory
init_traj = details.initial_trajectory
test_points = [s for s in [trial_traj[0], trial_traj[-1]]
if s not in [init_traj[0], init_traj[-1]]]
total = collections.Counter(
{state: sum([int(state(pt)) for pt in test_points])
for state in self.states}
)
total_count = sum(total.values())
if total_count == 0:
err = ("Step "+str(step.mccycle)+" has a trajectory without "
"endpoints in any of the states.")
raise NoFramesInStateError(err)
if total_count > len(test_points):
err = ("The " + str(len(test_points)) +
" end points of the trail trajectory from step " +
str(step.mccycle) +
" found " + str(total_count) + " stable states."
"\n Are you sure your states don't overlap?"
)
raise MoreStatesThanFramesError(err)
try:
self[key] += total
except KeyError:
self[key] = total
else:
total = {}
return [s for s in total.keys() if total[s] > 0]
@staticmethod
def step_key(step):
"""
Returns the key we use for hashing (the shooting snapshot).
Parameters
----------
step : :class:`.MCStep`
the step to extract a shooting point from
Returns
-------
:class:`.Snapshot` or None
the shooting snapshot, or None if this step is not a shooting
move.
"""
key = None
try:
change = step.change.canonical
details = change.details
shooting_snap = details.shooting_snapshot
except AttributeError:
# wrong kind of move (no shooting_snapshot)
pass
except IndexError:
# very wrong kind of move (no trials!)
pass
else:
# easy to change how we define the key
key = shooting_snap
return key
@classmethod
def from_individual_runs(cls, run_results, states=None):
"""Build shooting point analysis from pairs of shooting point to
final state.
Parameters
----------
run_results : list of 2-tuples (:class:`.Snapshot`, :class:`.Volume`)
the first element in each pair is the shooting point, the second
is the final volume
"""
if states is None:
states = set(s[1] for s in run_results)
analyzer = ShootingPointAnalysis(None, states)
for step in run_results:
key = step[0]
total = collections.Counter({step[1]: 1})
try:
analyzer[key] += total
except KeyError:
analyzer[key] = total
return analyzer
def committor(self, state, label_function=None):
"""Calculate the (point-by-point) committor.
This is for the point-by-point (per-configuration) committor, not
for histograms. See `committor_histogram` for the histogram version.
Parameters
----------
state : :class:`.Volume`
the committor is 1.0 if 100% of shots enter this state
label_function : callable
the keys for the dictionary that is returned are
`label_function(snapshot)`; default `None` gives the snapshot as
key.
Returns
-------
dict :
mapping labels given by label_function to the committor value
"""
if label_function is None:
label_function = lambda s: s
results = {}
for k in self:
out_key = label_function(k)
counter_k = self[k]
committor = float(counter_k[state]) / sum([counter_k[s]
for s in self.states])
results[out_key] = committor
return results
@staticmethod
def _get_key_dim(key):
try:
ndim = len(key)
except TypeError:
ndim = 1
if ndim > 2 or ndim < 1:
err = ("Histogram key dimension {0} > 2 or {0} < 1 "
"(key: {1})").format(ndim, key)
raise RuntimeError(err)
return ndim
def committor_histogram(self, new_hash, state, bins=10):
"""Calculate the histogrammed version of the committor.
Parameters
----------
new_hash : callable
values are histogrammed in bins based on new_hash(snapshot)
state : :class:`.Volume`
the committor is 1.0 if 100% of shots enter this state
bins : see numpy.histogram
bins input to numpy.histogram
Returns
-------
tuple :
hist, bins like numpy.histogram, where hist is the histogram
count and bins is the bins output from numpy.histogram. 2-tuple
in the case of 1D histogram, 3-tuple in the case of 2D histogram
"""
rehashed = self.rehash(new_hash)
r_store = rehashed.store
count_all = {k: sum(r_store[k].values()) for k in r_store}
count_state = {k: r_store[k][state] for k in r_store}
ndim = self._get_key_dim(list(r_store.keys())[0])
if ndim == 1:
(all_hist, b) = np.histogram(list(count_all.keys()),
weights=list(count_all.values()),
bins=bins)
(state_hist, b) = np.histogram(list(count_state.keys()),
weights=list(count_state.values()),
bins=bins)
b_list = [b]
elif ndim == 2:
(all_hist, b_x, b_y) = np.histogram2d(
x=[k[0] for k in count_all],
y=[k[1] for k in count_all],
weights=list(count_all.values()),
bins=bins
)
(state_hist, b_x, b_y) = np.histogram2d(
x=[k[0] for k in count_state],
y=[k[1] for k in count_state],
weights=list(count_state.values()),
bins=bins
)
b_list = [b_x, b_y]
# if all_hist is 0, state_hist is NaN: ignore warning, return NaN
with np.errstate(divide='ignore', invalid='ignore'):
state_frac = np.true_divide(state_hist, all_hist)
return tuple([state_frac] + b_list)
def to_pandas(self, label_function=None):
"""
Pandas dataframe. Row for each configuration, column for each state.
Parameters
----------
label_function : callable
takes snapshot, returns index to use for pandas.DataFrame
"""
transposed = pd.DataFrame(self.store).transpose().to_dict()
df = pd.DataFrame(transposed)
df.columns = [s.name for s in transposed.keys()]
if label_function is None:
df.index = range(len(df.index))
else:
# TODO: is ordering guaranteed here?
df.index = [label_function(self.hash_representatives[k])
for k in self.store]
return df