Skip to content

Commit faf2c0a

Browse files
committed
Properly encrypt RTP packets with padding
1 parent 3d58eea commit faf2c0a

File tree

8 files changed

+130
-24
lines changed

8 files changed

+130
-24
lines changed

session_srtp.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,20 @@ func (s *SessionSRTP) writeRTP(header *rtp.Header, payload []byte) (int, error)
150150
ibuf := bufferpool.Get()
151151
defer bufferpool.Put(ibuf)
152152

153+
buf := ibuf.([]byte) // nolint:forcetypeassert
154+
headerLen, marshalSize := rtp.HeaderAndPacketMarshalSize(header, payload) // nolint:staticcheck
155+
if len(buf) < marshalSize+20 {
156+
// The buffer is too small, so we need to allocate a new one. Add 20 bytes for auth tag like
157+
// for bufferpool above.
158+
buf = make([]byte, marshalSize+20)
159+
}
160+
_, err := rtp.MarshalPacketTo(buf, header, payload) // nolint:staticcheck
161+
if err != nil {
162+
return 0, err
163+
}
164+
153165
s.session.localContextMutex.Lock()
154-
encrypted, err := s.localContext.encryptRTP(ibuf.([]byte), header, payload) //nolint:forcetypeassert
166+
encrypted, err := s.localContext.encryptRTP(buf, header, headerLen, buf[:marshalSize])
155167
s.session.localContextMutex.Unlock()
156168

157169
if err != nil {

session_srtp_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,3 +363,51 @@ func encryptSRTP(context *Context, pkt *rtp.Packet) ([]byte, error) {
363363

364364
return encrypted, nil
365365
}
366+
367+
func TestSessionSRTPPacketWithPadding(t *testing.T) {
368+
lim := test.TimeOut(time.Second * 5)
369+
defer lim.Stop()
370+
371+
report := test.CheckRoutines(t)
372+
defer report()
373+
374+
const (
375+
testSSRC = 5000
376+
rtpHeaderSize = 12
377+
paddingSize = 5
378+
authTagLen = 10 // For AES_CM_128_HMAC_SHA1_80, the auth tag length is 10 bytes.
379+
)
380+
testPayload := []byte{0x00, 0x01, 0x03, 0x04}
381+
readBuffer := make([]byte, rtpHeaderSize+paddingSize+len(testPayload))
382+
aSession, bSession := buildSessionSRTPPair(t)
383+
384+
aWriteStream, err := aSession.OpenWriteStream()
385+
assert.NoError(t, err)
386+
387+
writeBytes, err := aWriteStream.WriteRTP(&rtp.Header{SSRC: testSSRC, Padding: true, PaddingSize: paddingSize},
388+
append([]byte{}, testPayload...))
389+
assert.NoError(t, err)
390+
assert.Equalf(t, rtpHeaderSize+paddingSize+len(testPayload)+authTagLen, writeBytes,
391+
"WriteRTP should return the size of the packet including padding, exp(%v) actual(%v)",
392+
rtpHeaderSize+paddingSize+len(testPayload)+authTagLen, writeBytes)
393+
394+
bReadStream, ssrc, err := bSession.AcceptStream()
395+
assert.NoError(t, err)
396+
assert.Equalf(t, uint32(testSSRC), ssrc, "SSRC mismatch during accept exp(%v) actual(%v)", testSSRC, ssrc)
397+
398+
readBytes, err := bReadStream.Read(readBuffer)
399+
assert.NoError(t, err)
400+
assert.Equal(t, rtpHeaderSize+paddingSize+len(testPayload), readBytes,
401+
"Read should return the size of the packet including padding, exp(%v) actual(%v)",
402+
rtpHeaderSize+paddingSize+len(testPayload), readBytes)
403+
404+
var rtpPacket rtp.Packet
405+
err = rtpPacket.Unmarshal(readBuffer[:readBytes])
406+
assert.NoError(t, err)
407+
assert.Equal(t, rtpPacket.Padding, true)
408+
assert.Equal(t, rtpPacket.PaddingSize, byte(paddingSize))
409+
assert.Equal(t, rtpPacket.Payload, testPayload)
410+
411+
assert.NoError(t, aSession.Close())
412+
assert.NoError(t, bSession.Close())
413+
}

srtp.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,14 @@ func (c *Context) EncryptRTP(dst []byte, plaintext []byte, header *rtp.Header) (
113113
return nil, err
114114
}
115115

116-
return c.encryptRTP(dst, header, plaintext[headerLen:])
116+
return c.encryptRTP(dst, header, headerLen, plaintext)
117117
}
118118

119119
// encryptRTP marshals and encrypts an RTP packet, writing to the dst buffer provided.
120120
// If the dst buffer does not have the capacity, a new one will be allocated and returned.
121121
// Similar to above but faster because it can avoid unmarshaling the header and marshaling the payload.
122-
func (c *Context) encryptRTP(dst []byte, header *rtp.Header, payload []byte) (ciphertext []byte, err error) {
122+
func (c *Context) encryptRTP(dst []byte, header *rtp.Header, headerLen int, plaintext []byte,
123+
) (ciphertext []byte, err error) {
123124
s := c.getSRTPSSRCState(header.SSRC)
124125
roc, diff, ovf := s.nextRolloverCount(header.SequenceNumber)
125126
if ovf {
@@ -136,7 +137,7 @@ func (c *Context) encryptRTP(dst []byte, header *rtp.Header, payload []byte) (ci
136137
rocInPacket = true
137138
}
138139

139-
return c.cipher.encryptRTP(dst, header, payload, roc, rocInPacket)
140+
return c.cipher.encryptRTP(dst, header, headerLen, plaintext, roc, rocInPacket)
140141
}
141142

142143
func (c *Context) hasROCInPacket(header *rtp.Header, authTagLen int) (bool, int) {

srtp_cipher.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ type srtpCipher interface {
1717
AEADAuthTagLen() (int, error)
1818
getRTCPIndex([]byte) uint32
1919

20-
encryptRTP([]byte, *rtp.Header, []byte, uint32, bool) ([]byte, error)
20+
encryptRTP([]byte, *rtp.Header, int, []byte, uint32, bool) ([]byte, error)
2121
encryptRTCP([]byte, []byte, uint32, uint32) ([]byte, error)
2222

2323
decryptRTP([]byte, []byte, *rtp.Header, int, uint32, bool) ([]byte, error)

srtp_cipher_aead_aes_gcm.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,14 @@ func newSrtpCipherAeadAesGcm(
9191
func (s *srtpCipherAeadAesGcm) encryptRTP(
9292
dst []byte,
9393
header *rtp.Header,
94-
payload []byte,
94+
headerLen int,
95+
plaintext []byte,
9596
roc uint32,
9697
rocInAuthTag bool,
9798
) (ciphertext []byte, err error) {
99+
payload := plaintext[headerLen:]
100+
payloadLen := len(payload)
101+
98102
// Grow the given buffer to fit the output.
99103
authTagLen, err := s.AEADAuthTagLen()
100104
if err != nil {
@@ -106,18 +110,21 @@ func (s *srtpCipherAeadAesGcm) encryptRTP(
106110
dstLen += 4
107111
}
108112
dst = growBufferSize(dst, dstLen)
113+
sameBuffer := isSameBuffer(dst, plaintext)
109114

110-
n, err := header.MarshalTo(dst)
111-
if err != nil {
112-
return nil, err
115+
// Copy the header unencrypted.
116+
if !sameBuffer {
117+
copy(dst, plaintext[:headerLen])
113118
}
114119

115120
iv := s.rtpInitializationVector(header, roc)
116121
if s.srtpEncrypted {
117-
s.srtpCipher.Seal(dst[n:n], iv[:], payload, dst[:n])
122+
s.srtpCipher.Seal(dst[headerLen:headerLen], iv[:], payload, dst[:headerLen])
118123
} else {
119-
clearLen := n + len(payload)
120-
copy(dst[n:], payload)
124+
clearLen := headerLen + payloadLen
125+
if !sameBuffer {
126+
copy(dst[headerLen:], payload)
127+
}
121128
s.srtpCipher.Seal(dst[clearLen:clearLen], iv[:], nil, dst[:clearLen])
122129
}
123130

srtp_cipher_aes_cm_hmac_sha1.go

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,33 +104,37 @@ func newSrtpCipherAesCmHmacSha1(
104104
func (s *srtpCipherAesCmHmacSha1) encryptRTP(
105105
dst []byte,
106106
header *rtp.Header,
107-
payload []byte,
107+
headerLen int,
108+
plaintext []byte,
108109
roc uint32,
109110
rocInAuthTag bool,
110111
) (ciphertext []byte, err error) {
112+
payload := plaintext[headerLen:]
113+
payloadLen := len(payload)
114+
111115
// Grow the given buffer to fit the output.
112116
authTagLen, err := s.AuthTagRTPLen()
113117
if err != nil {
114118
return nil, err
115119
}
116-
dst = growBufferSize(dst, header.MarshalSize()+len(payload)+len(s.mki)+authTagLen)
120+
dst = growBufferSize(dst, headerLen+payloadLen+len(s.mki)+authTagLen)
121+
sameBuffer := isSameBuffer(dst, plaintext)
117122

118123
// Copy the header unencrypted.
119-
n, err := header.MarshalTo(dst)
120-
if err != nil {
121-
return nil, err
124+
if !sameBuffer {
125+
copy(dst, plaintext[:headerLen])
122126
}
123127

124128
// Encrypt the payload
125129
if s.srtpEncrypted {
126130
counter := generateCounter(header.SequenceNumber, roc, header.SSRC, s.srtpSessionSalt)
127-
if err = xorBytesCTR(s.srtpBlock, counter[:], dst[n:], payload); err != nil {
131+
if err = xorBytesCTR(s.srtpBlock, counter[:], dst[headerLen:], payload); err != nil {
128132
return nil, err
129133
}
130-
} else {
131-
copy(dst[n:], payload)
134+
} else if !sameBuffer {
135+
copy(dst[headerLen:], payload)
132136
}
133-
n += len(payload)
137+
n := headerLen + payloadLen
134138

135139
// Generate the auth tag.
136140
authTag, err := s.generateSrtpAuthTag(dst[:n], roc, rocInAuthTag)

srtp_cipher_test.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,18 @@ func TestSrtpCipher(t *testing.T) {
591591
assert.NoError(t, err)
592592
assert.Equal(t, testCase.encryptedRTPPacket, actualEncrypted)
593593
})
594+
595+
t.Run("Same buffer", func(t *testing.T) {
596+
buffer := make([]byte, 0, 1000)
597+
src, dst := buffer, buffer
598+
src = append(src, testCase.decryptedRTPPacket...)
599+
assert.True(t, isSameBuffer(dst, src))
600+
601+
actualEncrypted, err := ctx.EncryptRTP(dst, src, nil)
602+
assert.NoError(t, err)
603+
assert.Equal(t, testCase.encryptedRTPPacket, actualEncrypted)
604+
assert.True(t, isSameBuffer(actualEncrypted, src))
605+
})
594606
})
595607

596608
t.Run("Decrypt RTP", func(t *testing.T) {

util.go

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33

44
package srtp
55

6-
import "bytes"
6+
import (
7+
"bytes"
8+
"unsafe"
9+
)
710

8-
// Grow the buffer size to the given number of bytes.
11+
// growBufferSize grows the buffer size to the given number of bytes.
912
func growBufferSize(buf []byte, size int) []byte {
1013
if size <= cap(buf) {
1114
return buf[:size]
@@ -17,7 +20,7 @@ func growBufferSize(buf []byte, size int) []byte {
1720
return buf2
1821
}
1922

20-
// Check if buffers match, if not allocate a new buffer and return it.
23+
// allocateIfMismatch checks if buffers match, if not allocates a new buffer and returns it.
2124
func allocateIfMismatch(dst, src []byte) []byte {
2225
if dst == nil {
2326
dst = make([]byte, len(src))
@@ -35,3 +38,22 @@ func allocateIfMismatch(dst, src []byte) []byte {
3538

3639
return dst
3740
}
41+
42+
// isSameBuffer returns true if slices a and b share the same underlying buffer.
43+
func isSameBuffer(a, b []byte) bool {
44+
// If both are nil, they are technically the same (no buffer)
45+
if a == nil && b == nil {
46+
return true
47+
}
48+
49+
// If either is nil, or both have 0 capacity, they can't share backing buffer
50+
if cap(a) == 0 || cap(b) == 0 {
51+
return false
52+
}
53+
54+
// Create a slice of length 1 from each if possible
55+
aPtr := unsafe.Pointer(&a[:1][0]) // nolint:gosec
56+
bPtr := unsafe.Pointer(&b[:1][0]) // nolint:gosec
57+
58+
return aPtr == bPtr
59+
}

0 commit comments

Comments
 (0)