# coding: utf-8 ''' Implements conditional aggregates. This code was based on the work of others found on the internet: 1. http://web.archive.org/web/20101115170804/http://www.voteruniverse.com/Members/jlantz/blog/conditional-aggregates-in-django 2. https://code.djangoproject.com/ticket/11305 3. https://groups.google.com/forum/?fromgroups=#!topic/django-users/cjzloTUwmS0 4. https://groups.google.com/forum/?fromgroups=#!topic/django-users/vVprMpsAnPo ''' from __future__ import unicode_literals from django.utils import six import django from django.db.models.aggregates import Aggregate as DjangoAggregate from django.db.models.sql.aggregates import Aggregate as DjangoSqlAggregate VERSION = django.VERSION[:2] class SqlAggregate(DjangoSqlAggregate): conditional_template = '%(function)s(CASE WHEN %(condition)s THEN %(field)s ELSE null END)' def __init__(self, col, source=None, is_summary=False, condition=None, **extra): super(SqlAggregate, self).__init__(col, source, is_summary, **extra) self.condition = condition def relabel_aliases(self, change_map): if VERSION < (1, 7): super(SqlAggregate, self).relabel_aliases(change_map) if self.has_condition: condition_change_map = dict((k, v) for k, v in \ change_map.items() if k in self.condition.query.alias_map ) self.condition.query.change_aliases(condition_change_map) def relabeled_clone(self, change_map): self.relabel_aliases(change_map) return super(SqlAggregate, self).relabeled_clone(change_map) def as_sql(self, qn, connection): if self.has_condition: self.sql_template = self.conditional_template self.extra['condition'] = self._condition_as_sql(qn, connection) return super(SqlAggregate, self).as_sql(qn, connection) @property def has_condition(self): # Warning: bool(QuerySet) will hit the database return self.condition is not None def _condition_as_sql(self, qn, connection): ''' Return sql for condition. ''' def escape(value): if isinstance(value, bool): value = str(int(value)) if isinstance(value, six.string_types): # Escape params used with LIKE if '%' in value: value = value.replace('%', '%%') # Escape single quotes if "'" in value: value = value.replace("'", "''") # Add single quote to text values value = "'" + value + "'" return value sql, param = self.condition.query.where.as_sql(qn, connection) param = map(escape, param) return sql % tuple(param) class SqlSum(SqlAggregate): sql_function = 'SUM' class SqlCount(SqlAggregate): is_ordinal = True sql_function = 'COUNT' sql_template = '%(function)s(%(distinct)s%(field)s)' conditional_template = '%(function)s(%(distinct)sCASE WHEN %(condition)s THEN %(field)s ELSE null END)' def __init__(self, col, distinct=False, **extra): super(SqlCount, self).__init__(col, distinct=distinct and 'DISTINCT ' or '', **extra) class SqlAvg(SqlAggregate): is_computed = True sql_function = 'AVG' class SqlMax(SqlAggregate): sql_function = 'MAX' class SqlMin(SqlAggregate): sql_function = 'MIN' class Aggregate(DjangoAggregate): def __init__(self, lookup, only=None, **extra): super(Aggregate, self).__init__(lookup, **extra) self.only = only self.condition = None def _get_fields_from_Q(self, q): fields = [] for child in q.children: if hasattr(child, 'children'): fields.extend(self._get_fields_from_Q(child)) else: fields.append(child) return fields def add_to_query(self, query, alias, col, source, is_summary): if self.only: self.condition = query.model._default_manager.filter(self.only) for child in self._get_fields_from_Q(self.only): field_list = child[0].split('__') # Pop off the last field if it's a query term ('gte', 'contains', 'isnull', etc.) if field_list[-1] in query.query_terms: field_list.pop() # setup_joins have different returns in Django 1.5 and 1.6, but the order of what we need remains. result = query.setup_joins(field_list, query.model._meta, query.get_initial_alias(), None) join_list = result[3] fname = 'promote_alias_chain' if VERSION < (1, 5) else 'promote_joins' args = (join_list, True) if VERSION < (1, 7) else (join_list,) promote = getattr(query, fname) promote(*args) aggregate = self.sql_klass(col, source=source, is_summary=is_summary, condition=self.condition, **self.extra) query.aggregates[alias] = aggregate class Sum(Aggregate): name = 'Sum' sql_klass = SqlSum class Count(Aggregate): name = 'Count' sql_klass = SqlCount class Avg(Aggregate): name = 'Avg' sql_klass = SqlAvg class Max(Aggregate): name = 'Max' sql_klass = SqlMax class Min(Aggregate): name = 'Min' sql_klass = SqlMin