Skip to content

Commit

Permalink
Break notify into submethods to create the session then create the pu…
Browse files Browse the repository at this point in the history
…blish input to send. Check we populate a region for all requests.

This reverts commit 4c2a5f1.

Signed-off-by: Tyler Reid <tyler.reid@grafana.com>
  • Loading branch information
Tyler Reid committed Jul 9, 2021
1 parent 72b368c commit f53d0ab
Showing 1 changed file with 65 additions and 41 deletions.
106 changes: 65 additions & 41 deletions notify/sns/sns.go
Expand Up @@ -62,20 +62,45 @@ func New(c *config.SNSConfig, t *template.Template, l log.Logger, httpOpts ...co

func (n *Notifier) Notify(ctx context.Context, alert ...*types.Alert) (bool, error) {
var (
err error
data = notify.GetTemplateData(ctx, n.tmpl, alert, n.logger)
tmpl = notify.TmplText(n.tmpl, data, &err)
creds *credentials.Credentials = nil
err error
data = notify.GetTemplateData(ctx, n.tmpl, alert, n.logger)
tmpl = notify.TmplText(n.tmpl, data, &err)
)
if n.conf.Sigv4.AccessKey != "" && n.conf.Sigv4.SecretKey != "" {
creds = credentials.NewStaticCredentials(n.conf.Sigv4.AccessKey, string(n.conf.Sigv4.SecretKey), "")

client, err := createSNSClient(n, tmpl)
if err != nil {
if e, ok := err.(awserr.RequestFailure); ok {
return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message()))
} else {
return true, err
}
}

attributes := make(map[string]*sns.MessageAttributeValue, len(n.conf.Attributes))
for k, v := range n.conf.Attributes {
attributes[tmpl(k)] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(tmpl(v))}
publishInput, err := createPublishInput(ctx, n, tmpl)
if err != nil {
return true, err
}

publishOutput, err := client.Publish(publishInput)
if err != nil {
if e, ok := err.(awserr.RequestFailure); ok {
return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message()))
} else {
return true, err
}
}

level.Debug(n.logger).Log("msg", "SNS message successfully published", "message_id", publishOutput.MessageId, "sequence number", publishOutput.SequenceNumber)

return false, nil
}

func createSNSClient(n *Notifier, tmpl func(string) string) (*sns.SNS, error) {
var creds *credentials.Credentials = nil
// If there are provided sigV4 credentials we want to use those to create a session.
if n.conf.Sigv4.AccessKey != "" && n.conf.Sigv4.SecretKey != "" {
creds = credentials.NewStaticCredentials(n.conf.Sigv4.AccessKey, string(n.conf.Sigv4.SecretKey), "")
}
sess, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{
Region: aws.String(n.conf.Sigv4.Region),
Expand All @@ -84,11 +109,7 @@ func (n *Notifier) Notify(ctx context.Context, alert ...*types.Alert) (bool, err
Profile: n.conf.Sigv4.Profile,
})
if err != nil {
if e, ok := err.(awserr.RequestFailure); ok {
return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message()))
} else {
return true, err
}
return nil, err
}

if n.conf.Sigv4.RoleARN != "" {
Expand All @@ -105,32 +126,37 @@ func (n *Notifier) Notify(ctx context.Context, alert ...*types.Alert) (bool, err
Profile: n.conf.Sigv4.Profile,
})
if err != nil {
if e, ok := err.(awserr.RequestFailure); ok {
return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message()))
} else {
return true, err
}
return nil, err
}
}
creds = stscreds.NewCredentials(stsSess, n.conf.Sigv4.RoleARN)
}
// Max message size for a message in a SNS publish request is 256KB, except for SMS messages where the limit is 1600 characters/runes.
messageSizeLimit := 256 * 1024
// Use our generated session with credentials to create the SNS Client.
client := sns.New(sess, &aws.Config{Credentials: creds})
publishInput := &sns.PublishInput{}
// We will always need a region to be set by either the local config or the environment.
if aws.StringValue(sess.Config.Region) == "" {
return nil, fmt.Errorf("region not configured in sns.sigv4.region or in default credentials chain")
}
return client, nil
}

func createPublishInput(ctx context.Context, n *Notifier, tmpl func(string) string) (*sns.PublishInput, error) {
publishInput := &sns.PublishInput{}
messageAttributes := createMessageAttributes(n, tmpl)
// Max message size for a message in a SNS publish request is 256KB, except for SMS messages where the limit is 1600 characters/runes.
messageSizeLimit := 256 * 1024
if n.conf.TopicARN != "" {
topicTmpl := tmpl(n.conf.TopicARN)
publishInput.SetTopicArn(topicTmpl)

if n.isFifo == nil {
// If we are using a topic ARN it could be a FIFO topic specified by the topic postfix .fifo.
n.isFifo = aws.Bool(n.conf.TopicARN[len(n.conf.TopicARN)-5:] == ".fifo")
}
if *n.isFifo {
// Deduplication key and Message Group ID are only added if it's a FIFO SNS Topic.
key, err := notify.ExtractGroupKey(ctx)
if err != nil {
return false, err
return nil, err
}
publishInput.SetMessageDeduplicationId(key.Hash())
publishInput.SetMessageGroupId(key.Hash())
Expand All @@ -143,36 +169,25 @@ func (n *Notifier) Notify(ctx context.Context, alert ...*types.Alert) (bool, err
}
if n.conf.TargetARN != "" {
publishInput.SetTargetArn(tmpl(n.conf.TargetARN))

}

messageToSend, isTrunc, err := validateAndTruncateMessage(tmpl(n.conf.Message), messageSizeLimit)
if err != nil {
return false, err
return nil, err
}
if isTrunc {
attributes["truncated"] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String("true")}
// If we truncated the message we need to add a message attribute showing that it was truncated.
messageAttributes["truncated"] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String("true")}
}

publishInput.SetMessage(messageToSend)
publishInput.SetMessageAttributes(messageAttributes)

if n.conf.Subject != "" {
publishInput.SetSubject(tmpl(n.conf.Subject))
}

publishInput.SetMessageAttributes(attributes)

publishOutput, err := client.Publish(publishInput)
if err != nil {
if e, ok := err.(awserr.RequestFailure); ok {
return n.retrier.Check(e.StatusCode(), strings.NewReader(e.Message()))
} else {
return true, err
}
}

level.Debug(n.logger).Log("msg", "SNS message successfully published", "message_id", publishOutput.MessageId, "sequence number", publishOutput.SequenceNumber)

return false, nil
return publishInput, nil
}

func validateAndTruncateMessage(message string, maxMessageSizeInBytes int) (string, bool, error) {
Expand All @@ -187,3 +202,12 @@ func validateAndTruncateMessage(message string, maxMessageSizeInBytes int) (stri
copy(truncated, message)
return string(truncated), true, nil
}

func createMessageAttributes(n *Notifier, tmpl func(string) string) map[string]*sns.MessageAttributeValue {
// Convert the given attributes map into the AWS Message Attributes Format
attributes := make(map[string]*sns.MessageAttributeValue, len(n.conf.Attributes))
for k, v := range n.conf.Attributes {
attributes[tmpl(k)] = &sns.MessageAttributeValue{DataType: aws.String("String"), StringValue: aws.String(tmpl(v))}
}
return attributes
}

0 comments on commit f53d0ab

Please sign in to comment.