feat(hogql): window functions (#15132)

This commit is contained in:
Marius Andra
2023-05-31 12:53:31 +02:00
committed by GitHub
parent 9d5f9bbe30
commit e23d4f562e
17 changed files with 1414 additions and 969 deletions

View File

@@ -142,9 +142,9 @@ jobs:
- name: Check if antlr definitions are up to date
run: |
# Installing a version of ant compatible with what we use in development from homebrew (4.11)
# Installing a version of ant compatible with what we use in development from homebrew (4.13)
# "apt-get install antlr" would install 4.7 which is incompatible with our grammar.
export ANTLR_VERSION=4.11.1
export ANTLR_VERSION=4.13.0
# java version doesn't matter
sudo apt-get install default-jre
mkdir antlr

View File

@@ -449,6 +449,26 @@ class JoinExpr(Expr):
sample: Optional["SampleExpr"] = None
class WindowFrameExpr(Expr):
frame_type: Optional[Literal["CURRENT ROW", "PRECEDING", "FOLLOWING"]] = None
frame_value: Optional[int] = None
class WindowExpr(Expr):
partition_by: Optional[List[Expr]] = None
order_by: Optional[List[OrderExpr]] = None
frame_method: Optional[Literal["ROWS", "RANGE"]] = None
frame_start: Optional[WindowFrameExpr] = None
frame_end: Optional[WindowFrameExpr] = None
class WindowFunction(Expr):
name: str
args: Optional[List[Expr]] = None
over_expr: Optional[WindowExpr] = None
over_identifier: Optional[str] = None
class SelectQuery(Expr):
# :TRICKY: When adding new fields, make sure they're handled in visitor.py and resolver.py
type: Optional[SelectQueryType] = None
@@ -456,6 +476,7 @@ class SelectQuery(Expr):
select: List[Expr]
distinct: Optional[bool] = None
select_from: Optional[JoinExpr] = None
window_exprs: Optional[Dict[str, WindowExpr]] = None
where: Optional[Expr] = None
prewhere: Optional[Expr] = None
having: Optional[Expr] = None

View File

@@ -585,6 +585,29 @@ HOGQL_AGGREGATIONS = {
"uniqHLL12If": (2, None),
"uniqTheta": (1, None),
"uniqThetaIf": (2, None),
"median": 1,
"medianIf": 2,
"medianExact": 1,
"medianExactIf": 2,
"medianExactLow": 1,
"medianExactLowIf": 2,
"medianExactHigh": 1,
"medianExactHighIf": 2,
"medianExactWeighted": 1,
"medianExactWeightedIf": 2,
"medianTiming": 1,
"medianTimingIf": 2,
"medianTimingWeighted": 1,
"medianTimingWeightedIf": 2,
"medianDeterministic": 1,
"medianDeterministicIf": 2,
"medianTDigest": 1,
"medianTDigestIf": 2,
"medianTDigestWeighted": 1,
"medianTDigestWeightedIf": 2,
"medianBFloat16": 1,
"medianBFloat16If": 2,
# TODO: quantile(0.5)(expr) is not supported
# "quantile": 1,
# "quantileIf": 2,
# "quantiles": 1,

View File

@@ -1,4 +1,4 @@
# Generated from HogQLLexer.g4 by ANTLR 4.11.1
# Generated from HogQLLexer.g4 by ANTLR 4.13.0
from antlr4 import *
from io import StringIO
import sys
@@ -1221,7 +1221,7 @@ class HogQLLexer(Lexer):
def __init__(self, input=None, output:TextIO = sys.stdout):
super().__init__(input, output)
self.checkVersion("4.11.1")
self.checkVersion("4.13.0")
self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache())
self._actions = None
self._predicates = None

View File

@@ -15,11 +15,11 @@ selectStmt:
columns=columnExprList
from=fromClause?
arrayJoinClause?
windowClause?
prewhereClause?
where=whereClause?
groupByClause? (WITH (CUBE | ROLLUP))? (WITH TOTALS)?
havingClause?
windowClause?
orderByClause?
limitClause?
settingsClause?
@@ -29,7 +29,7 @@ withClause: WITH withExprList;
topClause: TOP DECIMAL_LITERAL (WITH TIES)?;
fromClause: FROM joinExpr;
arrayJoinClause: (LEFT | INNER)? ARRAY JOIN columnExprList;
windowClause: WINDOW identifier AS LPAREN windowExpr RPAREN;
windowClause: WINDOW identifier AS LPAREN windowExpr RPAREN (COMMA identifier AS LPAREN windowExpr RPAREN)*;
prewhereClause: PREWHERE columnExpr;
whereClause: WHERE columnExpr;
groupByClause: GROUP BY ((CUBE | ROLLUP) LPAREN columnExprList RPAREN | columnExprList);

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
# Generated from HogQLParser.g4 by ANTLR 4.11.1
# Generated from HogQLParser.g4 by ANTLR 4.13.0
from antlr4 import *
if __name__ is not None and "." in __name__:
if "." in __name__:
from .HogQLParser import HogQLParser
else:
from HogQLParser import HogQLParser

View File

@@ -87,7 +87,6 @@ class HogQLParseTreeConverter(ParseTreeVisitor):
return self.visit(ctx.selectStmt() or ctx.selectUnionStmt())
def visitSelectStmt(self, ctx: HogQLParser.SelectStmtContext):
select_query = ast.SelectQuery(
macros=self.visit(ctx.withClause()) if ctx.withClause() else None,
select=self.visit(ctx.columnExprList()) if ctx.columnExprList() else [],
@@ -100,6 +99,12 @@ class HogQLParseTreeConverter(ParseTreeVisitor):
order_by=self.visit(ctx.orderByClause()) if ctx.orderByClause() else None,
)
if ctx.windowClause():
select_query.window_exprs = {}
for index, window_expr in enumerate(ctx.windowClause().windowExpr()):
name = self.visit(ctx.windowClause().identifier()[index])
select_query.window_exprs[name] = self.visit(window_expr)
if ctx.limitClause():
limit_clause = ctx.limitClause()
limit_expr = limit_clause.limitExpr()
@@ -116,8 +121,6 @@ class HogQLParseTreeConverter(ParseTreeVisitor):
raise NotImplementedException(f"Unsupported: SelectStmt.topClause()")
if ctx.arrayJoinClause():
raise NotImplementedException(f"Unsupported: SelectStmt.arrayJoinClause()")
if ctx.windowClause():
raise NotImplementedException(f"Unsupported: SelectStmt.windowClause()")
if ctx.settingsClause():
raise NotImplementedException(f"Unsupported: SelectStmt.settingsClause()")
@@ -292,25 +295,44 @@ class HogQLParseTreeConverter(ParseTreeVisitor):
raise NotImplementedException(f"Unsupported node: SettingExpr")
def visitWindowExpr(self, ctx: HogQLParser.WindowExprContext):
raise NotImplementedException(f"Unsupported node: WindowExpr")
frame = ctx.winFrameClause()
visited_frame = self.visit(frame) if frame else None
expr = ast.WindowExpr(
partition_by=self.visit(ctx.winPartitionByClause()) if ctx.winPartitionByClause() else None,
order_by=self.visit(ctx.winOrderByClause()) if ctx.winOrderByClause() else None,
frame_method="RANGE" if frame and frame.RANGE() else "ROWS" if frame and frame.ROWS() else None,
frame_start=visited_frame[0] if isinstance(visited_frame, tuple) else visited_frame,
frame_end=visited_frame[1] if isinstance(visited_frame, tuple) else None,
)
return expr
def visitWinPartitionByClause(self, ctx: HogQLParser.WinPartitionByClauseContext):
raise NotImplementedException(f"Unsupported node: WinPartitionByClause")
return self.visit(ctx.columnExprList())
def visitWinOrderByClause(self, ctx: HogQLParser.WinOrderByClauseContext):
raise NotImplementedException(f"Unsupported node: WinOrderByClause")
return self.visit(ctx.orderExprList())
def visitWinFrameClause(self, ctx: HogQLParser.WinFrameClauseContext):
raise NotImplementedException(f"Unsupported node: WinFrameClause")
return self.visit(ctx.winFrameExtend())
def visitFrameStart(self, ctx: HogQLParser.FrameStartContext):
raise NotImplementedException(f"Unsupported node: FrameStart")
return self.visit(ctx.winFrameBound())
def visitFrameBetween(self, ctx: HogQLParser.FrameBetweenContext):
raise NotImplementedException(f"Unsupported node: FrameBetween")
return (self.visit(ctx.winFrameBound(0)), self.visit(ctx.winFrameBound(1)))
def visitWinFrameBound(self, ctx: HogQLParser.WinFrameBoundContext):
raise NotImplementedException(f"Unsupported node: WinFrameBound")
if ctx.PRECEDING():
return ast.WindowFrameExpr(
frame_type="PRECEDING",
frame_value=self.visit(ctx.numberLiteral()).value if ctx.numberLiteral() else None,
)
if ctx.FOLLOWING():
return ast.WindowFrameExpr(
frame_type="FOLLOWING",
frame_value=self.visit(ctx.numberLiteral()).value if ctx.numberLiteral() else None,
)
return ast.WindowFrameExpr(frame_type="CURRENT ROW")
def visitExpr(self, ctx: HogQLParser.ExprContext):
return self.visit(ctx.columnExpr())
@@ -479,9 +501,6 @@ class HogQLParseTreeConverter(ParseTreeVisitor):
op=ast.CompareOperationOp.NotEq if ctx.NOT() else ast.CompareOperationOp.Eq,
)
def visitColumnExprWinFunctionTarget(self, ctx: HogQLParser.ColumnExprWinFunctionTargetContext):
raise NotImplementedException(f"Unsupported node: ColumnExprWinFunctionTarget")
def visitColumnExprTrim(self, ctx: HogQLParser.ColumnExprTrimContext):
raise NotImplementedException(f"Unsupported node: ColumnExprTrim")
@@ -557,8 +576,19 @@ class HogQLParseTreeConverter(ParseTreeVisitor):
def visitColumnExprNot(self, ctx: HogQLParser.ColumnExprNotContext):
return ast.Not(expr=self.visit(ctx.columnExpr()))
def visitColumnExprWinFunctionTarget(self, ctx: HogQLParser.ColumnExprWinFunctionTargetContext):
return ast.WindowFunction(
name=self.visit(ctx.identifier(0)),
args=self.visit(ctx.columnExprList()) if ctx.columnExprList() else [],
over_identifier=self.visit(ctx.identifier(1)),
)
def visitColumnExprWinFunction(self, ctx: HogQLParser.ColumnExprWinFunctionContext):
raise NotImplementedException(f"Unsupported node: ColumnExprWinFunction")
return ast.WindowFunction(
name=self.visit(ctx.identifier()),
args=self.visit(ctx.columnExprList()) if ctx.columnExprList() else [],
over_expr=self.visit(ctx.windowExpr()) if ctx.windowExpr() else None,
)
def visitColumnExprIdentifier(self, ctx: HogQLParser.ColumnExprIdentifierContext):
return self.visit(ctx.columnIdentifier())

View File

@@ -174,10 +174,17 @@ class _Printer(Visitor):
next_join = next_join.next_join
columns = [self.visit(column) for column in node.select] if node.select else ["1"]
where = self.visit(where) if where else None
having = self.visit(node.having) if node.having else None
window = (
", ".join(
[f"{self._print_identifier(name)} AS ({self.visit(expr)})" for name, expr in node.window_exprs.items()]
)
if node.window_exprs
else None
)
prewhere = self.visit(node.prewhere) if node.prewhere else None
where = self.visit(where) if where else None
group_by = [self.visit(column) for column in node.group_by] if node.group_by else None
having = self.visit(node.having) if node.having else None
order_by = [self.visit(column) for column in node.order_by] if node.order_by else None
clauses = [
@@ -187,6 +194,7 @@ class _Printer(Visitor):
"WHERE " + where if where else None,
f"GROUP BY {', '.join(group_by)}" if group_by and len(group_by) > 0 else None,
"HAVING " + having if having else None,
"WINDOW " + window if window else None,
f"ORDER BY {', '.join(order_by)}" if order_by and len(order_by) > 0 else None,
]
@@ -433,7 +441,6 @@ class _Printer(Visitor):
translated_args = ", ".join([self.visit(arg) for arg in node.args])
if node.distinct:
translated_args = f"DISTINCT {translated_args}"
return f"{node.name}({translated_args})"
elif node.name in CLICKHOUSE_FUNCTIONS:
@@ -656,6 +663,56 @@ class _Printer(Visitor):
def visit_unknown(self, node: ast.AST):
raise HogQLException(f"Unknown AST node {type(node).__name__}")
def visit_window_expr(self, node: ast.WindowExpr):
strings: List[str] = []
if node.partition_by is not None:
if len(node.partition_by) == 0:
raise HogQLException("PARTITION BY must have at least one argument")
strings.append("PARTITION BY")
for expr in node.partition_by:
strings.append(self.visit(expr))
if node.order_by is not None:
if len(node.order_by) == 0:
raise HogQLException("ORDER BY must have at least one argument")
strings.append("ORDER BY")
for expr in node.order_by:
strings.append(self.visit(expr))
if node.frame_method is not None:
if node.frame_method == "ROWS":
strings.append("ROWS")
elif node.frame_method == "RANGE":
strings.append("RANGE")
else:
raise HogQLException(f"Invalid frame method {node.frame_method}")
if node.frame_start and node.frame_end is None:
strings.append(self.visit(node.frame_start))
elif node.frame_start is not None and node.frame_end is not None:
strings.append("BETWEEN")
strings.append(self.visit(node.frame_start))
strings.append("AND")
strings.append(self.visit(node.frame_end))
else:
raise HogQLException("Frame start and end must be specified together")
return " ".join(strings)
def visit_window_function(self, node: ast.WindowFunction):
over = f"({self.visit(node.over_expr)})" if node.over_expr else self._print_identifier(node.over_identifier)
return f"{self._print_identifier(node.name)}({', '.join(self.visit(expr) for expr in node.args)}) OVER {over}"
def visit_window_frame_expr(self, node: ast.WindowFrameExpr):
if node.frame_type == "PRECEDING":
return f"{int(str(node.frame_value)) if node.frame_value is not None else 'UNBOUNDED'} PRECEDING"
elif node.frame_type == "FOLLOWING":
return f"{int(str(node.frame_value)) if node.frame_value is not None else 'UNBOUNDED'} FOLLOWING"
elif node.frame_type == "CURRENT ROW":
return "CURRENT ROW"
else:
raise HogQLException(f"Invalid frame type {node.frame_type}")
def _last_select(self) -> Optional[ast.SelectQuery]:
"""Find the last SELECT query in the stack."""
for node in reversed(self.stack):

View File

@@ -128,6 +128,9 @@ class Resolver(CloningVisitor):
new_node.limit_with_ties = node.limit_with_ties
new_node.offset = self.visit(node.offset)
new_node.distinct = node.distinct
new_node.window_exprs = (
{name: self.visit(expr) for name, expr in node.window_exprs.items()} if node.window_exprs else None
)
self.scopes.pop()
@@ -251,11 +254,7 @@ class Resolver(CloningVisitor):
raise ResolverException("Alias cannot be empty")
node = super().visit_alias(node)
if not node.expr.type:
raise ResolverException(f"Cannot alias an expression without a type: {node.alias}")
node.type = ast.FieldAliasType(alias=node.alias, type=node.expr.type)
node.type = ast.FieldAliasType(alias=node.alias, type=node.expr.type or ast.UnknownType())
scope.aliases[node.alias] = node.type
return node

View File

@@ -991,3 +991,56 @@ class TestParser(BaseTest):
],
),
)
def test_window_functions(self):
query = "SELECT person.id, min(timestamp) over (PARTITION by person.id ORDER BY timestamp DESC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) AS timestamp FROM events"
expr = parse_select(query)
expected = ast.SelectQuery(
select=[
ast.Field(chain=["person", "id"]),
ast.Alias(
alias="timestamp",
expr=ast.WindowFunction(
name="min",
args=[ast.Field(chain=["timestamp"])],
over_expr=ast.WindowExpr(
partition_by=[ast.Field(chain=["person", "id"])],
order_by=[ast.OrderExpr(expr=ast.Field(chain=["timestamp"]), order="DESC")],
frame_method="ROWS",
frame_start=ast.WindowFrameExpr(frame_type="PRECEDING", frame_value=None),
frame_end=ast.WindowFrameExpr(frame_type="PRECEDING", frame_value=1),
),
),
),
],
select_from=ast.JoinExpr(table=ast.Field(chain=["events"])),
)
self.assertEqual(expr, expected)
def test_window_functions_with_window(self):
query = "SELECT person.id, min(timestamp) over win1 AS timestamp FROM events WINDOW win1 as (PARTITION by person.id ORDER BY timestamp DESC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING)"
expr = parse_select(query)
expected = ast.SelectQuery(
select=[
ast.Field(chain=["person", "id"]),
ast.Alias(
alias="timestamp",
expr=ast.WindowFunction(
name="min",
args=[ast.Field(chain=["timestamp"])],
over_identifier="win1",
),
),
],
select_from=ast.JoinExpr(table=ast.Field(chain=["events"])),
window_exprs={
"win1": ast.WindowExpr(
partition_by=[ast.Field(chain=["person", "id"])],
order_by=[ast.OrderExpr(expr=ast.Field(chain=["timestamp"]), order="DESC")],
frame_method="ROWS",
frame_start=ast.WindowFrameExpr(frame_type="PRECEDING", frame_value=None),
frame_end=ast.WindowFrameExpr(frame_type="PRECEDING", frame_value=1),
)
},
)
self.assertEqual(expr, expected)

View File

@@ -584,3 +584,19 @@ class TestPrinter(BaseTest):
with self.assertRaises(HogQLException) as error_context:
self._select("SELECT now(), toDateTime(timestamp), toDateTime('2020-02-02') FROM events", context)
self.assertEqual(str(error_context.exception), "Unknown timezone: 'Europe/PostHogLandia'")
def test_window_functions(self):
self.assertEqual(
self._select(
"SELECT distinct_id, min(timestamp) over win1 as timestamp FROM events WINDOW win1 as (PARTITION by distinct_id ORDER BY timestamp DESC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING)"
),
f"SELECT events.distinct_id, min(toTimeZone(events.timestamp, %(hogql_val_0)s)) OVER win1 AS timestamp FROM events WHERE equals(events.team_id, {self.team.pk}) WINDOW win1 AS (PARTITION BY events.distinct_id ORDER BY timestamp DESC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) LIMIT 10000",
)
def test_window_functions_with_window(self):
self.assertEqual(
self._select(
"SELECT distinct_id, min(timestamp) over win1 as timestamp FROM events WINDOW win1 as (PARTITION by distinct_id ORDER BY timestamp DESC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING)"
),
f"SELECT events.distinct_id, min(toTimeZone(events.timestamp, %(hogql_val_0)s)) OVER win1 AS timestamp FROM events WHERE equals(events.team_id, {self.team.pk}) WINDOW win1 AS (PARTITION BY events.distinct_id ORDER BY timestamp DESC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) LIMIT 10000",
)

View File

@@ -1,5 +1,6 @@
from uuid import UUID
import pytz
from django.test import override_settings
from django.utils import timezone
from freezegun import freeze_time
@@ -741,3 +742,155 @@ class TestQuery(ClickhouseTestMixin, APIBaseTest):
)
],
)
def test_window_functions_simple(self):
random_uuid = str(UUIDT())
for person in range(5):
distinct_id = f"person_{person}_{random_uuid}"
with freeze_time("2020-01-10 00:00:00"):
_create_person(
properties={"name": f"Person {person}", "random_uuid": random_uuid},
team=self.team,
distinct_ids=[distinct_id],
is_identified=True,
)
_create_event(
distinct_id=distinct_id,
event="random event",
team=self.team,
properties={"character": "Luigi"},
)
flush_persons_and_events()
with freeze_time("2020-01-10 00:10:00"):
_create_event(
distinct_id=distinct_id,
event="random bla",
team=self.team,
properties={"character": "Luigi"},
)
flush_persons_and_events()
with freeze_time("2020-01-10 00:20:00"):
_create_event(
distinct_id=distinct_id,
event="random boo",
team=self.team,
properties={"character": "Luigi"},
)
flush_persons_and_events()
query = f"""
select distinct_id,
timestamp,
event,
groupArray(event) OVER (PARTITION BY distinct_id ORDER BY timestamp ASC ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS two_before,
groupArray(event) OVER (PARTITION BY distinct_id ORDER BY timestamp ASC ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING) AS two_after
from events
where timestamp > toDateTime('2020-01-09 00:00:00')
and distinct_id like '%_{random_uuid}'
order by distinct_id, timestamp
"""
response = execute_hogql_query(query, team=self.team)
expected = []
for person in range(5):
expected += [
(
f"person_{person}_{random_uuid}",
datetime.datetime(2020, 1, 10, 00, 00, 00, tzinfo=pytz.UTC),
"random event",
[],
["random bla", "random boo"],
),
(
f"person_{person}_{random_uuid}",
datetime.datetime(2020, 1, 10, 00, 10, 00, tzinfo=pytz.UTC),
"random bla",
["random event"],
["random boo"],
),
(
f"person_{person}_{random_uuid}",
datetime.datetime(2020, 1, 10, 00, 20, 00, tzinfo=pytz.UTC),
"random boo",
["random event", "random bla"],
[],
),
]
self.assertEqual(response.results, expected)
def test_window_functions_with_window(self):
random_uuid = str(UUIDT())
for person in range(5):
distinct_id = f"person_{person}_{random_uuid}"
with freeze_time("2020-01-10 00:00:00"):
_create_person(
properties={"name": f"Person {person}", "random_uuid": random_uuid},
team=self.team,
distinct_ids=[distinct_id],
is_identified=True,
)
_create_event(
distinct_id=distinct_id,
event="random event",
team=self.team,
properties={"character": "Luigi"},
)
flush_persons_and_events()
with freeze_time("2020-01-10 00:10:00"):
_create_event(
distinct_id=distinct_id,
event="random bla",
team=self.team,
properties={"character": "Luigi"},
)
flush_persons_and_events()
with freeze_time("2020-01-10 00:20:00"):
_create_event(
distinct_id=distinct_id,
event="random boo",
team=self.team,
properties={"character": "Luigi"},
)
flush_persons_and_events()
query = f"""
select distinct_id,
timestamp,
event,
groupArray(event) OVER w1 AS two_before,
groupArray(event) OVER w2 AS two_after
from events
where timestamp > toDateTime('2020-01-09 00:00:00')
and distinct_id like '%_{random_uuid}'
window w1 as (PARTITION BY distinct_id ORDER BY timestamp ASC ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING),
w2 as (PARTITION BY distinct_id ORDER BY timestamp ASC ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING)
order by distinct_id, timestamp
"""
response = execute_hogql_query(query, team=self.team)
expected = []
for person in range(5):
expected += [
(
f"person_{person}_{random_uuid}",
datetime.datetime(2020, 1, 10, 00, 00, 00, tzinfo=pytz.UTC),
"random event",
[],
["random bla", "random boo"],
),
(
f"person_{person}_{random_uuid}",
datetime.datetime(2020, 1, 10, 00, 10, 00, tzinfo=pytz.UTC),
"random bla",
["random event"],
["random boo"],
),
(
f"person_{person}_{random_uuid}",
datetime.datetime(2020, 1, 10, 00, 20, 00, tzinfo=pytz.UTC),
"random boo",
["random event", "random bla"],
[],
),
]
self.assertEqual(response.results, expected)

View File

@@ -109,6 +109,8 @@ class TraversingVisitor(Visitor):
self.visit(expr)
self.visit(node.limit),
self.visit(node.offset),
for expr in (node.window_exprs or {}).values():
self.visit(expr)
def visit_select_union_query(self, node: ast.SelectUnionQuery):
for expr in node.select_queries:
@@ -199,6 +201,22 @@ class TraversingVisitor(Visitor):
def visit_property_type(self, node: ast.PropertyType):
self.visit(node.field_type)
def visit_window_expr(self, node: ast.WindowExpr):
for expr in node.partition_by or []:
self.visit(expr)
for expr in node.order_by or []:
self.visit(expr)
self.visit(node.frame_start)
self.visit(node.frame_end)
def visit_window_function(self, node: ast.WindowFunction):
for expr in node.args or []:
self.visit(expr)
self.visit(node.over_expr)
def visit_window_frame_expr(self, node: ast.WindowFrameExpr):
pass
class CloningVisitor(Visitor):
"""Visitor that traverses and clones the AST tree. Clears types."""
@@ -339,6 +357,9 @@ class CloningVisitor(Visitor):
limit_with_ties=node.limit_with_ties,
offset=self.visit(node.offset),
distinct=node.distinct,
window_exprs={name: self.visit(expr) for name, expr in node.window_exprs.items()}
if node.window_exprs
else None,
)
def visit_select_union_query(self, node: ast.SelectUnionQuery):
@@ -346,3 +367,29 @@ class CloningVisitor(Visitor):
type=None if self.clear_types else node.type,
select_queries=[self.visit(expr) for expr in node.select_queries],
)
def visit_window_expr(self, node: ast.WindowExpr):
return ast.WindowExpr(
type=None if self.clear_types else node.type,
partition_by=[self.visit(expr) for expr in node.partition_by] if node.partition_by else None,
order_by=[self.visit(expr) for expr in node.order_by] if node.order_by else None,
frame_method=node.frame_method,
frame_start=self.visit(node.frame_start),
frame_end=self.visit(node.frame_end),
)
def visit_window_function(self, node: ast.WindowFunction):
return ast.WindowFunction(
type=None if self.clear_types else node.type,
name=node.name,
args=[self.visit(expr) for expr in node.args] if node.args else None,
over_expr=self.visit(node.over_expr) if node.over_expr else None,
over_identifier=node.over_identifier,
)
def visit_window_frame_expr(self, node: ast.WindowFrameExpr):
return ast.WindowFrameExpr(
type=None if self.clear_types else node.type,
frame_type=node.frame_type,
frame_value=node.frame_value,
)

View File

@@ -4,7 +4,7 @@
# - `pip-compile --rebuild requirements.in`
# - `pip-compile --rebuild requirements-dev.in`
#
antlr4-python3-runtime==4.11.1
antlr4-python3-runtime==4.13.0
amqp==2.6.0
boto3==1.26.66
celery==4.4.7

View File

@@ -12,7 +12,7 @@ amqp==2.6.0
# via
# -r requirements.in
# kombu
antlr4-python3-runtime==4.11.1
antlr4-python3-runtime==4.13.0
# via -r requirements.in
asgiref==3.3.2
# via django