diff --git a/client.go b/client.go index c2b8ecb6..5f2d6e5a 100644 --- a/client.go +++ b/client.go @@ -20,6 +20,7 @@ package agollo import ( "container/list" "errors" + "strings" "github.com/apolloconfig/agollo/v4/agcache" "github.com/apolloconfig/agollo/v4/agcache/memory" @@ -158,19 +159,40 @@ func (c *internalClient) GetConfigAndInit(namespace string) *storage.Config { return nil } - config := c.cache.GetConfig(namespace) + cfg := c.cache.GetConfig(namespace) - if config == nil { + if cfg == nil { //sync config apolloConfig := syncApolloConfig.SyncWithNamespace(namespace, c.getAppConfig) if apolloConfig != nil { - c.cache.UpdateApolloConfig(apolloConfig, c.getAppConfig) + c.SyncAndUpdate(namespace, apolloConfig) + } + } + + cfg = c.cache.GetConfig(namespace) + + return cfg +} + +func (c *internalClient) SyncAndUpdate(namespace string, apolloConfig *config.ApolloConfig) { + // update appConfig only if namespace does not exist yet + namespaces := strings.Split(c.appConfig.NamespaceName, ",") + exists := false + for _, n := range namespaces { + if n == namespace { + exists = true + break } } + if !exists { + c.appConfig.NamespaceName += "," + namespace + } - config = c.cache.GetConfig(namespace) + // update notification + c.appConfig.GetNotificationsMap().UpdateNotify(namespace, 0) - return config + // update cache + c.cache.UpdateApolloConfig(apolloConfig, c.getAppConfig) } // GetConfigCache 根据namespace获取apollo配置的缓存 diff --git a/client_test.go b/client_test.go index 41ffc31d..8a2d7d00 100644 --- a/client_test.go +++ b/client_test.go @@ -365,17 +365,40 @@ func TestGetConfigAndInitValNotNil(t *testing.T) { AppID: "testID", NamespaceName: "testNotFound", }, - Configurations: map[string]interface{}{"testKey": "testValue"}, + Configurations: map[string]interface{}{"testKey": "testUpdatedValue"}, } }) - defer patch.Reset() client := createMockApolloConfig(120) cf := client.GetConfig("testNotFound") Assert(t, cf, NotNilVal()) - // cache should be updated + + // appConfig notificationsMap appConfig should be updated + Assert(t, client.appConfig.GetNotificationsMap().GetNotify("testNotFound"), Equal(int64(0))) + + // cache should be updated with new configuration Assert(t, client.cache.GetConfig("testNotFound"), NotNilVal()) - Assert(t, client.cache.GetConfig("testNotFound").GetValue("testKey"), Equal("testValue")) + Assert(t, client.cache.GetConfig("testNotFound").GetValue("testKey"), Equal("testUpdatedValue")) + Assert(t, client.appConfig.NamespaceName, Equal("application,testNotFound")) + patch.Reset() + + // second replace + patch1 := gomonkey.ApplyMethod(reflect.TypeOf(apc), "SyncWithNamespace", func(_ *remote.AbsApolloConfig, namespace string, appConfigFunc func() config.AppConfig) *config.ApolloConfig { + return &config.ApolloConfig{ + ApolloConnConfig: config.ApolloConnConfig{ + AppID: "testID", + NamespaceName: "testNotFound1", + }, + Configurations: map[string]interface{}{"testKey": "testUpdatedValue"}, + } + }) + defer patch1.Reset() + client.appConfig.NamespaceName = "testNotFound1" + cf1 := client.GetConfig("testNotFound1") + Assert(t, cf1, NotNilVal()) + Assert(t, client.cache.GetConfig("testNotFound1"), NotNilVal()) + // appConfig namespace existed, should not be appended + Assert(t, client.appConfig.NamespaceName, Equal("testNotFound1")) } func TestGetConfigAndInitValNil(t *testing.T) { diff --git a/component/notify/componet_notify_test.go b/component/notify/componet_notify_test.go index ec3b04a6..bb5c4efc 100644 --- a/component/notify/componet_notify_test.go +++ b/component/notify/componet_notify_test.go @@ -102,3 +102,21 @@ func getTestAppConfig() *config.AppConfig { appConfig.Init() return appConfig } + +func TestConfigComponent_SetAppConfig_UpdatesAppConfigCorrectly(t *testing.T) { + expectedAppConfig := getTestAppConfig() + c := &ConfigComponent{} + // set appConfigFunc + c.SetAppConfig(func() config.AppConfig { + return *expectedAppConfig + }) + + // appConfig should be equal + Assert(t, c.appConfigFunc(), Equal(*expectedAppConfig)) + + // appConfig value is be replaced + expectedAppConfig.AppID = "test1" + expectedAppConfig.NamespaceName = expectedAppConfig.NamespaceName + config.Comma + "abc" + Assert(t, c.appConfigFunc().AppID, Equal("test1")) + Assert(t, c.appConfigFunc().NamespaceName, Equal("application,abc")) +} diff --git a/component/remote/abs.go b/component/remote/abs.go index 801a4177..1952a809 100644 --- a/component/remote/abs.go +++ b/component/remote/abs.go @@ -18,7 +18,6 @@ package remote import ( - "strconv" "time" "github.com/apolloconfig/agollo/v4/component/log" @@ -47,12 +46,7 @@ func (a *AbsApolloConfig) SyncWithNamespace(namespace string, appConfigFunc func IsRetry: true, } if appConfig.SyncServerTimeout > 0 { - duration, err := time.ParseDuration(strconv.Itoa(appConfig.SyncServerTimeout) + "s") - if err != nil { - log.Errorf("parse sync server timeout %s fail, error:%v", appConfig.SyncServerTimeout, err) - return nil - } - c.Timeout = duration + c.Timeout = time.Duration(appConfig.SyncServerTimeout) * time.Second } callback := a.remoteApollo.CallBack(namespace) diff --git a/component/remote/async_test.go b/component/remote/async_test.go index dcb428e4..c7d645a9 100644 --- a/component/remote/async_test.go +++ b/component/remote/async_test.go @@ -146,7 +146,7 @@ func initMockNotifyAndConfigServerWithTwoErrResponse() *httptest.Server { return runMockConfigServer(handlerMap, onlynormaltworesponse) } -//run mock config server +// run mock config server func runMockConfigServer(handlerMap map[string]func(http.ResponseWriter, *http.Request), notifyHandler func(http.ResponseWriter, *http.Request)) *httptest.Server { appConfig := env.InitFileConfig() @@ -177,11 +177,11 @@ func initNotifications() *config.AppConfig { return appConfig } -//Error response -//will hold 5s and keep response 404 -func runErrorResponse() *httptest.Server { +// Error response with status and body +func runErrorResponse(status int, body []byte) *httptest.Server { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) + w.WriteHeader(status) + _, _ = w.Write(body) })) return ts @@ -281,10 +281,26 @@ func TestGetRemoteConfig(t *testing.T) { } func TestErrorGetRemoteConfig(t *testing.T) { + tests := []struct { + name string + status int + errContained string + }{ + {name: "500", status: http.StatusInternalServerError, errContained: "over Max Retry Still Error"}, + {name: "404", status: http.StatusNotFound, errContained: "Connect Apollo Server Fail"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testErrorGetRemoteConfig(t, tt.status, tt.errContained) + }) + } +} + +func testErrorGetRemoteConfig(t *testing.T, status int, errContained string) { //clear initNotifications() appConfig := initNotifications() - server1 := runErrorResponse() + server1 := runErrorResponse(status, nil) appConfig.IP = server1.URL server.SetNextTryConnTime(appConfig.GetHost(), 0) @@ -303,7 +319,7 @@ func TestErrorGetRemoteConfig(t *testing.T) { t.Log("remoteConfigs:", remoteConfigs) t.Log("remoteConfigs size:", len(remoteConfigs)) - Assert(t, "over Max Retry Still Error", Equal(err.Error())) + Assert(t, err.Error(), StartWith(errContained)) } func TestCreateApolloConfigWithJson(t *testing.T) { diff --git a/component/serverlist/sync.go b/component/serverlist/sync.go index 9bb0c402..a554239c 100644 --- a/component/serverlist/sync.go +++ b/component/serverlist/sync.go @@ -19,7 +19,6 @@ package serverlist import ( "encoding/json" - "strconv" "time" "github.com/apolloconfig/agollo/v4/env/server" @@ -66,10 +65,10 @@ func (s *SyncServerIPListComponent) Start() { } } -//SyncServerIPList sync ip list from server -//then -//1.update agcache -//2.store in disk +// SyncServerIPList sync ip list from server +// then +// 1.update agcache +// 2.store in disk func SyncServerIPList(appConfigFunc func() config.AppConfig) (map[string]*config.ServerInfo, error) { if appConfigFunc == nil { panic("can not find apollo config!please confirm!") @@ -80,12 +79,8 @@ func SyncServerIPList(appConfigFunc func() config.AppConfig) (map[string]*config AppID: appConfig.AppID, Secret: appConfig.Secret, } - if appConfigFunc().SyncServerTimeout > 0 { - duration, err := time.ParseDuration(strconv.Itoa(appConfigFunc().SyncServerTimeout) + "s") - if err != nil { - return nil, err - } - c.Timeout = duration + if appConfig.SyncServerTimeout > 0 { + c.Timeout = time.Duration(appConfig.SyncServerTimeout) * time.Second } serverMap, err := http.Request(appConfig.GetServicesConfigURL(), c, &http.CallBack{ SuccessCallBack: SyncServerIPListSuccessCallBack, diff --git a/env/config/config.go b/env/config/config.go index cea37e49..ebfde508 100644 --- a/env/config/config.go +++ b/env/config/config.go @@ -29,17 +29,17 @@ import ( var ( defaultNotificationID = int64(-1) - comma = "," + Comma = "," ) -//File 读写配置文件 +// File 读写配置文件 type File interface { Load(fileName string, unmarshal func([]byte) (interface{}, error)) (interface{}, error) Write(content interface{}, configPath string) error } -//AppConfig 配置文件 +// AppConfig 配置文件 type AppConfig struct { AppID string `json:"appId"` Cluster string `json:"cluster"` @@ -56,7 +56,7 @@ type AppConfig struct { currentConnApolloConfig *CurrentApolloConfig } -//ServerInfo 服务器信息 +// ServerInfo 服务器信息 type ServerInfo struct { AppName string `json:"appName"` InstanceID string `json:"instanceId"` @@ -64,19 +64,19 @@ type ServerInfo struct { IsDown bool `json:"-"` } -//GetIsBackupConfig whether backup config after fetch config from apollo -//false : no -//true : yes (default) +// GetIsBackupConfig whether backup config after fetch config from apollo +// false : no +// true : yes (default) func (a *AppConfig) GetIsBackupConfig() bool { return a.IsBackupConfig } -//GetBackupConfigPath GetBackupConfigPath +// GetBackupConfigPath GetBackupConfigPath func (a *AppConfig) GetBackupConfigPath() string { return a.BackupConfigPath } -//GetHost GetHost +// GetHost GetHost func (a *AppConfig) GetHost() string { u, err := url.Parse(a.IP) if err != nil { @@ -108,10 +108,10 @@ func (a *AppConfig) initAllNotifications(callback func(namespace string)) { } } -//SplitNamespaces 根据namespace字符串分割后,并执行callback函数 +// SplitNamespaces 根据namespace字符串分割后,并执行callback函数 func SplitNamespaces(namespacesStr string, callback func(namespace string)) sync.Map { namespaces := sync.Map{} - split := strings.Split(namespacesStr, comma) + split := strings.Split(namespacesStr, Comma) for _, namespace := range split { if callback != nil { callback(namespace) @@ -126,7 +126,7 @@ func (a *AppConfig) GetNotificationsMap() *notificationsMap { return a.notificationsMap } -//GetServicesConfigURL 获取服务器列表url +// GetServicesConfigURL 获取服务器列表url func (a *AppConfig) GetServicesConfigURL() string { return fmt.Sprintf("%sservices/config?appId=%s&ip=%s", a.GetHost(), diff --git a/protocol/auth/sign/sign.go b/protocol/auth/sign/sign.go index 20bdf2dd..e729637a 100644 --- a/protocol/auth/sign/sign.go +++ b/protocol/auth/sign/sign.go @@ -22,6 +22,7 @@ import ( "crypto/sha1" "encoding/base64" "fmt" + "hash" "net/url" "strconv" "time" @@ -37,6 +38,17 @@ const ( question = "?" ) +var ( + h = sha1.New +) + +// SetHash set hash function +func SetHash(f func() hash.Hash) func() hash.Hash { + o := h + h = f + return o +} + // AuthSignature apollo 授权 type AuthSignature struct { } @@ -63,7 +75,7 @@ func (t *AuthSignature) HTTPHeaders(url string, appID string, secret string) map func signString(stringToSign string, accessKeySecret string) string { key := []byte(accessKeySecret) - mac := hmac.New(sha1.New, key) + mac := hmac.New(h, key) mac.Write([]byte(stringToSign)) return base64.StdEncoding.EncodeToString(mac.Sum(nil)) } diff --git a/protocol/auth/sign/sign_test.go b/protocol/auth/sign/sign_test.go index 520790ad..9cc09913 100644 --- a/protocol/auth/sign/sign_test.go +++ b/protocol/auth/sign/sign_test.go @@ -18,8 +18,10 @@ package sign import ( - . "github.com/tevid/gohamcrest" + "crypto/sha256" "testing" + + . "github.com/tevid/gohamcrest" ) const ( @@ -33,6 +35,13 @@ func TestSignString(t *testing.T) { Assert(t, s, Equal("mcS95GXa7CpCjIfrbxgjKr0lRu8=")) } +func TestSetHash(t *testing.T) { + o := SetHash(sha256.New) + defer func() { SetHash(o) }() + s := signString(rawURL, secret) + Assert(t, s, Equal("XeIN8X6lAoujl6i88icVreaMYlBXeDco348545DkQDY=")) +} + func TestUrl2PathWithQuery(t *testing.T) { pathWithQuery := url2PathWithQuery(rawURL) diff --git a/protocol/http/request.go b/protocol/http/request.go index 1c4a00e1..51bb4b41 100644 --- a/protocol/http/request.go +++ b/protocol/http/request.go @@ -79,7 +79,7 @@ func getDefaultTransport(insecureSkipVerify bool) *http.Transport { return defaultTransport } -//CallBack 请求回调函数 +// CallBack 请求回调函数 type CallBack struct { SuccessCallBack func([]byte, CallBack) (interface{}, error) NotModifyCallBack func() error @@ -87,7 +87,7 @@ type CallBack struct { Namespace string } -//Request 建立网络请求 +// Request 建立网络请求 func Request(requestURL string, connectionConfig *env.ConnectConfig, callBack *CallBack) (interface{}, error) { client := &http.Client{} //如有设置自定义超时时间即使用 @@ -175,6 +175,9 @@ func Request(requestURL string, connectionConfig *env.ConnectConfig, callBack *C return nil, callBack.NotModifyCallBack() } return nil, nil + case http.StatusBadRequest, http.StatusUnauthorized, http.StatusNotFound, http.StatusMethodNotAllowed: + log.Errorf("Connect Apollo Server Fail, url:%s, StatusCode:%d", requestURL, res.StatusCode) + return nil, errors.New(fmt.Sprintf("Connect Apollo Server Fail, StatusCode:%d", res.StatusCode)) default: log.Errorf("Connect Apollo Server Fail, url:%s, StatusCode:%d", requestURL, res.StatusCode) // if error then sleep @@ -190,7 +193,7 @@ func Request(requestURL string, connectionConfig *env.ConnectConfig, callBack *C return nil, err } -//RequestRecovery 可以恢复的请求 +// RequestRecovery 可以恢复的请求 func RequestRecovery(appConfig config.AppConfig, connectConfig *env.ConnectConfig, callBack *CallBack) (interface{}, error) { diff --git a/protocol/http/request_server_test.go b/protocol/http/request_server_test.go index 74998578..71e3ca03 100644 --- a/protocol/http/request_server_test.go +++ b/protocol/http/request_server_test.go @@ -50,10 +50,10 @@ var servicesResponseStr = `[{ var normalBackupConfigCount = 0 -//Normal response -//First request will hold 5s and response http.StatusNotModified -//Second request will hold 5s and response http.StatusNotModified -//Second request will response [{"namespaceName":"application","notificationId":3}] +// Normal response +// First request will hold 5s and response http.StatusNotModified +// Second request will hold 5s and response http.StatusNotModified +// Second request will response [{"namespaceName":"application","notificationId":3}] func runNormalBackupConfigResponse() *httptest.Server { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { normalBackupConfigCount++ @@ -84,7 +84,7 @@ func runNormalBackupConfigResponseWithHTTPS() *httptest.Server { return ts } -//wait long time then response +// wait long time then response func runLongTimeResponse() *httptest.Server { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(10 * time.Second) @@ -94,3 +94,11 @@ func runLongTimeResponse() *httptest.Server { return ts } + +func runStatusCodeResponse(status int) *httptest.Server { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(status) + })) + + return ts +} diff --git a/protocol/http/request_test.go b/protocol/http/request_test.go index 66e2b30f..43774492 100644 --- a/protocol/http/request_test.go +++ b/protocol/http/request_test.go @@ -19,6 +19,7 @@ package http import ( "fmt" + "net/http" "net/url" "testing" "time" @@ -133,6 +134,43 @@ func TestCustomTimeout(t *testing.T) { Assert(t, o, NilVal()) } +func TestFailFastStatusCode(t *testing.T) { + time.Sleep(1 * time.Second) + + tests := []struct { + name string + status int + }{ + {name: "400", status: http.StatusBadRequest}, + {name: "401", status: http.StatusUnauthorized}, + {name: "404", status: http.StatusNotFound}, + {name: "405", status: http.StatusMethodNotAllowed}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testFailFastStatusCode(t, tt.status) + }) + } +} + +func testFailFastStatusCode(t *testing.T, status int) { + server := runStatusCodeResponse(status) + appConfig := getTestAppConfig() + appConfig.IP = server.URL + + startTime := time.Now().Unix() + _, err := RequestRecovery(*appConfig, &env.ConnectConfig{ + URI: getConfigURLSuffix(appConfig, appConfig.NamespaceName), + IsRetry: true, + }, &CallBack{ + SuccessCallBack: nil, + }) + duration := time.Now().Unix() - startTime + + Assert(t, err, NotNilVal()) + Assert(t, int64(0), Equal(duration)) +} + func mockIPList(t *testing.T, appConfigFunc func() config.AppConfig) { time.Sleep(1 * time.Second) @@ -159,7 +197,7 @@ func getConfigURLSuffix(config *config.AppConfig, namespaceName string) string { utils.GetInternal()) } -//SyncServerIPListSuccessCallBack 同步服务器列表成功后的回调 +// SyncServerIPListSuccessCallBack 同步服务器列表成功后的回调 func SyncServerIPListSuccessCallBack(responseBody []byte, callback CallBack) (o interface{}, err error) { log.Debugf("get all server info: %s", string(responseBody))