Skip to content

Commit

Permalink
Add support for additional AMQP URI query parameters
Browse files Browse the repository at this point in the history
https://www.rabbitmq.com/docs/uri-query-parameters specifies several parameters that are used in this library, but not yet supported in URIs.

This commit adds support for the following parameters:
auth_mechanism
heartbeat
connection_timeout
channel_max

Fix default value check when setting SASL authentication from URI

Add documentation for added query parameters

Add support for additional AMQP URI query parameters

https://www.rabbitmq.com/docs/uri-query-parameters specifies several parameters that are used in this library, but not yet supported in URIs.

This commit adds support for the following parameters:
auth_mechanism
heartbeat
connection_timeout
channel_max

Fix default value check when setting SASL authentication from URI

Fix ChannelMax type mismatch

Use URI heartbeat

Bump versions on Windows
  • Loading branch information
vilius-g authored and lukebakken committed Mar 13, 2024
1 parent a2fcd5b commit 4172682
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 19 deletions.
4 changes: 2 additions & 2 deletions .ci/versions.json
@@ -1,4 +1,4 @@
{
"erlang": "26.1.1",
"rabbitmq": "3.12.6"
"erlang": "26.2.2",
"rabbitmq": "3.13.0"
}
45 changes: 38 additions & 7 deletions connection.go
Expand Up @@ -157,8 +157,7 @@ func DefaultDial(connectionTimeout time.Duration) func(network, addr string) (ne
// scheme. It is equivalent to calling DialTLS(amqp, nil).
func Dial(url string) (*Connection, error) {
return DialConfig(url, Config{
Heartbeat: defaultHeartbeat,
Locale: defaultLocale,
Locale: defaultLocale,
})
}

Expand All @@ -169,7 +168,6 @@ func Dial(url string) (*Connection, error) {
// DialTLS uses the provided tls.Config when encountering an amqps:// scheme.
func DialTLS(url string, amqps *tls.Config) (*Connection, error) {
return DialConfig(url, Config{
Heartbeat: defaultHeartbeat,
TLSClientConfig: amqps,
Locale: defaultLocale,
})
Expand All @@ -186,7 +184,6 @@ func DialTLS(url string, amqps *tls.Config) (*Connection, error) {
// amqps:// scheme.
func DialTLS_ExternalAuth(url string, amqps *tls.Config) (*Connection, error) {
return DialConfig(url, Config{
Heartbeat: defaultHeartbeat,
TLSClientConfig: amqps,
SASL: []Authentication{&ExternalAuth{}},
})
Expand All @@ -195,7 +192,9 @@ func DialTLS_ExternalAuth(url string, amqps *tls.Config) (*Connection, error) {
// DialConfig accepts a string in the AMQP URI format and a configuration for
// the transport and connection setup, returning a new Connection. Defaults to
// a server heartbeat interval of 10 seconds and sets the initial read deadline
// to 30 seconds.
// to 30 seconds. The heartbeat interval specified in the AMQP URI takes precedence
// over the value specified in the config. To disable heartbeats, you must use
// the AMQP URI and set heartbeat=0 there.
func DialConfig(url string, config Config) (*Connection, error) {
var err error
var conn net.Conn
Expand All @@ -206,18 +205,50 @@ func DialConfig(url string, config Config) (*Connection, error) {
}

if config.SASL == nil {
config.SASL = []Authentication{uri.PlainAuth()}
if uri.AuthMechanism != nil {
for _, identifier := range uri.AuthMechanism {
switch strings.ToUpper(identifier) {
case "PLAIN":
config.SASL = append(config.SASL, uri.PlainAuth())
case "AMQPLAIN":
config.SASL = append(config.SASL, uri.AMQPlainAuth())
case "EXTERNAL":
config.SASL = append(config.SASL, &ExternalAuth{})
default:
return nil, fmt.Errorf("unsupported auth_mechanism: %v", identifier)
}
}
} else {
config.SASL = []Authentication{uri.PlainAuth()}
}
}

if config.Vhost == "" {
config.Vhost = uri.Vhost
}

if uri.Heartbeat.hasValue {
config.Heartbeat = uri.Heartbeat.value
} else {
if config.Heartbeat == 0 {
config.Heartbeat = defaultHeartbeat
}
}

if config.ChannelMax == 0 {
config.ChannelMax = uri.ChannelMax
}

connectionTimeout := defaultConnectionTimeout
if uri.ConnectionTimeout != 0 {
connectionTimeout = time.Duration(uri.ConnectionTimeout) * time.Millisecond
}

addr := net.JoinHostPort(uri.Host, strconv.FormatInt(int64(uri.Port), 10))

dialer := config.Dial
if dialer == nil {
dialer = DefaultDial(defaultConnectionTimeout)
dialer = DefaultDial(connectionTimeout)
}

conn, err = dialer("tcp", addr)
Expand Down
13 changes: 13 additions & 0 deletions types.go
Expand Up @@ -553,3 +553,16 @@ type bodyFrame struct {
}

func (f *bodyFrame) channel() uint16 { return f.ChannelId }

type heartbeatDuration struct {
value time.Duration
hasValue bool
}

func newHeartbeatDurationFromSeconds(s int) heartbeatDuration {
v := time.Duration(s) * time.Second
return heartbeatDuration{
value: v,
hasValue: true,
}
}
54 changes: 44 additions & 10 deletions uri.go
Expand Up @@ -7,6 +7,7 @@ package amqp091

import (
"errors"
"fmt"
"net"
"net/url"
"strconv"
Expand All @@ -32,16 +33,20 @@ var defaultURI = URI{

// URI represents a parsed AMQP URI string.
type URI struct {
Scheme string
Host string
Port int
Username string
Password string
Vhost string
CertFile string // client TLS auth - path to certificate (PEM)
CACertFile string // client TLS auth - path to CA certificate (PEM)
KeyFile string // client TLS auth - path to private key (PEM)
ServerName string // client TLS auth - server name
Scheme string
Host string
Port int
Username string
Password string
Vhost string
CertFile string // client TLS auth - path to certificate (PEM)
CACertFile string // client TLS auth - path to CA certificate (PEM)
KeyFile string // client TLS auth - path to private key (PEM)
ServerName string // client TLS auth - server name
AuthMechanism []string
Heartbeat heartbeatDuration
ConnectionTimeout int
ChannelMax uint16
}

// ParseURI attempts to parse the given AMQP URI according to the spec.
Expand All @@ -62,6 +67,10 @@ type URI struct {
// keyfile: <path/to/client_key.pem>
// cacertfile: <path/to/ca.pem>
// server_name_indication: <server name>
// auth_mechanism: <one or more: plain, amqplain, external>
// heartbeat: <seconds (integer)>
// connection_timeout: <milliseconds (integer)>
// channel_max: <max number of channels (integer)>
//
// If cacertfile is not provided, system CA certificates will be used.
// Mutual TLS (client auth) will be enabled only in case keyfile AND certfile provided.
Expand Down Expand Up @@ -134,6 +143,31 @@ func ParseURI(uri string) (URI, error) {
builder.KeyFile = params.Get("keyfile")
builder.CACertFile = params.Get("cacertfile")
builder.ServerName = params.Get("server_name_indication")
builder.AuthMechanism = params["auth_mechanism"]

if params.Has("heartbeat") {
value, err := strconv.Atoi(params.Get("heartbeat"))
if err != nil {
return builder, fmt.Errorf("heartbeat is not an integer: %v", err)
}
builder.Heartbeat = newHeartbeatDurationFromSeconds(value)
}

if params.Has("connection_timeout") {
value, err := strconv.Atoi(params.Get("connection_timeout"))
if err != nil {
return builder, fmt.Errorf("connection_timeout is not an integer: %v", err)
}
builder.ConnectionTimeout = value
}

if params.Has("channel_max") {
value, err := strconv.ParseUint(params.Get("channel_max"), 10, 16)
if err != nil {
return builder, fmt.Errorf("connection_timeout is not an integer: %v", err)
}
builder.ChannelMax = uint16(value)
}

return builder, nil
}
Expand Down
25 changes: 25 additions & 0 deletions uri_test.go
Expand Up @@ -6,7 +6,9 @@
package amqp091

import (
"reflect"
"testing"
"time"
)

// Test matrix defined on http://www.rabbitmq.com/uri-spec.html
Expand Down Expand Up @@ -388,3 +390,26 @@ func TestURITLSConfig(t *testing.T) {
t.Fatal("Server name not set")
}
}

func TestURIParameters(t *testing.T) {
url := "amqps://foo.bar/?auth_mechanism=plain&auth_mechanism=amqpplain&heartbeat=2&connection_timeout=5000&channel_max=8"
uri, err := ParseURI(url)
if err != nil {
t.Fatal("Could not parse")
}
if !reflect.DeepEqual(uri.AuthMechanism, []string{"plain", "amqpplain"}) {
t.Fatal("AuthMechanism not set")
}
if !uri.Heartbeat.hasValue {
t.Fatal("Heartbeat not set")
}
if uri.Heartbeat.value != time.Duration(2)*time.Second {
t.Fatal("Heartbeat not set")
}
if uri.ConnectionTimeout != 5000 {
t.Fatal("ConnectionTimeout not set")
}
if uri.ChannelMax != 8 {
t.Fatal("ChannelMax name not set")
}
}

0 comments on commit 4172682

Please sign in to comment.