Skip to content

Commit 8fe976d

Browse files
authored
Upgrade SplitHTTP Transport (#3462)
* move to paths instead of querystrings * permit early data on serverside * early data for the client, fix context cancellation
1 parent c1a7602 commit 8fe976d

3 files changed

Lines changed: 193 additions & 73 deletions

File tree

transport/internet/splithttp/dialer.go

Lines changed: 77 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/xtls/xray-core/common/buf"
1717
"github.com/xtls/xray-core/common/net"
1818
"github.com/xtls/xray-core/common/session"
19+
"github.com/xtls/xray-core/common/signal/done"
1920
"github.com/xtls/xray-core/common/signal/semaphore"
2021
"github.com/xtls/xray-core/common/uuid"
2122
"github.com/xtls/xray-core/transport/internet"
@@ -44,18 +45,6 @@ var (
4445
globalDialerAccess sync.Mutex
4546
)
4647

47-
func destroyHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) {
48-
globalDialerAccess.Lock()
49-
defer globalDialerAccess.Unlock()
50-
51-
if globalDialerMap == nil {
52-
globalDialerMap = make(map[dialerConf]reusedClient)
53-
}
54-
55-
delete(globalDialerMap, dialerConf{dest, streamSettings})
56-
57-
}
58-
5948
func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) reusedClient {
6049
globalDialerAccess.Lock()
6150
defer globalDialerAccess.Unlock()
@@ -77,15 +66,15 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
7766
}
7867

7968
dialContext := func(ctxInner context.Context) (net.Conn, error) {
80-
conn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
69+
conn, err := internet.DialSystem(ctxInner, dest, streamSettings.SocketSettings)
8170
if err != nil {
8271
return nil, err
8372
}
8473

8574
if gotlsConfig != nil {
8675
if fingerprint := tls.GetFingerprint(tlsConfig.Fingerprint); fingerprint != nil {
8776
conn = tls.UClient(conn, gotlsConfig, fingerprint)
88-
if err := conn.(*tls.UConn).HandshakeContext(ctx); err != nil {
77+
if err := conn.(*tls.UConn).HandshakeContext(ctxInner); err != nil {
8978
return nil, err
9079
}
9180
} else {
@@ -171,49 +160,73 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
171160

172161
var remoteAddr gonet.Addr
173162
var localAddr gonet.Addr
163+
// this is done when the TCP/UDP connection to the server was established,
164+
// and we can unblock the Dial function and print correct net addresses in
165+
// logs
166+
gotConn := done.New()
174167

175-
trace := &httptrace.ClientTrace{
176-
GotConn: func(connInfo httptrace.GotConnInfo) {
177-
remoteAddr = connInfo.Conn.RemoteAddr()
178-
localAddr = connInfo.Conn.LocalAddr()
179-
},
180-
}
168+
var downResponse io.ReadCloser
169+
gotDownResponse := done.New()
181170

182171
sessionIdUuid := uuid.New()
183172
sessionId := sessionIdUuid.String()
184173

185-
req, err := http.NewRequestWithContext(
186-
httptrace.WithClientTrace(ctx, trace),
187-
"GET",
188-
requestURL.String()+"?session="+sessionId,
189-
nil,
190-
)
191-
if err != nil {
192-
return nil, err
193-
}
174+
go func() {
175+
trace := &httptrace.ClientTrace{
176+
GotConn: func(connInfo httptrace.GotConnInfo) {
177+
remoteAddr = connInfo.Conn.RemoteAddr()
178+
localAddr = connInfo.Conn.LocalAddr()
179+
gotConn.Close()
180+
},
181+
}
194182

195-
req.Header = transportConfiguration.GetRequestHeader()
196-
197-
downResponse, err := httpClient.download.Do(req)
198-
if err != nil {
199-
// workaround for various connection pool related issues, mostly around
200-
// HTTP/1.1. if the http client ever fails to send a request, we simply
201-
// delete it entirely.
202-
// in HTTP/1.1, it was observed that pool connections would immediately
203-
// fail with "context canceled" if the previous http response body was
204-
// not explicitly BOTH drained and closed. at the same time, sometimes
205-
// the draining itself takes forever and causes more problems.
206-
// see also https://github.com/golang/go/issues/60240
207-
destroyHTTPClient(ctx, dest, streamSettings)
208-
return nil, newError("failed to send download http request, destroying client").Base(err)
209-
}
183+
// in case we hit an error, we want to unblock this part
184+
defer gotConn.Close()
210185

211-
if downResponse.StatusCode != 200 {
212-
downResponse.Body.Close()
213-
return nil, newError("invalid status code on download:", downResponse.Status)
214-
}
186+
req, err := http.NewRequestWithContext(
187+
httptrace.WithClientTrace(context.WithoutCancel(ctx), trace),
188+
"GET",
189+
requestURL.String()+sessionId,
190+
nil,
191+
)
192+
if err != nil {
193+
newError("failed to construct download http request").Base(err).WriteToLog()
194+
gotDownResponse.Close()
195+
return
196+
}
197+
198+
req.Header = transportConfiguration.GetRequestHeader()
199+
200+
response, err := httpClient.download.Do(req)
201+
gotConn.Close()
202+
if err != nil {
203+
newError("failed to send download http request").Base(err).WriteToLog()
204+
gotDownResponse.Close()
205+
return
206+
}
207+
208+
if response.StatusCode != 200 {
209+
response.Body.Close()
210+
newError("invalid status code on download:", response.Status).WriteToLog()
211+
gotDownResponse.Close()
212+
return
213+
}
214+
215+
// skip "ok" response
216+
trashHeader := []byte{0, 0}
217+
_, err = io.ReadFull(response.Body, trashHeader)
218+
if err != nil {
219+
response.Body.Close()
220+
newError("failed to read initial response").Base(err).WriteToLog()
221+
gotDownResponse.Close()
222+
return
223+
}
215224

216-
uploadUrl := requestURL.String() + "?session=" + sessionId + "&seq="
225+
downResponse = response.Body
226+
gotDownResponse.Close()
227+
}()
228+
229+
uploadUrl := requestURL.String() + sessionId + "/"
217230

218231
uploadPipeReader, uploadPipeWriter := pipe.New(pipe.WithSizeLimit(maxUploadSize))
219232

@@ -266,7 +279,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
266279
for i := 0; i < 5; i++ {
267280
uploadConn = httpClient.uploadRawPool.Get()
268281
if uploadConn == nil {
269-
uploadConn, err = httpClient.dialUploadConn(ctx)
282+
uploadConn, err = httpClient.dialUploadConn(context.WithoutCancel(ctx))
270283
if err != nil {
271284
newError("failed to connect upload").Base(err).WriteToLog()
272285
uploadPipeReader.Interrupt()
@@ -293,21 +306,27 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
293306
}
294307
}()
295308

296-
// skip "ok" response
297-
trashHeader := []byte{0, 0}
298-
_, err = io.ReadFull(downResponse.Body, trashHeader)
299-
if err != nil {
300-
downResponse.Body.Close()
301-
return nil, newError("failed to read initial response")
302-
}
309+
// we want to block Dial until we know the remote address of the server,
310+
// for logging purposes
311+
<-gotConn.Wait()
303312

304313
// necessary in order to send larger chunks in upload
305314
bufferedUploadPipeWriter := buf.NewBufferedWriter(uploadPipeWriter)
306315
bufferedUploadPipeWriter.SetBuffered(false)
307316

317+
lazyDownload := &LazyReader{
318+
CreateReader: func() (io.ReadCloser, error) {
319+
<-gotDownResponse.Wait()
320+
if downResponse == nil {
321+
return nil, newError("downResponse failed")
322+
}
323+
return downResponse, nil
324+
},
325+
}
326+
308327
conn := splitConn{
309328
writer: bufferedUploadPipeWriter,
310-
reader: downResponse.Body,
329+
reader: lazyDownload,
311330
remoteAddr: remoteAddr,
312331
localAddr: localAddr,
313332
}

transport/internet/splithttp/hub.go

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
gonet "net"
88
"net/http"
99
"strconv"
10+
"strings"
1011
"sync"
1112
"time"
1213

@@ -28,20 +29,65 @@ type requestHandler struct {
2829
localAddr gonet.TCPAddr
2930
}
3031

32+
type httpSession struct {
33+
uploadQueue *UploadQueue
34+
// for as long as the GET request is not opened by the client, this will be
35+
// open ("undone"), and the session may be expired within a certain TTL.
36+
// after the client connects, this becomes "done" and the session lives as
37+
// long as the GET request.
38+
isFullyConnected *done.Instance
39+
}
40+
41+
func (h *requestHandler) maybeReapSession(isFullyConnected *done.Instance, sessionId string) {
42+
shouldReap := done.New()
43+
go func() {
44+
time.Sleep(30 * time.Second)
45+
shouldReap.Close()
46+
}()
47+
48+
select {
49+
case <-isFullyConnected.Wait():
50+
return
51+
case <-shouldReap.Wait():
52+
h.sessions.Delete(sessionId)
53+
}
54+
}
55+
56+
func (h *requestHandler) upsertSession(sessionId string) *httpSession {
57+
currentSessionAny, ok := h.sessions.Load(sessionId)
58+
if ok {
59+
return currentSessionAny.(*httpSession)
60+
}
61+
62+
s := &httpSession{
63+
uploadQueue: NewUploadQueue(int(2 * h.ln.config.GetNormalizedMaxConcurrentUploads())),
64+
isFullyConnected: done.New(),
65+
}
66+
67+
h.sessions.Store(sessionId, s)
68+
go h.maybeReapSession(s.isFullyConnected, sessionId)
69+
return s
70+
}
71+
3172
func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
3273
if len(h.host) > 0 && request.Host != h.host {
3374
newError("failed to validate host, request:", request.Host, ", config:", h.host).WriteToLog()
3475
writer.WriteHeader(http.StatusNotFound)
3576
return
3677
}
37-
if request.URL.Path != h.path {
78+
79+
if !strings.HasPrefix(request.URL.Path, h.path) {
3880
newError("failed to validate path, request:", request.URL.Path, ", config:", h.path).WriteToLog()
3981
writer.WriteHeader(http.StatusNotFound)
4082
return
4183
}
4284

43-
queryString := request.URL.Query()
44-
sessionId := queryString.Get("session")
85+
sessionId := ""
86+
subpath := strings.Split(request.URL.Path[len(h.path):], "/")
87+
if len(subpath) > 0 {
88+
sessionId = subpath[0]
89+
}
90+
4591
if sessionId == "" {
4692
newError("no sessionid on request:", request.URL.Path).WriteToLog()
4793
writer.WriteHeader(http.StatusBadRequest)
@@ -60,15 +106,14 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
60106
}
61107
}
62108

109+
currentSession := h.upsertSession(sessionId)
110+
63111
if request.Method == "POST" {
64-
uploadQueue, ok := h.sessions.Load(sessionId)
65-
if !ok {
66-
newError("sessionid does not exist").WriteToLog()
67-
writer.WriteHeader(http.StatusBadRequest)
68-
return
112+
seq := ""
113+
if len(subpath) > 1 {
114+
seq = subpath[1]
69115
}
70116

71-
seq := queryString.Get("seq")
72117
if seq == "" {
73118
newError("no seq on request:", request.URL.Path).WriteToLog()
74119
writer.WriteHeader(http.StatusBadRequest)
@@ -89,7 +134,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
89134
return
90135
}
91136

92-
err = uploadQueue.(*UploadQueue).Push(Packet{
137+
err = currentSession.uploadQueue.Push(Packet{
93138
Payload: payload,
94139
Seq: seqInt,
95140
})
@@ -107,10 +152,9 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
107152
panic("expected http.ResponseWriter to be an http.Flusher")
108153
}
109154

110-
uploadQueue := NewUploadQueue(int(2 * h.ln.config.GetNormalizedMaxConcurrentUploads()))
111-
112-
h.sessions.Store(sessionId, uploadQueue)
113-
// the connection is finished, clean up map
155+
// after GET is done, the connection is finished. disable automatic
156+
// session reaping, and handle it in defer
157+
currentSession.isFullyConnected.Close()
114158
defer h.sessions.Delete(sessionId)
115159

116160
// magic header instructs nginx + apache to not buffer response body
@@ -130,7 +174,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
130174
downloadDone: downloadDone,
131175
responseFlusher: responseFlusher,
132176
},
133-
reader: uploadQueue,
177+
reader: currentSession.uploadQueue,
134178
remoteAddr: remoteAddr,
135179
}
136180

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package splithttp
2+
3+
import (
4+
"io"
5+
"sync"
6+
)
7+
8+
type LazyReader struct {
9+
readerSync sync.Mutex
10+
CreateReader func() (io.ReadCloser, error)
11+
reader io.ReadCloser
12+
readerError error
13+
}
14+
15+
func (r *LazyReader) getReader() (io.ReadCloser, error) {
16+
r.readerSync.Lock()
17+
defer r.readerSync.Unlock()
18+
if r.reader != nil {
19+
return r.reader, nil
20+
}
21+
22+
if r.readerError != nil {
23+
return nil, r.readerError
24+
}
25+
26+
reader, err := r.CreateReader()
27+
if err != nil {
28+
r.readerError = err
29+
return nil, err
30+
}
31+
32+
r.reader = reader
33+
return reader, nil
34+
}
35+
36+
func (r *LazyReader) Read(b []byte) (int, error) {
37+
reader, err := r.getReader()
38+
if err != nil {
39+
return 0, err
40+
}
41+
n, err := reader.Read(b)
42+
return n, err
43+
}
44+
45+
func (r *LazyReader) Close() error {
46+
r.readerSync.Lock()
47+
defer r.readerSync.Unlock()
48+
49+
var err error
50+
if r.reader != nil {
51+
err = r.reader.Close()
52+
r.reader = nil
53+
r.readerError = newError("closed reader")
54+
}
55+
56+
return err
57+
}

0 commit comments

Comments
 (0)