From 9d568fac83ae6627f3a68129bdba0c1300a382a5 Mon Sep 17 00:00:00 2001 From: Sergey Ponomarev Date: Sun, 1 Jan 2023 14:07:05 +0200 Subject: [PATCH] client.go Simplify default UA logic The getClientName() checks if !NoDefaultUserAgentHeader then returns the Client.Name field. But it also saves it to atomic field clientName. This is not needed and logic can be simplified. Previously the clientName vas a byte slice that was copied from c.Name and cached. See 02e0722fb73c6237818b8e5af55957eb919a7334 Fix #1458 --- client.go | 57 +++++++++++++++++++------------------------------- client_test.go | 2 +- strings.go | 2 +- 3 files changed, 23 insertions(+), 38 deletions(-) diff --git a/client.go b/client.go index 263908ca12..23bfa573a9 100644 --- a/client.go +++ b/client.go @@ -788,7 +788,6 @@ type HostClient struct { // Connection pool strategy. Can be either LIFO or FIFO (default). ConnPoolStrategy ConnPoolStrategyType - clientName atomic.Value lastUseTime uint32 connsLock sync.Mutex @@ -1327,9 +1326,14 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error) userAgentOld := req.Header.UserAgent() if len(userAgentOld) == 0 { - req.Header.userAgent = append(req.Header.userAgent[:0], c.getClientName()...) + userAgent := c.Name + if userAgent == "" && !c.NoDefaultUserAgentHeader { + userAgent = defaultUserAgent + } + if userAgent != "" { + req.Header.userAgent = append(req.Header.userAgent[:], userAgent...) + } } - if c.Transport != nil { err := c.Transport(req, resp) return err == nil, err @@ -1990,21 +1994,6 @@ func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig * return conn, nil } -func (c *HostClient) getClientName() []byte { - v := c.clientName.Load() - var clientName []byte - if v == nil { - clientName = []byte(c.Name) - if len(clientName) == 0 && !c.NoDefaultUserAgentHeader { - clientName = defaultUserAgent - } - c.clientName.Store(clientName) - } else { - clientName = v.([]byte) - } - return clientName -} - // AddMissingPort adds a port to a host if it is missing. // A literal IPv6 address in hostport must be enclosed in square // brackets, as in "[::1]:80", "[::1%lo0]:80". @@ -2318,7 +2307,6 @@ type pipelineConnClient struct { tlsConfigLock sync.Mutex tlsConfig *tls.Config - clientName atomic.Value } type pipelineWork struct { @@ -2389,7 +2377,13 @@ func (c *pipelineConnClient) DoDeadline(req *Request, resp *Response, deadline t userAgentOld := req.Header.UserAgent() if len(userAgentOld) == 0 { - req.Header.userAgent = append(req.Header.userAgent[:0], c.getClientName()...) + userAgent := c.Name + if userAgent == "" && !c.NoDefaultUserAgentHeader { + userAgent = defaultUserAgent + } + if userAgent != "" { + req.Header.userAgent = append(req.Header.userAgent[:], userAgent...) + } } w := c.acquirePipelineWork(timeout) @@ -2490,7 +2484,13 @@ func (c *pipelineConnClient) Do(req *Request, resp *Response) error { userAgentOld := req.Header.UserAgent() if len(userAgentOld) == 0 { - req.Header.userAgent = append(req.Header.userAgent[:0], c.getClientName()...) + userAgent := c.Name + if userAgent == "" && !c.NoDefaultUserAgentHeader { + userAgent = defaultUserAgent + } + if userAgent != "" { + req.Header.userAgent = append(req.Header.userAgent[:], userAgent...) + } } w := c.acquirePipelineWork(0) @@ -2886,19 +2886,4 @@ func (c *pipelineConnClient) PendingRequests() int { return n } -func (c *pipelineConnClient) getClientName() []byte { - v := c.clientName.Load() - var clientName []byte - if v == nil { - clientName = []byte(c.Name) - if len(clientName) == 0 && !c.NoDefaultUserAgentHeader { - clientName = defaultUserAgent - } - c.clientName.Store(clientName) - } else { - clientName = v.([]byte) - } - return clientName -} - var errPipelineConnStopped = errors.New("pipeline connection has been stopped") diff --git a/client_test.go b/client_test.go index b5c0b7f9ca..aa53461de5 100644 --- a/client_test.go +++ b/client_test.go @@ -724,7 +724,7 @@ func TestClientDefaultUserAgent(t *testing.T) { if err != nil { t.Fatal(err) } - if userAgentSeen != string(defaultUserAgent) { + if userAgentSeen != defaultUserAgent { t.Fatalf("User-Agent defers %q != %q", userAgentSeen, defaultUserAgent) } } diff --git a/strings.go b/strings.go index 370e307989..dd7a827ba3 100644 --- a/strings.go +++ b/strings.go @@ -2,7 +2,7 @@ package fasthttp var ( defaultServerName = []byte("fasthttp") - defaultUserAgent = []byte("fasthttp") + defaultUserAgent = "fasthttp" defaultContentType = []byte("text/plain; charset=utf-8") )