diff --git a/pkg-py/CHANGELOG.md b/pkg-py/CHANGELOG.md
index 725a47dca..7adf224d3 100644
--- a/pkg-py/CHANGELOG.md
+++ b/pkg-py/CHANGELOG.md
@@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### New features
+* Added `PolarsLazySource` to support Polars LazyFrames as data sources. Data stays lazy until the render boundary, enabling efficient handling of large datasets. Pass a `polars.LazyFrame` directly to `QueryChat()` and queries will be executed lazily via Polars' SQLContext.
+
* Added support for Gradio, Dash, and Streamlit web frameworks in addition to Shiny. Import from the new submodules:
* `from querychat.gradio import QueryChat`
* `from querychat.dash import QueryChat`
diff --git a/pkg-py/docs/data-sources.qmd b/pkg-py/docs/data-sources.qmd
index a9a7a5446..19f5aed08 100644
--- a/pkg-py/docs/data-sources.qmd
+++ b/pkg-py/docs/data-sources.qmd
@@ -3,11 +3,12 @@ title: Data Sources
lightbox: true
---
-`querychat` supports several different data sources, including:
+`querychat` supports several different data sources, including:
1. Any [narwhals-compatible](https://narwhals-dev.github.io/narwhals/) data frame.
-2. Any [SQLAlchemy](https://www.sqlalchemy.org/) database.
-3. A custom [DataSource](reference/types.DataSource.qmd) interface/protocol.
+2. Polars LazyFrames for efficient handling of large datasets.
+3. Any [SQLAlchemy](https://www.sqlalchemy.org/) database.
+4. A custom [DataSource](reference/types.DataSource.qmd) interface/protocol.
The sections below describe how to use each type of data source with `querychat`.
@@ -63,7 +64,68 @@ app = qc.app()
:::
-If you're [building an app](build.qmd), note you can read the queried data frame reactively using the `df()` method, which returns a `narwhals.DataFrame`. Call `.to_native()` on the result to get the underlying pandas or polars DataFrame.
+If you're [building an app](build.qmd), note you can read the queried data frame reactively using the `df()` method, which returns a `narwhals.DataFrame` (or `narwhals.LazyFrame` for lazy sources). Call `.to_native()` on the result to get the underlying pandas or polars DataFrame.
+
+## Polars LazyFrames {#lazy-frames}
+
+For large datasets, you can use [Polars LazyFrames](https://docs.pola.rs/user-guide/lazy/using/) to keep data on disk until it's actually needed. This is particularly useful when:
+
+- Your dataset is too large to fit comfortably in memory
+- You only need filtered or aggregated subsets of the data
+- You want faster startup times for your application
+
+With lazy evaluation, data stays on disk and queries are optimized by Polars before execution. Only the final results are loaded into memory.
+
+```{.python filename="lazy-app.py"}
+import polars as pl
+from querychat import QueryChat
+
+# Scan a large parquet file (doesn't load data yet!)
+lf = pl.scan_parquet("large_dataset.parquet")
+
+# Pass the LazyFrame directly to QueryChat
+qc = QueryChat(lf, "sales")
+app = qc.app()
+```
+
+::: {.callout-tip}
+### Why use lazy evaluation?
+
+The lazy approach can be significantly faster for large datasets because:
+
+- **Deferred loading**: Data stays on disk until actually needed, so startup is nearly instant
+- **Query optimization**: Polars optimizes the query plan before execution, potentially skipping unnecessary columns and rows
+- **Reduced memory**: Only the filtered/aggregated results are loaded into memory, not the entire dataset
+
+This is especially beneficial when users typically query small subsets of a large dataset.
+:::
+
+You can create LazyFrames from various sources:
+
+```python
+# From parquet (most efficient)
+lf = pl.scan_parquet("data.parquet")
+
+# From CSV
+lf = pl.scan_csv("data.csv")
+
+# From multiple files
+lf = pl.scan_parquet("data/*.parquet")
+
+# From an existing DataFrame
+df = pl.read_csv("data.csv")
+lf = df.lazy()
+```
+
+When using a LazyFrame source, the `df()` method returns a `narwhals.LazyFrame`. Call `.collect()` to materialize the results when needed:
+
+```python
+# Get the lazy result
+result_lazy = qc.df()
+
+# Materialize when ready
+result_df = result_lazy.collect()
+```
## Databases
diff --git a/pkg-py/src/querychat/_dash.py b/pkg-py/src/querychat/_dash.py
index 226fa4bed..fbad259bc 100644
--- a/pkg-py/src/querychat/_dash.py
+++ b/pkg-py/src/querychat/_dash.py
@@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Literal, Optional, cast
+import narwhals.stable.v1 as nw
from chatlas import Turn
from ._dash_ui import IDs, card_ui, chat_container_ui, chat_messages_ui
@@ -374,6 +375,9 @@ def update_display(state_data: AppStateDict, reset_clicks):
sql_code = f"```sql\n{state.get_display_sql()}\n```"
df = state.get_current_data()
+ # Collect if lazy before accessing .to_pandas() or .shape
+ if isinstance(df, nw.LazyFrame):
+ df = df.collect()
display_df = df.to_pandas()
table_data = display_df.to_dict("records")
@@ -404,8 +408,11 @@ def update_display(state_data: AppStateDict, reset_clicks):
)
def export_csv(n_clicks: int, state_data: AppStateDict):
state = deserialize_state(state_data)
- df = state.get_current_data().to_pandas()
- return send_data_frame(df.to_csv, "querychat_data.csv", index=False)
+ df = state.get_current_data()
+ # Collect if lazy before converting to pandas
+ if isinstance(df, nw.LazyFrame):
+ df = df.collect()
+ return send_data_frame(df.to_pandas().to_csv, "querychat_data.csv", index=False)
def register_chat_callbacks(
diff --git a/pkg-py/src/querychat/_datasource.py b/pkg-py/src/querychat/_datasource.py
index 563775026..65ae90a5f 100644
--- a/pkg-py/src/querychat/_datasource.py
+++ b/pkg-py/src/querychat/_datasource.py
@@ -1,7 +1,8 @@
from __future__ import annotations
from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, Any, Literal, Union
import duckdb
import narwhals.stable.v1 as nw
@@ -12,13 +13,39 @@
from ._utils import check_query
if TYPE_CHECKING:
+ import polars as pl
from sqlalchemy.engine import Connection, Engine
+DataOrLazyFrame = Union[nw.DataFrame, nw.LazyFrame]
+
class MissingColumnsError(ValueError):
"""Raised when a query result is missing required columns."""
+@dataclass
+class ColumnMeta:
+ """Metadata for a single column in a schema."""
+
+ name: str
+ """Column name."""
+
+ sql_type: str
+ """SQL type name (e.g., 'INTEGER', 'TEXT', 'DATE')."""
+
+ kind: Literal["numeric", "text", "date", "other"]
+ """Column category for determining what stats to collect."""
+
+ min_val: Any = None
+ """Minimum value for numeric/date columns."""
+
+ max_val: Any = None
+ """Maximum value for numeric/date columns."""
+
+ categories: list[str] = field(default_factory=list)
+ """Unique values for text columns below the categorical threshold."""
+
+
class DataSource(ABC):
"""
An abstract class defining the interface for data sources used by QueryChat.
@@ -58,7 +85,7 @@ def get_schema(self, *, categorical_threshold: int) -> str:
...
@abstractmethod
- def execute_query(self, query: str) -> nw.DataFrame:
+ def execute_query(self, query: str) -> DataOrLazyFrame:
"""
Execute SQL query and return results as DataFrame.
@@ -104,7 +131,7 @@ def test_query(
...
@abstractmethod
- def get_data(self) -> nw.DataFrame:
+ def get_data(self) -> DataOrLazyFrame:
"""
Return the unfiltered data as a DataFrame.
@@ -143,12 +170,12 @@ def __init__(self, df: nw.DataFrame, table_name: str):
Parameters
----------
df
- The DataFrame to wrap (pandas, polars, or any narwhals-compatible frame)
+ A narwhals DataFrame
table_name
Name of the table in SQL queries
"""
- self._df = nw.from_native(df) if not isinstance(df, nw.DataFrame) else df
+ self._df = df
self.table_name = table_name
self._conn = duckdb.connect(database=":memory:")
@@ -635,3 +662,229 @@ def cleanup(self) -> None:
"""
if self._engine:
self._engine.dispose()
+
+
+class PolarsLazySource(DataSource):
+ """
+ A DataSource implementation for Polars LazyFrames.
+
+ Keeps data lazy throughout the query pipeline. Results from execute_query()
+ are LazyFrames that can be chained with additional operations before
+ collecting.
+ """
+
+ table_name: str
+
+ def __init__(self, lf: nw.LazyFrame, table_name: str):
+ """
+ Initialize with a narwhals LazyFrame wrapping a Polars LazyFrame.
+
+ Parameters
+ ----------
+ lf
+ A narwhals LazyFrame (wrapping a Polars LazyFrame)
+ table_name
+ Name of the table in SQL queries
+
+ """
+ import polars as pl
+
+ self.table_name = table_name
+
+ # Get native Polars LazyFrame for SQLContext
+ self._lf: pl.LazyFrame = lf.to_native()
+ if not isinstance(self._lf, pl.LazyFrame):
+ raise TypeError(f"Expected Polars LazyFrame, got {type(self._lf).__name__}")
+
+ self._ctx = pl.SQLContext({table_name: self._lf})
+
+ # Cache schema (no data collection needed)
+ self._schema = self._lf.collect_schema()
+ self._colnames = list(self._schema.keys())
+
+ def get_db_type(self) -> str:
+ """Get the database type."""
+ return "Polars"
+
+ def get_schema(self, *, categorical_threshold: int) -> str:
+ """Generate schema information from LazyFrame using lazy aggregates."""
+ # Build column metadata (classification happens here)
+ columns = [
+ self._make_column_meta(name, dtype) for name, dtype in self._schema.items()
+ ]
+
+ # Add stats to the metadata and format schema string
+ self._add_column_stats(columns, self._lf, categorical_threshold)
+ return self._format_schema(self.table_name, columns)
+
+ def execute_query(self, query: str) -> nw.LazyFrame:
+ """
+ Execute SQL query and return results as LazyFrame.
+
+ Parameters
+ ----------
+ query
+ SQL query to execute
+
+ Returns
+ -------
+ :
+ Query results as a narwhals LazyFrame
+
+ """
+ check_query(query)
+ result = self._ctx.execute(query)
+ return nw.from_native(result)
+
+ def test_query(
+ self, query: str, *, require_all_columns: bool = False
+ ) -> nw.DataFrame:
+ """
+ Test SQL query validity by executing and collecting one row.
+
+ Parameters
+ ----------
+ query
+ SQL query to test
+ require_all_columns
+ If True, validates that result includes all original table columns
+
+ Returns
+ -------
+ :
+ Query results as a narwhals DataFrame with at most one row
+
+ """
+ check_query(query)
+
+ test_sql = f"SELECT * FROM ({query}) AS subquery LIMIT 1"
+ lf = self._ctx.execute(test_sql)
+
+ # Actually collect to catch runtime errors (e.g., division by zero)
+ result = nw.from_native(lf.collect())
+
+ if require_all_columns:
+ result_columns = set(result.columns)
+ missing = set(self._colnames) - result_columns
+ if missing:
+ missing_list = ", ".join(f"'{c}'" for c in sorted(missing))
+ original_list = ", ".join(f"'{c}'" for c in self._colnames)
+ raise MissingColumnsError(
+ f"Query result missing required columns: {missing_list}. "
+ f"The query must return all original table columns. "
+ f"Original columns: {original_list}"
+ )
+
+ return result
+
+ def get_data(self) -> nw.LazyFrame:
+ """
+ Return the unfiltered data as a LazyFrame.
+
+ Returns
+ -------
+ :
+ The original LazyFrame
+
+ """
+ return nw.from_native(self._lf)
+
+ def cleanup(self) -> None:
+ """Clean up resources (no-op for Polars)."""
+
+ @staticmethod
+ def _make_column_meta(name: str, dtype: pl.DataType) -> ColumnMeta:
+ import polars as pl
+
+ if dtype.is_numeric():
+ kind = "numeric"
+ sql_type = "INTEGER" if dtype.is_integer() else "FLOAT"
+ elif dtype == pl.String:
+ kind = "text"
+ sql_type = "TEXT"
+ elif dtype == pl.Date:
+ kind = "date"
+ sql_type = "DATE"
+ elif dtype == pl.Datetime:
+ kind = "date"
+ sql_type = "TIMESTAMP"
+ elif dtype == pl.Boolean:
+ kind = "other"
+ sql_type = "BOOLEAN"
+ elif dtype == pl.Time:
+ kind = "other"
+ sql_type = "TIME"
+ else:
+ kind = "other"
+ sql_type = "TEXT"
+
+ return ColumnMeta(name=name, sql_type=sql_type, kind=kind)
+
+ @staticmethod
+ def _add_column_stats(
+ columns: list[ColumnMeta],
+ lf: pl.LazyFrame,
+ categorical_threshold: int,
+ ) -> None:
+ import polars as pl
+
+ # Build aggregation expressions based on column kinds
+ agg_exprs: list[pl.Expr] = []
+ for col in columns:
+ if col.kind in ("numeric", "date"):
+ agg_exprs.append(pl.col(col.name).min().alias(f"{col.name}__min"))
+ agg_exprs.append(pl.col(col.name).max().alias(f"{col.name}__max"))
+ elif col.kind == "text":
+ agg_exprs.append(
+ pl.col(col.name).n_unique().alias(f"{col.name}__nunique")
+ )
+
+ if not agg_exprs:
+ return
+
+ # First scan: collect all aggregate statistics
+ stats = lf.select(agg_exprs).collect().row(0, named=True)
+
+ # Add min/max for numeric/date columns
+ for col in columns:
+ if col.kind in ("numeric", "date"):
+ col.min_val = stats.get(f"{col.name}__min")
+ col.max_val = stats.get(f"{col.name}__max")
+
+ # Find text columns that qualify as categorical
+ categorical_cols = [
+ col
+ for col in columns
+ if col.kind == "text"
+ and (nunique := stats.get(f"{col.name}__nunique"))
+ and nunique <= categorical_threshold
+ ]
+
+ if not categorical_cols:
+ return
+
+ # Second scan: batch collect unique values for all categorical columns
+ unique_exprs = [
+ pl.col(col.name).drop_nulls().unique().implode().alias(col.name)
+ for col in categorical_cols
+ ]
+ unique_row = lf.select(unique_exprs).collect().row(0, named=True)
+
+ for col in categorical_cols:
+ col.categories = unique_row[col.name]
+
+ @staticmethod
+ def _format_schema(table_name: str, columns: list[ColumnMeta]) -> str:
+ """Format column metadata into schema string."""
+ lines = [f"Table: {table_name}", "Columns:"]
+
+ for col in columns:
+ lines.append(f"- {col.name} ({col.sql_type})")
+
+ if col.kind in ("numeric", "date"):
+ lines.append(f" Range: {col.min_val} to {col.max_val}")
+ elif col.categories:
+ cats = ", ".join(f"'{v}'" for v in col.categories)
+ lines.append(f" Categorical values: {cats}")
+
+ return "\n".join(lines)
diff --git a/pkg-py/src/querychat/_gradio.py b/pkg-py/src/querychat/_gradio.py
index c446388bc..a6d874101 100644
--- a/pkg-py/src/querychat/_gradio.py
+++ b/pkg-py/src/querychat/_gradio.py
@@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Optional
+import narwhals.stable.v1 as nw
from gradio.context import Context
from ._querychat_base import TOOL_GROUPS, QueryChatBase
@@ -294,6 +295,9 @@ def update_displays(state_dict: AppStateDict):
)
df = self.df(state_dict)
+ # Collect if lazy before accessing .shape
+ if isinstance(df, nw.LazyFrame):
+ df = df.collect()
data_info_parts = []
if error:
diff --git a/pkg-py/src/querychat/_querychat_base.py b/pkg-py/src/querychat/_querychat_base.py
index 65995cbab..fd0d8fad6 100644
--- a/pkg-py/src/querychat/_querychat_base.py
+++ b/pkg-py/src/querychat/_querychat_base.py
@@ -12,7 +12,12 @@
import narwhals.stable.v1 as nw
import sqlalchemy
-from ._datasource import DataFrameSource, DataSource, SQLAlchemySource
+from ._datasource import (
+ DataFrameSource,
+ DataSource,
+ PolarsLazySource,
+ SQLAlchemySource,
+)
from ._shiny_module import GREETING_PROMPT
from ._system_prompt import QueryChatSystemPrompt
from ._utils import MISSING, MISSING_TYPE
@@ -174,13 +179,28 @@ def normalize_data_source(
return data_source
if isinstance(data_source, sqlalchemy.Engine):
return SQLAlchemySource(data_source, table_name)
+
src = nw.from_native(data_source, pass_through=True)
+
if isinstance(src, nw.DataFrame):
return DataFrameSource(src, table_name)
+
if isinstance(src, nw.LazyFrame):
- raise NotImplementedError(
- "LazyFrame data sources are not yet supported (they will be soon)."
+ native = src.to_native()
+ try:
+ import polars as pl
+
+ if isinstance(native, pl.LazyFrame):
+ return PolarsLazySource(src, table_name)
+ except ImportError:
+ pass
+ raise TypeError(
+ f"Unsupported LazyFrame backend: {type(native).__name__} from {type(native).__module__}. "
+ "Currently only Polars LazyFrames are supported. "
+ "If you believe this type should be supported, please open an issue at "
+ "https://github.com/posit-dev/querychat/issues"
)
+
raise TypeError(
f"Unsupported data source type: {type(data_source)}. "
"If you believe this type should be supported, please open an issue at "
diff --git a/pkg-py/src/querychat/_querychat_core.py b/pkg-py/src/querychat/_querychat_core.py
index 35945f4e2..c2bb4c82b 100644
--- a/pkg-py/src/querychat/_querychat_core.py
+++ b/pkg-py/src/querychat/_querychat_core.py
@@ -31,9 +31,7 @@
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator
- import narwhals.stable.v1 as nw
-
- from ._datasource import DataSource
+ from ._datasource import DataOrLazyFrame, DataSource
ClientFactory = Callable[
@@ -72,7 +70,7 @@ def _client_factory(
"""Create a chat client with dashboard callbacks."""
return self.client(update_dashboard=update_cb, reset_dashboard=reset_cb) # type: ignore[attr-defined]
- def df(self, state: AppStateDict | None) -> nw.DataFrame:
+ def df(self, state: AppStateDict | None) -> DataOrLazyFrame:
"""
Get the current DataFrame from state.
@@ -85,6 +83,7 @@ def df(self, state: AppStateDict | None) -> nw.DataFrame:
-------
:
The filtered data if a SQL query is active, otherwise the full dataset.
+ Returns a LazyFrame if the data source is lazy.
"""
sql = state.get("sql") if state else None
@@ -210,7 +209,7 @@ def reset_dashboard(self) -> None:
self.title = None
self.error = None
- def get_current_data(self) -> nw.DataFrame:
+ def get_current_data(self) -> DataOrLazyFrame:
"""Get current data, falling back to default if query fails."""
if self.sql:
try:
diff --git a/pkg-py/src/querychat/_shiny.py b/pkg-py/src/querychat/_shiny.py
index 2886b74e7..a8cbd7f2b 100644
--- a/pkg-py/src/querychat/_shiny.py
+++ b/pkg-py/src/querychat/_shiny.py
@@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, Literal, Optional, overload
+import narwhals.stable.v1 as nw
from shiny.express._stub_session import ExpressStubSession
from shiny.session import get_current_session
from shinychat import output_markdown_stream
@@ -16,10 +17,11 @@
from pathlib import Path
import chatlas
- import narwhals.stable.v1 as nw
import sqlalchemy
from narwhals.stable.v1.typing import IntoFrame
+ from ._datasource import DataOrLazyFrame
+
class QueryChat(QueryChatBase):
"""
@@ -239,7 +241,11 @@ def _():
@render.data_frame
def dt():
- return vals.df()
+ df = vals.df()
+ # Collect if lazy
+ if isinstance(df, nw.LazyFrame):
+ df = df.collect()
+ return df
@render.ui
def sql_output():
@@ -605,16 +611,16 @@ def ui(self, *, id: Optional[str] = None, **kwargs):
"""
return mod_ui(id or self.id, **kwargs)
- def df(self) -> nw.DataFrame:
+ def df(self) -> DataOrLazyFrame:
"""
Reactively read the current filtered data frame that is in effect.
Returns
-------
:
- The current filtered data frame as a narwhals DataFrame. If no query
- has been set, this will return the unfiltered data frame from the
- data source.
+ The current filtered data frame as a narwhals DataFrame or LazyFrame.
+ If the data source is lazy, returns a LazyFrame. If no query has been
+ set, this will return the unfiltered data from the data source.
"""
return self._vals.df()
diff --git a/pkg-py/src/querychat/_shiny_module.py b/pkg-py/src/querychat/_shiny_module.py
index 80c798a69..7d10e4ec2 100644
--- a/pkg-py/src/querychat/_shiny_module.py
+++ b/pkg-py/src/querychat/_shiny_module.py
@@ -17,12 +17,11 @@
if TYPE_CHECKING:
from collections.abc import Callable
- import narwhals.stable.v1 as nw
from shiny.bookmark import BookmarkState, RestoreState
from shiny import Inputs, Outputs, Session
- from ._datasource import DataSource
+ from ._datasource import DataOrLazyFrame, DataSource
from .types import UpdateDashboardData
ReactiveString = reactive.Value[str]
@@ -62,9 +61,10 @@ class ServerValues:
Attributes
----------
df
- A reactive Calc that returns the current filtered data frame. If no SQL
- query has been set, this returns the unfiltered data from the data source.
- Call it like `.df()` to reactively read the current data frame.
+ A reactive Calc that returns the current filtered data frame or lazy frame.
+ If the data source is lazy, returns a LazyFrame. If no SQL query has been
+ set, this returns the unfiltered data from the data source.
+ Call it like `.df()` to reactively read the current data.
sql
A reactive Value containing the current SQL query string. Access the value
by calling `.sql()`, or set it with `.sql.set("SELECT ...")`.
@@ -81,7 +81,7 @@ class ServerValues:
"""
- df: Callable[[], nw.DataFrame]
+ df: Callable[[], DataOrLazyFrame]
sql: ReactiveStringOrNone
title: ReactiveStringOrNone
client: chatlas.Chat
diff --git a/pkg-py/src/querychat/_streamlit.py b/pkg-py/src/querychat/_streamlit.py
index c2453bc34..a0f1ea611 100644
--- a/pkg-py/src/querychat/_streamlit.py
+++ b/pkg-py/src/querychat/_streamlit.py
@@ -4,6 +4,8 @@
from typing import TYPE_CHECKING, Optional
+import narwhals.stable.v1 as nw
+
from ._querychat_base import TOOL_GROUPS, QueryChatBase
from ._querychat_core import (
GREETING_PROMPT,
@@ -17,10 +19,11 @@
from pathlib import Path
import chatlas
- import narwhals.stable.v1 as nw
import sqlalchemy
from narwhals.stable.v1.typing import IntoFrame
+ from ._datasource import DataOrLazyFrame
+
class QueryChat(QueryChatBase):
"""
@@ -181,8 +184,8 @@ def ui(self) -> None:
st.rerun()
- def df(self) -> nw.DataFrame:
- """Get the current filtered data frame."""
+ def df(self) -> DataOrLazyFrame:
+ """Get the current filtered data frame (or LazyFrame if data source is lazy)."""
return self._get_state().get_current_data()
def sql(self) -> str | None:
@@ -235,6 +238,9 @@ def _render_main_content(self) -> None:
st.subheader("Data view")
df = state.get_current_data()
+ # Collect if lazy before accessing .shape or displaying
+ if isinstance(df, nw.LazyFrame):
+ df = df.collect()
if state.error:
st.error(state.error)
st.dataframe(df, use_container_width=True, height=400, hide_index=True)
diff --git a/pkg-py/src/querychat/_utils.py b/pkg-py/src/querychat/_utils.py
index 64b3755c2..1eb7c2052 100644
--- a/pkg-py/src/querychat/_utils.py
+++ b/pkg-py/src/querychat/_utils.py
@@ -7,6 +7,19 @@
from typing import TYPE_CHECKING, Literal, Optional
import narwhals.stable.v1 as nw
+from great_tables import GT
+
+if TYPE_CHECKING:
+ from ._datasource import DataOrLazyFrame
+
+
+class MISSING_TYPE: # noqa: N801
+ """
+ A singleton representing a missing value.
+ """
+
+
+MISSING = MISSING_TYPE()
class UnsafeQueryError(ValueError):
@@ -81,19 +94,6 @@ def check_query(query: str) -> None:
)
-if TYPE_CHECKING:
- from narwhals.stable.v1.typing import IntoFrame
-
-
-class MISSING_TYPE: # noqa: N801
- """
- A singleton representing a missing value.
- """
-
-
-MISSING = MISSING_TYPE()
-
-
@contextmanager
def temp_env_vars(env_vars: dict[str, Optional[str]]):
"""
@@ -196,7 +196,7 @@ def querychat_tool_starts_open(action: Literal["update", "query", "reset"]) -> b
return action != "reset"
-def df_to_html(df: IntoFrame, maxrows: int = 5) -> str:
+def df_to_html(df: DataOrLazyFrame, maxrows: int = 5) -> str:
"""
Convert a DataFrame to a Bootstrap-styled HTML table for display in chat.
@@ -213,38 +213,16 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str:
HTML string representation of the table
"""
- ndf = nw.from_native(df)
-
- if isinstance(ndf, (nw.LazyFrame, nw.DataFrame)):
- df_short = ndf.lazy().head(maxrows).collect()
- nrow_full = ndf.lazy().select(nw.len()).collect().item()
- else:
- raise TypeError(
- "Must be able to convert `df` into a Narwhals DataFrame or LazyFrame",
- )
-
- # Build simple Bootstrap-styled HTML table
- columns = df_short.columns
- rows = df_short.rows()
-
- # Table header
- header_cells = "".join(f"
{col} | " for col in columns)
- header = f"{header_cells}
"
-
- # Table body
- body_rows = []
- for row in rows:
- cells = "".join(f"{val} | " for val in row)
- body_rows.append(f"{cells}
")
- body = f"{''.join(body_rows)}"
+ if isinstance(df, nw.DataFrame):
+ df = df.lazy()
- # Use Bootstrap table classes
- table_html = f''
+ df_short = df.head(maxrows).collect()
+ gt_tbl = GT(df_short.to_native())
+ table_html = gt_tbl.as_raw_html(make_page=False)
# Add note about truncated rows if needed
- if len(df_short) != nrow_full:
- rows_notice = f"\n\n*(Showing {maxrows} of {nrow_full} rows)*\n"
- else:
- rows_notice = ""
+ nrow_full = df.select(nw.len()).collect().item()
+ if nrow_full > maxrows:
+ table_html += f"\n\n*(Showing {maxrows} of {nrow_full} rows)*\n"
- return table_html + rows_notice
+ return table_html
diff --git a/pkg-py/src/querychat/tools.py b/pkg-py/src/querychat/tools.py
index df344a268..a1a46b005 100644
--- a/pkg-py/src/querychat/tools.py
+++ b/pkg-py/src/querychat/tools.py
@@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Any, Protocol, TypedDict, runtime_checkable
import chevron
+import narwhals.stable.v1 as nw
from chatlas import ContentToolResult, Tool
from shinychat.types import ToolResultDisplay
@@ -232,6 +233,8 @@ def query(query: str, _intent: str = "") -> ContentToolResult:
try:
result_df = data_source.execute_query(query)
+ if isinstance(result_df, nw.LazyFrame):
+ result_df = result_df.collect()
value = result_df.rows(named=True)
# Format table results
diff --git a/pkg-py/tests/test_dataframe_source.py b/pkg-py/tests/test_dataframe_source.py
index 7602583af..53f442b8f 100644
--- a/pkg-py/tests/test_dataframe_source.py
+++ b/pkg-py/tests/test_dataframe_source.py
@@ -40,11 +40,6 @@ def narwhals_df(pandas_df):
class TestDataFrameSourceInit:
"""Tests for DataFrameSource initialization."""
- def test_init_with_pandas_dataframe(self, pandas_df):
- """Test that DataFrameSource accepts a pandas DataFrame."""
- source = DataFrameSource(pandas_df, "test_table")
- assert source.table_name == "test_table"
-
def test_init_with_narwhals_dataframe(self, narwhals_df):
"""Test that DataFrameSource accepts a narwhals DataFrame."""
source = DataFrameSource(narwhals_df, "test_table")
@@ -52,37 +47,37 @@ def test_init_with_narwhals_dataframe(self, narwhals_df):
@pytest.mark.skipif(not HAS_POLARS_WITH_PYARROW, reason="polars or pyarrow not installed")
def test_init_with_polars_dataframe(self):
- """Test that DataFrameSource accepts a polars DataFrame."""
+ """Test that DataFrameSource accepts a narwhals-wrapped polars DataFrame."""
polars_df = pl.DataFrame(
{
"id": [1, 2, 3],
"name": ["Alice", "Bob", "Charlie"],
}
)
- source = DataFrameSource(polars_df, "test_table")
+ source = DataFrameSource(nw.from_native(polars_df), "test_table")
assert source.table_name == "test_table"
class TestDataFrameSourceExecuteQuery:
"""Tests for DataFrameSource.execute_query method."""
- def test_execute_query_returns_narwhals_dataframe(self, pandas_df):
+ def test_execute_query_returns_narwhals_dataframe(self, narwhals_df):
"""Test that execute_query returns a narwhals DataFrame."""
- source = DataFrameSource(pandas_df, "employees")
+ source = DataFrameSource(narwhals_df, "employees")
result = source.execute_query("SELECT * FROM employees")
assert isinstance(result, nw.DataFrame)
- def test_execute_query_select_all(self, pandas_df):
+ def test_execute_query_select_all(self, narwhals_df):
"""Test SELECT * query."""
- source = DataFrameSource(pandas_df, "employees")
+ source = DataFrameSource(narwhals_df, "employees")
result = source.execute_query("SELECT * FROM employees")
assert result.shape == (5, 5)
assert set(result.columns) == {"id", "name", "age", "salary", "department"}
- def test_execute_query_with_filter(self, pandas_df):
+ def test_execute_query_with_filter(self, narwhals_df):
"""Test query with WHERE clause."""
- source = DataFrameSource(pandas_df, "employees")
+ source = DataFrameSource(narwhals_df, "employees")
result = source.execute_query(
"SELECT * FROM employees WHERE department = 'Engineering'"
)
@@ -91,9 +86,9 @@ def test_execute_query_with_filter(self, pandas_df):
departments = result["department"].unique().to_list()
assert departments == ["Engineering"]
- def test_execute_query_with_aggregation(self, pandas_df):
+ def test_execute_query_with_aggregation(self, narwhals_df):
"""Test query with aggregation."""
- source = DataFrameSource(pandas_df, "employees")
+ source = DataFrameSource(narwhals_df, "employees")
result = source.execute_query(
"SELECT department, AVG(salary) as avg_salary FROM employees GROUP BY department"
)
@@ -102,17 +97,17 @@ def test_execute_query_with_aggregation(self, pandas_df):
assert "department" in result.columns
assert "avg_salary" in result.columns
- def test_execute_query_select_columns(self, pandas_df):
+ def test_execute_query_select_columns(self, narwhals_df):
"""Test selecting specific columns."""
- source = DataFrameSource(pandas_df, "employees")
+ source = DataFrameSource(narwhals_df, "employees")
result = source.execute_query("SELECT name, age FROM employees")
assert result.shape == (5, 2)
assert list(result.columns) == ["name", "age"]
- def test_execute_query_order_by(self, pandas_df):
+ def test_execute_query_order_by(self, narwhals_df):
"""Test query with ORDER BY clause."""
- source = DataFrameSource(pandas_df, "employees")
+ source = DataFrameSource(narwhals_df, "employees")
result = source.execute_query(
"SELECT name, age FROM employees ORDER BY age DESC"
)
@@ -120,9 +115,9 @@ def test_execute_query_order_by(self, pandas_df):
ages = result["age"].to_list()
assert ages == sorted(ages, reverse=True)
- def test_execute_query_empty_result(self, pandas_df):
+ def test_execute_query_empty_result(self, narwhals_df):
"""Test query that returns no rows."""
- source = DataFrameSource(pandas_df, "employees")
+ source = DataFrameSource(narwhals_df, "employees")
result = source.execute_query(
"SELECT * FROM employees WHERE age > 100"
)
@@ -134,27 +129,27 @@ def test_execute_query_empty_result(self, pandas_df):
class TestDataFrameSourceGetData:
"""Tests for DataFrameSource.get_data method."""
- def test_get_data_returns_narwhals_dataframe(self, pandas_df):
+ def test_get_data_returns_narwhals_dataframe(self, narwhals_df):
"""Test that get_data returns a narwhals DataFrame."""
- source = DataFrameSource(pandas_df, "employees")
+ source = DataFrameSource(narwhals_df, "employees")
result = source.get_data()
assert isinstance(result, nw.DataFrame)
- def test_get_data_returns_full_dataset(self, pandas_df):
+ def test_get_data_returns_full_dataset(self, narwhals_df):
"""Test that get_data returns all rows."""
- source = DataFrameSource(pandas_df, "employees")
+ source = DataFrameSource(narwhals_df, "employees")
result = source.get_data()
- assert result.shape == pandas_df.shape
- assert set(result.columns) == set(pandas_df.columns)
+ assert result.shape == narwhals_df.shape
+ assert set(result.columns) == set(narwhals_df.columns)
- def test_get_data_preserves_data(self, pandas_df):
+ def test_get_data_preserves_data(self, narwhals_df):
"""Test that get_data preserves data values."""
- source = DataFrameSource(pandas_df, "employees")
+ source = DataFrameSource(narwhals_df, "employees")
result = source.get_data()
# Check that the data matches
- original_names = sorted(pandas_df["name"].tolist())
+ original_names = sorted(narwhals_df["name"].to_list())
result_names = sorted(result["name"].to_list())
assert original_names == result_names
@@ -162,25 +157,25 @@ def test_get_data_preserves_data(self, pandas_df):
class TestDataFrameSourceGetSchema:
"""Tests for DataFrameSource.get_schema method."""
- def test_get_schema_includes_table_name(self, pandas_df):
+ def test_get_schema_includes_table_name(self, narwhals_df):
"""Test that schema includes table name."""
- source = DataFrameSource(pandas_df, "employees")
+ source = DataFrameSource(narwhals_df, "employees")
schema = source.get_schema(categorical_threshold=10)
assert "Table: employees" in schema
assert "Columns:" in schema
- def test_get_schema_includes_all_columns(self, pandas_df):
+ def test_get_schema_includes_all_columns(self, narwhals_df):
"""Test that schema includes all columns."""
- source = DataFrameSource(pandas_df, "employees")
+ source = DataFrameSource(narwhals_df, "employees")
schema = source.get_schema(categorical_threshold=10)
- for col in pandas_df.columns:
+ for col in narwhals_df.columns:
assert f"- {col} (" in schema
- def test_get_schema_numeric_ranges(self, pandas_df):
+ def test_get_schema_numeric_ranges(self, narwhals_df):
"""Test that numeric columns include range information."""
- source = DataFrameSource(pandas_df, "employees")
+ source = DataFrameSource(narwhals_df, "employees")
schema = source.get_schema(categorical_threshold=10)
# Age should have range
@@ -188,9 +183,9 @@ def test_get_schema_numeric_ranges(self, pandas_df):
# Salary should have range
assert "Range: 50000.0 to 70000.0" in schema
- def test_get_schema_categorical_values(self, pandas_df):
+ def test_get_schema_categorical_values(self, narwhals_df):
"""Test that categorical columns show unique values."""
- source = DataFrameSource(pandas_df, "employees")
+ source = DataFrameSource(narwhals_df, "employees")
schema = source.get_schema(categorical_threshold=10)
# Department has only 2 unique values, should be categorical
@@ -198,9 +193,9 @@ def test_get_schema_categorical_values(self, pandas_df):
assert "'Engineering'" in schema
assert "'Sales'" in schema
- def test_get_schema_respects_threshold(self, pandas_df):
+ def test_get_schema_respects_threshold(self, narwhals_df):
"""Test that categorical_threshold is respected."""
- source = DataFrameSource(pandas_df, "employees")
+ source = DataFrameSource(narwhals_df, "employees")
# With threshold 1, no columns should be categorical
schema_low = source.get_schema(categorical_threshold=1)
@@ -218,18 +213,18 @@ def test_get_schema_respects_threshold(self, pandas_df):
class TestDataFrameSourceDbType:
"""Tests for DataFrameSource.get_db_type method."""
- def test_get_db_type_returns_duckdb(self, pandas_df):
+ def test_get_db_type_returns_duckdb(self, narwhals_df):
"""Test that get_db_type returns 'DuckDB'."""
- source = DataFrameSource(pandas_df, "employees")
+ source = DataFrameSource(narwhals_df, "employees")
assert source.get_db_type() == "DuckDB"
class TestDataFrameSourceCleanup:
"""Tests for DataFrameSource.cleanup method."""
- def test_cleanup_closes_connection(self, pandas_df):
+ def test_cleanup_closes_connection(self, narwhals_df):
"""Test that cleanup closes the DuckDB connection."""
- source = DataFrameSource(pandas_df, "employees")
+ source = DataFrameSource(narwhals_df, "employees")
# Should work before cleanup
result = source.execute_query("SELECT * FROM employees LIMIT 1")
@@ -249,13 +244,15 @@ class TestDataFrameSourceWithPolars:
@pytest.fixture
def polars_df(self):
- """Create a sample polars DataFrame."""
- return pl.DataFrame(
- {
- "id": [1, 2, 3],
- "name": ["Alice", "Bob", "Charlie"],
- "value": [10.5, 20.5, 30.5],
- }
+ """Create a sample narwhals-wrapped polars DataFrame."""
+ return nw.from_native(
+ pl.DataFrame(
+ {
+ "id": [1, 2, 3],
+ "name": ["Alice", "Bob", "Charlie"],
+ "value": [10.5, 20.5, 30.5],
+ }
+ )
)
def test_execute_query_with_polars(self, polars_df):
diff --git a/pkg-py/tests/test_datasource.py b/pkg-py/tests/test_datasource.py
index 8ba00feec..913d9a424 100644
--- a/pkg-py/tests/test_datasource.py
+++ b/pkg-py/tests/test_datasource.py
@@ -2,6 +2,7 @@
import tempfile
from pathlib import Path
+import narwhals.stable.v1 as nw
import pandas as pd
import pytest
from querychat._datasource import DataFrameSource, SQLAlchemySource
@@ -339,12 +340,14 @@ def test_test_query_empty_result(test_db_engine):
def test_test_query_dataframe_source():
"""Test that test_query works with DataFrameSource."""
# Create test DataFrame
- test_df = pd.DataFrame(
- {
- "id": [1, 2, 3, 4, 5],
- "name": ["a", "b", "c", "d", "e"],
- "value": [10, 20, 30, 40, 50],
- }
+ test_df = nw.from_native(
+ pd.DataFrame(
+ {
+ "id": [1, 2, 3, 4, 5],
+ "name": ["a", "b", "c", "d", "e"],
+ "value": [10, 20, 30, 40, 50],
+ }
+ )
)
source = DataFrameSource(test_df, "test_table")
@@ -468,12 +471,14 @@ def test_check_query_escape_hatch_does_not_enable_always_blocked(monkeypatch):
def test_check_query_integrated_into_execute_query():
"""Test that check_query is integrated into execute_query()."""
- test_df = pd.DataFrame(
- {
- "id": [1, 2, 3],
- "name": ["a", "b", "c"],
- "value": [10, 20, 30],
- }
+ test_df = nw.from_native(
+ pd.DataFrame(
+ {
+ "id": [1, 2, 3],
+ "name": ["a", "b", "c"],
+ "value": [10, 20, 30],
+ }
+ )
)
source = DataFrameSource(test_df, "test_table")
diff --git a/pkg-py/tests/test_df_to_html.py b/pkg-py/tests/test_df_to_html.py
index 323146153..de5d7f693 100644
--- a/pkg-py/tests/test_df_to_html.py
+++ b/pkg-py/tests/test_df_to_html.py
@@ -2,6 +2,7 @@
import tempfile
from pathlib import Path
+import narwhals.stable.v1 as nw
import pandas as pd
import pytest
from querychat._datasource import DataFrameSource, SQLAlchemySource
@@ -11,14 +12,16 @@
@pytest.fixture
def sample_dataframe():
- """Create a sample pandas DataFrame for testing."""
- return pd.DataFrame(
- {
- "id": [1, 2, 3, 4, 5],
- "name": ["Alice", "Bob", "Charlie", "Diana", "Eve"],
- "age": [25, 30, 35, 28, 32],
- "salary": [50000, 60000, 70000, 55000, 65000],
- },
+ """Create a sample narwhals DataFrame for testing."""
+ return nw.from_native(
+ pd.DataFrame(
+ {
+ "id": [1, 2, 3, 4, 5],
+ "name": ["Alice", "Bob", "Charlie", "Diana", "Eve"],
+ "age": [25, 30, 35, 28, 32],
+ "salary": [50000, 60000, 70000, 55000, 65000],
+ },
+ )
)
diff --git a/pkg-py/tests/test_init_with_pandas.py b/pkg-py/tests/test_init_with_pandas.py
index 1179a25c4..2ee976794 100644
--- a/pkg-py/tests/test_init_with_pandas.py
+++ b/pkg-py/tests/test_init_with_pandas.py
@@ -65,7 +65,7 @@ def test_init_with_narwhals_dataframe():
def test_init_with_narwhals_lazyframe_raises():
- """Test that QueryChat() raises TypeError for LazyFrames."""
+ """Test that QueryChat() raises TypeError for non-Polars LazyFrames."""
pdf = pd.DataFrame(
{
"id": [1, 2, 3],
@@ -75,7 +75,8 @@ def test_init_with_narwhals_lazyframe_raises():
)
nw_lazy = nw.from_native(pdf).lazy()
- with pytest.raises(NotImplementedError, match="LazyFrame"):
+ # Non-Polars LazyFrames (e.g., pandas-backed) are not supported
+ with pytest.raises(TypeError, match="Unsupported LazyFrame backend"):
QueryChat(
data_source=nw_lazy,
table_name="test_table",
diff --git a/pkg-py/tests/test_polars_lazy_source.py b/pkg-py/tests/test_polars_lazy_source.py
new file mode 100644
index 000000000..eccc0a802
--- /dev/null
+++ b/pkg-py/tests/test_polars_lazy_source.py
@@ -0,0 +1,225 @@
+"""Tests for the PolarsLazySource class."""
+
+import narwhals.stable.v1 as nw
+import polars as pl
+import pytest
+
+
+@pytest.fixture
+def polars_lazy_df():
+ """Create a sample Polars LazyFrame."""
+ return pl.LazyFrame(
+ {
+ "id": [1, 2, 3, 4, 5],
+ "name": ["Alice", "Bob", "Charlie", "Diana", "Eve"],
+ "age": [25, 30, 35, 28, 32],
+ "salary": [50000.0, 60000.0, 70000.0, 55000.0, 65000.0],
+ "department": ["Engineering", "Sales", "Engineering", "Sales", "Engineering"],
+ }
+ )
+
+
+class TestPolarsLazySourceInit:
+ """Tests for PolarsLazySource initialization."""
+
+ def test_init_accepts_narwhals_lazyframe(self, polars_lazy_df):
+ """Test that PolarsLazySource accepts a narwhals LazyFrame."""
+ from querychat._datasource import PolarsLazySource
+
+ nw_lf = nw.from_native(polars_lazy_df)
+ source = PolarsLazySource(nw_lf, "test_table")
+ assert source.table_name == "test_table"
+
+ def test_get_db_type_returns_polars(self, polars_lazy_df):
+ """Test that get_db_type returns 'Polars'."""
+ from querychat._datasource import PolarsLazySource
+
+ nw_lf = nw.from_native(polars_lazy_df)
+ source = PolarsLazySource(nw_lf, "employees")
+ assert source.get_db_type() == "Polars"
+
+
+class TestPolarsLazySourceExecuteQuery:
+ """Tests for PolarsLazySource.execute_query method."""
+
+ def test_execute_query_returns_narwhals_lazyframe(self, polars_lazy_df):
+ """Test that execute_query returns a narwhals LazyFrame."""
+ from querychat._datasource import PolarsLazySource
+
+ nw_lf = nw.from_native(polars_lazy_df)
+ source = PolarsLazySource(nw_lf, "employees")
+ result = source.execute_query("SELECT * FROM employees")
+ assert isinstance(result, nw.LazyFrame)
+
+ def test_execute_query_select_all(self, polars_lazy_df):
+ """Test SELECT * query."""
+ from querychat._datasource import PolarsLazySource
+
+ nw_lf = nw.from_native(polars_lazy_df)
+ source = PolarsLazySource(nw_lf, "employees")
+ result = source.execute_query("SELECT * FROM employees")
+
+ # Collect to verify results
+ collected = result.collect()
+ assert collected.shape == (5, 5)
+ assert set(collected.columns) == {"id", "name", "age", "salary", "department"}
+
+ def test_execute_query_with_filter(self, polars_lazy_df):
+ """Test query with WHERE clause."""
+ from querychat._datasource import PolarsLazySource
+
+ nw_lf = nw.from_native(polars_lazy_df)
+ source = PolarsLazySource(nw_lf, "employees")
+ result = source.execute_query(
+ "SELECT * FROM employees WHERE department = 'Engineering'"
+ )
+
+ collected = result.collect()
+ assert collected.shape == (3, 5)
+
+ def test_execute_query_with_aggregation(self, polars_lazy_df):
+ """Test query with aggregation."""
+ from querychat._datasource import PolarsLazySource
+
+ nw_lf = nw.from_native(polars_lazy_df)
+ source = PolarsLazySource(nw_lf, "employees")
+ result = source.execute_query(
+ "SELECT department, AVG(salary) as avg_salary FROM employees GROUP BY department"
+ )
+
+ collected = result.collect()
+ assert collected.shape == (2, 2)
+ assert "department" in collected.columns
+ assert "avg_salary" in collected.columns
+
+
+class TestPolarsLazySourceGetData:
+ """Tests for PolarsLazySource.get_data method."""
+
+ def test_get_data_returns_narwhals_lazyframe(self, polars_lazy_df):
+ """Test that get_data returns a narwhals LazyFrame."""
+ from querychat._datasource import PolarsLazySource
+
+ nw_lf = nw.from_native(polars_lazy_df)
+ source = PolarsLazySource(nw_lf, "employees")
+ result = source.get_data()
+ assert isinstance(result, nw.LazyFrame)
+
+ def test_get_data_returns_original_lazyframe(self, polars_lazy_df):
+ """Test that get_data returns the original LazyFrame."""
+ from querychat._datasource import PolarsLazySource
+
+ nw_lf = nw.from_native(polars_lazy_df)
+ source = PolarsLazySource(nw_lf, "employees")
+ result = source.get_data()
+
+ # The underlying native Polars LazyFrame should be the same
+ assert result.to_native() is polars_lazy_df
+
+
+class TestPolarsLazySourceGetSchema:
+ """Tests for PolarsLazySource.get_schema method."""
+
+ def test_get_schema_includes_table_name(self, polars_lazy_df):
+ """Test that schema includes table name."""
+ from querychat._datasource import PolarsLazySource
+
+ nw_lf = nw.from_native(polars_lazy_df)
+ source = PolarsLazySource(nw_lf, "employees")
+ schema = source.get_schema(categorical_threshold=10)
+
+ assert "Table: employees" in schema
+ assert "Columns:" in schema
+
+ def test_get_schema_includes_all_columns(self, polars_lazy_df):
+ """Test that schema includes all columns."""
+ from querychat._datasource import PolarsLazySource
+
+ nw_lf = nw.from_native(polars_lazy_df)
+ source = PolarsLazySource(nw_lf, "employees")
+ schema = source.get_schema(categorical_threshold=10)
+
+ for col in ["id", "name", "age", "salary", "department"]:
+ assert f"- {col} (" in schema
+
+ def test_get_schema_numeric_ranges(self, polars_lazy_df):
+ """Test that numeric columns include range information."""
+ from querychat._datasource import PolarsLazySource
+
+ nw_lf = nw.from_native(polars_lazy_df)
+ source = PolarsLazySource(nw_lf, "employees")
+ schema = source.get_schema(categorical_threshold=10)
+
+ # Age should have range
+ assert "Range: 25 to 35" in schema
+ # Salary should have range
+ assert "Range: 50000.0 to 70000.0" in schema
+
+ def test_get_schema_categorical_values(self, polars_lazy_df):
+ """Test that categorical columns show unique values."""
+ from querychat._datasource import PolarsLazySource
+
+ nw_lf = nw.from_native(polars_lazy_df)
+ source = PolarsLazySource(nw_lf, "employees")
+ schema = source.get_schema(categorical_threshold=10)
+
+ # Department has only 2 unique values, should be categorical
+ assert "Categorical values:" in schema
+ assert "'Engineering'" in schema
+ assert "'Sales'" in schema
+
+
+class TestPolarsLazySourceTestQuery:
+ """Tests for PolarsLazySource.test_query method."""
+
+ def test_test_query_returns_dataframe(self, polars_lazy_df):
+ """Test that test_query returns a collected DataFrame (not LazyFrame)."""
+ from querychat._datasource import PolarsLazySource
+
+ nw_lf = nw.from_native(polars_lazy_df)
+ source = PolarsLazySource(nw_lf, "employees")
+ result = source.test_query("SELECT * FROM employees")
+ # test_query collects to catch runtime errors, so returns DataFrame
+ assert isinstance(result, nw.DataFrame)
+ assert len(result) <= 1
+
+ def test_test_query_require_all_columns_passes(self, polars_lazy_df):
+ """Test that test_query passes when all columns present."""
+ from querychat._datasource import PolarsLazySource
+
+ nw_lf = nw.from_native(polars_lazy_df)
+ source = PolarsLazySource(nw_lf, "employees")
+ # Should not raise
+ result = source.test_query(
+ "SELECT * FROM employees", require_all_columns=True
+ )
+ assert isinstance(result, nw.DataFrame)
+
+ def test_test_query_require_all_columns_fails(self, polars_lazy_df):
+ """Test that test_query raises when columns missing."""
+ from querychat._datasource import (
+ MissingColumnsError,
+ PolarsLazySource,
+ )
+
+ nw_lf = nw.from_native(polars_lazy_df)
+ source = PolarsLazySource(nw_lf, "employees")
+
+ with pytest.raises(MissingColumnsError):
+ source.test_query(
+ "SELECT name, age FROM employees", require_all_columns=True
+ )
+
+ def test_test_query_catches_runtime_errors(self):
+ """Test that test_query catches runtime errors by actually executing."""
+ from querychat._datasource import PolarsLazySource
+
+ # Create LazyFrame with string column that can't be cast to integer
+ lf = pl.LazyFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]})
+ nw_lf = nw.from_native(lf)
+ source = PolarsLazySource(nw_lf, "test_table")
+
+ # This query fails at runtime when trying to cast strings to integers
+ # test_query should catch this because it actually executes (collects) the query
+ with pytest.raises(pl.exceptions.InvalidOperationError):
+ source.test_query("SELECT CAST(b AS INTEGER) FROM test_table")
diff --git a/pkg-py/tests/test_querychat.py b/pkg-py/tests/test_querychat.py
index 75acf6755..c5e872fd6 100644
--- a/pkg-py/tests/test_querychat.py
+++ b/pkg-py/tests/test_querychat.py
@@ -1,8 +1,11 @@
import os
+import narwhals.stable.v1 as nw
import pandas as pd
+import polars as pl
import pytest
from querychat import QueryChat
+from querychat._datasource import PolarsLazySource
@pytest.fixture(autouse=True)
@@ -87,3 +90,32 @@ def test_querychat_client_has_system_prompt(sample_df):
# (needed for methods like generate_greeting() that use _client directly)
assert qc._client.system_prompt is not None
assert "test_table" in qc._client.system_prompt
+
+
+def test_querychat_with_polars_lazyframe():
+ """Test that QueryChat accepts a Polars LazyFrame."""
+ lf = pl.LazyFrame(
+ {
+ "id": [1, 2, 3],
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35],
+ }
+ )
+
+ qc = QueryChat(
+ data_source=lf,
+ table_name="test_table",
+ greeting="Hello!",
+ )
+
+ # Should have created a PolarsLazySource
+ assert isinstance(qc.data_source, PolarsLazySource)
+
+ # Query should return a LazyFrame
+ result = qc.data_source.execute_query("SELECT * FROM test_table WHERE id = 2")
+ assert isinstance(result, nw.LazyFrame)
+
+ # Collect to verify
+ collected = result.collect()
+ assert len(collected) == 1
+ assert collected.item(0, "name") == "Bob"
diff --git a/pkg-py/tests/test_system_prompt.py b/pkg-py/tests/test_system_prompt.py
index 627bb6126..4d6919660 100644
--- a/pkg-py/tests/test_system_prompt.py
+++ b/pkg-py/tests/test_system_prompt.py
@@ -3,6 +3,7 @@
import tempfile
from pathlib import Path
+import narwhals.stable.v1 as nw
import pandas as pd
import pytest
from querychat._datasource import DataFrameSource
@@ -12,12 +13,14 @@
@pytest.fixture
def sample_data_source():
"""Create a sample DataFrameSource for testing."""
- df = pd.DataFrame(
- {
- "id": [1, 2, 3],
- "name": ["Alice", "Bob", "Charlie"],
- "age": [25, 30, 35],
- }
+ df = nw.from_native(
+ pd.DataFrame(
+ {
+ "id": [1, 2, 3],
+ "name": ["Alice", "Bob", "Charlie"],
+ "age": [25, 30, 35],
+ }
+ )
)
return DataFrameSource(df, "test_table")
diff --git a/pyproject.toml b/pyproject.toml
index d2d0f6ceb..cca646f8d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -64,7 +64,7 @@ packages = ["pkg-py/src/querychat"]
include = ["pkg-py/src/querychat", "pkg-py/LICENSE", "pkg-py/README.md"]
[dependency-groups]
-dev = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4", "pytest>=8.4.0"]
+dev = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4", "pytest>=8.4.0", "polars>=1.0.0"]
docs = ["quartodoc>=0.11.1", "nbformat", "nbclient", "ipykernel"]
examples = [
"openai",