Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions server/mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -1992,6 +1992,10 @@ func (as *mqttAccountSessionManager) processJSAPIReplies(_ *subscription, pc *cl
// No lock held on entry.
func (as *mqttAccountSessionManager) processRetainedMsg(_ *subscription, c *client, _ *Account, subject, reply string, rmsg []byte) {
h, m := c.msgParts(rmsg)
// We need to strip the trailing "\r\n".
if l := len(m); l >= LEN_CR_LF {
m = m[:l-LEN_CR_LF]
}
rm, err := mqttDecodeRetainedMessage(h, m)
if err != nil {
return
Expand All @@ -2008,8 +2012,10 @@ func (as *mqttAccountSessionManager) processRetainedMsg(_ *subscription, c *clie
// At this point we either recover from our own server, or process a remote retained message.
seq, _, _ := ackReplyInfo(reply)

// Handle this retained message, no need to copy the bytes.
as.handleRetainedMsg(rm.Subject, &mqttRetainedMsgRef{sseq: seq}, rm, false)
// Handle this retained message. The `rm.Msg` references some buffer owned
// by the caller. handleRetainedMsg() will take care of making a copy of
// `rm.Msg` it `rm` ends-up being stored in the cache.
as.handleRetainedMsg(rm.Subject, &mqttRetainedMsgRef{sseq: seq}, rm)

// If we were recovering (rrmTotal > 0), then check if we are done.
as.mu.Lock()
Expand Down Expand Up @@ -2286,7 +2292,7 @@ func (as *mqttAccountSessionManager) sendJSAPIrequests(s *Server, c *client, acc
// If a message for this topic already existed, the existing record is updated
// with the provided information.
// Lock not held on entry.
func (as *mqttAccountSessionManager) handleRetainedMsg(key string, rf *mqttRetainedMsgRef, rm *mqttRetainedMsg, copyBytesToCache bool) {
func (as *mqttAccountSessionManager) handleRetainedMsg(key string, rf *mqttRetainedMsgRef, rm *mqttRetainedMsg) {
as.mu.Lock()
defer as.mu.Unlock()
if as.retmsgs == nil {
Expand All @@ -2313,7 +2319,9 @@ func (as *mqttAccountSessionManager) handleRetainedMsg(key string, rf *mqttRetai

// Update the in-memory retained message cache but only for messages
// that are already in the cache, i.e. have been (recently) used.
as.setCachedRetainedMsg(key, rm, true, copyBytesToCache)
// If that is the case, we ask setCachedRetainedMsg() to make a copy
// of rm.Msg bytes slice.
as.setCachedRetainedMsg(key, rm, true, true)
return
}
}
Expand Down Expand Up @@ -3122,7 +3130,17 @@ func (as *mqttAccountSessionManager) getCachedRetainedMsg(subject string) *mqttR
return rm
}

func (as *mqttAccountSessionManager) setCachedRetainedMsg(subject string, rm *mqttRetainedMsg, onlyReplace bool, copyBytesToCache bool) {
// If cache is enabled, the expiration for the `rm` is bumped by
// `mqttRetainedCacheTTL` seconds.
// If `onlyReplace` is true, then the `rm` object is stored in the cache using
// the `subject` key only if there was already an object stored under that key.
// If `copyMsgBytes` is true, then the `rm.Msg` bytes are copied (because it
// references some buffer that is not owned by the caller).
//
// Note: currently `onlyReplace` and `cloneMsgBytes` always have the same
// value (all `true` or all `false`) however we use different booleans to
// better express the intent.
func (as *mqttAccountSessionManager) setCachedRetainedMsg(subject string, rm *mqttRetainedMsg, onlyReplace, copyMsgBytes bool) {
if as.rmsCache == nil || rm == nil {
return
}
Expand All @@ -3132,7 +3150,7 @@ func (as *mqttAccountSessionManager) setCachedRetainedMsg(subject string, rm *mq
return
}
}
if copyBytesToCache {
if copyMsgBytes {
rm.Msg = copyBytes(rm.Msg)
}
as.rmsCache.Store(subject, rm)
Expand Down Expand Up @@ -4409,9 +4427,9 @@ func (c *client) mqttHandlePubRetain() {
rf := &mqttRetainedMsgRef{
sseq: smr.Sequence,
}
// Add/update the map. `true` to copy the payload bytes if needs to
// update rmsCache.
asm.handleRetainedMsg(key, rf, rm, true)
// Add/update the map. The `rm.Msg` bytes slice will be copied if the object
// happens to be stored in the rmsCache.
asm.handleRetainedMsg(key, rf, rm)
} else {
c.mu.Lock()
acc := c.acc
Expand Down
102 changes: 99 additions & 3 deletions server/mqtt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3281,7 +3281,7 @@ func TestMQTTRetainedMsgNetworkUpdates(t *testing.T) {
for _, a := range test.order {
if a.add {
rf := &mqttRetainedMsgRef{sseq: a.seq}
asm.handleRetainedMsg(test.subject, rf, nil, false)
asm.handleRetainedMsg(test.subject, rf, nil)
} else {
asm.handleRetainedMsgDel(test.subject, a.seq)
}
Expand All @@ -3294,7 +3294,7 @@ func TestMQTTRetainedMsgNetworkUpdates(t *testing.T) {
t.Run("clear_"+subject, func(t *testing.T) {
// Now add a new message, which should clear the floor.
rf := &mqttRetainedMsgRef{sseq: 3}
asm.handleRetainedMsg(subject, rf, nil, false)
asm.handleRetainedMsg(subject, rf, nil)
check(t, subject, true, 3, 0)
// Now do a non network delete and make sure it is gone.
asm.handleRetainedMsgDel(subject, 0)
Expand All @@ -3315,7 +3315,7 @@ func TestMQTTRetainedMsgDel(t *testing.T) {
var i uint64
for i = 0; i < 3; i++ {
rf := &mqttRetainedMsgRef{sseq: i}
asm.handleRetainedMsg("subject", rf, nil, false)
asm.handleRetainedMsg("subject", rf, nil)
}
asm.handleRetainedMsgDel("subject", 2)
if asm.sl.count > 0 {
Expand Down Expand Up @@ -3406,6 +3406,102 @@ func TestMQTTRetainedMsgMigration(t *testing.T) {
}
}

func TestMQTTRetainedNoMsgBodyCorruption(t *testing.T) {
f := func() {
o := testMQTTDefaultOptions()
s := testMQTTRunServer(t, o)
defer testMQTTShutdownServer(s)

// Send a retained message.
c, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTPublish(t, c, r, 0, false, true, "foo/bar", 0, []byte("retained 1"))
testMQTTFlush(t, c, nil, r)

checkRetained := func(msg string) {
t.Helper()
c, r := testMQTTConnect(t, &mqttConnInfo{cleanSess: true}, o.MQTT.Host, o.MQTT.Port)
defer c.Close()
testMQTTCheckConnAck(t, r, mqttConnAckRCConnectionAccepted, false)
testMQTTSub(t, 1, c, r, []*mqttFilter{{filter: "foo/#", qos: 0}}, []byte{0})
testMQTTCheckPubMsg(t, c, r, "foo/bar", mqttPubFlagRetain, []byte(msg))
}
// Subscribe to make it load into the cache.
checkRetained("retained 1")

// Now send another one.
testMQTTPublish(t, c, r, 0, false, true, "foo/bar", 0, []byte("retained 2"))
testMQTTFlush(t, c, nil, r)

// Check it is updated
checkRetained("retained 2")

// Now we will simulate an update coming from another server
// if we were in cluster mode.
nc := natsConnect(t, s.ClientURL())
defer nc.Close()

msg := nats.NewMsg("$MQTT.rmsgs.foo.bar")
msg.Header.Set(mqttNatsRetainedMessageOrigin, "XXXXXXXX")
msg.Header.Set(mqttNatsRetainedMessageTopic, "foo/bar")
msg.Header.Set(mqttNatsRetainedMessageFlags, "1")
msg.Data = []byte("retained 3")

// Have a continuous flow of updates coming in
wg := sync.WaitGroup{}
wg.Add(1)
ch := make(chan struct{})
go func() {
defer wg.Done()
for {
nc.PublishMsg(msg)
select {
case <-ch:
return
default:
}
}
}()

s.mu.RLock()
sm := &s.mqtt.sessmgr
s.mu.RUnlock()
sm.mu.RLock()
as := sm.sessions[globalAccountName]
sm.mu.RUnlock()
require_True(t, as != nil)
as.mu.RLock()
cache := as.rmsCache
as.mu.RUnlock()

// Wait to make sure at least the first update occurs
checkFor(t, time.Second, 10*time.Millisecond, func() error {
v, ok := cache.Load("foo.bar")
if !ok {
return errors.New("not in the cache")
}
rm := v.(*mqttRetainedMsg)
if !bytes.Equal(rm.Msg, []byte("retained 3")) {
return fmt.Errorf("Retained message not updated, got %q", rm.Msg)
}
return nil
})
// Repeat starting a subscription to check the retained message and
// make sure it is not corrupted. With the bug, the payload will at
// the very least contain trailing "\r\n" and possibly be corrupted
// (and the race detector would report a race).
for range 50 {
checkRetained("retained 3")
}
close(ch)
wg.Wait()
}
for range 5 {
f()
}
}

func TestMQTTClusterReplicasCount(t *testing.T) {
for _, test := range []struct {
size int
Expand Down