diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index f1f7a07701de16..7f237276617152 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -1606,6 +1606,59 @@ def dummycallback(sock, servername, ctx, cycle=ctx): gc.collect() self.assertIs(wr(), None) + @unittest.skipUnless(support.Py_GIL_DISABLED, + "test is only useful if the GIL is disabled") + @threading_helper.requires_working_threading() + def test_sni_callback_race(self): + # Replacing sni_callback while handshakes are in-flight must not + # crash (use-after-free on the callback in free-threaded builds). + client_ctx, server_ctx, hostname = testing_context() + + server_ctx.sni_callback = lambda *a: None + done = threading.Event() + + def do_handshakes(): + while not done.is_set(): + c_in = ssl.MemoryBIO() + c_out = ssl.MemoryBIO() + s_in = ssl.MemoryBIO() + s_out = ssl.MemoryBIO() + client = client_ctx.wrap_bio( + c_in, c_out, server_hostname=hostname) + server = server_ctx.wrap_bio(s_in, s_out, server_side=True) + for _ in range(50): + try: + client.do_handshake() + except ssl.SSLWantReadError: + pass + except ssl.SSLError: + break + if c_out.pending: + s_in.write(c_out.read()) + try: + server.do_handshake() + except ssl.SSLWantReadError: + pass + except ssl.SSLError: + break + if s_out.pending: + c_in.write(s_out.read()) + + def toggle_callback(): + while not done.is_set(): + server_ctx.sni_callback = lambda *a: None + server_ctx.sni_callback = None + + workers = max(4, (os.cpu_count() or 4) * 2) + threads = [threading.Thread(target=do_handshakes) + for _ in range(workers)] + threads.append(threading.Thread(target=toggle_callback)) + + with threading_helper.catch_threading_exception() as cm: + with threading_helper.start_threads(threads): + done.set() + self.assertIsNone(cm.exc_value) + def test_cert_store_stats(self): ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) self.assertEqual(ctx.cert_store_stats(), diff --git a/Misc/NEWS.d/next/Library/2026-05-18-22-45-54.gh-issue-149816.T68vc_.rst b/Misc/NEWS.d/next/Library/2026-05-18-22-45-54.gh-issue-149816.T68vc_.rst new file mode 100644 index 00000000000000..9996cc7ec0e866 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2026-05-18-22-45-54.gh-issue-149816.T68vc_.rst @@ -0,0 +1 @@ +Fix race condition in :attr:`ssl.SSLContext.sni_callback` diff --git a/Modules/_ssl.c b/Modules/_ssl.c index 3224ca7d0f93b9..35754e566a1528 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -26,6 +26,7 @@ #define OPENSSL_NO_DEPRECATED 1 #include "Python.h" +#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION() #include "pycore_fileutils.h" // _PyIsSelectable_fd() #include "pycore_long.h" // _PyLong_UnsignedLongLong_Converter() #include "pycore_pyerrors.h" // _PyErr_ChainExceptions1() @@ -5153,12 +5154,15 @@ _servername_callback(SSL *s, int *al, void *args) PyObject *result; /* The high-level ssl.SSLSocket object */ PyObject *ssl_socket; + PyObject *sni_cb; const char *servername = SSL_get_servername(s, TLSEXT_NAMETYPE_host_name); PyGILState_STATE gstate = PyGILState_Ensure(); - if (sslctx->set_sni_cb == NULL) { - /* remove race condition in this the call back while if removing the - * callback is in progress */ + Py_BEGIN_CRITICAL_SECTION(sslctx); + sni_cb = Py_XNewRef(sslctx->set_sni_cb); + Py_END_CRITICAL_SECTION(); + + if (sni_cb == NULL) { PyGILState_Release(gstate); return SSL_TLSEXT_ERR_OK; } @@ -5185,7 +5189,7 @@ _servername_callback(SSL *s, int *al, void *args) goto error; if (servername == NULL) { - result = PyObject_CallFunctionObjArgs(sslctx->set_sni_cb, ssl_socket, + result = PyObject_CallFunctionObjArgs(sni_cb, ssl_socket, Py_None, sslctx, NULL); } else { @@ -5212,7 +5216,7 @@ _servername_callback(SSL *s, int *al, void *args) } Py_DECREF(servername_bytes); result = PyObject_CallFunctionObjArgs( - sslctx->set_sni_cb, ssl_socket, servername_str, + sni_cb, ssl_socket, servername_str, sslctx, NULL); Py_DECREF(servername_str); } @@ -5222,7 +5226,7 @@ _servername_callback(SSL *s, int *al, void *args) PyErr_FormatUnraisable("Exception ignored " "in ssl servername callback " "while calling set SNI callback %R", - sslctx->set_sni_cb); + sni_cb); *al = SSL_AD_HANDSHAKE_FAILURE; ret = SSL_TLSEXT_ERR_ALERT_FATAL; } @@ -5247,11 +5251,13 @@ _servername_callback(SSL *s, int *al, void *args) Py_DECREF(result); } + Py_DECREF(sni_cb); PyGILState_Release(gstate); return ret; error: Py_XDECREF(ssl_socket); + Py_XDECREF(sni_cb); *al = SSL_AD_INTERNAL_ERROR; ret = SSL_TLSEXT_ERR_ALERT_FATAL; PyGILState_Release(gstate); @@ -5301,20 +5307,18 @@ _ssl__SSLContext_sni_callback_set_impl(PySSLContext *self, PyObject *value) "sni_callback cannot be set on TLS_CLIENT context"); return -1; } - Py_CLEAR(self->set_sni_cb); - if (value == Py_None) { + if (!PyCallable_Check(value)) { SSL_CTX_set_tlsext_servername_callback(self->ctx, NULL); - } - else { - if (!PyCallable_Check(value)) { - SSL_CTX_set_tlsext_servername_callback(self->ctx, NULL); - PyErr_SetString(PyExc_TypeError, - "not a callable object"); + Py_CLEAR(self->set_sni_cb); + if (value != Py_None) { + PyErr_SetString(PyExc_TypeError, "not a callable object"); return -1; } - self->set_sni_cb = Py_NewRef(value); - SSL_CTX_set_tlsext_servername_callback(self->ctx, _servername_callback); + } + else { + Py_XSETREF(self->set_sni_cb, Py_NewRef(value)); SSL_CTX_set_tlsext_servername_arg(self->ctx, self); + SSL_CTX_set_tlsext_servername_callback(self->ctx, _servername_callback); } return 0; }