From 8188bb8d4353650ddab6ecdfb5e9cfbceee74891 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 27 Apr 2026 15:00:40 -0700 Subject: [PATCH 01/31] [SPARK-XXXXX][CONNECT][PYTHON] Add SQLContext wrapper for Spark Connect Implements a Connect-compatible SQLContext in pyspark.sql.connect.context that wraps a Connect SparkSession instead of requiring a SparkContext. All SparkSession-delegate methods (sql, table, range, createDataFrame, conf, udf, udtf, read, readStream, streams, catalog ops) are wired up. JVM-only APIs (registerJavaFunction, HiveContext) raise PySparkNotImplementedError. Adds 20 unit tests. Co-authored-by: Isaac --- python/pyspark/sql/connect/context.py | 451 ++++++++++++++++++ .../sql/tests/connect/test_connect_context.py | 187 ++++++++ 2 files changed, 638 insertions(+) create mode 100644 python/pyspark/sql/connect/context.py create mode 100644 python/pyspark/sql/tests/connect/test_connect_context.py diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py new file mode 100644 index 0000000000000..6de50120eed3c --- /dev/null +++ b/python/pyspark/sql/connect/context.py @@ -0,0 +1,451 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import ( + Optional, + Union, + Callable, + Any, + Iterable, + List, + Tuple, + Type, + ClassVar, + TYPE_CHECKING, +) + +from pyspark import _NoValue +from pyspark._globals import _NoValueType +from pyspark.errors import PySparkNotImplementedError +from pyspark.sql.connect.dataframe import DataFrame +from pyspark.sql.connect.readwriter import DataFrameReader +from pyspark.sql.connect.streaming.readwriter import DataStreamReader +from pyspark.sql.connect.streaming.query import StreamingQueryManager +from pyspark.sql.types import Row, AtomicType, DataType, StructType + +if TYPE_CHECKING: + import numpy as np + import pandas as pd + import pyarrow as pa + from pyspark.sql.connect.session import SparkSession + from pyspark.sql.connect.udf import UDFRegistration + from pyspark.sql.connect.udtf import UDTFRegistration + from pyspark.sql._typing import UserDefinedFunctionLike + +__all__ = ["SQLContext", "HiveContext"] + + +class SQLContext: + """The entry point for working with structured data (rows and columns) in Spark, in Spark 1.x. + + As of Spark 2.0, this is replaced by :class:`SparkSession`. However, we are keeping the class + here for backward compatibility. + + This is the Spark Connect-compatible implementation. Unlike the classic implementation, + it wraps a Connect :class:`SparkSession` directly and does not require a + :class:`~pyspark.SparkContext`. + + .. deprecated:: 3.0.0 + Use :func:`SparkSession.builder.getOrCreate()` instead. + + Parameters + ---------- + sparkSession : :class:`SparkSession` + The Connect :class:`SparkSession` to wrap. + """ + + _instantiatedContext: ClassVar[Optional["SQLContext"]] = None + + def __init__(self, sparkSession: "SparkSession") -> None: + warnings.warn( + "Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead.", + FutureWarning, + stacklevel=2, + ) + self.sparkSession = sparkSession + if SQLContext._instantiatedContext is None: + SQLContext._instantiatedContext = self + + @classmethod + def _from_session(cls, sparkSession: "SparkSession") -> "SQLContext": + """Create a new instance without emitting a deprecation warning. + + Used internally by :meth:`newSession` and :meth:`getOrCreate`. + """ + ctx = object.__new__(cls) + ctx.sparkSession = sparkSession + return ctx + + @classmethod + def getOrCreate(cls: Type["SQLContext"], sparkSession: "SparkSession") -> "SQLContext": + """Get the existing SQLContext or create a new one wrapping the given SparkSession. + + .. deprecated:: 3.0.0 + Use :func:`SparkSession.builder.getOrCreate()` instead. + + Parameters + ---------- + sparkSession : :class:`SparkSession` + """ + warnings.warn( + "Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead.", + FutureWarning, + stacklevel=2, + ) + if cls._instantiatedContext is None: + cls._instantiatedContext = cls._from_session(sparkSession) + return cls._instantiatedContext + + def newSession(self) -> "SQLContext": + """Returns a new SQLContext as new session, that has separate SQLConf, + registered temporary views and UDFs, but shared table cache. + + .. versionadded:: 1.6.0 + """ + return self._from_session(self.sparkSession.newSession()) + + def setConf(self, key: str, value: Union[bool, int, str]) -> None: + """Sets the given Spark SQL configuration property. + + .. versionadded:: 1.3.0 + """ + self.sparkSession.conf.set(key, value) + + def getConf( + self, key: str, defaultValue: Union[Optional[str], _NoValueType] = _NoValue + ) -> Optional[str]: + """Returns the value of Spark SQL configuration property for the given key. + + If the key is not set and defaultValue is set, return + defaultValue. If the key is not set and defaultValue is not set, return + the system default value. + + .. versionadded:: 1.3.0 + """ + return self.sparkSession.conf.get(key, defaultValue) + + @property + def udf(self) -> "UDFRegistration": + """Returns a :class:`UDFRegistration` for UDF registration. + + .. versionadded:: 1.3.1 + + Returns + ------- + :class:`UDFRegistration` + """ + return self.sparkSession.udf + + @property + def udtf(self) -> "UDTFRegistration": + """Returns a :class:`UDTFRegistration` for UDTF registration. + + .. versionadded:: 3.5.0 + + Returns + ------- + :class:`UDTFRegistration` + """ + return self.sparkSession.udtf + + def range( + self, + start: int, + end: Optional[int] = None, + step: int = 1, + numPartitions: Optional[int] = None, + ) -> DataFrame: + """Create a :class:`DataFrame` with single :class:`~pyspark.sql.types.LongType` column + named ``id``, containing elements in a range from ``start`` to ``end`` (exclusive) with + step value ``step``. + + .. versionadded:: 1.4.0 + + Parameters + ---------- + start : int + the start value + end : int, optional + the end value (exclusive) + step : int, optional + the incremental step (default: 1) + numPartitions : int, optional + the number of partitions of the DataFrame + + Returns + ------- + :class:`DataFrame` + """ + return self.sparkSession.range(start, end, step, numPartitions) + + def registerFunction( + self, name: str, f: Callable[..., Any], returnType: Optional[DataType] = None + ) -> "UserDefinedFunctionLike": + """An alias for :func:`spark.udf.register`. + See :meth:`pyspark.sql.UDFRegistration.register`. + + .. versionadded:: 1.2.0 + + .. deprecated:: 2.3.0 + Use :func:`spark.udf.register` instead. + """ + warnings.warn("Deprecated in 2.3.0. Use spark.udf.register instead.", FutureWarning) + return self.sparkSession.udf.register(name, f, returnType) + + def registerJavaFunction( + self, name: str, javaClassName: str, returnType: Optional[DataType] = None + ) -> None: + """Not supported in Spark Connect. + + .. versionadded:: 2.1.0 + + .. deprecated:: 2.3.0 + Use :func:`spark.udf.registerJavaFunction` instead. + """ + raise PySparkNotImplementedError( + errorClass="NOT_IMPLEMENTED", + messageParameters={"feature": "registerJavaFunction"}, + ) + + def createDataFrame( + self, + data: Union["pd.DataFrame", "np.ndarray", "pa.Table", Iterable[Any]], + schema: Optional[Union[AtomicType, StructType, str, List[str], Tuple[str, ...]]] = None, + samplingRatio: Optional[float] = None, + verifySchema: Optional[bool] = None, + ) -> DataFrame: + """Creates a :class:`DataFrame` from an iterable, a :class:`pandas.DataFrame`, + or a :class:`pyarrow.Table`. + + .. versionadded:: 1.3.0 + + Parameters + ---------- + data : iterable + an iterable of any kind of SQL data representation (:class:`Row`, + :class:`tuple`, ``int``, ``boolean``, etc.), :class:`list`, + :class:`pandas.DataFrame`, or :class:`pyarrow.Table`. + schema : :class:`~pyspark.sql.types.DataType`, str or list, optional + a :class:`~pyspark.sql.types.DataType` or a datatype string or a list of + column names. + samplingRatio : float, optional + the sample ratio of rows used for inferring + verifySchema : bool, optional + verify data types of every row against schema. + + Returns + ------- + :class:`DataFrame` + """ + return self.sparkSession.createDataFrame( # type: ignore[call-overload] + data, schema, samplingRatio, verifySchema + ) + + def registerDataFrameAsTable(self, df: DataFrame, tableName: str) -> None: + """Registers the given :class:`DataFrame` as a temporary table in the catalog. + + Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`. + + .. versionadded:: 1.3.0 + """ + df.createOrReplaceTempView(tableName) + + def dropTempTable(self, tableName: str) -> None: + """Remove the temporary table from catalog. + + .. versionadded:: 1.6.0 + """ + self.sparkSession.catalog.dropTempView(tableName) + + def createExternalTable( + self, + tableName: str, + path: Optional[str] = None, + source: Optional[str] = None, + schema: Optional[StructType] = None, + **options: str, + ) -> DataFrame: + """Creates an external table based on the dataset in a data source. + + It returns the DataFrame associated with the external table. + + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and + created external table. + + .. versionadded:: 1.3.0 + + Returns + ------- + :class:`DataFrame` + """ + return self.sparkSession.catalog.createExternalTable( + tableName, path, source, schema, **options + ) + + def sql(self, sqlQuery: str) -> DataFrame: + """Returns a :class:`DataFrame` representing the result of the given query. + + .. versionadded:: 1.0.0 + + Returns + ------- + :class:`DataFrame` + """ + return self.sparkSession.sql(sqlQuery) + + def table(self, tableName: str) -> DataFrame: + """Returns the specified table or view as a :class:`DataFrame`. + + .. versionadded:: 1.0.0 + + Returns + ------- + :class:`DataFrame` + """ + return self.sparkSession.table(tableName) + + def tables(self, dbName: Optional[str] = None) -> DataFrame: + """Returns a :class:`DataFrame` containing names of tables in the given database. + + If ``dbName`` is not specified, the current database will be used. + + The returned DataFrame has two columns: ``tableName`` and ``isTemporary`` + (a column with :class:`~pyspark.sql.types.BooleanType` indicating if a table is a + temporary one or not). + + .. versionadded:: 1.3.0 + + Parameters + ---------- + dbName: str, optional + name of the database to use. + + Returns + ------- + :class:`DataFrame` + """ + listed = self.sparkSession.catalog.listTables(dbName) + rows = [ + Row( + namespace=".".join(t.namespace) if t.namespace else "", + tableName=t.name, + isTemporary=t.isTemporary, + ) + for t in listed + ] + return self.sparkSession.createDataFrame(rows) + + def tableNames(self, dbName: Optional[str] = None) -> List[str]: + """Returns a list of names of tables in the database ``dbName``. + + .. versionadded:: 1.3.0 + + Parameters + ---------- + dbName: str + name of the database to use. Default to the current database. + + Returns + ------- + list + list of table names, in string + """ + return [t.name for t in self.sparkSession.catalog.listTables(dbName)] + + def cacheTable(self, tableName: str) -> None: + """Caches the specified table in-memory. + + .. versionadded:: 1.0.0 + """ + self.sparkSession.catalog.cacheTable(tableName) + + def uncacheTable(self, tableName: str) -> None: + """Removes the specified table from the in-memory cache. + + .. versionadded:: 1.0.0 + """ + self.sparkSession.catalog.uncacheTable(tableName) + + def clearCache(self) -> None: + """Removes all cached tables from the in-memory cache. + + .. versionadded:: 1.3.0 + """ + self.sparkSession.catalog.clearCache() + + @property + def read(self) -> DataFrameReader: + """Returns a :class:`DataFrameReader` that can be used to read data + in as a :class:`DataFrame`. + + .. versionadded:: 1.4.0 + + Returns + ------- + :class:`DataFrameReader` + """ + return self.sparkSession.read + + @property + def readStream(self) -> DataStreamReader: + """Returns a :class:`DataStreamReader` that can be used to read data streams + as a streaming :class:`DataFrame`. + + .. versionadded:: 2.0.0 + + Notes + ----- + This API is evolving. + + Returns + ------- + :class:`DataStreamReader` + """ + return self.sparkSession.readStream + + @property + def streams(self) -> StreamingQueryManager: + """Returns a :class:`StreamingQueryManager` that allows managing all the + :class:`~pyspark.sql.streaming.StreamingQuery` StreamingQueries active on + `this` context. + + .. versionadded:: 2.0.0 + + Notes + ----- + This API is evolving. + """ + return self.sparkSession.streams + + +class HiveContext(SQLContext): + """Not supported in Spark Connect. + + .. deprecated:: 2.0.0 + Use SparkSession.builder.enableHiveSupport().getOrCreate(). + """ + + def __init__(self, sparkSession: "SparkSession") -> None: + raise PySparkNotImplementedError( + errorClass="NOT_IMPLEMENTED", + messageParameters={"feature": "HiveContext"}, + ) diff --git a/python/pyspark/sql/tests/connect/test_connect_context.py b/python/pyspark/sql/tests/connect/test_connect_context.py new file mode 100644 index 0000000000000..f78243e259af3 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_connect_context.py @@ -0,0 +1,187 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import warnings + +from pyspark.errors import PySparkNotImplementedError +from pyspark.sql.connect.context import HiveContext, SQLContext +from pyspark.sql.connect.dataframe import DataFrame +from pyspark.sql.connect.readwriter import DataFrameReader +from pyspark.sql.connect.streaming.readwriter import DataStreamReader +from pyspark.sql.connect.streaming.query import StreamingQueryManager +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class SQLContextConnectTests(ReusedConnectTestCase): + def setUp(self): + super().setUp() + SQLContext._instantiatedContext = None + + def tearDown(self): + SQLContext._instantiatedContext = None + super().tearDown() + + def _make(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + return SQLContext(self.spark) + + def test_init_emits_deprecation_warning(self): + with self.assertWarns(FutureWarning): + SQLContext(self.spark) + + def test_getOrCreate_returns_same_instance(self): + ctx = self._make() + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + ctx2 = SQLContext.getOrCreate(self.spark) + self.assertIs(ctx, ctx2) + + def test_getOrCreate_emits_deprecation_warning(self): + SQLContext._instantiatedContext = None + with self.assertWarns(FutureWarning): + SQLContext.getOrCreate(self.spark) + + def test_newSession_returns_new_instance(self): + ctx = self._make() + ctx2 = ctx.newSession() + self.assertIsInstance(ctx2, SQLContext) + self.assertIsNot(ctx, ctx2) + + def test_setConf_and_getConf(self): + ctx = self._make() + ctx.setConf("spark.sql.shuffle.partitions", "42") + self.assertEqual(ctx.getConf("spark.sql.shuffle.partitions"), "42") + self.assertEqual(ctx.getConf("spark.sql.shuffle.partitions", "99"), "42") + ctx.setConf("spark.sql.shuffle.partitions", "200") + + def test_udf_property(self): + from pyspark.sql.connect.udf import UDFRegistration + + ctx = self._make() + self.assertIsInstance(ctx.udf, UDFRegistration) + + def test_udtf_property(self): + from pyspark.sql.connect.udtf import UDTFRegistration + + ctx = self._make() + self.assertIsInstance(ctx.udtf, UDTFRegistration) + + def test_range(self): + ctx = self._make() + df = ctx.range(3) + self.assertIsInstance(df, DataFrame) + self.assertEqual(df.collect(), self.spark.range(3).collect()) + + def test_createDataFrame(self): + ctx = self._make() + df = ctx.createDataFrame([(1, "a"), (2, "b")], ["id", "val"]) + self.assertIsInstance(df, DataFrame) + rows = df.collect() + self.assertEqual(len(rows), 2) + + def test_sql(self): + ctx = self._make() + df = ctx.sql("SELECT 1 AS n") + self.assertIsInstance(df, DataFrame) + self.assertEqual(df.collect()[0].n, 1) + + def test_registerDataFrameAsTable_and_sql(self): + ctx = self._make() + df = self.spark.createDataFrame([(1, "row1"), (2, "row2")], ["field1", "field2"]) + ctx.registerDataFrameAsTable(df, "ctx_test_tbl") + try: + result = ctx.sql("SELECT field1 FROM ctx_test_tbl ORDER BY field1").collect() + self.assertEqual([r.field1 for r in result], [1, 2]) + finally: + ctx.dropTempTable("ctx_test_tbl") + + def test_table(self): + ctx = self._make() + df = self.spark.createDataFrame([(42,)], ["v"]) + df.createOrReplaceTempView("ctx_table_view") + try: + result = ctx.table("ctx_table_view") + self.assertIsInstance(result, DataFrame) + self.assertEqual(result.collect()[0].v, 42) + finally: + self.spark.catalog.dropTempView("ctx_table_view") + + def test_tables_returns_dataframe(self): + ctx = self._make() + df = self.spark.createDataFrame([(1,)], ["x"]) + df.createOrReplaceTempView("ctx_tables_view") + try: + tables_df = ctx.tables() + self.assertIsInstance(tables_df, DataFrame) + names = [r.tableName for r in tables_df.collect()] + self.assertIn("ctx_tables_view", names) + finally: + self.spark.catalog.dropTempView("ctx_tables_view") + + def test_tableNames_returns_list(self): + ctx = self._make() + df = self.spark.createDataFrame([(1,)], ["x"]) + df.createOrReplaceTempView("ctx_tablenames_view") + try: + names = ctx.tableNames() + self.assertIsInstance(names, list) + self.assertIn("ctx_tablenames_view", names) + finally: + self.spark.catalog.dropTempView("ctx_tablenames_view") + + def test_cacheTable_uncacheTable(self): + ctx = self._make() + df = self.spark.createDataFrame([(1,)], ["x"]) + df.createOrReplaceTempView("ctx_cache_view") + try: + ctx.cacheTable("ctx_cache_view") + ctx.uncacheTable("ctx_cache_view") + finally: + self.spark.catalog.dropTempView("ctx_cache_view") + + def test_clearCache(self): + ctx = self._make() + ctx.clearCache() + + def test_read_property(self): + ctx = self._make() + self.assertIsInstance(ctx.read, DataFrameReader) + + def test_readStream_property(self): + ctx = self._make() + self.assertIsInstance(ctx.readStream, DataStreamReader) + + def test_streams_property(self): + ctx = self._make() + self.assertIsInstance(ctx.streams, StreamingQueryManager) + + def test_registerJavaFunction_raises(self): + ctx = self._make() + with self.assertRaises(PySparkNotImplementedError): + ctx.registerJavaFunction("f", "com.example.F") + + def test_hive_context_raises(self): + with self.assertRaises(PySparkNotImplementedError): + HiveContext(self.spark) + + +if __name__ == "__main__": + from pyspark.testing import main + + main() From 63db047b062e85ef72e4a5055e880c4a92ea67b0 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 27 Apr 2026 22:57:42 -0700 Subject: [PATCH 02/31] Address review comments on Connect SQLContext wrapper - Replace manual Row construction in tables() with SHOW TABLES SQL - Fix all versionadded annotations to 4.0.0 (Connect-era version) - Add _get_or_create_from_session() and dispatch in classic SQLContext.getOrCreate() so users need not import from pyspark.sql.connect directly (sc arg made optional) - Extract SQLContextTestsMixin into test_context.py; Connect test now inherits the shared suite via SQLContextParityTests Co-authored-by: Isaac --- python/pyspark/sql/connect/context.py | 76 ++++----- python/pyspark/sql/context.py | 17 +- .../sql/tests/connect/test_connect_context.py | 150 +----------------- python/pyspark/sql/tests/test_context.py | 103 ++++++++++++ 4 files changed, 165 insertions(+), 181 deletions(-) diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py index 6de50120eed3c..ce1a5bdb7e109 100644 --- a/python/pyspark/sql/connect/context.py +++ b/python/pyspark/sql/connect/context.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.streaming.readwriter import DataStreamReader from pyspark.sql.connect.streaming.query import StreamingQueryManager -from pyspark.sql.types import Row, AtomicType, DataType, StructType +from pyspark.sql.types import AtomicType, DataType, StructType if TYPE_CHECKING: import numpy as np @@ -83,14 +83,23 @@ def __init__(self, sparkSession: "SparkSession") -> None: @classmethod def _from_session(cls, sparkSession: "SparkSession") -> "SQLContext": - """Create a new instance without emitting a deprecation warning. - - Used internally by :meth:`newSession` and :meth:`getOrCreate`. - """ + """Create a new instance without emitting a deprecation warning.""" ctx = object.__new__(cls) ctx.sparkSession = sparkSession return ctx + @classmethod + def _get_or_create_from_session(cls, sparkSession: "SparkSession") -> "SQLContext": + """Return the cached instance or create one from an active Connect SparkSession. + + Called by the classic :meth:`pyspark.sql.context.SQLContext.getOrCreate` when + running in Spark Connect mode, so users do not need to import from + ``pyspark.sql.connect`` directly. + """ + if cls._instantiatedContext is None: + cls._instantiatedContext = cls._from_session(sparkSession) + return cls._instantiatedContext + @classmethod def getOrCreate(cls: Type["SQLContext"], sparkSession: "SparkSession") -> "SQLContext": """Get the existing SQLContext or create a new one wrapping the given SparkSession. @@ -115,14 +124,14 @@ def newSession(self) -> "SQLContext": """Returns a new SQLContext as new session, that has separate SQLConf, registered temporary views and UDFs, but shared table cache. - .. versionadded:: 1.6.0 + .. versionadded:: 4.0.0 """ return self._from_session(self.sparkSession.newSession()) def setConf(self, key: str, value: Union[bool, int, str]) -> None: """Sets the given Spark SQL configuration property. - .. versionadded:: 1.3.0 + .. versionadded:: 4.0.0 """ self.sparkSession.conf.set(key, value) @@ -135,7 +144,7 @@ def getConf( defaultValue. If the key is not set and defaultValue is not set, return the system default value. - .. versionadded:: 1.3.0 + .. versionadded:: 4.0.0 """ return self.sparkSession.conf.get(key, defaultValue) @@ -143,7 +152,7 @@ def getConf( def udf(self) -> "UDFRegistration": """Returns a :class:`UDFRegistration` for UDF registration. - .. versionadded:: 1.3.1 + .. versionadded:: 4.0.0 Returns ------- @@ -155,7 +164,7 @@ def udf(self) -> "UDFRegistration": def udtf(self) -> "UDTFRegistration": """Returns a :class:`UDTFRegistration` for UDTF registration. - .. versionadded:: 3.5.0 + .. versionadded:: 4.0.0 Returns ------- @@ -174,7 +183,7 @@ def range( named ``id``, containing elements in a range from ``start`` to ``end`` (exclusive) with step value ``step``. - .. versionadded:: 1.4.0 + .. versionadded:: 4.0.0 Parameters ---------- @@ -199,7 +208,7 @@ def registerFunction( """An alias for :func:`spark.udf.register`. See :meth:`pyspark.sql.UDFRegistration.register`. - .. versionadded:: 1.2.0 + .. versionadded:: 4.0.0 .. deprecated:: 2.3.0 Use :func:`spark.udf.register` instead. @@ -212,7 +221,7 @@ def registerJavaFunction( ) -> None: """Not supported in Spark Connect. - .. versionadded:: 2.1.0 + .. versionadded:: 4.0.0 .. deprecated:: 2.3.0 Use :func:`spark.udf.registerJavaFunction` instead. @@ -232,7 +241,7 @@ def createDataFrame( """Creates a :class:`DataFrame` from an iterable, a :class:`pandas.DataFrame`, or a :class:`pyarrow.Table`. - .. versionadded:: 1.3.0 + .. versionadded:: 4.0.0 Parameters ---------- @@ -261,14 +270,14 @@ def registerDataFrameAsTable(self, df: DataFrame, tableName: str) -> None: Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`. - .. versionadded:: 1.3.0 + .. versionadded:: 4.0.0 """ df.createOrReplaceTempView(tableName) def dropTempTable(self, tableName: str) -> None: """Remove the temporary table from catalog. - .. versionadded:: 1.6.0 + .. versionadded:: 4.0.0 """ self.sparkSession.catalog.dropTempView(tableName) @@ -291,7 +300,7 @@ def createExternalTable( Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and created external table. - .. versionadded:: 1.3.0 + .. versionadded:: 4.0.0 Returns ------- @@ -304,7 +313,7 @@ def createExternalTable( def sql(self, sqlQuery: str) -> DataFrame: """Returns a :class:`DataFrame` representing the result of the given query. - .. versionadded:: 1.0.0 + .. versionadded:: 4.0.0 Returns ------- @@ -315,7 +324,7 @@ def sql(self, sqlQuery: str) -> DataFrame: def table(self, tableName: str) -> DataFrame: """Returns the specified table or view as a :class:`DataFrame`. - .. versionadded:: 1.0.0 + .. versionadded:: 4.0.0 Returns ------- @@ -332,7 +341,7 @@ def tables(self, dbName: Optional[str] = None) -> DataFrame: (a column with :class:`~pyspark.sql.types.BooleanType` indicating if a table is a temporary one or not). - .. versionadded:: 1.3.0 + .. versionadded:: 4.0.0 Parameters ---------- @@ -343,21 +352,14 @@ def tables(self, dbName: Optional[str] = None) -> DataFrame: ------- :class:`DataFrame` """ - listed = self.sparkSession.catalog.listTables(dbName) - rows = [ - Row( - namespace=".".join(t.namespace) if t.namespace else "", - tableName=t.name, - isTemporary=t.isTemporary, - ) - for t in listed - ] - return self.sparkSession.createDataFrame(rows) + if dbName is None: + return self.sparkSession.sql("SHOW TABLES") + return self.sparkSession.sql(f"SHOW TABLES IN {dbName}") def tableNames(self, dbName: Optional[str] = None) -> List[str]: """Returns a list of names of tables in the database ``dbName``. - .. versionadded:: 1.3.0 + .. versionadded:: 4.0.0 Parameters ---------- @@ -374,21 +376,21 @@ def tableNames(self, dbName: Optional[str] = None) -> List[str]: def cacheTable(self, tableName: str) -> None: """Caches the specified table in-memory. - .. versionadded:: 1.0.0 + .. versionadded:: 4.0.0 """ self.sparkSession.catalog.cacheTable(tableName) def uncacheTable(self, tableName: str) -> None: """Removes the specified table from the in-memory cache. - .. versionadded:: 1.0.0 + .. versionadded:: 4.0.0 """ self.sparkSession.catalog.uncacheTable(tableName) def clearCache(self) -> None: """Removes all cached tables from the in-memory cache. - .. versionadded:: 1.3.0 + .. versionadded:: 4.0.0 """ self.sparkSession.catalog.clearCache() @@ -397,7 +399,7 @@ def read(self) -> DataFrameReader: """Returns a :class:`DataFrameReader` that can be used to read data in as a :class:`DataFrame`. - .. versionadded:: 1.4.0 + .. versionadded:: 4.0.0 Returns ------- @@ -410,7 +412,7 @@ def readStream(self) -> DataStreamReader: """Returns a :class:`DataStreamReader` that can be used to read data streams as a streaming :class:`DataFrame`. - .. versionadded:: 2.0.0 + .. versionadded:: 4.0.0 Notes ----- @@ -428,7 +430,7 @@ def streams(self) -> StreamingQueryManager: :class:`~pyspark.sql.streaming.StreamingQuery` StreamingQueries active on `this` context. - .. versionadded:: 2.0.0 + .. versionadded:: 4.0.0 Notes ----- diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 7145b27f2cf3c..0ae76d833bc85 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -142,10 +142,13 @@ def _ssql_ctx(self) -> "JavaObject": return self._jsqlContext @classmethod - def getOrCreate(cls: Type["SQLContext"], sc: "SparkContext") -> "SQLContext": + def getOrCreate(cls: Type["SQLContext"], sc: Optional["SparkContext"] = None) -> "SQLContext": """ Get the existing SQLContext or create a new one with given SparkContext. + When running in Spark Connect mode, ``sc`` is not required and the active + :class:`SparkSession` is used automatically. + .. versionadded:: 1.6.0 .. deprecated:: 3.0.0 @@ -153,12 +156,22 @@ def getOrCreate(cls: Type["SQLContext"], sc: "SparkContext") -> "SQLContext": Parameters ---------- - sc : :class:`SparkContext` + sc : :class:`SparkContext`, optional + Required in classic mode; ignored in Spark Connect mode. """ + from pyspark.util import is_remote_only + warnings.warn( "Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead.", FutureWarning, ) + if is_remote_only(): + from pyspark.sql.connect.context import SQLContext as ConnectSQLContext + + return ConnectSQLContext._get_or_create_from_session( + SparkSession._getActiveSessionOrCreate() + ) + assert sc is not None, "sc is required in classic (non-Connect) mode" return cls._get_or_create(sc) @classmethod diff --git a/python/pyspark/sql/tests/connect/test_connect_context.py b/python/pyspark/sql/tests/connect/test_connect_context.py index f78243e259af3..368a2cb7723ca 100644 --- a/python/pyspark/sql/tests/connect/test_connect_context.py +++ b/python/pyspark/sql/tests/connect/test_connect_context.py @@ -15,168 +15,34 @@ # limitations under the License. # -import unittest import warnings from pyspark.errors import PySparkNotImplementedError from pyspark.sql.connect.context import HiveContext, SQLContext -from pyspark.sql.connect.dataframe import DataFrame -from pyspark.sql.connect.readwriter import DataFrameReader -from pyspark.sql.connect.streaming.readwriter import DataStreamReader -from pyspark.sql.connect.streaming.query import StreamingQueryManager +from pyspark.sql.tests.test_context import SQLContextTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase -class SQLContextConnectTests(ReusedConnectTestCase): - def setUp(self): - super().setUp() - SQLContext._instantiatedContext = None - - def tearDown(self): - SQLContext._instantiatedContext = None - super().tearDown() - - def _make(self): +class SQLContextParityTests(SQLContextTestsMixin, ReusedConnectTestCase): + def _make_ctx(self) -> SQLContext: with warnings.catch_warnings(): warnings.simplefilter("ignore", FutureWarning) return SQLContext(self.spark) - def test_init_emits_deprecation_warning(self): + def test_init_emits_deprecation_warning(self) -> None: with self.assertWarns(FutureWarning): SQLContext(self.spark) - def test_getOrCreate_returns_same_instance(self): - ctx = self._make() - with warnings.catch_warnings(): - warnings.simplefilter("ignore", FutureWarning) - ctx2 = SQLContext.getOrCreate(self.spark) - self.assertIs(ctx, ctx2) - - def test_getOrCreate_emits_deprecation_warning(self): + def test_getOrCreate_emits_deprecation_warning(self) -> None: SQLContext._instantiatedContext = None with self.assertWarns(FutureWarning): SQLContext.getOrCreate(self.spark) - def test_newSession_returns_new_instance(self): - ctx = self._make() - ctx2 = ctx.newSession() - self.assertIsInstance(ctx2, SQLContext) - self.assertIsNot(ctx, ctx2) - - def test_setConf_and_getConf(self): - ctx = self._make() - ctx.setConf("spark.sql.shuffle.partitions", "42") - self.assertEqual(ctx.getConf("spark.sql.shuffle.partitions"), "42") - self.assertEqual(ctx.getConf("spark.sql.shuffle.partitions", "99"), "42") - ctx.setConf("spark.sql.shuffle.partitions", "200") - - def test_udf_property(self): - from pyspark.sql.connect.udf import UDFRegistration - - ctx = self._make() - self.assertIsInstance(ctx.udf, UDFRegistration) - - def test_udtf_property(self): - from pyspark.sql.connect.udtf import UDTFRegistration - - ctx = self._make() - self.assertIsInstance(ctx.udtf, UDTFRegistration) - - def test_range(self): - ctx = self._make() - df = ctx.range(3) - self.assertIsInstance(df, DataFrame) - self.assertEqual(df.collect(), self.spark.range(3).collect()) - - def test_createDataFrame(self): - ctx = self._make() - df = ctx.createDataFrame([(1, "a"), (2, "b")], ["id", "val"]) - self.assertIsInstance(df, DataFrame) - rows = df.collect() - self.assertEqual(len(rows), 2) - - def test_sql(self): - ctx = self._make() - df = ctx.sql("SELECT 1 AS n") - self.assertIsInstance(df, DataFrame) - self.assertEqual(df.collect()[0].n, 1) - - def test_registerDataFrameAsTable_and_sql(self): - ctx = self._make() - df = self.spark.createDataFrame([(1, "row1"), (2, "row2")], ["field1", "field2"]) - ctx.registerDataFrameAsTable(df, "ctx_test_tbl") - try: - result = ctx.sql("SELECT field1 FROM ctx_test_tbl ORDER BY field1").collect() - self.assertEqual([r.field1 for r in result], [1, 2]) - finally: - ctx.dropTempTable("ctx_test_tbl") - - def test_table(self): - ctx = self._make() - df = self.spark.createDataFrame([(42,)], ["v"]) - df.createOrReplaceTempView("ctx_table_view") - try: - result = ctx.table("ctx_table_view") - self.assertIsInstance(result, DataFrame) - self.assertEqual(result.collect()[0].v, 42) - finally: - self.spark.catalog.dropTempView("ctx_table_view") - - def test_tables_returns_dataframe(self): - ctx = self._make() - df = self.spark.createDataFrame([(1,)], ["x"]) - df.createOrReplaceTempView("ctx_tables_view") - try: - tables_df = ctx.tables() - self.assertIsInstance(tables_df, DataFrame) - names = [r.tableName for r in tables_df.collect()] - self.assertIn("ctx_tables_view", names) - finally: - self.spark.catalog.dropTempView("ctx_tables_view") - - def test_tableNames_returns_list(self): - ctx = self._make() - df = self.spark.createDataFrame([(1,)], ["x"]) - df.createOrReplaceTempView("ctx_tablenames_view") - try: - names = ctx.tableNames() - self.assertIsInstance(names, list) - self.assertIn("ctx_tablenames_view", names) - finally: - self.spark.catalog.dropTempView("ctx_tablenames_view") - - def test_cacheTable_uncacheTable(self): - ctx = self._make() - df = self.spark.createDataFrame([(1,)], ["x"]) - df.createOrReplaceTempView("ctx_cache_view") - try: - ctx.cacheTable("ctx_cache_view") - ctx.uncacheTable("ctx_cache_view") - finally: - self.spark.catalog.dropTempView("ctx_cache_view") - - def test_clearCache(self): - ctx = self._make() - ctx.clearCache() - - def test_read_property(self): - ctx = self._make() - self.assertIsInstance(ctx.read, DataFrameReader) - - def test_readStream_property(self): - ctx = self._make() - self.assertIsInstance(ctx.readStream, DataStreamReader) - - def test_streams_property(self): - ctx = self._make() - self.assertIsInstance(ctx.streams, StreamingQueryManager) - - def test_registerJavaFunction_raises(self): - ctx = self._make() + def test_registerJavaFunction_raises(self) -> None: with self.assertRaises(PySparkNotImplementedError): - ctx.registerJavaFunction("f", "com.example.F") + self._make_ctx().registerJavaFunction("f", "com.example.F") - def test_hive_context_raises(self): + def test_hive_context_raises(self) -> None: with self.assertRaises(PySparkNotImplementedError): HiveContext(self.spark) diff --git a/python/pyspark/sql/tests/test_context.py b/python/pyspark/sql/tests/test_context.py index 2a2dc0cd69ed7..3aaa412ea5e9c 100644 --- a/python/pyspark/sql/tests/test_context.py +++ b/python/pyspark/sql/tests/test_context.py @@ -171,6 +171,109 @@ def range_frame_match(): reload(window) +class SQLContextTestsMixin: + """Tests for SQLContext that run in both classic and Connect modes. + + Subclasses must implement :meth:`_make_ctx` to return an :class:`SQLContext` + appropriate for their mode, and must expose ``self.spark`` as a + :class:`SparkSession`. + """ + + spark: SparkSession + + def _make_ctx(self) -> SQLContext: + raise NotImplementedError + + def setUp(self) -> None: + SQLContext._instantiatedContext = None + + def tearDown(self) -> None: + SQLContext._instantiatedContext = None + + def test_setConf_and_getConf(self) -> None: + ctx = self._make_ctx() + ctx.setConf("spark.sql.shuffle.partitions", "42") + self.assertEqual(ctx.getConf("spark.sql.shuffle.partitions"), "42") + self.assertEqual(ctx.getConf("spark.sql.shuffle.partitions", "99"), "42") + ctx.setConf("spark.sql.shuffle.partitions", "200") + + def test_range(self) -> None: + ctx = self._make_ctx() + self.assertEqual(ctx.range(3).count(), 3) + + def test_createDataFrame(self) -> None: + ctx = self._make_ctx() + df = ctx.createDataFrame([(1, "a"), (2, "b")], ["id", "val"]) + self.assertEqual(df.count(), 2) + + def test_sql(self) -> None: + ctx = self._make_ctx() + self.assertEqual(ctx.sql("SELECT 1 AS n").collect()[0].n, 1) + + def test_registerDataFrameAsTable_and_sql(self) -> None: + ctx = self._make_ctx() + df = self.spark.createDataFrame([(1, "r1"), (2, "r2")], ["field1", "field2"]) + ctx.registerDataFrameAsTable(df, "ctx_mixin_tbl") + try: + result = ctx.sql("SELECT field1 FROM ctx_mixin_tbl ORDER BY field1").collect() + self.assertEqual([r.field1 for r in result], [1, 2]) + finally: + ctx.dropTempTable("ctx_mixin_tbl") + + def test_table(self) -> None: + ctx = self._make_ctx() + self.spark.createDataFrame([(42,)], ["v"]).createOrReplaceTempView("ctx_mixin_view") + try: + self.assertEqual(ctx.table("ctx_mixin_view").collect()[0].v, 42) + finally: + self.spark.catalog.dropTempView("ctx_mixin_view") + + def test_tables_contains_registered_view(self) -> None: + ctx = self._make_ctx() + self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView("ctx_mixin_tables_view") + try: + names = [r.tableName for r in ctx.tables().collect()] + self.assertIn("ctx_mixin_tables_view", names) + finally: + self.spark.catalog.dropTempView("ctx_mixin_tables_view") + + def test_tableNames_contains_registered_view(self) -> None: + ctx = self._make_ctx() + self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView( + "ctx_mixin_tablenames_view" + ) + try: + self.assertIn("ctx_mixin_tablenames_view", ctx.tableNames()) + finally: + self.spark.catalog.dropTempView("ctx_mixin_tablenames_view") + + def test_cacheTable_and_uncacheTable(self) -> None: + ctx = self._make_ctx() + self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView("ctx_mixin_cache_view") + try: + ctx.cacheTable("ctx_mixin_cache_view") + ctx.uncacheTable("ctx_mixin_cache_view") + finally: + self.spark.catalog.dropTempView("ctx_mixin_cache_view") + + def test_clearCache(self) -> None: + self._make_ctx().clearCache() + + def test_newSession_returns_distinct_instance(self) -> None: + ctx = self._make_ctx() + ctx2 = ctx.newSession() + self.assertIsNot(ctx, ctx2) + + def test_read_is_available(self) -> None: + self.assertIsNotNone(self._make_ctx().read) + + def test_readStream_is_available(self) -> None: + self.assertIsNotNone(self._make_ctx().readStream) + + def test_streams_is_available(self) -> None: + self.assertIsNotNone(self._make_ctx().streams) + + class SQLContextTests(unittest.TestCase): def test_get_or_create(self): sc = None From b867ac5530683bff9559ffac99c9efefd1303228 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 28 Apr 2026 14:34:07 -0700 Subject: [PATCH 03/31] Fix mypy errors in Connect SQLContext wrapper - Use base pyspark.sql.dataframe.DataFrame as return type in connect/context.py since SparkSession.sql/range/table/createDataFrame are annotated to return the parent class, not the Connect subclass - Remove now-unnecessary # type: ignore[call-overload] on createDataFrame - Add # type: ignore[return-value, arg-type] to the classic SQLContext.getOrCreate Connect dispatch path where the two SQLContext classes are structurally equivalent but not in the same hierarchy Co-authored-by: Isaac --- python/pyspark/sql/connect/context.py | 4 ++-- python/pyspark/sql/context.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py index ce1a5bdb7e109..f8f45ad56d498 100644 --- a/python/pyspark/sql/connect/context.py +++ b/python/pyspark/sql/connect/context.py @@ -32,7 +32,7 @@ from pyspark import _NoValue from pyspark._globals import _NoValueType from pyspark.errors import PySparkNotImplementedError -from pyspark.sql.connect.dataframe import DataFrame +from pyspark.sql.dataframe import DataFrame from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.streaming.readwriter import DataStreamReader from pyspark.sql.connect.streaming.query import StreamingQueryManager @@ -261,7 +261,7 @@ def createDataFrame( ------- :class:`DataFrame` """ - return self.sparkSession.createDataFrame( # type: ignore[call-overload] + return self.sparkSession.createDataFrame( data, schema, samplingRatio, verifySchema ) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 0ae76d833bc85..eeba347d48c35 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -168,8 +168,9 @@ def getOrCreate(cls: Type["SQLContext"], sc: Optional["SparkContext"] = None) -> if is_remote_only(): from pyspark.sql.connect.context import SQLContext as ConnectSQLContext - return ConnectSQLContext._get_or_create_from_session( - SparkSession._getActiveSessionOrCreate() + session = SparkSession._getActiveSessionOrCreate() + return ConnectSQLContext._get_or_create_from_session( # type: ignore[return-value] + session # type: ignore[arg-type] ) assert sc is not None, "sc is required in classic (non-Connect) mode" return cls._get_or_create(sc) From d605e536453ceb1b2005eba4c6525f79f115aef0 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 30 Apr 2026 16:45:20 -0700 Subject: [PATCH 04/31] Fix CI: register test_connect_context in modules.py and fix ruff formatting Co-authored-by: Isaac --- dev/sparktestsupport/modules.py | 1 + python/pyspark/sql/connect/context.py | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 9853e4dfe7a6b..3d6d41f538f3a 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -1165,6 +1165,7 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_geometrytype", "pyspark.sql.tests.connect.test_parity_datasources", "pyspark.sql.tests.connect.test_parity_errors", + "pyspark.sql.tests.connect.test_connect_context", "pyspark.sql.tests.connect.test_parity_catalog", "pyspark.sql.tests.connect.test_parity_conf", "pyspark.sql.tests.connect.test_parity_serde", diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py index f8f45ad56d498..044026e489f9e 100644 --- a/python/pyspark/sql/connect/context.py +++ b/python/pyspark/sql/connect/context.py @@ -261,9 +261,7 @@ def createDataFrame( ------- :class:`DataFrame` """ - return self.sparkSession.createDataFrame( - data, schema, samplingRatio, verifySchema - ) + return self.sparkSession.createDataFrame(data, schema, samplingRatio, verifySchema) def registerDataFrameAsTable(self, df: DataFrame, tableName: str) -> None: """Registers the given :class:`DataFrame` as a temporary table in the catalog. From 94047fcf0e52893d81e2713e804a1e540bee4f8f Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 30 Apr 2026 17:03:27 -0700 Subject: [PATCH 05/31] Add API reference docs for SQLContext and HiveContext legacy entry points Co-authored-by: Isaac --- .../source/reference/pyspark.sql/index.rst | 1 + .../source/reference/pyspark.sql/legacy.rst | 72 +++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 python/docs/source/reference/pyspark.sql/legacy.rst diff --git a/python/docs/source/reference/pyspark.sql/index.rst b/python/docs/source/reference/pyspark.sql/index.rst index 36618af2de2c2..ce4b2de6278b7 100644 --- a/python/docs/source/reference/pyspark.sql/index.rst +++ b/python/docs/source/reference/pyspark.sql/index.rst @@ -45,3 +45,4 @@ This page gives an overview of all public Spark SQL API. protobuf datasource stateful_processor + legacy diff --git a/python/docs/source/reference/pyspark.sql/legacy.rst b/python/docs/source/reference/pyspark.sql/legacy.rst new file mode 100644 index 0000000000000..d7692af2823f7 --- /dev/null +++ b/python/docs/source/reference/pyspark.sql/legacy.rst @@ -0,0 +1,72 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +==================== +Legacy Entry Points +==================== +.. currentmodule:: pyspark.sql + +:class:`SQLContext` was the primary entry point for Spark SQL in Spark 1.x. +As of Spark 2.0, it has been replaced by :class:`SparkSession`. +These classes are retained for backward compatibility only. + +.. deprecated:: + Use :func:`SparkSession.builder.getOrCreate` instead. + +SQLContext +---------- + +.. autosummary:: + :toctree: api/ + + SQLContext + +.. autosummary:: + :toctree: api/ + + SQLContext.getOrCreate + SQLContext.newSession + SQLContext.setConf + SQLContext.getConf + SQLContext.udf + SQLContext.udtf + SQLContext.range + SQLContext.registerFunction + SQLContext.registerJavaFunction + SQLContext.createDataFrame + SQLContext.registerDataFrameAsTable + SQLContext.dropTempTable + SQLContext.createExternalTable + SQLContext.sql + SQLContext.table + SQLContext.tables + SQLContext.tableNames + SQLContext.cacheTable + SQLContext.uncacheTable + SQLContext.clearCache + SQLContext.read + SQLContext.readStream + SQLContext.streams + +HiveContext +----------- + +.. autosummary:: + :toctree: api/ + + HiveContext From aa3e4e3f7c154cf55d56576a0ec60c31b89516b8 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Fri, 1 May 2026 10:57:28 -0700 Subject: [PATCH 06/31] Fix newSession() in Connect SQLContext to use cloneSession() SparkSession.newSession() is JVM-only and not supported in Spark Connect. Use cloneSession() which is the Connect equivalent. Co-authored-by: Isaac --- python/pyspark/sql/connect/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py index 044026e489f9e..4d08cd0aa6c9e 100644 --- a/python/pyspark/sql/connect/context.py +++ b/python/pyspark/sql/connect/context.py @@ -126,7 +126,7 @@ def newSession(self) -> "SQLContext": .. versionadded:: 4.0.0 """ - return self._from_session(self.sparkSession.newSession()) + return self._from_session(self.sparkSession.cloneSession()) def setConf(self, key: str, value: Union[bool, int, str]) -> None: """Sets the given Spark SQL configuration property. From ee8827614a96a91e52412f5499c49370dfc8d959 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 26 May 2026 12:00:40 -0700 Subject: [PATCH 07/31] Address review comments: fix tables() parity, remove type: ignore, fix setUp/tearDown - tables() in Connect SQLContext now uses catalog.listTables() to always return consistent (namespace, tableName, isTemporary) columns instead of SHOW TABLES whose column names vary across catalogs - Replace type: ignore[return-value/arg-type] in classic getOrCreate() with explicit cast() calls for clarity - Override setUp/tearDown in SQLContextParityTests to reset Connect SQLContext._instantiatedContext between tests, not just the classic one --- python/pyspark/sql/connect/context.py | 23 +++++++++++++------ python/pyspark/sql/context.py | 8 +++++-- .../sql/tests/connect/test_connect_context.py | 8 +++++++ 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py index 4d08cd0aa6c9e..9bfddf9c9e980 100644 --- a/python/pyspark/sql/connect/context.py +++ b/python/pyspark/sql/connect/context.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.streaming.readwriter import DataStreamReader from pyspark.sql.connect.streaming.query import StreamingQueryManager -from pyspark.sql.types import AtomicType, DataType, StructType +from pyspark.sql.types import AtomicType, BooleanType, DataType, StringType, StructField, StructType if TYPE_CHECKING: import numpy as np @@ -335,9 +335,9 @@ def tables(self, dbName: Optional[str] = None) -> DataFrame: If ``dbName`` is not specified, the current database will be used. - The returned DataFrame has two columns: ``tableName`` and ``isTemporary`` - (a column with :class:`~pyspark.sql.types.BooleanType` indicating if a table is a - temporary one or not). + The returned DataFrame has three columns: ``namespace``, ``tableName`` and + ``isTemporary`` (a column with :class:`~pyspark.sql.types.BooleanType` indicating if a + table is a temporary one or not). .. versionadded:: 4.0.0 @@ -350,9 +350,18 @@ def tables(self, dbName: Optional[str] = None) -> DataFrame: ------- :class:`DataFrame` """ - if dbName is None: - return self.sparkSession.sql("SHOW TABLES") - return self.sparkSession.sql(f"SHOW TABLES IN {dbName}") + schema = StructType( + [ + StructField("namespace", StringType(), nullable=True), + StructField("tableName", StringType(), nullable=True), + StructField("isTemporary", BooleanType(), nullable=False), + ] + ) + rows = [ + (t.namespace[0] if t.namespace else "", t.name, t.isTemporary) + for t in self.sparkSession.catalog.listTables(dbName) + ] + return self.sparkSession.createDataFrame(rows, schema) def tableNames(self, dbName: Optional[str] = None) -> List[str]: """Returns a list of names of tables in the database ``dbName``. diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index eeba347d48c35..15fb6c71095f9 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -167,10 +167,14 @@ def getOrCreate(cls: Type["SQLContext"], sc: Optional["SparkContext"] = None) -> ) if is_remote_only(): from pyspark.sql.connect.context import SQLContext as ConnectSQLContext + from pyspark.sql.connect.session import SparkSession as ConnectSparkSession session = SparkSession._getActiveSessionOrCreate() - return ConnectSQLContext._get_or_create_from_session( # type: ignore[return-value] - session # type: ignore[arg-type] + return cast( + "SQLContext", + ConnectSQLContext._get_or_create_from_session( + cast(ConnectSparkSession, session) + ), ) assert sc is not None, "sc is required in classic (non-Connect) mode" return cls._get_or_create(sc) diff --git a/python/pyspark/sql/tests/connect/test_connect_context.py b/python/pyspark/sql/tests/connect/test_connect_context.py index 368a2cb7723ca..f08661607af21 100644 --- a/python/pyspark/sql/tests/connect/test_connect_context.py +++ b/python/pyspark/sql/tests/connect/test_connect_context.py @@ -24,6 +24,14 @@ class SQLContextParityTests(SQLContextTestsMixin, ReusedConnectTestCase): + def setUp(self) -> None: + super().setUp() + SQLContext._instantiatedContext = None + + def tearDown(self) -> None: + super().tearDown() + SQLContext._instantiatedContext = None + def _make_ctx(self) -> SQLContext: with warnings.catch_warnings(): warnings.simplefilter("ignore", FutureWarning) From e4c94a326f84071a4da7d6e82692671631bf5da0 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 26 May 2026 12:05:59 -0700 Subject: [PATCH 08/31] Block HiveContext bypass via getOrCreate by overriding _from_session HiveContext.__init__ raised PySparkNotImplementedError, but HiveContext.getOrCreate(spark) bypassed it because getOrCreate routes through _from_session which uses object.__new__(cls), skipping __init__. Override _from_session in HiveContext to raise unconditionally, closing all bypass paths (getOrCreate, _get_or_create_from_session, newSession). Add test_hive_context_get_or_create_raises to cover the bypass. --- python/pyspark/sql/connect/context.py | 7 +++++++ python/pyspark/sql/tests/connect/test_connect_context.py | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py index 9bfddf9c9e980..c1e516d968d48 100644 --- a/python/pyspark/sql/connect/context.py +++ b/python/pyspark/sql/connect/context.py @@ -458,3 +458,10 @@ def __init__(self, sparkSession: "SparkSession") -> None: errorClass="NOT_IMPLEMENTED", messageParameters={"feature": "HiveContext"}, ) + + @classmethod + def _from_session(cls, sparkSession: "SparkSession") -> "SQLContext": + raise PySparkNotImplementedError( + errorClass="NOT_IMPLEMENTED", + messageParameters={"feature": "HiveContext"}, + ) diff --git a/python/pyspark/sql/tests/connect/test_connect_context.py b/python/pyspark/sql/tests/connect/test_connect_context.py index f08661607af21..151f301c96de3 100644 --- a/python/pyspark/sql/tests/connect/test_connect_context.py +++ b/python/pyspark/sql/tests/connect/test_connect_context.py @@ -54,6 +54,10 @@ def test_hive_context_raises(self) -> None: with self.assertRaises(PySparkNotImplementedError): HiveContext(self.spark) + def test_hive_context_get_or_create_raises(self) -> None: + with self.assertRaises(PySparkNotImplementedError): + HiveContext.getOrCreate(self.spark) + if __name__ == "__main__": from pyspark.testing import main From adea4d41c12dfdb05c72b2a2a22f6e504ec85681 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 26 May 2026 15:27:05 -0700 Subject: [PATCH 09/31] Fix ruff format: collapse inner cast onto one line --- python/pyspark/sql/context.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 15fb6c71095f9..999a345b0b989 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -172,9 +172,7 @@ def getOrCreate(cls: Type["SQLContext"], sc: Optional["SparkContext"] = None) -> session = SparkSession._getActiveSessionOrCreate() return cast( "SQLContext", - ConnectSQLContext._get_or_create_from_session( - cast(ConnectSparkSession, session) - ), + ConnectSQLContext._get_or_create_from_session(cast(ConnectSparkSession, session)), ) assert sc is not None, "sc is required in classic (non-Connect) mode" return cls._get_or_create(sc) From b4329f6622180ee0e7e90a804d7f361adfe6f17a Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 26 May 2026 22:11:03 -0700 Subject: [PATCH 10/31] Trigger CI From b8bb1474441c50fdda663facfab426ba24cd70c5 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 28 May 2026 15:26:47 -0700 Subject: [PATCH 11/31] Address review feedback: remove public getOrCreate from Connect SQLContext, fix HiveContext bypass, reorganize tests - Remove public `getOrCreate` from Connect SQLContext; internal dispatch uses `_get_or_create_from_session` only (fixes Finding #1 / #4) - Fix HiveContext bypass in classic dispatch: route getOrCreate to the Connect counterpart by class name so ConnectHiveContext._from_session raises as expected (fixes Finding #2) - Fix newSession() docstring to accurately describe cloneSession() semantics (fixes Finding #3) - Fix docstring nits: missing article, list/tuple, inferring the schema, table names as strings, streams wording - Add comment explaining catalog.listTables() over SHOW TABLES - Reorganize tests: add test_sql_context.py with mixin + classic runner, test_parity_sql_context.py for Connect parity, slim test_connect_context.py to Connect-specific tests only Co-authored-by: DB Tsai --- python/pyspark/sql/connect/context.py | 45 +++--- python/pyspark/sql/context.py | 9 +- .../sql/tests/connect/test_connect_context.py | 24 +-- .../tests/connect/test_parity_sql_context.py | 42 ++++++ python/pyspark/sql/tests/test_context.py | 104 +------------ python/pyspark/sql/tests/test_sql_context.py | 137 ++++++++++++++++++ 6 files changed, 211 insertions(+), 150 deletions(-) create mode 100644 python/pyspark/sql/tests/connect/test_parity_sql_context.py create mode 100644 python/pyspark/sql/tests/test_sql_context.py diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py index c1e516d968d48..b110222840178 100644 --- a/python/pyspark/sql/connect/context.py +++ b/python/pyspark/sql/connect/context.py @@ -24,7 +24,6 @@ Iterable, List, Tuple, - Type, ClassVar, TYPE_CHECKING, ) @@ -100,31 +99,18 @@ def _get_or_create_from_session(cls, sparkSession: "SparkSession") -> "SQLContex cls._instantiatedContext = cls._from_session(sparkSession) return cls._instantiatedContext - @classmethod - def getOrCreate(cls: Type["SQLContext"], sparkSession: "SparkSession") -> "SQLContext": - """Get the existing SQLContext or create a new one wrapping the given SparkSession. - - .. deprecated:: 3.0.0 - Use :func:`SparkSession.builder.getOrCreate()` instead. - - Parameters - ---------- - sparkSession : :class:`SparkSession` - """ - warnings.warn( - "Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead.", - FutureWarning, - stacklevel=2, - ) - if cls._instantiatedContext is None: - cls._instantiatedContext = cls._from_session(sparkSession) - return cls._instantiatedContext - def newSession(self) -> "SQLContext": - """Returns a new SQLContext as new session, that has separate SQLConf, - registered temporary views and UDFs, but shared table cache. + """Returns a new SQLContext as a new independent server session cloned from this one, + with the current session's configuration, temporary views, and registered functions + copied into it. .. versionadded:: 4.0.0 + + Notes + ----- + Unlike classic :meth:`pyspark.sql.context.SQLContext.newSession`, which returns a fresh + session sharing only the table cache, this uses :meth:`SparkSession.cloneSession` and + inherits the current session's state. """ return self._from_session(self.sparkSession.cloneSession()) @@ -250,10 +236,10 @@ def createDataFrame( :class:`tuple`, ``int``, ``boolean``, etc.), :class:`list`, :class:`pandas.DataFrame`, or :class:`pyarrow.Table`. schema : :class:`~pyspark.sql.types.DataType`, str or list, optional - a :class:`~pyspark.sql.types.DataType` or a datatype string or a list of + a :class:`~pyspark.sql.types.DataType` or a datatype string or a list/tuple of column names. samplingRatio : float, optional - the sample ratio of rows used for inferring + the sample ratio of rows used for inferring the schema. verifySchema : bool, optional verify data types of every row against schema. @@ -357,6 +343,9 @@ def tables(self, dbName: Optional[str] = None) -> DataFrame: StructField("isTemporary", BooleanType(), nullable=False), ] ) + # Use catalog.listTables() rather than SHOW TABLES so the column names are always + # (namespace, tableName, isTemporary), matching the classic implementation. + # SHOW TABLES returns "database" vs "namespace" depending on the active catalog. rows = [ (t.namespace[0] if t.namespace else "", t.name, t.isTemporary) for t in self.sparkSession.catalog.listTables(dbName) @@ -376,7 +365,7 @@ def tableNames(self, dbName: Optional[str] = None) -> List[str]: Returns ------- list - list of table names, in string + list of table names as strings """ return [t.name for t in self.sparkSession.catalog.listTables(dbName)] @@ -434,8 +423,8 @@ def readStream(self) -> DataStreamReader: @property def streams(self) -> StreamingQueryManager: """Returns a :class:`StreamingQueryManager` that allows managing all the - :class:`~pyspark.sql.streaming.StreamingQuery` StreamingQueries active on - `this` context. + :class:`~pyspark.sql.streaming.StreamingQuery` instances active on this + context. .. versionadded:: 4.0.0 diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 999a345b0b989..b988ed092f489 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -166,13 +166,18 @@ def getOrCreate(cls: Type["SQLContext"], sc: Optional["SparkContext"] = None) -> FutureWarning, ) if is_remote_only(): - from pyspark.sql.connect.context import SQLContext as ConnectSQLContext + from pyspark.sql.connect import context as _connect_context from pyspark.sql.connect.session import SparkSession as ConnectSparkSession session = SparkSession._getActiveSessionOrCreate() + # Route to the Connect counterpart so subclasses (e.g. HiveContext) are handled + # correctly: ConnectHiveContext._from_session raises PySparkNotImplementedError. + connect_cls = getattr( + _connect_context, cls.__name__, _connect_context.SQLContext + ) return cast( "SQLContext", - ConnectSQLContext._get_or_create_from_session(cast(ConnectSparkSession, session)), + connect_cls._get_or_create_from_session(cast(ConnectSparkSession, session)), ) assert sc is not None, "sc is required in classic (non-Connect) mode" return cls._get_or_create(sc) diff --git a/python/pyspark/sql/tests/connect/test_connect_context.py b/python/pyspark/sql/tests/connect/test_connect_context.py index 151f301c96de3..3477c71602849 100644 --- a/python/pyspark/sql/tests/connect/test_connect_context.py +++ b/python/pyspark/sql/tests/connect/test_connect_context.py @@ -19,11 +19,12 @@ from pyspark.errors import PySparkNotImplementedError from pyspark.sql.connect.context import HiveContext, SQLContext -from pyspark.sql.tests.test_context import SQLContextTestsMixin from pyspark.testing.connectutils import ReusedConnectTestCase -class SQLContextParityTests(SQLContextTestsMixin, ReusedConnectTestCase): +class SQLContextConnectTests(ReusedConnectTestCase): + """Connect-specific SQLContext tests not covered by the parity mixin.""" + def setUp(self) -> None: super().setUp() SQLContext._instantiatedContext = None @@ -32,32 +33,21 @@ def tearDown(self) -> None: super().tearDown() SQLContext._instantiatedContext = None - def _make_ctx(self) -> SQLContext: - with warnings.catch_warnings(): - warnings.simplefilter("ignore", FutureWarning) - return SQLContext(self.spark) - def test_init_emits_deprecation_warning(self) -> None: with self.assertWarns(FutureWarning): SQLContext(self.spark) - def test_getOrCreate_emits_deprecation_warning(self) -> None: - SQLContext._instantiatedContext = None - with self.assertWarns(FutureWarning): - SQLContext.getOrCreate(self.spark) - def test_registerJavaFunction_raises(self) -> None: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + ctx = SQLContext(self.spark) with self.assertRaises(PySparkNotImplementedError): - self._make_ctx().registerJavaFunction("f", "com.example.F") + ctx.registerJavaFunction("f", "com.example.F") def test_hive_context_raises(self) -> None: with self.assertRaises(PySparkNotImplementedError): HiveContext(self.spark) - def test_hive_context_get_or_create_raises(self) -> None: - with self.assertRaises(PySparkNotImplementedError): - HiveContext.getOrCreate(self.spark) - if __name__ == "__main__": from pyspark.testing import main diff --git a/python/pyspark/sql/tests/connect/test_parity_sql_context.py b/python/pyspark/sql/tests/connect/test_parity_sql_context.py new file mode 100644 index 0000000000000..8d9fd9e950c42 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_sql_context.py @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings + +from pyspark.sql.connect.context import SQLContext +from pyspark.sql.tests.test_sql_context import SQLContextTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class SQLContextParityTests(SQLContextTestsMixin, ReusedConnectTestCase): + def setUp(self) -> None: + super().setUp() + SQLContext._instantiatedContext = None + + def tearDown(self) -> None: + super().tearDown() + SQLContext._instantiatedContext = None + + def _make_ctx(self) -> SQLContext: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + return SQLContext(self.spark) + + +if __name__ == "__main__": + from pyspark.testing import main + + main() diff --git a/python/pyspark/sql/tests/test_context.py b/python/pyspark/sql/tests/test_context.py index 3aaa412ea5e9c..9f8afc54fbe9a 100644 --- a/python/pyspark/sql/tests/test_context.py +++ b/python/pyspark/sql/tests/test_context.py @@ -27,6 +27,7 @@ from pyspark.sql import Row, SparkSession from pyspark.sql.types import StructType, StringType, StructField from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.sql.tests.test_sql_context import SQLContextTestsMixin # noqa: F401 class HiveContextSQLTests(ReusedSQLTestCase): @@ -171,109 +172,6 @@ def range_frame_match(): reload(window) -class SQLContextTestsMixin: - """Tests for SQLContext that run in both classic and Connect modes. - - Subclasses must implement :meth:`_make_ctx` to return an :class:`SQLContext` - appropriate for their mode, and must expose ``self.spark`` as a - :class:`SparkSession`. - """ - - spark: SparkSession - - def _make_ctx(self) -> SQLContext: - raise NotImplementedError - - def setUp(self) -> None: - SQLContext._instantiatedContext = None - - def tearDown(self) -> None: - SQLContext._instantiatedContext = None - - def test_setConf_and_getConf(self) -> None: - ctx = self._make_ctx() - ctx.setConf("spark.sql.shuffle.partitions", "42") - self.assertEqual(ctx.getConf("spark.sql.shuffle.partitions"), "42") - self.assertEqual(ctx.getConf("spark.sql.shuffle.partitions", "99"), "42") - ctx.setConf("spark.sql.shuffle.partitions", "200") - - def test_range(self) -> None: - ctx = self._make_ctx() - self.assertEqual(ctx.range(3).count(), 3) - - def test_createDataFrame(self) -> None: - ctx = self._make_ctx() - df = ctx.createDataFrame([(1, "a"), (2, "b")], ["id", "val"]) - self.assertEqual(df.count(), 2) - - def test_sql(self) -> None: - ctx = self._make_ctx() - self.assertEqual(ctx.sql("SELECT 1 AS n").collect()[0].n, 1) - - def test_registerDataFrameAsTable_and_sql(self) -> None: - ctx = self._make_ctx() - df = self.spark.createDataFrame([(1, "r1"), (2, "r2")], ["field1", "field2"]) - ctx.registerDataFrameAsTable(df, "ctx_mixin_tbl") - try: - result = ctx.sql("SELECT field1 FROM ctx_mixin_tbl ORDER BY field1").collect() - self.assertEqual([r.field1 for r in result], [1, 2]) - finally: - ctx.dropTempTable("ctx_mixin_tbl") - - def test_table(self) -> None: - ctx = self._make_ctx() - self.spark.createDataFrame([(42,)], ["v"]).createOrReplaceTempView("ctx_mixin_view") - try: - self.assertEqual(ctx.table("ctx_mixin_view").collect()[0].v, 42) - finally: - self.spark.catalog.dropTempView("ctx_mixin_view") - - def test_tables_contains_registered_view(self) -> None: - ctx = self._make_ctx() - self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView("ctx_mixin_tables_view") - try: - names = [r.tableName for r in ctx.tables().collect()] - self.assertIn("ctx_mixin_tables_view", names) - finally: - self.spark.catalog.dropTempView("ctx_mixin_tables_view") - - def test_tableNames_contains_registered_view(self) -> None: - ctx = self._make_ctx() - self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView( - "ctx_mixin_tablenames_view" - ) - try: - self.assertIn("ctx_mixin_tablenames_view", ctx.tableNames()) - finally: - self.spark.catalog.dropTempView("ctx_mixin_tablenames_view") - - def test_cacheTable_and_uncacheTable(self) -> None: - ctx = self._make_ctx() - self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView("ctx_mixin_cache_view") - try: - ctx.cacheTable("ctx_mixin_cache_view") - ctx.uncacheTable("ctx_mixin_cache_view") - finally: - self.spark.catalog.dropTempView("ctx_mixin_cache_view") - - def test_clearCache(self) -> None: - self._make_ctx().clearCache() - - def test_newSession_returns_distinct_instance(self) -> None: - ctx = self._make_ctx() - ctx2 = ctx.newSession() - self.assertIsNot(ctx, ctx2) - - def test_read_is_available(self) -> None: - self.assertIsNotNone(self._make_ctx().read) - - def test_readStream_is_available(self) -> None: - self.assertIsNotNone(self._make_ctx().readStream) - - def test_streams_is_available(self) -> None: - self.assertIsNotNone(self._make_ctx().streams) - - class SQLContextTests(unittest.TestCase): def test_get_or_create(self): sc = None diff --git a/python/pyspark/sql/tests/test_sql_context.py b/python/pyspark/sql/tests/test_sql_context.py new file mode 100644 index 0000000000000..cb87bfbc8e689 --- /dev/null +++ b/python/pyspark/sql/tests/test_sql_context.py @@ -0,0 +1,137 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings + +from pyspark import SQLContext +from pyspark.sql import SparkSession +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class SQLContextTestsMixin: + """Tests for SQLContext that run in both classic and Connect modes. + + Subclasses must implement :meth:`_make_ctx` to return an :class:`SQLContext` + appropriate for their mode, and must expose ``self.spark`` as a + :class:`SparkSession`. + """ + + spark: SparkSession + + def _make_ctx(self) -> SQLContext: + raise NotImplementedError + + def setUp(self) -> None: + SQLContext._instantiatedContext = None + + def tearDown(self) -> None: + SQLContext._instantiatedContext = None + + def test_setConf_and_getConf(self) -> None: + ctx = self._make_ctx() + ctx.setConf("spark.sql.shuffle.partitions", "42") + self.assertEqual(ctx.getConf("spark.sql.shuffle.partitions"), "42") + self.assertEqual(ctx.getConf("spark.sql.shuffle.partitions", "99"), "42") + ctx.setConf("spark.sql.shuffle.partitions", "200") + + def test_range(self) -> None: + ctx = self._make_ctx() + self.assertEqual(ctx.range(3).count(), 3) + + def test_createDataFrame(self) -> None: + ctx = self._make_ctx() + df = ctx.createDataFrame([(1, "a"), (2, "b")], ["id", "val"]) + self.assertEqual(df.count(), 2) + + def test_sql(self) -> None: + ctx = self._make_ctx() + self.assertEqual(ctx.sql("SELECT 1 AS n").collect()[0].n, 1) + + def test_registerDataFrameAsTable_and_sql(self) -> None: + ctx = self._make_ctx() + df = self.spark.createDataFrame([(1, "r1"), (2, "r2")], ["field1", "field2"]) + ctx.registerDataFrameAsTable(df, "ctx_mixin_tbl") + try: + result = ctx.sql("SELECT field1 FROM ctx_mixin_tbl ORDER BY field1").collect() + self.assertEqual([r.field1 for r in result], [1, 2]) + finally: + ctx.dropTempTable("ctx_mixin_tbl") + + def test_table(self) -> None: + ctx = self._make_ctx() + self.spark.createDataFrame([(42,)], ["v"]).createOrReplaceTempView("ctx_mixin_view") + try: + self.assertEqual(ctx.table("ctx_mixin_view").collect()[0].v, 42) + finally: + self.spark.catalog.dropTempView("ctx_mixin_view") + + def test_tables_contains_registered_view(self) -> None: + ctx = self._make_ctx() + self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView("ctx_mixin_tables_view") + try: + names = [r.tableName for r in ctx.tables().collect()] + self.assertIn("ctx_mixin_tables_view", names) + finally: + self.spark.catalog.dropTempView("ctx_mixin_tables_view") + + def test_tableNames_contains_registered_view(self) -> None: + ctx = self._make_ctx() + self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView( + "ctx_mixin_tablenames_view" + ) + try: + self.assertIn("ctx_mixin_tablenames_view", ctx.tableNames()) + finally: + self.spark.catalog.dropTempView("ctx_mixin_tablenames_view") + + def test_cacheTable_and_uncacheTable(self) -> None: + ctx = self._make_ctx() + self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView("ctx_mixin_cache_view") + try: + ctx.cacheTable("ctx_mixin_cache_view") + ctx.uncacheTable("ctx_mixin_cache_view") + finally: + self.spark.catalog.dropTempView("ctx_mixin_cache_view") + + def test_clearCache(self) -> None: + self._make_ctx().clearCache() + + def test_newSession_returns_distinct_instance(self) -> None: + ctx = self._make_ctx() + ctx2 = ctx.newSession() + self.assertIsNot(ctx, ctx2) + + def test_read_is_available(self) -> None: + self.assertIsNotNone(self._make_ctx().read) + + def test_readStream_is_available(self) -> None: + self.assertIsNotNone(self._make_ctx().readStream) + + def test_streams_is_available(self) -> None: + self.assertIsNotNone(self._make_ctx().streams) + + +class SQLContextTests(SQLContextTestsMixin, ReusedSQLTestCase): + def _make_ctx(self) -> SQLContext: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + return SQLContext(self.sc, self.spark) + + +if __name__ == "__main__": + from pyspark.testing import main + + main() From 30f30563d68b635c47b3988a6b0112ba1d7d72f9 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 28 May 2026 16:03:51 -0700 Subject: [PATCH 12/31] Fix HiveContext _instantiatedContext bypass and register missing test modules in CI - Add _instantiatedContext = None override on ConnectHiveContext to prevent Python MRO from finding the parent SQLContext's cached instance; without this, HiveContext.getOrCreate() silently returns a SQLContext when called after SQLContext.getOrCreate() instead of raising PySparkNotImplementedError - Register test_parity_sql_context and test_sql_context in modules.py so the shared mixin tests actually run in CI Co-authored-by: DB Tsai --- dev/sparktestsupport/modules.py | 2 ++ python/pyspark/sql/connect/context.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 3d6d41f538f3a..c474b87c8b36f 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -572,6 +572,7 @@ def __hash__(self): "pyspark.sql.tests.test_column", "pyspark.sql.tests.test_conf", "pyspark.sql.tests.test_context", + "pyspark.sql.tests.test_sql_context", "pyspark.sql.tests.test_dataframe", "pyspark.sql.tests.test_collection", "pyspark.sql.tests.test_creation", @@ -1166,6 +1167,7 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_datasources", "pyspark.sql.tests.connect.test_parity_errors", "pyspark.sql.tests.connect.test_connect_context", + "pyspark.sql.tests.connect.test_parity_sql_context", "pyspark.sql.tests.connect.test_parity_catalog", "pyspark.sql.tests.connect.test_parity_conf", "pyspark.sql.tests.connect.test_parity_serde", diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py index b110222840178..96e7e549f6e4d 100644 --- a/python/pyspark/sql/connect/context.py +++ b/python/pyspark/sql/connect/context.py @@ -442,6 +442,10 @@ class HiveContext(SQLContext): Use SparkSession.builder.enableHiveSupport().getOrCreate(). """ + # Override to prevent inheriting SQLContext's cached instance, which would cause + # _get_or_create_from_session to skip _from_session and return the wrong type. + _instantiatedContext: ClassVar[Optional["SQLContext"]] = None + def __init__(self, sparkSession: "SparkSession") -> None: raise PySparkNotImplementedError( errorClass="NOT_IMPLEMENTED", From c95a81f8f66e63aceb5f7936ee3420ee4e8fc51c Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 28 May 2026 16:12:30 -0700 Subject: [PATCH 13/31] Fix SQLContext.__init__ to use type(self) for cache, not hardcoded class name Using SQLContext._instantiatedContext in __init__ would set the parent's cache for any subclass that reaches __init__ without overriding it. Using type(self) ensures each class maintains its own singleton correctly. Co-authored-by: DB Tsai --- python/pyspark/sql/connect/context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py index 96e7e549f6e4d..05b2b1be2f325 100644 --- a/python/pyspark/sql/connect/context.py +++ b/python/pyspark/sql/connect/context.py @@ -77,8 +77,8 @@ def __init__(self, sparkSession: "SparkSession") -> None: stacklevel=2, ) self.sparkSession = sparkSession - if SQLContext._instantiatedContext is None: - SQLContext._instantiatedContext = self + if type(self)._instantiatedContext is None: + type(self)._instantiatedContext = self @classmethod def _from_session(cls, sparkSession: "SparkSession") -> "SQLContext": From bb48a0b152dda8caea90dc130a49f8ccb2eef5d3 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 28 May 2026 16:19:44 -0700 Subject: [PATCH 14/31] Remove misleading deprecated annotation from registerJavaFunction in Connect SQLContext The `.. deprecated:: 2.3.0` note implied spark.udf.registerJavaFunction is a working alternative in Connect, but the method simply raises PySparkNotImplementedError. The annotation is from the classic implementation and should not appear here. Co-authored-by: DB Tsai --- python/pyspark/sql/connect/context.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py index 05b2b1be2f325..9e0d8d304b76d 100644 --- a/python/pyspark/sql/connect/context.py +++ b/python/pyspark/sql/connect/context.py @@ -208,9 +208,6 @@ def registerJavaFunction( """Not supported in Spark Connect. .. versionadded:: 4.0.0 - - .. deprecated:: 2.3.0 - Use :func:`spark.udf.registerJavaFunction` instead. """ raise PySparkNotImplementedError( errorClass="NOT_IMPLEMENTED", From cc15e981f5b8616b4e401f491a5d91f529129688 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 28 May 2026 16:28:47 -0700 Subject: [PATCH 15/31] Fix tables() namespace truncation and stop() cache leak in Connect SQLContext MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit tables(): t.namespace[-1] instead of t.namespace[0] — for multi-level namespaces (e.g. Unity Catalog ['catalog', 'db']), the database is the innermost component (last element), not the first. Using index 0 was silently returning the catalog name instead of the database name. stop(): clear connect.SQLContext._instantiatedContext when the session being stopped is the one wrapped by the cached context. The classic SparkSession.stop() already does this (session.py:2158); the Connect variant was missing the equivalent cleanup, causing getOrCreate() to return a stale context wrapping a closed session after stop(). Co-authored-by: DB Tsai --- python/pyspark/sql/connect/context.py | 2 +- python/pyspark/sql/connect/session.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py index 9e0d8d304b76d..65f80c46d4138 100644 --- a/python/pyspark/sql/connect/context.py +++ b/python/pyspark/sql/connect/context.py @@ -344,7 +344,7 @@ def tables(self, dbName: Optional[str] = None) -> DataFrame: # (namespace, tableName, isTemporary), matching the classic implementation. # SHOW TABLES returns "database" vs "namespace" depending on the active catalog. rows = [ - (t.namespace[0] if t.namespace else "", t.name, t.isTemporary) + (t.namespace[-1] if t.namespace else "", t.name, t.isTemporary) for t in self.sparkSession.catalog.listTables(dbName) ] return self.sparkSession.createDataFrame(rows, schema) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index d538e427e51c2..8f45b865ae057 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -959,6 +959,14 @@ def stop(self) -> None: if self is getattr(SparkSession._active_session, "session", None): SparkSession._active_session.session = None + from pyspark.sql.connect.context import SQLContext as _ConnectSQLContext + + if ( + _ConnectSQLContext._instantiatedContext is not None + and _ConnectSQLContext._instantiatedContext.sparkSession is self + ): + _ConnectSQLContext._instantiatedContext = None + if "SPARK_LOCAL_REMOTE" in os.environ: # When local mode is in use, follow the regular Spark session's # behavior by terminating the Spark Connect server, From 631339f0400095e468106db5108c90a9b7e414f9 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Fri, 29 May 2026 16:15:28 -0700 Subject: [PATCH 16/31] address review feedback: version tags, __all__, comment fix, getOrCreate tests - Remove __all__ from connect/context.py (internal module, not public API) - Change all versionadded:: 4.0.0 to 5.0.0 in connect/context.py - Fix comment: s/ConnectHiveContext/the Connect HiveContext/ in context.py - Remove dead SQLContextTestsMixin import from test_context.py - Add test_getOrCreate_emits_deprecation_and_returns_connect_context and test_hive_context_getOrCreate_raises to test_connect_context.py, both using unittest.mock.patch to mock is_remote_only Co-authored-by: DB Tsai --- python/pyspark/sql/connect/context.py | 48 ++++++++++--------- python/pyspark/sql/context.py | 2 +- .../sql/tests/connect/test_connect_context.py | 20 ++++++++ python/pyspark/sql/tests/test_context.py | 1 - 4 files changed, 46 insertions(+), 25 deletions(-) diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py index 65f80c46d4138..fd8b5e0ee4c07 100644 --- a/python/pyspark/sql/connect/context.py +++ b/python/pyspark/sql/connect/context.py @@ -46,7 +46,9 @@ from pyspark.sql.connect.udtf import UDTFRegistration from pyspark.sql._typing import UserDefinedFunctionLike -__all__ = ["SQLContext", "HiveContext"] +# Internal module — not part of the public PySpark API surface. +# The public SQLContext/HiveContext are in pyspark.sql.context; this module +# is an implementation detail used by the Connect dispatch in that file. class SQLContext: @@ -104,7 +106,7 @@ def newSession(self) -> "SQLContext": with the current session's configuration, temporary views, and registered functions copied into it. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 Notes ----- @@ -117,7 +119,7 @@ def newSession(self) -> "SQLContext": def setConf(self, key: str, value: Union[bool, int, str]) -> None: """Sets the given Spark SQL configuration property. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 """ self.sparkSession.conf.set(key, value) @@ -130,7 +132,7 @@ def getConf( defaultValue. If the key is not set and defaultValue is not set, return the system default value. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 """ return self.sparkSession.conf.get(key, defaultValue) @@ -138,7 +140,7 @@ def getConf( def udf(self) -> "UDFRegistration": """Returns a :class:`UDFRegistration` for UDF registration. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 Returns ------- @@ -150,7 +152,7 @@ def udf(self) -> "UDFRegistration": def udtf(self) -> "UDTFRegistration": """Returns a :class:`UDTFRegistration` for UDTF registration. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 Returns ------- @@ -169,7 +171,7 @@ def range( named ``id``, containing elements in a range from ``start`` to ``end`` (exclusive) with step value ``step``. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 Parameters ---------- @@ -194,7 +196,7 @@ def registerFunction( """An alias for :func:`spark.udf.register`. See :meth:`pyspark.sql.UDFRegistration.register`. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 .. deprecated:: 2.3.0 Use :func:`spark.udf.register` instead. @@ -207,7 +209,7 @@ def registerJavaFunction( ) -> None: """Not supported in Spark Connect. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 """ raise PySparkNotImplementedError( errorClass="NOT_IMPLEMENTED", @@ -224,7 +226,7 @@ def createDataFrame( """Creates a :class:`DataFrame` from an iterable, a :class:`pandas.DataFrame`, or a :class:`pyarrow.Table`. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 Parameters ---------- @@ -251,14 +253,14 @@ def registerDataFrameAsTable(self, df: DataFrame, tableName: str) -> None: Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 """ df.createOrReplaceTempView(tableName) def dropTempTable(self, tableName: str) -> None: """Remove the temporary table from catalog. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 """ self.sparkSession.catalog.dropTempView(tableName) @@ -281,7 +283,7 @@ def createExternalTable( Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and created external table. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 Returns ------- @@ -294,7 +296,7 @@ def createExternalTable( def sql(self, sqlQuery: str) -> DataFrame: """Returns a :class:`DataFrame` representing the result of the given query. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 Returns ------- @@ -305,7 +307,7 @@ def sql(self, sqlQuery: str) -> DataFrame: def table(self, tableName: str) -> DataFrame: """Returns the specified table or view as a :class:`DataFrame`. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 Returns ------- @@ -322,7 +324,7 @@ def tables(self, dbName: Optional[str] = None) -> DataFrame: ``isTemporary`` (a column with :class:`~pyspark.sql.types.BooleanType` indicating if a table is a temporary one or not). - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 Parameters ---------- @@ -352,7 +354,7 @@ def tables(self, dbName: Optional[str] = None) -> DataFrame: def tableNames(self, dbName: Optional[str] = None) -> List[str]: """Returns a list of names of tables in the database ``dbName``. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 Parameters ---------- @@ -369,21 +371,21 @@ def tableNames(self, dbName: Optional[str] = None) -> List[str]: def cacheTable(self, tableName: str) -> None: """Caches the specified table in-memory. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 """ self.sparkSession.catalog.cacheTable(tableName) def uncacheTable(self, tableName: str) -> None: """Removes the specified table from the in-memory cache. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 """ self.sparkSession.catalog.uncacheTable(tableName) def clearCache(self) -> None: """Removes all cached tables from the in-memory cache. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 """ self.sparkSession.catalog.clearCache() @@ -392,7 +394,7 @@ def read(self) -> DataFrameReader: """Returns a :class:`DataFrameReader` that can be used to read data in as a :class:`DataFrame`. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 Returns ------- @@ -405,7 +407,7 @@ def readStream(self) -> DataStreamReader: """Returns a :class:`DataStreamReader` that can be used to read data streams as a streaming :class:`DataFrame`. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 Notes ----- @@ -423,7 +425,7 @@ def streams(self) -> StreamingQueryManager: :class:`~pyspark.sql.streaming.StreamingQuery` instances active on this context. - .. versionadded:: 4.0.0 + .. versionadded:: 5.0.0 Notes ----- diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index b988ed092f489..23da8f7dd61ed 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -171,7 +171,7 @@ def getOrCreate(cls: Type["SQLContext"], sc: Optional["SparkContext"] = None) -> session = SparkSession._getActiveSessionOrCreate() # Route to the Connect counterpart so subclasses (e.g. HiveContext) are handled - # correctly: ConnectHiveContext._from_session raises PySparkNotImplementedError. + # correctly: the Connect HiveContext._from_session raises PySparkNotImplementedError. connect_cls = getattr( _connect_context, cls.__name__, _connect_context.SQLContext ) diff --git a/python/pyspark/sql/tests/connect/test_connect_context.py b/python/pyspark/sql/tests/connect/test_connect_context.py index 3477c71602849..371effd0e1bd3 100644 --- a/python/pyspark/sql/tests/connect/test_connect_context.py +++ b/python/pyspark/sql/tests/connect/test_connect_context.py @@ -16,8 +16,11 @@ # import warnings +from unittest.mock import patch from pyspark.errors import PySparkNotImplementedError +from pyspark.sql import HiveContext as ClassicHiveContext +from pyspark.sql import SQLContext as ClassicSQLContext from pyspark.sql.connect.context import HiveContext, SQLContext from pyspark.testing.connectutils import ReusedConnectTestCase @@ -48,6 +51,23 @@ def test_hive_context_raises(self) -> None: with self.assertRaises(PySparkNotImplementedError): HiveContext(self.spark) + def test_getOrCreate_emits_deprecation_and_returns_connect_context(self) -> None: + """SQLContext.getOrCreate() in remote-only mode returns a Connect-backed context.""" + with patch("pyspark.util.is_remote_only", return_value=True): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + ctx = ClassicSQLContext.getOrCreate() + self.assertTrue(any(issubclass(w.category, FutureWarning) for w in caught)) + self.assertIsInstance(ctx, SQLContext) + + def test_hive_context_getOrCreate_raises(self) -> None: + """HiveContext.getOrCreate() in remote-only mode raises PySparkNotImplementedError.""" + with patch("pyspark.util.is_remote_only", return_value=True): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + with self.assertRaises(PySparkNotImplementedError): + ClassicHiveContext.getOrCreate() + if __name__ == "__main__": from pyspark.testing import main diff --git a/python/pyspark/sql/tests/test_context.py b/python/pyspark/sql/tests/test_context.py index 9f8afc54fbe9a..2a2dc0cd69ed7 100644 --- a/python/pyspark/sql/tests/test_context.py +++ b/python/pyspark/sql/tests/test_context.py @@ -27,7 +27,6 @@ from pyspark.sql import Row, SparkSession from pyspark.sql.types import StructType, StringType, StructField from pyspark.testing.sqlutils import ReusedSQLTestCase -from pyspark.sql.tests.test_sql_context import SQLContextTestsMixin # noqa: F401 class HiveContextSQLTests(ReusedSQLTestCase): From 4e28c40d2cb130d62472c5f059cbaf666e03d273 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Fri, 29 May 2026 16:22:20 -0700 Subject: [PATCH 17/31] fix: versionadded tags should be 4.2.0, not 5.0.0 This PR targets the 4.2 release; 5.0.0 was incorrect. Co-authored-by: DB Tsai --- python/pyspark/sql/connect/context.py | 44 +++++++++++++-------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py index fd8b5e0ee4c07..bc70bd65be3ef 100644 --- a/python/pyspark/sql/connect/context.py +++ b/python/pyspark/sql/connect/context.py @@ -106,7 +106,7 @@ def newSession(self) -> "SQLContext": with the current session's configuration, temporary views, and registered functions copied into it. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 Notes ----- @@ -119,7 +119,7 @@ def newSession(self) -> "SQLContext": def setConf(self, key: str, value: Union[bool, int, str]) -> None: """Sets the given Spark SQL configuration property. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 """ self.sparkSession.conf.set(key, value) @@ -132,7 +132,7 @@ def getConf( defaultValue. If the key is not set and defaultValue is not set, return the system default value. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 """ return self.sparkSession.conf.get(key, defaultValue) @@ -140,7 +140,7 @@ def getConf( def udf(self) -> "UDFRegistration": """Returns a :class:`UDFRegistration` for UDF registration. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 Returns ------- @@ -152,7 +152,7 @@ def udf(self) -> "UDFRegistration": def udtf(self) -> "UDTFRegistration": """Returns a :class:`UDTFRegistration` for UDTF registration. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 Returns ------- @@ -171,7 +171,7 @@ def range( named ``id``, containing elements in a range from ``start`` to ``end`` (exclusive) with step value ``step``. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 Parameters ---------- @@ -196,7 +196,7 @@ def registerFunction( """An alias for :func:`spark.udf.register`. See :meth:`pyspark.sql.UDFRegistration.register`. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 .. deprecated:: 2.3.0 Use :func:`spark.udf.register` instead. @@ -209,7 +209,7 @@ def registerJavaFunction( ) -> None: """Not supported in Spark Connect. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 """ raise PySparkNotImplementedError( errorClass="NOT_IMPLEMENTED", @@ -226,7 +226,7 @@ def createDataFrame( """Creates a :class:`DataFrame` from an iterable, a :class:`pandas.DataFrame`, or a :class:`pyarrow.Table`. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 Parameters ---------- @@ -253,14 +253,14 @@ def registerDataFrameAsTable(self, df: DataFrame, tableName: str) -> None: Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 """ df.createOrReplaceTempView(tableName) def dropTempTable(self, tableName: str) -> None: """Remove the temporary table from catalog. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 """ self.sparkSession.catalog.dropTempView(tableName) @@ -283,7 +283,7 @@ def createExternalTable( Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and created external table. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 Returns ------- @@ -296,7 +296,7 @@ def createExternalTable( def sql(self, sqlQuery: str) -> DataFrame: """Returns a :class:`DataFrame` representing the result of the given query. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 Returns ------- @@ -307,7 +307,7 @@ def sql(self, sqlQuery: str) -> DataFrame: def table(self, tableName: str) -> DataFrame: """Returns the specified table or view as a :class:`DataFrame`. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 Returns ------- @@ -324,7 +324,7 @@ def tables(self, dbName: Optional[str] = None) -> DataFrame: ``isTemporary`` (a column with :class:`~pyspark.sql.types.BooleanType` indicating if a table is a temporary one or not). - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 Parameters ---------- @@ -354,7 +354,7 @@ def tables(self, dbName: Optional[str] = None) -> DataFrame: def tableNames(self, dbName: Optional[str] = None) -> List[str]: """Returns a list of names of tables in the database ``dbName``. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 Parameters ---------- @@ -371,21 +371,21 @@ def tableNames(self, dbName: Optional[str] = None) -> List[str]: def cacheTable(self, tableName: str) -> None: """Caches the specified table in-memory. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 """ self.sparkSession.catalog.cacheTable(tableName) def uncacheTable(self, tableName: str) -> None: """Removes the specified table from the in-memory cache. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 """ self.sparkSession.catalog.uncacheTable(tableName) def clearCache(self) -> None: """Removes all cached tables from the in-memory cache. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 """ self.sparkSession.catalog.clearCache() @@ -394,7 +394,7 @@ def read(self) -> DataFrameReader: """Returns a :class:`DataFrameReader` that can be used to read data in as a :class:`DataFrame`. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 Returns ------- @@ -407,7 +407,7 @@ def readStream(self) -> DataStreamReader: """Returns a :class:`DataStreamReader` that can be used to read data streams as a streaming :class:`DataFrame`. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 Notes ----- @@ -425,7 +425,7 @@ def streams(self) -> StreamingQueryManager: :class:`~pyspark.sql.streaming.StreamingQuery` instances active on this context. - .. versionadded:: 5.0.0 + .. versionadded:: 4.2.0 Notes ----- From dc19a21d248c82ef6e334c7683076ea9f2f10ac6 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Sun, 31 May 2026 21:58:54 -0700 Subject: [PATCH 18/31] address review feedback: 4.3.0 version tags, session-cache comment, mixin test coverage - Change all versionadded:: 4.2.0 to 4.3.0 (4.2.0 is in code freeze) - Add note to _get_or_create_from_session explaining why cached instance is not re-validated: Connect has no JVM lifecycle sentinel; cache cleared via stop() hook - Add test_udf_is_available, test_udtf_is_available, test_registerFunction to SQLContextTestsMixin to match what the PR description claims is covered Co-authored-by: DB Tsai --- python/pyspark/sql/connect/context.py | 49 +++++++++++--------- python/pyspark/sql/tests/test_sql_context.py | 14 ++++++ 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py index bc70bd65be3ef..e854b2a857b2f 100644 --- a/python/pyspark/sql/connect/context.py +++ b/python/pyspark/sql/connect/context.py @@ -96,6 +96,11 @@ def _get_or_create_from_session(cls, sparkSession: "SparkSession") -> "SQLContex Called by the classic :meth:`pyspark.sql.context.SQLContext.getOrCreate` when running in Spark Connect mode, so users do not need to import from ``pyspark.sql.connect`` directly. + + Note: unlike the classic path, the cached instance is not re-validated against + the currently-active session. Connect sessions have no JVM lifecycle to inspect + (no ``_sc._jsc`` sentinel); the cache is cleared via the ``stop()`` hook when + the session terminates. """ if cls._instantiatedContext is None: cls._instantiatedContext = cls._from_session(sparkSession) @@ -106,7 +111,7 @@ def newSession(self) -> "SQLContext": with the current session's configuration, temporary views, and registered functions copied into it. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 Notes ----- @@ -119,7 +124,7 @@ def newSession(self) -> "SQLContext": def setConf(self, key: str, value: Union[bool, int, str]) -> None: """Sets the given Spark SQL configuration property. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 """ self.sparkSession.conf.set(key, value) @@ -132,7 +137,7 @@ def getConf( defaultValue. If the key is not set and defaultValue is not set, return the system default value. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 """ return self.sparkSession.conf.get(key, defaultValue) @@ -140,7 +145,7 @@ def getConf( def udf(self) -> "UDFRegistration": """Returns a :class:`UDFRegistration` for UDF registration. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 Returns ------- @@ -152,7 +157,7 @@ def udf(self) -> "UDFRegistration": def udtf(self) -> "UDTFRegistration": """Returns a :class:`UDTFRegistration` for UDTF registration. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 Returns ------- @@ -171,7 +176,7 @@ def range( named ``id``, containing elements in a range from ``start`` to ``end`` (exclusive) with step value ``step``. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 Parameters ---------- @@ -196,7 +201,7 @@ def registerFunction( """An alias for :func:`spark.udf.register`. See :meth:`pyspark.sql.UDFRegistration.register`. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 .. deprecated:: 2.3.0 Use :func:`spark.udf.register` instead. @@ -209,7 +214,7 @@ def registerJavaFunction( ) -> None: """Not supported in Spark Connect. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 """ raise PySparkNotImplementedError( errorClass="NOT_IMPLEMENTED", @@ -226,7 +231,7 @@ def createDataFrame( """Creates a :class:`DataFrame` from an iterable, a :class:`pandas.DataFrame`, or a :class:`pyarrow.Table`. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 Parameters ---------- @@ -253,14 +258,14 @@ def registerDataFrameAsTable(self, df: DataFrame, tableName: str) -> None: Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 """ df.createOrReplaceTempView(tableName) def dropTempTable(self, tableName: str) -> None: """Remove the temporary table from catalog. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 """ self.sparkSession.catalog.dropTempView(tableName) @@ -283,7 +288,7 @@ def createExternalTable( Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and created external table. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 Returns ------- @@ -296,7 +301,7 @@ def createExternalTable( def sql(self, sqlQuery: str) -> DataFrame: """Returns a :class:`DataFrame` representing the result of the given query. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 Returns ------- @@ -307,7 +312,7 @@ def sql(self, sqlQuery: str) -> DataFrame: def table(self, tableName: str) -> DataFrame: """Returns the specified table or view as a :class:`DataFrame`. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 Returns ------- @@ -324,7 +329,7 @@ def tables(self, dbName: Optional[str] = None) -> DataFrame: ``isTemporary`` (a column with :class:`~pyspark.sql.types.BooleanType` indicating if a table is a temporary one or not). - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 Parameters ---------- @@ -354,7 +359,7 @@ def tables(self, dbName: Optional[str] = None) -> DataFrame: def tableNames(self, dbName: Optional[str] = None) -> List[str]: """Returns a list of names of tables in the database ``dbName``. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 Parameters ---------- @@ -371,21 +376,21 @@ def tableNames(self, dbName: Optional[str] = None) -> List[str]: def cacheTable(self, tableName: str) -> None: """Caches the specified table in-memory. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 """ self.sparkSession.catalog.cacheTable(tableName) def uncacheTable(self, tableName: str) -> None: """Removes the specified table from the in-memory cache. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 """ self.sparkSession.catalog.uncacheTable(tableName) def clearCache(self) -> None: """Removes all cached tables from the in-memory cache. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 """ self.sparkSession.catalog.clearCache() @@ -394,7 +399,7 @@ def read(self) -> DataFrameReader: """Returns a :class:`DataFrameReader` that can be used to read data in as a :class:`DataFrame`. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 Returns ------- @@ -407,7 +412,7 @@ def readStream(self) -> DataStreamReader: """Returns a :class:`DataStreamReader` that can be used to read data streams as a streaming :class:`DataFrame`. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 Notes ----- @@ -425,7 +430,7 @@ def streams(self) -> StreamingQueryManager: :class:`~pyspark.sql.streaming.StreamingQuery` instances active on this context. - .. versionadded:: 4.2.0 + .. versionadded:: 4.3.0 Notes ----- diff --git a/python/pyspark/sql/tests/test_sql_context.py b/python/pyspark/sql/tests/test_sql_context.py index cb87bfbc8e689..3d5e0f3ad773b 100644 --- a/python/pyspark/sql/tests/test_sql_context.py +++ b/python/pyspark/sql/tests/test_sql_context.py @@ -123,6 +123,20 @@ def test_readStream_is_available(self) -> None: def test_streams_is_available(self) -> None: self.assertIsNotNone(self._make_ctx().streams) + def test_udf_is_available(self) -> None: + self.assertIsNotNone(self._make_ctx().udf) + + def test_udtf_is_available(self) -> None: + self.assertIsNotNone(self._make_ctx().udtf) + + def test_registerFunction(self) -> None: + ctx = self._make_ctx() + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + ctx.registerFunction("ctx_mixin_double", lambda x: x * 2) + result = ctx.sql("SELECT ctx_mixin_double(3) AS v").collect()[0].v + self.assertEqual(result, 6) + class SQLContextTests(SQLContextTestsMixin, ReusedSQLTestCase): def _make_ctx(self) -> SQLContext: From 9e37701427c84e54adf82b2254b6044f4d0272b3 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Sun, 31 May 2026 22:59:56 -0700 Subject: [PATCH 19/31] fix: properly validate cached session in _get_or_create_from_session, add createExternalTable test - connect/context.py: re-create the cached context whenever the incoming sparkSession differs from the stored one, mirroring classic _get_or_create's dead-context check (_sc._jsc is None). Connect has no JVM sentinel so we compare session identity instead. - test_sql_context.py: add test_createExternalTable to SQLContextTestsMixin, covering the last untested method claimed in the PR description. Co-authored-by: DB Tsai --- python/pyspark/sql/connect/context.py | 14 +++++++++----- python/pyspark/sql/tests/test_sql_context.py | 13 +++++++++++++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py index e854b2a857b2f..5b2149ae0fbec 100644 --- a/python/pyspark/sql/connect/context.py +++ b/python/pyspark/sql/connect/context.py @@ -97,12 +97,16 @@ def _get_or_create_from_session(cls, sparkSession: "SparkSession") -> "SQLContex running in Spark Connect mode, so users do not need to import from ``pyspark.sql.connect`` directly. - Note: unlike the classic path, the cached instance is not re-validated against - the currently-active session. Connect sessions have no JVM lifecycle to inspect - (no ``_sc._jsc`` sentinel); the cache is cleared via the ``stop()`` hook when - the session terminates. + Unlike the classic path (which checks ``_sc._jsc is None`` to detect a dead + SparkContext), Connect sessions have no JVM lifecycle sentinel. Instead we + re-create whenever the incoming ``sparkSession`` is not the same object as the + one stored in the cached context, which handles the case where the previous + session was stopped and a new one started. """ - if cls._instantiatedContext is None: + if ( + cls._instantiatedContext is None + or cls._instantiatedContext.sparkSession is not sparkSession + ): cls._instantiatedContext = cls._from_session(sparkSession) return cls._instantiatedContext diff --git a/python/pyspark/sql/tests/test_sql_context.py b/python/pyspark/sql/tests/test_sql_context.py index 3d5e0f3ad773b..db8404d88ad87 100644 --- a/python/pyspark/sql/tests/test_sql_context.py +++ b/python/pyspark/sql/tests/test_sql_context.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os +import tempfile import warnings from pyspark import SQLContext @@ -137,6 +139,17 @@ def test_registerFunction(self) -> None: result = ctx.sql("SELECT ctx_mixin_double(3) AS v").collect()[0].v self.assertEqual(result, 6) + def test_createExternalTable(self) -> None: + ctx = self._make_ctx() + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "data") + self.spark.createDataFrame([(1, "a"), (2, "b")], ["id", "val"]).write.parquet(path) + df = ctx.createExternalTable("ctx_mixin_ext_tbl", path, "parquet") + try: + self.assertEqual(df.count(), 2) + finally: + self.spark.sql("DROP TABLE IF EXISTS ctx_mixin_ext_tbl") + class SQLContextTests(SQLContextTestsMixin, ReusedSQLTestCase): def _make_ctx(self) -> SQLContext: From 02fb0b04c3ec150b448b0eff8432626ee41bad38 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 1 Jun 2026 17:29:09 -0700 Subject: [PATCH 20/31] fix: ruff format on context.py getattr, fix registerFunction test return type - Collapse getattr call to a single line to pass ruff format check - Specify IntegerType() in test_registerFunction so the return type is explicit; without it Spark infers StringType and '6' != 6 in Connect mode Co-authored-by: DB Tsai --- python/pyspark/sql/context.py | 4 +--- python/pyspark/sql/tests/test_sql_context.py | 4 +++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 23da8f7dd61ed..9747e57b22e73 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -172,9 +172,7 @@ def getOrCreate(cls: Type["SQLContext"], sc: Optional["SparkContext"] = None) -> session = SparkSession._getActiveSessionOrCreate() # Route to the Connect counterpart so subclasses (e.g. HiveContext) are handled # correctly: the Connect HiveContext._from_session raises PySparkNotImplementedError. - connect_cls = getattr( - _connect_context, cls.__name__, _connect_context.SQLContext - ) + connect_cls = getattr(_connect_context, cls.__name__, _connect_context.SQLContext) return cast( "SQLContext", connect_cls._get_or_create_from_session(cast(ConnectSparkSession, session)), diff --git a/python/pyspark/sql/tests/test_sql_context.py b/python/pyspark/sql/tests/test_sql_context.py index db8404d88ad87..e74432470a8eb 100644 --- a/python/pyspark/sql/tests/test_sql_context.py +++ b/python/pyspark/sql/tests/test_sql_context.py @@ -132,10 +132,12 @@ def test_udtf_is_available(self) -> None: self.assertIsNotNone(self._make_ctx().udtf) def test_registerFunction(self) -> None: + from pyspark.sql.types import IntegerType + ctx = self._make_ctx() with warnings.catch_warnings(): warnings.simplefilter("ignore", FutureWarning) - ctx.registerFunction("ctx_mixin_double", lambda x: x * 2) + ctx.registerFunction("ctx_mixin_double", lambda x: x * 2, IntegerType()) result = ctx.sql("SELECT ctx_mixin_double(3) AS v").collect()[0].v self.assertEqual(result, 6) From b89f9212fa00899adb94ffe72f7b060d22aa28ba Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 3 Jun 2026 22:50:56 -0700 Subject: [PATCH 21/31] address review feedback (viirya): widen getOrCreate dispatch to is_remote(), narrow scope to getOrCreate, fix test super() chain and docs - getOrCreate now dispatches on is_remote() instead of is_remote_only() so a full PySpark install talking to a Connect server (SPARK_REMOTE) is covered, not just remote-only pyspark-client installs. Test patch targets updated to pyspark.sql.utils.is_remote. - Clarify in the PR description/docs that only getOrCreate() is wired to Connect; the SQLContext(...) constructor stays classic-only. - SQLContextTestsMixin.setUp/tearDown now chain super() so ReusedSQLTestCase / ReusedConnectTestCase cleanup runs. - legacy.rst: add version to the .. deprecated:: 3.0.0 directive. - connect/session.py stop(): only reset the SQLContext cache when the context module was already imported (sys.modules guard), avoiding an unconditional import. Co-authored-by: Isaac --- .../source/reference/pyspark.sql/legacy.rst | 2 +- python/pyspark/sql/connect/session.py | 18 +++++++++++------- python/pyspark/sql/context.py | 4 ++-- .../sql/tests/connect/test_connect_context.py | 8 ++++---- python/pyspark/sql/tests/test_sql_context.py | 2 ++ 5 files changed, 20 insertions(+), 14 deletions(-) diff --git a/python/docs/source/reference/pyspark.sql/legacy.rst b/python/docs/source/reference/pyspark.sql/legacy.rst index d7692af2823f7..17db5a437e66b 100644 --- a/python/docs/source/reference/pyspark.sql/legacy.rst +++ b/python/docs/source/reference/pyspark.sql/legacy.rst @@ -25,7 +25,7 @@ Legacy Entry Points As of Spark 2.0, it has been replaced by :class:`SparkSession`. These classes are retained for backward compatibility only. -.. deprecated:: +.. deprecated:: 3.0.0 Use :func:`SparkSession.builder.getOrCreate` instead. SQLContext diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 8f45b865ae057..d588f9d3e2294 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -19,6 +19,7 @@ import json import threading import os +import sys import warnings from collections.abc import Callable, Sized import functools @@ -959,13 +960,16 @@ def stop(self) -> None: if self is getattr(SparkSession._active_session, "session", None): SparkSession._active_session.session = None - from pyspark.sql.connect.context import SQLContext as _ConnectSQLContext - - if ( - _ConnectSQLContext._instantiatedContext is not None - and _ConnectSQLContext._instantiatedContext.sparkSession is self - ): - _ConnectSQLContext._instantiatedContext = None + # Only touch the SQLContext cache if the module was ever imported; if no + # SQLContext was created there is nothing to reset, and we avoid importing it. + _connect_context = sys.modules.get("pyspark.sql.connect.context") + if _connect_context is not None: + _ConnectSQLContext = _connect_context.SQLContext + if ( + _ConnectSQLContext._instantiatedContext is not None + and _ConnectSQLContext._instantiatedContext.sparkSession is self + ): + _ConnectSQLContext._instantiatedContext = None if "SPARK_LOCAL_REMOTE" in os.environ: # When local mode is in use, follow the regular Spark session's diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 9747e57b22e73..8b1228df2cea4 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -159,13 +159,13 @@ def getOrCreate(cls: Type["SQLContext"], sc: Optional["SparkContext"] = None) -> sc : :class:`SparkContext`, optional Required in classic mode; ignored in Spark Connect mode. """ - from pyspark.util import is_remote_only + from pyspark.sql.utils import is_remote warnings.warn( "Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead.", FutureWarning, ) - if is_remote_only(): + if is_remote(): from pyspark.sql.connect import context as _connect_context from pyspark.sql.connect.session import SparkSession as ConnectSparkSession diff --git a/python/pyspark/sql/tests/connect/test_connect_context.py b/python/pyspark/sql/tests/connect/test_connect_context.py index 371effd0e1bd3..bb17d9a752f80 100644 --- a/python/pyspark/sql/tests/connect/test_connect_context.py +++ b/python/pyspark/sql/tests/connect/test_connect_context.py @@ -52,8 +52,8 @@ def test_hive_context_raises(self) -> None: HiveContext(self.spark) def test_getOrCreate_emits_deprecation_and_returns_connect_context(self) -> None: - """SQLContext.getOrCreate() in remote-only mode returns a Connect-backed context.""" - with patch("pyspark.util.is_remote_only", return_value=True): + """SQLContext.getOrCreate() in Connect mode returns a Connect-backed context.""" + with patch("pyspark.sql.utils.is_remote", return_value=True): with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") ctx = ClassicSQLContext.getOrCreate() @@ -61,8 +61,8 @@ def test_getOrCreate_emits_deprecation_and_returns_connect_context(self) -> None self.assertIsInstance(ctx, SQLContext) def test_hive_context_getOrCreate_raises(self) -> None: - """HiveContext.getOrCreate() in remote-only mode raises PySparkNotImplementedError.""" - with patch("pyspark.util.is_remote_only", return_value=True): + """HiveContext.getOrCreate() in Connect mode raises PySparkNotImplementedError.""" + with patch("pyspark.sql.utils.is_remote", return_value=True): with warnings.catch_warnings(): warnings.simplefilter("ignore", FutureWarning) with self.assertRaises(PySparkNotImplementedError): diff --git a/python/pyspark/sql/tests/test_sql_context.py b/python/pyspark/sql/tests/test_sql_context.py index e74432470a8eb..8fcf58a6c8d96 100644 --- a/python/pyspark/sql/tests/test_sql_context.py +++ b/python/pyspark/sql/tests/test_sql_context.py @@ -37,10 +37,12 @@ def _make_ctx(self) -> SQLContext: raise NotImplementedError def setUp(self) -> None: + super().setUp() SQLContext._instantiatedContext = None def tearDown(self) -> None: SQLContext._instantiatedContext = None + super().tearDown() def test_setConf_and_getConf(self) -> None: ctx = self._make_ctx() From a0a9c37fca16e4473282a63ab197469216830f76 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 4 Jun 2026 11:23:41 -0700 Subject: [PATCH 22/31] address review feedback (viirya round 2): replace assert with PySparkValueError, dedupe test class name, doc/test polish - getOrCreate: replace `assert sc is not None` with an explicit PySparkValueError (ARGUMENT_REQUIRED). asserts are stripped under `python -O`, which would skip the guard and fail later with a cryptic AttributeError on sc._jvm. - Rename the new mixin-based classic suite SQLContextTests -> SQLContextClassicTests to avoid the duplicate class name with the existing smoke test in test_context.py. - legacy.rst: note that registerJavaFunction and HiveContext are unsupported under Spark Connect and raise PySparkNotImplementedError. - Add test_newSession_inherits_state to the Connect-specific suite, locking in the documented cloneSession inherited-state behavior (temp views copied into the clone), which intentionally differs from classic newSession. Co-authored-by: Isaac --- .../docs/source/reference/pyspark.sql/legacy.rst | 6 ++++++ python/pyspark/sql/context.py | 12 +++++++++++- .../sql/tests/connect/test_connect_context.py | 14 ++++++++++++++ python/pyspark/sql/tests/test_sql_context.py | 2 +- 4 files changed, 32 insertions(+), 2 deletions(-) diff --git a/python/docs/source/reference/pyspark.sql/legacy.rst b/python/docs/source/reference/pyspark.sql/legacy.rst index 17db5a437e66b..aa905dbaa2b80 100644 --- a/python/docs/source/reference/pyspark.sql/legacy.rst +++ b/python/docs/source/reference/pyspark.sql/legacy.rst @@ -28,6 +28,12 @@ These classes are retained for backward compatibility only. .. deprecated:: 3.0.0 Use :func:`SparkSession.builder.getOrCreate` instead. +.. note:: + Under Spark Connect, :meth:`SQLContext.registerJavaFunction` and the whole + :class:`HiveContext` are not supported and raise + :class:`~pyspark.errors.PySparkNotImplementedError`, + since they rely on a JVM ``SparkContext`` that does not exist in Connect mode. + SQLContext ---------- diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 8b1228df2cea4..85505ea18e7a5 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -34,6 +34,7 @@ from pyspark import _NoValue from pyspark._globals import _NoValueType +from pyspark.errors import PySparkValueError from pyspark.sql.session import _monkey_patch_RDD, SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader @@ -177,7 +178,16 @@ def getOrCreate(cls: Type["SQLContext"], sc: Optional["SparkContext"] = None) -> "SQLContext", connect_cls._get_or_create_from_session(cast(ConnectSparkSession, session)), ) - assert sc is not None, "sc is required in classic (non-Connect) mode" + if sc is None: + # Not an ``assert`` because asserts are stripped under ``python -O``, which + # would skip the guard and fail later with a cryptic AttributeError on sc._jvm. + raise PySparkValueError( + errorClass="ARGUMENT_REQUIRED", + messageParameters={ + "arg_name": "sc", + "condition": "running in classic (non-Connect) mode", + }, + ) return cls._get_or_create(sc) @classmethod diff --git a/python/pyspark/sql/tests/connect/test_connect_context.py b/python/pyspark/sql/tests/connect/test_connect_context.py index bb17d9a752f80..f7ba93a45b4f6 100644 --- a/python/pyspark/sql/tests/connect/test_connect_context.py +++ b/python/pyspark/sql/tests/connect/test_connect_context.py @@ -60,6 +60,20 @@ def test_getOrCreate_emits_deprecation_and_returns_connect_context(self) -> None self.assertTrue(any(issubclass(w.category, FutureWarning) for w in caught)) self.assertIsInstance(ctx, SQLContext) + def test_newSession_inherits_state(self) -> None: + """Connect newSession() clones the parent's state (e.g. temp views) via + cloneSession(), unlike classic newSession() which returns a fresh session.""" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + ctx = SQLContext(self.spark) + self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView("ctx_clone_view") + try: + ctx2 = ctx.newSession() + # cloneSession copies the parent's temp views into the new independent session. + self.assertIn("ctx_clone_view", ctx2.tableNames()) + finally: + self.spark.catalog.dropTempView("ctx_clone_view") + def test_hive_context_getOrCreate_raises(self) -> None: """HiveContext.getOrCreate() in Connect mode raises PySparkNotImplementedError.""" with patch("pyspark.sql.utils.is_remote", return_value=True): diff --git a/python/pyspark/sql/tests/test_sql_context.py b/python/pyspark/sql/tests/test_sql_context.py index 8fcf58a6c8d96..21d935fceb29f 100644 --- a/python/pyspark/sql/tests/test_sql_context.py +++ b/python/pyspark/sql/tests/test_sql_context.py @@ -155,7 +155,7 @@ def test_createExternalTable(self) -> None: self.spark.sql("DROP TABLE IF EXISTS ctx_mixin_ext_tbl") -class SQLContextTests(SQLContextTestsMixin, ReusedSQLTestCase): +class SQLContextClassicTests(SQLContextTestsMixin, ReusedSQLTestCase): def _make_ctx(self) -> SQLContext: with warnings.catch_warnings(): warnings.simplefilter("ignore", FutureWarning) From 4bd0da7cd3ac01b4a76435595de91a77a1cbac6f Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 4 Jun 2026 11:41:32 -0700 Subject: [PATCH 23/31] test cleanup from self-review: drop dead warnings filter, stop cloned Connect session - SQLContextClassicTests._make_ctx: remove the warnings.catch_warnings wrapper. Classic SQLContext.__init__ only warns when sparkSession is None, and a non-None session is passed, so the filter was dead code. - test_newSession_inherits_state: stop the independent server-side session created by newSession()/cloneSession() so it does not leak. Only done in the Connect-only test (safe there); the shared mixin test cannot stop it because classic newSession().stop() would tear down the shared SparkContext. Co-authored-by: Isaac --- python/pyspark/sql/tests/connect/test_connect_context.py | 5 +++++ python/pyspark/sql/tests/test_sql_context.py | 6 +++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_connect_context.py b/python/pyspark/sql/tests/connect/test_connect_context.py index f7ba93a45b4f6..c011c5ab94519 100644 --- a/python/pyspark/sql/tests/connect/test_connect_context.py +++ b/python/pyspark/sql/tests/connect/test_connect_context.py @@ -67,11 +67,16 @@ def test_newSession_inherits_state(self) -> None: warnings.simplefilter("ignore", FutureWarning) ctx = SQLContext(self.spark) self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView("ctx_clone_view") + ctx2 = None try: ctx2 = ctx.newSession() # cloneSession copies the parent's temp views into the new independent session. self.assertIn("ctx_clone_view", ctx2.tableNames()) finally: + # newSession() clones an independent server-side session; stop it so it does + # not leak. (Safe under Connect: it does not touch the shared session.) + if ctx2 is not None: + ctx2.sparkSession.stop() self.spark.catalog.dropTempView("ctx_clone_view") def test_hive_context_getOrCreate_raises(self) -> None: diff --git a/python/pyspark/sql/tests/test_sql_context.py b/python/pyspark/sql/tests/test_sql_context.py index 21d935fceb29f..24bf53c838c01 100644 --- a/python/pyspark/sql/tests/test_sql_context.py +++ b/python/pyspark/sql/tests/test_sql_context.py @@ -157,9 +157,9 @@ def test_createExternalTable(self) -> None: class SQLContextClassicTests(SQLContextTestsMixin, ReusedSQLTestCase): def _make_ctx(self) -> SQLContext: - with warnings.catch_warnings(): - warnings.simplefilter("ignore", FutureWarning) - return SQLContext(self.sc, self.spark) + # Passing a non-None sparkSession means __init__ does not emit a deprecation + # warning, so no warnings filter is needed here. + return SQLContext(self.sc, self.spark) if __name__ == "__main__": From 4ab909dd2ddf2dcc1190fd640a004059cf743db9 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Thu, 4 Jun 2026 12:49:34 -0700 Subject: [PATCH 24/31] fix: do not stop cloned Connect session in test (terminates shared local server) The previous commit added ctx2.sparkSession.stop() in test_newSession_inherits_state to clean up the cloned session. Under SPARK_LOCAL_REMOTE (the test harness), Connect stop() terminates the shared local Connect server, which broke every subsequent test in the pyspark-connect module (exit code 19). Revert the stop(); the cloned session is reclaimed when the test server shuts down, matching test_newSession_returns_distinct_instance. Co-authored-by: Isaac --- python/pyspark/sql/tests/connect/test_connect_context.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_connect_context.py b/python/pyspark/sql/tests/connect/test_connect_context.py index c011c5ab94519..4c4a5c4f2b2f1 100644 --- a/python/pyspark/sql/tests/connect/test_connect_context.py +++ b/python/pyspark/sql/tests/connect/test_connect_context.py @@ -67,16 +67,15 @@ def test_newSession_inherits_state(self) -> None: warnings.simplefilter("ignore", FutureWarning) ctx = SQLContext(self.spark) self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView("ctx_clone_view") - ctx2 = None try: ctx2 = ctx.newSession() # cloneSession copies the parent's temp views into the new independent session. self.assertIn("ctx_clone_view", ctx2.tableNames()) finally: - # newSession() clones an independent server-side session; stop it so it does - # not leak. (Safe under Connect: it does not touch the shared session.) - if ctx2 is not None: - ctx2.sparkSession.stop() + # Do not stop ctx2's cloned session: under SPARK_LOCAL_REMOTE (the test + # harness) Connect stop() terminates the shared local Connect server, which + # would break the rest of the suite. The cloned session is cleaned up when + # the test server shuts down. self.spark.catalog.dropTempView("ctx_clone_view") def test_hive_context_getOrCreate_raises(self) -> None: From edfd561d78e2d7177099d94d36a227a282b14b43 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 8 Jun 2026 13:09:38 -0700 Subject: [PATCH 25/31] address review feedback (HyukjinKwon): bump .. deprecated:: directives to 4.3.0 The versionadded directives were bumped to 4.3.0 in earlier rounds, but the .. deprecated:: directives still carried the classic historical versions (3.0.0 / 2.3.0 / 2.0.0), which is contradictory since a method cannot be deprecated before the version it was added in. Bump all three deprecated directives and the two matching runtime FutureWarning strings to 4.3.0. Co-authored-by: Isaac --- python/pyspark/sql/connect/context.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py index 5b2149ae0fbec..e0197df2e7ff1 100644 --- a/python/pyspark/sql/connect/context.py +++ b/python/pyspark/sql/connect/context.py @@ -61,7 +61,7 @@ class SQLContext: it wraps a Connect :class:`SparkSession` directly and does not require a :class:`~pyspark.SparkContext`. - .. deprecated:: 3.0.0 + .. deprecated:: 4.3.0 Use :func:`SparkSession.builder.getOrCreate()` instead. Parameters @@ -74,7 +74,7 @@ class SQLContext: def __init__(self, sparkSession: "SparkSession") -> None: warnings.warn( - "Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead.", + "Deprecated in 4.3.0. Use SparkSession.builder.getOrCreate() instead.", FutureWarning, stacklevel=2, ) @@ -207,10 +207,10 @@ def registerFunction( .. versionadded:: 4.3.0 - .. deprecated:: 2.3.0 + .. deprecated:: 4.3.0 Use :func:`spark.udf.register` instead. """ - warnings.warn("Deprecated in 2.3.0. Use spark.udf.register instead.", FutureWarning) + warnings.warn("Deprecated in 4.3.0. Use spark.udf.register instead.", FutureWarning) return self.sparkSession.udf.register(name, f, returnType) def registerJavaFunction( @@ -446,7 +446,7 @@ def streams(self) -> StreamingQueryManager: class HiveContext(SQLContext): """Not supported in Spark Connect. - .. deprecated:: 2.0.0 + .. deprecated:: 4.3.0 Use SparkSession.builder.enableHiveSupport().getOrCreate(). """ From 9fdf721b79ca41e9d7c40a930a28df322d922e83 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 8 Jun 2026 21:06:05 -0700 Subject: [PATCH 26/31] fix: release cloned Connect session in test to avoid atexit hang test_newSession_inherits_state left the cloned Connect session fully un-released. Its client registers its own atexit _on_exit; once tearDownClass stops the shared local server (SPARK_LOCAL_REMOTE), _on_exit -> _cleanup_ml_cache retries execute_command against the dead server and hangs until the 450s module timeout. Release the cloned server-side session and close its own client channel in finally, without calling SparkSession.stop() (which would terminate the shared local server). After close() the atexit cleanup fails fast. Co-authored-by: Isaac --- .../sql/tests/connect/test_connect_context.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_connect_context.py b/python/pyspark/sql/tests/connect/test_connect_context.py index 4c4a5c4f2b2f1..4862ff403e3c3 100644 --- a/python/pyspark/sql/tests/connect/test_connect_context.py +++ b/python/pyspark/sql/tests/connect/test_connect_context.py @@ -67,15 +67,26 @@ def test_newSession_inherits_state(self) -> None: warnings.simplefilter("ignore", FutureWarning) ctx = SQLContext(self.spark) self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView("ctx_clone_view") + ctx2 = None try: ctx2 = ctx.newSession() # cloneSession copies the parent's temp views into the new independent session. self.assertIn("ctx_clone_view", ctx2.tableNames()) finally: - # Do not stop ctx2's cloned session: under SPARK_LOCAL_REMOTE (the test - # harness) Connect stop() terminates the shared local Connect server, which - # would break the rest of the suite. The cloned session is cleaned up when - # the test server shuts down. + if ctx2 is not None: + # Release only the cloned server-side session and close its own client + # channel. We must NOT call ctx2's SparkSession.stop(): under + # SPARK_LOCAL_REMOTE (the test harness) it terminates the shared local + # Connect server, breaking the rest of the suite. Leaving the cloned + # client open is also wrong -- once tearDownClass stops the shared + # server, its atexit _on_exit -> _cleanup_ml_cache retries against the + # dead server and hangs until the test times out. + client = ctx2.sparkSession.client + try: + client.release_session() + except Exception: + pass + client.close() self.spark.catalog.dropTempView("ctx_clone_view") def test_hive_context_getOrCreate_raises(self) -> None: From 327fd8018029b7924383c05817fe40e28028dbed Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 9 Jun 2026 14:15:02 -0700 Subject: [PATCH 27/31] address review feedback (hvanhovell): newSession() returns a fresh session, not a clone Connect SQLContext.newSession() previously delegated to SparkSession.cloneSession(), which copies the source session's configuration, temporary views, registered functions, and catalog state into the new session. That diverges from classic newSession() semantics, which return a fresh session with empty state (separate SQLConf / temp views / UDFs, shared table cache). The root cause was that the Connect SparkSession only exposed cloneSession() and had no equivalent of classic newSession(). Add SparkSession.newSession() to the Connect session, mirroring the existing Scala Connect newSession(): rebuild the client against the same endpoint with the session id cleared so a fresh UUID is generated and the server lazily creates an empty isolated session -- no CloneSession RPC and no state copy. SQLContext.newSession() now delegates to the new SparkSession.newSession(), the docstring is restored to the classic contract, and the Connect test asserts the parent's temp views are not visible in the new session. Co-authored-by: Isaac --- python/pyspark/sql/connect/client/core.py | 25 +++++++++++++++++++ python/pyspark/sql/connect/context.py | 13 +++++----- python/pyspark/sql/connect/session.py | 24 ++++++++++++++++++ .../sql/tests/connect/test_connect_context.py | 18 ++++++------- .../sql/tests/test_connect_compatibility.py | 1 - 5 files changed, 64 insertions(+), 17 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 2dd814612a9a3..26724d82ca0a9 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -2633,3 +2633,28 @@ def clone(self, new_session_id: Optional[str] = None) -> "SparkConnectClient": # Ensure the session ID is correctly set from the response new_client._session_id = response.new_session_id return new_client + + def newSession(self) -> "SparkConnectClient": + """ + Create a new client against the same endpoint with a fresh, independent server-side + session that does NOT inherit any state from this client's session. Unlike + :meth:`clone`, no state (SQL configurations, temporary views, registered functions, + catalog state) is copied over, and no server round-trip is made: the new client is + built from a copy of this client's connection configuration with the session ID + cleared, so a fresh session ID is generated and the server lazily creates an empty + isolated session for it. + + Returns + ------- + SparkConnectClient + A new SparkConnectClient instance bound to a fresh, empty session. + """ + # Reuse the same connection configuration (endpoint, channel options, metadata, + # user) but drop the session ID so the constructor generates a fresh UUID. + new_connection = copy.deepcopy(self._builder) + new_connection._params.pop(ChannelBuilder.PARAM_SESSION_ID, None) + return SparkConnectClient( + connection=new_connection, + user_id=self._user_id, + use_reattachable_execute=self._use_reattachable_execute, + ) diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py index e0197df2e7ff1..323efad0c04c0 100644 --- a/python/pyspark/sql/connect/context.py +++ b/python/pyspark/sql/connect/context.py @@ -111,19 +111,18 @@ def _get_or_create_from_session(cls, sparkSession: "SparkSession") -> "SQLContex return cls._instantiatedContext def newSession(self) -> "SQLContext": - """Returns a new SQLContext as a new independent server session cloned from this one, - with the current session's configuration, temporary views, and registered functions - copied into it. + """Returns a new SQLContext as a new session, that has separate SQLConf, + registered temporary views and UDFs, but shared table cache. .. versionadded:: 4.3.0 Notes ----- - Unlike classic :meth:`pyspark.sql.context.SQLContext.newSession`, which returns a fresh - session sharing only the table cache, this uses :meth:`SparkSession.cloneSession` and - inherits the current session's state. + This matches the classic :meth:`pyspark.sql.context.SQLContext.newSession` semantics: + the returned session starts with empty state rather than inheriting this session's + configuration, temporary views, or registered functions. """ - return self._from_session(self.sparkSession.cloneSession()) + return self._from_session(self.sparkSession.newSession()) def setConf(self, key: str, value: Union[bool, int, str]) -> None: """Sets the given Spark SQL configuration property. diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index d588f9d3e2294..9f8d5944da448 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -1321,6 +1321,30 @@ def cloneSession(self, new_session_id: Optional[str] = None) -> "SparkSession": new_session._session_id = cloned_client._session_id return new_session + def newSession(self) -> "SparkSession": + """ + Returns a new :class:`SparkSession` as a new session, that has separate SQLConf, + registered temporary views and UDFs, but shared table cache. + + Unlike :meth:`cloneSession`, the returned session starts with empty state: no + configuration, temporary views, registered functions, or catalog state are copied + over from this session. This mirrors the classic + :meth:`pyspark.sql.session.SparkSession.newSession` semantics. + + .. versionadded:: 4.3.0 + + Returns + ------- + :class:`SparkSession` + A new SparkSession bound to a fresh, independent server-side session. + """ + new_client = self._client.newSession() + # Create a new SparkSession bound to the fresh, independent session directly. + new_session = object.__new__(SparkSession) + new_session._client = new_client + new_session._session_id = new_client._session_id + return new_session + SparkSession.__doc__ = PySparkSession.__doc__ diff --git a/python/pyspark/sql/tests/connect/test_connect_context.py b/python/pyspark/sql/tests/connect/test_connect_context.py index 4862ff403e3c3..a27a55b7ab99a 100644 --- a/python/pyspark/sql/tests/connect/test_connect_context.py +++ b/python/pyspark/sql/tests/connect/test_connect_context.py @@ -60,24 +60,24 @@ def test_getOrCreate_emits_deprecation_and_returns_connect_context(self) -> None self.assertTrue(any(issubclass(w.category, FutureWarning) for w in caught)) self.assertIsInstance(ctx, SQLContext) - def test_newSession_inherits_state(self) -> None: - """Connect newSession() clones the parent's state (e.g. temp views) via - cloneSession(), unlike classic newSession() which returns a fresh session.""" + def test_newSession_returns_fresh_state(self) -> None: + """Connect newSession() returns a fresh, independent session that does NOT inherit + the parent's state (e.g. temp views), matching classic newSession() semantics.""" with warnings.catch_warnings(): warnings.simplefilter("ignore", FutureWarning) ctx = SQLContext(self.spark) - self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView("ctx_clone_view") + self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView("ctx_fresh_view") ctx2 = None try: ctx2 = ctx.newSession() - # cloneSession copies the parent's temp views into the new independent session. - self.assertIn("ctx_clone_view", ctx2.tableNames()) + # newSession() starts with empty state, so the parent's temp view is not visible. + self.assertNotIn("ctx_fresh_view", ctx2.tableNames()) finally: if ctx2 is not None: - # Release only the cloned server-side session and close its own client + # Release only the new server-side session and close its own client # channel. We must NOT call ctx2's SparkSession.stop(): under # SPARK_LOCAL_REMOTE (the test harness) it terminates the shared local - # Connect server, breaking the rest of the suite. Leaving the cloned + # Connect server, breaking the rest of the suite. Leaving the new # client open is also wrong -- once tearDownClass stops the shared # server, its atexit _on_exit -> _cleanup_ml_cache retries against the # dead server and hangs until the test times out. @@ -87,7 +87,7 @@ def test_newSession_inherits_state(self) -> None: except Exception: pass client.close() - self.spark.catalog.dropTempView("ctx_clone_view") + self.spark.catalog.dropTempView("ctx_fresh_view") def test_hive_context_getOrCreate_raises(self) -> None: """HiveContext.getOrCreate() in Connect mode raises PySparkNotImplementedError.""" diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index bd49b1f465482..6c1268f927851 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -264,7 +264,6 @@ def test_spark_session_compatibility(self): expected_missing_connect_methods = { "clearProgressHandlers", "copyFromLocalToFs", - "newSession", "registerProgressHandler", "removeProgressHandler", } From fb8931d969f9fa006f272f998ff333c21763f2b4 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 9 Jun 2026 14:21:40 -0700 Subject: [PATCH 28/31] address review feedback (cloud-fan): namespace parity, em-dash, subclass fallback; drop dead newSession guard - connect/context.py tables(): join the full namespace ("a.b") instead of keeping only the last part, matching classic SHOW TABLES which emits the full quoted namespace. Identical for single-part v1 namespaces; correct under multi-level v2 catalog namespaces. - connect/context.py: replace the em-dash in the module comment with an ASCII hyphen (the only non-ASCII character the PR introduced). - classic context.py getOrCreate(): when a user-defined SQLContext subclass has no Connect counterpart, raise PySparkNotImplementedError instead of silently substituting a base Connect SQLContext that would be missing the subclass's attributes. - connect/session.py: drop the now-dead "newSession" entry from the __getattr__ JVM_ATTRIBUTE_NOT_SUPPORTED guard, since SparkSession.newSession() is now a real method (mirroring Scala Connect's newSession()). Co-authored-by: Isaac --- python/pyspark/sql/connect/context.py | 6 ++++-- python/pyspark/sql/connect/session.py | 2 +- python/pyspark/sql/context.py | 12 ++++++++++-- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py index 323efad0c04c0..a4bea6419870d 100644 --- a/python/pyspark/sql/connect/context.py +++ b/python/pyspark/sql/connect/context.py @@ -46,7 +46,7 @@ from pyspark.sql.connect.udtf import UDTFRegistration from pyspark.sql._typing import UserDefinedFunctionLike -# Internal module — not part of the public PySpark API surface. +# Internal module - not part of the public PySpark API surface. # The public SQLContext/HiveContext are in pyspark.sql.context; this module # is an implementation detail used by the Connect dispatch in that file. @@ -354,7 +354,9 @@ def tables(self, dbName: Optional[str] = None) -> DataFrame: # (namespace, tableName, isTemporary), matching the classic implementation. # SHOW TABLES returns "database" vs "namespace" depending on the active catalog. rows = [ - (t.namespace[-1] if t.namespace else "", t.name, t.isTemporary) + # Join the full namespace ("a.b") to match classic SHOW TABLES, which emits the + # quoted namespace; keeping only the last part would drop levels under a v2 catalog. + (".".join(t.namespace) if t.namespace else "", t.name, t.isTemporary) for t in self.sparkSession.catalog.listTables(dbName) ] return self.sparkSession.createDataFrame(rows, schema) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 9f8d5944da448..8c99240d84521 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -1012,7 +1012,7 @@ def streams(self) -> "StreamingQueryManager": streams.__doc__ = PySparkSession.streams.__doc__ def __getattr__(self, name: str) -> Any: - if name in ["_jsc", "_jconf", "_jvm", "_jsparkSession", "sparkContext", "newSession"]: + if name in ["_jsc", "_jconf", "_jvm", "_jsparkSession", "sparkContext"]: raise PySparkAttributeError( errorClass="JVM_ATTRIBUTE_NOT_SUPPORTED", messageParameters={"attr_name": name} ) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 85505ea18e7a5..01d8eb5583c93 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -34,7 +34,7 @@ from pyspark import _NoValue from pyspark._globals import _NoValueType -from pyspark.errors import PySparkValueError +from pyspark.errors import PySparkNotImplementedError, PySparkValueError from pyspark.sql.session import _monkey_patch_RDD, SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader @@ -173,7 +173,15 @@ def getOrCreate(cls: Type["SQLContext"], sc: Optional["SparkContext"] = None) -> session = SparkSession._getActiveSessionOrCreate() # Route to the Connect counterpart so subclasses (e.g. HiveContext) are handled # correctly: the Connect HiveContext._from_session raises PySparkNotImplementedError. - connect_cls = getattr(_connect_context, cls.__name__, _connect_context.SQLContext) + connect_cls = getattr(_connect_context, cls.__name__, None) + if connect_cls is None: + # A user-defined SQLContext subclass has no Connect counterpart. Fail loudly + # instead of silently returning a base Connect SQLContext that would be + # missing the subclass's attributes. + raise PySparkNotImplementedError( + errorClass="NOT_IMPLEMENTED", + messageParameters={"feature": f"{cls.__name__}.getOrCreate in Spark Connect"}, + ) return cast( "SQLContext", connect_cls._get_or_create_from_session(cast(ConnectSparkSession, session)), From 5df84a4189fa7ffec31dacc36194a7e59bdc3abd Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 9 Jun 2026 15:13:11 -0700 Subject: [PATCH 29/31] fix: set release_session_on_close on sessions created via object.__new__ newSession() and cloneSession() bypass SparkSession.__init__, so the release_session_on_close attribute read unconditionally by stop() was never set, making stop() raise AttributeError on such sessions. Set it explicitly in both paths and assert it in the regression tests (the local-remote harness cannot safely call stop() on a child session). Co-authored-by: Isaac --- python/pyspark/sql/connect/session.py | 2 ++ .../pyspark/sql/tests/connect/test_connect_clone_session.py | 4 ++++ python/pyspark/sql/tests/connect/test_connect_context.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 8c99240d84521..307eb7ee7fd89 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -1319,6 +1319,7 @@ def cloneSession(self, new_session_id: Optional[str] = None) -> "SparkSession": new_session = object.__new__(SparkSession) new_session._client = cloned_client new_session._session_id = cloned_client._session_id + new_session.release_session_on_close = True return new_session def newSession(self) -> "SparkSession": @@ -1343,6 +1344,7 @@ def newSession(self) -> "SparkSession": new_session = object.__new__(SparkSession) new_session._client = new_client new_session._session_id = new_client._session_id + new_session.release_session_on_close = True return new_session diff --git a/python/pyspark/sql/tests/connect/test_connect_clone_session.py b/python/pyspark/sql/tests/connect/test_connect_clone_session.py index b4c9e4c67c8bb..7cb4009912377 100644 --- a/python/pyspark/sql/tests/connect/test_connect_clone_session.py +++ b/python/pyspark/sql/tests/connect/test_connect_clone_session.py @@ -32,6 +32,10 @@ def test_clone_session_basic(self): # Clone the session cloned_session = self.connect.cloneSession() + # The cloned session bypasses SparkSession.__init__, so make sure it still + # carries the attributes that SparkSession.stop() reads. + self.assertTrue(cloned_session.release_session_on_close) + # Verify the configuration was copied # (if cloning doesn't preserve dynamic configs, use a different approach) cloned_value = cloned_session.sql("SET spark.test.original").collect()[0][1] diff --git a/python/pyspark/sql/tests/connect/test_connect_context.py b/python/pyspark/sql/tests/connect/test_connect_context.py index a27a55b7ab99a..2857f70ebf2f9 100644 --- a/python/pyspark/sql/tests/connect/test_connect_context.py +++ b/python/pyspark/sql/tests/connect/test_connect_context.py @@ -72,6 +72,9 @@ def test_newSession_returns_fresh_state(self) -> None: ctx2 = ctx.newSession() # newSession() starts with empty state, so the parent's temp view is not visible. self.assertNotIn("ctx_fresh_view", ctx2.tableNames()) + # The new session bypasses SparkSession.__init__, so make sure it still + # carries the attributes that SparkSession.stop() reads. + self.assertTrue(ctx2.sparkSession.release_session_on_close) finally: if ctx2 is not None: # Release only the new server-side session and close its own client From 2b7fd64491f241764d1d7348348c1e7671f9ffd2 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Tue, 9 Jun 2026 15:13:20 -0700 Subject: [PATCH 30/31] fix: release leaked Connect client in parity newSession test In SQLContextParityTests, the inherited test_newSession_returns_distinct_instance created a new SparkConnectClient (with its atexit hook) that was never released or closed, reproducing the hang-after-tearDownClass failure mode documented in test_connect_context.py. Override the test in the Connect parity subclass to release the server-side session and close the client. Co-authored-by: Isaac --- .../tests/connect/test_parity_sql_context.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/pyspark/sql/tests/connect/test_parity_sql_context.py b/python/pyspark/sql/tests/connect/test_parity_sql_context.py index 8d9fd9e950c42..04491a209a91c 100644 --- a/python/pyspark/sql/tests/connect/test_parity_sql_context.py +++ b/python/pyspark/sql/tests/connect/test_parity_sql_context.py @@ -35,6 +35,24 @@ def _make_ctx(self) -> SQLContext: warnings.simplefilter("ignore", FutureWarning) return SQLContext(self.spark) + def test_newSession_returns_distinct_instance(self) -> None: + ctx = self._make_ctx() + ctx2 = ctx.newSession() + try: + self.assertIsNot(ctx, ctx2) + finally: + # Release only the new server-side session and close its own client channel, + # not SparkSession.stop(). See SQLContextConnectTests.test_newSession_returns + # _fresh_state in test_connect_context.py for why: stop() would terminate the + # shared local Connect server, while leaving the client open hangs the suite + # in the client's atexit hook once the server goes away. + client = ctx2.sparkSession.client + try: + client.release_session() + except Exception: + pass + client.close() + if __name__ == "__main__": from pyspark.testing import main From 392a4aeecb9b20359e510e78264de9c092852dfd Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Wed, 10 Jun 2026 18:18:01 -0700 Subject: [PATCH 31/31] address review feedback (Codex): preserve session hooks and RPC deadlines in newSession() SparkConnectClient.newSession() built the fresh client from only connection, user_id, and use_reattachable_execute, silently dropping registered session hooks and custom RPC deadlines, while the adjacent clone() path preserves both. Pass session_hooks and rpc_deadlines through, and add a mock-based test asserting a registered hook still observes ExecutePlanRequests issued through a session created via newSession(). Co-authored-by: Isaac --- python/pyspark/sql/connect/client/core.py | 4 +++ .../sql/tests/connect/client/test_client.py | 34 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 26724d82ca0a9..6e0d4cbcf1ef7 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -2653,8 +2653,12 @@ def newSession(self) -> "SparkConnectClient": # user) but drop the session ID so the constructor generates a fresh UUID. new_connection = copy.deepcopy(self._builder) new_connection._params.pop(ChannelBuilder.PARAM_SESSION_ID, None) + # Only server-side session state is left behind: client-side behavior such as + # registered session hooks and RPC deadlines carries over, as in clone(). return SparkConnectClient( connection=new_connection, user_id=self._user_id, use_reattachable_execute=self._use_reattachable_execute, + session_hooks=self._session_hooks, + rpc_deadlines=self._rpc_deadlines, ) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 10580bd52e629..5f1bac730c97a 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -440,6 +440,40 @@ def on_execute_plan(self, req): self.assertEqual(inits, 1) self.assertEqual(calls, 2) + def test_session_hook_preserved_after_new_session(self): + calls = 0 + + class TestHook(RemoteSparkSession.Hook): + def __init__(self, _session): + pass + + def on_execute_plan(self, req): + nonlocal calls + calls += 1 + return req + + # Use create() instead of getOrCreate() to avoid picking up a session (and hooks) + # left active by other tests. + session = RemoteSparkSession.builder.remote("sc://foo")._registerHook(TestHook).create() + new_session = session.newSession() + try: + # Client-side behavior carries over to the fresh session, as in clone(). + self.assertEqual(new_session.client._session_hooks, session.client._session_hooks) + self.assertEqual(new_session.client._rpc_deadlines, session.client._rpc_deadlines) + + new_session.client._stub = MockService(new_session.client._session_id) + new_session.client.disable_reattachable_execute() + + # The hook still observes ExecutePlanRequests issued through the new session. + self.assertEqual(calls, 0) + new_session.range(1).collect() + self.assertEqual(calls, 1) + finally: + # Close the clients so their atexit hooks do not try to release sessions + # against the unreachable endpoint at interpreter shutdown. + new_session.client.close() + session.client.close() + def test_custom_operation_id(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) mock = MockService(client._session_id)