diff --git a/conncheck.go b/conncheck.go index cc47aa559..70e9925f6 100644 --- a/conncheck.go +++ b/conncheck.go @@ -19,35 +19,36 @@ import ( var errUnexpectedRead = errors.New("unexpected read from socket") -func connCheck(c net.Conn) error { - var ( - n int - err error - buff [1]byte - ) - - sconn, ok := c.(syscall.Conn) +func connCheck(conn net.Conn) error { + var sysErr error + + sysConn, ok := conn.(syscall.Conn) if !ok { return nil } - rc, err := sconn.SyscallConn() + rawConn, err := sysConn.SyscallConn() if err != nil { return err } - rerr := rc.Read(func(fd uintptr) bool { - n, err = syscall.Read(int(fd), buff[:]) + + err = rawConn.Read(func(fd uintptr) bool { + var buf [1]byte + n, err := syscall.Read(int(fd), buf[:]) + switch { + case n == 0 && err == nil: + sysErr = io.EOF + case n > 0: + sysErr = errUnexpectedRead + case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK: + sysErr = nil + default: + sysErr = err + } return true }) - switch { - case rerr != nil: - return rerr - case n == 0 && err == nil: - return io.EOF - case n > 0: - return errUnexpectedRead - case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK: - return nil - default: + if err != nil { return err } + + return sysErr } diff --git a/conncheck_dummy.go b/conncheck_dummy.go index fd01f64c9..4888288aa 100644 --- a/conncheck_dummy.go +++ b/conncheck_dummy.go @@ -12,6 +12,6 @@ package mysql import "net" -func connCheck(c net.Conn) error { +func connCheck(conn net.Conn) error { return nil }