Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor assertions #2301

Merged
merged 1 commit into from Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
58 changes: 25 additions & 33 deletions echo_test.go
Expand Up @@ -491,16 +491,14 @@ func TestEchoURL(t *testing.T) {
g := e.Group("/group")
g.GET("/users/:uid/files/:fid", getFile)

assertion := assert.New(t)

assertion.Equal("/static/file", e.URL(static))
assertion.Equal("/users/:id", e.URL(getUser))
assertion.Equal("/users/1", e.URL(getUser, "1"))
assertion.Equal("/users/1", e.URL(getUser, "1"))
assertion.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt"))
assertion.Equal("/documents/*", e.URL(getAny))
assertion.Equal("/group/users/1/files/:fid", e.URL(getFile, "1"))
assertion.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1"))
assert.Equal(t, "/static/file", e.URL(static))
assert.Equal(t, "/users/:id", e.URL(getUser))
assert.Equal(t, "/users/1", e.URL(getUser, "1"))
assert.Equal(t, "/users/1", e.URL(getUser, "1"))
assert.Equal(t, "/documents/foo.txt", e.URL(getAny, "foo.txt"))
assert.Equal(t, "/documents/*", e.URL(getAny))
assert.Equal(t, "/group/users/1/files/:fid", e.URL(getFile, "1"))
assert.Equal(t, "/group/users/1/files/1", e.URL(getFile, "1", "1"))
}

func TestEchoRoutes(t *testing.T) {
Expand Down Expand Up @@ -607,8 +605,6 @@ func TestEchoServeHTTPPathEncoding(t *testing.T) {
}

func TestEchoHost(t *testing.T) {
assertion := assert.New(t)

okHandler := func(c Context) error { return c.String(http.StatusOK, http.StatusText(http.StatusOK)) }
teapotHandler := func(c Context) error { return c.String(http.StatusTeapot, http.StatusText(http.StatusTeapot)) }
acceptHandler := func(c Context) error { return c.String(http.StatusAccepted, http.StatusText(http.StatusAccepted)) }
Expand Down Expand Up @@ -703,8 +699,8 @@ func TestEchoHost(t *testing.T) {

e.ServeHTTP(rec, req)

assertion.Equal(tc.expectStatus, rec.Code)
assertion.Equal(tc.expectBody, rec.Body.String())
assert.Equal(t, tc.expectStatus, rec.Code)
assert.Equal(t, tc.expectBody, rec.Body.String())
})
}
}
Expand Down Expand Up @@ -1429,8 +1425,6 @@ func TestEchoListenerNetworkInvalid(t *testing.T) {
}

func TestEchoReverse(t *testing.T) {
assert := assert.New(t)

e := New()
dummyHandler := func(Context) error { return nil }

Expand All @@ -1440,33 +1434,31 @@ func TestEchoReverse(t *testing.T) {
e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux"
e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*"

assert.Equal("/static", e.Reverse("/static"))
assert.Equal("/static", e.Reverse("/static", "missing param"))
assert.Equal("/static/*", e.Reverse("/static/*"))
assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt"))

assert.Equal("/params/:foo", e.Reverse("/params/:foo"))
assert.Equal("/params/one", e.Reverse("/params/:foo", "one"))
assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux"))
assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one"))
assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two"))
assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three"))
assert.Equal(t, "/static", e.Reverse("/static"))
assert.Equal(t, "/static", e.Reverse("/static", "missing param"))
assert.Equal(t, "/static/*", e.Reverse("/static/*"))
assert.Equal(t, "/static/foo.txt", e.Reverse("/static/*", "foo.txt"))

assert.Equal(t, "/params/:foo", e.Reverse("/params/:foo"))
assert.Equal(t, "/params/one", e.Reverse("/params/:foo", "one"))
assert.Equal(t, "/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux"))
assert.Equal(t, "/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one"))
assert.Equal(t, "/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two"))
assert.Equal(t, "/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three"))
}

func TestEchoReverseHandleHostProperly(t *testing.T) {
assert := assert.New(t)

dummyHandler := func(Context) error { return nil }

e := New()
h := e.Host("the_host")
h.GET("/static", dummyHandler).Name = "/static"
h.GET("/static/*", dummyHandler).Name = "/static/*"

assert.Equal("/static", e.Reverse("/static"))
assert.Equal("/static", e.Reverse("/static", "missing param"))
assert.Equal("/static/*", e.Reverse("/static/*"))
assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt"))
assert.Equal(t, "/static", e.Reverse("/static"))
assert.Equal(t, "/static", e.Reverse("/static", "missing param"))
assert.Equal(t, "/static/*", e.Reverse("/static/*"))
assert.Equal(t, "/static/foo.txt", e.Reverse("/static/*", "foo.txt"))
}

func TestEcho_ListenerAddr(t *testing.T) {
Expand Down
18 changes: 8 additions & 10 deletions middleware/basic_auth_test.go
Expand Up @@ -26,12 +26,10 @@ func TestBasicAuth(t *testing.T) {
return c.String(http.StatusOK, "test")
})

assert := assert.New(t)

// Valid credentials
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(h(c))
assert.NoError(t, h(c))

h = BasicAuthWithConfig(BasicAuthConfig{
Skipper: nil,
Expand All @@ -44,34 +42,34 @@ func TestBasicAuth(t *testing.T) {
// Valid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(h(c))
assert.NoError(t, h(c))

// Case-insensitive header scheme
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(h(c))
assert.NoError(t, h(c))

// Invalid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))
req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(http.StatusUnauthorized, he.Code)
assert.Equal(basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))

// Invalid base64 string
auth = basic + " invalidString"
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(http.StatusBadRequest, he.Code)
assert.Equal(t, http.StatusBadRequest, he.Code)

// Missing Authorization header
req.Header.Del(echo.HeaderAuthorization)
he = h(c).(*echo.HTTPError)
assert.Equal(http.StatusUnauthorized, he.Code)
assert.Equal(t, http.StatusUnauthorized, he.Code)

// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(http.StatusUnauthorized, he.Code)
assert.Equal(t, http.StatusUnauthorized, he.Code)
}
12 changes: 5 additions & 7 deletions middleware/body_dump_test.go
Expand Up @@ -33,13 +33,11 @@ func TestBodyDump(t *testing.T) {
responseBody = string(resBody)
})

assert := assert.New(t)

if assert.NoError(mw(h)(c)) {
assert.Equal(requestBody, hw)
assert.Equal(responseBody, hw)
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(hw, rec.Body.String())
if assert.NoError(t, mw(h)(c)) {
assert.Equal(t, requestBody, hw)
assert.Equal(t, responseBody, hw)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.String())
}

// Must set default skipper
Expand Down
18 changes: 8 additions & 10 deletions middleware/body_limit_test.go
Expand Up @@ -25,26 +25,24 @@ func TestBodyLimit(t *testing.T) {
return c.String(http.StatusOK, string(body))
}

assert := assert.New(t)

// Based on content length (within limit)
if assert.NoError(BodyLimit("2M")(h)(c)) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(hw, rec.Body.Bytes())
if assert.NoError(t, BodyLimit("2M")(h)(c)) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}

// Based on content length (overlimit)
he := BodyLimit("2B")(h)(c).(*echo.HTTPError)
assert.Equal(http.StatusRequestEntityTooLarge, he.Code)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)

// Based on content read (within limit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
req.ContentLength = -1
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
if assert.NoError(BodyLimit("2M")(h)(c)) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal("Hello, World!", rec.Body.String())
if assert.NoError(t, BodyLimit("2M")(h)(c)) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "Hello, World!", rec.Body.String())
}

// Based on content read (overlimit)
Expand All @@ -53,7 +51,7 @@ func TestBodyLimit(t *testing.T) {
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
he = BodyLimit("2B")(h)(c).(*echo.HTTPError)
assert.Equal(http.StatusRequestEntityTooLarge, he.Code)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
}

func TestBodyLimitReader(t *testing.T) {
Expand Down
26 changes: 12 additions & 14 deletions middleware/compress_test.go
Expand Up @@ -26,24 +26,22 @@ func TestGzip(t *testing.T) {
})
h(c)

assert := assert.New(t)

assert.Equal("test", rec.Body.String())
assert.Equal(t, "test", rec.Body.String())

// Gzip
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
assert.Contains(rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
r, err := gzip.NewReader(rec.Body)
if assert.NoError(err) {
if assert.NoError(t, err) {
buf := new(bytes.Buffer)
defer r.Close()
buf.ReadFrom(r)
assert.Equal("test", buf.String())
assert.Equal(t, "test", buf.String())
}

chunkBuf := make([]byte, 5)
Expand All @@ -63,21 +61,21 @@ func TestGzip(t *testing.T) {
c.Response().Flush()

// Read the first part of the data
assert.True(rec.Flushed)
assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
assert.True(t, rec.Flushed)
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
r.Reset(rec.Body)

_, err = io.ReadFull(r, chunkBuf)
assert.NoError(err)
assert.Equal("test\n", string(chunkBuf))
assert.NoError(t, err)
assert.Equal(t, "test\n", string(chunkBuf))

// Write and flush the second part of the data
c.Response().Write([]byte("test\n"))
c.Response().Flush()

_, err = io.ReadFull(r, chunkBuf)
assert.NoError(err)
assert.Equal("test\n", string(chunkBuf))
assert.NoError(t, err)
assert.Equal(t, "test\n", string(chunkBuf))

// Write the final part of the data and return
c.Response().Write([]byte("test"))
Expand All @@ -87,7 +85,7 @@ func TestGzip(t *testing.T) {
buf := new(bytes.Buffer)
defer r.Close()
buf.ReadFrom(r)
assert.Equal("test", buf.String())
assert.Equal(t, "test", buf.String())
}

func TestGzipNoContent(t *testing.T) {
Expand Down
18 changes: 8 additions & 10 deletions middleware/decompress_test.go
Expand Up @@ -28,8 +28,7 @@ func TestDecompress(t *testing.T) {
})
h(c)

assert := assert.New(t)
assert.Equal("test", rec.Body.String())
assert.Equal(t, "test", rec.Body.String())

// Decompress
body := `{"name": "echo"}`
Expand All @@ -39,10 +38,10 @@ func TestDecompress(t *testing.T) {
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
assert.NoError(t, err)
assert.Equal(t, body, string(b))
}

func TestDecompressDefaultConfig(t *testing.T) {
Expand All @@ -57,8 +56,7 @@ func TestDecompressDefaultConfig(t *testing.T) {
})
h(c)

assert := assert.New(t)
assert.Equal("test", rec.Body.String())
assert.Equal(t, "test", rec.Body.String())

// Decompress
body := `{"name": "echo"}`
Expand All @@ -68,10 +66,10 @@ func TestDecompressDefaultConfig(t *testing.T) {
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
assert.NoError(t, err)
assert.Equal(t, body, string(b))
}

func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
Expand Down
12 changes: 5 additions & 7 deletions middleware/jwt_test.go
Expand Up @@ -348,8 +348,6 @@ func TestJWTConfig(t *testing.T) {
}

func TestJWTwithKID(t *testing.T) {
test := assert.New(t)

e := echo.New()
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
Expand Down Expand Up @@ -417,19 +415,19 @@ func TestJWTwithKID(t *testing.T) {
if tc.expErrCode != 0 {
h := JWTWithConfig(tc.config)(handler)
he := h(c).(*echo.HTTPError)
test.Equal(tc.expErrCode, he.Code, tc.info)
assert.Equal(t, tc.expErrCode, he.Code, tc.info)
continue
}

h := JWTWithConfig(tc.config)(handler)
if test.NoError(h(c), tc.info) {
if assert.NoError(t, h(c), tc.info) {
user := c.Get("user").(*jwt.Token)
switch claims := user.Claims.(type) {
case jwt.MapClaims:
test.Equal(claims["name"], "John Doe", tc.info)
assert.Equal(t, claims["name"], "John Doe", tc.info)
case *jwtCustomClaims:
test.Equal(claims.Name, "John Doe", tc.info)
test.Equal(claims.Admin, true, tc.info)
assert.Equal(t, claims.Name, "John Doe", tc.info)
assert.Equal(t, claims.Admin, true, tc.info)
default:
panic("unexpected type of claims")
}
Expand Down