From 2e845991ea8693af795055d033cc71337695801e Mon Sep 17 00:00:00 2001 From: Jacob Hayes Date: Fri, 25 Sep 2020 22:54:03 -0400 Subject: [PATCH 1/3] Add failing nested label test case --- test/test_sqlalchemy_bigquery.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_sqlalchemy_bigquery.py b/test/test_sqlalchemy_bigquery.py index 831ee351..28897cf7 100644 --- a/test/test_sqlalchemy_bigquery.py +++ b/test/test_sqlalchemy_bigquery.py @@ -162,10 +162,14 @@ def query(): def query(table): col1 = literal_column("TIMESTAMP_TRUNC(timestamp, DAY)").label("timestamp_label") col2 = func.sum(table.c.integer) + # Test rendering of nested labels. Full expression should render in SELECT, but + # ORDER/GROUP BY should use label only. + col3 = func.sum(func.sum(table.c.integer.label("inner")).label("outer")).over().label('outer') query = ( select([ col1, col2, + col3, ]) .where(col1 < '2017-01-01 00:00:00') .group_by(col1) From 32acaea55307058a2122d846ebbe01da76406462 Mon Sep 17 00:00:00 2001 From: Jacob Hayes Date: Fri, 25 Sep 2020 22:59:29 -0400 Subject: [PATCH 2/3] Fix rendering of nested label expressions --- pybigquery/sqlalchemy_bigquery.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/pybigquery/sqlalchemy_bigquery.py b/pybigquery/sqlalchemy_bigquery.py index 74573458..b8c2deb5 100644 --- a/pybigquery/sqlalchemy_bigquery.py +++ b/pybigquery/sqlalchemy_bigquery.py @@ -183,12 +183,17 @@ def visit_column(self, column, add_to_result_map=None, self.preparer.quote(tablename) + \ "." + name - def visit_label(self, *args, **kwargs): - # Use labels in GROUP BY clause - if len(kwargs) == 0 or len(kwargs) == 1: + def visit_label(self, *args, within_group_by=False, **kwargs): + # Use labels in GROUP BY clause. + # + # Flag set in the group_by_clause method. Works around missing equivalent to + # supports_simple_order_by_label for group by. + if within_group_by: kwargs['render_label_as_label'] = args[0] - result = super(BigQueryCompiler, self).visit_label(*args, **kwargs) - return result + return super(BigQueryCompiler, self).visit_label(*args, **kwargs) + + def group_by_clause(self, select, **kw): + return super(BigQueryCompiler, self).group_by_clause(select, **kw, within_group_by=True) class BigQueryTypeCompiler(GenericTypeCompiler): From efd11fe77c46dee182ac1d0319d2c7ff8a74d6ac Mon Sep 17 00:00:00 2001 From: Tim Swast Date: Wed, 18 Nov 2020 11:57:30 -0600 Subject: [PATCH 3/3] add test cases from issue --- pybigquery/sqlalchemy_bigquery.py | 8 +++++--- test/test_sqlalchemy_bigquery.py | 27 +++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/pybigquery/sqlalchemy_bigquery.py b/pybigquery/sqlalchemy_bigquery.py index 40c4e693..648308dd 100644 --- a/pybigquery/sqlalchemy_bigquery.py +++ b/pybigquery/sqlalchemy_bigquery.py @@ -186,14 +186,16 @@ def visit_column(self, column, add_to_result_map=None, def visit_label(self, *args, within_group_by=False, **kwargs): # Use labels in GROUP BY clause. # - # Flag set in the group_by_clause method. Works around missing equivalent to - # supports_simple_order_by_label for group by. + # Flag set in the group_by_clause method. Works around missing + # equivalent to supports_simple_order_by_label for group by. if within_group_by: kwargs['render_label_as_label'] = args[0] return super(BigQueryCompiler, self).visit_label(*args, **kwargs) def group_by_clause(self, select, **kw): - return super(BigQueryCompiler, self).group_by_clause(select, **kw, within_group_by=True) + return super(BigQueryCompiler, self).group_by_clause( + select, **kw, within_group_by=True + ) class BigQueryTypeCompiler(GenericTypeCompiler): diff --git a/test/test_sqlalchemy_bigquery.py b/test/test_sqlalchemy_bigquery.py index 82c04ff7..863da132 100644 --- a/test/test_sqlalchemy_bigquery.py +++ b/test/test_sqlalchemy_bigquery.py @@ -301,6 +301,33 @@ def test_group_by(session, table, session_using_test_dataset, table_using_test_d assert len(result) > 0 +def test_nested_labels(engine, table): + col = table.c.integer + exprs = [ + sqlalchemy.func.sum( + sqlalchemy.func.sum(col.label("inner") + ).label("outer")).over(), + sqlalchemy.func.sum( + sqlalchemy.case([[ + sqlalchemy.literal(True), + col.label("inner"), + ]]).label("outer") + ), + sqlalchemy.func.sum( + sqlalchemy.func.sum( + sqlalchemy.case([[ + sqlalchemy.literal(True), col.label("inner") + ]]).label("middle") + ).label("outer") + ).over(), + ] + for expr in exprs: + sql = str(expr.compile(engine)) + assert "inner" not in sql + assert "middle" not in sql + assert "outer" not in sql + + def test_session_query(session, table, session_using_test_dataset, table_using_test_dataset): for session, table in [(session, table), (session_using_test_dataset, table_using_test_dataset)]: col_concat = func.concat(table.c.string).label('concat')