diff --git a/pkg-py/CHANGELOG.md b/pkg-py/CHANGELOG.md index 725a47dca..7adf224d3 100644 --- a/pkg-py/CHANGELOG.md +++ b/pkg-py/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features +* Added `PolarsLazySource` to support Polars LazyFrames as data sources. Data stays lazy until the render boundary, enabling efficient handling of large datasets. Pass a `polars.LazyFrame` directly to `QueryChat()` and queries will be executed lazily via Polars' SQLContext. + * Added support for Gradio, Dash, and Streamlit web frameworks in addition to Shiny. Import from the new submodules: * `from querychat.gradio import QueryChat` * `from querychat.dash import QueryChat` diff --git a/pkg-py/docs/data-sources.qmd b/pkg-py/docs/data-sources.qmd index a9a7a5446..19f5aed08 100644 --- a/pkg-py/docs/data-sources.qmd +++ b/pkg-py/docs/data-sources.qmd @@ -3,11 +3,12 @@ title: Data Sources lightbox: true --- -`querychat` supports several different data sources, including: +`querychat` supports several different data sources, including: 1. Any [narwhals-compatible](https://narwhals-dev.github.io/narwhals/) data frame. -2. Any [SQLAlchemy](https://www.sqlalchemy.org/) database. -3. A custom [DataSource](reference/types.DataSource.qmd) interface/protocol. +2. Polars LazyFrames for efficient handling of large datasets. +3. Any [SQLAlchemy](https://www.sqlalchemy.org/) database. +4. A custom [DataSource](reference/types.DataSource.qmd) interface/protocol. The sections below describe how to use each type of data source with `querychat`. @@ -63,7 +64,68 @@ app = qc.app() ::: -If you're [building an app](build.qmd), note you can read the queried data frame reactively using the `df()` method, which returns a `narwhals.DataFrame`. Call `.to_native()` on the result to get the underlying pandas or polars DataFrame. +If you're [building an app](build.qmd), note you can read the queried data frame reactively using the `df()` method, which returns a `narwhals.DataFrame` (or `narwhals.LazyFrame` for lazy sources). Call `.to_native()` on the result to get the underlying pandas or polars DataFrame. + +## Polars LazyFrames {#lazy-frames} + +For large datasets, you can use [Polars LazyFrames](https://docs.pola.rs/user-guide/lazy/using/) to keep data on disk until it's actually needed. This is particularly useful when: + +- Your dataset is too large to fit comfortably in memory +- You only need filtered or aggregated subsets of the data +- You want faster startup times for your application + +With lazy evaluation, data stays on disk and queries are optimized by Polars before execution. Only the final results are loaded into memory. + +```{.python filename="lazy-app.py"} +import polars as pl +from querychat import QueryChat + +# Scan a large parquet file (doesn't load data yet!) +lf = pl.scan_parquet("large_dataset.parquet") + +# Pass the LazyFrame directly to QueryChat +qc = QueryChat(lf, "sales") +app = qc.app() +``` + +::: {.callout-tip} +### Why use lazy evaluation? + +The lazy approach can be significantly faster for large datasets because: + +- **Deferred loading**: Data stays on disk until actually needed, so startup is nearly instant +- **Query optimization**: Polars optimizes the query plan before execution, potentially skipping unnecessary columns and rows +- **Reduced memory**: Only the filtered/aggregated results are loaded into memory, not the entire dataset + +This is especially beneficial when users typically query small subsets of a large dataset. +::: + +You can create LazyFrames from various sources: + +```python +# From parquet (most efficient) +lf = pl.scan_parquet("data.parquet") + +# From CSV +lf = pl.scan_csv("data.csv") + +# From multiple files +lf = pl.scan_parquet("data/*.parquet") + +# From an existing DataFrame +df = pl.read_csv("data.csv") +lf = df.lazy() +``` + +When using a LazyFrame source, the `df()` method returns a `narwhals.LazyFrame`. Call `.collect()` to materialize the results when needed: + +```python +# Get the lazy result +result_lazy = qc.df() + +# Materialize when ready +result_df = result_lazy.collect() +``` ## Databases diff --git a/pkg-py/src/querychat/_dash.py b/pkg-py/src/querychat/_dash.py index 226fa4bed..fbad259bc 100644 --- a/pkg-py/src/querychat/_dash.py +++ b/pkg-py/src/querychat/_dash.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Literal, Optional, cast +import narwhals.stable.v1 as nw from chatlas import Turn from ._dash_ui import IDs, card_ui, chat_container_ui, chat_messages_ui @@ -374,6 +375,9 @@ def update_display(state_data: AppStateDict, reset_clicks): sql_code = f"```sql\n{state.get_display_sql()}\n```" df = state.get_current_data() + # Collect if lazy before accessing .to_pandas() or .shape + if isinstance(df, nw.LazyFrame): + df = df.collect() display_df = df.to_pandas() table_data = display_df.to_dict("records") @@ -404,8 +408,11 @@ def update_display(state_data: AppStateDict, reset_clicks): ) def export_csv(n_clicks: int, state_data: AppStateDict): state = deserialize_state(state_data) - df = state.get_current_data().to_pandas() - return send_data_frame(df.to_csv, "querychat_data.csv", index=False) + df = state.get_current_data() + # Collect if lazy before converting to pandas + if isinstance(df, nw.LazyFrame): + df = df.collect() + return send_data_frame(df.to_pandas().to_csv, "querychat_data.csv", index=False) def register_chat_callbacks( diff --git a/pkg-py/src/querychat/_datasource.py b/pkg-py/src/querychat/_datasource.py index 563775026..65ae90a5f 100644 --- a/pkg-py/src/querychat/_datasource.py +++ b/pkg-py/src/querychat/_datasource.py @@ -1,7 +1,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal, Union import duckdb import narwhals.stable.v1 as nw @@ -12,13 +13,39 @@ from ._utils import check_query if TYPE_CHECKING: + import polars as pl from sqlalchemy.engine import Connection, Engine +DataOrLazyFrame = Union[nw.DataFrame, nw.LazyFrame] + class MissingColumnsError(ValueError): """Raised when a query result is missing required columns.""" +@dataclass +class ColumnMeta: + """Metadata for a single column in a schema.""" + + name: str + """Column name.""" + + sql_type: str + """SQL type name (e.g., 'INTEGER', 'TEXT', 'DATE').""" + + kind: Literal["numeric", "text", "date", "other"] + """Column category for determining what stats to collect.""" + + min_val: Any = None + """Minimum value for numeric/date columns.""" + + max_val: Any = None + """Maximum value for numeric/date columns.""" + + categories: list[str] = field(default_factory=list) + """Unique values for text columns below the categorical threshold.""" + + class DataSource(ABC): """ An abstract class defining the interface for data sources used by QueryChat. @@ -58,7 +85,7 @@ def get_schema(self, *, categorical_threshold: int) -> str: ... @abstractmethod - def execute_query(self, query: str) -> nw.DataFrame: + def execute_query(self, query: str) -> DataOrLazyFrame: """ Execute SQL query and return results as DataFrame. @@ -104,7 +131,7 @@ def test_query( ... @abstractmethod - def get_data(self) -> nw.DataFrame: + def get_data(self) -> DataOrLazyFrame: """ Return the unfiltered data as a DataFrame. @@ -143,12 +170,12 @@ def __init__(self, df: nw.DataFrame, table_name: str): Parameters ---------- df - The DataFrame to wrap (pandas, polars, or any narwhals-compatible frame) + A narwhals DataFrame table_name Name of the table in SQL queries """ - self._df = nw.from_native(df) if not isinstance(df, nw.DataFrame) else df + self._df = df self.table_name = table_name self._conn = duckdb.connect(database=":memory:") @@ -635,3 +662,229 @@ def cleanup(self) -> None: """ if self._engine: self._engine.dispose() + + +class PolarsLazySource(DataSource): + """ + A DataSource implementation for Polars LazyFrames. + + Keeps data lazy throughout the query pipeline. Results from execute_query() + are LazyFrames that can be chained with additional operations before + collecting. + """ + + table_name: str + + def __init__(self, lf: nw.LazyFrame, table_name: str): + """ + Initialize with a narwhals LazyFrame wrapping a Polars LazyFrame. + + Parameters + ---------- + lf + A narwhals LazyFrame (wrapping a Polars LazyFrame) + table_name + Name of the table in SQL queries + + """ + import polars as pl + + self.table_name = table_name + + # Get native Polars LazyFrame for SQLContext + self._lf: pl.LazyFrame = lf.to_native() + if not isinstance(self._lf, pl.LazyFrame): + raise TypeError(f"Expected Polars LazyFrame, got {type(self._lf).__name__}") + + self._ctx = pl.SQLContext({table_name: self._lf}) + + # Cache schema (no data collection needed) + self._schema = self._lf.collect_schema() + self._colnames = list(self._schema.keys()) + + def get_db_type(self) -> str: + """Get the database type.""" + return "Polars" + + def get_schema(self, *, categorical_threshold: int) -> str: + """Generate schema information from LazyFrame using lazy aggregates.""" + # Build column metadata (classification happens here) + columns = [ + self._make_column_meta(name, dtype) for name, dtype in self._schema.items() + ] + + # Add stats to the metadata and format schema string + self._add_column_stats(columns, self._lf, categorical_threshold) + return self._format_schema(self.table_name, columns) + + def execute_query(self, query: str) -> nw.LazyFrame: + """ + Execute SQL query and return results as LazyFrame. + + Parameters + ---------- + query + SQL query to execute + + Returns + ------- + : + Query results as a narwhals LazyFrame + + """ + check_query(query) + result = self._ctx.execute(query) + return nw.from_native(result) + + def test_query( + self, query: str, *, require_all_columns: bool = False + ) -> nw.DataFrame: + """ + Test SQL query validity by executing and collecting one row. + + Parameters + ---------- + query + SQL query to test + require_all_columns + If True, validates that result includes all original table columns + + Returns + ------- + : + Query results as a narwhals DataFrame with at most one row + + """ + check_query(query) + + test_sql = f"SELECT * FROM ({query}) AS subquery LIMIT 1" + lf = self._ctx.execute(test_sql) + + # Actually collect to catch runtime errors (e.g., division by zero) + result = nw.from_native(lf.collect()) + + if require_all_columns: + result_columns = set(result.columns) + missing = set(self._colnames) - result_columns + if missing: + missing_list = ", ".join(f"'{c}'" for c in sorted(missing)) + original_list = ", ".join(f"'{c}'" for c in self._colnames) + raise MissingColumnsError( + f"Query result missing required columns: {missing_list}. " + f"The query must return all original table columns. " + f"Original columns: {original_list}" + ) + + return result + + def get_data(self) -> nw.LazyFrame: + """ + Return the unfiltered data as a LazyFrame. + + Returns + ------- + : + The original LazyFrame + + """ + return nw.from_native(self._lf) + + def cleanup(self) -> None: + """Clean up resources (no-op for Polars).""" + + @staticmethod + def _make_column_meta(name: str, dtype: pl.DataType) -> ColumnMeta: + import polars as pl + + if dtype.is_numeric(): + kind = "numeric" + sql_type = "INTEGER" if dtype.is_integer() else "FLOAT" + elif dtype == pl.String: + kind = "text" + sql_type = "TEXT" + elif dtype == pl.Date: + kind = "date" + sql_type = "DATE" + elif dtype == pl.Datetime: + kind = "date" + sql_type = "TIMESTAMP" + elif dtype == pl.Boolean: + kind = "other" + sql_type = "BOOLEAN" + elif dtype == pl.Time: + kind = "other" + sql_type = "TIME" + else: + kind = "other" + sql_type = "TEXT" + + return ColumnMeta(name=name, sql_type=sql_type, kind=kind) + + @staticmethod + def _add_column_stats( + columns: list[ColumnMeta], + lf: pl.LazyFrame, + categorical_threshold: int, + ) -> None: + import polars as pl + + # Build aggregation expressions based on column kinds + agg_exprs: list[pl.Expr] = [] + for col in columns: + if col.kind in ("numeric", "date"): + agg_exprs.append(pl.col(col.name).min().alias(f"{col.name}__min")) + agg_exprs.append(pl.col(col.name).max().alias(f"{col.name}__max")) + elif col.kind == "text": + agg_exprs.append( + pl.col(col.name).n_unique().alias(f"{col.name}__nunique") + ) + + if not agg_exprs: + return + + # First scan: collect all aggregate statistics + stats = lf.select(agg_exprs).collect().row(0, named=True) + + # Add min/max for numeric/date columns + for col in columns: + if col.kind in ("numeric", "date"): + col.min_val = stats.get(f"{col.name}__min") + col.max_val = stats.get(f"{col.name}__max") + + # Find text columns that qualify as categorical + categorical_cols = [ + col + for col in columns + if col.kind == "text" + and (nunique := stats.get(f"{col.name}__nunique")) + and nunique <= categorical_threshold + ] + + if not categorical_cols: + return + + # Second scan: batch collect unique values for all categorical columns + unique_exprs = [ + pl.col(col.name).drop_nulls().unique().implode().alias(col.name) + for col in categorical_cols + ] + unique_row = lf.select(unique_exprs).collect().row(0, named=True) + + for col in categorical_cols: + col.categories = unique_row[col.name] + + @staticmethod + def _format_schema(table_name: str, columns: list[ColumnMeta]) -> str: + """Format column metadata into schema string.""" + lines = [f"Table: {table_name}", "Columns:"] + + for col in columns: + lines.append(f"- {col.name} ({col.sql_type})") + + if col.kind in ("numeric", "date"): + lines.append(f" Range: {col.min_val} to {col.max_val}") + elif col.categories: + cats = ", ".join(f"'{v}'" for v in col.categories) + lines.append(f" Categorical values: {cats}") + + return "\n".join(lines) diff --git a/pkg-py/src/querychat/_gradio.py b/pkg-py/src/querychat/_gradio.py index c446388bc..a6d874101 100644 --- a/pkg-py/src/querychat/_gradio.py +++ b/pkg-py/src/querychat/_gradio.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional +import narwhals.stable.v1 as nw from gradio.context import Context from ._querychat_base import TOOL_GROUPS, QueryChatBase @@ -294,6 +295,9 @@ def update_displays(state_dict: AppStateDict): ) df = self.df(state_dict) + # Collect if lazy before accessing .shape + if isinstance(df, nw.LazyFrame): + df = df.collect() data_info_parts = [] if error: diff --git a/pkg-py/src/querychat/_querychat_base.py b/pkg-py/src/querychat/_querychat_base.py index 65995cbab..fd0d8fad6 100644 --- a/pkg-py/src/querychat/_querychat_base.py +++ b/pkg-py/src/querychat/_querychat_base.py @@ -12,7 +12,12 @@ import narwhals.stable.v1 as nw import sqlalchemy -from ._datasource import DataFrameSource, DataSource, SQLAlchemySource +from ._datasource import ( + DataFrameSource, + DataSource, + PolarsLazySource, + SQLAlchemySource, +) from ._shiny_module import GREETING_PROMPT from ._system_prompt import QueryChatSystemPrompt from ._utils import MISSING, MISSING_TYPE @@ -174,13 +179,28 @@ def normalize_data_source( return data_source if isinstance(data_source, sqlalchemy.Engine): return SQLAlchemySource(data_source, table_name) + src = nw.from_native(data_source, pass_through=True) + if isinstance(src, nw.DataFrame): return DataFrameSource(src, table_name) + if isinstance(src, nw.LazyFrame): - raise NotImplementedError( - "LazyFrame data sources are not yet supported (they will be soon)." + native = src.to_native() + try: + import polars as pl + + if isinstance(native, pl.LazyFrame): + return PolarsLazySource(src, table_name) + except ImportError: + pass + raise TypeError( + f"Unsupported LazyFrame backend: {type(native).__name__} from {type(native).__module__}. " + "Currently only Polars LazyFrames are supported. " + "If you believe this type should be supported, please open an issue at " + "https://github.com/posit-dev/querychat/issues" ) + raise TypeError( f"Unsupported data source type: {type(data_source)}. " "If you believe this type should be supported, please open an issue at " diff --git a/pkg-py/src/querychat/_querychat_core.py b/pkg-py/src/querychat/_querychat_core.py index 35945f4e2..c2bb4c82b 100644 --- a/pkg-py/src/querychat/_querychat_core.py +++ b/pkg-py/src/querychat/_querychat_core.py @@ -31,9 +31,7 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator, Iterator - import narwhals.stable.v1 as nw - - from ._datasource import DataSource + from ._datasource import DataOrLazyFrame, DataSource ClientFactory = Callable[ @@ -72,7 +70,7 @@ def _client_factory( """Create a chat client with dashboard callbacks.""" return self.client(update_dashboard=update_cb, reset_dashboard=reset_cb) # type: ignore[attr-defined] - def df(self, state: AppStateDict | None) -> nw.DataFrame: + def df(self, state: AppStateDict | None) -> DataOrLazyFrame: """ Get the current DataFrame from state. @@ -85,6 +83,7 @@ def df(self, state: AppStateDict | None) -> nw.DataFrame: ------- : The filtered data if a SQL query is active, otherwise the full dataset. + Returns a LazyFrame if the data source is lazy. """ sql = state.get("sql") if state else None @@ -210,7 +209,7 @@ def reset_dashboard(self) -> None: self.title = None self.error = None - def get_current_data(self) -> nw.DataFrame: + def get_current_data(self) -> DataOrLazyFrame: """Get current data, falling back to default if query fails.""" if self.sql: try: diff --git a/pkg-py/src/querychat/_shiny.py b/pkg-py/src/querychat/_shiny.py index 2886b74e7..a8cbd7f2b 100644 --- a/pkg-py/src/querychat/_shiny.py +++ b/pkg-py/src/querychat/_shiny.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Literal, Optional, overload +import narwhals.stable.v1 as nw from shiny.express._stub_session import ExpressStubSession from shiny.session import get_current_session from shinychat import output_markdown_stream @@ -16,10 +17,11 @@ from pathlib import Path import chatlas - import narwhals.stable.v1 as nw import sqlalchemy from narwhals.stable.v1.typing import IntoFrame + from ._datasource import DataOrLazyFrame + class QueryChat(QueryChatBase): """ @@ -239,7 +241,11 @@ def _(): @render.data_frame def dt(): - return vals.df() + df = vals.df() + # Collect if lazy + if isinstance(df, nw.LazyFrame): + df = df.collect() + return df @render.ui def sql_output(): @@ -605,16 +611,16 @@ def ui(self, *, id: Optional[str] = None, **kwargs): """ return mod_ui(id or self.id, **kwargs) - def df(self) -> nw.DataFrame: + def df(self) -> DataOrLazyFrame: """ Reactively read the current filtered data frame that is in effect. Returns ------- : - The current filtered data frame as a narwhals DataFrame. If no query - has been set, this will return the unfiltered data frame from the - data source. + The current filtered data frame as a narwhals DataFrame or LazyFrame. + If the data source is lazy, returns a LazyFrame. If no query has been + set, this will return the unfiltered data from the data source. """ return self._vals.df() diff --git a/pkg-py/src/querychat/_shiny_module.py b/pkg-py/src/querychat/_shiny_module.py index 80c798a69..7d10e4ec2 100644 --- a/pkg-py/src/querychat/_shiny_module.py +++ b/pkg-py/src/querychat/_shiny_module.py @@ -17,12 +17,11 @@ if TYPE_CHECKING: from collections.abc import Callable - import narwhals.stable.v1 as nw from shiny.bookmark import BookmarkState, RestoreState from shiny import Inputs, Outputs, Session - from ._datasource import DataSource + from ._datasource import DataOrLazyFrame, DataSource from .types import UpdateDashboardData ReactiveString = reactive.Value[str] @@ -62,9 +61,10 @@ class ServerValues: Attributes ---------- df - A reactive Calc that returns the current filtered data frame. If no SQL - query has been set, this returns the unfiltered data from the data source. - Call it like `.df()` to reactively read the current data frame. + A reactive Calc that returns the current filtered data frame or lazy frame. + If the data source is lazy, returns a LazyFrame. If no SQL query has been + set, this returns the unfiltered data from the data source. + Call it like `.df()` to reactively read the current data. sql A reactive Value containing the current SQL query string. Access the value by calling `.sql()`, or set it with `.sql.set("SELECT ...")`. @@ -81,7 +81,7 @@ class ServerValues: """ - df: Callable[[], nw.DataFrame] + df: Callable[[], DataOrLazyFrame] sql: ReactiveStringOrNone title: ReactiveStringOrNone client: chatlas.Chat diff --git a/pkg-py/src/querychat/_streamlit.py b/pkg-py/src/querychat/_streamlit.py index c2453bc34..a0f1ea611 100644 --- a/pkg-py/src/querychat/_streamlit.py +++ b/pkg-py/src/querychat/_streamlit.py @@ -4,6 +4,8 @@ from typing import TYPE_CHECKING, Optional +import narwhals.stable.v1 as nw + from ._querychat_base import TOOL_GROUPS, QueryChatBase from ._querychat_core import ( GREETING_PROMPT, @@ -17,10 +19,11 @@ from pathlib import Path import chatlas - import narwhals.stable.v1 as nw import sqlalchemy from narwhals.stable.v1.typing import IntoFrame + from ._datasource import DataOrLazyFrame + class QueryChat(QueryChatBase): """ @@ -181,8 +184,8 @@ def ui(self) -> None: st.rerun() - def df(self) -> nw.DataFrame: - """Get the current filtered data frame.""" + def df(self) -> DataOrLazyFrame: + """Get the current filtered data frame (or LazyFrame if data source is lazy).""" return self._get_state().get_current_data() def sql(self) -> str | None: @@ -235,6 +238,9 @@ def _render_main_content(self) -> None: st.subheader("Data view") df = state.get_current_data() + # Collect if lazy before accessing .shape or displaying + if isinstance(df, nw.LazyFrame): + df = df.collect() if state.error: st.error(state.error) st.dataframe(df, use_container_width=True, height=400, hide_index=True) diff --git a/pkg-py/src/querychat/_utils.py b/pkg-py/src/querychat/_utils.py index 64b3755c2..1eb7c2052 100644 --- a/pkg-py/src/querychat/_utils.py +++ b/pkg-py/src/querychat/_utils.py @@ -7,6 +7,19 @@ from typing import TYPE_CHECKING, Literal, Optional import narwhals.stable.v1 as nw +from great_tables import GT + +if TYPE_CHECKING: + from ._datasource import DataOrLazyFrame + + +class MISSING_TYPE: # noqa: N801 + """ + A singleton representing a missing value. + """ + + +MISSING = MISSING_TYPE() class UnsafeQueryError(ValueError): @@ -81,19 +94,6 @@ def check_query(query: str) -> None: ) -if TYPE_CHECKING: - from narwhals.stable.v1.typing import IntoFrame - - -class MISSING_TYPE: # noqa: N801 - """ - A singleton representing a missing value. - """ - - -MISSING = MISSING_TYPE() - - @contextmanager def temp_env_vars(env_vars: dict[str, Optional[str]]): """ @@ -196,7 +196,7 @@ def querychat_tool_starts_open(action: Literal["update", "query", "reset"]) -> b return action != "reset" -def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: +def df_to_html(df: DataOrLazyFrame, maxrows: int = 5) -> str: """ Convert a DataFrame to a Bootstrap-styled HTML table for display in chat. @@ -213,38 +213,16 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: HTML string representation of the table """ - ndf = nw.from_native(df) - - if isinstance(ndf, (nw.LazyFrame, nw.DataFrame)): - df_short = ndf.lazy().head(maxrows).collect() - nrow_full = ndf.lazy().select(nw.len()).collect().item() - else: - raise TypeError( - "Must be able to convert `df` into a Narwhals DataFrame or LazyFrame", - ) - - # Build simple Bootstrap-styled HTML table - columns = df_short.columns - rows = df_short.rows() - - # Table header - header_cells = "".join(f"{col}" for col in columns) - header = f"{header_cells}" - - # Table body - body_rows = [] - for row in rows: - cells = "".join(f"{val}" for val in row) - body_rows.append(f"{cells}") - body = f"{''.join(body_rows)}" + if isinstance(df, nw.DataFrame): + df = df.lazy() - # Use Bootstrap table classes - table_html = f'{header}{body}
' + df_short = df.head(maxrows).collect() + gt_tbl = GT(df_short.to_native()) + table_html = gt_tbl.as_raw_html(make_page=False) # Add note about truncated rows if needed - if len(df_short) != nrow_full: - rows_notice = f"\n\n*(Showing {maxrows} of {nrow_full} rows)*\n" - else: - rows_notice = "" + nrow_full = df.select(nw.len()).collect().item() + if nrow_full > maxrows: + table_html += f"\n\n*(Showing {maxrows} of {nrow_full} rows)*\n" - return table_html + rows_notice + return table_html diff --git a/pkg-py/src/querychat/tools.py b/pkg-py/src/querychat/tools.py index df344a268..a1a46b005 100644 --- a/pkg-py/src/querychat/tools.py +++ b/pkg-py/src/querychat/tools.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Protocol, TypedDict, runtime_checkable import chevron +import narwhals.stable.v1 as nw from chatlas import ContentToolResult, Tool from shinychat.types import ToolResultDisplay @@ -232,6 +233,8 @@ def query(query: str, _intent: str = "") -> ContentToolResult: try: result_df = data_source.execute_query(query) + if isinstance(result_df, nw.LazyFrame): + result_df = result_df.collect() value = result_df.rows(named=True) # Format table results diff --git a/pkg-py/tests/test_dataframe_source.py b/pkg-py/tests/test_dataframe_source.py index 7602583af..53f442b8f 100644 --- a/pkg-py/tests/test_dataframe_source.py +++ b/pkg-py/tests/test_dataframe_source.py @@ -40,11 +40,6 @@ def narwhals_df(pandas_df): class TestDataFrameSourceInit: """Tests for DataFrameSource initialization.""" - def test_init_with_pandas_dataframe(self, pandas_df): - """Test that DataFrameSource accepts a pandas DataFrame.""" - source = DataFrameSource(pandas_df, "test_table") - assert source.table_name == "test_table" - def test_init_with_narwhals_dataframe(self, narwhals_df): """Test that DataFrameSource accepts a narwhals DataFrame.""" source = DataFrameSource(narwhals_df, "test_table") @@ -52,37 +47,37 @@ def test_init_with_narwhals_dataframe(self, narwhals_df): @pytest.mark.skipif(not HAS_POLARS_WITH_PYARROW, reason="polars or pyarrow not installed") def test_init_with_polars_dataframe(self): - """Test that DataFrameSource accepts a polars DataFrame.""" + """Test that DataFrameSource accepts a narwhals-wrapped polars DataFrame.""" polars_df = pl.DataFrame( { "id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"], } ) - source = DataFrameSource(polars_df, "test_table") + source = DataFrameSource(nw.from_native(polars_df), "test_table") assert source.table_name == "test_table" class TestDataFrameSourceExecuteQuery: """Tests for DataFrameSource.execute_query method.""" - def test_execute_query_returns_narwhals_dataframe(self, pandas_df): + def test_execute_query_returns_narwhals_dataframe(self, narwhals_df): """Test that execute_query returns a narwhals DataFrame.""" - source = DataFrameSource(pandas_df, "employees") + source = DataFrameSource(narwhals_df, "employees") result = source.execute_query("SELECT * FROM employees") assert isinstance(result, nw.DataFrame) - def test_execute_query_select_all(self, pandas_df): + def test_execute_query_select_all(self, narwhals_df): """Test SELECT * query.""" - source = DataFrameSource(pandas_df, "employees") + source = DataFrameSource(narwhals_df, "employees") result = source.execute_query("SELECT * FROM employees") assert result.shape == (5, 5) assert set(result.columns) == {"id", "name", "age", "salary", "department"} - def test_execute_query_with_filter(self, pandas_df): + def test_execute_query_with_filter(self, narwhals_df): """Test query with WHERE clause.""" - source = DataFrameSource(pandas_df, "employees") + source = DataFrameSource(narwhals_df, "employees") result = source.execute_query( "SELECT * FROM employees WHERE department = 'Engineering'" ) @@ -91,9 +86,9 @@ def test_execute_query_with_filter(self, pandas_df): departments = result["department"].unique().to_list() assert departments == ["Engineering"] - def test_execute_query_with_aggregation(self, pandas_df): + def test_execute_query_with_aggregation(self, narwhals_df): """Test query with aggregation.""" - source = DataFrameSource(pandas_df, "employees") + source = DataFrameSource(narwhals_df, "employees") result = source.execute_query( "SELECT department, AVG(salary) as avg_salary FROM employees GROUP BY department" ) @@ -102,17 +97,17 @@ def test_execute_query_with_aggregation(self, pandas_df): assert "department" in result.columns assert "avg_salary" in result.columns - def test_execute_query_select_columns(self, pandas_df): + def test_execute_query_select_columns(self, narwhals_df): """Test selecting specific columns.""" - source = DataFrameSource(pandas_df, "employees") + source = DataFrameSource(narwhals_df, "employees") result = source.execute_query("SELECT name, age FROM employees") assert result.shape == (5, 2) assert list(result.columns) == ["name", "age"] - def test_execute_query_order_by(self, pandas_df): + def test_execute_query_order_by(self, narwhals_df): """Test query with ORDER BY clause.""" - source = DataFrameSource(pandas_df, "employees") + source = DataFrameSource(narwhals_df, "employees") result = source.execute_query( "SELECT name, age FROM employees ORDER BY age DESC" ) @@ -120,9 +115,9 @@ def test_execute_query_order_by(self, pandas_df): ages = result["age"].to_list() assert ages == sorted(ages, reverse=True) - def test_execute_query_empty_result(self, pandas_df): + def test_execute_query_empty_result(self, narwhals_df): """Test query that returns no rows.""" - source = DataFrameSource(pandas_df, "employees") + source = DataFrameSource(narwhals_df, "employees") result = source.execute_query( "SELECT * FROM employees WHERE age > 100" ) @@ -134,27 +129,27 @@ def test_execute_query_empty_result(self, pandas_df): class TestDataFrameSourceGetData: """Tests for DataFrameSource.get_data method.""" - def test_get_data_returns_narwhals_dataframe(self, pandas_df): + def test_get_data_returns_narwhals_dataframe(self, narwhals_df): """Test that get_data returns a narwhals DataFrame.""" - source = DataFrameSource(pandas_df, "employees") + source = DataFrameSource(narwhals_df, "employees") result = source.get_data() assert isinstance(result, nw.DataFrame) - def test_get_data_returns_full_dataset(self, pandas_df): + def test_get_data_returns_full_dataset(self, narwhals_df): """Test that get_data returns all rows.""" - source = DataFrameSource(pandas_df, "employees") + source = DataFrameSource(narwhals_df, "employees") result = source.get_data() - assert result.shape == pandas_df.shape - assert set(result.columns) == set(pandas_df.columns) + assert result.shape == narwhals_df.shape + assert set(result.columns) == set(narwhals_df.columns) - def test_get_data_preserves_data(self, pandas_df): + def test_get_data_preserves_data(self, narwhals_df): """Test that get_data preserves data values.""" - source = DataFrameSource(pandas_df, "employees") + source = DataFrameSource(narwhals_df, "employees") result = source.get_data() # Check that the data matches - original_names = sorted(pandas_df["name"].tolist()) + original_names = sorted(narwhals_df["name"].to_list()) result_names = sorted(result["name"].to_list()) assert original_names == result_names @@ -162,25 +157,25 @@ def test_get_data_preserves_data(self, pandas_df): class TestDataFrameSourceGetSchema: """Tests for DataFrameSource.get_schema method.""" - def test_get_schema_includes_table_name(self, pandas_df): + def test_get_schema_includes_table_name(self, narwhals_df): """Test that schema includes table name.""" - source = DataFrameSource(pandas_df, "employees") + source = DataFrameSource(narwhals_df, "employees") schema = source.get_schema(categorical_threshold=10) assert "Table: employees" in schema assert "Columns:" in schema - def test_get_schema_includes_all_columns(self, pandas_df): + def test_get_schema_includes_all_columns(self, narwhals_df): """Test that schema includes all columns.""" - source = DataFrameSource(pandas_df, "employees") + source = DataFrameSource(narwhals_df, "employees") schema = source.get_schema(categorical_threshold=10) - for col in pandas_df.columns: + for col in narwhals_df.columns: assert f"- {col} (" in schema - def test_get_schema_numeric_ranges(self, pandas_df): + def test_get_schema_numeric_ranges(self, narwhals_df): """Test that numeric columns include range information.""" - source = DataFrameSource(pandas_df, "employees") + source = DataFrameSource(narwhals_df, "employees") schema = source.get_schema(categorical_threshold=10) # Age should have range @@ -188,9 +183,9 @@ def test_get_schema_numeric_ranges(self, pandas_df): # Salary should have range assert "Range: 50000.0 to 70000.0" in schema - def test_get_schema_categorical_values(self, pandas_df): + def test_get_schema_categorical_values(self, narwhals_df): """Test that categorical columns show unique values.""" - source = DataFrameSource(pandas_df, "employees") + source = DataFrameSource(narwhals_df, "employees") schema = source.get_schema(categorical_threshold=10) # Department has only 2 unique values, should be categorical @@ -198,9 +193,9 @@ def test_get_schema_categorical_values(self, pandas_df): assert "'Engineering'" in schema assert "'Sales'" in schema - def test_get_schema_respects_threshold(self, pandas_df): + def test_get_schema_respects_threshold(self, narwhals_df): """Test that categorical_threshold is respected.""" - source = DataFrameSource(pandas_df, "employees") + source = DataFrameSource(narwhals_df, "employees") # With threshold 1, no columns should be categorical schema_low = source.get_schema(categorical_threshold=1) @@ -218,18 +213,18 @@ def test_get_schema_respects_threshold(self, pandas_df): class TestDataFrameSourceDbType: """Tests for DataFrameSource.get_db_type method.""" - def test_get_db_type_returns_duckdb(self, pandas_df): + def test_get_db_type_returns_duckdb(self, narwhals_df): """Test that get_db_type returns 'DuckDB'.""" - source = DataFrameSource(pandas_df, "employees") + source = DataFrameSource(narwhals_df, "employees") assert source.get_db_type() == "DuckDB" class TestDataFrameSourceCleanup: """Tests for DataFrameSource.cleanup method.""" - def test_cleanup_closes_connection(self, pandas_df): + def test_cleanup_closes_connection(self, narwhals_df): """Test that cleanup closes the DuckDB connection.""" - source = DataFrameSource(pandas_df, "employees") + source = DataFrameSource(narwhals_df, "employees") # Should work before cleanup result = source.execute_query("SELECT * FROM employees LIMIT 1") @@ -249,13 +244,15 @@ class TestDataFrameSourceWithPolars: @pytest.fixture def polars_df(self): - """Create a sample polars DataFrame.""" - return pl.DataFrame( - { - "id": [1, 2, 3], - "name": ["Alice", "Bob", "Charlie"], - "value": [10.5, 20.5, 30.5], - } + """Create a sample narwhals-wrapped polars DataFrame.""" + return nw.from_native( + pl.DataFrame( + { + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + "value": [10.5, 20.5, 30.5], + } + ) ) def test_execute_query_with_polars(self, polars_df): diff --git a/pkg-py/tests/test_datasource.py b/pkg-py/tests/test_datasource.py index 8ba00feec..913d9a424 100644 --- a/pkg-py/tests/test_datasource.py +++ b/pkg-py/tests/test_datasource.py @@ -2,6 +2,7 @@ import tempfile from pathlib import Path +import narwhals.stable.v1 as nw import pandas as pd import pytest from querychat._datasource import DataFrameSource, SQLAlchemySource @@ -339,12 +340,14 @@ def test_test_query_empty_result(test_db_engine): def test_test_query_dataframe_source(): """Test that test_query works with DataFrameSource.""" # Create test DataFrame - test_df = pd.DataFrame( - { - "id": [1, 2, 3, 4, 5], - "name": ["a", "b", "c", "d", "e"], - "value": [10, 20, 30, 40, 50], - } + test_df = nw.from_native( + pd.DataFrame( + { + "id": [1, 2, 3, 4, 5], + "name": ["a", "b", "c", "d", "e"], + "value": [10, 20, 30, 40, 50], + } + ) ) source = DataFrameSource(test_df, "test_table") @@ -468,12 +471,14 @@ def test_check_query_escape_hatch_does_not_enable_always_blocked(monkeypatch): def test_check_query_integrated_into_execute_query(): """Test that check_query is integrated into execute_query().""" - test_df = pd.DataFrame( - { - "id": [1, 2, 3], - "name": ["a", "b", "c"], - "value": [10, 20, 30], - } + test_df = nw.from_native( + pd.DataFrame( + { + "id": [1, 2, 3], + "name": ["a", "b", "c"], + "value": [10, 20, 30], + } + ) ) source = DataFrameSource(test_df, "test_table") diff --git a/pkg-py/tests/test_df_to_html.py b/pkg-py/tests/test_df_to_html.py index 323146153..de5d7f693 100644 --- a/pkg-py/tests/test_df_to_html.py +++ b/pkg-py/tests/test_df_to_html.py @@ -2,6 +2,7 @@ import tempfile from pathlib import Path +import narwhals.stable.v1 as nw import pandas as pd import pytest from querychat._datasource import DataFrameSource, SQLAlchemySource @@ -11,14 +12,16 @@ @pytest.fixture def sample_dataframe(): - """Create a sample pandas DataFrame for testing.""" - return pd.DataFrame( - { - "id": [1, 2, 3, 4, 5], - "name": ["Alice", "Bob", "Charlie", "Diana", "Eve"], - "age": [25, 30, 35, 28, 32], - "salary": [50000, 60000, 70000, 55000, 65000], - }, + """Create a sample narwhals DataFrame for testing.""" + return nw.from_native( + pd.DataFrame( + { + "id": [1, 2, 3, 4, 5], + "name": ["Alice", "Bob", "Charlie", "Diana", "Eve"], + "age": [25, 30, 35, 28, 32], + "salary": [50000, 60000, 70000, 55000, 65000], + }, + ) ) diff --git a/pkg-py/tests/test_init_with_pandas.py b/pkg-py/tests/test_init_with_pandas.py index 1179a25c4..2ee976794 100644 --- a/pkg-py/tests/test_init_with_pandas.py +++ b/pkg-py/tests/test_init_with_pandas.py @@ -65,7 +65,7 @@ def test_init_with_narwhals_dataframe(): def test_init_with_narwhals_lazyframe_raises(): - """Test that QueryChat() raises TypeError for LazyFrames.""" + """Test that QueryChat() raises TypeError for non-Polars LazyFrames.""" pdf = pd.DataFrame( { "id": [1, 2, 3], @@ -75,7 +75,8 @@ def test_init_with_narwhals_lazyframe_raises(): ) nw_lazy = nw.from_native(pdf).lazy() - with pytest.raises(NotImplementedError, match="LazyFrame"): + # Non-Polars LazyFrames (e.g., pandas-backed) are not supported + with pytest.raises(TypeError, match="Unsupported LazyFrame backend"): QueryChat( data_source=nw_lazy, table_name="test_table", diff --git a/pkg-py/tests/test_polars_lazy_source.py b/pkg-py/tests/test_polars_lazy_source.py new file mode 100644 index 000000000..eccc0a802 --- /dev/null +++ b/pkg-py/tests/test_polars_lazy_source.py @@ -0,0 +1,225 @@ +"""Tests for the PolarsLazySource class.""" + +import narwhals.stable.v1 as nw +import polars as pl +import pytest + + +@pytest.fixture +def polars_lazy_df(): + """Create a sample Polars LazyFrame.""" + return pl.LazyFrame( + { + "id": [1, 2, 3, 4, 5], + "name": ["Alice", "Bob", "Charlie", "Diana", "Eve"], + "age": [25, 30, 35, 28, 32], + "salary": [50000.0, 60000.0, 70000.0, 55000.0, 65000.0], + "department": ["Engineering", "Sales", "Engineering", "Sales", "Engineering"], + } + ) + + +class TestPolarsLazySourceInit: + """Tests for PolarsLazySource initialization.""" + + def test_init_accepts_narwhals_lazyframe(self, polars_lazy_df): + """Test that PolarsLazySource accepts a narwhals LazyFrame.""" + from querychat._datasource import PolarsLazySource + + nw_lf = nw.from_native(polars_lazy_df) + source = PolarsLazySource(nw_lf, "test_table") + assert source.table_name == "test_table" + + def test_get_db_type_returns_polars(self, polars_lazy_df): + """Test that get_db_type returns 'Polars'.""" + from querychat._datasource import PolarsLazySource + + nw_lf = nw.from_native(polars_lazy_df) + source = PolarsLazySource(nw_lf, "employees") + assert source.get_db_type() == "Polars" + + +class TestPolarsLazySourceExecuteQuery: + """Tests for PolarsLazySource.execute_query method.""" + + def test_execute_query_returns_narwhals_lazyframe(self, polars_lazy_df): + """Test that execute_query returns a narwhals LazyFrame.""" + from querychat._datasource import PolarsLazySource + + nw_lf = nw.from_native(polars_lazy_df) + source = PolarsLazySource(nw_lf, "employees") + result = source.execute_query("SELECT * FROM employees") + assert isinstance(result, nw.LazyFrame) + + def test_execute_query_select_all(self, polars_lazy_df): + """Test SELECT * query.""" + from querychat._datasource import PolarsLazySource + + nw_lf = nw.from_native(polars_lazy_df) + source = PolarsLazySource(nw_lf, "employees") + result = source.execute_query("SELECT * FROM employees") + + # Collect to verify results + collected = result.collect() + assert collected.shape == (5, 5) + assert set(collected.columns) == {"id", "name", "age", "salary", "department"} + + def test_execute_query_with_filter(self, polars_lazy_df): + """Test query with WHERE clause.""" + from querychat._datasource import PolarsLazySource + + nw_lf = nw.from_native(polars_lazy_df) + source = PolarsLazySource(nw_lf, "employees") + result = source.execute_query( + "SELECT * FROM employees WHERE department = 'Engineering'" + ) + + collected = result.collect() + assert collected.shape == (3, 5) + + def test_execute_query_with_aggregation(self, polars_lazy_df): + """Test query with aggregation.""" + from querychat._datasource import PolarsLazySource + + nw_lf = nw.from_native(polars_lazy_df) + source = PolarsLazySource(nw_lf, "employees") + result = source.execute_query( + "SELECT department, AVG(salary) as avg_salary FROM employees GROUP BY department" + ) + + collected = result.collect() + assert collected.shape == (2, 2) + assert "department" in collected.columns + assert "avg_salary" in collected.columns + + +class TestPolarsLazySourceGetData: + """Tests for PolarsLazySource.get_data method.""" + + def test_get_data_returns_narwhals_lazyframe(self, polars_lazy_df): + """Test that get_data returns a narwhals LazyFrame.""" + from querychat._datasource import PolarsLazySource + + nw_lf = nw.from_native(polars_lazy_df) + source = PolarsLazySource(nw_lf, "employees") + result = source.get_data() + assert isinstance(result, nw.LazyFrame) + + def test_get_data_returns_original_lazyframe(self, polars_lazy_df): + """Test that get_data returns the original LazyFrame.""" + from querychat._datasource import PolarsLazySource + + nw_lf = nw.from_native(polars_lazy_df) + source = PolarsLazySource(nw_lf, "employees") + result = source.get_data() + + # The underlying native Polars LazyFrame should be the same + assert result.to_native() is polars_lazy_df + + +class TestPolarsLazySourceGetSchema: + """Tests for PolarsLazySource.get_schema method.""" + + def test_get_schema_includes_table_name(self, polars_lazy_df): + """Test that schema includes table name.""" + from querychat._datasource import PolarsLazySource + + nw_lf = nw.from_native(polars_lazy_df) + source = PolarsLazySource(nw_lf, "employees") + schema = source.get_schema(categorical_threshold=10) + + assert "Table: employees" in schema + assert "Columns:" in schema + + def test_get_schema_includes_all_columns(self, polars_lazy_df): + """Test that schema includes all columns.""" + from querychat._datasource import PolarsLazySource + + nw_lf = nw.from_native(polars_lazy_df) + source = PolarsLazySource(nw_lf, "employees") + schema = source.get_schema(categorical_threshold=10) + + for col in ["id", "name", "age", "salary", "department"]: + assert f"- {col} (" in schema + + def test_get_schema_numeric_ranges(self, polars_lazy_df): + """Test that numeric columns include range information.""" + from querychat._datasource import PolarsLazySource + + nw_lf = nw.from_native(polars_lazy_df) + source = PolarsLazySource(nw_lf, "employees") + schema = source.get_schema(categorical_threshold=10) + + # Age should have range + assert "Range: 25 to 35" in schema + # Salary should have range + assert "Range: 50000.0 to 70000.0" in schema + + def test_get_schema_categorical_values(self, polars_lazy_df): + """Test that categorical columns show unique values.""" + from querychat._datasource import PolarsLazySource + + nw_lf = nw.from_native(polars_lazy_df) + source = PolarsLazySource(nw_lf, "employees") + schema = source.get_schema(categorical_threshold=10) + + # Department has only 2 unique values, should be categorical + assert "Categorical values:" in schema + assert "'Engineering'" in schema + assert "'Sales'" in schema + + +class TestPolarsLazySourceTestQuery: + """Tests for PolarsLazySource.test_query method.""" + + def test_test_query_returns_dataframe(self, polars_lazy_df): + """Test that test_query returns a collected DataFrame (not LazyFrame).""" + from querychat._datasource import PolarsLazySource + + nw_lf = nw.from_native(polars_lazy_df) + source = PolarsLazySource(nw_lf, "employees") + result = source.test_query("SELECT * FROM employees") + # test_query collects to catch runtime errors, so returns DataFrame + assert isinstance(result, nw.DataFrame) + assert len(result) <= 1 + + def test_test_query_require_all_columns_passes(self, polars_lazy_df): + """Test that test_query passes when all columns present.""" + from querychat._datasource import PolarsLazySource + + nw_lf = nw.from_native(polars_lazy_df) + source = PolarsLazySource(nw_lf, "employees") + # Should not raise + result = source.test_query( + "SELECT * FROM employees", require_all_columns=True + ) + assert isinstance(result, nw.DataFrame) + + def test_test_query_require_all_columns_fails(self, polars_lazy_df): + """Test that test_query raises when columns missing.""" + from querychat._datasource import ( + MissingColumnsError, + PolarsLazySource, + ) + + nw_lf = nw.from_native(polars_lazy_df) + source = PolarsLazySource(nw_lf, "employees") + + with pytest.raises(MissingColumnsError): + source.test_query( + "SELECT name, age FROM employees", require_all_columns=True + ) + + def test_test_query_catches_runtime_errors(self): + """Test that test_query catches runtime errors by actually executing.""" + from querychat._datasource import PolarsLazySource + + # Create LazyFrame with string column that can't be cast to integer + lf = pl.LazyFrame({"a": [1, 2, 3], "b": ["x", "y", "z"]}) + nw_lf = nw.from_native(lf) + source = PolarsLazySource(nw_lf, "test_table") + + # This query fails at runtime when trying to cast strings to integers + # test_query should catch this because it actually executes (collects) the query + with pytest.raises(pl.exceptions.InvalidOperationError): + source.test_query("SELECT CAST(b AS INTEGER) FROM test_table") diff --git a/pkg-py/tests/test_querychat.py b/pkg-py/tests/test_querychat.py index 75acf6755..c5e872fd6 100644 --- a/pkg-py/tests/test_querychat.py +++ b/pkg-py/tests/test_querychat.py @@ -1,8 +1,11 @@ import os +import narwhals.stable.v1 as nw import pandas as pd +import polars as pl import pytest from querychat import QueryChat +from querychat._datasource import PolarsLazySource @pytest.fixture(autouse=True) @@ -87,3 +90,32 @@ def test_querychat_client_has_system_prompt(sample_df): # (needed for methods like generate_greeting() that use _client directly) assert qc._client.system_prompt is not None assert "test_table" in qc._client.system_prompt + + +def test_querychat_with_polars_lazyframe(): + """Test that QueryChat accepts a Polars LazyFrame.""" + lf = pl.LazyFrame( + { + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + "age": [25, 30, 35], + } + ) + + qc = QueryChat( + data_source=lf, + table_name="test_table", + greeting="Hello!", + ) + + # Should have created a PolarsLazySource + assert isinstance(qc.data_source, PolarsLazySource) + + # Query should return a LazyFrame + result = qc.data_source.execute_query("SELECT * FROM test_table WHERE id = 2") + assert isinstance(result, nw.LazyFrame) + + # Collect to verify + collected = result.collect() + assert len(collected) == 1 + assert collected.item(0, "name") == "Bob" diff --git a/pkg-py/tests/test_system_prompt.py b/pkg-py/tests/test_system_prompt.py index 627bb6126..4d6919660 100644 --- a/pkg-py/tests/test_system_prompt.py +++ b/pkg-py/tests/test_system_prompt.py @@ -3,6 +3,7 @@ import tempfile from pathlib import Path +import narwhals.stable.v1 as nw import pandas as pd import pytest from querychat._datasource import DataFrameSource @@ -12,12 +13,14 @@ @pytest.fixture def sample_data_source(): """Create a sample DataFrameSource for testing.""" - df = pd.DataFrame( - { - "id": [1, 2, 3], - "name": ["Alice", "Bob", "Charlie"], - "age": [25, 30, 35], - } + df = nw.from_native( + pd.DataFrame( + { + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + "age": [25, 30, 35], + } + ) ) return DataFrameSource(df, "test_table") diff --git a/pyproject.toml b/pyproject.toml index d2d0f6ceb..cca646f8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ packages = ["pkg-py/src/querychat"] include = ["pkg-py/src/querychat", "pkg-py/LICENSE", "pkg-py/README.md"] [dependency-groups] -dev = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4", "pytest>=8.4.0"] +dev = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4", "pytest>=8.4.0", "polars>=1.0.0"] docs = ["quartodoc>=0.11.1", "nbformat", "nbclient", "ipykernel"] examples = [ "openai",