diff --git a/hvsock.go b/hvsock.go index 1bd34a35..a1064987 100644 --- a/hvsock.go +++ b/hvsock.go @@ -5,6 +5,7 @@ package winio import ( "context" + "errors" "fmt" "io" "net" @@ -437,6 +438,10 @@ func canRedial(err error) bool { } func (conn *HvsockConn) opErr(op string, err error) error { + // translate from "file closed" to "socket closed" + if errors.Is(err, ErrFileClosed) { + err = sockets.ErrSocketClosed + } return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err} } @@ -451,8 +456,8 @@ func (conn *HvsockConn) Read(b []byte) (int, error) { err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil) n, err := conn.sock.asyncIo(c, &conn.sock.readDeadline, bytes, err) if err != nil { - if _, ok := err.(syscall.Errno); ok { - err = os.NewSyscallError("wsarecv", err) + if eno := windows.Errno(0); errors.As(err, &eno) { + err = os.NewSyscallError("wsarecv", eno) } return 0, conn.opErr("read", err) } else if n == 0 { @@ -485,8 +490,8 @@ func (conn *HvsockConn) write(b []byte) (int, error) { err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil) n, err := conn.sock.asyncIo(c, &conn.sock.writeDeadline, bytes, err) if err != nil { - if _, ok := err.(syscall.Errno); ok { - err = os.NewSyscallError("wsasend", err) + if eno := windows.Errno(0); errors.As(err, &eno) { + err = os.NewSyscallError("wsasend", eno) } return 0, conn.opErr("write", err) } @@ -505,11 +510,16 @@ func (conn *HvsockConn) IsClosed() bool { // shutdown disables sending or receiving on a socket func (conn *HvsockConn) shutdown(how int) error { if conn.IsClosed() { - return ErrFileClosed + return sockets.ErrSocketClosed } err := syscall.Shutdown(conn.sock.handle, how) if err != nil { + // If the connection was closed, shutdowns fail with "not connected" + if errors.Is(err, windows.WSAENOTCONN) || + errors.Is(err, windows.WSAESHUTDOWN) { + err = sockets.ErrSocketClosed + } return os.NewSyscallError("shutdown", err) } return nil diff --git a/pkg/sockets/rawaddr.go b/pkg/sockets/rawaddr.go index 3680cbdd..80d9ecb2 100644 --- a/pkg/sockets/rawaddr.go +++ b/pkg/sockets/rawaddr.go @@ -21,7 +21,7 @@ var ( // https://docs.microsoft.com/en-us/windows/win32/winsock/sockaddr-2 type RawSockaddr interface { // Sockaddr returns a pointer to the RawSockaddr and the length of the struct. - Sockaddr() (ptr unsafe.Pointer, len int32, err error) + Sockaddr() (unsafe.Pointer, int32, error) // FromBytes populates the RawsockAddr with the data in the byte array. // Implementers should check the buffer is correctly sized and the address family @@ -30,12 +30,12 @@ type RawSockaddr interface { FromBytes([]byte) error } -func validateSockAddr(ptr unsafe.Pointer, len int32) error { +func validateSockAddr(ptr unsafe.Pointer, n int32) error { if ptr == nil { return fmt.Errorf("pointer is %p: %w", ptr, ErrInvalidPointer) } - if len < 1 { - return fmt.Errorf("buffer size %d < 1: %w", len, ErrBufferSize) + if n < 1 { + return fmt.Errorf("buffer size %d < 1: %w", n, ErrBufferSize) } return nil } diff --git a/pkg/sockets/sockets.go b/pkg/sockets/sockets.go index ffe8a87c..b707da03 100644 --- a/pkg/sockets/sockets.go +++ b/pkg/sockets/sockets.go @@ -21,6 +21,8 @@ import ( const socketError = uintptr(^uint32(0)) +var ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed) + // CloseWriter is a connection that can disable writing to itself. type CloseWriter interface { net.Conn @@ -125,7 +127,6 @@ func (f *runtimeFunc) Load() error { ) }) return f.err - } var ( @@ -141,9 +142,8 @@ var ( ) func ConnectEx(fd windows.Handle, rsa RawSockaddr, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *windows.Overlapped) error { - err := connectExFunc.Load() - if err != nil { - return fmt.Errorf("failed to load ConnectEx function pointer: %e", err) + if err := connectExFunc.Load(); err != nil { + return fmt.Errorf("failed to load ConnectEx function pointer: %w", err) } ptr, n, err := rsa.Sockaddr() if err != nil {