from __future__ import unicode_literals from six import iteritems, add_metaclass from six.moves import map from .exceptions import UnknownDslObject, ValidationException SKIP_VALUES = ('', None) def _wrap(val, obj_wrapper=None): if isinstance(val, dict): return AttrDict(val) if obj_wrapper is None else obj_wrapper(val) if isinstance(val, list): return AttrList(val) return val def _make_dsl_class(base, name, params_def=None, suffix=''): """ Generate a DSL class based on the name of the DSL object and it's parameters """ attrs = {'name': name} if params_def: attrs['_param_defs'] = params_def cls_name = str(''.join(s.title() for s in name.split('_')) + suffix) return type(cls_name, (base, ), attrs) class AttrList(object): def __init__(self, l, obj_wrapper=None): # make iteables into lists if not isinstance(l, list): l = list(l) self._l_ = l self._obj_wrapper = obj_wrapper def __repr__(self): return repr(self._l_) def __eq__(self, other): if isinstance(other, AttrList): return other._l_ == self._l_ # make sure we still equal to a dict with the same data return other == self._l_ def __getitem__(self, k): l = self._l_[k] if isinstance(k, slice): return AttrList(l) return _wrap(l, self._obj_wrapper) def __setitem__(self, k, value): self._l_[k] = value def __iter__(self): return map(lambda i: _wrap(i, self._obj_wrapper), self._l_) def __len__(self): return len(self._l_) def __nonzero__(self): return bool(self._l_) __bool__ = __nonzero__ def __getattr__(self, name): return getattr(self._l_, name) class AttrDict(object): """ Helper class to provide attribute like access (read and write) to dictionaries. Used to provide a convenient way to access both results and nested dsl dicts. """ def __init__(self, d): # assign the inner dict manually to prevent __setattr__ from firing super(AttrDict, self).__setattr__('_d_', d) def __contains__(self, key): return key in self._d_ def __nonzero__(self): return bool(self._d_) __bool__ = __nonzero__ def __dir__(self): # introspection for auto-complete in IPython etc return list(self._d_.keys()) def __eq__(self, other): if isinstance(other, AttrDict): return other._d_ == self._d_ # make sure we still equal to a dict with the same data return other == self._d_ def __repr__(self): r = repr(self._d_) if len(r) > 60: r = r[:60] + '...}' return r def __getattr__(self, attr_name): try: return _wrap(self._d_[attr_name]) except KeyError: raise AttributeError( '%r object has no attribute %r' % (self.__class__.__name__, attr_name)) def __delattr__(self, attr_name): try: del self._d_[attr_name] except KeyError: raise AttributeError( '%r object has no attribute %r' % (self.__class__.__name__, attr_name)) def __getitem__(self, key): return _wrap(self._d_[key]) def __setitem__(self, key, value): self._d_[key] = value def __delitem__(self, key): del self._d_[key] def __setattr__(self, name, value): if name in self._d_ or not hasattr(self.__class__, name): self._d_[name] = value else: # there is an attribute on the class (could be property, ..) - don't add it as field super(AttrDict, self).__setattr__(name, value) def __iter__(self): return iter(self._d_) def to_dict(self): return self._d_ class DslMeta(type): """ Base Metaclass for DslBase subclasses that builds a registry of all classes for given DslBase subclass (== all the query types for the Query subclass of DslBase). It then uses the information from that registry (as well as `name` and `shortcut` attributes from the base class) to construct any subclass based on it's name. For typical use see `QueryMeta` and `Query` in `elasticsearch_dsl.query`. """ _types = {} def __init__(cls, name, bases, attrs): super(DslMeta, cls).__init__(name, bases, attrs) # skip for DslBase if not hasattr(cls, '_type_shortcut'): return if cls.name is None: # abstract base class, register it's shortcut cls._types[cls._type_name] = cls._type_shortcut # and create a registry for subclasses if not hasattr(cls, '_classes'): cls._classes = {} elif cls.name not in cls._classes: # normal class, register it cls._classes[cls.name] = cls @classmethod def get_dsl_type(cls, name): try: return cls._types[name] except KeyError: raise UnknownDslObject('DSL type %s does not exist.' % name) @add_metaclass(DslMeta) class DslBase(object): """ Base class for all DSL objects - queries, filters, aggregations etc. Wraps a dictionary representing the object's json. Provides several feature: - attribute access to the wrapped dictionary (.field instead of ['field']) - _clone method returning a deep copy of self - to_dict method to serialize into dict (to be sent via elasticsearch-py) - basic logical operators (&, | and ~) using a Bool(Filter|Query) TODO: move into a class specific for Query/Filter - respects the definiton of the class and (de)serializes it's attributes based on the `_param_defs` definition (for example turning all values in the `must` attribute into Query objects) """ _param_defs = {} @classmethod def get_dsl_class(cls, name): try: return cls._classes[name] except KeyError: raise UnknownDslObject('DSL class `%s` does not exist in %s.' % (name, cls._type_name)) def __init__(self, **params): self._params = {} for pname, pvalue in iteritems(params): if '__' in pname: pname = pname.replace('__', '.') self._setattr(pname, pvalue) def _repr_params(self): """ Produce a repr of all our parameters to be used in __repr__. """ return ', '.join( '%s=%r' % (n.replace('.', '__'), v) for (n, v) in sorted(iteritems(self._params)) # make sure we don't include empty typed params if 'type' not in self._param_defs.get(n, {}) or v ) def __repr__(self): return '%s(%s)' % ( self.__class__.__name__, self._repr_params() ) def __eq__(self, other): return isinstance(other, self.__class__) and other.to_dict() == self.to_dict() def __ne__(self, other): return not self == other def __setattr__(self, name, value): if name.startswith('_'): return super(DslBase, self).__setattr__(name, value) return self._setattr(name, value) def _setattr(self, name, value): # if this attribute has special type assigned to it... if name in self._param_defs: pinfo = self._param_defs[name] if 'type' in pinfo: # get the shortcut used to construct this type (query.Q, aggs.A, etc) shortcut = self.__class__.get_dsl_type(pinfo['type']) if pinfo.get('multi'): if not isinstance(value, (tuple, list)): value = (value, ) value = list(map(shortcut, value)) # dict(name -> DslBase), make sure we pickup all the objs elif pinfo.get('hash'): value = dict((k, shortcut(v)) for (k, v) in iteritems(value)) # single value object, just convert else: value = shortcut(value) self._params[name] = value def __getattr__(self, name): if name.startswith('_'): raise AttributeError( '%r object has no attribute %r' % (self.__class__.__name__, name)) value = None try: value = self._params[name] except KeyError: # compound types should never throw AttributeError and return empty # container instead if name in self._param_defs: pinfo = self._param_defs[name] if pinfo.get('multi'): value = self._params.setdefault(name, []) elif pinfo.get('hash'): value = self._params.setdefault(name, {}) if value is None: raise AttributeError( '%r object has no attribute %r' % (self.__class__.__name__, name)) # wrap nested dicts in AttrDict for convenient access if isinstance(value, dict): return AttrDict(value) return value def to_dict(self): """ Serialize the DSL object to plain dict """ d = {} for pname, value in iteritems(self._params): pinfo = self._param_defs.get(pname) # typed param if pinfo and 'type' in pinfo: # don't serialize empty lists and dicts for typed fields if value in ({}, []): continue # multi-values are serialized as list of dicts if pinfo.get('multi'): value = list(map(lambda x: x.to_dict(), value)) # squash all the hash values into one dict elif pinfo.get('hash'): value = dict((k, v.to_dict()) for k, v in iteritems(value)) # serialize single values else: value = value.to_dict() # serialize anything with to_dict method elif hasattr(value, 'to_dict'): value = value.to_dict() d[pname] = value return {self.name: d} def _clone(self): return self._type_shortcut(self.to_dict()) class ObjectBase(AttrDict): def __init__(self, **kwargs): m = self._doc_type.mapping for k in m: if k in kwargs and m[k]._coerce: kwargs[k] = m[k].deserialize(kwargs[k]) super(ObjectBase, self).__init__(kwargs) def __getattr__(self, name): try: return super(ObjectBase, self).__getattr__(name) except AttributeError: if name in self._doc_type.mapping: f = self._doc_type.mapping[name] if hasattr(f, 'empty'): value = f.empty() if value not in SKIP_VALUES: setattr(self, name, value) value = getattr(self, name) return value raise def __setattr__(self, name, value): if name in self._doc_type.mapping: value = self._doc_type.mapping[name].deserialize(value) super(ObjectBase, self).__setattr__(name, value) def to_dict(self): out = {} for k, v in iteritems(self._d_): try: f = self._doc_type.mapping[k] if f._coerce: v = f.serialize(v) except KeyError: pass # don't serialize empty values # careful not to include numeric zeros if v in ([], {}, None): continue out[k] = v return out def clean_fields(self): errors = {} for name in self._doc_type.mapping: field = self._doc_type.mapping[name] data = self._d_.get(name, None) try: # save the cleaned value data = field.clean(data) except ValidationException as e: errors.setdefault(name, []).append(e) if name in self._d_ or data not in ([], {}, None): self._d_[name] = data if errors: raise ValidationException(errors) def clean(self): pass def full_clean(self): self.clean_fields() self.clean() def merge(data, new_data): if not (isinstance(data, (AttrDict, dict)) and isinstance(new_data, (AttrDict, dict))): raise ValueError('You can only merge two dicts! Got %r and %r instead.' % (data, new_data)) for key, value in iteritems(new_data): if key in data and isinstance(data[key], (AttrDict, dict)): merge(data[key], value) else: data[key] = value