from openpathsampling.netcdfplus.base import StorableObject
from .object import ObjectStore
import logging
logger = logging.getLogger(__name__)
init_log = logging.getLogger('openpathsampling.initialization')
[docs]
class VariableStore(ObjectStore):
[docs]
def __init__(self, content_class, var_names):
super(VariableStore, self).__init__(
content_class,
json=False
)
# TODO: determine var_names automatically from content_class
# problem is that some decorators, e.g. using delayed loader
# hide the actual __init__ signature and so we cannot determine
# what variables to store. Could be 2.0
if not issubclass(content_class, StorableObject):
raise ValueError(('Content_class %s must be subclassed from '
'StorableObject') % content_class.__name__)
var_names_class = self.content_class.args()[1:]
# backwards compatibility. Reorder the stored var_names to comply
# with the signature. Optional variables need to be at the end!!
var_names_new = []
for name in var_names_class:
if name in var_names:
var_names_new.append(name)
else:
break
logger.info(
'Creates VariableStore with variables %s and instatiated with %s' %
(str(var_names_new), str(var_names))
)
self.var_names = var_names_new
self._cached_all = False
def to_dict(self):
return {
'content_class': self.content_class,
'var_names': self.var_names
}
def _save(self, obj, idx):
for var in self.var_names:
self.write(var, idx, obj)
def _load(self, idx):
# kwargs = {var: self.vars[var][idx] for var in self.var_names}
args = [self.vars[var][idx] for var in self.var_names]
return self.content_class(*args)
def initialize(self):
super(VariableStore, self).initialize()
# Add here the stores to be imported
# self.create_variable('name', 'var_type')
def all(self):
self.cache_all()
return self
def cache_all(self, part=None):
"""Load all samples as fast as possible into the cache
Parameters
----------
part : list of int or `None`
If `None` (default) all samples will be loaded. Otherwise the
list of indices in `part` will be loaded into the cache
"""
max_length = self.cache.size[0]
max_length = len(self) if max_length < 0 else max_length
if part is None:
length = min(len(self), max_length)
part = range(length)
else:
part = sorted(list(set(part)))
length = min(len(part), max_length)
part = part[:length]
if not part:
return
# just in case we saved the var_names in another order and so we are
# backwards compatible
if not self._cached_all:
data = zip(*[
self.vars[var][part]
for var in self.var_names
])
[self.add_to_cache(idx, v) for idx, v in zip(part, data)]
self._cached_all = True
def add_to_cache(self, idx, data):
if idx not in self.cache:
# attr = {var: self.vars[var].getter(data[nn])
# for nn, var in enumerate(self.var_names)}
obj = self.content_class(*data)
self._get_id(idx, obj)
# self.index[obj.__uuid__] = idx
self.cache[idx] = obj