expressions.py 64 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880
  1. import copy
  2. import datetime
  3. import functools
  4. import inspect
  5. import warnings
  6. from collections import defaultdict
  7. from decimal import Decimal
  8. from uuid import UUID
  9. from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
  10. from django.db import DatabaseError, NotSupportedError, connection
  11. from django.db.models import fields
  12. from django.db.models.constants import LOOKUP_SEP
  13. from django.db.models.query_utils import Q
  14. from django.utils.deconstruct import deconstructible
  15. from django.utils.deprecation import RemovedInDjango50Warning
  16. from django.utils.functional import cached_property
  17. from django.utils.hashable import make_hashable
  18. class SQLiteNumericMixin:
  19. """
  20. Some expressions with output_field=DecimalField() must be cast to
  21. numeric to be properly filtered.
  22. """
  23. def as_sqlite(self, compiler, connection, **extra_context):
  24. sql, params = self.as_sql(compiler, connection, **extra_context)
  25. try:
  26. if self.output_field.get_internal_type() == "DecimalField":
  27. sql = "CAST(%s AS NUMERIC)" % sql
  28. except FieldError:
  29. pass
  30. return sql, params
  31. class Combinable:
  32. """
  33. Provide the ability to combine one or two objects with
  34. some connector. For example F('foo') + F('bar').
  35. """
  36. # Arithmetic connectors
  37. ADD = "+"
  38. SUB = "-"
  39. MUL = "*"
  40. DIV = "/"
  41. POW = "^"
  42. # The following is a quoted % operator - it is quoted because it can be
  43. # used in strings that also have parameter substitution.
  44. MOD = "%%"
  45. # Bitwise operators - note that these are generated by .bitand()
  46. # and .bitor(), the '&' and '|' are reserved for boolean operator
  47. # usage.
  48. BITAND = "&"
  49. BITOR = "|"
  50. BITLEFTSHIFT = "<<"
  51. BITRIGHTSHIFT = ">>"
  52. BITXOR = "#"
  53. def _combine(self, other, connector, reversed):
  54. if not hasattr(other, "resolve_expression"):
  55. # everything must be resolvable to an expression
  56. other = Value(other)
  57. if reversed:
  58. return CombinedExpression(other, connector, self)
  59. return CombinedExpression(self, connector, other)
  60. #############
  61. # OPERATORS #
  62. #############
  63. def __neg__(self):
  64. return self._combine(-1, self.MUL, False)
  65. def __add__(self, other):
  66. return self._combine(other, self.ADD, False)
  67. def __sub__(self, other):
  68. return self._combine(other, self.SUB, False)
  69. def __mul__(self, other):
  70. return self._combine(other, self.MUL, False)
  71. def __truediv__(self, other):
  72. return self._combine(other, self.DIV, False)
  73. def __mod__(self, other):
  74. return self._combine(other, self.MOD, False)
  75. def __pow__(self, other):
  76. return self._combine(other, self.POW, False)
  77. def __and__(self, other):
  78. if getattr(self, "conditional", False) and getattr(other, "conditional", False):
  79. return Q(self) & Q(other)
  80. raise NotImplementedError(
  81. "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
  82. )
  83. def bitand(self, other):
  84. return self._combine(other, self.BITAND, False)
  85. def bitleftshift(self, other):
  86. return self._combine(other, self.BITLEFTSHIFT, False)
  87. def bitrightshift(self, other):
  88. return self._combine(other, self.BITRIGHTSHIFT, False)
  89. def __xor__(self, other):
  90. if getattr(self, "conditional", False) and getattr(other, "conditional", False):
  91. return Q(self) ^ Q(other)
  92. raise NotImplementedError(
  93. "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
  94. )
  95. def bitxor(self, other):
  96. return self._combine(other, self.BITXOR, False)
  97. def __or__(self, other):
  98. if getattr(self, "conditional", False) and getattr(other, "conditional", False):
  99. return Q(self) | Q(other)
  100. raise NotImplementedError(
  101. "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
  102. )
  103. def bitor(self, other):
  104. return self._combine(other, self.BITOR, False)
  105. def __radd__(self, other):
  106. return self._combine(other, self.ADD, True)
  107. def __rsub__(self, other):
  108. return self._combine(other, self.SUB, True)
  109. def __rmul__(self, other):
  110. return self._combine(other, self.MUL, True)
  111. def __rtruediv__(self, other):
  112. return self._combine(other, self.DIV, True)
  113. def __rmod__(self, other):
  114. return self._combine(other, self.MOD, True)
  115. def __rpow__(self, other):
  116. return self._combine(other, self.POW, True)
  117. def __rand__(self, other):
  118. raise NotImplementedError(
  119. "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
  120. )
  121. def __ror__(self, other):
  122. raise NotImplementedError(
  123. "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
  124. )
  125. def __rxor__(self, other):
  126. raise NotImplementedError(
  127. "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
  128. )
  129. def __invert__(self):
  130. return NegatedExpression(self)
  131. class BaseExpression:
  132. """Base class for all query expressions."""
  133. empty_result_set_value = NotImplemented
  134. # aggregate specific fields
  135. is_summary = False
  136. _output_field_resolved_to_none = False
  137. # Can the expression be used in a WHERE clause?
  138. filterable = True
  139. # Can the expression can be used as a source expression in Window?
  140. window_compatible = False
  141. def __init__(self, output_field=None):
  142. if output_field is not None:
  143. self.output_field = output_field
  144. def __getstate__(self):
  145. state = self.__dict__.copy()
  146. state.pop("convert_value", None)
  147. return state
  148. def get_db_converters(self, connection):
  149. return (
  150. []
  151. if self.convert_value is self._convert_value_noop
  152. else [self.convert_value]
  153. ) + self.output_field.get_db_converters(connection)
  154. def get_source_expressions(self):
  155. return []
  156. def set_source_expressions(self, exprs):
  157. assert not exprs
  158. def _parse_expressions(self, *expressions):
  159. return [
  160. arg
  161. if hasattr(arg, "resolve_expression")
  162. else (F(arg) if isinstance(arg, str) else Value(arg))
  163. for arg in expressions
  164. ]
  165. def as_sql(self, compiler, connection):
  166. """
  167. Responsible for returning a (sql, [params]) tuple to be included
  168. in the current query.
  169. Different backends can provide their own implementation, by
  170. providing an `as_{vendor}` method and patching the Expression:
  171. ```
  172. def override_as_sql(self, compiler, connection):
  173. # custom logic
  174. return super().as_sql(compiler, connection)
  175. setattr(Expression, 'as_' + connection.vendor, override_as_sql)
  176. ```
  177. Arguments:
  178. * compiler: the query compiler responsible for generating the query.
  179. Must have a compile method, returning a (sql, [params]) tuple.
  180. Calling compiler(value) will return a quoted `value`.
  181. * connection: the database connection used for the current query.
  182. Return: (sql, params)
  183. Where `sql` is a string containing ordered sql parameters to be
  184. replaced with the elements of the list `params`.
  185. """
  186. raise NotImplementedError("Subclasses must implement as_sql()")
  187. @cached_property
  188. def contains_aggregate(self):
  189. return any(
  190. expr and expr.contains_aggregate for expr in self.get_source_expressions()
  191. )
  192. @cached_property
  193. def contains_over_clause(self):
  194. return any(
  195. expr and expr.contains_over_clause for expr in self.get_source_expressions()
  196. )
  197. @cached_property
  198. def contains_column_references(self):
  199. return any(
  200. expr and expr.contains_column_references
  201. for expr in self.get_source_expressions()
  202. )
  203. @cached_property
  204. def contains_subquery(self):
  205. return any(
  206. expr and (getattr(expr, "subquery", False) or expr.contains_subquery)
  207. for expr in self.get_source_expressions()
  208. )
  209. def resolve_expression(
  210. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  211. ):
  212. """
  213. Provide the chance to do any preprocessing or validation before being
  214. added to the query.
  215. Arguments:
  216. * query: the backend query implementation
  217. * allow_joins: boolean allowing or denying use of joins
  218. in this query
  219. * reuse: a set of reusable joins for multijoins
  220. * summarize: a terminal aggregate clause
  221. * for_save: whether this expression about to be used in a save or update
  222. Return: an Expression to be added to the query.
  223. """
  224. c = self.copy()
  225. c.is_summary = summarize
  226. c.set_source_expressions(
  227. [
  228. expr.resolve_expression(query, allow_joins, reuse, summarize)
  229. if expr
  230. else None
  231. for expr in c.get_source_expressions()
  232. ]
  233. )
  234. return c
  235. @property
  236. def conditional(self):
  237. return isinstance(self.output_field, fields.BooleanField)
  238. @property
  239. def field(self):
  240. return self.output_field
  241. @cached_property
  242. def output_field(self):
  243. """Return the output type of this expressions."""
  244. output_field = self._resolve_output_field()
  245. if output_field is None:
  246. self._output_field_resolved_to_none = True
  247. raise FieldError("Cannot resolve expression type, unknown output_field")
  248. return output_field
  249. @cached_property
  250. def _output_field_or_none(self):
  251. """
  252. Return the output field of this expression, or None if
  253. _resolve_output_field() didn't return an output type.
  254. """
  255. try:
  256. return self.output_field
  257. except FieldError:
  258. if not self._output_field_resolved_to_none:
  259. raise
  260. def _resolve_output_field(self):
  261. """
  262. Attempt to infer the output type of the expression.
  263. As a guess, if the output fields of all source fields match then simply
  264. infer the same type here.
  265. If a source's output field resolves to None, exclude it from this check.
  266. If all sources are None, then an error is raised higher up the stack in
  267. the output_field property.
  268. """
  269. # This guess is mostly a bad idea, but there is quite a lot of code
  270. # (especially 3rd party Func subclasses) that depend on it, we'd need a
  271. # deprecation path to fix it.
  272. sources_iter = (
  273. source for source in self.get_source_fields() if source is not None
  274. )
  275. for output_field in sources_iter:
  276. for source in sources_iter:
  277. if not isinstance(output_field, source.__class__):
  278. raise FieldError(
  279. "Expression contains mixed types: %s, %s. You must "
  280. "set output_field."
  281. % (
  282. output_field.__class__.__name__,
  283. source.__class__.__name__,
  284. )
  285. )
  286. return output_field
  287. @staticmethod
  288. def _convert_value_noop(value, expression, connection):
  289. return value
  290. @cached_property
  291. def convert_value(self):
  292. """
  293. Expressions provide their own converters because users have the option
  294. of manually specifying the output_field which may be a different type
  295. from the one the database returns.
  296. """
  297. field = self.output_field
  298. internal_type = field.get_internal_type()
  299. if internal_type == "FloatField":
  300. return (
  301. lambda value, expression, connection: None
  302. if value is None
  303. else float(value)
  304. )
  305. elif internal_type.endswith("IntegerField"):
  306. return (
  307. lambda value, expression, connection: None
  308. if value is None
  309. else int(value)
  310. )
  311. elif internal_type == "DecimalField":
  312. return (
  313. lambda value, expression, connection: None
  314. if value is None
  315. else Decimal(value)
  316. )
  317. return self._convert_value_noop
  318. def get_lookup(self, lookup):
  319. return self.output_field.get_lookup(lookup)
  320. def get_transform(self, name):
  321. return self.output_field.get_transform(name)
  322. def relabeled_clone(self, change_map):
  323. clone = self.copy()
  324. clone.set_source_expressions(
  325. [
  326. e.relabeled_clone(change_map) if e is not None else None
  327. for e in self.get_source_expressions()
  328. ]
  329. )
  330. return clone
  331. def replace_expressions(self, replacements):
  332. if replacement := replacements.get(self):
  333. return replacement
  334. clone = self.copy()
  335. source_expressions = clone.get_source_expressions()
  336. clone.set_source_expressions(
  337. [
  338. expr.replace_expressions(replacements) if expr else None
  339. for expr in source_expressions
  340. ]
  341. )
  342. return clone
  343. def get_refs(self):
  344. refs = set()
  345. for expr in self.get_source_expressions():
  346. refs |= expr.get_refs()
  347. return refs
  348. def copy(self):
  349. return copy.copy(self)
  350. def prefix_references(self, prefix):
  351. clone = self.copy()
  352. clone.set_source_expressions(
  353. [
  354. F(f"{prefix}{expr.name}")
  355. if isinstance(expr, F)
  356. else expr.prefix_references(prefix)
  357. for expr in self.get_source_expressions()
  358. ]
  359. )
  360. return clone
  361. def get_group_by_cols(self):
  362. if not self.contains_aggregate:
  363. return [self]
  364. cols = []
  365. for source in self.get_source_expressions():
  366. cols.extend(source.get_group_by_cols())
  367. return cols
  368. def get_source_fields(self):
  369. """Return the underlying field types used by this aggregate."""
  370. return [e._output_field_or_none for e in self.get_source_expressions()]
  371. def asc(self, **kwargs):
  372. return OrderBy(self, **kwargs)
  373. def desc(self, **kwargs):
  374. return OrderBy(self, descending=True, **kwargs)
  375. def reverse_ordering(self):
  376. return self
  377. def flatten(self):
  378. """
  379. Recursively yield this expression and all subexpressions, in
  380. depth-first order.
  381. """
  382. yield self
  383. for expr in self.get_source_expressions():
  384. if expr:
  385. if hasattr(expr, "flatten"):
  386. yield from expr.flatten()
  387. else:
  388. yield expr
  389. def select_format(self, compiler, sql, params):
  390. """
  391. Custom format for select clauses. For example, EXISTS expressions need
  392. to be wrapped in CASE WHEN on Oracle.
  393. """
  394. if hasattr(self.output_field, "select_format"):
  395. return self.output_field.select_format(compiler, sql, params)
  396. return sql, params
  397. @deconstructible
  398. class Expression(BaseExpression, Combinable):
  399. """An expression that can be combined with other expressions."""
  400. @cached_property
  401. def identity(self):
  402. constructor_signature = inspect.signature(self.__init__)
  403. args, kwargs = self._constructor_args
  404. signature = constructor_signature.bind_partial(*args, **kwargs)
  405. signature.apply_defaults()
  406. arguments = signature.arguments.items()
  407. identity = [self.__class__]
  408. for arg, value in arguments:
  409. if isinstance(value, fields.Field):
  410. if value.name and value.model:
  411. value = (value.model._meta.label, value.name)
  412. else:
  413. value = type(value)
  414. else:
  415. value = make_hashable(value)
  416. identity.append((arg, value))
  417. return tuple(identity)
  418. def __eq__(self, other):
  419. if not isinstance(other, Expression):
  420. return NotImplemented
  421. return other.identity == self.identity
  422. def __hash__(self):
  423. return hash(self.identity)
  424. # Type inference for CombinedExpression.output_field.
  425. # Missing items will result in FieldError, by design.
  426. #
  427. # The current approach for NULL is based on lowest common denominator behavior
  428. # i.e. if one of the supported databases is raising an error (rather than
  429. # return NULL) for `val <op> NULL`, then Django raises FieldError.
  430. NoneType = type(None)
  431. _connector_combinations = [
  432. # Numeric operations - operands of same type.
  433. {
  434. connector: [
  435. (fields.IntegerField, fields.IntegerField, fields.IntegerField),
  436. (fields.FloatField, fields.FloatField, fields.FloatField),
  437. (fields.DecimalField, fields.DecimalField, fields.DecimalField),
  438. ]
  439. for connector in (
  440. Combinable.ADD,
  441. Combinable.SUB,
  442. Combinable.MUL,
  443. # Behavior for DIV with integer arguments follows Postgres/SQLite,
  444. # not MySQL/Oracle.
  445. Combinable.DIV,
  446. Combinable.MOD,
  447. Combinable.POW,
  448. )
  449. },
  450. # Numeric operations - operands of different type.
  451. {
  452. connector: [
  453. (fields.IntegerField, fields.DecimalField, fields.DecimalField),
  454. (fields.DecimalField, fields.IntegerField, fields.DecimalField),
  455. (fields.IntegerField, fields.FloatField, fields.FloatField),
  456. (fields.FloatField, fields.IntegerField, fields.FloatField),
  457. ]
  458. for connector in (
  459. Combinable.ADD,
  460. Combinable.SUB,
  461. Combinable.MUL,
  462. Combinable.DIV,
  463. Combinable.MOD,
  464. )
  465. },
  466. # Bitwise operators.
  467. {
  468. connector: [
  469. (fields.IntegerField, fields.IntegerField, fields.IntegerField),
  470. ]
  471. for connector in (
  472. Combinable.BITAND,
  473. Combinable.BITOR,
  474. Combinable.BITLEFTSHIFT,
  475. Combinable.BITRIGHTSHIFT,
  476. Combinable.BITXOR,
  477. )
  478. },
  479. # Numeric with NULL.
  480. {
  481. connector: [
  482. (field_type, NoneType, field_type),
  483. (NoneType, field_type, field_type),
  484. ]
  485. for connector in (
  486. Combinable.ADD,
  487. Combinable.SUB,
  488. Combinable.MUL,
  489. Combinable.DIV,
  490. Combinable.MOD,
  491. Combinable.POW,
  492. )
  493. for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
  494. },
  495. # Date/DateTimeField/DurationField/TimeField.
  496. {
  497. Combinable.ADD: [
  498. # Date/DateTimeField.
  499. (fields.DateField, fields.DurationField, fields.DateTimeField),
  500. (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
  501. (fields.DurationField, fields.DateField, fields.DateTimeField),
  502. (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
  503. # DurationField.
  504. (fields.DurationField, fields.DurationField, fields.DurationField),
  505. # TimeField.
  506. (fields.TimeField, fields.DurationField, fields.TimeField),
  507. (fields.DurationField, fields.TimeField, fields.TimeField),
  508. ],
  509. },
  510. {
  511. Combinable.SUB: [
  512. # Date/DateTimeField.
  513. (fields.DateField, fields.DurationField, fields.DateTimeField),
  514. (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
  515. (fields.DateField, fields.DateField, fields.DurationField),
  516. (fields.DateField, fields.DateTimeField, fields.DurationField),
  517. (fields.DateTimeField, fields.DateField, fields.DurationField),
  518. (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
  519. # DurationField.
  520. (fields.DurationField, fields.DurationField, fields.DurationField),
  521. # TimeField.
  522. (fields.TimeField, fields.DurationField, fields.TimeField),
  523. (fields.TimeField, fields.TimeField, fields.DurationField),
  524. ],
  525. },
  526. ]
  527. _connector_combinators = defaultdict(list)
  528. def register_combinable_fields(lhs, connector, rhs, result):
  529. """
  530. Register combinable types:
  531. lhs <connector> rhs -> result
  532. e.g.
  533. register_combinable_fields(
  534. IntegerField, Combinable.ADD, FloatField, FloatField
  535. )
  536. """
  537. _connector_combinators[connector].append((lhs, rhs, result))
  538. for d in _connector_combinations:
  539. for connector, field_types in d.items():
  540. for lhs, rhs, result in field_types:
  541. register_combinable_fields(lhs, connector, rhs, result)
  542. @functools.lru_cache(maxsize=128)
  543. def _resolve_combined_type(connector, lhs_type, rhs_type):
  544. combinators = _connector_combinators.get(connector, ())
  545. for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
  546. if issubclass(lhs_type, combinator_lhs_type) and issubclass(
  547. rhs_type, combinator_rhs_type
  548. ):
  549. return combined_type
  550. class CombinedExpression(SQLiteNumericMixin, Expression):
  551. def __init__(self, lhs, connector, rhs, output_field=None):
  552. super().__init__(output_field=output_field)
  553. self.connector = connector
  554. self.lhs = lhs
  555. self.rhs = rhs
  556. def __repr__(self):
  557. return "<{}: {}>".format(self.__class__.__name__, self)
  558. def __str__(self):
  559. return "{} {} {}".format(self.lhs, self.connector, self.rhs)
  560. def get_source_expressions(self):
  561. return [self.lhs, self.rhs]
  562. def set_source_expressions(self, exprs):
  563. self.lhs, self.rhs = exprs
  564. def _resolve_output_field(self):
  565. # We avoid using super() here for reasons given in
  566. # Expression._resolve_output_field()
  567. combined_type = _resolve_combined_type(
  568. self.connector,
  569. type(self.lhs._output_field_or_none),
  570. type(self.rhs._output_field_or_none),
  571. )
  572. if combined_type is None:
  573. raise FieldError(
  574. f"Cannot infer type of {self.connector!r} expression involving these "
  575. f"types: {self.lhs.output_field.__class__.__name__}, "
  576. f"{self.rhs.output_field.__class__.__name__}. You must set "
  577. f"output_field."
  578. )
  579. return combined_type()
  580. def as_sql(self, compiler, connection):
  581. expressions = []
  582. expression_params = []
  583. sql, params = compiler.compile(self.lhs)
  584. expressions.append(sql)
  585. expression_params.extend(params)
  586. sql, params = compiler.compile(self.rhs)
  587. expressions.append(sql)
  588. expression_params.extend(params)
  589. # order of precedence
  590. expression_wrapper = "(%s)"
  591. sql = connection.ops.combine_expression(self.connector, expressions)
  592. return expression_wrapper % sql, expression_params
  593. def resolve_expression(
  594. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  595. ):
  596. lhs = self.lhs.resolve_expression(
  597. query, allow_joins, reuse, summarize, for_save
  598. )
  599. rhs = self.rhs.resolve_expression(
  600. query, allow_joins, reuse, summarize, for_save
  601. )
  602. if not isinstance(self, (DurationExpression, TemporalSubtraction)):
  603. try:
  604. lhs_type = lhs.output_field.get_internal_type()
  605. except (AttributeError, FieldError):
  606. lhs_type = None
  607. try:
  608. rhs_type = rhs.output_field.get_internal_type()
  609. except (AttributeError, FieldError):
  610. rhs_type = None
  611. if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
  612. return DurationExpression(
  613. self.lhs, self.connector, self.rhs
  614. ).resolve_expression(
  615. query,
  616. allow_joins,
  617. reuse,
  618. summarize,
  619. for_save,
  620. )
  621. datetime_fields = {"DateField", "DateTimeField", "TimeField"}
  622. if (
  623. self.connector == self.SUB
  624. and lhs_type in datetime_fields
  625. and lhs_type == rhs_type
  626. ):
  627. return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
  628. query,
  629. allow_joins,
  630. reuse,
  631. summarize,
  632. for_save,
  633. )
  634. c = self.copy()
  635. c.is_summary = summarize
  636. c.lhs = lhs
  637. c.rhs = rhs
  638. return c
  639. class DurationExpression(CombinedExpression):
  640. def compile(self, side, compiler, connection):
  641. try:
  642. output = side.output_field
  643. except FieldError:
  644. pass
  645. else:
  646. if output.get_internal_type() == "DurationField":
  647. sql, params = compiler.compile(side)
  648. return connection.ops.format_for_duration_arithmetic(sql), params
  649. return compiler.compile(side)
  650. def as_sql(self, compiler, connection):
  651. if connection.features.has_native_duration_field:
  652. return super().as_sql(compiler, connection)
  653. connection.ops.check_expression_support(self)
  654. expressions = []
  655. expression_params = []
  656. sql, params = self.compile(self.lhs, compiler, connection)
  657. expressions.append(sql)
  658. expression_params.extend(params)
  659. sql, params = self.compile(self.rhs, compiler, connection)
  660. expressions.append(sql)
  661. expression_params.extend(params)
  662. # order of precedence
  663. expression_wrapper = "(%s)"
  664. sql = connection.ops.combine_duration_expression(self.connector, expressions)
  665. return expression_wrapper % sql, expression_params
  666. def as_sqlite(self, compiler, connection, **extra_context):
  667. sql, params = self.as_sql(compiler, connection, **extra_context)
  668. if self.connector in {Combinable.MUL, Combinable.DIV}:
  669. try:
  670. lhs_type = self.lhs.output_field.get_internal_type()
  671. rhs_type = self.rhs.output_field.get_internal_type()
  672. except (AttributeError, FieldError):
  673. pass
  674. else:
  675. allowed_fields = {
  676. "DecimalField",
  677. "DurationField",
  678. "FloatField",
  679. "IntegerField",
  680. }
  681. if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
  682. raise DatabaseError(
  683. f"Invalid arguments for operator {self.connector}."
  684. )
  685. return sql, params
  686. class TemporalSubtraction(CombinedExpression):
  687. output_field = fields.DurationField()
  688. def __init__(self, lhs, rhs):
  689. super().__init__(lhs, self.SUB, rhs)
  690. def as_sql(self, compiler, connection):
  691. connection.ops.check_expression_support(self)
  692. lhs = compiler.compile(self.lhs)
  693. rhs = compiler.compile(self.rhs)
  694. return connection.ops.subtract_temporals(
  695. self.lhs.output_field.get_internal_type(), lhs, rhs
  696. )
  697. @deconstructible(path="django.db.models.F")
  698. class F(Combinable):
  699. """An object capable of resolving references to existing query objects."""
  700. def __init__(self, name):
  701. """
  702. Arguments:
  703. * name: the name of the field this expression references
  704. """
  705. self.name = name
  706. def __repr__(self):
  707. return "{}({})".format(self.__class__.__name__, self.name)
  708. def resolve_expression(
  709. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  710. ):
  711. return query.resolve_ref(self.name, allow_joins, reuse, summarize)
  712. def replace_expressions(self, replacements):
  713. return replacements.get(self, self)
  714. def asc(self, **kwargs):
  715. return OrderBy(self, **kwargs)
  716. def desc(self, **kwargs):
  717. return OrderBy(self, descending=True, **kwargs)
  718. def __eq__(self, other):
  719. return self.__class__ == other.__class__ and self.name == other.name
  720. def __hash__(self):
  721. return hash(self.name)
  722. def copy(self):
  723. return copy.copy(self)
  724. class ResolvedOuterRef(F):
  725. """
  726. An object that contains a reference to an outer query.
  727. In this case, the reference to the outer query has been resolved because
  728. the inner query has been used as a subquery.
  729. """
  730. contains_aggregate = False
  731. contains_over_clause = False
  732. def as_sql(self, *args, **kwargs):
  733. raise ValueError(
  734. "This queryset contains a reference to an outer query and may "
  735. "only be used in a subquery."
  736. )
  737. def resolve_expression(self, *args, **kwargs):
  738. col = super().resolve_expression(*args, **kwargs)
  739. if col.contains_over_clause:
  740. raise NotSupportedError(
  741. f"Referencing outer query window expression is not supported: "
  742. f"{self.name}."
  743. )
  744. # FIXME: Rename possibly_multivalued to multivalued and fix detection
  745. # for non-multivalued JOINs (e.g. foreign key fields). This should take
  746. # into account only many-to-many and one-to-many relationships.
  747. col.possibly_multivalued = LOOKUP_SEP in self.name
  748. return col
  749. def relabeled_clone(self, relabels):
  750. return self
  751. def get_group_by_cols(self):
  752. return []
  753. class OuterRef(F):
  754. contains_aggregate = False
  755. contains_over_clause = False
  756. def resolve_expression(self, *args, **kwargs):
  757. if isinstance(self.name, self.__class__):
  758. return self.name
  759. return ResolvedOuterRef(self.name)
  760. def relabeled_clone(self, relabels):
  761. return self
  762. @deconstructible(path="django.db.models.Func")
  763. class Func(SQLiteNumericMixin, Expression):
  764. """An SQL function call."""
  765. function = None
  766. template = "%(function)s(%(expressions)s)"
  767. arg_joiner = ", "
  768. arity = None # The number of arguments the function accepts.
  769. def __init__(self, *expressions, output_field=None, **extra):
  770. if self.arity is not None and len(expressions) != self.arity:
  771. raise TypeError(
  772. "'%s' takes exactly %s %s (%s given)"
  773. % (
  774. self.__class__.__name__,
  775. self.arity,
  776. "argument" if self.arity == 1 else "arguments",
  777. len(expressions),
  778. )
  779. )
  780. super().__init__(output_field=output_field)
  781. self.source_expressions = self._parse_expressions(*expressions)
  782. self.extra = extra
  783. def __repr__(self):
  784. args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
  785. extra = {**self.extra, **self._get_repr_options()}
  786. if extra:
  787. extra = ", ".join(
  788. str(key) + "=" + str(val) for key, val in sorted(extra.items())
  789. )
  790. return "{}({}, {})".format(self.__class__.__name__, args, extra)
  791. return "{}({})".format(self.__class__.__name__, args)
  792. def _get_repr_options(self):
  793. """Return a dict of extra __init__() options to include in the repr."""
  794. return {}
  795. def get_source_expressions(self):
  796. return self.source_expressions
  797. def set_source_expressions(self, exprs):
  798. self.source_expressions = exprs
  799. def resolve_expression(
  800. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  801. ):
  802. c = self.copy()
  803. c.is_summary = summarize
  804. for pos, arg in enumerate(c.source_expressions):
  805. c.source_expressions[pos] = arg.resolve_expression(
  806. query, allow_joins, reuse, summarize, for_save
  807. )
  808. return c
  809. def as_sql(
  810. self,
  811. compiler,
  812. connection,
  813. function=None,
  814. template=None,
  815. arg_joiner=None,
  816. **extra_context,
  817. ):
  818. connection.ops.check_expression_support(self)
  819. sql_parts = []
  820. params = []
  821. for arg in self.source_expressions:
  822. try:
  823. arg_sql, arg_params = compiler.compile(arg)
  824. except EmptyResultSet:
  825. empty_result_set_value = getattr(
  826. arg, "empty_result_set_value", NotImplemented
  827. )
  828. if empty_result_set_value is NotImplemented:
  829. raise
  830. arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
  831. except FullResultSet:
  832. arg_sql, arg_params = compiler.compile(Value(True))
  833. sql_parts.append(arg_sql)
  834. params.extend(arg_params)
  835. data = {**self.extra, **extra_context}
  836. # Use the first supplied value in this order: the parameter to this
  837. # method, a value supplied in __init__()'s **extra (the value in
  838. # `data`), or the value defined on the class.
  839. if function is not None:
  840. data["function"] = function
  841. else:
  842. data.setdefault("function", self.function)
  843. template = template or data.get("template", self.template)
  844. arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
  845. data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
  846. return template % data, params
  847. def copy(self):
  848. copy = super().copy()
  849. copy.source_expressions = self.source_expressions[:]
  850. copy.extra = self.extra.copy()
  851. return copy
  852. @deconstructible(path="django.db.models.Value")
  853. class Value(SQLiteNumericMixin, Expression):
  854. """Represent a wrapped value as a node within an expression."""
  855. # Provide a default value for `for_save` in order to allow unresolved
  856. # instances to be compiled until a decision is taken in #25425.
  857. for_save = False
  858. def __init__(self, value, output_field=None):
  859. """
  860. Arguments:
  861. * value: the value this expression represents. The value will be
  862. added into the sql parameter list and properly quoted.
  863. * output_field: an instance of the model field type that this
  864. expression will return, such as IntegerField() or CharField().
  865. """
  866. super().__init__(output_field=output_field)
  867. self.value = value
  868. def __repr__(self):
  869. return f"{self.__class__.__name__}({self.value!r})"
  870. def as_sql(self, compiler, connection):
  871. connection.ops.check_expression_support(self)
  872. val = self.value
  873. output_field = self._output_field_or_none
  874. if output_field is not None:
  875. if self.for_save:
  876. val = output_field.get_db_prep_save(val, connection=connection)
  877. else:
  878. val = output_field.get_db_prep_value(val, connection=connection)
  879. if hasattr(output_field, "get_placeholder"):
  880. return output_field.get_placeholder(val, compiler, connection), [val]
  881. if val is None:
  882. # cx_Oracle does not always convert None to the appropriate
  883. # NULL type (like in case expressions using numbers), so we
  884. # use a literal SQL NULL
  885. return "NULL", []
  886. return "%s", [val]
  887. def resolve_expression(
  888. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  889. ):
  890. c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
  891. c.for_save = for_save
  892. return c
  893. def get_group_by_cols(self):
  894. return []
  895. def _resolve_output_field(self):
  896. if isinstance(self.value, str):
  897. return fields.CharField()
  898. if isinstance(self.value, bool):
  899. return fields.BooleanField()
  900. if isinstance(self.value, int):
  901. return fields.IntegerField()
  902. if isinstance(self.value, float):
  903. return fields.FloatField()
  904. if isinstance(self.value, datetime.datetime):
  905. return fields.DateTimeField()
  906. if isinstance(self.value, datetime.date):
  907. return fields.DateField()
  908. if isinstance(self.value, datetime.time):
  909. return fields.TimeField()
  910. if isinstance(self.value, datetime.timedelta):
  911. return fields.DurationField()
  912. if isinstance(self.value, Decimal):
  913. return fields.DecimalField()
  914. if isinstance(self.value, bytes):
  915. return fields.BinaryField()
  916. if isinstance(self.value, UUID):
  917. return fields.UUIDField()
  918. @property
  919. def empty_result_set_value(self):
  920. return self.value
  921. class RawSQL(Expression):
  922. def __init__(self, sql, params, output_field=None):
  923. if output_field is None:
  924. output_field = fields.Field()
  925. self.sql, self.params = sql, params
  926. super().__init__(output_field=output_field)
  927. def __repr__(self):
  928. return "{}({}, {})".format(self.__class__.__name__, self.sql, self.params)
  929. def as_sql(self, compiler, connection):
  930. return "(%s)" % self.sql, self.params
  931. def get_group_by_cols(self):
  932. return [self]
  933. def resolve_expression(
  934. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  935. ):
  936. # Resolve parents fields used in raw SQL.
  937. if query.model:
  938. for parent in query.model._meta.get_parent_list():
  939. for parent_field in parent._meta.local_fields:
  940. _, column_name = parent_field.get_attname_column()
  941. if column_name.lower() in self.sql.lower():
  942. query.resolve_ref(
  943. parent_field.name, allow_joins, reuse, summarize
  944. )
  945. break
  946. return super().resolve_expression(
  947. query, allow_joins, reuse, summarize, for_save
  948. )
  949. class Star(Expression):
  950. def __repr__(self):
  951. return "'*'"
  952. def as_sql(self, compiler, connection):
  953. return "*", []
  954. class Col(Expression):
  955. contains_column_references = True
  956. possibly_multivalued = False
  957. def __init__(self, alias, target, output_field=None):
  958. if output_field is None:
  959. output_field = target
  960. super().__init__(output_field=output_field)
  961. self.alias, self.target = alias, target
  962. def __repr__(self):
  963. alias, target = self.alias, self.target
  964. identifiers = (alias, str(target)) if alias else (str(target),)
  965. return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
  966. def as_sql(self, compiler, connection):
  967. alias, column = self.alias, self.target.column
  968. identifiers = (alias, column) if alias else (column,)
  969. sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
  970. return sql, []
  971. def relabeled_clone(self, relabels):
  972. if self.alias is None:
  973. return self
  974. return self.__class__(
  975. relabels.get(self.alias, self.alias), self.target, self.output_field
  976. )
  977. def get_group_by_cols(self):
  978. return [self]
  979. def get_db_converters(self, connection):
  980. if self.target == self.output_field:
  981. return self.output_field.get_db_converters(connection)
  982. return self.output_field.get_db_converters(
  983. connection
  984. ) + self.target.get_db_converters(connection)
  985. class Ref(Expression):
  986. """
  987. Reference to column alias of the query. For example, Ref('sum_cost') in
  988. qs.annotate(sum_cost=Sum('cost')) query.
  989. """
  990. def __init__(self, refs, source):
  991. super().__init__()
  992. self.refs, self.source = refs, source
  993. def __repr__(self):
  994. return "{}({}, {})".format(self.__class__.__name__, self.refs, self.source)
  995. def get_source_expressions(self):
  996. return [self.source]
  997. def set_source_expressions(self, exprs):
  998. (self.source,) = exprs
  999. def resolve_expression(
  1000. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  1001. ):
  1002. # The sub-expression `source` has already been resolved, as this is
  1003. # just a reference to the name of `source`.
  1004. return self
  1005. def get_refs(self):
  1006. return {self.refs}
  1007. def relabeled_clone(self, relabels):
  1008. clone = self.copy()
  1009. clone.source = self.source.relabeled_clone(relabels)
  1010. return clone
  1011. def as_sql(self, compiler, connection):
  1012. return connection.ops.quote_name(self.refs), []
  1013. def get_group_by_cols(self):
  1014. return [self]
  1015. class ExpressionList(Func):
  1016. """
  1017. An expression containing multiple expressions. Can be used to provide a
  1018. list of expressions as an argument to another expression, like a partition
  1019. clause.
  1020. """
  1021. template = "%(expressions)s"
  1022. def __init__(self, *expressions, **extra):
  1023. if not expressions:
  1024. raise ValueError(
  1025. "%s requires at least one expression." % self.__class__.__name__
  1026. )
  1027. super().__init__(*expressions, **extra)
  1028. def __str__(self):
  1029. return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
  1030. def as_sqlite(self, compiler, connection, **extra_context):
  1031. # Casting to numeric is unnecessary.
  1032. return self.as_sql(compiler, connection, **extra_context)
  1033. def get_group_by_cols(self):
  1034. group_by_cols = []
  1035. for partition in self.get_source_expressions():
  1036. group_by_cols.extend(partition.get_group_by_cols())
  1037. return group_by_cols
  1038. class OrderByList(Func):
  1039. template = "ORDER BY %(expressions)s"
  1040. def __init__(self, *expressions, **extra):
  1041. expressions = (
  1042. (
  1043. OrderBy(F(expr[1:]), descending=True)
  1044. if isinstance(expr, str) and expr[0] == "-"
  1045. else expr
  1046. )
  1047. for expr in expressions
  1048. )
  1049. super().__init__(*expressions, **extra)
  1050. def as_sql(self, *args, **kwargs):
  1051. if not self.source_expressions:
  1052. return "", ()
  1053. return super().as_sql(*args, **kwargs)
  1054. def get_group_by_cols(self):
  1055. group_by_cols = []
  1056. for order_by in self.get_source_expressions():
  1057. group_by_cols.extend(order_by.get_group_by_cols())
  1058. return group_by_cols
  1059. @deconstructible(path="django.db.models.ExpressionWrapper")
  1060. class ExpressionWrapper(SQLiteNumericMixin, Expression):
  1061. """
  1062. An expression that can wrap another expression so that it can provide
  1063. extra context to the inner expression, such as the output_field.
  1064. """
  1065. def __init__(self, expression, output_field):
  1066. super().__init__(output_field=output_field)
  1067. self.expression = expression
  1068. def set_source_expressions(self, exprs):
  1069. self.expression = exprs[0]
  1070. def get_source_expressions(self):
  1071. return [self.expression]
  1072. def get_group_by_cols(self):
  1073. if isinstance(self.expression, Expression):
  1074. expression = self.expression.copy()
  1075. expression.output_field = self.output_field
  1076. return expression.get_group_by_cols()
  1077. # For non-expressions e.g. an SQL WHERE clause, the entire
  1078. # `expression` must be included in the GROUP BY clause.
  1079. return super().get_group_by_cols()
  1080. def as_sql(self, compiler, connection):
  1081. return compiler.compile(self.expression)
  1082. def __repr__(self):
  1083. return "{}({})".format(self.__class__.__name__, self.expression)
  1084. class NegatedExpression(ExpressionWrapper):
  1085. """The logical negation of a conditional expression."""
  1086. def __init__(self, expression):
  1087. super().__init__(expression, output_field=fields.BooleanField())
  1088. def __invert__(self):
  1089. return self.expression.copy()
  1090. def as_sql(self, compiler, connection):
  1091. try:
  1092. sql, params = super().as_sql(compiler, connection)
  1093. except EmptyResultSet:
  1094. features = compiler.connection.features
  1095. if not features.supports_boolean_expr_in_select_clause:
  1096. return "1=1", ()
  1097. return compiler.compile(Value(True))
  1098. ops = compiler.connection.ops
  1099. # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
  1100. # to be compared to another expression unless they're wrapped in a CASE
  1101. # WHEN.
  1102. if not ops.conditional_expression_supported_in_where_clause(self.expression):
  1103. return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
  1104. return f"NOT {sql}", params
  1105. def resolve_expression(
  1106. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  1107. ):
  1108. resolved = super().resolve_expression(
  1109. query, allow_joins, reuse, summarize, for_save
  1110. )
  1111. if not getattr(resolved.expression, "conditional", False):
  1112. raise TypeError("Cannot negate non-conditional expressions.")
  1113. return resolved
  1114. def select_format(self, compiler, sql, params):
  1115. # Wrap boolean expressions with a CASE WHEN expression if a database
  1116. # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
  1117. # GROUP BY list.
  1118. expression_supported_in_where_clause = (
  1119. compiler.connection.ops.conditional_expression_supported_in_where_clause
  1120. )
  1121. if (
  1122. not compiler.connection.features.supports_boolean_expr_in_select_clause
  1123. # Avoid double wrapping.
  1124. and expression_supported_in_where_clause(self.expression)
  1125. ):
  1126. sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql)
  1127. return sql, params
  1128. @deconstructible(path="django.db.models.When")
  1129. class When(Expression):
  1130. template = "WHEN %(condition)s THEN %(result)s"
  1131. # This isn't a complete conditional expression, must be used in Case().
  1132. conditional = False
  1133. def __init__(self, condition=None, then=None, **lookups):
  1134. if lookups:
  1135. if condition is None:
  1136. condition, lookups = Q(**lookups), None
  1137. elif getattr(condition, "conditional", False):
  1138. condition, lookups = Q(condition, **lookups), None
  1139. if condition is None or not getattr(condition, "conditional", False) or lookups:
  1140. raise TypeError(
  1141. "When() supports a Q object, a boolean expression, or lookups "
  1142. "as a condition."
  1143. )
  1144. if isinstance(condition, Q) and not condition:
  1145. raise ValueError("An empty Q() can't be used as a When() condition.")
  1146. super().__init__(output_field=None)
  1147. self.condition = condition
  1148. self.result = self._parse_expressions(then)[0]
  1149. def __str__(self):
  1150. return "WHEN %r THEN %r" % (self.condition, self.result)
  1151. def __repr__(self):
  1152. return "<%s: %s>" % (self.__class__.__name__, self)
  1153. def get_source_expressions(self):
  1154. return [self.condition, self.result]
  1155. def set_source_expressions(self, exprs):
  1156. self.condition, self.result = exprs
  1157. def get_source_fields(self):
  1158. # We're only interested in the fields of the result expressions.
  1159. return [self.result._output_field_or_none]
  1160. def resolve_expression(
  1161. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  1162. ):
  1163. c = self.copy()
  1164. c.is_summary = summarize
  1165. if hasattr(c.condition, "resolve_expression"):
  1166. c.condition = c.condition.resolve_expression(
  1167. query, allow_joins, reuse, summarize, False
  1168. )
  1169. c.result = c.result.resolve_expression(
  1170. query, allow_joins, reuse, summarize, for_save
  1171. )
  1172. return c
  1173. def as_sql(self, compiler, connection, template=None, **extra_context):
  1174. connection.ops.check_expression_support(self)
  1175. template_params = extra_context
  1176. sql_params = []
  1177. condition_sql, condition_params = compiler.compile(self.condition)
  1178. template_params["condition"] = condition_sql
  1179. result_sql, result_params = compiler.compile(self.result)
  1180. template_params["result"] = result_sql
  1181. template = template or self.template
  1182. return template % template_params, (
  1183. *sql_params,
  1184. *condition_params,
  1185. *result_params,
  1186. )
  1187. def get_group_by_cols(self):
  1188. # This is not a complete expression and cannot be used in GROUP BY.
  1189. cols = []
  1190. for source in self.get_source_expressions():
  1191. cols.extend(source.get_group_by_cols())
  1192. return cols
  1193. @deconstructible(path="django.db.models.Case")
  1194. class Case(SQLiteNumericMixin, Expression):
  1195. """
  1196. An SQL searched CASE expression:
  1197. CASE
  1198. WHEN n > 0
  1199. THEN 'positive'
  1200. WHEN n < 0
  1201. THEN 'negative'
  1202. ELSE 'zero'
  1203. END
  1204. """
  1205. template = "CASE %(cases)s ELSE %(default)s END"
  1206. case_joiner = " "
  1207. def __init__(self, *cases, default=None, output_field=None, **extra):
  1208. if not all(isinstance(case, When) for case in cases):
  1209. raise TypeError("Positional arguments must all be When objects.")
  1210. super().__init__(output_field)
  1211. self.cases = list(cases)
  1212. self.default = self._parse_expressions(default)[0]
  1213. self.extra = extra
  1214. def __str__(self):
  1215. return "CASE %s, ELSE %r" % (
  1216. ", ".join(str(c) for c in self.cases),
  1217. self.default,
  1218. )
  1219. def __repr__(self):
  1220. return "<%s: %s>" % (self.__class__.__name__, self)
  1221. def get_source_expressions(self):
  1222. return self.cases + [self.default]
  1223. def set_source_expressions(self, exprs):
  1224. *self.cases, self.default = exprs
  1225. def resolve_expression(
  1226. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  1227. ):
  1228. c = self.copy()
  1229. c.is_summary = summarize
  1230. for pos, case in enumerate(c.cases):
  1231. c.cases[pos] = case.resolve_expression(
  1232. query, allow_joins, reuse, summarize, for_save
  1233. )
  1234. c.default = c.default.resolve_expression(
  1235. query, allow_joins, reuse, summarize, for_save
  1236. )
  1237. return c
  1238. def copy(self):
  1239. c = super().copy()
  1240. c.cases = c.cases[:]
  1241. return c
  1242. def as_sql(
  1243. self, compiler, connection, template=None, case_joiner=None, **extra_context
  1244. ):
  1245. connection.ops.check_expression_support(self)
  1246. if not self.cases:
  1247. return compiler.compile(self.default)
  1248. template_params = {**self.extra, **extra_context}
  1249. case_parts = []
  1250. sql_params = []
  1251. default_sql, default_params = compiler.compile(self.default)
  1252. for case in self.cases:
  1253. try:
  1254. case_sql, case_params = compiler.compile(case)
  1255. except EmptyResultSet:
  1256. continue
  1257. except FullResultSet:
  1258. default_sql, default_params = compiler.compile(case.result)
  1259. break
  1260. case_parts.append(case_sql)
  1261. sql_params.extend(case_params)
  1262. if not case_parts:
  1263. return default_sql, default_params
  1264. case_joiner = case_joiner or self.case_joiner
  1265. template_params["cases"] = case_joiner.join(case_parts)
  1266. template_params["default"] = default_sql
  1267. sql_params.extend(default_params)
  1268. template = template or template_params.get("template", self.template)
  1269. sql = template % template_params
  1270. if self._output_field_or_none is not None:
  1271. sql = connection.ops.unification_cast_sql(self.output_field) % sql
  1272. return sql, sql_params
  1273. def get_group_by_cols(self):
  1274. if not self.cases:
  1275. return self.default.get_group_by_cols()
  1276. return super().get_group_by_cols()
  1277. class Subquery(BaseExpression, Combinable):
  1278. """
  1279. An explicit subquery. It may contain OuterRef() references to the outer
  1280. query which will be resolved when it is applied to that query.
  1281. """
  1282. template = "(%(subquery)s)"
  1283. contains_aggregate = False
  1284. empty_result_set_value = None
  1285. subquery = True
  1286. def __init__(self, queryset, output_field=None, **extra):
  1287. # Allow the usage of both QuerySet and sql.Query objects.
  1288. self.query = getattr(queryset, "query", queryset).clone()
  1289. self.query.subquery = True
  1290. self.extra = extra
  1291. super().__init__(output_field)
  1292. def get_source_expressions(self):
  1293. return [self.query]
  1294. def set_source_expressions(self, exprs):
  1295. self.query = exprs[0]
  1296. def _resolve_output_field(self):
  1297. return self.query.output_field
  1298. def copy(self):
  1299. clone = super().copy()
  1300. clone.query = clone.query.clone()
  1301. return clone
  1302. @property
  1303. def external_aliases(self):
  1304. return self.query.external_aliases
  1305. def get_external_cols(self):
  1306. return self.query.get_external_cols()
  1307. def as_sql(self, compiler, connection, template=None, **extra_context):
  1308. connection.ops.check_expression_support(self)
  1309. template_params = {**self.extra, **extra_context}
  1310. subquery_sql, sql_params = self.query.as_sql(compiler, connection)
  1311. template_params["subquery"] = subquery_sql[1:-1]
  1312. template = template or template_params.get("template", self.template)
  1313. sql = template % template_params
  1314. return sql, sql_params
  1315. def get_group_by_cols(self):
  1316. return self.query.get_group_by_cols(wrapper=self)
  1317. class Exists(Subquery):
  1318. template = "EXISTS(%(subquery)s)"
  1319. output_field = fields.BooleanField()
  1320. empty_result_set_value = False
  1321. def __init__(self, queryset, **kwargs):
  1322. super().__init__(queryset, **kwargs)
  1323. self.query = self.query.exists()
  1324. def select_format(self, compiler, sql, params):
  1325. # Wrap EXISTS() with a CASE WHEN expression if a database backend
  1326. # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
  1327. # BY list.
  1328. if not compiler.connection.features.supports_boolean_expr_in_select_clause:
  1329. sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql)
  1330. return sql, params
  1331. @deconstructible(path="django.db.models.OrderBy")
  1332. class OrderBy(Expression):
  1333. template = "%(expression)s %(ordering)s"
  1334. conditional = False
  1335. def __init__(self, expression, descending=False, nulls_first=None, nulls_last=None):
  1336. if nulls_first and nulls_last:
  1337. raise ValueError("nulls_first and nulls_last are mutually exclusive")
  1338. if nulls_first is False or nulls_last is False:
  1339. # When the deprecation ends, replace with:
  1340. # raise ValueError(
  1341. # "nulls_first and nulls_last values must be True or None."
  1342. # )
  1343. warnings.warn(
  1344. "Passing nulls_first=False or nulls_last=False is deprecated, use None "
  1345. "instead.",
  1346. RemovedInDjango50Warning,
  1347. stacklevel=2,
  1348. )
  1349. self.nulls_first = nulls_first
  1350. self.nulls_last = nulls_last
  1351. self.descending = descending
  1352. if not hasattr(expression, "resolve_expression"):
  1353. raise ValueError("expression must be an expression type")
  1354. self.expression = expression
  1355. def __repr__(self):
  1356. return "{}({}, descending={})".format(
  1357. self.__class__.__name__, self.expression, self.descending
  1358. )
  1359. def set_source_expressions(self, exprs):
  1360. self.expression = exprs[0]
  1361. def get_source_expressions(self):
  1362. return [self.expression]
  1363. def as_sql(self, compiler, connection, template=None, **extra_context):
  1364. template = template or self.template
  1365. if connection.features.supports_order_by_nulls_modifier:
  1366. if self.nulls_last:
  1367. template = "%s NULLS LAST" % template
  1368. elif self.nulls_first:
  1369. template = "%s NULLS FIRST" % template
  1370. else:
  1371. if self.nulls_last and not (
  1372. self.descending and connection.features.order_by_nulls_first
  1373. ):
  1374. template = "%%(expression)s IS NULL, %s" % template
  1375. elif self.nulls_first and not (
  1376. not self.descending and connection.features.order_by_nulls_first
  1377. ):
  1378. template = "%%(expression)s IS NOT NULL, %s" % template
  1379. connection.ops.check_expression_support(self)
  1380. expression_sql, params = compiler.compile(self.expression)
  1381. placeholders = {
  1382. "expression": expression_sql,
  1383. "ordering": "DESC" if self.descending else "ASC",
  1384. **extra_context,
  1385. }
  1386. params *= template.count("%(expression)s")
  1387. return (template % placeholders).rstrip(), params
  1388. def as_oracle(self, compiler, connection):
  1389. # Oracle doesn't allow ORDER BY EXISTS() or filters unless it's wrapped
  1390. # in a CASE WHEN.
  1391. if connection.ops.conditional_expression_supported_in_where_clause(
  1392. self.expression
  1393. ):
  1394. copy = self.copy()
  1395. copy.expression = Case(
  1396. When(self.expression, then=True),
  1397. default=False,
  1398. )
  1399. return copy.as_sql(compiler, connection)
  1400. return self.as_sql(compiler, connection)
  1401. def get_group_by_cols(self):
  1402. cols = []
  1403. for source in self.get_source_expressions():
  1404. cols.extend(source.get_group_by_cols())
  1405. return cols
  1406. def reverse_ordering(self):
  1407. self.descending = not self.descending
  1408. if self.nulls_first:
  1409. self.nulls_last = True
  1410. self.nulls_first = None
  1411. elif self.nulls_last:
  1412. self.nulls_first = True
  1413. self.nulls_last = None
  1414. return self
  1415. def asc(self):
  1416. self.descending = False
  1417. def desc(self):
  1418. self.descending = True
  1419. class Window(SQLiteNumericMixin, Expression):
  1420. template = "%(expression)s OVER (%(window)s)"
  1421. # Although the main expression may either be an aggregate or an
  1422. # expression with an aggregate function, the GROUP BY that will
  1423. # be introduced in the query as a result is not desired.
  1424. contains_aggregate = False
  1425. contains_over_clause = True
  1426. def __init__(
  1427. self,
  1428. expression,
  1429. partition_by=None,
  1430. order_by=None,
  1431. frame=None,
  1432. output_field=None,
  1433. ):
  1434. self.partition_by = partition_by
  1435. self.order_by = order_by
  1436. self.frame = frame
  1437. if not getattr(expression, "window_compatible", False):
  1438. raise ValueError(
  1439. "Expression '%s' isn't compatible with OVER clauses."
  1440. % expression.__class__.__name__
  1441. )
  1442. if self.partition_by is not None:
  1443. if not isinstance(self.partition_by, (tuple, list)):
  1444. self.partition_by = (self.partition_by,)
  1445. self.partition_by = ExpressionList(*self.partition_by)
  1446. if self.order_by is not None:
  1447. if isinstance(self.order_by, (list, tuple)):
  1448. self.order_by = OrderByList(*self.order_by)
  1449. elif isinstance(self.order_by, (BaseExpression, str)):
  1450. self.order_by = OrderByList(self.order_by)
  1451. else:
  1452. raise ValueError(
  1453. "Window.order_by must be either a string reference to a "
  1454. "field, an expression, or a list or tuple of them."
  1455. )
  1456. super().__init__(output_field=output_field)
  1457. self.source_expression = self._parse_expressions(expression)[0]
  1458. def _resolve_output_field(self):
  1459. return self.source_expression.output_field
  1460. def get_source_expressions(self):
  1461. return [self.source_expression, self.partition_by, self.order_by, self.frame]
  1462. def set_source_expressions(self, exprs):
  1463. self.source_expression, self.partition_by, self.order_by, self.frame = exprs
  1464. def as_sql(self, compiler, connection, template=None):
  1465. connection.ops.check_expression_support(self)
  1466. if not connection.features.supports_over_clause:
  1467. raise NotSupportedError("This backend does not support window expressions.")
  1468. expr_sql, params = compiler.compile(self.source_expression)
  1469. window_sql, window_params = [], ()
  1470. if self.partition_by is not None:
  1471. sql_expr, sql_params = self.partition_by.as_sql(
  1472. compiler=compiler,
  1473. connection=connection,
  1474. template="PARTITION BY %(expressions)s",
  1475. )
  1476. window_sql.append(sql_expr)
  1477. window_params += tuple(sql_params)
  1478. if self.order_by is not None:
  1479. order_sql, order_params = compiler.compile(self.order_by)
  1480. window_sql.append(order_sql)
  1481. window_params += tuple(order_params)
  1482. if self.frame:
  1483. frame_sql, frame_params = compiler.compile(self.frame)
  1484. window_sql.append(frame_sql)
  1485. window_params += tuple(frame_params)
  1486. template = template or self.template
  1487. return (
  1488. template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
  1489. (*params, *window_params),
  1490. )
  1491. def as_sqlite(self, compiler, connection):
  1492. if isinstance(self.output_field, fields.DecimalField):
  1493. # Casting to numeric must be outside of the window expression.
  1494. copy = self.copy()
  1495. source_expressions = copy.get_source_expressions()
  1496. source_expressions[0].output_field = fields.FloatField()
  1497. copy.set_source_expressions(source_expressions)
  1498. return super(Window, copy).as_sqlite(compiler, connection)
  1499. return self.as_sql(compiler, connection)
  1500. def __str__(self):
  1501. return "{} OVER ({}{}{})".format(
  1502. str(self.source_expression),
  1503. "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
  1504. str(self.order_by or ""),
  1505. str(self.frame or ""),
  1506. )
  1507. def __repr__(self):
  1508. return "<%s: %s>" % (self.__class__.__name__, self)
  1509. def get_group_by_cols(self):
  1510. group_by_cols = []
  1511. if self.partition_by:
  1512. group_by_cols.extend(self.partition_by.get_group_by_cols())
  1513. if self.order_by is not None:
  1514. group_by_cols.extend(self.order_by.get_group_by_cols())
  1515. return group_by_cols
  1516. class WindowFrame(Expression):
  1517. """
  1518. Model the frame clause in window expressions. There are two types of frame
  1519. clauses which are subclasses, however, all processing and validation (by no
  1520. means intended to be complete) is done here. Thus, providing an end for a
  1521. frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
  1522. row in the frame).
  1523. """
  1524. template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
  1525. def __init__(self, start=None, end=None):
  1526. self.start = Value(start)
  1527. self.end = Value(end)
  1528. def set_source_expressions(self, exprs):
  1529. self.start, self.end = exprs
  1530. def get_source_expressions(self):
  1531. return [self.start, self.end]
  1532. def as_sql(self, compiler, connection):
  1533. connection.ops.check_expression_support(self)
  1534. start, end = self.window_frame_start_end(
  1535. connection, self.start.value, self.end.value
  1536. )
  1537. return (
  1538. self.template
  1539. % {
  1540. "frame_type": self.frame_type,
  1541. "start": start,
  1542. "end": end,
  1543. },
  1544. [],
  1545. )
  1546. def __repr__(self):
  1547. return "<%s: %s>" % (self.__class__.__name__, self)
  1548. def get_group_by_cols(self):
  1549. return []
  1550. def __str__(self):
  1551. if self.start.value is not None and self.start.value < 0:
  1552. start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING)
  1553. elif self.start.value is not None and self.start.value == 0:
  1554. start = connection.ops.CURRENT_ROW
  1555. else:
  1556. start = connection.ops.UNBOUNDED_PRECEDING
  1557. if self.end.value is not None and self.end.value > 0:
  1558. end = "%d %s" % (self.end.value, connection.ops.FOLLOWING)
  1559. elif self.end.value is not None and self.end.value == 0:
  1560. end = connection.ops.CURRENT_ROW
  1561. else:
  1562. end = connection.ops.UNBOUNDED_FOLLOWING
  1563. return self.template % {
  1564. "frame_type": self.frame_type,
  1565. "start": start,
  1566. "end": end,
  1567. }
  1568. def window_frame_start_end(self, connection, start, end):
  1569. raise NotImplementedError("Subclasses must implement window_frame_start_end().")
  1570. class RowRange(WindowFrame):
  1571. frame_type = "ROWS"
  1572. def window_frame_start_end(self, connection, start, end):
  1573. return connection.ops.window_frame_rows_start_end(start, end)
  1574. class ValueRange(WindowFrame):
  1575. frame_type = "RANGE"
  1576. def window_frame_start_end(self, connection, start, end):
  1577. return connection.ops.window_frame_range_start_end(start, end)