/
customizations_test.go
107 lines (95 loc) · 2.73 KB
/
customizations_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
// +build go1.7
package ec2_test
import (
"bytes"
"context"
"io/ioutil"
"net/http"
"net/url"
"regexp"
"testing"
"github.com/aws/aws-sdk-go/aws"
sdkclient "github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/ec2"
)
func TestCopySnapshotPresignedURL(t *testing.T) {
svc := ec2.New(unit.Session, &aws.Config{Region: aws.String("us-west-2")})
func() {
defer func() {
if r := recover(); r != nil {
t.Fatalf("expect CopySnapshotRequest with nill")
}
}()
// Doesn't panic on nil input
req, _ := svc.CopySnapshotRequest(nil)
req.Sign()
}()
req, _ := svc.CopySnapshotRequest(&ec2.CopySnapshotInput{
SourceRegion: aws.String("us-west-1"),
SourceSnapshotId: aws.String("snap-id"),
})
req.Sign()
b, _ := ioutil.ReadAll(req.HTTPRequest.Body)
q, _ := url.ParseQuery(string(b))
u, _ := url.QueryUnescape(q.Get("PresignedUrl"))
if e, a := "us-west-2", q.Get("DestinationRegion"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "us-west-1", q.Get("SourceRegion"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
r := regexp.MustCompile(`^https://ec2\.us-west-1\.amazonaws\.com/.+&DestinationRegion=us-west-2`)
if !r.MatchString(u) {
t.Errorf("expect %v to match, got %v", r.String(), u)
}
}
func TestNoCustomRetryerWithMaxRetries(t *testing.T) {
cases := map[string]struct {
Config aws.Config
ExpectMaxRetries int
}{
"With custom retrier": {
Config: aws.Config{
Retryer: sdkclient.DefaultRetryer{
NumMaxRetries: 10,
},
},
ExpectMaxRetries: 10,
},
"with max retries": {
Config: aws.Config{
MaxRetries: aws.Int(10),
},
ExpectMaxRetries: 10,
},
"no options set": {
ExpectMaxRetries: sdkclient.DefaultRetryerMaxNumRetries,
},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
client := ec2.New(unit.Session, &aws.Config{
DisableParamValidation: aws.Bool(true),
}, c.Config.Copy())
client.ModifyNetworkInterfaceAttributeWithContext(context.Background(), nil, checkRetryerMaxRetries(t, c.ExpectMaxRetries))
client.AssignPrivateIpAddressesWithContext(context.Background(), nil, checkRetryerMaxRetries(t, c.ExpectMaxRetries))
})
}
}
func checkRetryerMaxRetries(t *testing.T, maxRetries int) func(*request.Request) {
return func(r *request.Request) {
r.Handlers.Send.Clear()
r.Handlers.Send.PushBack(func(rr *request.Request) {
if e, a := maxRetries, rr.Retryer.MaxRetries(); e != a {
t.Errorf("%s, expect %v max retries, got %v", rr.Operation.Name, e, a)
}
rr.HTTPResponse = &http.Response{
StatusCode: 200,
Header: http.Header{},
Body: ioutil.NopCloser(&bytes.Buffer{}),
}
})
}
}