diff --git a/client/request.go b/client/request.go index b7ea8b81..07ec972d 100644 --- a/client/request.go +++ b/client/request.go @@ -278,7 +278,17 @@ DoneChoosingBodySource: if r.pathPattern != "" && r.pathPattern != "/" && r.pathPattern[len(r.pathPattern)-1] == '/' { reinstateSlash = true } - urlPath := path.Join(basePath, r.pathPattern) + + // In case the basePath includes hardcoded query parameters, parse those out before + // constructing the final path. The parameters themselves will be merged with the + // ones set by the client, with the priority given to the latter. + basePathURL, err := url.Parse(basePath) + if err != nil { + return nil, err + } + basePathQueryParams := basePathURL.Query() + + urlPath := path.Join(basePathURL.Path, r.pathPattern) for k, v := range r.pathParams { urlPath = strings.Replace(urlPath, "{"+k+"}", url.PathEscape(v), -1) } @@ -291,6 +301,19 @@ DoneChoosingBodySource: return nil, err } + originalParams := r.GetQueryParams() + + // Merge the query parameters extracted from the basePath with the ones set by + // the client in this struct. In case of conflict, the client wins. + for k, v := range basePathQueryParams { + _, present := originalParams[k] + if !present { + if err = r.SetQueryParam(k, v...); err != nil { + return nil, err + } + } + } + req.URL.RawQuery = r.query.Encode() req.Header = r.header diff --git a/client/request_test.go b/client/request_test.go index 9979fd8c..86d0b0ff 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -517,6 +517,39 @@ func TestBuildRequest_BuildHTTP_EscapedPath(t *testing.T) { } } +func TestBuildRequest_BuildHTTP_BasePathWithParameters(t *testing.T) { + reqWrtr := runtime.ClientRequestWriterFunc(func(req runtime.ClientRequest, reg strfmt.Registry) error { + _ = req.SetBodyParam(nil) + _ = req.SetQueryParam("hello", "world") + _ = req.SetPathParam("id", "1234") + return nil + }) + r, _ := newRequest("POST", "/flats/{id}/", reqWrtr) + + req, err := r.BuildHTTP(runtime.JSONMime, "/basepath?foo=bar", testProducers, nil) + if assert.NoError(t, err) && assert.NotNil(t, req) { + assert.Equal(t, "world", req.URL.Query().Get("hello")) + assert.Equal(t, "bar", req.URL.Query().Get("foo")) + assert.Equal(t, "/basepath/flats/1234/", req.URL.Path) + } +} + +func TestBuildRequest_BuildHTTP_BasePathWithConflictingParameters(t *testing.T) { + reqWrtr := runtime.ClientRequestWriterFunc(func(req runtime.ClientRequest, reg strfmt.Registry) error { + _ = req.SetBodyParam(nil) + _ = req.SetQueryParam("hello", "world") + _ = req.SetPathParam("id", "1234") + return nil + }) + r, _ := newRequest("POST", "/flats/{id}/", reqWrtr) + + req, err := r.BuildHTTP(runtime.JSONMime, "/basepath?hello=kitty", testProducers, nil) + if assert.NoError(t, err) && assert.NotNil(t, req) { + assert.Equal(t, "world", req.URL.Query().Get("hello")) + assert.Equal(t, "/basepath/flats/1234/", req.URL.Path) + } +} + type testReqFn func(*testing.T, *http.Request) type testRoundTripper struct {