diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 9853e4dfe7a6b..c474b87c8b36f 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -572,6 +572,7 @@ def __hash__(self): "pyspark.sql.tests.test_column", "pyspark.sql.tests.test_conf", "pyspark.sql.tests.test_context", + "pyspark.sql.tests.test_sql_context", "pyspark.sql.tests.test_dataframe", "pyspark.sql.tests.test_collection", "pyspark.sql.tests.test_creation", @@ -1165,6 +1166,8 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_geometrytype", "pyspark.sql.tests.connect.test_parity_datasources", "pyspark.sql.tests.connect.test_parity_errors", + "pyspark.sql.tests.connect.test_connect_context", + "pyspark.sql.tests.connect.test_parity_sql_context", "pyspark.sql.tests.connect.test_parity_catalog", "pyspark.sql.tests.connect.test_parity_conf", "pyspark.sql.tests.connect.test_parity_serde", diff --git a/python/docs/source/reference/pyspark.sql/index.rst b/python/docs/source/reference/pyspark.sql/index.rst index 36618af2de2c2..ce4b2de6278b7 100644 --- a/python/docs/source/reference/pyspark.sql/index.rst +++ b/python/docs/source/reference/pyspark.sql/index.rst @@ -45,3 +45,4 @@ This page gives an overview of all public Spark SQL API. protobuf datasource stateful_processor + legacy diff --git a/python/docs/source/reference/pyspark.sql/legacy.rst b/python/docs/source/reference/pyspark.sql/legacy.rst new file mode 100644 index 0000000000000..aa905dbaa2b80 --- /dev/null +++ b/python/docs/source/reference/pyspark.sql/legacy.rst @@ -0,0 +1,78 @@ +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + + +==================== +Legacy Entry Points +==================== +.. currentmodule:: pyspark.sql + +:class:`SQLContext` was the primary entry point for Spark SQL in Spark 1.x. +As of Spark 2.0, it has been replaced by :class:`SparkSession`. +These classes are retained for backward compatibility only. + +.. deprecated:: 3.0.0 + Use :func:`SparkSession.builder.getOrCreate` instead. + +.. note:: + Under Spark Connect, :meth:`SQLContext.registerJavaFunction` and the whole + :class:`HiveContext` are not supported and raise + :class:`~pyspark.errors.PySparkNotImplementedError`, + since they rely on a JVM ``SparkContext`` that does not exist in Connect mode. + +SQLContext +---------- + +.. autosummary:: + :toctree: api/ + + SQLContext + +.. autosummary:: + :toctree: api/ + + SQLContext.getOrCreate + SQLContext.newSession + SQLContext.setConf + SQLContext.getConf + SQLContext.udf + SQLContext.udtf + SQLContext.range + SQLContext.registerFunction + SQLContext.registerJavaFunction + SQLContext.createDataFrame + SQLContext.registerDataFrameAsTable + SQLContext.dropTempTable + SQLContext.createExternalTable + SQLContext.sql + SQLContext.table + SQLContext.tables + SQLContext.tableNames + SQLContext.cacheTable + SQLContext.uncacheTable + SQLContext.clearCache + SQLContext.read + SQLContext.readStream + SQLContext.streams + +HiveContext +----------- + +.. autosummary:: + :toctree: api/ + + HiveContext diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 2dd814612a9a3..6e0d4cbcf1ef7 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -2633,3 +2633,32 @@ def clone(self, new_session_id: Optional[str] = None) -> "SparkConnectClient": # Ensure the session ID is correctly set from the response new_client._session_id = response.new_session_id return new_client + + def newSession(self) -> "SparkConnectClient": + """ + Create a new client against the same endpoint with a fresh, independent server-side + session that does NOT inherit any state from this client's session. Unlike + :meth:`clone`, no state (SQL configurations, temporary views, registered functions, + catalog state) is copied over, and no server round-trip is made: the new client is + built from a copy of this client's connection configuration with the session ID + cleared, so a fresh session ID is generated and the server lazily creates an empty + isolated session for it. + + Returns + ------- + SparkConnectClient + A new SparkConnectClient instance bound to a fresh, empty session. + """ + # Reuse the same connection configuration (endpoint, channel options, metadata, + # user) but drop the session ID so the constructor generates a fresh UUID. + new_connection = copy.deepcopy(self._builder) + new_connection._params.pop(ChannelBuilder.PARAM_SESSION_ID, None) + # Only server-side session state is left behind: client-side behavior such as + # registered session hooks and RPC deadlines carries over, as in clone(). + return SparkConnectClient( + connection=new_connection, + user_id=self._user_id, + use_reattachable_execute=self._use_reattachable_execute, + session_hooks=self._session_hooks, + rpc_deadlines=self._rpc_deadlines, + ) diff --git a/python/pyspark/sql/connect/context.py b/python/pyspark/sql/connect/context.py new file mode 100644 index 0000000000000..a4bea6419870d --- /dev/null +++ b/python/pyspark/sql/connect/context.py @@ -0,0 +1,469 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import ( + Optional, + Union, + Callable, + Any, + Iterable, + List, + Tuple, + ClassVar, + TYPE_CHECKING, +) + +from pyspark import _NoValue +from pyspark._globals import _NoValueType +from pyspark.errors import PySparkNotImplementedError +from pyspark.sql.dataframe import DataFrame +from pyspark.sql.connect.readwriter import DataFrameReader +from pyspark.sql.connect.streaming.readwriter import DataStreamReader +from pyspark.sql.connect.streaming.query import StreamingQueryManager +from pyspark.sql.types import AtomicType, BooleanType, DataType, StringType, StructField, StructType + +if TYPE_CHECKING: + import numpy as np + import pandas as pd + import pyarrow as pa + from pyspark.sql.connect.session import SparkSession + from pyspark.sql.connect.udf import UDFRegistration + from pyspark.sql.connect.udtf import UDTFRegistration + from pyspark.sql._typing import UserDefinedFunctionLike + +# Internal module - not part of the public PySpark API surface. +# The public SQLContext/HiveContext are in pyspark.sql.context; this module +# is an implementation detail used by the Connect dispatch in that file. + + +class SQLContext: + """The entry point for working with structured data (rows and columns) in Spark, in Spark 1.x. + + As of Spark 2.0, this is replaced by :class:`SparkSession`. However, we are keeping the class + here for backward compatibility. + + This is the Spark Connect-compatible implementation. Unlike the classic implementation, + it wraps a Connect :class:`SparkSession` directly and does not require a + :class:`~pyspark.SparkContext`. + + .. deprecated:: 4.3.0 + Use :func:`SparkSession.builder.getOrCreate()` instead. + + Parameters + ---------- + sparkSession : :class:`SparkSession` + The Connect :class:`SparkSession` to wrap. + """ + + _instantiatedContext: ClassVar[Optional["SQLContext"]] = None + + def __init__(self, sparkSession: "SparkSession") -> None: + warnings.warn( + "Deprecated in 4.3.0. Use SparkSession.builder.getOrCreate() instead.", + FutureWarning, + stacklevel=2, + ) + self.sparkSession = sparkSession + if type(self)._instantiatedContext is None: + type(self)._instantiatedContext = self + + @classmethod + def _from_session(cls, sparkSession: "SparkSession") -> "SQLContext": + """Create a new instance without emitting a deprecation warning.""" + ctx = object.__new__(cls) + ctx.sparkSession = sparkSession + return ctx + + @classmethod + def _get_or_create_from_session(cls, sparkSession: "SparkSession") -> "SQLContext": + """Return the cached instance or create one from an active Connect SparkSession. + + Called by the classic :meth:`pyspark.sql.context.SQLContext.getOrCreate` when + running in Spark Connect mode, so users do not need to import from + ``pyspark.sql.connect`` directly. + + Unlike the classic path (which checks ``_sc._jsc is None`` to detect a dead + SparkContext), Connect sessions have no JVM lifecycle sentinel. Instead we + re-create whenever the incoming ``sparkSession`` is not the same object as the + one stored in the cached context, which handles the case where the previous + session was stopped and a new one started. + """ + if ( + cls._instantiatedContext is None + or cls._instantiatedContext.sparkSession is not sparkSession + ): + cls._instantiatedContext = cls._from_session(sparkSession) + return cls._instantiatedContext + + def newSession(self) -> "SQLContext": + """Returns a new SQLContext as a new session, that has separate SQLConf, + registered temporary views and UDFs, but shared table cache. + + .. versionadded:: 4.3.0 + + Notes + ----- + This matches the classic :meth:`pyspark.sql.context.SQLContext.newSession` semantics: + the returned session starts with empty state rather than inheriting this session's + configuration, temporary views, or registered functions. + """ + return self._from_session(self.sparkSession.newSession()) + + def setConf(self, key: str, value: Union[bool, int, str]) -> None: + """Sets the given Spark SQL configuration property. + + .. versionadded:: 4.3.0 + """ + self.sparkSession.conf.set(key, value) + + def getConf( + self, key: str, defaultValue: Union[Optional[str], _NoValueType] = _NoValue + ) -> Optional[str]: + """Returns the value of Spark SQL configuration property for the given key. + + If the key is not set and defaultValue is set, return + defaultValue. If the key is not set and defaultValue is not set, return + the system default value. + + .. versionadded:: 4.3.0 + """ + return self.sparkSession.conf.get(key, defaultValue) + + @property + def udf(self) -> "UDFRegistration": + """Returns a :class:`UDFRegistration` for UDF registration. + + .. versionadded:: 4.3.0 + + Returns + ------- + :class:`UDFRegistration` + """ + return self.sparkSession.udf + + @property + def udtf(self) -> "UDTFRegistration": + """Returns a :class:`UDTFRegistration` for UDTF registration. + + .. versionadded:: 4.3.0 + + Returns + ------- + :class:`UDTFRegistration` + """ + return self.sparkSession.udtf + + def range( + self, + start: int, + end: Optional[int] = None, + step: int = 1, + numPartitions: Optional[int] = None, + ) -> DataFrame: + """Create a :class:`DataFrame` with single :class:`~pyspark.sql.types.LongType` column + named ``id``, containing elements in a range from ``start`` to ``end`` (exclusive) with + step value ``step``. + + .. versionadded:: 4.3.0 + + Parameters + ---------- + start : int + the start value + end : int, optional + the end value (exclusive) + step : int, optional + the incremental step (default: 1) + numPartitions : int, optional + the number of partitions of the DataFrame + + Returns + ------- + :class:`DataFrame` + """ + return self.sparkSession.range(start, end, step, numPartitions) + + def registerFunction( + self, name: str, f: Callable[..., Any], returnType: Optional[DataType] = None + ) -> "UserDefinedFunctionLike": + """An alias for :func:`spark.udf.register`. + See :meth:`pyspark.sql.UDFRegistration.register`. + + .. versionadded:: 4.3.0 + + .. deprecated:: 4.3.0 + Use :func:`spark.udf.register` instead. + """ + warnings.warn("Deprecated in 4.3.0. Use spark.udf.register instead.", FutureWarning) + return self.sparkSession.udf.register(name, f, returnType) + + def registerJavaFunction( + self, name: str, javaClassName: str, returnType: Optional[DataType] = None + ) -> None: + """Not supported in Spark Connect. + + .. versionadded:: 4.3.0 + """ + raise PySparkNotImplementedError( + errorClass="NOT_IMPLEMENTED", + messageParameters={"feature": "registerJavaFunction"}, + ) + + def createDataFrame( + self, + data: Union["pd.DataFrame", "np.ndarray", "pa.Table", Iterable[Any]], + schema: Optional[Union[AtomicType, StructType, str, List[str], Tuple[str, ...]]] = None, + samplingRatio: Optional[float] = None, + verifySchema: Optional[bool] = None, + ) -> DataFrame: + """Creates a :class:`DataFrame` from an iterable, a :class:`pandas.DataFrame`, + or a :class:`pyarrow.Table`. + + .. versionadded:: 4.3.0 + + Parameters + ---------- + data : iterable + an iterable of any kind of SQL data representation (:class:`Row`, + :class:`tuple`, ``int``, ``boolean``, etc.), :class:`list`, + :class:`pandas.DataFrame`, or :class:`pyarrow.Table`. + schema : :class:`~pyspark.sql.types.DataType`, str or list, optional + a :class:`~pyspark.sql.types.DataType` or a datatype string or a list/tuple of + column names. + samplingRatio : float, optional + the sample ratio of rows used for inferring the schema. + verifySchema : bool, optional + verify data types of every row against schema. + + Returns + ------- + :class:`DataFrame` + """ + return self.sparkSession.createDataFrame(data, schema, samplingRatio, verifySchema) + + def registerDataFrameAsTable(self, df: DataFrame, tableName: str) -> None: + """Registers the given :class:`DataFrame` as a temporary table in the catalog. + + Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`. + + .. versionadded:: 4.3.0 + """ + df.createOrReplaceTempView(tableName) + + def dropTempTable(self, tableName: str) -> None: + """Remove the temporary table from catalog. + + .. versionadded:: 4.3.0 + """ + self.sparkSession.catalog.dropTempView(tableName) + + def createExternalTable( + self, + tableName: str, + path: Optional[str] = None, + source: Optional[str] = None, + schema: Optional[StructType] = None, + **options: str, + ) -> DataFrame: + """Creates an external table based on the dataset in a data source. + + It returns the DataFrame associated with the external table. + + The data source is specified by the ``source`` and a set of ``options``. + If ``source`` is not specified, the default data source configured by + ``spark.sql.sources.default`` will be used. + + Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and + created external table. + + .. versionadded:: 4.3.0 + + Returns + ------- + :class:`DataFrame` + """ + return self.sparkSession.catalog.createExternalTable( + tableName, path, source, schema, **options + ) + + def sql(self, sqlQuery: str) -> DataFrame: + """Returns a :class:`DataFrame` representing the result of the given query. + + .. versionadded:: 4.3.0 + + Returns + ------- + :class:`DataFrame` + """ + return self.sparkSession.sql(sqlQuery) + + def table(self, tableName: str) -> DataFrame: + """Returns the specified table or view as a :class:`DataFrame`. + + .. versionadded:: 4.3.0 + + Returns + ------- + :class:`DataFrame` + """ + return self.sparkSession.table(tableName) + + def tables(self, dbName: Optional[str] = None) -> DataFrame: + """Returns a :class:`DataFrame` containing names of tables in the given database. + + If ``dbName`` is not specified, the current database will be used. + + The returned DataFrame has three columns: ``namespace``, ``tableName`` and + ``isTemporary`` (a column with :class:`~pyspark.sql.types.BooleanType` indicating if a + table is a temporary one or not). + + .. versionadded:: 4.3.0 + + Parameters + ---------- + dbName: str, optional + name of the database to use. + + Returns + ------- + :class:`DataFrame` + """ + schema = StructType( + [ + StructField("namespace", StringType(), nullable=True), + StructField("tableName", StringType(), nullable=True), + StructField("isTemporary", BooleanType(), nullable=False), + ] + ) + # Use catalog.listTables() rather than SHOW TABLES so the column names are always + # (namespace, tableName, isTemporary), matching the classic implementation. + # SHOW TABLES returns "database" vs "namespace" depending on the active catalog. + rows = [ + # Join the full namespace ("a.b") to match classic SHOW TABLES, which emits the + # quoted namespace; keeping only the last part would drop levels under a v2 catalog. + (".".join(t.namespace) if t.namespace else "", t.name, t.isTemporary) + for t in self.sparkSession.catalog.listTables(dbName) + ] + return self.sparkSession.createDataFrame(rows, schema) + + def tableNames(self, dbName: Optional[str] = None) -> List[str]: + """Returns a list of names of tables in the database ``dbName``. + + .. versionadded:: 4.3.0 + + Parameters + ---------- + dbName: str + name of the database to use. Default to the current database. + + Returns + ------- + list + list of table names as strings + """ + return [t.name for t in self.sparkSession.catalog.listTables(dbName)] + + def cacheTable(self, tableName: str) -> None: + """Caches the specified table in-memory. + + .. versionadded:: 4.3.0 + """ + self.sparkSession.catalog.cacheTable(tableName) + + def uncacheTable(self, tableName: str) -> None: + """Removes the specified table from the in-memory cache. + + .. versionadded:: 4.3.0 + """ + self.sparkSession.catalog.uncacheTable(tableName) + + def clearCache(self) -> None: + """Removes all cached tables from the in-memory cache. + + .. versionadded:: 4.3.0 + """ + self.sparkSession.catalog.clearCache() + + @property + def read(self) -> DataFrameReader: + """Returns a :class:`DataFrameReader` that can be used to read data + in as a :class:`DataFrame`. + + .. versionadded:: 4.3.0 + + Returns + ------- + :class:`DataFrameReader` + """ + return self.sparkSession.read + + @property + def readStream(self) -> DataStreamReader: + """Returns a :class:`DataStreamReader` that can be used to read data streams + as a streaming :class:`DataFrame`. + + .. versionadded:: 4.3.0 + + Notes + ----- + This API is evolving. + + Returns + ------- + :class:`DataStreamReader` + """ + return self.sparkSession.readStream + + @property + def streams(self) -> StreamingQueryManager: + """Returns a :class:`StreamingQueryManager` that allows managing all the + :class:`~pyspark.sql.streaming.StreamingQuery` instances active on this + context. + + .. versionadded:: 4.3.0 + + Notes + ----- + This API is evolving. + """ + return self.sparkSession.streams + + +class HiveContext(SQLContext): + """Not supported in Spark Connect. + + .. deprecated:: 4.3.0 + Use SparkSession.builder.enableHiveSupport().getOrCreate(). + """ + + # Override to prevent inheriting SQLContext's cached instance, which would cause + # _get_or_create_from_session to skip _from_session and return the wrong type. + _instantiatedContext: ClassVar[Optional["SQLContext"]] = None + + def __init__(self, sparkSession: "SparkSession") -> None: + raise PySparkNotImplementedError( + errorClass="NOT_IMPLEMENTED", + messageParameters={"feature": "HiveContext"}, + ) + + @classmethod + def _from_session(cls, sparkSession: "SparkSession") -> "SQLContext": + raise PySparkNotImplementedError( + errorClass="NOT_IMPLEMENTED", + messageParameters={"feature": "HiveContext"}, + ) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index d538e427e51c2..307eb7ee7fd89 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -19,6 +19,7 @@ import json import threading import os +import sys import warnings from collections.abc import Callable, Sized import functools @@ -959,6 +960,17 @@ def stop(self) -> None: if self is getattr(SparkSession._active_session, "session", None): SparkSession._active_session.session = None + # Only touch the SQLContext cache if the module was ever imported; if no + # SQLContext was created there is nothing to reset, and we avoid importing it. + _connect_context = sys.modules.get("pyspark.sql.connect.context") + if _connect_context is not None: + _ConnectSQLContext = _connect_context.SQLContext + if ( + _ConnectSQLContext._instantiatedContext is not None + and _ConnectSQLContext._instantiatedContext.sparkSession is self + ): + _ConnectSQLContext._instantiatedContext = None + if "SPARK_LOCAL_REMOTE" in os.environ: # When local mode is in use, follow the regular Spark session's # behavior by terminating the Spark Connect server, @@ -1000,7 +1012,7 @@ def streams(self) -> "StreamingQueryManager": streams.__doc__ = PySparkSession.streams.__doc__ def __getattr__(self, name: str) -> Any: - if name in ["_jsc", "_jconf", "_jvm", "_jsparkSession", "sparkContext", "newSession"]: + if name in ["_jsc", "_jconf", "_jvm", "_jsparkSession", "sparkContext"]: raise PySparkAttributeError( errorClass="JVM_ATTRIBUTE_NOT_SUPPORTED", messageParameters={"attr_name": name} ) @@ -1307,6 +1319,32 @@ def cloneSession(self, new_session_id: Optional[str] = None) -> "SparkSession": new_session = object.__new__(SparkSession) new_session._client = cloned_client new_session._session_id = cloned_client._session_id + new_session.release_session_on_close = True + return new_session + + def newSession(self) -> "SparkSession": + """ + Returns a new :class:`SparkSession` as a new session, that has separate SQLConf, + registered temporary views and UDFs, but shared table cache. + + Unlike :meth:`cloneSession`, the returned session starts with empty state: no + configuration, temporary views, registered functions, or catalog state are copied + over from this session. This mirrors the classic + :meth:`pyspark.sql.session.SparkSession.newSession` semantics. + + .. versionadded:: 4.3.0 + + Returns + ------- + :class:`SparkSession` + A new SparkSession bound to a fresh, independent server-side session. + """ + new_client = self._client.newSession() + # Create a new SparkSession bound to the fresh, independent session directly. + new_session = object.__new__(SparkSession) + new_session._client = new_client + new_session._session_id = new_client._session_id + new_session.release_session_on_close = True return new_session diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 7145b27f2cf3c..01d8eb5583c93 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -34,6 +34,7 @@ from pyspark import _NoValue from pyspark._globals import _NoValueType +from pyspark.errors import PySparkNotImplementedError, PySparkValueError from pyspark.sql.session import _monkey_patch_RDD, SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader @@ -142,10 +143,13 @@ def _ssql_ctx(self) -> "JavaObject": return self._jsqlContext @classmethod - def getOrCreate(cls: Type["SQLContext"], sc: "SparkContext") -> "SQLContext": + def getOrCreate(cls: Type["SQLContext"], sc: Optional["SparkContext"] = None) -> "SQLContext": """ Get the existing SQLContext or create a new one with given SparkContext. + When running in Spark Connect mode, ``sc`` is not required and the active + :class:`SparkSession` is used automatically. + .. versionadded:: 1.6.0 .. deprecated:: 3.0.0 @@ -153,12 +157,45 @@ def getOrCreate(cls: Type["SQLContext"], sc: "SparkContext") -> "SQLContext": Parameters ---------- - sc : :class:`SparkContext` + sc : :class:`SparkContext`, optional + Required in classic mode; ignored in Spark Connect mode. """ + from pyspark.sql.utils import is_remote + warnings.warn( "Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead.", FutureWarning, ) + if is_remote(): + from pyspark.sql.connect import context as _connect_context + from pyspark.sql.connect.session import SparkSession as ConnectSparkSession + + session = SparkSession._getActiveSessionOrCreate() + # Route to the Connect counterpart so subclasses (e.g. HiveContext) are handled + # correctly: the Connect HiveContext._from_session raises PySparkNotImplementedError. + connect_cls = getattr(_connect_context, cls.__name__, None) + if connect_cls is None: + # A user-defined SQLContext subclass has no Connect counterpart. Fail loudly + # instead of silently returning a base Connect SQLContext that would be + # missing the subclass's attributes. + raise PySparkNotImplementedError( + errorClass="NOT_IMPLEMENTED", + messageParameters={"feature": f"{cls.__name__}.getOrCreate in Spark Connect"}, + ) + return cast( + "SQLContext", + connect_cls._get_or_create_from_session(cast(ConnectSparkSession, session)), + ) + if sc is None: + # Not an ``assert`` because asserts are stripped under ``python -O``, which + # would skip the guard and fail later with a cryptic AttributeError on sc._jvm. + raise PySparkValueError( + errorClass="ARGUMENT_REQUIRED", + messageParameters={ + "arg_name": "sc", + "condition": "running in classic (non-Connect) mode", + }, + ) return cls._get_or_create(sc) @classmethod diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 10580bd52e629..5f1bac730c97a 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -440,6 +440,40 @@ def on_execute_plan(self, req): self.assertEqual(inits, 1) self.assertEqual(calls, 2) + def test_session_hook_preserved_after_new_session(self): + calls = 0 + + class TestHook(RemoteSparkSession.Hook): + def __init__(self, _session): + pass + + def on_execute_plan(self, req): + nonlocal calls + calls += 1 + return req + + # Use create() instead of getOrCreate() to avoid picking up a session (and hooks) + # left active by other tests. + session = RemoteSparkSession.builder.remote("sc://foo")._registerHook(TestHook).create() + new_session = session.newSession() + try: + # Client-side behavior carries over to the fresh session, as in clone(). + self.assertEqual(new_session.client._session_hooks, session.client._session_hooks) + self.assertEqual(new_session.client._rpc_deadlines, session.client._rpc_deadlines) + + new_session.client._stub = MockService(new_session.client._session_id) + new_session.client.disable_reattachable_execute() + + # The hook still observes ExecutePlanRequests issued through the new session. + self.assertEqual(calls, 0) + new_session.range(1).collect() + self.assertEqual(calls, 1) + finally: + # Close the clients so their atexit hooks do not try to release sessions + # against the unreachable endpoint at interpreter shutdown. + new_session.client.close() + session.client.close() + def test_custom_operation_id(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) mock = MockService(client._session_id) diff --git a/python/pyspark/sql/tests/connect/test_connect_clone_session.py b/python/pyspark/sql/tests/connect/test_connect_clone_session.py index b4c9e4c67c8bb..7cb4009912377 100644 --- a/python/pyspark/sql/tests/connect/test_connect_clone_session.py +++ b/python/pyspark/sql/tests/connect/test_connect_clone_session.py @@ -32,6 +32,10 @@ def test_clone_session_basic(self): # Clone the session cloned_session = self.connect.cloneSession() + # The cloned session bypasses SparkSession.__init__, so make sure it still + # carries the attributes that SparkSession.stop() reads. + self.assertTrue(cloned_session.release_session_on_close) + # Verify the configuration was copied # (if cloning doesn't preserve dynamic configs, use a different approach) cloned_value = cloned_session.sql("SET spark.test.original").collect()[0][1] diff --git a/python/pyspark/sql/tests/connect/test_connect_context.py b/python/pyspark/sql/tests/connect/test_connect_context.py new file mode 100644 index 0000000000000..2857f70ebf2f9 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_connect_context.py @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from unittest.mock import patch + +from pyspark.errors import PySparkNotImplementedError +from pyspark.sql import HiveContext as ClassicHiveContext +from pyspark.sql import SQLContext as ClassicSQLContext +from pyspark.sql.connect.context import HiveContext, SQLContext +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class SQLContextConnectTests(ReusedConnectTestCase): + """Connect-specific SQLContext tests not covered by the parity mixin.""" + + def setUp(self) -> None: + super().setUp() + SQLContext._instantiatedContext = None + + def tearDown(self) -> None: + super().tearDown() + SQLContext._instantiatedContext = None + + def test_init_emits_deprecation_warning(self) -> None: + with self.assertWarns(FutureWarning): + SQLContext(self.spark) + + def test_registerJavaFunction_raises(self) -> None: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + ctx = SQLContext(self.spark) + with self.assertRaises(PySparkNotImplementedError): + ctx.registerJavaFunction("f", "com.example.F") + + def test_hive_context_raises(self) -> None: + with self.assertRaises(PySparkNotImplementedError): + HiveContext(self.spark) + + def test_getOrCreate_emits_deprecation_and_returns_connect_context(self) -> None: + """SQLContext.getOrCreate() in Connect mode returns a Connect-backed context.""" + with patch("pyspark.sql.utils.is_remote", return_value=True): + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + ctx = ClassicSQLContext.getOrCreate() + self.assertTrue(any(issubclass(w.category, FutureWarning) for w in caught)) + self.assertIsInstance(ctx, SQLContext) + + def test_newSession_returns_fresh_state(self) -> None: + """Connect newSession() returns a fresh, independent session that does NOT inherit + the parent's state (e.g. temp views), matching classic newSession() semantics.""" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + ctx = SQLContext(self.spark) + self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView("ctx_fresh_view") + ctx2 = None + try: + ctx2 = ctx.newSession() + # newSession() starts with empty state, so the parent's temp view is not visible. + self.assertNotIn("ctx_fresh_view", ctx2.tableNames()) + # The new session bypasses SparkSession.__init__, so make sure it still + # carries the attributes that SparkSession.stop() reads. + self.assertTrue(ctx2.sparkSession.release_session_on_close) + finally: + if ctx2 is not None: + # Release only the new server-side session and close its own client + # channel. We must NOT call ctx2's SparkSession.stop(): under + # SPARK_LOCAL_REMOTE (the test harness) it terminates the shared local + # Connect server, breaking the rest of the suite. Leaving the new + # client open is also wrong -- once tearDownClass stops the shared + # server, its atexit _on_exit -> _cleanup_ml_cache retries against the + # dead server and hangs until the test times out. + client = ctx2.sparkSession.client + try: + client.release_session() + except Exception: + pass + client.close() + self.spark.catalog.dropTempView("ctx_fresh_view") + + def test_hive_context_getOrCreate_raises(self) -> None: + """HiveContext.getOrCreate() in Connect mode raises PySparkNotImplementedError.""" + with patch("pyspark.sql.utils.is_remote", return_value=True): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + with self.assertRaises(PySparkNotImplementedError): + ClassicHiveContext.getOrCreate() + + +if __name__ == "__main__": + from pyspark.testing import main + + main() diff --git a/python/pyspark/sql/tests/connect/test_parity_sql_context.py b/python/pyspark/sql/tests/connect/test_parity_sql_context.py new file mode 100644 index 0000000000000..04491a209a91c --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_sql_context.py @@ -0,0 +1,60 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings + +from pyspark.sql.connect.context import SQLContext +from pyspark.sql.tests.test_sql_context import SQLContextTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class SQLContextParityTests(SQLContextTestsMixin, ReusedConnectTestCase): + def setUp(self) -> None: + super().setUp() + SQLContext._instantiatedContext = None + + def tearDown(self) -> None: + super().tearDown() + SQLContext._instantiatedContext = None + + def _make_ctx(self) -> SQLContext: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + return SQLContext(self.spark) + + def test_newSession_returns_distinct_instance(self) -> None: + ctx = self._make_ctx() + ctx2 = ctx.newSession() + try: + self.assertIsNot(ctx, ctx2) + finally: + # Release only the new server-side session and close its own client channel, + # not SparkSession.stop(). See SQLContextConnectTests.test_newSession_returns + # _fresh_state in test_connect_context.py for why: stop() would terminate the + # shared local Connect server, while leaving the client open hangs the suite + # in the client's atexit hook once the server goes away. + client = ctx2.sparkSession.client + try: + client.release_session() + except Exception: + pass + client.close() + + +if __name__ == "__main__": + from pyspark.testing import main + + main() diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index bd49b1f465482..6c1268f927851 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -264,7 +264,6 @@ def test_spark_session_compatibility(self): expected_missing_connect_methods = { "clearProgressHandlers", "copyFromLocalToFs", - "newSession", "registerProgressHandler", "removeProgressHandler", } diff --git a/python/pyspark/sql/tests/test_sql_context.py b/python/pyspark/sql/tests/test_sql_context.py new file mode 100644 index 0000000000000..24bf53c838c01 --- /dev/null +++ b/python/pyspark/sql/tests/test_sql_context.py @@ -0,0 +1,168 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import tempfile +import warnings + +from pyspark import SQLContext +from pyspark.sql import SparkSession +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class SQLContextTestsMixin: + """Tests for SQLContext that run in both classic and Connect modes. + + Subclasses must implement :meth:`_make_ctx` to return an :class:`SQLContext` + appropriate for their mode, and must expose ``self.spark`` as a + :class:`SparkSession`. + """ + + spark: SparkSession + + def _make_ctx(self) -> SQLContext: + raise NotImplementedError + + def setUp(self) -> None: + super().setUp() + SQLContext._instantiatedContext = None + + def tearDown(self) -> None: + SQLContext._instantiatedContext = None + super().tearDown() + + def test_setConf_and_getConf(self) -> None: + ctx = self._make_ctx() + ctx.setConf("spark.sql.shuffle.partitions", "42") + self.assertEqual(ctx.getConf("spark.sql.shuffle.partitions"), "42") + self.assertEqual(ctx.getConf("spark.sql.shuffle.partitions", "99"), "42") + ctx.setConf("spark.sql.shuffle.partitions", "200") + + def test_range(self) -> None: + ctx = self._make_ctx() + self.assertEqual(ctx.range(3).count(), 3) + + def test_createDataFrame(self) -> None: + ctx = self._make_ctx() + df = ctx.createDataFrame([(1, "a"), (2, "b")], ["id", "val"]) + self.assertEqual(df.count(), 2) + + def test_sql(self) -> None: + ctx = self._make_ctx() + self.assertEqual(ctx.sql("SELECT 1 AS n").collect()[0].n, 1) + + def test_registerDataFrameAsTable_and_sql(self) -> None: + ctx = self._make_ctx() + df = self.spark.createDataFrame([(1, "r1"), (2, "r2")], ["field1", "field2"]) + ctx.registerDataFrameAsTable(df, "ctx_mixin_tbl") + try: + result = ctx.sql("SELECT field1 FROM ctx_mixin_tbl ORDER BY field1").collect() + self.assertEqual([r.field1 for r in result], [1, 2]) + finally: + ctx.dropTempTable("ctx_mixin_tbl") + + def test_table(self) -> None: + ctx = self._make_ctx() + self.spark.createDataFrame([(42,)], ["v"]).createOrReplaceTempView("ctx_mixin_view") + try: + self.assertEqual(ctx.table("ctx_mixin_view").collect()[0].v, 42) + finally: + self.spark.catalog.dropTempView("ctx_mixin_view") + + def test_tables_contains_registered_view(self) -> None: + ctx = self._make_ctx() + self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView("ctx_mixin_tables_view") + try: + names = [r.tableName for r in ctx.tables().collect()] + self.assertIn("ctx_mixin_tables_view", names) + finally: + self.spark.catalog.dropTempView("ctx_mixin_tables_view") + + def test_tableNames_contains_registered_view(self) -> None: + ctx = self._make_ctx() + self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView( + "ctx_mixin_tablenames_view" + ) + try: + self.assertIn("ctx_mixin_tablenames_view", ctx.tableNames()) + finally: + self.spark.catalog.dropTempView("ctx_mixin_tablenames_view") + + def test_cacheTable_and_uncacheTable(self) -> None: + ctx = self._make_ctx() + self.spark.createDataFrame([(1,)], ["x"]).createOrReplaceTempView("ctx_mixin_cache_view") + try: + ctx.cacheTable("ctx_mixin_cache_view") + ctx.uncacheTable("ctx_mixin_cache_view") + finally: + self.spark.catalog.dropTempView("ctx_mixin_cache_view") + + def test_clearCache(self) -> None: + self._make_ctx().clearCache() + + def test_newSession_returns_distinct_instance(self) -> None: + ctx = self._make_ctx() + ctx2 = ctx.newSession() + self.assertIsNot(ctx, ctx2) + + def test_read_is_available(self) -> None: + self.assertIsNotNone(self._make_ctx().read) + + def test_readStream_is_available(self) -> None: + self.assertIsNotNone(self._make_ctx().readStream) + + def test_streams_is_available(self) -> None: + self.assertIsNotNone(self._make_ctx().streams) + + def test_udf_is_available(self) -> None: + self.assertIsNotNone(self._make_ctx().udf) + + def test_udtf_is_available(self) -> None: + self.assertIsNotNone(self._make_ctx().udtf) + + def test_registerFunction(self) -> None: + from pyspark.sql.types import IntegerType + + ctx = self._make_ctx() + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + ctx.registerFunction("ctx_mixin_double", lambda x: x * 2, IntegerType()) + result = ctx.sql("SELECT ctx_mixin_double(3) AS v").collect()[0].v + self.assertEqual(result, 6) + + def test_createExternalTable(self) -> None: + ctx = self._make_ctx() + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "data") + self.spark.createDataFrame([(1, "a"), (2, "b")], ["id", "val"]).write.parquet(path) + df = ctx.createExternalTable("ctx_mixin_ext_tbl", path, "parquet") + try: + self.assertEqual(df.count(), 2) + finally: + self.spark.sql("DROP TABLE IF EXISTS ctx_mixin_ext_tbl") + + +class SQLContextClassicTests(SQLContextTestsMixin, ReusedSQLTestCase): + def _make_ctx(self) -> SQLContext: + # Passing a non-None sparkSession means __init__ does not emit a deprecation + # warning, so no warnings filter is needed here. + return SQLContext(self.sc, self.spark) + + +if __name__ == "__main__": + from pyspark.testing import main + + main()