Skip to content

Commit

Permalink
refactor assertions (#2301)
Browse files Browse the repository at this point in the history
  • Loading branch information
aldas committed Oct 12, 2022
1 parent 4c44305 commit 1d5f335
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 91 deletions.
58 changes: 25 additions & 33 deletions echo_test.go
Expand Up @@ -491,16 +491,14 @@ func TestEchoURL(t *testing.T) {
g := e.Group("/group")
g.GET("/users/:uid/files/:fid", getFile)

assertion := assert.New(t)

assertion.Equal("/static/file", e.URL(static))
assertion.Equal("/users/:id", e.URL(getUser))
assertion.Equal("/users/1", e.URL(getUser, "1"))
assertion.Equal("/users/1", e.URL(getUser, "1"))
assertion.Equal("/documents/foo.txt", e.URL(getAny, "foo.txt"))
assertion.Equal("/documents/*", e.URL(getAny))
assertion.Equal("/group/users/1/files/:fid", e.URL(getFile, "1"))
assertion.Equal("/group/users/1/files/1", e.URL(getFile, "1", "1"))
assert.Equal(t, "/static/file", e.URL(static))
assert.Equal(t, "/users/:id", e.URL(getUser))
assert.Equal(t, "/users/1", e.URL(getUser, "1"))
assert.Equal(t, "/users/1", e.URL(getUser, "1"))
assert.Equal(t, "/documents/foo.txt", e.URL(getAny, "foo.txt"))
assert.Equal(t, "/documents/*", e.URL(getAny))
assert.Equal(t, "/group/users/1/files/:fid", e.URL(getFile, "1"))
assert.Equal(t, "/group/users/1/files/1", e.URL(getFile, "1", "1"))
}

func TestEchoRoutes(t *testing.T) {
Expand Down Expand Up @@ -607,8 +605,6 @@ func TestEchoServeHTTPPathEncoding(t *testing.T) {
}

func TestEchoHost(t *testing.T) {
assertion := assert.New(t)

okHandler := func(c Context) error { return c.String(http.StatusOK, http.StatusText(http.StatusOK)) }
teapotHandler := func(c Context) error { return c.String(http.StatusTeapot, http.StatusText(http.StatusTeapot)) }
acceptHandler := func(c Context) error { return c.String(http.StatusAccepted, http.StatusText(http.StatusAccepted)) }
Expand Down Expand Up @@ -703,8 +699,8 @@ func TestEchoHost(t *testing.T) {

e.ServeHTTP(rec, req)

assertion.Equal(tc.expectStatus, rec.Code)
assertion.Equal(tc.expectBody, rec.Body.String())
assert.Equal(t, tc.expectStatus, rec.Code)
assert.Equal(t, tc.expectBody, rec.Body.String())
})
}
}
Expand Down Expand Up @@ -1429,8 +1425,6 @@ func TestEchoListenerNetworkInvalid(t *testing.T) {
}

func TestEchoReverse(t *testing.T) {
assert := assert.New(t)

e := New()
dummyHandler := func(Context) error { return nil }

Expand All @@ -1440,33 +1434,31 @@ func TestEchoReverse(t *testing.T) {
e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux"
e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*"

assert.Equal("/static", e.Reverse("/static"))
assert.Equal("/static", e.Reverse("/static", "missing param"))
assert.Equal("/static/*", e.Reverse("/static/*"))
assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt"))

assert.Equal("/params/:foo", e.Reverse("/params/:foo"))
assert.Equal("/params/one", e.Reverse("/params/:foo", "one"))
assert.Equal("/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux"))
assert.Equal("/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one"))
assert.Equal("/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two"))
assert.Equal("/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three"))
assert.Equal(t, "/static", e.Reverse("/static"))
assert.Equal(t, "/static", e.Reverse("/static", "missing param"))
assert.Equal(t, "/static/*", e.Reverse("/static/*"))
assert.Equal(t, "/static/foo.txt", e.Reverse("/static/*", "foo.txt"))

assert.Equal(t, "/params/:foo", e.Reverse("/params/:foo"))
assert.Equal(t, "/params/one", e.Reverse("/params/:foo", "one"))
assert.Equal(t, "/params/:foo/bar/:qux", e.Reverse("/params/:foo/bar/:qux"))
assert.Equal(t, "/params/one/bar/:qux", e.Reverse("/params/:foo/bar/:qux", "one"))
assert.Equal(t, "/params/one/bar/two", e.Reverse("/params/:foo/bar/:qux", "one", "two"))
assert.Equal(t, "/params/one/bar/two/three", e.Reverse("/params/:foo/bar/:qux/*", "one", "two", "three"))
}

func TestEchoReverseHandleHostProperly(t *testing.T) {
assert := assert.New(t)

dummyHandler := func(Context) error { return nil }

e := New()
h := e.Host("the_host")
h.GET("/static", dummyHandler).Name = "/static"
h.GET("/static/*", dummyHandler).Name = "/static/*"

assert.Equal("/static", e.Reverse("/static"))
assert.Equal("/static", e.Reverse("/static", "missing param"))
assert.Equal("/static/*", e.Reverse("/static/*"))
assert.Equal("/static/foo.txt", e.Reverse("/static/*", "foo.txt"))
assert.Equal(t, "/static", e.Reverse("/static"))
assert.Equal(t, "/static", e.Reverse("/static", "missing param"))
assert.Equal(t, "/static/*", e.Reverse("/static/*"))
assert.Equal(t, "/static/foo.txt", e.Reverse("/static/*", "foo.txt"))
}

func TestEcho_ListenerAddr(t *testing.T) {
Expand Down
18 changes: 8 additions & 10 deletions middleware/basic_auth_test.go
Expand Up @@ -26,12 +26,10 @@ func TestBasicAuth(t *testing.T) {
return c.String(http.StatusOK, "test")
})

assert := assert.New(t)

// Valid credentials
auth := basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(h(c))
assert.NoError(t, h(c))

h = BasicAuthWithConfig(BasicAuthConfig{
Skipper: nil,
Expand All @@ -44,34 +42,34 @@ func TestBasicAuth(t *testing.T) {
// Valid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(h(c))
assert.NoError(t, h(c))

// Case-insensitive header scheme
auth = strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, auth)
assert.NoError(h(c))
assert.NoError(t, h(c))

// Invalid credentials
auth = basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))
req.Header.Set(echo.HeaderAuthorization, auth)
he := h(c).(*echo.HTTPError)
assert.Equal(http.StatusUnauthorized, he.Code)
assert.Equal(basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))
assert.Equal(t, http.StatusUnauthorized, he.Code)
assert.Equal(t, basic+` realm="someRealm"`, res.Header().Get(echo.HeaderWWWAuthenticate))

// Invalid base64 string
auth = basic + " invalidString"
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(http.StatusBadRequest, he.Code)
assert.Equal(t, http.StatusBadRequest, he.Code)

// Missing Authorization header
req.Header.Del(echo.HeaderAuthorization)
he = h(c).(*echo.HTTPError)
assert.Equal(http.StatusUnauthorized, he.Code)
assert.Equal(t, http.StatusUnauthorized, he.Code)

// Invalid Authorization header
auth = base64.StdEncoding.EncodeToString([]byte("invalid"))
req.Header.Set(echo.HeaderAuthorization, auth)
he = h(c).(*echo.HTTPError)
assert.Equal(http.StatusUnauthorized, he.Code)
assert.Equal(t, http.StatusUnauthorized, he.Code)
}
12 changes: 5 additions & 7 deletions middleware/body_dump_test.go
Expand Up @@ -33,13 +33,11 @@ func TestBodyDump(t *testing.T) {
responseBody = string(resBody)
})

assert := assert.New(t)

if assert.NoError(mw(h)(c)) {
assert.Equal(requestBody, hw)
assert.Equal(responseBody, hw)
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(hw, rec.Body.String())
if assert.NoError(t, mw(h)(c)) {
assert.Equal(t, requestBody, hw)
assert.Equal(t, responseBody, hw)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.String())
}

// Must set default skipper
Expand Down
18 changes: 8 additions & 10 deletions middleware/body_limit_test.go
Expand Up @@ -25,26 +25,24 @@ func TestBodyLimit(t *testing.T) {
return c.String(http.StatusOK, string(body))
}

assert := assert.New(t)

// Based on content length (within limit)
if assert.NoError(BodyLimit("2M")(h)(c)) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal(hw, rec.Body.Bytes())
if assert.NoError(t, BodyLimit("2M")(h)(c)) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, hw, rec.Body.Bytes())
}

// Based on content length (overlimit)
he := BodyLimit("2B")(h)(c).(*echo.HTTPError)
assert.Equal(http.StatusRequestEntityTooLarge, he.Code)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)

// Based on content read (within limit)
req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw))
req.ContentLength = -1
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
if assert.NoError(BodyLimit("2M")(h)(c)) {
assert.Equal(http.StatusOK, rec.Code)
assert.Equal("Hello, World!", rec.Body.String())
if assert.NoError(t, BodyLimit("2M")(h)(c)) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "Hello, World!", rec.Body.String())
}

// Based on content read (overlimit)
Expand All @@ -53,7 +51,7 @@ func TestBodyLimit(t *testing.T) {
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
he = BodyLimit("2B")(h)(c).(*echo.HTTPError)
assert.Equal(http.StatusRequestEntityTooLarge, he.Code)
assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code)
}

func TestBodyLimitReader(t *testing.T) {
Expand Down
26 changes: 12 additions & 14 deletions middleware/compress_test.go
Expand Up @@ -26,24 +26,22 @@ func TestGzip(t *testing.T) {
})
h(c)

assert := assert.New(t)

assert.Equal("test", rec.Body.String())
assert.Equal(t, "test", rec.Body.String())

// Gzip
req = httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme)
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
assert.Contains(rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain)
r, err := gzip.NewReader(rec.Body)
if assert.NoError(err) {
if assert.NoError(t, err) {
buf := new(bytes.Buffer)
defer r.Close()
buf.ReadFrom(r)
assert.Equal("test", buf.String())
assert.Equal(t, "test", buf.String())
}

chunkBuf := make([]byte, 5)
Expand All @@ -63,21 +61,21 @@ func TestGzip(t *testing.T) {
c.Response().Flush()

// Read the first part of the data
assert.True(rec.Flushed)
assert.Equal(gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
assert.True(t, rec.Flushed)
assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding))
r.Reset(rec.Body)

_, err = io.ReadFull(r, chunkBuf)
assert.NoError(err)
assert.Equal("test\n", string(chunkBuf))
assert.NoError(t, err)
assert.Equal(t, "test\n", string(chunkBuf))

// Write and flush the second part of the data
c.Response().Write([]byte("test\n"))
c.Response().Flush()

_, err = io.ReadFull(r, chunkBuf)
assert.NoError(err)
assert.Equal("test\n", string(chunkBuf))
assert.NoError(t, err)
assert.Equal(t, "test\n", string(chunkBuf))

// Write the final part of the data and return
c.Response().Write([]byte("test"))
Expand All @@ -87,7 +85,7 @@ func TestGzip(t *testing.T) {
buf := new(bytes.Buffer)
defer r.Close()
buf.ReadFrom(r)
assert.Equal("test", buf.String())
assert.Equal(t, "test", buf.String())
}

func TestGzipNoContent(t *testing.T) {
Expand Down
18 changes: 8 additions & 10 deletions middleware/decompress_test.go
Expand Up @@ -28,8 +28,7 @@ func TestDecompress(t *testing.T) {
})
h(c)

assert := assert.New(t)
assert.Equal("test", rec.Body.String())
assert.Equal(t, "test", rec.Body.String())

// Decompress
body := `{"name": "echo"}`
Expand All @@ -39,10 +38,10 @@ func TestDecompress(t *testing.T) {
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
assert.NoError(t, err)
assert.Equal(t, body, string(b))
}

func TestDecompressDefaultConfig(t *testing.T) {
Expand All @@ -57,8 +56,7 @@ func TestDecompressDefaultConfig(t *testing.T) {
})
h(c)

assert := assert.New(t)
assert.Equal("test", rec.Body.String())
assert.Equal(t, "test", rec.Body.String())

// Decompress
body := `{"name": "echo"}`
Expand All @@ -68,10 +66,10 @@ func TestDecompressDefaultConfig(t *testing.T) {
rec = httptest.NewRecorder()
c = e.NewContext(req, rec)
h(c)
assert.Equal(GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding))
b, err := ioutil.ReadAll(req.Body)
assert.NoError(err)
assert.Equal(body, string(b))
assert.NoError(t, err)
assert.Equal(t, body, string(b))
}

func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) {
Expand Down
12 changes: 5 additions & 7 deletions middleware/jwt_test.go
Expand Up @@ -348,8 +348,6 @@ func TestJWTConfig(t *testing.T) {
}

func TestJWTwithKID(t *testing.T) {
test := assert.New(t)

e := echo.New()
handler := func(c echo.Context) error {
return c.String(http.StatusOK, "test")
Expand Down Expand Up @@ -417,19 +415,19 @@ func TestJWTwithKID(t *testing.T) {
if tc.expErrCode != 0 {
h := JWTWithConfig(tc.config)(handler)
he := h(c).(*echo.HTTPError)
test.Equal(tc.expErrCode, he.Code, tc.info)
assert.Equal(t, tc.expErrCode, he.Code, tc.info)
continue
}

h := JWTWithConfig(tc.config)(handler)
if test.NoError(h(c), tc.info) {
if assert.NoError(t, h(c), tc.info) {
user := c.Get("user").(*jwt.Token)
switch claims := user.Claims.(type) {
case jwt.MapClaims:
test.Equal(claims["name"], "John Doe", tc.info)
assert.Equal(t, claims["name"], "John Doe", tc.info)
case *jwtCustomClaims:
test.Equal(claims.Name, "John Doe", tc.info)
test.Equal(claims.Admin, true, tc.info)
assert.Equal(t, claims.Name, "John Doe", tc.info)
assert.Equal(t, claims.Admin, true, tc.info)
default:
panic("unexpected type of claims")
}
Expand Down

0 comments on commit 1d5f335

Please sign in to comment.