grouping.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. #
  2. # Copyright (C) 2009-2020 the sqlparse authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of python-sqlparse and is released under
  6. # the BSD License: https://opensource.org/licenses/BSD-3-Clause
  7. from sqlparse import sql
  8. from sqlparse import tokens as T
  9. from sqlparse.exceptions import SQLParseError
  10. from sqlparse.utils import recurse, imt
  11. # Maximum recursion depth for grouping operations to prevent DoS attacks
  12. # Set to None to disable limit (not recommended for untrusted input)
  13. MAX_GROUPING_DEPTH = 100
  14. # Maximum number of tokens to process in one grouping operation to prevent
  15. # DoS attacks.
  16. # Set to None to disable limit (not recommended for untrusted input)
  17. MAX_GROUPING_TOKENS = 10000
  18. T_NUMERICAL = (T.Number, T.Number.Integer, T.Number.Float)
  19. T_STRING = (T.String, T.String.Single, T.String.Symbol)
  20. T_NAME = (T.Name, T.Name.Placeholder)
  21. def _group_matching(tlist, cls, depth=0):
  22. """Groups Tokens that have beginning and end."""
  23. if MAX_GROUPING_DEPTH is not None and depth > MAX_GROUPING_DEPTH:
  24. raise SQLParseError(
  25. f"Maximum grouping depth exceeded ({MAX_GROUPING_DEPTH})."
  26. )
  27. # Limit the number of tokens to prevent DoS attacks
  28. if MAX_GROUPING_TOKENS is not None \
  29. and len(tlist.tokens) > MAX_GROUPING_TOKENS:
  30. raise SQLParseError(
  31. f"Maximum number of tokens exceeded ({MAX_GROUPING_TOKENS})."
  32. )
  33. opens = []
  34. tidx_offset = 0
  35. token_list = list(tlist)
  36. for idx, token in enumerate(token_list):
  37. tidx = idx - tidx_offset
  38. if token.is_whitespace:
  39. # ~50% of tokens will be whitespace. Will checking early
  40. # for them avoid 3 comparisons, but then add 1 more comparison
  41. # for the other ~50% of tokens...
  42. continue
  43. if token.is_group and not isinstance(token, cls):
  44. # Check inside previously grouped (i.e. parenthesis) if group
  45. # of different type is inside (i.e., case). though ideally should
  46. # should check for all open/close tokens at once to avoid recursion
  47. _group_matching(token, cls, depth + 1)
  48. continue
  49. if token.match(*cls.M_OPEN):
  50. opens.append(tidx)
  51. elif token.match(*cls.M_CLOSE):
  52. try:
  53. open_idx = opens.pop()
  54. except IndexError:
  55. # this indicates invalid sql and unbalanced tokens.
  56. # instead of break, continue in case other "valid" groups exist
  57. continue
  58. close_idx = tidx
  59. tlist.group_tokens(cls, open_idx, close_idx)
  60. tidx_offset += close_idx - open_idx
  61. def group_brackets(tlist):
  62. _group_matching(tlist, sql.SquareBrackets)
  63. def group_parenthesis(tlist):
  64. _group_matching(tlist, sql.Parenthesis)
  65. def group_case(tlist):
  66. _group_matching(tlist, sql.Case)
  67. def group_if(tlist):
  68. _group_matching(tlist, sql.If)
  69. def group_for(tlist):
  70. _group_matching(tlist, sql.For)
  71. def group_begin(tlist):
  72. _group_matching(tlist, sql.Begin)
  73. def group_typecasts(tlist):
  74. def match(token):
  75. return token.match(T.Punctuation, '::')
  76. def valid(token):
  77. return token is not None
  78. def post(tlist, pidx, tidx, nidx):
  79. return pidx, nidx
  80. valid_prev = valid_next = valid
  81. _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
  82. def group_tzcasts(tlist):
  83. def match(token):
  84. return token.ttype == T.Keyword.TZCast
  85. def valid_prev(token):
  86. return token is not None
  87. def valid_next(token):
  88. return token is not None and (
  89. token.is_whitespace
  90. or token.match(T.Keyword, 'AS')
  91. or token.match(*sql.TypedLiteral.M_CLOSE)
  92. )
  93. def post(tlist, pidx, tidx, nidx):
  94. return pidx, nidx
  95. _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
  96. def group_typed_literal(tlist):
  97. # definitely not complete, see e.g.:
  98. # https://docs.microsoft.com/en-us/sql/odbc/reference/appendixes/interval-literal-syntax
  99. # https://docs.microsoft.com/en-us/sql/odbc/reference/appendixes/interval-literals
  100. # https://www.postgresql.org/docs/9.1/datatype-datetime.html
  101. # https://www.postgresql.org/docs/9.1/functions-datetime.html
  102. def match(token):
  103. return imt(token, m=sql.TypedLiteral.M_OPEN)
  104. def match_to_extend(token):
  105. return isinstance(token, sql.TypedLiteral)
  106. def valid_prev(token):
  107. return token is not None
  108. def valid_next(token):
  109. return token is not None and token.match(*sql.TypedLiteral.M_CLOSE)
  110. def valid_final(token):
  111. return token is not None and token.match(*sql.TypedLiteral.M_EXTEND)
  112. def post(tlist, pidx, tidx, nidx):
  113. return tidx, nidx
  114. _group(tlist, sql.TypedLiteral, match, valid_prev, valid_next,
  115. post, extend=False)
  116. _group(tlist, sql.TypedLiteral, match_to_extend, valid_prev, valid_final,
  117. post, extend=True)
  118. def group_period(tlist):
  119. def match(token):
  120. for ttype, value in ((T.Punctuation, '.'),
  121. (T.Operator, '->'),
  122. (T.Operator, '->>')):
  123. if token.match(ttype, value):
  124. return True
  125. return False
  126. def valid_prev(token):
  127. sqlcls = sql.SquareBrackets, sql.Identifier
  128. ttypes = T.Name, T.String.Symbol
  129. return imt(token, i=sqlcls, t=ttypes)
  130. def valid_next(token):
  131. # issue261, allow invalid next token
  132. return True
  133. def post(tlist, pidx, tidx, nidx):
  134. # next_ validation is being performed here. issue261
  135. sqlcls = sql.SquareBrackets, sql.Function
  136. ttypes = T.Name, T.String.Symbol, T.Wildcard, T.String.Single
  137. next_ = tlist[nidx] if nidx is not None else None
  138. valid_next = imt(next_, i=sqlcls, t=ttypes)
  139. return (pidx, nidx) if valid_next else (pidx, tidx)
  140. _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
  141. def group_as(tlist):
  142. def match(token):
  143. return token.is_keyword and token.normalized == 'AS'
  144. def valid_prev(token):
  145. return token.normalized == 'NULL' or not token.is_keyword
  146. def valid_next(token):
  147. ttypes = T.DML, T.DDL, T.CTE
  148. return not imt(token, t=ttypes) and token is not None
  149. def post(tlist, pidx, tidx, nidx):
  150. return pidx, nidx
  151. _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
  152. def group_assignment(tlist):
  153. def match(token):
  154. return token.match(T.Assignment, ':=')
  155. def valid(token):
  156. return token is not None and token.ttype not in (T.Keyword,)
  157. def post(tlist, pidx, tidx, nidx):
  158. m_semicolon = T.Punctuation, ';'
  159. snidx, _ = tlist.token_next_by(m=m_semicolon, idx=nidx)
  160. nidx = snidx or nidx
  161. return pidx, nidx
  162. valid_prev = valid_next = valid
  163. _group(tlist, sql.Assignment, match, valid_prev, valid_next, post)
  164. def group_comparison(tlist):
  165. sqlcls = (sql.Parenthesis, sql.Function, sql.Identifier,
  166. sql.Operation, sql.TypedLiteral)
  167. ttypes = T_NUMERICAL + T_STRING + T_NAME
  168. def match(token):
  169. return token.ttype == T.Operator.Comparison
  170. def valid(token):
  171. if imt(token, t=ttypes, i=sqlcls):
  172. return True
  173. elif token and token.is_keyword and token.normalized == 'NULL':
  174. return True
  175. else:
  176. return False
  177. def post(tlist, pidx, tidx, nidx):
  178. return pidx, nidx
  179. valid_prev = valid_next = valid
  180. _group(tlist, sql.Comparison, match,
  181. valid_prev, valid_next, post, extend=False)
  182. @recurse(sql.Identifier)
  183. def group_identifier(tlist):
  184. ttypes = (T.String.Symbol, T.Name)
  185. tidx, token = tlist.token_next_by(t=ttypes)
  186. while token:
  187. tlist.group_tokens(sql.Identifier, tidx, tidx)
  188. tidx, token = tlist.token_next_by(t=ttypes, idx=tidx)
  189. @recurse(sql.Over)
  190. def group_over(tlist):
  191. tidx, token = tlist.token_next_by(m=sql.Over.M_OPEN)
  192. while token:
  193. nidx, next_ = tlist.token_next(tidx)
  194. if imt(next_, i=sql.Parenthesis, t=T.Name):
  195. tlist.group_tokens(sql.Over, tidx, nidx)
  196. tidx, token = tlist.token_next_by(m=sql.Over.M_OPEN, idx=tidx)
  197. def group_arrays(tlist):
  198. sqlcls = sql.SquareBrackets, sql.Identifier, sql.Function
  199. ttypes = T.Name, T.String.Symbol
  200. def match(token):
  201. return isinstance(token, sql.SquareBrackets)
  202. def valid_prev(token):
  203. return imt(token, i=sqlcls, t=ttypes)
  204. def valid_next(token):
  205. return True
  206. def post(tlist, pidx, tidx, nidx):
  207. return pidx, tidx
  208. _group(tlist, sql.Identifier, match,
  209. valid_prev, valid_next, post, extend=True, recurse=False)
  210. def group_operator(tlist):
  211. ttypes = T_NUMERICAL + T_STRING + T_NAME
  212. sqlcls = (sql.SquareBrackets, sql.Parenthesis, sql.Function,
  213. sql.Identifier, sql.Operation, sql.TypedLiteral)
  214. def match(token):
  215. return imt(token, t=(T.Operator, T.Wildcard))
  216. def valid(token):
  217. return imt(token, i=sqlcls, t=ttypes) \
  218. or (token and token.match(
  219. T.Keyword,
  220. ('CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP')))
  221. def post(tlist, pidx, tidx, nidx):
  222. tlist[tidx].ttype = T.Operator
  223. return pidx, nidx
  224. valid_prev = valid_next = valid
  225. _group(tlist, sql.Operation, match,
  226. valid_prev, valid_next, post, extend=False)
  227. def group_identifier_list(tlist):
  228. m_role = T.Keyword, ('null', 'role')
  229. sqlcls = (sql.Function, sql.Case, sql.Identifier, sql.Comparison,
  230. sql.IdentifierList, sql.Operation)
  231. ttypes = (T_NUMERICAL + T_STRING + T_NAME
  232. + (T.Keyword, T.Comment, T.Wildcard))
  233. def match(token):
  234. return token.match(T.Punctuation, ',')
  235. def valid(token):
  236. return imt(token, i=sqlcls, m=m_role, t=ttypes)
  237. def post(tlist, pidx, tidx, nidx):
  238. return pidx, nidx
  239. valid_prev = valid_next = valid
  240. _group(tlist, sql.IdentifierList, match,
  241. valid_prev, valid_next, post, extend=True)
  242. @recurse(sql.Comment)
  243. def group_comments(tlist):
  244. tidx, token = tlist.token_next_by(t=T.Comment)
  245. while token:
  246. eidx, end = tlist.token_not_matching(
  247. lambda tk: imt(tk, t=T.Comment) or tk.is_newline, idx=tidx)
  248. if end is not None:
  249. eidx, end = tlist.token_prev(eidx, skip_ws=False)
  250. tlist.group_tokens(sql.Comment, tidx, eidx)
  251. tidx, token = tlist.token_next_by(t=T.Comment, idx=tidx)
  252. @recurse(sql.Where)
  253. def group_where(tlist):
  254. tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN)
  255. while token:
  256. eidx, end = tlist.token_next_by(m=sql.Where.M_CLOSE, idx=tidx)
  257. if end is None:
  258. end = tlist._groupable_tokens[-1]
  259. else:
  260. end = tlist.tokens[eidx - 1]
  261. # TODO: convert this to eidx instead of end token.
  262. # i think above values are len(tlist) and eidx-1
  263. eidx = tlist.token_index(end)
  264. tlist.group_tokens(sql.Where, tidx, eidx)
  265. tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN, idx=tidx)
  266. @recurse()
  267. def group_aliased(tlist):
  268. I_ALIAS = (sql.Parenthesis, sql.Function, sql.Case, sql.Identifier,
  269. sql.Operation, sql.Comparison)
  270. tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number)
  271. while token:
  272. nidx, next_ = tlist.token_next(tidx)
  273. if isinstance(next_, sql.Identifier):
  274. tlist.group_tokens(sql.Identifier, tidx, nidx, extend=True)
  275. tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number, idx=tidx)
  276. @recurse(sql.Function)
  277. def group_functions(tlist):
  278. has_create = False
  279. has_table = False
  280. has_as = False
  281. for tmp_token in tlist.tokens:
  282. if tmp_token.value.upper() == 'CREATE':
  283. has_create = True
  284. if tmp_token.value.upper() == 'TABLE':
  285. has_table = True
  286. if tmp_token.value == 'AS':
  287. has_as = True
  288. if has_create and has_table and not has_as:
  289. return
  290. tidx, token = tlist.token_next_by(t=T.Name)
  291. while token:
  292. nidx, next_ = tlist.token_next(tidx)
  293. if isinstance(next_, sql.Parenthesis):
  294. over_idx, over = tlist.token_next(nidx)
  295. if over and isinstance(over, sql.Over):
  296. eidx = over_idx
  297. else:
  298. eidx = nidx
  299. tlist.group_tokens(sql.Function, tidx, eidx)
  300. tidx, token = tlist.token_next_by(t=T.Name, idx=tidx)
  301. @recurse(sql.Identifier)
  302. def group_order(tlist):
  303. """Group together Identifier and Asc/Desc token"""
  304. tidx, token = tlist.token_next_by(t=T.Keyword.Order)
  305. while token:
  306. pidx, prev_ = tlist.token_prev(tidx)
  307. if imt(prev_, i=sql.Identifier, t=T.Number):
  308. tlist.group_tokens(sql.Identifier, pidx, tidx)
  309. tidx = pidx
  310. tidx, token = tlist.token_next_by(t=T.Keyword.Order, idx=tidx)
  311. @recurse()
  312. def align_comments(tlist):
  313. tidx, token = tlist.token_next_by(i=sql.Comment)
  314. while token:
  315. pidx, prev_ = tlist.token_prev(tidx)
  316. if isinstance(prev_, sql.TokenList):
  317. tlist.group_tokens(sql.TokenList, pidx, tidx, extend=True)
  318. tidx = pidx
  319. tidx, token = tlist.token_next_by(i=sql.Comment, idx=tidx)
  320. def group_values(tlist):
  321. tidx, token = tlist.token_next_by(m=(T.Keyword, 'VALUES'))
  322. start_idx = tidx
  323. end_idx = -1
  324. while token:
  325. if isinstance(token, sql.Parenthesis):
  326. end_idx = tidx
  327. tidx, token = tlist.token_next(tidx)
  328. if end_idx != -1:
  329. tlist.group_tokens(sql.Values, start_idx, end_idx, extend=True)
  330. def group(stmt):
  331. for func in [
  332. group_comments,
  333. # _group_matching
  334. group_brackets,
  335. group_parenthesis,
  336. group_case,
  337. group_if,
  338. group_for,
  339. group_begin,
  340. group_over,
  341. group_functions,
  342. group_where,
  343. group_period,
  344. group_arrays,
  345. group_identifier,
  346. group_order,
  347. group_typecasts,
  348. group_tzcasts,
  349. group_typed_literal,
  350. group_operator,
  351. group_comparison,
  352. group_as,
  353. group_aliased,
  354. group_assignment,
  355. align_comments,
  356. group_identifier_list,
  357. group_values,
  358. ]:
  359. func(stmt)
  360. return stmt
  361. def _group(tlist, cls, match,
  362. valid_prev=lambda t: True,
  363. valid_next=lambda t: True,
  364. post=None,
  365. extend=True,
  366. recurse=True,
  367. depth=0
  368. ):
  369. """Groups together tokens that are joined by a middle token. i.e. x < y"""
  370. if MAX_GROUPING_DEPTH is not None and depth > MAX_GROUPING_DEPTH:
  371. raise SQLParseError(
  372. f"Maximum grouping depth exceeded ({MAX_GROUPING_DEPTH})."
  373. )
  374. # Limit the number of tokens to prevent DoS attacks
  375. if MAX_GROUPING_TOKENS is not None \
  376. and len(tlist.tokens) > MAX_GROUPING_TOKENS:
  377. raise SQLParseError(
  378. f"Maximum number of tokens exceeded ({MAX_GROUPING_TOKENS})."
  379. )
  380. tidx_offset = 0
  381. pidx, prev_ = None, None
  382. token_list = list(tlist)
  383. for idx, token in enumerate(token_list):
  384. tidx = idx - tidx_offset
  385. if tidx < 0: # tidx shouldn't get negative
  386. continue
  387. if token.is_whitespace:
  388. continue
  389. if recurse and token.is_group and not isinstance(token, cls):
  390. _group(token, cls, match, valid_prev, valid_next,
  391. post, extend, True, depth + 1)
  392. if match(token):
  393. nidx, next_ = tlist.token_next(tidx)
  394. if prev_ and valid_prev(prev_) and valid_next(next_):
  395. from_idx, to_idx = post(tlist, pidx, tidx, nidx)
  396. grp = tlist.group_tokens(cls, from_idx, to_idx, extend=extend)
  397. tidx_offset += to_idx - from_idx
  398. pidx, prev_ = from_idx, grp
  399. continue
  400. pidx, prev_ = tidx, token