Skip to content

Commit acc0a63

Browse files
[FIXED] MQTT: Retained message body possibly corrupt in cluster mode (#7622)
2 parents 05e659c + c6353e6 commit acc0a63

File tree

2 files changed

+126
-12
lines changed

2 files changed

+126
-12
lines changed

server/mqtt.go

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1992,6 +1992,10 @@ func (as *mqttAccountSessionManager) processJSAPIReplies(_ *subscription, pc *cl
19921992
// No lock held on entry.
19931993
func (as *mqttAccountSessionManager) processRetainedMsg(_ *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) {
19941994
h, m := c.msgParts(rmsg)
1995+
// We need to strip the trailing "\r\n".
1996+
if l := len(m); l >= LEN_CR_LF {
1997+
m = m[:l-LEN_CR_LF]
1998+
}
19951999
rm, err := mqttDecodeRetainedMessage(h, m)
19962000
if err != nil {
19972001
return
@@ -2008,8 +2012,10 @@ func (as *mqttAccountSessionManager) processRetainedMsg(_ *subscription, c *clie
20082012
// At this point we either recover from our own server, or process a remote retained message.
20092013
seq, _, _ := ackReplyInfo(reply)
20102014

2011-
// Handle this retained message, no need to copy the bytes.
2012-
as.handleRetainedMsg(rm.Subject, &mqttRetainedMsgRef{sseq: seq}, rm, false)
2015+
// Handle this retained message. The `rm.Msg` references some buffer owned
2016+
// by the caller. handleRetainedMsg() will take care of making a copy of
2017+
// `rm.Msg` it `rm` ends-up being stored in the cache.
2018+
as.handleRetainedMsg(rm.Subject, &mqttRetainedMsgRef{sseq: seq}, rm)
20132019

20142020
// If we were recovering (rrmTotal > 0), then check if we are done.
20152021
as.mu.Lock()
@@ -2286,7 +2292,7 @@ func (as *mqttAccountSessionManager) sendJSAPIrequests(s *Server, c *client, acc
22862292
// If a message for this topic already existed, the existing record is updated
22872293
// with the provided information.
22882294
// Lock not held on entry.
2289-
func (as *mqttAccountSessionManager) handleRetainedMsg(key string, rf *mqttRetainedMsgRef, rm *mqttRetainedMsg, copyBytesToCache bool) {
2295+
func (as *mqttAccountSessionManager) handleRetainedMsg(key string, rf *mqttRetainedMsgRef, rm *mqttRetainedMsg) {
22902296
as.mu.Lock()
22912297
defer as.mu.Unlock()
22922298
if as.retmsgs == nil {
@@ -2313,7 +2319,9 @@ func (as *mqttAccountSessionManager) handleRetainedMsg(key string, rf *mqttRetai
23132319

23142320
// Update the in-memory retained message cache but only for messages
23152321
// that are already in the cache, i.e. have been (recently) used.
2316-
as.setCachedRetainedMsg(key, rm, true, copyBytesToCache)
2322+
// If that is the case, we ask setCachedRetainedMsg() to make a copy
2323+
// of rm.Msg bytes slice.
2324+
as.setCachedRetainedMsg(key, rm, true, true)
23172325
return
23182326
}
23192327
}
@@ -3122,7 +3130,17 @@ func (as *mqttAccountSessionManager) getCachedRetainedMsg(subject string) *mqttR
31223130
return rm
31233131
}
31243132

3125-
func (as *mqttAccountSessionManager) setCachedRetainedMsg(subject string, rm *mqttRetainedMsg, onlyReplace bool, copyBytesToCache bool) {
3133+
// If cache is enabled, the expiration for the `rm` is bumped by
3134+
// `mqttRetainedCacheTTL` seconds.
3135+
// If `onlyReplace` is true, then the `rm` object is stored in the cache using
3136+
// the `subject` key only if there was already an object stored under that key.
3137+
// If `copyMsgBytes` is true, then the `rm.Msg` bytes are copied (because it
3138+
// references some buffer that is not owned by the caller).
3139+
//
3140+
// Note: currently `onlyReplace` and `cloneMsgBytes` always have the same
3141+
// value (all `true` or all `false`) however we use different booleans to
3142+
// better express the intent.
3143+
func (as *mqttAccountSessionManager) setCachedRetainedMsg(subject string, rm *mqttRetainedMsg, onlyReplace, copyMsgBytes bool) {
31263144
if as.rmsCache == nil || rm == nil {
31273145
return
31283146
}
@@ -3132,7 +3150,7 @@ func (as *mqttAccountSessionManager) setCachedRetainedMsg(subject string, rm *mq
31323150
return
31333151
}
31343152
}
3135-
if copyBytesToCache {
3153+
if copyMsgBytes {
31363154
rm.Msg = copyBytes(rm.Msg)
31373155
}
31383156
as.rmsCache.Store(subject, rm)
@@ -4409,9 +4427,9 @@ func (c *client) mqttHandlePubRetain() {
44094427
rf := &mqttRetainedMsgRef{
44104428
sseq: smr.Sequence,
44114429
}
4412-
// Add/update the map. `true` to copy the payload bytes if needs to
4413-
// update rmsCache.
4414-
asm.handleRetainedMsg(key, rf, rm, true)
4430+
// Add/update the map. The `rm.Msg` bytes slice will be copied if the object
4431+
// happens to be stored in the rmsCache.
4432+
asm.handleRetainedMsg(key, rf, rm)
44154433
} else {
44164434
c.mu.Lock()
44174435
acc := c.acc

server/mqtt_test.go

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3281,7 +3281,7 @@ func TestMQTTRetainedMsgNetworkUpdates(t *testing.T) {
32813281
for _, a := range test.order {
32823282
if a.add {
32833283
rf := &mqttRetainedMsgRef{sseq: a.seq}
3284-
asm.handleRetainedMsg(test.subject, rf, nil, false)
3284+
asm.handleRetainedMsg(test.subject, rf, nil)
32853285
} else {
32863286
asm.handleRetainedMsgDel(test.subject, a.seq)
32873287
}
@@ -3294,7 +3294,7 @@ func TestMQTTRetainedMsgNetworkUpdates(t *testing.T) {
32943294
t.Run("clear_"+subject, func(t *testing.T) {
32953295
// Now add a new message, which should clear the floor.
32963296
rf := &mqttRetainedMsgRef{sseq: 3}
3297-
asm.handleRetainedMsg(subject, rf, nil, false)
3297+
asm.handleRetainedMsg(subject, rf, nil)
32983298
check(t, subject, true, 3, 0)
32993299
// Now do a non network delete and make sure it is gone.
33003300
asm.handleRetainedMsgDel(subject, 0)
@@ -3315,7 +3315,7 @@ func TestMQTTRetainedMsgDel(t *testing.T) {
33153315
var i uint64
33163316
for i = 0; i < 3; i++ {
33173317
rf := &mqttRetainedMsgRef{sseq: i}
3318-
asm.handleRetainedMsg("subject", rf, nil, false)
3318+
asm.handleRetainedMsg("subject", rf, nil)
33193319
}
33203320
asm.handleRetainedMsgDel("subject", 2)
33213321
if asm.sl.count > 0 {
@@ -3406,6 +3406,102 @@ func TestMQTTRetainedMsgMigration(t *testing.T) {
34063406
}
34073407
}
34083408

3409+
func TestMQTTRetainedNoMsgBodyCorruption(t *testing.T) {
3410+
f := func() {
3411+
o := testMQTTDefaultOptions()
3412+
s := testMQTTRunServer(t, o)
3413+
defer testMQTTShutdownServer(s)
3414+
3415+
// Send a retained message.
3416+
c, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
3417+
defer c.Close()
3418+
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
3419+
testMQTTPublish(t, c, r, 0, false, true, "foo/bar", 0, []byte("retained 1"))
3420+
testMQTTFlush(t, c, nil, r)
3421+
3422+
checkRetained := func(msg string) {
3423+
t.Helper()
3424+
c, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
3425+
defer c.Close()
3426+
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
3427+
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo/#", qos: 0}}, []byte{0})
3428+
testMQTTCheckPubMsg(t, c, r, "foo/bar", mqttPubFlagRetain, []byte(msg))
3429+
}
3430+
// Subscribe to make it load into the cache.
3431+
checkRetained("retained 1")
3432+
3433+
// Now send another one.
3434+
testMQTTPublish(t, c, r, 0, false, true, "foo/bar", 0, []byte("retained 2"))
3435+
testMQTTFlush(t, c, nil, r)
3436+
3437+
// Check it is updated
3438+
checkRetained("retained 2")
3439+
3440+
// Now we will simulate an update coming from another server
3441+
// if we were in cluster mode.
3442+
nc := natsConnect(t, s.ClientURL())
3443+
defer nc.Close()
3444+
3445+
msg := nats.NewMsg("$MQTT.rmsgs.foo.bar")
3446+
msg.Header.Set(mqttNatsRetainedMessageOrigin, "XXXXXXXX")
3447+
msg.Header.Set(mqttNatsRetainedMessageTopic, "foo/bar")
3448+
msg.Header.Set(mqttNatsRetainedMessageFlags, "1")
3449+
msg.Data = []byte("retained 3")
3450+
3451+
// Have a continuous flow of updates coming in
3452+
wg := sync.WaitGroup{}
3453+
wg.Add(1)
3454+
ch := make(chan struct{})
3455+
go func() {
3456+
defer wg.Done()
3457+
for {
3458+
nc.PublishMsg(msg)
3459+
select {
3460+
case <-ch:
3461+
return
3462+
default:
3463+
}
3464+
}
3465+
}()
3466+
3467+
s.mu.RLock()
3468+
sm := &s.mqtt.sessmgr
3469+
s.mu.RUnlock()
3470+
sm.mu.RLock()
3471+
as := sm.sessions[globalAccountName]
3472+
sm.mu.RUnlock()
3473+
require_True(t, as != nil)
3474+
as.mu.RLock()
3475+
cache := as.rmsCache
3476+
as.mu.RUnlock()
3477+
3478+
// Wait to make sure at least the first update occurs
3479+
checkFor(t, time.Second, 10*time.Millisecond, func() error {
3480+
v, ok := cache.Load("foo.bar")
3481+
if !ok {
3482+
return errors.New("not in the cache")
3483+
}
3484+
rm := v.(*mqttRetainedMsg)
3485+
if !bytes.Equal(rm.Msg, []byte("retained 3")) {
3486+
return fmt.Errorf("Retained message not updated, got %q", rm.Msg)
3487+
}
3488+
return nil
3489+
})
3490+
// Repeat starting a subscription to check the retained message and
3491+
// make sure it is not corrupted. With the bug, the payload will at
3492+
// the very least contain trailing "\r\n" and possibly be corrupted
3493+
// (and the race detector would report a race).
3494+
for range 50 {
3495+
checkRetained("retained 3")
3496+
}
3497+
close(ch)
3498+
wg.Wait()
3499+
}
3500+
for range 5 {
3501+
f()
3502+
}
3503+
}
3504+
34093505
func TestMQTTClusterReplicasCount(t *testing.T) {
34103506
for _, test := range []struct {
34113507
size int

0 commit comments

Comments
 (0)