Skip to content

Commit

Permalink
PR: Context for asychIO in Dial
Browse files Browse the repository at this point in the history
Signed-off-by: Hamza El-Saawy <hamzaelsaawy@microsoft.com>
  • Loading branch information
helsaawy committed Jul 5, 2022
1 parent a615ab2 commit 0997ca6
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 24 deletions.
38 changes: 24 additions & 14 deletions file.go
Expand Up @@ -4,6 +4,7 @@
package winio

import (
"context"
"errors"
"io"
"runtime"
Expand Down Expand Up @@ -178,11 +179,26 @@ func ioCompletionProcessor(h syscall.Handle) {
}
}

// todo: helsaawy - create an asyncIO version that takes a context
// asyncIoContext is similar to asyncIoDeadline, but takes a context.Context instead of a deadlineHandler
func (f *win32File) asyncIoContext(ctx context.Context, c *ioOperation, bytes uint32, err error) (int, error) {
return f.asyncIo(c, ctx.Done(), bytes, err)
}

// asyncIoDeadline processes the return value from ReadFile or WriteFile, blocking until
// the operation has actually completed, or until the deadlineHandler expires
func (f *win32File) asyncIoDeadline(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
var ch timeoutChan
if d != nil {
d.channelLock.Lock()
ch = d.channel
d.channelLock.Unlock()
}
return f.asyncIo(c, ch, bytes, err)
}

// asyncIo processes the return value from ReadFile or WriteFile, blocking until
// the operation has actually completed.
func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
// asyncIo is similar to asyncIoDeadline and asyncIOContext, but instead takes a
// <- chan struct{} parameter to cancel the operation
func (f *win32File) asyncIo(c *ioOperation, ch <-chan struct{}, bytes uint32, err error) (int, error) {
if err != syscall.ERROR_IO_PENDING {
return int(bytes), err
}
Expand All @@ -191,13 +207,6 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
cancelIoEx(f.handle, &c.o)
}

var timeout timeoutChan
if d != nil {
d.channelLock.Lock()
timeout = d.channel
d.channelLock.Unlock()
}

var r ioResult
select {
case r = <-c.ch:
Expand All @@ -211,7 +220,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
var bytes, flags uint32
err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
}
case <-timeout:
case <-ch:
cancelIoEx(f.handle, &c.o)
r = <-c.ch
err = r.err
Expand All @@ -223,6 +232,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
// runtime.KeepAlive is needed, as c is passed via native
// code to ioCompletionProcessor, c must remain alive
// until the channel read is complete.
// todo: should the *ioOperation be (de)allocated via win32 heap functions instead?
runtime.KeepAlive(c)
return int(r.bytes), err
}
Expand All @@ -241,7 +251,7 @@ func (f *win32File) Read(b []byte) (int, error) {

var bytes uint32
err = syscall.ReadFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
n, err := f.asyncIoDeadline(c, &f.readDeadline, bytes, err)
runtime.KeepAlive(b)

// Handle EOF conditions.
Expand All @@ -268,7 +278,7 @@ func (f *win32File) Write(b []byte) (int, error) {

var bytes uint32
err = syscall.WriteFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
n, err := f.asyncIoDeadline(c, &f.writeDeadline, bytes, err)
runtime.KeepAlive(b)
return n, err
}
Expand Down
13 changes: 4 additions & 9 deletions hvsock.go
Expand Up @@ -243,7 +243,7 @@ func (l *HvsockListener) Accept() (_ net.Conn, err error) {

var bytes uint32
err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0 /*rxdatalen*/, addrlen, addrlen, &bytes, &c.o)
if _, err = l.sock.asyncIo(c, nil, bytes, err); err != nil {
if _, err = l.sock.asyncIoDeadline(c, nil, bytes, err); err != nil {
return nil, l.opErr("accept", os.NewSyscallError("acceptex", err))
}

Expand Down Expand Up @@ -348,7 +348,7 @@ func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *Hvsock
0, // sendDataLen
&bytes,
(*windows.Overlapped)(unsafe.Pointer(&c.o)))
_, err = sock.asyncIo(c, nil, bytes, err)
_, err = sock.asyncIoContext(ctx, c, bytes, err)
if i < d.Retries && canRedial(err) {
if err = d.redialWait(ctx); err == nil {
continue
Expand Down Expand Up @@ -379,11 +379,6 @@ func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *Hvsock
}
conn.local.fromRaw(&sal)

// one last check for timeout, since asyncIO doesnt check the context
if err = ctx.Err(); err != nil {
return nil, conn.opErr(op, err)
}

conn.sock = sock
sock = nil

Expand Down Expand Up @@ -445,7 +440,7 @@ func (conn *HvsockConn) Read(b []byte) (int, error) {
buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
var flags, bytes uint32
err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
n, err := conn.sock.asyncIo(c, &conn.sock.readDeadline, bytes, err)
n, err := conn.sock.asyncIoDeadline(c, &conn.sock.readDeadline, bytes, err)
if err != nil {
if eno := windows.Errno(0); errors.As(err, &eno) {
err = os.NewSyscallError("wsarecv", eno)
Expand Down Expand Up @@ -479,7 +474,7 @@ func (conn *HvsockConn) write(b []byte) (int, error) {
buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))}
var bytes uint32
err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
n, err := conn.sock.asyncIo(c, &conn.sock.writeDeadline, bytes, err)
n, err := conn.sock.asyncIoDeadline(c, &conn.sock.writeDeadline, bytes, err)
if err != nil {
if eno := windows.Errno(0); errors.As(err, &eno) {
err = os.NewSyscallError("wsasend", eno)
Expand Down
2 changes: 1 addition & 1 deletion pipe.go
Expand Up @@ -476,7 +476,7 @@ func connectPipe(p *win32File) error {
defer p.wg.Done()

err = connectNamedPipe(p.handle, &c.o)
_, err = p.asyncIo(c, nil, 0, err)
_, err = p.asyncIoDeadline(c, nil, 0, err)
if err != nil && err != cERROR_PIPE_CONNECTED {
return err
}
Expand Down

0 comments on commit 0997ca6

Please sign in to comment.