Find Cheaper University Textbooks
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

396 lines
13 KiB

from __future__ import unicode_literals
from six import iteritems, add_metaclass
from six.moves import map
from .exceptions import UnknownDslObject, ValidationException
SKIP_VALUES = ('', None)
def _wrap(val, obj_wrapper=None):
if isinstance(val, dict):
return AttrDict(val) if obj_wrapper is None else obj_wrapper(val)
if isinstance(val, list):
return AttrList(val)
return val
def _make_dsl_class(base, name, params_def=None, suffix=''):
"""
Generate a DSL class based on the name of the DSL object and it's parameters
"""
attrs = {'name': name}
if params_def:
attrs['_param_defs'] = params_def
cls_name = str(''.join(s.title() for s in name.split('_')) + suffix)
return type(cls_name, (base, ), attrs)
class AttrList(object):
def __init__(self, l, obj_wrapper=None):
# make iteables into lists
if not isinstance(l, list):
l = list(l)
self._l_ = l
self._obj_wrapper = obj_wrapper
def __repr__(self):
return repr(self._l_)
def __eq__(self, other):
if isinstance(other, AttrList):
return other._l_ == self._l_
# make sure we still equal to a dict with the same data
return other == self._l_
def __getitem__(self, k):
l = self._l_[k]
if isinstance(k, slice):
return AttrList(l)
return _wrap(l, self._obj_wrapper)
def __setitem__(self, k, value):
self._l_[k] = value
def __iter__(self):
return map(lambda i: _wrap(i, self._obj_wrapper), self._l_)
def __len__(self):
return len(self._l_)
def __nonzero__(self):
return bool(self._l_)
__bool__ = __nonzero__
def __getattr__(self, name):
return getattr(self._l_, name)
class AttrDict(object):
"""
Helper class to provide attribute like access (read and write) to
dictionaries. Used to provide a convenient way to access both results and
nested dsl dicts.
"""
def __init__(self, d):
# assign the inner dict manually to prevent __setattr__ from firing
super(AttrDict, self).__setattr__('_d_', d)
def __contains__(self, key):
return key in self._d_
def __nonzero__(self):
return bool(self._d_)
__bool__ = __nonzero__
def __dir__(self):
# introspection for auto-complete in IPython etc
return list(self._d_.keys())
def __eq__(self, other):
if isinstance(other, AttrDict):
return other._d_ == self._d_
# make sure we still equal to a dict with the same data
return other == self._d_
def __repr__(self):
r = repr(self._d_)
if len(r) > 60:
r = r[:60] + '...}'
return r
def __getattr__(self, attr_name):
try:
return _wrap(self._d_[attr_name])
except KeyError:
raise AttributeError(
'%r object has no attribute %r' % (self.__class__.__name__, attr_name))
def __delattr__(self, attr_name):
try:
del self._d_[attr_name]
except KeyError:
raise AttributeError(
'%r object has no attribute %r' % (self.__class__.__name__, attr_name))
def __getitem__(self, key):
return _wrap(self._d_[key])
def __setitem__(self, key, value):
self._d_[key] = value
def __delitem__(self, key):
del self._d_[key]
def __setattr__(self, name, value):
if name in self._d_ or not hasattr(self.__class__, name):
self._d_[name] = value
else:
# there is an attribute on the class (could be property, ..) - don't add it as field
super(AttrDict, self).__setattr__(name, value)
def __iter__(self):
return iter(self._d_)
def to_dict(self):
return self._d_
class DslMeta(type):
"""
Base Metaclass for DslBase subclasses that builds a registry of all classes
for given DslBase subclass (== all the query types for the Query subclass
of DslBase).
It then uses the information from that registry (as well as `name` and
`shortcut` attributes from the base class) to construct any subclass based
on it's name.
For typical use see `QueryMeta` and `Query` in `elasticsearch_dsl.query`.
"""
_types = {}
def __init__(cls, name, bases, attrs):
super(DslMeta, cls).__init__(name, bases, attrs)
# skip for DslBase
if not hasattr(cls, '_type_shortcut'):
return
if cls.name is None:
# abstract base class, register it's shortcut
cls._types[cls._type_name] = cls._type_shortcut
# and create a registry for subclasses
if not hasattr(cls, '_classes'):
cls._classes = {}
elif cls.name not in cls._classes:
# normal class, register it
cls._classes[cls.name] = cls
@classmethod
def get_dsl_type(cls, name):
try:
return cls._types[name]
except KeyError:
raise UnknownDslObject('DSL type %s does not exist.' % name)
@add_metaclass(DslMeta)
class DslBase(object):
"""
Base class for all DSL objects - queries, filters, aggregations etc. Wraps
a dictionary representing the object's json.
Provides several feature:
- attribute access to the wrapped dictionary (.field instead of ['field'])
- _clone method returning a deep copy of self
- to_dict method to serialize into dict (to be sent via elasticsearch-py)
- basic logical operators (&, | and ~) using a Bool(Filter|Query) TODO:
move into a class specific for Query/Filter
- respects the definiton of the class and (de)serializes it's
attributes based on the `_param_defs` definition (for example turning
all values in the `must` attribute into Query objects)
"""
_param_defs = {}
@classmethod
def get_dsl_class(cls, name):
try:
return cls._classes[name]
except KeyError:
raise UnknownDslObject('DSL class `%s` does not exist in %s.' % (name, cls._type_name))
def __init__(self, **params):
self._params = {}
for pname, pvalue in iteritems(params):
if '__' in pname:
pname = pname.replace('__', '.')
self._setattr(pname, pvalue)
def _repr_params(self):
""" Produce a repr of all our parameters to be used in __repr__. """
return ', '.join(
'%s=%r' % (n.replace('.', '__'), v)
for (n, v) in sorted(iteritems(self._params))
# make sure we don't include empty typed params
if 'type' not in self._param_defs.get(n, {}) or v
)
def __repr__(self):
return '%s(%s)' % (
self.__class__.__name__,
self._repr_params()
)
def __eq__(self, other):
return isinstance(other, self.__class__) and other.to_dict() == self.to_dict()
def __ne__(self, other):
return not self == other
def __setattr__(self, name, value):
if name.startswith('_'):
return super(DslBase, self).__setattr__(name, value)
return self._setattr(name, value)
def _setattr(self, name, value):
# if this attribute has special type assigned to it...
if name in self._param_defs:
pinfo = self._param_defs[name]
if 'type' in pinfo:
# get the shortcut used to construct this type (query.Q, aggs.A, etc)
shortcut = self.__class__.get_dsl_type(pinfo['type'])
if pinfo.get('multi'):
if not isinstance(value, (tuple, list)):
value = (value, )
value = list(map(shortcut, value))
# dict(name -> DslBase), make sure we pickup all the objs
elif pinfo.get('hash'):
value = dict((k, shortcut(v)) for (k, v) in iteritems(value))
# single value object, just convert
else:
value = shortcut(value)
self._params[name] = value
def __getattr__(self, name):
if name.startswith('_'):
raise AttributeError(
'%r object has no attribute %r' % (self.__class__.__name__, name))
value = None
try:
value = self._params[name]
except KeyError:
# compound types should never throw AttributeError and return empty
# container instead
if name in self._param_defs:
pinfo = self._param_defs[name]
if pinfo.get('multi'):
value = self._params.setdefault(name, [])
elif pinfo.get('hash'):
value = self._params.setdefault(name, {})
if value is None:
raise AttributeError(
'%r object has no attribute %r' % (self.__class__.__name__, name))
# wrap nested dicts in AttrDict for convenient access
if isinstance(value, dict):
return AttrDict(value)
return value
def to_dict(self):
"""
Serialize the DSL object to plain dict
"""
d = {}
for pname, value in iteritems(self._params):
pinfo = self._param_defs.get(pname)
# typed param
if pinfo and 'type' in pinfo:
# don't serialize empty lists and dicts for typed fields
if value in ({}, []):
continue
# multi-values are serialized as list of dicts
if pinfo.get('multi'):
value = list(map(lambda x: x.to_dict(), value))
# squash all the hash values into one dict
elif pinfo.get('hash'):
value = dict((k, v.to_dict()) for k, v in iteritems(value))
# serialize single values
else:
value = value.to_dict()
# serialize anything with to_dict method
elif hasattr(value, 'to_dict'):
value = value.to_dict()
d[pname] = value
return {self.name: d}
def _clone(self):
return self._type_shortcut(self.to_dict())
class ObjectBase(AttrDict):
def __init__(self, **kwargs):
m = self._doc_type.mapping
for k in m:
if k in kwargs and m[k]._coerce:
kwargs[k] = m[k].deserialize(kwargs[k])
super(ObjectBase, self).__init__(kwargs)
def __getattr__(self, name):
try:
return super(ObjectBase, self).__getattr__(name)
except AttributeError:
if name in self._doc_type.mapping:
f = self._doc_type.mapping[name]
if hasattr(f, 'empty'):
value = f.empty()
if value not in SKIP_VALUES:
setattr(self, name, value)
value = getattr(self, name)
return value
raise
def __setattr__(self, name, value):
if name in self._doc_type.mapping:
value = self._doc_type.mapping[name].deserialize(value)
super(ObjectBase, self).__setattr__(name, value)
def to_dict(self):
out = {}
for k, v in iteritems(self._d_):
try:
f = self._doc_type.mapping[k]
if f._coerce:
v = f.serialize(v)
except KeyError:
pass
# don't serialize empty values
# careful not to include numeric zeros
if v in ([], {}, None):
continue
out[k] = v
return out
def clean_fields(self):
errors = {}
for name in self._doc_type.mapping:
field = self._doc_type.mapping[name]
data = self._d_.get(name, None)
try:
# save the cleaned value
data = field.clean(data)
except ValidationException as e:
errors.setdefault(name, []).append(e)
if name in self._d_ or data not in ([], {}, None):
self._d_[name] = data
if errors:
raise ValidationException(errors)
def clean(self):
pass
def full_clean(self):
self.clean_fields()
self.clean()
def merge(data, new_data):
if not (isinstance(data, (AttrDict, dict))
and isinstance(new_data, (AttrDict, dict))):
raise ValueError('You can only merge two dicts! Got %r and %r instead.' % (data, new_data))
for key, value in iteritems(new_data):
if key in data and isinstance(data[key], (AttrDict, dict)):
merge(data[key], value)
else:
data[key] = value