Skip to content

Commit 70d14d8

Browse files
vilius-glukebakken
authored andcommitted
Add support for additional AMQP URI query parameters
https://www.rabbitmq.com/docs/uri-query-parameters specifies several parameters that are used in this library, but not yet supported in URIs. This commit adds support for the following parameters: auth_mechanism heartbeat connection_timeout channel_max Fix default value check when setting SASL authentication from URI Add documentation for added query parameters Add support for additional AMQP URI query parameters https://www.rabbitmq.com/docs/uri-query-parameters specifies several parameters that are used in this library, but not yet supported in URIs. This commit adds support for the following parameters: auth_mechanism heartbeat connection_timeout channel_max Fix default value check when setting SASL authentication from URI Fix ChannelMax type mismatch Use URI heartbeat Bump versions on Windows
1 parent a2fcd5b commit 70d14d8

File tree

5 files changed

+118
-18
lines changed

5 files changed

+118
-18
lines changed

.ci/versions.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
2-
"erlang": "26.1.1",
3-
"rabbitmq": "3.12.6"
2+
"erlang": "26.2.2",
3+
"rabbitmq": "3.13.0"
44
}

connection.go

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,7 @@ func DefaultDial(connectionTimeout time.Duration) func(network, addr string) (ne
157157
// scheme. It is equivalent to calling DialTLS(amqp, nil).
158158
func Dial(url string) (*Connection, error) {
159159
return DialConfig(url, Config{
160-
Heartbeat: defaultHeartbeat,
161-
Locale: defaultLocale,
160+
Locale: defaultLocale,
162161
})
163162
}
164163

@@ -169,7 +168,6 @@ func Dial(url string) (*Connection, error) {
169168
// DialTLS uses the provided tls.Config when encountering an amqps:// scheme.
170169
func DialTLS(url string, amqps *tls.Config) (*Connection, error) {
171170
return DialConfig(url, Config{
172-
Heartbeat: defaultHeartbeat,
173171
TLSClientConfig: amqps,
174172
Locale: defaultLocale,
175173
})
@@ -186,7 +184,6 @@ func DialTLS(url string, amqps *tls.Config) (*Connection, error) {
186184
// amqps:// scheme.
187185
func DialTLS_ExternalAuth(url string, amqps *tls.Config) (*Connection, error) {
188186
return DialConfig(url, Config{
189-
Heartbeat: defaultHeartbeat,
190187
TLSClientConfig: amqps,
191188
SASL: []Authentication{&ExternalAuth{}},
192189
})
@@ -206,18 +203,50 @@ func DialConfig(url string, config Config) (*Connection, error) {
206203
}
207204

208205
if config.SASL == nil {
209-
config.SASL = []Authentication{uri.PlainAuth()}
206+
if uri.AuthMechanism != nil {
207+
for _, identifier := range uri.AuthMechanism {
208+
switch strings.ToUpper(identifier) {
209+
case "PLAIN":
210+
config.SASL = append(config.SASL, uri.PlainAuth())
211+
case "AMQPLAIN":
212+
config.SASL = append(config.SASL, uri.AMQPlainAuth())
213+
case "EXTERNAL":
214+
config.SASL = append(config.SASL, &ExternalAuth{})
215+
default:
216+
return nil, fmt.Errorf("unsupported auth_mechanism: %v", identifier)
217+
}
218+
}
219+
} else {
220+
config.SASL = []Authentication{uri.PlainAuth()}
221+
}
210222
}
211223

212224
if config.Vhost == "" {
213225
config.Vhost = uri.Vhost
214226
}
215227

228+
if uri.Heartbeat.hasValue {
229+
config.Heartbeat = uri.Heartbeat.value
230+
} else {
231+
if config.Heartbeat == 0 {
232+
config.Heartbeat = defaultHeartbeat
233+
}
234+
}
235+
236+
if config.ChannelMax == 0 {
237+
config.ChannelMax = uri.ChannelMax
238+
}
239+
240+
connectionTimeout := defaultConnectionTimeout
241+
if uri.ConnectionTimeout != 0 {
242+
connectionTimeout = time.Duration(uri.ConnectionTimeout) * time.Millisecond
243+
}
244+
216245
addr := net.JoinHostPort(uri.Host, strconv.FormatInt(int64(uri.Port), 10))
217246

218247
dialer := config.Dial
219248
if dialer == nil {
220-
dialer = DefaultDial(defaultConnectionTimeout)
249+
dialer = DefaultDial(connectionTimeout)
221250
}
222251

223252
conn, err = dialer("tcp", addr)

types.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,3 +553,16 @@ type bodyFrame struct {
553553
}
554554

555555
func (f *bodyFrame) channel() uint16 { return f.ChannelId }
556+
557+
type heartbeatDuration struct {
558+
value time.Duration
559+
hasValue bool
560+
}
561+
562+
func newHeartbeatDurationFromSeconds(s int) heartbeatDuration {
563+
v := time.Duration(s) * time.Second
564+
return heartbeatDuration{
565+
value: v,
566+
hasValue: true,
567+
}
568+
}

uri.go

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package amqp091
77

88
import (
99
"errors"
10+
"fmt"
1011
"net"
1112
"net/url"
1213
"strconv"
@@ -32,16 +33,20 @@ var defaultURI = URI{
3233

3334
// URI represents a parsed AMQP URI string.
3435
type URI struct {
35-
Scheme string
36-
Host string
37-
Port int
38-
Username string
39-
Password string
40-
Vhost string
41-
CertFile string // client TLS auth - path to certificate (PEM)
42-
CACertFile string // client TLS auth - path to CA certificate (PEM)
43-
KeyFile string // client TLS auth - path to private key (PEM)
44-
ServerName string // client TLS auth - server name
36+
Scheme string
37+
Host string
38+
Port int
39+
Username string
40+
Password string
41+
Vhost string
42+
CertFile string // client TLS auth - path to certificate (PEM)
43+
CACertFile string // client TLS auth - path to CA certificate (PEM)
44+
KeyFile string // client TLS auth - path to private key (PEM)
45+
ServerName string // client TLS auth - server name
46+
AuthMechanism []string
47+
Heartbeat heartbeatDuration
48+
ConnectionTimeout int
49+
ChannelMax uint16
4550
}
4651

4752
// ParseURI attempts to parse the given AMQP URI according to the spec.
@@ -62,6 +67,10 @@ type URI struct {
6267
// keyfile: <path/to/client_key.pem>
6368
// cacertfile: <path/to/ca.pem>
6469
// server_name_indication: <server name>
70+
// auth_mechanism: <one or more: plain, amqplain, external>
71+
// heartbeat: <seconds (integer)>
72+
// connection_timeout: <milliseconds (integer)>
73+
// channel_max: <max number of channels (integer)>
6574
//
6675
// If cacertfile is not provided, system CA certificates will be used.
6776
// Mutual TLS (client auth) will be enabled only in case keyfile AND certfile provided.
@@ -134,6 +143,31 @@ func ParseURI(uri string) (URI, error) {
134143
builder.KeyFile = params.Get("keyfile")
135144
builder.CACertFile = params.Get("cacertfile")
136145
builder.ServerName = params.Get("server_name_indication")
146+
builder.AuthMechanism = params["auth_mechanism"]
147+
148+
if params.Has("heartbeat") {
149+
value, err := strconv.Atoi(params.Get("heartbeat"))
150+
if err != nil {
151+
return builder, fmt.Errorf("heartbeat is not an integer: %v", err)
152+
}
153+
builder.Heartbeat = newHeartbeatDurationFromSeconds(value)
154+
}
155+
156+
if params.Has("connection_timeout") {
157+
value, err := strconv.Atoi(params.Get("connection_timeout"))
158+
if err != nil {
159+
return builder, fmt.Errorf("connection_timeout is not an integer: %v", err)
160+
}
161+
builder.ConnectionTimeout = value
162+
}
163+
164+
if params.Has("channel_max") {
165+
value, err := strconv.ParseUint(params.Get("channel_max"), 10, 16)
166+
if err != nil {
167+
return builder, fmt.Errorf("connection_timeout is not an integer: %v", err)
168+
}
169+
builder.ChannelMax = uint16(value)
170+
}
137171

138172
return builder, nil
139173
}

uri_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package amqp091
77

88
import (
9+
"reflect"
910
"testing"
1011
)
1112

@@ -388,3 +389,26 @@ func TestURITLSConfig(t *testing.T) {
388389
t.Fatal("Server name not set")
389390
}
390391
}
392+
393+
func TestURIParameters(t *testing.T) {
394+
url := "amqps://foo.bar/?auth_mechanism=plain&auth_mechanism=amqpplain&heartbeat=2&connection_timeout=5000&channel_max=8"
395+
uri, err := ParseURI(url)
396+
if err != nil {
397+
t.Fatal("Could not parse")
398+
}
399+
if !reflect.DeepEqual(uri.AuthMechanism, []string{"plain", "amqpplain"}) {
400+
t.Fatal("AuthMechanism not set")
401+
}
402+
if !uri.Heartbeat.hasValue {
403+
t.Fatal("Heartbeat not set")
404+
}
405+
if uri.Heartbeat.value != 2 {
406+
t.Fatal("Heartbeat not set")
407+
}
408+
if uri.ConnectionTimeout != 5000 {
409+
t.Fatal("ConnectionTimeout not set")
410+
}
411+
if uri.ChannelMax != 8 {
412+
t.Fatal("ChannelMax name not set")
413+
}
414+
}

0 commit comments

Comments
 (0)