Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python-package/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ packages = ["src/querychat"]
include = ["src/querychat", "LICENSE", "README.md"]

[tool.uv]
dev-dependencies = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4"]
dev-dependencies = [
"ruff>=0.6.5",
"pyright>=1.1.401",
"tox-uv>=1.11.4",
"pytest>=8.4.0",
]

[tool.ruff]
src = ["src/querychat"]
Expand Down
144 changes: 104 additions & 40 deletions python-package/src/querychat/datasource.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from __future__ import annotations

from typing import ClassVar, Protocol
from typing import TYPE_CHECKING, ClassVar, Protocol

import duckdb
import narwhals as nw
import pandas as pd
from sqlalchemy import inspect, text
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.sql import sqltypes

if TYPE_CHECKING:
from sqlalchemy.engine import Connection, Engine


class DataSource(Protocol):
db_engine: ClassVar[str]
Expand Down Expand Up @@ -176,7 +178,7 @@ def __init__(self, engine: Engine, table_name: str):
if not inspector.has_table(table_name):
raise ValueError(f"Table '{table_name}' not found in database")

def get_schema(self, *, categorical_threshold: int) -> str:
def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912
"""
Generate schema information from database table.

Expand All @@ -189,12 +191,15 @@ def get_schema(self, *, categorical_threshold: int) -> str:

schema = [f"Table: {self._table_name}", "Columns:"]

# Build a single query to get all column statistics
select_parts = []
numeric_columns = []
text_columns = []

for col in columns:
# Get SQL type name
sql_type = self._get_sql_type_name(col["type"])
column_info = [f"- {col['name']} ({sql_type})"]
col_name = col["name"]

# For numeric columns, try to get range
# Check if column is numeric
if isinstance(
col["type"],
(
Expand All @@ -206,44 +211,103 @@ def get_schema(self, *, categorical_threshold: int) -> str:
sqltypes.DateTime,
sqltypes.BigInteger,
sqltypes.SmallInteger,
# sqltypes.Interval,
),
):
try:
query = text(
f"SELECT MIN({col['name']}), MAX({col['name']}) FROM {self._table_name}",
)
with self._get_connection() as conn:
result = conn.execute(query).fetchone()
if result and result[0] is not None and result[1] is not None:
column_info.append(f" Range: {result[0]} to {result[1]}")
except Exception:
pass # Skip range info if query fails

# For string/text columns, check if categorical
numeric_columns.append(col_name)
select_parts.extend(
[
f"MIN({col_name}) as {col_name}_min",
f"MAX({col_name}) as {col_name}_max",
],
)

# Check if column is text/string
elif isinstance(
col["type"],
(sqltypes.String, sqltypes.Text, sqltypes.Enum),
):
try:
count_query = text(
f"SELECT COUNT(DISTINCT {col['name']}) FROM {self._table_name}",
)
text_columns.append(col_name)
select_parts.append(
f"COUNT(DISTINCT {col_name}) as {col_name}_distinct_count",
)

# Execute single query to get all statistics
column_stats = {}
if select_parts:
try:
stats_query = text(
f"SELECT {', '.join(select_parts)} FROM {self._table_name}", # noqa: S608
)
with self._get_connection() as conn:
result = conn.execute(stats_query).fetchone()
if result:
# Convert result to dict for easier access
column_stats = dict(zip(result._fields, result))
except Exception: # noqa: S110
pass # Fall back to no statistics if query fails

# Get categorical values for text columns that are below threshold
categorical_values = {}
text_cols_to_query = []
for col_name in text_columns:
distinct_count_key = f"{col_name}_distinct_count"
if (
distinct_count_key in column_stats
and column_stats[distinct_count_key]
and column_stats[distinct_count_key] <= categorical_threshold
):
text_cols_to_query.append(col_name)

# Get categorical values in a single query if needed
if text_cols_to_query:
try:
# Build UNION query for all categorical columns
union_parts = [
f"SELECT '{col_name}' as column_name, {col_name} as value " # noqa: S608
f"FROM {self._table_name} WHERE {col_name} IS NOT NULL "
f"GROUP BY {col_name}"
for col_name in text_cols_to_query
]

if union_parts:
categorical_query = text(" UNION ALL ".join(union_parts))
with self._get_connection() as conn:
distinct_count = conn.execute(count_query).scalar()
if distinct_count and distinct_count <= categorical_threshold:
values_query = text(
f"SELECT DISTINCT {col['name']} FROM {self._table_name} "
f"WHERE {col['name']} IS NOT NULL",
)
values = [
str(row[0])
for row in conn.execute(values_query).fetchall()
]
values_str = ", ".join([f"'{v}'" for v in values])
column_info.append(f" Categorical values: {values_str}")
except Exception:
pass # Skip categorical info if query fails
results = conn.execute(categorical_query).fetchall()
for row in results:
col_name, value = row
if col_name not in categorical_values:
categorical_values[col_name] = []
categorical_values[col_name].append(str(value))
except Exception: # noqa: S110
pass # Skip categorical values if query fails

# Build schema description using collected statistics
for col in columns:
col_name = col["name"]
sql_type = self._get_sql_type_name(col["type"])
column_info = [f"- {col_name} ({sql_type})"]

# Add range info for numeric columns
if col_name in numeric_columns:
min_key = f"{col_name}_min"
max_key = f"{col_name}_max"
if (
min_key in column_stats
and max_key in column_stats
and column_stats[min_key] is not None
and column_stats[max_key] is not None
):
column_info.append(
f" Range: {column_stats[min_key]} to {column_stats[max_key]}",
)

# Add categorical values for text columns
elif col_name in categorical_values:
values = categorical_values[col_name]
# Remove duplicates and sort
unique_values = sorted(set(values))
values_str = ", ".join([f"'{v}'" for v in unique_values])
column_info.append(f" Categorical values: {values_str}")

schema.extend(column_info)

Expand Down Expand Up @@ -271,9 +335,9 @@ def get_data(self) -> pd.DataFrame:
The complete dataset as a pandas DataFrame

"""
return self.execute_query(f"SELECT * FROM {self._table_name}")
return self.execute_query(f"SELECT * FROM {self._table_name}") # noqa: S608

def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str:
def _get_sql_type_name(self, type_: sqltypes.TypeEngine) -> str: # noqa: PLR0911
"""Convert SQLAlchemy type to SQL type name."""
if isinstance(type_, sqltypes.Integer):
return "INTEGER"
Expand Down
14 changes: 6 additions & 8 deletions python-package/src/querychat/querychat.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,12 @@ def __getitem__(self, key: str) -> Any:
backwards compatibility only; new code should use the attributes
directly instead.
"""
if key == "chat":
return self.chat
elif key == "sql":
return self.sql
elif key == "title":
return self.title
elif key == "df":
return self.df
return {
"chat": self.chat,
"sql": self.sql,
"title": self.title,
"df": self.df,
}.get(key)


def system_prompt(
Expand Down
Empty file.
Loading
Loading