Skip to content

Commit

Permalink
Add support for additional AMQP URI query parameters
Browse files Browse the repository at this point in the history
https://www.rabbitmq.com/docs/uri-query-parameters specifies several parameters that are used in this library, but not yet supported in URIs.

This commit adds support for the following parameters:
auth_mechanism
heartbeat
connection_timeout
channel_max

Fix default value check when setting SASL authentication from URI

Add documentation for added query parameters

Add support for additional AMQP URI query parameters

https://www.rabbitmq.com/docs/uri-query-parameters specifies several parameters that are used in this library, but not yet supported in URIs.

This commit adds support for the following parameters:
auth_mechanism
heartbeat
connection_timeout
channel_max

Fix default value check when setting SASL authentication from URI

Fix ChannelMax type mismatch
  • Loading branch information
vilius-g authored and lukebakken committed Mar 6, 2024
1 parent a2fcd5b commit 9044e89
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 12 deletions.
32 changes: 30 additions & 2 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,18 +206,46 @@ func DialConfig(url string, config Config) (*Connection, error) {
}

if config.SASL == nil {
config.SASL = []Authentication{uri.PlainAuth()}
if uri.AuthMechanism != nil {
for _, identifier := range uri.AuthMechanism {
switch strings.ToUpper(identifier) {
case "PLAIN":
config.SASL = append(config.SASL, uri.PlainAuth())
case "AMQPLAIN":
config.SASL = append(config.SASL, uri.AMQPlainAuth())
case "EXTERNAL":
config.SASL = append(config.SASL, &ExternalAuth{})
default:
return nil, fmt.Errorf("unsupported auth_mechanism: %v", identifier)
}
}
} else {
config.SASL = []Authentication{uri.PlainAuth()}
}
}

if config.Vhost == "" {
config.Vhost = uri.Vhost
}

if config.Heartbeat == 0 {
config.Heartbeat = time.Duration(uri.Heartbeat) * time.Second
}

if config.ChannelMax == 0 {
config.ChannelMax = uri.ChannelMax
}

connectionTimeout := defaultConnectionTimeout
if uri.ConnectionTimeout != 0 {
connectionTimeout = time.Duration(uri.ConnectionTimeout) * time.Millisecond
}

addr := net.JoinHostPort(uri.Host, strconv.FormatInt(int64(uri.Port), 10))

dialer := config.Dial
if dialer == nil {
dialer = DefaultDial(defaultConnectionTimeout)
dialer = DefaultDial(connectionTimeout)
}

conn, err = dialer("tcp", addr)
Expand Down
54 changes: 44 additions & 10 deletions uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package amqp091

import (
"errors"
"fmt"
"net"
"net/url"
"strconv"
Expand All @@ -32,16 +33,20 @@ var defaultURI = URI{

// URI represents a parsed AMQP URI string.
type URI struct {
Scheme string
Host string
Port int
Username string
Password string
Vhost string
CertFile string // client TLS auth - path to certificate (PEM)
CACertFile string // client TLS auth - path to CA certificate (PEM)
KeyFile string // client TLS auth - path to private key (PEM)
ServerName string // client TLS auth - server name
Scheme string
Host string
Port int
Username string
Password string
Vhost string
CertFile string // client TLS auth - path to certificate (PEM)
CACertFile string // client TLS auth - path to CA certificate (PEM)
KeyFile string // client TLS auth - path to private key (PEM)
ServerName string // client TLS auth - server name
AuthMechanism []string
Heartbeat int
ConnectionTimeout int
ChannelMax uint16
}

// ParseURI attempts to parse the given AMQP URI according to the spec.
Expand All @@ -62,6 +67,10 @@ type URI struct {
// keyfile: <path/to/client_key.pem>
// cacertfile: <path/to/ca.pem>
// server_name_indication: <server name>
// auth_mechanism: <one or more: plain, amqplain, external>
// heartbeat: <seconds (integer)>
// connection_timeout: <milliseconds (integer)>
// channel_max: <max number of channels (integer)>
//
// If cacertfile is not provided, system CA certificates will be used.
// Mutual TLS (client auth) will be enabled only in case keyfile AND certfile provided.
Expand Down Expand Up @@ -134,6 +143,31 @@ func ParseURI(uri string) (URI, error) {
builder.KeyFile = params.Get("keyfile")
builder.CACertFile = params.Get("cacertfile")
builder.ServerName = params.Get("server_name_indication")
builder.AuthMechanism = params["auth_mechanism"]

if params.Has("heartbeat") {
value, err := strconv.Atoi(params.Get("heartbeat"))
if err != nil {
return builder, fmt.Errorf("heartbeat is not an integer: %v", err)
}
builder.Heartbeat = value
}

if params.Has("connection_timeout") {
value, err := strconv.Atoi(params.Get("connection_timeout"))
if err != nil {
return builder, fmt.Errorf("connection_timeout is not an integer: %v", err)
}
builder.ConnectionTimeout = value
}

if params.Has("channel_max") {
value, err := strconv.ParseUint(params.Get("channel_max"), 10, 16)
if err != nil {
return builder, fmt.Errorf("connection_timeout is not an integer: %v", err)
}
builder.ChannelMax = uint16(value)
}

return builder, nil
}
Expand Down
21 changes: 21 additions & 0 deletions uri_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package amqp091

import (
"reflect"
"testing"
)

Expand Down Expand Up @@ -388,3 +389,23 @@ func TestURITLSConfig(t *testing.T) {
t.Fatal("Server name not set")
}
}

func TestURIParameters(t *testing.T) {
url := "amqps://foo.bar/?auth_mechanism=plain&auth_mechanism=amqpplain&heartbeat=2&connection_timeout=5000&channel_max=8"
uri, err := ParseURI(url)
if err != nil {
t.Fatal("Could not parse")
}
if !reflect.DeepEqual(uri.AuthMechanism, []string{"plain", "amqpplain"}) {
t.Fatal("AuthMechanism not set")
}
if uri.Heartbeat != 2 {
t.Fatal("Heartbeat not set")
}
if uri.ConnectionTimeout != 5000 {
t.Fatal("ConnectionTimeout not set")
}
if uri.ChannelMax != 8 {
t.Fatal("ChannelMax name not set")
}
}

0 comments on commit 9044e89

Please sign in to comment.