mirror of
https://github.com/BillyOutlast/posthog.git
synced 2026-02-04 03:01:23 +01:00
feat(hogql): window functions (#15132)
This commit is contained in:
4
.github/workflows/ci-backend.yml
vendored
4
.github/workflows/ci-backend.yml
vendored
@@ -142,9 +142,9 @@ jobs:
|
||||
|
||||
- name: Check if antlr definitions are up to date
|
||||
run: |
|
||||
# Installing a version of ant compatible with what we use in development from homebrew (4.11)
|
||||
# Installing a version of ant compatible with what we use in development from homebrew (4.13)
|
||||
# "apt-get install antlr" would install 4.7 which is incompatible with our grammar.
|
||||
export ANTLR_VERSION=4.11.1
|
||||
export ANTLR_VERSION=4.13.0
|
||||
# java version doesn't matter
|
||||
sudo apt-get install default-jre
|
||||
mkdir antlr
|
||||
|
||||
@@ -449,6 +449,26 @@ class JoinExpr(Expr):
|
||||
sample: Optional["SampleExpr"] = None
|
||||
|
||||
|
||||
class WindowFrameExpr(Expr):
|
||||
frame_type: Optional[Literal["CURRENT ROW", "PRECEDING", "FOLLOWING"]] = None
|
||||
frame_value: Optional[int] = None
|
||||
|
||||
|
||||
class WindowExpr(Expr):
|
||||
partition_by: Optional[List[Expr]] = None
|
||||
order_by: Optional[List[OrderExpr]] = None
|
||||
frame_method: Optional[Literal["ROWS", "RANGE"]] = None
|
||||
frame_start: Optional[WindowFrameExpr] = None
|
||||
frame_end: Optional[WindowFrameExpr] = None
|
||||
|
||||
|
||||
class WindowFunction(Expr):
|
||||
name: str
|
||||
args: Optional[List[Expr]] = None
|
||||
over_expr: Optional[WindowExpr] = None
|
||||
over_identifier: Optional[str] = None
|
||||
|
||||
|
||||
class SelectQuery(Expr):
|
||||
# :TRICKY: When adding new fields, make sure they're handled in visitor.py and resolver.py
|
||||
type: Optional[SelectQueryType] = None
|
||||
@@ -456,6 +476,7 @@ class SelectQuery(Expr):
|
||||
select: List[Expr]
|
||||
distinct: Optional[bool] = None
|
||||
select_from: Optional[JoinExpr] = None
|
||||
window_exprs: Optional[Dict[str, WindowExpr]] = None
|
||||
where: Optional[Expr] = None
|
||||
prewhere: Optional[Expr] = None
|
||||
having: Optional[Expr] = None
|
||||
|
||||
@@ -585,6 +585,29 @@ HOGQL_AGGREGATIONS = {
|
||||
"uniqHLL12If": (2, None),
|
||||
"uniqTheta": (1, None),
|
||||
"uniqThetaIf": (2, None),
|
||||
"median": 1,
|
||||
"medianIf": 2,
|
||||
"medianExact": 1,
|
||||
"medianExactIf": 2,
|
||||
"medianExactLow": 1,
|
||||
"medianExactLowIf": 2,
|
||||
"medianExactHigh": 1,
|
||||
"medianExactHighIf": 2,
|
||||
"medianExactWeighted": 1,
|
||||
"medianExactWeightedIf": 2,
|
||||
"medianTiming": 1,
|
||||
"medianTimingIf": 2,
|
||||
"medianTimingWeighted": 1,
|
||||
"medianTimingWeightedIf": 2,
|
||||
"medianDeterministic": 1,
|
||||
"medianDeterministicIf": 2,
|
||||
"medianTDigest": 1,
|
||||
"medianTDigestIf": 2,
|
||||
"medianTDigestWeighted": 1,
|
||||
"medianTDigestWeightedIf": 2,
|
||||
"medianBFloat16": 1,
|
||||
"medianBFloat16If": 2,
|
||||
# TODO: quantile(0.5)(expr) is not supported
|
||||
# "quantile": 1,
|
||||
# "quantileIf": 2,
|
||||
# "quantiles": 1,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Generated from HogQLLexer.g4 by ANTLR 4.11.1
|
||||
# Generated from HogQLLexer.g4 by ANTLR 4.13.0
|
||||
from antlr4 import *
|
||||
from io import StringIO
|
||||
import sys
|
||||
@@ -1221,7 +1221,7 @@ class HogQLLexer(Lexer):
|
||||
|
||||
def __init__(self, input=None, output:TextIO = sys.stdout):
|
||||
super().__init__(input, output)
|
||||
self.checkVersion("4.11.1")
|
||||
self.checkVersion("4.13.0")
|
||||
self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache())
|
||||
self._actions = None
|
||||
self._predicates = None
|
||||
|
||||
@@ -15,11 +15,11 @@ selectStmt:
|
||||
columns=columnExprList
|
||||
from=fromClause?
|
||||
arrayJoinClause?
|
||||
windowClause?
|
||||
prewhereClause?
|
||||
where=whereClause?
|
||||
groupByClause? (WITH (CUBE | ROLLUP))? (WITH TOTALS)?
|
||||
havingClause?
|
||||
windowClause?
|
||||
orderByClause?
|
||||
limitClause?
|
||||
settingsClause?
|
||||
@@ -29,7 +29,7 @@ withClause: WITH withExprList;
|
||||
topClause: TOP DECIMAL_LITERAL (WITH TIES)?;
|
||||
fromClause: FROM joinExpr;
|
||||
arrayJoinClause: (LEFT | INNER)? ARRAY JOIN columnExprList;
|
||||
windowClause: WINDOW identifier AS LPAREN windowExpr RPAREN;
|
||||
windowClause: WINDOW identifier AS LPAREN windowExpr RPAREN (COMMA identifier AS LPAREN windowExpr RPAREN)*;
|
||||
prewhereClause: PREWHERE columnExpr;
|
||||
whereClause: WHERE columnExpr;
|
||||
groupByClause: GROUP BY ((CUBE | ROLLUP) LPAREN columnExprList RPAREN | columnExprList);
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,6 @@
|
||||
# Generated from HogQLParser.g4 by ANTLR 4.11.1
|
||||
# Generated from HogQLParser.g4 by ANTLR 4.13.0
|
||||
from antlr4 import *
|
||||
if __name__ is not None and "." in __name__:
|
||||
if "." in __name__:
|
||||
from .HogQLParser import HogQLParser
|
||||
else:
|
||||
from HogQLParser import HogQLParser
|
||||
|
||||
@@ -87,7 +87,6 @@ class HogQLParseTreeConverter(ParseTreeVisitor):
|
||||
return self.visit(ctx.selectStmt() or ctx.selectUnionStmt())
|
||||
|
||||
def visitSelectStmt(self, ctx: HogQLParser.SelectStmtContext):
|
||||
|
||||
select_query = ast.SelectQuery(
|
||||
macros=self.visit(ctx.withClause()) if ctx.withClause() else None,
|
||||
select=self.visit(ctx.columnExprList()) if ctx.columnExprList() else [],
|
||||
@@ -100,6 +99,12 @@ class HogQLParseTreeConverter(ParseTreeVisitor):
|
||||
order_by=self.visit(ctx.orderByClause()) if ctx.orderByClause() else None,
|
||||
)
|
||||
|
||||
if ctx.windowClause():
|
||||
select_query.window_exprs = {}
|
||||
for index, window_expr in enumerate(ctx.windowClause().windowExpr()):
|
||||
name = self.visit(ctx.windowClause().identifier()[index])
|
||||
select_query.window_exprs[name] = self.visit(window_expr)
|
||||
|
||||
if ctx.limitClause():
|
||||
limit_clause = ctx.limitClause()
|
||||
limit_expr = limit_clause.limitExpr()
|
||||
@@ -116,8 +121,6 @@ class HogQLParseTreeConverter(ParseTreeVisitor):
|
||||
raise NotImplementedException(f"Unsupported: SelectStmt.topClause()")
|
||||
if ctx.arrayJoinClause():
|
||||
raise NotImplementedException(f"Unsupported: SelectStmt.arrayJoinClause()")
|
||||
if ctx.windowClause():
|
||||
raise NotImplementedException(f"Unsupported: SelectStmt.windowClause()")
|
||||
if ctx.settingsClause():
|
||||
raise NotImplementedException(f"Unsupported: SelectStmt.settingsClause()")
|
||||
|
||||
@@ -292,25 +295,44 @@ class HogQLParseTreeConverter(ParseTreeVisitor):
|
||||
raise NotImplementedException(f"Unsupported node: SettingExpr")
|
||||
|
||||
def visitWindowExpr(self, ctx: HogQLParser.WindowExprContext):
|
||||
raise NotImplementedException(f"Unsupported node: WindowExpr")
|
||||
frame = ctx.winFrameClause()
|
||||
visited_frame = self.visit(frame) if frame else None
|
||||
expr = ast.WindowExpr(
|
||||
partition_by=self.visit(ctx.winPartitionByClause()) if ctx.winPartitionByClause() else None,
|
||||
order_by=self.visit(ctx.winOrderByClause()) if ctx.winOrderByClause() else None,
|
||||
frame_method="RANGE" if frame and frame.RANGE() else "ROWS" if frame and frame.ROWS() else None,
|
||||
frame_start=visited_frame[0] if isinstance(visited_frame, tuple) else visited_frame,
|
||||
frame_end=visited_frame[1] if isinstance(visited_frame, tuple) else None,
|
||||
)
|
||||
return expr
|
||||
|
||||
def visitWinPartitionByClause(self, ctx: HogQLParser.WinPartitionByClauseContext):
|
||||
raise NotImplementedException(f"Unsupported node: WinPartitionByClause")
|
||||
return self.visit(ctx.columnExprList())
|
||||
|
||||
def visitWinOrderByClause(self, ctx: HogQLParser.WinOrderByClauseContext):
|
||||
raise NotImplementedException(f"Unsupported node: WinOrderByClause")
|
||||
return self.visit(ctx.orderExprList())
|
||||
|
||||
def visitWinFrameClause(self, ctx: HogQLParser.WinFrameClauseContext):
|
||||
raise NotImplementedException(f"Unsupported node: WinFrameClause")
|
||||
return self.visit(ctx.winFrameExtend())
|
||||
|
||||
def visitFrameStart(self, ctx: HogQLParser.FrameStartContext):
|
||||
raise NotImplementedException(f"Unsupported node: FrameStart")
|
||||
return self.visit(ctx.winFrameBound())
|
||||
|
||||
def visitFrameBetween(self, ctx: HogQLParser.FrameBetweenContext):
|
||||
raise NotImplementedException(f"Unsupported node: FrameBetween")
|
||||
return (self.visit(ctx.winFrameBound(0)), self.visit(ctx.winFrameBound(1)))
|
||||
|
||||
def visitWinFrameBound(self, ctx: HogQLParser.WinFrameBoundContext):
|
||||
raise NotImplementedException(f"Unsupported node: WinFrameBound")
|
||||
if ctx.PRECEDING():
|
||||
return ast.WindowFrameExpr(
|
||||
frame_type="PRECEDING",
|
||||
frame_value=self.visit(ctx.numberLiteral()).value if ctx.numberLiteral() else None,
|
||||
)
|
||||
if ctx.FOLLOWING():
|
||||
return ast.WindowFrameExpr(
|
||||
frame_type="FOLLOWING",
|
||||
frame_value=self.visit(ctx.numberLiteral()).value if ctx.numberLiteral() else None,
|
||||
)
|
||||
return ast.WindowFrameExpr(frame_type="CURRENT ROW")
|
||||
|
||||
def visitExpr(self, ctx: HogQLParser.ExprContext):
|
||||
return self.visit(ctx.columnExpr())
|
||||
@@ -479,9 +501,6 @@ class HogQLParseTreeConverter(ParseTreeVisitor):
|
||||
op=ast.CompareOperationOp.NotEq if ctx.NOT() else ast.CompareOperationOp.Eq,
|
||||
)
|
||||
|
||||
def visitColumnExprWinFunctionTarget(self, ctx: HogQLParser.ColumnExprWinFunctionTargetContext):
|
||||
raise NotImplementedException(f"Unsupported node: ColumnExprWinFunctionTarget")
|
||||
|
||||
def visitColumnExprTrim(self, ctx: HogQLParser.ColumnExprTrimContext):
|
||||
raise NotImplementedException(f"Unsupported node: ColumnExprTrim")
|
||||
|
||||
@@ -557,8 +576,19 @@ class HogQLParseTreeConverter(ParseTreeVisitor):
|
||||
def visitColumnExprNot(self, ctx: HogQLParser.ColumnExprNotContext):
|
||||
return ast.Not(expr=self.visit(ctx.columnExpr()))
|
||||
|
||||
def visitColumnExprWinFunctionTarget(self, ctx: HogQLParser.ColumnExprWinFunctionTargetContext):
|
||||
return ast.WindowFunction(
|
||||
name=self.visit(ctx.identifier(0)),
|
||||
args=self.visit(ctx.columnExprList()) if ctx.columnExprList() else [],
|
||||
over_identifier=self.visit(ctx.identifier(1)),
|
||||
)
|
||||
|
||||
def visitColumnExprWinFunction(self, ctx: HogQLParser.ColumnExprWinFunctionContext):
|
||||
raise NotImplementedException(f"Unsupported node: ColumnExprWinFunction")
|
||||
return ast.WindowFunction(
|
||||
name=self.visit(ctx.identifier()),
|
||||
args=self.visit(ctx.columnExprList()) if ctx.columnExprList() else [],
|
||||
over_expr=self.visit(ctx.windowExpr()) if ctx.windowExpr() else None,
|
||||
)
|
||||
|
||||
def visitColumnExprIdentifier(self, ctx: HogQLParser.ColumnExprIdentifierContext):
|
||||
return self.visit(ctx.columnIdentifier())
|
||||
|
||||
@@ -174,10 +174,17 @@ class _Printer(Visitor):
|
||||
next_join = next_join.next_join
|
||||
|
||||
columns = [self.visit(column) for column in node.select] if node.select else ["1"]
|
||||
where = self.visit(where) if where else None
|
||||
having = self.visit(node.having) if node.having else None
|
||||
window = (
|
||||
", ".join(
|
||||
[f"{self._print_identifier(name)} AS ({self.visit(expr)})" for name, expr in node.window_exprs.items()]
|
||||
)
|
||||
if node.window_exprs
|
||||
else None
|
||||
)
|
||||
prewhere = self.visit(node.prewhere) if node.prewhere else None
|
||||
where = self.visit(where) if where else None
|
||||
group_by = [self.visit(column) for column in node.group_by] if node.group_by else None
|
||||
having = self.visit(node.having) if node.having else None
|
||||
order_by = [self.visit(column) for column in node.order_by] if node.order_by else None
|
||||
|
||||
clauses = [
|
||||
@@ -187,6 +194,7 @@ class _Printer(Visitor):
|
||||
"WHERE " + where if where else None,
|
||||
f"GROUP BY {', '.join(group_by)}" if group_by and len(group_by) > 0 else None,
|
||||
"HAVING " + having if having else None,
|
||||
"WINDOW " + window if window else None,
|
||||
f"ORDER BY {', '.join(order_by)}" if order_by and len(order_by) > 0 else None,
|
||||
]
|
||||
|
||||
@@ -433,7 +441,6 @@ class _Printer(Visitor):
|
||||
translated_args = ", ".join([self.visit(arg) for arg in node.args])
|
||||
if node.distinct:
|
||||
translated_args = f"DISTINCT {translated_args}"
|
||||
|
||||
return f"{node.name}({translated_args})"
|
||||
|
||||
elif node.name in CLICKHOUSE_FUNCTIONS:
|
||||
@@ -656,6 +663,56 @@ class _Printer(Visitor):
|
||||
def visit_unknown(self, node: ast.AST):
|
||||
raise HogQLException(f"Unknown AST node {type(node).__name__}")
|
||||
|
||||
def visit_window_expr(self, node: ast.WindowExpr):
|
||||
strings: List[str] = []
|
||||
if node.partition_by is not None:
|
||||
if len(node.partition_by) == 0:
|
||||
raise HogQLException("PARTITION BY must have at least one argument")
|
||||
strings.append("PARTITION BY")
|
||||
for expr in node.partition_by:
|
||||
strings.append(self.visit(expr))
|
||||
|
||||
if node.order_by is not None:
|
||||
if len(node.order_by) == 0:
|
||||
raise HogQLException("ORDER BY must have at least one argument")
|
||||
strings.append("ORDER BY")
|
||||
for expr in node.order_by:
|
||||
strings.append(self.visit(expr))
|
||||
|
||||
if node.frame_method is not None:
|
||||
if node.frame_method == "ROWS":
|
||||
strings.append("ROWS")
|
||||
elif node.frame_method == "RANGE":
|
||||
strings.append("RANGE")
|
||||
else:
|
||||
raise HogQLException(f"Invalid frame method {node.frame_method}")
|
||||
if node.frame_start and node.frame_end is None:
|
||||
strings.append(self.visit(node.frame_start))
|
||||
|
||||
elif node.frame_start is not None and node.frame_end is not None:
|
||||
strings.append("BETWEEN")
|
||||
strings.append(self.visit(node.frame_start))
|
||||
strings.append("AND")
|
||||
strings.append(self.visit(node.frame_end))
|
||||
|
||||
else:
|
||||
raise HogQLException("Frame start and end must be specified together")
|
||||
return " ".join(strings)
|
||||
|
||||
def visit_window_function(self, node: ast.WindowFunction):
|
||||
over = f"({self.visit(node.over_expr)})" if node.over_expr else self._print_identifier(node.over_identifier)
|
||||
return f"{self._print_identifier(node.name)}({', '.join(self.visit(expr) for expr in node.args)}) OVER {over}"
|
||||
|
||||
def visit_window_frame_expr(self, node: ast.WindowFrameExpr):
|
||||
if node.frame_type == "PRECEDING":
|
||||
return f"{int(str(node.frame_value)) if node.frame_value is not None else 'UNBOUNDED'} PRECEDING"
|
||||
elif node.frame_type == "FOLLOWING":
|
||||
return f"{int(str(node.frame_value)) if node.frame_value is not None else 'UNBOUNDED'} FOLLOWING"
|
||||
elif node.frame_type == "CURRENT ROW":
|
||||
return "CURRENT ROW"
|
||||
else:
|
||||
raise HogQLException(f"Invalid frame type {node.frame_type}")
|
||||
|
||||
def _last_select(self) -> Optional[ast.SelectQuery]:
|
||||
"""Find the last SELECT query in the stack."""
|
||||
for node in reversed(self.stack):
|
||||
|
||||
@@ -128,6 +128,9 @@ class Resolver(CloningVisitor):
|
||||
new_node.limit_with_ties = node.limit_with_ties
|
||||
new_node.offset = self.visit(node.offset)
|
||||
new_node.distinct = node.distinct
|
||||
new_node.window_exprs = (
|
||||
{name: self.visit(expr) for name, expr in node.window_exprs.items()} if node.window_exprs else None
|
||||
)
|
||||
|
||||
self.scopes.pop()
|
||||
|
||||
@@ -251,11 +254,7 @@ class Resolver(CloningVisitor):
|
||||
raise ResolverException("Alias cannot be empty")
|
||||
|
||||
node = super().visit_alias(node)
|
||||
|
||||
if not node.expr.type:
|
||||
raise ResolverException(f"Cannot alias an expression without a type: {node.alias}")
|
||||
|
||||
node.type = ast.FieldAliasType(alias=node.alias, type=node.expr.type)
|
||||
node.type = ast.FieldAliasType(alias=node.alias, type=node.expr.type or ast.UnknownType())
|
||||
scope.aliases[node.alias] = node.type
|
||||
return node
|
||||
|
||||
|
||||
@@ -991,3 +991,56 @@ class TestParser(BaseTest):
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
def test_window_functions(self):
|
||||
query = "SELECT person.id, min(timestamp) over (PARTITION by person.id ORDER BY timestamp DESC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) AS timestamp FROM events"
|
||||
expr = parse_select(query)
|
||||
expected = ast.SelectQuery(
|
||||
select=[
|
||||
ast.Field(chain=["person", "id"]),
|
||||
ast.Alias(
|
||||
alias="timestamp",
|
||||
expr=ast.WindowFunction(
|
||||
name="min",
|
||||
args=[ast.Field(chain=["timestamp"])],
|
||||
over_expr=ast.WindowExpr(
|
||||
partition_by=[ast.Field(chain=["person", "id"])],
|
||||
order_by=[ast.OrderExpr(expr=ast.Field(chain=["timestamp"]), order="DESC")],
|
||||
frame_method="ROWS",
|
||||
frame_start=ast.WindowFrameExpr(frame_type="PRECEDING", frame_value=None),
|
||||
frame_end=ast.WindowFrameExpr(frame_type="PRECEDING", frame_value=1),
|
||||
),
|
||||
),
|
||||
),
|
||||
],
|
||||
select_from=ast.JoinExpr(table=ast.Field(chain=["events"])),
|
||||
)
|
||||
self.assertEqual(expr, expected)
|
||||
|
||||
def test_window_functions_with_window(self):
|
||||
query = "SELECT person.id, min(timestamp) over win1 AS timestamp FROM events WINDOW win1 as (PARTITION by person.id ORDER BY timestamp DESC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING)"
|
||||
expr = parse_select(query)
|
||||
expected = ast.SelectQuery(
|
||||
select=[
|
||||
ast.Field(chain=["person", "id"]),
|
||||
ast.Alias(
|
||||
alias="timestamp",
|
||||
expr=ast.WindowFunction(
|
||||
name="min",
|
||||
args=[ast.Field(chain=["timestamp"])],
|
||||
over_identifier="win1",
|
||||
),
|
||||
),
|
||||
],
|
||||
select_from=ast.JoinExpr(table=ast.Field(chain=["events"])),
|
||||
window_exprs={
|
||||
"win1": ast.WindowExpr(
|
||||
partition_by=[ast.Field(chain=["person", "id"])],
|
||||
order_by=[ast.OrderExpr(expr=ast.Field(chain=["timestamp"]), order="DESC")],
|
||||
frame_method="ROWS",
|
||||
frame_start=ast.WindowFrameExpr(frame_type="PRECEDING", frame_value=None),
|
||||
frame_end=ast.WindowFrameExpr(frame_type="PRECEDING", frame_value=1),
|
||||
)
|
||||
},
|
||||
)
|
||||
self.assertEqual(expr, expected)
|
||||
|
||||
@@ -584,3 +584,19 @@ class TestPrinter(BaseTest):
|
||||
with self.assertRaises(HogQLException) as error_context:
|
||||
self._select("SELECT now(), toDateTime(timestamp), toDateTime('2020-02-02') FROM events", context)
|
||||
self.assertEqual(str(error_context.exception), "Unknown timezone: 'Europe/PostHogLandia'")
|
||||
|
||||
def test_window_functions(self):
|
||||
self.assertEqual(
|
||||
self._select(
|
||||
"SELECT distinct_id, min(timestamp) over win1 as timestamp FROM events WINDOW win1 as (PARTITION by distinct_id ORDER BY timestamp DESC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING)"
|
||||
),
|
||||
f"SELECT events.distinct_id, min(toTimeZone(events.timestamp, %(hogql_val_0)s)) OVER win1 AS timestamp FROM events WHERE equals(events.team_id, {self.team.pk}) WINDOW win1 AS (PARTITION BY events.distinct_id ORDER BY timestamp DESC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) LIMIT 10000",
|
||||
)
|
||||
|
||||
def test_window_functions_with_window(self):
|
||||
self.assertEqual(
|
||||
self._select(
|
||||
"SELECT distinct_id, min(timestamp) over win1 as timestamp FROM events WINDOW win1 as (PARTITION by distinct_id ORDER BY timestamp DESC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING)"
|
||||
),
|
||||
f"SELECT events.distinct_id, min(toTimeZone(events.timestamp, %(hogql_val_0)s)) OVER win1 AS timestamp FROM events WHERE equals(events.team_id, {self.team.pk}) WINDOW win1 AS (PARTITION BY events.distinct_id ORDER BY timestamp DESC ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING) LIMIT 10000",
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from uuid import UUID
|
||||
|
||||
import pytz
|
||||
from django.test import override_settings
|
||||
from django.utils import timezone
|
||||
from freezegun import freeze_time
|
||||
@@ -741,3 +742,155 @@ class TestQuery(ClickhouseTestMixin, APIBaseTest):
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def test_window_functions_simple(self):
|
||||
random_uuid = str(UUIDT())
|
||||
for person in range(5):
|
||||
distinct_id = f"person_{person}_{random_uuid}"
|
||||
with freeze_time("2020-01-10 00:00:00"):
|
||||
_create_person(
|
||||
properties={"name": f"Person {person}", "random_uuid": random_uuid},
|
||||
team=self.team,
|
||||
distinct_ids=[distinct_id],
|
||||
is_identified=True,
|
||||
)
|
||||
_create_event(
|
||||
distinct_id=distinct_id,
|
||||
event="random event",
|
||||
team=self.team,
|
||||
properties={"character": "Luigi"},
|
||||
)
|
||||
flush_persons_and_events()
|
||||
with freeze_time("2020-01-10 00:10:00"):
|
||||
_create_event(
|
||||
distinct_id=distinct_id,
|
||||
event="random bla",
|
||||
team=self.team,
|
||||
properties={"character": "Luigi"},
|
||||
)
|
||||
flush_persons_and_events()
|
||||
with freeze_time("2020-01-10 00:20:00"):
|
||||
_create_event(
|
||||
distinct_id=distinct_id,
|
||||
event="random boo",
|
||||
team=self.team,
|
||||
properties={"character": "Luigi"},
|
||||
)
|
||||
flush_persons_and_events()
|
||||
|
||||
query = f"""
|
||||
select distinct_id,
|
||||
timestamp,
|
||||
event,
|
||||
groupArray(event) OVER (PARTITION BY distinct_id ORDER BY timestamp ASC ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING) AS two_before,
|
||||
groupArray(event) OVER (PARTITION BY distinct_id ORDER BY timestamp ASC ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING) AS two_after
|
||||
from events
|
||||
where timestamp > toDateTime('2020-01-09 00:00:00')
|
||||
and distinct_id like '%_{random_uuid}'
|
||||
order by distinct_id, timestamp
|
||||
"""
|
||||
response = execute_hogql_query(query, team=self.team)
|
||||
|
||||
expected = []
|
||||
for person in range(5):
|
||||
expected += [
|
||||
(
|
||||
f"person_{person}_{random_uuid}",
|
||||
datetime.datetime(2020, 1, 10, 00, 00, 00, tzinfo=pytz.UTC),
|
||||
"random event",
|
||||
[],
|
||||
["random bla", "random boo"],
|
||||
),
|
||||
(
|
||||
f"person_{person}_{random_uuid}",
|
||||
datetime.datetime(2020, 1, 10, 00, 10, 00, tzinfo=pytz.UTC),
|
||||
"random bla",
|
||||
["random event"],
|
||||
["random boo"],
|
||||
),
|
||||
(
|
||||
f"person_{person}_{random_uuid}",
|
||||
datetime.datetime(2020, 1, 10, 00, 20, 00, tzinfo=pytz.UTC),
|
||||
"random boo",
|
||||
["random event", "random bla"],
|
||||
[],
|
||||
),
|
||||
]
|
||||
self.assertEqual(response.results, expected)
|
||||
|
||||
def test_window_functions_with_window(self):
|
||||
random_uuid = str(UUIDT())
|
||||
for person in range(5):
|
||||
distinct_id = f"person_{person}_{random_uuid}"
|
||||
with freeze_time("2020-01-10 00:00:00"):
|
||||
_create_person(
|
||||
properties={"name": f"Person {person}", "random_uuid": random_uuid},
|
||||
team=self.team,
|
||||
distinct_ids=[distinct_id],
|
||||
is_identified=True,
|
||||
)
|
||||
_create_event(
|
||||
distinct_id=distinct_id,
|
||||
event="random event",
|
||||
team=self.team,
|
||||
properties={"character": "Luigi"},
|
||||
)
|
||||
flush_persons_and_events()
|
||||
with freeze_time("2020-01-10 00:10:00"):
|
||||
_create_event(
|
||||
distinct_id=distinct_id,
|
||||
event="random bla",
|
||||
team=self.team,
|
||||
properties={"character": "Luigi"},
|
||||
)
|
||||
flush_persons_and_events()
|
||||
with freeze_time("2020-01-10 00:20:00"):
|
||||
_create_event(
|
||||
distinct_id=distinct_id,
|
||||
event="random boo",
|
||||
team=self.team,
|
||||
properties={"character": "Luigi"},
|
||||
)
|
||||
flush_persons_and_events()
|
||||
|
||||
query = f"""
|
||||
select distinct_id,
|
||||
timestamp,
|
||||
event,
|
||||
groupArray(event) OVER w1 AS two_before,
|
||||
groupArray(event) OVER w2 AS two_after
|
||||
from events
|
||||
where timestamp > toDateTime('2020-01-09 00:00:00')
|
||||
and distinct_id like '%_{random_uuid}'
|
||||
window w1 as (PARTITION BY distinct_id ORDER BY timestamp ASC ROWS BETWEEN 2 PRECEDING AND 1 PRECEDING),
|
||||
w2 as (PARTITION BY distinct_id ORDER BY timestamp ASC ROWS BETWEEN 1 FOLLOWING AND 2 FOLLOWING)
|
||||
order by distinct_id, timestamp
|
||||
"""
|
||||
response = execute_hogql_query(query, team=self.team)
|
||||
|
||||
expected = []
|
||||
for person in range(5):
|
||||
expected += [
|
||||
(
|
||||
f"person_{person}_{random_uuid}",
|
||||
datetime.datetime(2020, 1, 10, 00, 00, 00, tzinfo=pytz.UTC),
|
||||
"random event",
|
||||
[],
|
||||
["random bla", "random boo"],
|
||||
),
|
||||
(
|
||||
f"person_{person}_{random_uuid}",
|
||||
datetime.datetime(2020, 1, 10, 00, 10, 00, tzinfo=pytz.UTC),
|
||||
"random bla",
|
||||
["random event"],
|
||||
["random boo"],
|
||||
),
|
||||
(
|
||||
f"person_{person}_{random_uuid}",
|
||||
datetime.datetime(2020, 1, 10, 00, 20, 00, tzinfo=pytz.UTC),
|
||||
"random boo",
|
||||
["random event", "random bla"],
|
||||
[],
|
||||
),
|
||||
]
|
||||
self.assertEqual(response.results, expected)
|
||||
|
||||
@@ -109,6 +109,8 @@ class TraversingVisitor(Visitor):
|
||||
self.visit(expr)
|
||||
self.visit(node.limit),
|
||||
self.visit(node.offset),
|
||||
for expr in (node.window_exprs or {}).values():
|
||||
self.visit(expr)
|
||||
|
||||
def visit_select_union_query(self, node: ast.SelectUnionQuery):
|
||||
for expr in node.select_queries:
|
||||
@@ -199,6 +201,22 @@ class TraversingVisitor(Visitor):
|
||||
def visit_property_type(self, node: ast.PropertyType):
|
||||
self.visit(node.field_type)
|
||||
|
||||
def visit_window_expr(self, node: ast.WindowExpr):
|
||||
for expr in node.partition_by or []:
|
||||
self.visit(expr)
|
||||
for expr in node.order_by or []:
|
||||
self.visit(expr)
|
||||
self.visit(node.frame_start)
|
||||
self.visit(node.frame_end)
|
||||
|
||||
def visit_window_function(self, node: ast.WindowFunction):
|
||||
for expr in node.args or []:
|
||||
self.visit(expr)
|
||||
self.visit(node.over_expr)
|
||||
|
||||
def visit_window_frame_expr(self, node: ast.WindowFrameExpr):
|
||||
pass
|
||||
|
||||
|
||||
class CloningVisitor(Visitor):
|
||||
"""Visitor that traverses and clones the AST tree. Clears types."""
|
||||
@@ -339,6 +357,9 @@ class CloningVisitor(Visitor):
|
||||
limit_with_ties=node.limit_with_ties,
|
||||
offset=self.visit(node.offset),
|
||||
distinct=node.distinct,
|
||||
window_exprs={name: self.visit(expr) for name, expr in node.window_exprs.items()}
|
||||
if node.window_exprs
|
||||
else None,
|
||||
)
|
||||
|
||||
def visit_select_union_query(self, node: ast.SelectUnionQuery):
|
||||
@@ -346,3 +367,29 @@ class CloningVisitor(Visitor):
|
||||
type=None if self.clear_types else node.type,
|
||||
select_queries=[self.visit(expr) for expr in node.select_queries],
|
||||
)
|
||||
|
||||
def visit_window_expr(self, node: ast.WindowExpr):
|
||||
return ast.WindowExpr(
|
||||
type=None if self.clear_types else node.type,
|
||||
partition_by=[self.visit(expr) for expr in node.partition_by] if node.partition_by else None,
|
||||
order_by=[self.visit(expr) for expr in node.order_by] if node.order_by else None,
|
||||
frame_method=node.frame_method,
|
||||
frame_start=self.visit(node.frame_start),
|
||||
frame_end=self.visit(node.frame_end),
|
||||
)
|
||||
|
||||
def visit_window_function(self, node: ast.WindowFunction):
|
||||
return ast.WindowFunction(
|
||||
type=None if self.clear_types else node.type,
|
||||
name=node.name,
|
||||
args=[self.visit(expr) for expr in node.args] if node.args else None,
|
||||
over_expr=self.visit(node.over_expr) if node.over_expr else None,
|
||||
over_identifier=node.over_identifier,
|
||||
)
|
||||
|
||||
def visit_window_frame_expr(self, node: ast.WindowFrameExpr):
|
||||
return ast.WindowFrameExpr(
|
||||
type=None if self.clear_types else node.type,
|
||||
frame_type=node.frame_type,
|
||||
frame_value=node.frame_value,
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# - `pip-compile --rebuild requirements.in`
|
||||
# - `pip-compile --rebuild requirements-dev.in`
|
||||
#
|
||||
antlr4-python3-runtime==4.11.1
|
||||
antlr4-python3-runtime==4.13.0
|
||||
amqp==2.6.0
|
||||
boto3==1.26.66
|
||||
celery==4.4.7
|
||||
|
||||
@@ -12,7 +12,7 @@ amqp==2.6.0
|
||||
# via
|
||||
# -r requirements.in
|
||||
# kombu
|
||||
antlr4-python3-runtime==4.11.1
|
||||
antlr4-python3-runtime==4.13.0
|
||||
# via -r requirements.in
|
||||
asgiref==3.3.2
|
||||
# via django
|
||||
|
||||
Reference in New Issue
Block a user