From 0997ca6eb4f1cf2488c2b6014628b87114d9b301 Mon Sep 17 00:00:00 2001 From: Hamza El-Saawy Date: Tue, 5 Jul 2022 18:51:21 -0400 Subject: [PATCH] PR: Context for asychIO in Dial Signed-off-by: Hamza El-Saawy --- file.go | 38 ++++++++++++++++++++++++-------------- hvsock.go | 13 ++++--------- pipe.go | 2 +- 3 files changed, 29 insertions(+), 24 deletions(-) diff --git a/file.go b/file.go index 1b870350..02f4902b 100644 --- a/file.go +++ b/file.go @@ -4,6 +4,7 @@ package winio import ( + "context" "errors" "io" "runtime" @@ -178,11 +179,26 @@ func ioCompletionProcessor(h syscall.Handle) { } } -// todo: helsaawy - create an asyncIO version that takes a context +// asyncIoContext is similar to asyncIoDeadline, but takes a context.Context instead of a deadlineHandler +func (f *win32File) asyncIoContext(ctx context.Context, c *ioOperation, bytes uint32, err error) (int, error) { + return f.asyncIo(c, ctx.Done(), bytes, err) +} + +// asyncIoDeadline processes the return value from ReadFile or WriteFile, blocking until +// the operation has actually completed, or until the deadlineHandler expires +func (f *win32File) asyncIoDeadline(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { + var ch timeoutChan + if d != nil { + d.channelLock.Lock() + ch = d.channel + d.channelLock.Unlock() + } + return f.asyncIo(c, ch, bytes, err) +} -// asyncIo processes the return value from ReadFile or WriteFile, blocking until -// the operation has actually completed. -func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { +// asyncIo is similar to asyncIoDeadline and asyncIOContext, but instead takes a +// <- chan struct{} parameter to cancel the operation +func (f *win32File) asyncIo(c *ioOperation, ch <-chan struct{}, bytes uint32, err error) (int, error) { if err != syscall.ERROR_IO_PENDING { return int(bytes), err } @@ -191,13 +207,6 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er cancelIoEx(f.handle, &c.o) } - var timeout timeoutChan - if d != nil { - d.channelLock.Lock() - timeout = d.channel - d.channelLock.Unlock() - } - var r ioResult select { case r = <-c.ch: @@ -211,7 +220,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er var bytes, flags uint32 err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) } - case <-timeout: + case <-ch: cancelIoEx(f.handle, &c.o) r = <-c.ch err = r.err @@ -223,6 +232,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er // runtime.KeepAlive is needed, as c is passed via native // code to ioCompletionProcessor, c must remain alive // until the channel read is complete. + // todo: should the *ioOperation be (de)allocated via win32 heap functions instead? runtime.KeepAlive(c) return int(r.bytes), err } @@ -241,7 +251,7 @@ func (f *win32File) Read(b []byte) (int, error) { var bytes uint32 err = syscall.ReadFile(f.handle, b, &bytes, &c.o) - n, err := f.asyncIo(c, &f.readDeadline, bytes, err) + n, err := f.asyncIoDeadline(c, &f.readDeadline, bytes, err) runtime.KeepAlive(b) // Handle EOF conditions. @@ -268,7 +278,7 @@ func (f *win32File) Write(b []byte) (int, error) { var bytes uint32 err = syscall.WriteFile(f.handle, b, &bytes, &c.o) - n, err := f.asyncIo(c, &f.writeDeadline, bytes, err) + n, err := f.asyncIoDeadline(c, &f.writeDeadline, bytes, err) runtime.KeepAlive(b) return n, err } diff --git a/hvsock.go b/hvsock.go index 8872f471..b90c6e1c 100644 --- a/hvsock.go +++ b/hvsock.go @@ -243,7 +243,7 @@ func (l *HvsockListener) Accept() (_ net.Conn, err error) { var bytes uint32 err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0 /*rxdatalen*/, addrlen, addrlen, &bytes, &c.o) - if _, err = l.sock.asyncIo(c, nil, bytes, err); err != nil { + if _, err = l.sock.asyncIoDeadline(c, nil, bytes, err); err != nil { return nil, l.opErr("accept", os.NewSyscallError("acceptex", err)) } @@ -348,7 +348,7 @@ func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *Hvsock 0, // sendDataLen &bytes, (*windows.Overlapped)(unsafe.Pointer(&c.o))) - _, err = sock.asyncIo(c, nil, bytes, err) + _, err = sock.asyncIoContext(ctx, c, bytes, err) if i < d.Retries && canRedial(err) { if err = d.redialWait(ctx); err == nil { continue @@ -379,11 +379,6 @@ func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *Hvsock } conn.local.fromRaw(&sal) - // one last check for timeout, since asyncIO doesnt check the context - if err = ctx.Err(); err != nil { - return nil, conn.opErr(op, err) - } - conn.sock = sock sock = nil @@ -445,7 +440,7 @@ func (conn *HvsockConn) Read(b []byte) (int, error) { buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))} var flags, bytes uint32 err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil) - n, err := conn.sock.asyncIo(c, &conn.sock.readDeadline, bytes, err) + n, err := conn.sock.asyncIoDeadline(c, &conn.sock.readDeadline, bytes, err) if err != nil { if eno := windows.Errno(0); errors.As(err, &eno) { err = os.NewSyscallError("wsarecv", eno) @@ -479,7 +474,7 @@ func (conn *HvsockConn) write(b []byte) (int, error) { buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))} var bytes uint32 err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil) - n, err := conn.sock.asyncIo(c, &conn.sock.writeDeadline, bytes, err) + n, err := conn.sock.asyncIoDeadline(c, &conn.sock.writeDeadline, bytes, err) if err != nil { if eno := windows.Errno(0); errors.As(err, &eno) { err = os.NewSyscallError("wsasend", eno) diff --git a/pipe.go b/pipe.go index 1acb2014..5cefabb5 100644 --- a/pipe.go +++ b/pipe.go @@ -476,7 +476,7 @@ func connectPipe(p *win32File) error { defer p.wg.Done() err = connectNamedPipe(p.handle, &c.o) - _, err = p.asyncIo(c, nil, 0, err) + _, err = p.asyncIoDeadline(c, nil, 0, err) if err != nil && err != cERROR_PIPE_CONNECTED { return err }