diff --git a/pybigquery/sqlalchemy_bigquery.py b/pybigquery/sqlalchemy_bigquery.py index e074c82d..648308dd 100644 --- a/pybigquery/sqlalchemy_bigquery.py +++ b/pybigquery/sqlalchemy_bigquery.py @@ -183,12 +183,19 @@ def visit_column(self, column, add_to_result_map=None, self.preparer.quote(tablename) + \ "." + name - def visit_label(self, *args, **kwargs): - # Use labels in GROUP BY clause - if len(kwargs) == 0 or len(kwargs) == 1: + def visit_label(self, *args, within_group_by=False, **kwargs): + # Use labels in GROUP BY clause. + # + # Flag set in the group_by_clause method. Works around missing + # equivalent to supports_simple_order_by_label for group by. + if within_group_by: kwargs['render_label_as_label'] = args[0] - result = super(BigQueryCompiler, self).visit_label(*args, **kwargs) - return result + return super(BigQueryCompiler, self).visit_label(*args, **kwargs) + + def group_by_clause(self, select, **kw): + return super(BigQueryCompiler, self).group_by_clause( + select, **kw, within_group_by=True + ) class BigQueryTypeCompiler(GenericTypeCompiler): diff --git a/test/test_sqlalchemy_bigquery.py b/test/test_sqlalchemy_bigquery.py index 2ed998eb..863da132 100644 --- a/test/test_sqlalchemy_bigquery.py +++ b/test/test_sqlalchemy_bigquery.py @@ -162,10 +162,14 @@ def query(): def query(table): col1 = literal_column("TIMESTAMP_TRUNC(timestamp, DAY)").label("timestamp_label") col2 = func.sum(table.c.integer) + # Test rendering of nested labels. Full expression should render in SELECT, but + # ORDER/GROUP BY should use label only. + col3 = func.sum(func.sum(table.c.integer.label("inner")).label("outer")).over().label('outer') query = ( select([ col1, col2, + col3, ]) .where(col1 < '2017-01-01 00:00:00') .group_by(col1) @@ -297,6 +301,33 @@ def test_group_by(session, table, session_using_test_dataset, table_using_test_d assert len(result) > 0 +def test_nested_labels(engine, table): + col = table.c.integer + exprs = [ + sqlalchemy.func.sum( + sqlalchemy.func.sum(col.label("inner") + ).label("outer")).over(), + sqlalchemy.func.sum( + sqlalchemy.case([[ + sqlalchemy.literal(True), + col.label("inner"), + ]]).label("outer") + ), + sqlalchemy.func.sum( + sqlalchemy.func.sum( + sqlalchemy.case([[ + sqlalchemy.literal(True), col.label("inner") + ]]).label("middle") + ).label("outer") + ).over(), + ] + for expr in exprs: + sql = str(expr.compile(engine)) + assert "inner" not in sql + assert "middle" not in sql + assert "outer" not in sql + + def test_session_query(session, table, session_using_test_dataset, table_using_test_dataset): for session, table in [(session, table), (session_using_test_dataset, table_using_test_dataset)]: col_concat = func.concat(table.c.string).label('concat')