diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e80cf536..2797ea89 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,14 +27,14 @@ jobs: --modules-download-mode=readonly go-generate: - name: Validate Go Generate + name: Go Generate runs-on: windows-2019 steps: - uses: actions/checkout@v3 - uses: actions/setup-go@v3 with: go-version: ${{ env.GO_VERSION }} - - name: Verify generated files + - name: Run go generate shell: pwsh run: | Write-Output "::group::go generate" @@ -44,7 +44,9 @@ jobs: Write-Output "::error title=Go Generate::Error running go generate." exit $LASTEXITCODE } - + - name: Diff + shell: pwsh + run: | git add -N . Write-Output "::group::git diff" git diff --stat --exit-code diff --git a/backup.go b/backup.go index 6b3f121f..53fbac7e 100644 --- a/backup.go +++ b/backup.go @@ -8,11 +8,12 @@ import ( "errors" "fmt" "io" - "io/ioutil" "os" "runtime" "syscall" "unicode/utf16" + + "golang.org/x/sys/windows" ) //sys backupRead(h syscall.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupRead @@ -25,7 +26,7 @@ const ( BackupAlternateData BackupLink BackupPropertyData - BackupObjectId + BackupObjectId //revive:disable-line:var-naming ID, not Id BackupReparseData BackupSparseBlock BackupTxfsData @@ -35,14 +36,16 @@ const ( StreamSparseAttributes = uint32(8) ) +//nolint:revive // var-naming: ALL_CAPS const ( - WRITE_DAC = 0x40000 - WRITE_OWNER = 0x80000 - ACCESS_SYSTEM_SECURITY = 0x1000000 + WRITE_DAC = windows.WRITE_DAC + WRITE_OWNER = windows.WRITE_DAC + ACCESS_SYSTEM_SECURITY = windows.ACCESS_SYSTEM_SECURITY ) // BackupHeader represents a backup stream of a file. type BackupHeader struct { + //revive:disable-next-line:var-naming ID, not Id Id uint32 // The backup stream ID Attributes uint32 // Stream attributes Size int64 // The size of the stream in bytes @@ -50,8 +53,8 @@ type BackupHeader struct { Offset int64 // The offset of the stream in the file (for BackupSparseBlock only). } -type win32StreamId struct { - StreamId uint32 +type win32StreamID struct { + StreamID uint32 Attributes uint32 Size uint64 NameSize uint32 @@ -72,7 +75,7 @@ func NewBackupStreamReader(r io.Reader) *BackupStreamReader { // Next returns the next backup stream and prepares for calls to Read(). It skips the remainder of the current stream if // it was not completely read. func (r *BackupStreamReader) Next() (*BackupHeader, error) { - if r.bytesLeft > 0 { + if r.bytesLeft > 0 { //nolint:nestif // todo: flatten this if s, ok := r.r.(io.Seeker); ok { // Make sure Seek on io.SeekCurrent sometimes succeeds // before trying the actual seek. @@ -83,16 +86,16 @@ func (r *BackupStreamReader) Next() (*BackupHeader, error) { r.bytesLeft = 0 } } - if _, err := io.Copy(ioutil.Discard, r); err != nil { + if _, err := io.Copy(io.Discard, r); err != nil { return nil, err } } - var wsi win32StreamId + var wsi win32StreamID if err := binary.Read(r.r, binary.LittleEndian, &wsi); err != nil { return nil, err } hdr := &BackupHeader{ - Id: wsi.StreamId, + Id: wsi.StreamID, Attributes: wsi.Attributes, Size: int64(wsi.Size), } @@ -103,7 +106,7 @@ func (r *BackupStreamReader) Next() (*BackupHeader, error) { } hdr.Name = syscall.UTF16ToString(name) } - if wsi.StreamId == BackupSparseBlock { + if wsi.StreamID == BackupSparseBlock { if err := binary.Read(r.r, binary.LittleEndian, &hdr.Offset); err != nil { return nil, err } @@ -148,8 +151,8 @@ func (w *BackupStreamWriter) WriteHeader(hdr *BackupHeader) error { return fmt.Errorf("missing %d bytes", w.bytesLeft) } name := utf16.Encode([]rune(hdr.Name)) - wsi := win32StreamId{ - StreamId: hdr.Id, + wsi := win32StreamID{ + StreamID: hdr.Id, Attributes: hdr.Attributes, Size: uint64(hdr.Size), NameSize: uint32(len(name) * 2), @@ -204,7 +207,7 @@ func (r *BackupFileReader) Read(b []byte) (int, error) { var bytesRead uint32 err := backupRead(syscall.Handle(r.f.Fd()), b, &bytesRead, false, r.includeSecurity, &r.ctx) if err != nil { - return 0, &os.PathError{"BackupRead", r.f.Name(), err} + return 0, &os.PathError{Op: "BackupRead", Path: r.f.Name(), Err: err} } runtime.KeepAlive(r.f) if bytesRead == 0 { @@ -217,7 +220,7 @@ func (r *BackupFileReader) Read(b []byte) (int, error) { // the underlying file. func (r *BackupFileReader) Close() error { if r.ctx != 0 { - backupRead(syscall.Handle(r.f.Fd()), nil, nil, true, false, &r.ctx) + _ = backupRead(syscall.Handle(r.f.Fd()), nil, nil, true, false, &r.ctx) runtime.KeepAlive(r.f) r.ctx = 0 } @@ -243,7 +246,7 @@ func (w *BackupFileWriter) Write(b []byte) (int, error) { var bytesWritten uint32 err := backupWrite(syscall.Handle(w.f.Fd()), b, &bytesWritten, false, w.includeSecurity, &w.ctx) if err != nil { - return 0, &os.PathError{"BackupWrite", w.f.Name(), err} + return 0, &os.PathError{Op: "BackupWrite", Path: w.f.Name(), Err: err} } runtime.KeepAlive(w.f) if int(bytesWritten) != len(b) { @@ -256,7 +259,7 @@ func (w *BackupFileWriter) Write(b []byte) (int, error) { // close the underlying file. func (w *BackupFileWriter) Close() error { if w.ctx != 0 { - backupWrite(syscall.Handle(w.f.Fd()), nil, nil, true, false, &w.ctx) + _ = backupWrite(syscall.Handle(w.f.Fd()), nil, nil, true, false, &w.ctx) runtime.KeepAlive(w.f) w.ctx = 0 } @@ -272,7 +275,13 @@ func OpenForBackup(path string, access uint32, share uint32, createmode uint32) if err != nil { return nil, err } - h, err := syscall.CreateFile(&winPath[0], access, share, nil, createmode, syscall.FILE_FLAG_BACKUP_SEMANTICS|syscall.FILE_FLAG_OPEN_REPARSE_POINT, 0) + h, err := syscall.CreateFile(&winPath[0], + access, + share, + nil, + createmode, + syscall.FILE_FLAG_BACKUP_SEMANTICS|syscall.FILE_FLAG_OPEN_REPARSE_POINT, + 0) if err != nil { err = &os.PathError{Op: "open", Path: path, Err: err} return nil, err diff --git a/backup_test.go b/backup_test.go index eaa8a2e9..afd26c02 100644 --- a/backup_test.go +++ b/backup_test.go @@ -5,7 +5,6 @@ package winio import ( "io" - "io/ioutil" "os" "syscall" "testing" @@ -14,7 +13,7 @@ import ( var testFileName string func TestMain(m *testing.M) { - f, err := ioutil.TempFile("", "tmp") + f, err := os.CreateTemp("", "tmp") if err != nil { panic(err) } @@ -62,7 +61,7 @@ func TestBackupRead(t *testing.T) { defer f.Close() r := NewBackupFileReader(f, false) defer r.Close() - b, err := ioutil.ReadAll(r) + b, err := io.ReadAll(r) if err != nil { t.Fatal(err) } @@ -90,7 +89,7 @@ func TestBackupStreamRead(t *testing.T) { gotAltData := false for { hdr, err := br.Next() - if err == io.EOF { + if err == io.EOF { //nolint:errorlint break } if err != nil { @@ -105,7 +104,7 @@ func TestBackupStreamRead(t *testing.T) { if hdr.Name != "" { t.Fatalf("unexpected name %s", hdr.Name) } - b, err := ioutil.ReadAll(br) + b, err := io.ReadAll(br) if err != nil { t.Fatal(err) } @@ -120,7 +119,7 @@ func TestBackupStreamRead(t *testing.T) { if hdr.Name != ":ads.txt:$DATA" { t.Fatalf("incorrect name %s", hdr.Name) } - b, err := ioutil.ReadAll(br) + b, err := io.ReadAll(br) if err != nil { t.Fatal(err) } @@ -176,7 +175,7 @@ func TestBackupStreamWrite(t *testing.T) { f.Close() - b, err := ioutil.ReadFile(testFileName) + b, err := os.ReadFile(testFileName) if err != nil { t.Fatal(err) } @@ -184,7 +183,7 @@ func TestBackupStreamWrite(t *testing.T) { t.Fatalf("wrong data %v", b) } - b, err = ioutil.ReadFile(testFileName + ":ads.txt") + b, err = os.ReadFile(testFileName + ":ads.txt") if err != nil { t.Fatal(err) } @@ -202,11 +201,11 @@ func makeSparseFile() error { defer f.Close() const ( - FSCTL_SET_SPARSE = 0x000900c4 - FSCTL_SET_ZERO_DATA = 0x000980c8 + fsctlSetSparse = 0x000900c4 + fsctlSetZeroData = 0x000980c8 ) - err = syscall.DeviceIoControl(syscall.Handle(f.Fd()), FSCTL_SET_SPARSE, nil, 0, nil, 0, nil, nil) + err = syscall.DeviceIoControl(syscall.Handle(f.Fd()), fsctlSetSparse, nil, 0, nil, 0, nil, nil) if err != nil { return err } @@ -246,7 +245,7 @@ func TestBackupSparseFile(t *testing.T) { br := NewBackupStreamReader(r) for { hdr, err := br.Next() - if err == io.EOF { + if err == io.EOF { //nolint:errorlint break } if err != nil { diff --git a/backuptar/noop.go b/backuptar/doc.go similarity index 81% rename from backuptar/noop.go rename to backuptar/doc.go index d39eccf0..965d52ab 100644 --- a/backuptar/noop.go +++ b/backuptar/doc.go @@ -1,4 +1,3 @@ -// +build !windows // This file only exists to allow go get on non-Windows platforms. package backuptar diff --git a/backuptar/strconv.go b/backuptar/strconv.go index 34160966..455fd798 100644 --- a/backuptar/strconv.go +++ b/backuptar/strconv.go @@ -1,3 +1,5 @@ +//go:build windows + package backuptar import ( diff --git a/backuptar/tar.go b/backuptar/tar.go index 038caff3..6b3b0cd5 100644 --- a/backuptar/tar.go +++ b/backuptar/tar.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package backuptar @@ -7,7 +8,6 @@ import ( "encoding/base64" "fmt" "io" - "io/ioutil" "path/filepath" "strconv" "strings" @@ -18,17 +18,18 @@ import ( "golang.org/x/sys/windows" ) +//nolint:deadcode,varcheck // keep unused constants for potential future use const ( - c_ISUID = 04000 // Set uid - c_ISGID = 02000 // Set gid - c_ISVTX = 01000 // Save text (sticky bit) - c_ISDIR = 040000 // Directory - c_ISFIFO = 010000 // FIFO - c_ISREG = 0100000 // Regular file - c_ISLNK = 0120000 // Symbolic link - c_ISBLK = 060000 // Block special file - c_ISCHR = 020000 // Character special file - c_ISSOCK = 0140000 // Socket + cISUID = 0004000 // Set uid + cISGID = 0002000 // Set gid + cISVTX = 0001000 // Save text (sticky bit) + cISDIR = 0040000 // Directory + cISFIFO = 0010000 // FIFO + cISREG = 0100000 // Regular file + cISLNK = 0120000 // Symbolic link + cISBLK = 0060000 // Block special file + cISCHR = 0020000 // Character special file + cISSOCK = 0140000 // Socket ) const ( @@ -44,7 +45,7 @@ const ( // zeroReader is an io.Reader that always returns 0s. type zeroReader struct{} -func (zr zeroReader) Read(b []byte) (int, error) { +func (zeroReader) Read(b []byte) (int, error) { for i := range b { b[i] = 0 } @@ -55,7 +56,7 @@ func copySparse(t *tar.Writer, br *winio.BackupStreamReader) error { curOffset := int64(0) for { bhdr, err := br.Next() - if err == io.EOF { + if err == io.EOF { //nolint:errorlint err = io.ErrUnexpectedEOF } if err != nil { @@ -71,8 +72,8 @@ func copySparse(t *tar.Writer, br *winio.BackupStreamReader) error { } // archive/tar does not support writing sparse files // so just write zeroes to catch up to the current offset. - if _, err := io.CopyN(t, zeroReader{}, bhdr.Offset-curOffset); err != nil { - return fmt.Errorf("seek to offset %d: %s", bhdr.Offset, err) + if _, err = io.CopyN(t, zeroReader{}, bhdr.Offset-curOffset); err != nil { + return fmt.Errorf("seek to offset %d: %w", bhdr.Offset, err) } if bhdr.Size == 0 { // A sparse block with size = 0 is used to mark the end of the sparse blocks. @@ -106,7 +107,7 @@ func BasicInfoHeader(name string, size int64, fileInfo *winio.FileBasicInfo) *ta hdr.PAXRecords[hdrCreationTime] = formatPAXTime(time.Unix(0, fileInfo.CreationTime.Nanoseconds())) if (fileInfo.FileAttributes & syscall.FILE_ATTRIBUTE_DIRECTORY) != 0 { - hdr.Mode |= c_ISDIR + hdr.Mode |= cISDIR hdr.Size = 0 hdr.Typeflag = tar.TypeDir } @@ -138,9 +139,7 @@ func SecurityDescriptorFromTarHeader(hdr *tar.Header) ([]byte, error) { // ExtendedAttributesFromTarHeader reads the EAs associated with the header of the // current file from the tar header and returns it as a byte slice. func ExtendedAttributesFromTarHeader(hdr *tar.Header) ([]byte, error) { - var eas []winio.ExtendedAttribute - var eadata []byte - var err error + var eas []winio.ExtendedAttribute //nolint:prealloc // len(eas) <= len(hdr.PAXRecords); prealloc is wasteful for k, v := range hdr.PAXRecords { if !strings.HasPrefix(k, hdrEaPrefix) { continue @@ -154,13 +153,15 @@ func ExtendedAttributesFromTarHeader(hdr *tar.Header) ([]byte, error) { Value: data, }) } + var eaData []byte + var err error if len(eas) != 0 { - eadata, err = winio.EncodeExtendedAttributes(eas) + eaData, err = winio.EncodeExtendedAttributes(eas) if err != nil { return nil, err } } - return eadata, nil + return eaData, nil } // EncodeReparsePointFromTarHeader reads the ReparsePoint structure from the tar header @@ -181,11 +182,9 @@ func EncodeReparsePointFromTarHeader(hdr *tar.Header) []byte { // // The additional Win32 metadata is: // -// MSWINDOWS.fileattr: The Win32 file attributes, as a decimal value -// -// MSWINDOWS.rawsd: The Win32 security descriptor, in raw binary format -// -// MSWINDOWS.mountpoint: If present, this is a mount point and not a symlink, even though the type is '2' (symlink) +// - MSWINDOWS.fileattr: The Win32 file attributes, as a decimal value +// - MSWINDOWS.rawsd: The Win32 security descriptor, in raw binary format +// - MSWINDOWS.mountpoint: If present, this is a mount point and not a symlink, even though the type is '2' (symlink) func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size int64, fileInfo *winio.FileBasicInfo) error { name = filepath.ToSlash(name) hdr := BasicInfoHeader(name, size, fileInfo) @@ -208,7 +207,7 @@ func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size var dataHdr *winio.BackupHeader for dataHdr == nil { bhdr, err := br.Next() - if err == io.EOF { + if err == io.EOF { //nolint:errorlint break } if err != nil { @@ -216,21 +215,21 @@ func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size } switch bhdr.Id { case winio.BackupData: - hdr.Mode |= c_ISREG + hdr.Mode |= cISREG if !readTwice { dataHdr = bhdr } case winio.BackupSecurity: - sd, err := ioutil.ReadAll(br) + sd, err := io.ReadAll(br) if err != nil { return err } hdr.PAXRecords[hdrRawSecurityDescriptor] = base64.StdEncoding.EncodeToString(sd) case winio.BackupReparseData: - hdr.Mode |= c_ISLNK + hdr.Mode |= cISLNK hdr.Typeflag = tar.TypeSymlink - reparseBuffer, err := ioutil.ReadAll(br) + reparseBuffer, _ := io.ReadAll(br) rp, err := winio.DecodeReparsePoint(reparseBuffer) if err != nil { return err @@ -241,7 +240,7 @@ func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size hdr.Linkname = rp.Target case winio.BackupEaData: - eab, err := ioutil.ReadAll(br) + eab, err := io.ReadAll(br) if err != nil { return err } @@ -275,7 +274,7 @@ func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size } for dataHdr == nil { bhdr, err := br.Next() - if err == io.EOF { + if err == io.EOF { //nolint:errorlint break } if err != nil { @@ -310,7 +309,7 @@ func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size // range of the file containing the range contents. Finally there is a sparse block stream with // size = 0 and offset = . - if dataHdr != nil { + if dataHdr != nil { //nolint:nestif // todo: reduce nesting complexity // A data stream was found. Copy the data. // We assume that we will either have a data stream size > 0 XOR have sparse block streams. if dataHdr.Size > 0 || (dataHdr.Attributes&winio.StreamSparseAttributes) == 0 { @@ -318,13 +317,13 @@ func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size return fmt.Errorf("%s: mismatch between file size %d and header size %d", name, size, dataHdr.Size) } if _, err = io.Copy(t, br); err != nil { - return fmt.Errorf("%s: copying contents from data stream: %s", name, err) + return fmt.Errorf("%s: copying contents from data stream: %w", name, err) } } else if size > 0 { // As of a recent OS change, BackupRead now returns a data stream for empty sparse files. // These files have no sparse block streams, so skip the copySparse call if file size = 0. if err = copySparse(t, br); err != nil { - return fmt.Errorf("%s: copying contents from sparse block stream: %s", name, err) + return fmt.Errorf("%s: copying contents from sparse block stream: %w", name, err) } } } @@ -334,7 +333,7 @@ func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size // been written. In practice, this means that we don't get EA or TXF metadata. for { bhdr, err := br.Next() - if err == io.EOF { + if err == io.EOF { //nolint:errorlint break } if err != nil { @@ -342,35 +341,30 @@ func WriteTarFileFromBackupStream(t *tar.Writer, r io.Reader, name string, size } switch bhdr.Id { case winio.BackupAlternateData: - altName := bhdr.Name - if strings.HasSuffix(altName, ":$DATA") { - altName = altName[:len(altName)-len(":$DATA")] - } - if (bhdr.Attributes & winio.StreamSparseAttributes) == 0 { - hdr = &tar.Header{ - Format: hdr.Format, - Name: name + altName, - Mode: hdr.Mode, - Typeflag: tar.TypeReg, - Size: bhdr.Size, - ModTime: hdr.ModTime, - AccessTime: hdr.AccessTime, - ChangeTime: hdr.ChangeTime, - } - err = t.WriteHeader(hdr) - if err != nil { - return err - } - _, err = io.Copy(t, br) - if err != nil { - return err - } - - } else { + if (bhdr.Attributes & winio.StreamSparseAttributes) != 0 { // Unsupported for now, since the size of the alternate stream is not present // in the backup stream until after the data has been read. return fmt.Errorf("%s: tar of sparse alternate data streams is unsupported", name) } + altName := strings.TrimSuffix(bhdr.Name, ":$DATA") + hdr = &tar.Header{ + Format: hdr.Format, + Name: name + altName, + Mode: hdr.Mode, + Typeflag: tar.TypeReg, + Size: bhdr.Size, + ModTime: hdr.ModTime, + AccessTime: hdr.AccessTime, + ChangeTime: hdr.ChangeTime, + } + err = t.WriteHeader(hdr) + if err != nil { + return err + } + _, err = io.Copy(t, br) + if err != nil { + return err + } case winio.BackupEaData, winio.BackupLink, winio.BackupPropertyData, winio.BackupObjectId, winio.BackupTxfsData: // ignore these streams default: @@ -412,7 +406,7 @@ func FileInfoFromHeader(hdr *tar.Header) (name string, size int64, fileInfo *win } fileInfo.CreationTime = windows.NsecToFiletime(creationTime.UnixNano()) } - return + return name, size, fileInfo, err } // WriteBackupStreamFromTarFile writes a Win32 backup stream from the current tar file. Since this function may process multiple @@ -473,7 +467,6 @@ func WriteBackupStreamFromTarFile(w io.Writer, t *tar.Reader, hdr *tar.Header) ( if err != nil { return nil, err } - } if hdr.Typeflag == tar.TypeReg || hdr.Typeflag == tar.TypeRegA { diff --git a/backuptar/tar_test.go b/backuptar/tar_test.go index cc2cbf09..8984c596 100644 --- a/backuptar/tar_test.go +++ b/backuptar/tar_test.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package backuptar @@ -6,7 +7,6 @@ import ( "archive/tar" "bytes" "io" - "io/ioutil" "os" "path/filepath" "reflect" @@ -25,8 +25,7 @@ func ensurePresent(t *testing.T, m map[string]string, keys ...string) { } func setSparse(t *testing.T, f *os.File) { - const FSCTL_SET_SPARSE uint32 = 0x900c4 - if err := windows.DeviceIoControl(windows.Handle(f.Fd()), FSCTL_SET_SPARSE, nil, 0, nil, 0, nil, nil); err != nil { + if err := windows.DeviceIoControl(windows.Handle(f.Fd()), windows.FSCTL_SET_SPARSE, nil, 0, nil, 0, nil, nil); err != nil { t.Fatal(err) } } @@ -68,10 +67,12 @@ func compareReaders(t *testing.T, rActual io.Reader, rExpected io.Reader) { func TestRoundTrip(t *testing.T) { // Each test case is a name mapped to a function which must create a file and return its path. // The test then round-trips that file through backuptar, and validates the output matches the input. + // + //nolint:gosec // G306: Expect WriteFile permissions to be 0600 or less for name, setup := range map[string]func(*testing.T) string{ "normalFile": func(t *testing.T) string { path := filepath.Join(t.TempDir(), "foo.txt") - if err := ioutil.WriteFile(path, []byte("testing 1 2 3\n"), 0644); err != nil { + if err := os.WriteFile(path, []byte("testing 1 2 3\n"), 0644); err != nil { t.Fatal(err) } return path diff --git a/doc.go b/doc.go new file mode 100644 index 00000000..1f5bfe2d --- /dev/null +++ b/doc.go @@ -0,0 +1,22 @@ +// This package provides utilities for efficiently performing Win32 IO operations in Go. +// Currently, this package is provides support for genreal IO and management of +// - named pipes +// - files +// - [Hyper-V sockets] +// +// This code is similar to Go's [net] package, and uses IO completion ports to avoid +// blocking IO on system threads, allowing Go to reuse the thread to schedule other goroutines. +// +// This limits support to Windows Vista and newer operating systems. +// +// Additionally, this package provides support for: +// - creating and managing GUIDs +// - writing to [ETW] +// - opening and manageing VHDs +// - parsing [Windows Image files] +// - auto-generating Win32 API code +// +// [Hyper-V sockets]: https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service +// [ETW]: https://docs.microsoft.com/en-us/windows-hardware/drivers/devtest/event-tracing-for-windows--etw- +// [Windows Image files]: https://docs.microsoft.com/en-us/windows-hardware/manufacture/desktop/work-with-windows-images +package winio diff --git a/ea.go b/ea.go index 4051c1b3..e104dbdf 100644 --- a/ea.go +++ b/ea.go @@ -33,7 +33,7 @@ func parseEa(b []byte) (ea ExtendedAttribute, nb []byte, err error) { err = binary.Read(bytes.NewReader(b), binary.LittleEndian, &info) if err != nil { err = errInvalidEaBuffer - return + return ea, nb, err } nameOffset := fileFullEaInformationSize @@ -43,7 +43,7 @@ func parseEa(b []byte) (ea ExtendedAttribute, nb []byte, err error) { nextOffset := int(info.NextEntryOffset) if valueLen+valueOffset > len(b) || nextOffset < 0 || nextOffset > len(b) { err = errInvalidEaBuffer - return + return ea, nb, err } ea.Name = string(b[nameOffset : nameOffset+nameLen]) @@ -52,7 +52,7 @@ func parseEa(b []byte) (ea ExtendedAttribute, nb []byte, err error) { if info.NextEntryOffset != 0 { nb = b[info.NextEntryOffset:] } - return + return ea, nb, err } // DecodeExtendedAttributes decodes a list of EAs from a FILE_FULL_EA_INFORMATION @@ -67,7 +67,7 @@ func DecodeExtendedAttributes(b []byte) (eas []ExtendedAttribute, err error) { eas = append(eas, ea) b = nb } - return + return eas, err } func writeEa(buf *bytes.Buffer, ea *ExtendedAttribute, last bool) error { diff --git a/ea_test.go b/ea_test.go index 3e32076f..5591edb0 100644 --- a/ea_test.go +++ b/ea_test.go @@ -4,7 +4,6 @@ package winio import ( - "io/ioutil" "os" "reflect" "syscall" @@ -18,7 +17,8 @@ var ( {Name: "fizz", Value: []byte("buzz")}, } - testEasEncoded = []byte{16, 0, 0, 0, 0, 3, 3, 0, 102, 111, 111, 0, 98, 97, 114, 0, 0, 0, 0, 0, 0, 4, 4, 0, 102, 105, 122, 122, 0, 98, 117, 122, 122, 0, 0, 0} + testEasEncoded = []byte{16, 0, 0, 0, 0, 3, 3, 0, 102, 111, 111, 0, 98, 97, 114, 0, 0, + 0, 0, 0, 0, 4, 4, 0, 102, 105, 122, 122, 0, 98, 117, 122, 122, 0, 0, 0} testEasNotPadded = testEasEncoded[0 : len(testEasEncoded)-3] testEasTruncated = testEasEncoded[0:20] ) @@ -76,7 +76,7 @@ func Test_NilEasEncodeAndDecodeAsNil(t *testing.T) { // Test_SetFileEa makes sure that the test buffer is actually parsable by NtSetEaFile. func Test_SetFileEa(t *testing.T) { - f, err := ioutil.TempFile("", "winio") + f, err := os.CreateTemp("", "winio") if err != nil { t.Fatal(err) } @@ -85,7 +85,10 @@ func Test_SetFileEa(t *testing.T) { ntdll := syscall.MustLoadDLL("ntdll.dll") ntSetEaFile := ntdll.MustFindProc("NtSetEaFile") var iosb [2]uintptr - r, _, _ := ntSetEaFile.Call(f.Fd(), uintptr(unsafe.Pointer(&iosb[0])), uintptr(unsafe.Pointer(&testEasEncoded[0])), uintptr(len(testEasEncoded))) + r, _, _ := ntSetEaFile.Call(f.Fd(), + uintptr(unsafe.Pointer(&iosb[0])), + uintptr(unsafe.Pointer(&testEasEncoded[0])), + uintptr(len(testEasEncoded))) if r != 0 { t.Fatalf("NtSetEaFile failed with %08x", r) } diff --git a/file.go b/file.go index a88a1269..175a99d3 100644 --- a/file.go +++ b/file.go @@ -11,6 +11,8 @@ import ( "sync/atomic" "syscall" "time" + + "golang.org/x/sys/windows" ) //sys cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) = CancelIoEx @@ -24,6 +26,8 @@ type atomicBool int32 func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) } func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } + +//revive:disable-next-line:predeclared Keep "new" to maintain consistency with "atomic" pkg func (b *atomicBool) swap(new bool) bool { var newInt int32 if new { @@ -32,11 +36,6 @@ func (b *atomicBool) swap(new bool) bool { return atomic.SwapInt32((*int32)(b), newInt) == 1 } -const ( - cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 1 - cFILE_SKIP_SET_EVENT_ON_HANDLE = 2 -) - var ( ErrFileClosed = errors.New("file has already been closed") ErrTimeout = &timeoutError{} @@ -53,19 +52,19 @@ type timeoutChan chan struct{} var ioInitOnce sync.Once var ioCompletionPort syscall.Handle -// ioResult contains the result of an asynchronous IO operation +// ioResult contains the result of an asynchronous IO operation. type ioResult struct { bytes uint32 err error } -// ioOperation represents an outstanding asynchronous Win32 IO +// ioOperation represents an outstanding asynchronous Win32 IO. type ioOperation struct { o syscall.Overlapped ch chan ioResult } -func initIo() { +func initIO() { h, err := createIoCompletionPort(syscall.InvalidHandle, 0, 0, 0xffffffff) if err != nil { panic(err) @@ -94,15 +93,15 @@ type deadlineHandler struct { timedout atomicBool } -// makeWin32File makes a new win32File from an existing file handle +// makeWin32File makes a new win32File from an existing file handle. func makeWin32File(h syscall.Handle) (*win32File, error) { f := &win32File{handle: h} - ioInitOnce.Do(initIo) + ioInitOnce.Do(initIO) _, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff) if err != nil { return nil, err } - err = setFileCompletionNotificationModes(h, cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS|cFILE_SKIP_SET_EVENT_ON_HANDLE) + err = setFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE) if err != nil { return nil, err } @@ -121,14 +120,14 @@ func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) { return f, nil } -// closeHandle closes the resources associated with a Win32 handle +// closeHandle closes the resources associated with a Win32 handle. func (f *win32File) closeHandle() { f.wgLock.Lock() // Atomically set that we are closing, releasing the resources only once. if !f.closing.swap(true) { f.wgLock.Unlock() // cancel all IO and wait for it to complete - cancelIoEx(f.handle, nil) + _ = cancelIoEx(f.handle, nil) f.wg.Wait() // at this point, no new IO can start syscall.Close(f.handle) @@ -144,14 +143,14 @@ func (f *win32File) Close() error { return nil } -// IsClosed checks if the file has been closed +// IsClosed checks if the file has been closed. func (f *win32File) IsClosed() bool { return f.closing.isSet() } -// prepareIo prepares for a new IO operation. +// prepareIO prepares for a new IO operation. // The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. -func (f *win32File) prepareIo() (*ioOperation, error) { +func (f *win32File) prepareIO() (*ioOperation, error) { f.wgLock.RLock() if f.closing.isSet() { f.wgLock.RUnlock() @@ -164,7 +163,7 @@ func (f *win32File) prepareIo() (*ioOperation, error) { return c, nil } -// ioCompletionProcessor processes completed async IOs forever +// ioCompletionProcessor processes completed async IOs forever. func ioCompletionProcessor(h syscall.Handle) { for { var bytes uint32 @@ -180,15 +179,15 @@ func ioCompletionProcessor(h syscall.Handle) { // todo: helsaawy - create an asyncIO version that takes a context -// asyncIo processes the return value from ReadFile or WriteFile, blocking until +// asyncIO processes the return value from ReadFile or WriteFile, blocking until // the operation has actually completed. -func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { - if err != syscall.ERROR_IO_PENDING { +func (f *win32File) asyncIO(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { + if err != syscall.ERROR_IO_PENDING { //nolint:errorlint // err is Errno return int(bytes), err } if f.closing.isSet() { - cancelIoEx(f.handle, &c.o) + _ = cancelIoEx(f.handle, &c.o) } var timeout timeoutChan @@ -202,7 +201,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er select { case r = <-c.ch: err = r.err - if err == syscall.ERROR_OPERATION_ABORTED { + if err == syscall.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno if f.closing.isSet() { err = ErrFileClosed } @@ -212,10 +211,10 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) } case <-timeout: - cancelIoEx(f.handle, &c.o) + _ = cancelIoEx(f.handle, &c.o) r = <-c.ch err = r.err - if err == syscall.ERROR_OPERATION_ABORTED { + if err == syscall.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno err = ErrTimeout } } @@ -230,7 +229,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er // Read reads from a file handle. func (f *win32File) Read(b []byte) (int, error) { - c, err := f.prepareIo() + c, err := f.prepareIO() if err != nil { return 0, err } @@ -242,13 +241,13 @@ func (f *win32File) Read(b []byte) (int, error) { var bytes uint32 err = syscall.ReadFile(f.handle, b, &bytes, &c.o) - n, err := f.asyncIo(c, &f.readDeadline, bytes, err) + n, err := f.asyncIO(c, &f.readDeadline, bytes, err) runtime.KeepAlive(b) // Handle EOF conditions. if err == nil && n == 0 && len(b) != 0 { return 0, io.EOF - } else if err == syscall.ERROR_BROKEN_PIPE { + } else if err == syscall.ERROR_BROKEN_PIPE { //nolint:errorlint // err is Errno return 0, io.EOF } else { return n, err @@ -257,7 +256,7 @@ func (f *win32File) Read(b []byte) (int, error) { // Write writes to a file handle. func (f *win32File) Write(b []byte) (int, error) { - c, err := f.prepareIo() + c, err := f.prepareIO() if err != nil { return 0, err } @@ -269,7 +268,7 @@ func (f *win32File) Write(b []byte) (int, error) { var bytes uint32 err = syscall.WriteFile(f.handle, b, &bytes, &c.o) - n, err := f.asyncIo(c, &f.writeDeadline, bytes, err) + n, err := f.asyncIO(c, &f.writeDeadline, bytes, err) runtime.KeepAlive(b) return n, err } diff --git a/fileinfo.go b/fileinfo.go index 350c1096..702950e7 100644 --- a/fileinfo.go +++ b/fileinfo.go @@ -15,13 +15,18 @@ import ( type FileBasicInfo struct { CreationTime, LastAccessTime, LastWriteTime, ChangeTime windows.Filetime FileAttributes uint32 - pad uint32 // padding + _ uint32 // padding } // GetFileBasicInfo retrieves times and attributes for a file. func GetFileBasicInfo(f *os.File) (*FileBasicInfo, error) { bi := &FileBasicInfo{} - if err := windows.GetFileInformationByHandleEx(windows.Handle(f.Fd()), windows.FileBasicInfo, (*byte)(unsafe.Pointer(bi)), uint32(unsafe.Sizeof(*bi))); err != nil { + if err := windows.GetFileInformationByHandleEx( + windows.Handle(f.Fd()), + windows.FileBasicInfo, + (*byte)(unsafe.Pointer(bi)), + uint32(unsafe.Sizeof(*bi)), + ); err != nil { return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err} } runtime.KeepAlive(f) @@ -30,7 +35,12 @@ func GetFileBasicInfo(f *os.File) (*FileBasicInfo, error) { // SetFileBasicInfo sets times and attributes for a file. func SetFileBasicInfo(f *os.File, bi *FileBasicInfo) error { - if err := windows.SetFileInformationByHandle(windows.Handle(f.Fd()), windows.FileBasicInfo, (*byte)(unsafe.Pointer(bi)), uint32(unsafe.Sizeof(*bi))); err != nil { + if err := windows.SetFileInformationByHandle( + windows.Handle(f.Fd()), + windows.FileBasicInfo, + (*byte)(unsafe.Pointer(bi)), + uint32(unsafe.Sizeof(*bi)), + ); err != nil { return &os.PathError{Op: "SetFileInformationByHandle", Path: f.Name(), Err: err} } runtime.KeepAlive(f) @@ -49,7 +59,10 @@ type FileStandardInfo struct { // GetFileStandardInfo retrieves ended information for the file. func GetFileStandardInfo(f *os.File) (*FileStandardInfo, error) { si := &FileStandardInfo{} - if err := windows.GetFileInformationByHandleEx(windows.Handle(f.Fd()), windows.FileStandardInfo, (*byte)(unsafe.Pointer(si)), uint32(unsafe.Sizeof(*si))); err != nil { + if err := windows.GetFileInformationByHandleEx(windows.Handle(f.Fd()), + windows.FileStandardInfo, + (*byte)(unsafe.Pointer(si)), + uint32(unsafe.Sizeof(*si))); err != nil { return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err} } runtime.KeepAlive(f) @@ -66,7 +79,12 @@ type FileIDInfo struct { // GetFileID retrieves the unique (volume, file ID) pair for a file. func GetFileID(f *os.File) (*FileIDInfo, error) { fileID := &FileIDInfo{} - if err := windows.GetFileInformationByHandleEx(windows.Handle(f.Fd()), windows.FileIdInfo, (*byte)(unsafe.Pointer(fileID)), uint32(unsafe.Sizeof(*fileID))); err != nil { + if err := windows.GetFileInformationByHandleEx( + windows.Handle(f.Fd()), + windows.FileIdInfo, + (*byte)(unsafe.Pointer(fileID)), + uint32(unsafe.Sizeof(*fileID)), + ); err != nil { return nil, &os.PathError{Op: "GetFileInformationByHandleEx", Path: f.Name(), Err: err} } runtime.KeepAlive(f) diff --git a/fileinfo_test.go b/fileinfo_test.go index 8c97f506..bdb87edd 100644 --- a/fileinfo_test.go +++ b/fileinfo_test.go @@ -4,7 +4,6 @@ package winio import ( - "io/ioutil" "os" "testing" @@ -45,7 +44,7 @@ func checkFileStandardInfo(t *testing.T, current, expected *FileStandardInfo) { } func TestGetFileStandardInfo_File(t *testing.T) { - f, err := ioutil.TempFile("", "tst") + f, err := os.CreateTemp("", "tst") if err != nil { t.Fatal(err) } @@ -107,12 +106,7 @@ func TestGetFileStandardInfo_File(t *testing.T) { } func TestGetFileStandardInfo_Directory(t *testing.T) { - tempDir, err := ioutil.TempDir("", "tst") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tempDir) - + tempDir := t.TempDir() // os.Open returns the Search Handle, not the Directory Handle // See https://github.com/golang/go/issues/13738 f, err := OpenForBackup(tempDir, windows.GENERIC_READ, 0, windows.OPEN_EXISTING) diff --git a/hvsock.go b/hvsock.go index 8b56f96b..52f1c280 100644 --- a/hvsock.go +++ b/hvsock.go @@ -20,7 +20,7 @@ import ( "github.com/Microsoft/go-winio/pkg/guid" ) -const afHvSock = 34 // AF_HYPERV +const afHVSock = 34 // AF_HYPERV // Well known Service and VM IDs //https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service#vmid-wildcards @@ -30,7 +30,7 @@ func HvsockGUIDWildcard() guid.GUID { // 00000000-0000-0000-0000-000000000000 return guid.GUID{} } -// HvsockGUIDBroadcast is the wildcard VmId for broadcasting sends to all partitions +// HvsockGUIDBroadcast is the wildcard VmId for broadcasting sends to all partitions. func HvsockGUIDBroadcast() guid.GUID { //ffffffff-ffff-ffff-ffff-ffffffffffff return guid.GUID{ Data1: 0xffffffff, @@ -51,8 +51,8 @@ func HvsockGUIDLoopback() guid.GUID { // e0e16197-dd56-4a10-9195-5ee7a155a838 } // HvsockGUIDSiloHost is the address of a silo's host partition: -// - The silo host of a hosted silo is the utility VM. -// - The silo host of a silo on a physical host is the physical host. +// - The silo host of a hosted silo is the utility VM. +// - The silo host of a silo on a physical host is the physical host. func HvsockGUIDSiloHost() guid.GUID { // 36bd0c5c-7276-4223-88ba-7d03b654c568 return guid.GUID{ Data1: 0x36bd0c5c, @@ -74,10 +74,10 @@ func HvsockGUIDChildren() guid.GUID { // 90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd // HvsockGUIDParent is the wildcard VmId for accepting connections from the connector's parent partition. // Listening on this VmId accepts connection from: -// - Inside silos: silo host partition. -// - Inside hosted silo: host of the VM. -// - Inside VM: VM host. -// - Physical host: Not supported. +// - Inside silos: silo host partition. +// - Inside hosted silo: host of the VM. +// - Inside VM: VM host. +// - Physical host: Not supported. func HvsockGUIDParent() guid.GUID { // a42e7cda-d03f-480c-9cc2-a4de20abb878 return guid.GUID{ Data1: 0xa42e7cda, @@ -87,7 +87,7 @@ func HvsockGUIDParent() guid.GUID { // a42e7cda-d03f-480c-9cc2-a4de20abb878 } } -// hvsockVsockServiceTemplate is the Service GUID used for the VSOCK protocol +// hvsockVsockServiceTemplate is the Service GUID used for the VSOCK protocol. func hvsockVsockServiceTemplate() guid.GUID { // 00000000-facb-11e6-bd58-64006a7986d3 return guid.GUID{ Data2: 0xfacb, @@ -112,7 +112,7 @@ type rawHvsockAddr struct { var _ socket.RawSockaddr = &rawHvsockAddr{} // Network returns the address's network name, "hvsock". -func (addr *HvsockAddr) Network() string { +func (*HvsockAddr) Network() string { return "hvsock" } @@ -129,7 +129,7 @@ func VsockServiceID(port uint32) guid.GUID { func (addr *HvsockAddr) raw() rawHvsockAddr { return rawHvsockAddr{ - Family: afHvSock, + Family: afHVSock, VMID: addr.VMID, ServiceID: addr.ServiceID, } @@ -143,12 +143,12 @@ func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) { // Sockaddr returns a pointer to and the size of this struct. // // Implements the [socket.RawSockaddr] interface, and allows use in -// [socket.Bind()] and [socket.ConnectEx()] +// [socket.Bind] and [socket.ConnectEx]. func (r *rawHvsockAddr) Sockaddr() (unsafe.Pointer, int32, error) { return unsafe.Pointer(r), int32(unsafe.Sizeof(rawHvsockAddr{})), nil } -// Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()` +// Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()`. func (r *rawHvsockAddr) FromBytes(b []byte) error { n := int(unsafe.Sizeof(rawHvsockAddr{})) @@ -157,8 +157,8 @@ func (r *rawHvsockAddr) FromBytes(b []byte) error { } copy(unsafe.Slice((*byte)(unsafe.Pointer(r)), n), b[:n]) - if r.Family != afHvSock { - return fmt.Errorf("got %d, want %d: %w", r.Family, afHvSock, socket.ErrAddrFamily) + if r.Family != afHVSock { + return fmt.Errorf("got %d, want %d: %w", r.Family, afHVSock, socket.ErrAddrFamily) } return nil @@ -180,8 +180,8 @@ type HvsockConn struct { var _ net.Conn = &HvsockConn{} -func newHvSocket() (*win32File, error) { - fd, err := syscall.Socket(afHvSock, syscall.SOCK_STREAM, 1) +func newHVSocket() (*win32File, error) { + fd, err := syscall.Socket(afHVSock, syscall.SOCK_STREAM, 1) if err != nil { return nil, os.NewSyscallError("socket", err) } @@ -197,7 +197,7 @@ func newHvSocket() (*win32File, error) { // ListenHvsock listens for connections on the specified hvsock address. func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) { l := &HvsockListener{addr: *addr} - sock, err := newHvSocket() + sock, err := newHVSocket() if err != nil { return nil, l.opErr("listen", err) } @@ -224,7 +224,7 @@ func (l *HvsockListener) Addr() net.Addr { // Accept waits for the next connection and returns it. func (l *HvsockListener) Accept() (_ net.Conn, err error) { - sock, err := newHvSocket() + sock, err := newHVSocket() if err != nil { return nil, l.opErr("accept", err) } @@ -233,20 +233,21 @@ func (l *HvsockListener) Accept() (_ net.Conn, err error) { sock.Close() } }() - c, err := l.sock.prepareIo() + c, err := l.sock.prepareIO() if err != nil { return nil, l.opErr("accept", err) } defer l.sock.wg.Done() - // AcceptEx, per documentation, requires an extra 16 bytes per address: + // AcceptEx, per documentation, requires an extra 16 bytes per address. + // // https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nf-mswsock-acceptex const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{})) var addrbuf [addrlen * 2]byte var bytes uint32 err = syscall.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0 /*rxdatalen*/, addrlen, addrlen, &bytes, &c.o) - if _, err = l.sock.asyncIo(c, nil, bytes, err); err != nil { + if _, err = l.sock.asyncIO(c, nil, bytes, err); err != nil { return nil, l.opErr("accept", os.NewSyscallError("acceptex", err)) } @@ -294,7 +295,7 @@ type HvsockDialer struct { // Dial the Hyper-V socket at addr. // -// See (*HvsockDialer).Dial for more information. +// See [HvsockDialer.Dial] for more information. func Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) { return (&HvsockDialer{}).Dial(ctx, addr) } @@ -302,6 +303,7 @@ func Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) { // Dial attempts to connect to the Hyper-V socket at addr, and returns a connection if successful. // Will attempt (HvsockDialer).Retries if dialing fails, waiting (HvsockDialer).RetryWait between // retries. +// // Dialing can be cancelled either by providing (HvsockDialer).Deadline, or cancelling ctx. func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) { op := "dial" @@ -321,7 +323,7 @@ func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *Hvsock return nil, conn.opErr(op, err) } - sock, err := newHvSocket() + sock, err := newHVSocket() if err != nil { return nil, conn.opErr(op, err) } @@ -337,7 +339,7 @@ func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *Hvsock return nil, conn.opErr(op, os.NewSyscallError("bind", err)) } - c, err := sock.prepareIo() + c, err := sock.prepareIO() if err != nil { return nil, conn.opErr(op, err) } @@ -351,7 +353,7 @@ func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *Hvsock 0, // sendDataLen &bytes, (*windows.Overlapped)(unsafe.Pointer(&c.o))) - _, err = sock.asyncIo(c, nil, bytes, err) + _, err = sock.asyncIO(c, nil, bytes, err) if i < d.Retries && canRedial(err) { if err = d.redialWait(ctx); err == nil { continue @@ -382,7 +384,7 @@ func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *Hvsock } conn.local.fromRaw(&sal) - // one last check for timeout, since asyncIO doesnt check the context + // one last check for timeout, since asyncIO doesn't check the context if err = ctx.Err(); err != nil { return nil, conn.opErr(op, err) } @@ -393,7 +395,7 @@ func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *Hvsock return conn, nil } -// redialWait waits before attempting to redial, resetting the timer as appropriate +// redialWait waits before attempting to redial, resetting the timer as appropriate. func (d *HvsockDialer) redialWait(ctx context.Context) (err error) { if d.RetryWait == 0 { return nil @@ -419,9 +421,9 @@ func (d *HvsockDialer) redialWait(ctx context.Context) (err error) { return ctx.Err() } -// assumes error is a plain, unwrapped syscall.Errno provided by direct syscall +// assumes error is a plain, unwrapped syscall.Errno provided by direct syscall. func canRedial(err error) bool { - // nolint:errorlint + //nolint:errorlint // guaranteed to be an Errno switch err { case windows.WSAECONNREFUSED, windows.WSAENETUNREACH, windows.WSAETIMEDOUT, windows.ERROR_CONNECTION_REFUSED, windows.ERROR_CONNECTION_UNAVAIL: @@ -440,7 +442,7 @@ func (conn *HvsockConn) opErr(op string, err error) error { } func (conn *HvsockConn) Read(b []byte) (int, error) { - c, err := conn.sock.prepareIo() + c, err := conn.sock.prepareIO() if err != nil { return 0, conn.opErr("read", err) } @@ -448,7 +450,7 @@ func (conn *HvsockConn) Read(b []byte) (int, error) { buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))} var flags, bytes uint32 err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil) - n, err := conn.sock.asyncIo(c, &conn.sock.readDeadline, bytes, err) + n, err := conn.sock.asyncIO(c, &conn.sock.readDeadline, bytes, err) if err != nil { var eno windows.Errno if errors.As(err, &eno) { @@ -475,7 +477,7 @@ func (conn *HvsockConn) Write(b []byte) (int, error) { } func (conn *HvsockConn) write(b []byte) (int, error) { - c, err := conn.sock.prepareIo() + c, err := conn.sock.prepareIO() if err != nil { return 0, conn.opErr("write", err) } @@ -483,7 +485,7 @@ func (conn *HvsockConn) write(b []byte) (int, error) { buf := syscall.WSABuf{Buf: &b[0], Len: uint32(len(b))} var bytes uint32 err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil) - n, err := conn.sock.asyncIo(c, &conn.sock.writeDeadline, bytes, err) + n, err := conn.sock.asyncIO(c, &conn.sock.writeDeadline, bytes, err) if err != nil { var eno windows.Errno if errors.As(err, &eno) { @@ -503,7 +505,7 @@ func (conn *HvsockConn) IsClosed() bool { return conn.sock.IsClosed() } -// shutdown disables sending or receiving on a socket +// shutdown disables sending or receiving on a socket. func (conn *HvsockConn) shutdown(how int) error { if conn.IsClosed() { return socket.ErrSocketClosed diff --git a/internal/socket/socket.go b/internal/socket/socket.go index f4134df4..7f7993de 100644 --- a/internal/socket/socket.go +++ b/internal/socket/socket.go @@ -109,7 +109,7 @@ func (f *runtimeFunc) Load() error { var ( // todo: add `AcceptEx` and `GetAcceptExSockaddrs` - WSAID_CONNECTEX = guid.GUID{ //nolint:revive,stylecheck + WSAID_CONNECTEX = guid.GUID{ //revive:disable-line:var-naming ALL_CAPS Data1: 0x25a207b9, Data2: 0xddf3, Data3: 0x4660, @@ -119,7 +119,14 @@ var ( connectExFunc = runtimeFunc{id: WSAID_CONNECTEX} ) -func ConnectEx(fd windows.Handle, rsa RawSockaddr, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *windows.Overlapped) error { +func ConnectEx( + fd windows.Handle, + rsa RawSockaddr, + sendBuf *byte, + sendDataLen uint32, + bytesSent *uint32, + overlapped *windows.Overlapped, +) error { if err := connectExFunc.Load(); err != nil { return fmt.Errorf("failed to load ConnectEx function pointer: %w", err) } @@ -139,9 +146,28 @@ func ConnectEx(fd windows.Handle, rsa RawSockaddr, sendBuf *byte, sendDataLen ui // [out] LPDWORD lpdwBytesSent, // [in] LPOVERLAPPED lpOverlapped // ) -func connectEx(s windows.Handle, name unsafe.Pointer, namelen int32, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *windows.Overlapped) (err error) { + +func connectEx( + s windows.Handle, + name unsafe.Pointer, + namelen int32, + sendBuf *byte, + sendDataLen uint32, + bytesSent *uint32, + overlapped *windows.Overlapped, +) (err error) { // todo: after upgrading to 1.18, switch from syscall.Syscall9 to syscall.SyscallN - r1, _, e1 := syscall.Syscall9(connectExFunc.addr, 7, uintptr(s), uintptr(name), uintptr(namelen), uintptr(unsafe.Pointer(sendBuf)), uintptr(sendDataLen), uintptr(unsafe.Pointer(bytesSent)), uintptr(unsafe.Pointer(overlapped)), 0, 0) + r1, _, e1 := syscall.Syscall9(connectExFunc.addr, + 7, + uintptr(s), + uintptr(name), + uintptr(namelen), + uintptr(unsafe.Pointer(sendBuf)), + uintptr(sendDataLen), + uintptr(unsafe.Pointer(bytesSent)), + uintptr(unsafe.Pointer(overlapped)), + 0, + 0) if r1 == 0 { if e1 != 0 { err = error(e1) @@ -149,5 +175,5 @@ func connectEx(s windows.Handle, name unsafe.Pointer, namelen int32, sendBuf *by err = syscall.EINVAL } } - return + return err } diff --git a/pipe.go b/pipe.go index 1acb2014..ca6e38fc 100644 --- a/pipe.go +++ b/pipe.go @@ -14,6 +14,8 @@ import ( "syscall" "time" "unsafe" + + "golang.org/x/sys/windows" ) //sys connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) = ConnectNamedPipe @@ -22,10 +24,10 @@ import ( //sys getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo //sys getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW //sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc -//sys ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile -//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb -//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U -//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl +//sys ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntStatus) = ntdll.NtCreateNamedPipeFile +//sys rtlNtStatusToDosError(status ntStatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb +//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) = ntdll.RtlDosPathNameToNtPathName_U +//sys rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) = ntdll.RtlDefaultNpAcl type ioStatusBlock struct { Status, Information uintptr @@ -52,41 +54,19 @@ type securityDescriptor struct { Control uint16 Owner uintptr Group uintptr - Sacl uintptr - Dacl uintptr + Sacl uintptr //revive:disable-line:var-naming SACL, not Sacl + Dacl uintptr //revive:disable-line:var-naming DACL, not Dacl } -type ntstatus int32 +type ntStatus int32 -func (status ntstatus) Err() error { +func (status ntStatus) Err() error { if status >= 0 { return nil } return rtlNtStatusToDosError(status) } -const ( - cERROR_PIPE_BUSY = syscall.Errno(231) - cERROR_NO_DATA = syscall.Errno(232) - cERROR_PIPE_CONNECTED = syscall.Errno(535) - cERROR_SEM_TIMEOUT = syscall.Errno(121) - - cSECURITY_SQOS_PRESENT = 0x100000 - cSECURITY_ANONYMOUS = 0 - - cPIPE_TYPE_MESSAGE = 4 - - cPIPE_READMODE_MESSAGE = 2 - - cFILE_OPEN = 1 - cFILE_CREATE = 2 - - cFILE_PIPE_MESSAGE_TYPE = 1 - cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2 - - cSE_DACL_PRESENT = 4 -) - var ( // ErrPipeListenerClosed is returned for pipe operations on listeners that have been closed. ErrPipeListenerClosed = net.ErrClosed @@ -116,9 +96,10 @@ func (f *win32Pipe) RemoteAddr() net.Addr { } func (f *win32Pipe) SetDeadline(t time.Time) error { - f.SetReadDeadline(t) - f.SetWriteDeadline(t) - return nil + if err := f.SetReadDeadline(t); err != nil { + return err + } + return f.SetWriteDeadline(t) } // CloseWrite closes the write side of a message pipe in byte mode. @@ -157,14 +138,14 @@ func (f *win32MessageBytePipe) Read(b []byte) (int, error) { return 0, io.EOF } n, err := f.win32File.Read(b) - if err == io.EOF { + if err == io.EOF { //nolint:errorlint // If this was the result of a zero-byte read, then // it is possible that the read was due to a zero-size // message. Since we are simulating CloseWrite with a // zero-byte message, ensure that all future Read() calls // also return EOF. f.readEOF = true - } else if err == syscall.ERROR_MORE_DATA { + } else if err == syscall.ERROR_MORE_DATA { //nolint:errorlint // err is Errno // ERROR_MORE_DATA indicates that the pipe's read mode is message mode // and the message still has more bytes. Treat this as a success, since // this package presents all named pipes as byte streams. @@ -173,7 +154,7 @@ func (f *win32MessageBytePipe) Read(b []byte) (int, error) { return n, err } -func (s pipeAddress) Network() string { +func (pipeAddress) Network() string { return "pipe" } @@ -184,16 +165,21 @@ func (s pipeAddress) String() string { // tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout. func tryDialPipe(ctx context.Context, path *string, access uint32) (syscall.Handle, error) { for { - select { case <-ctx.Done(): return syscall.Handle(0), ctx.Err() default: - h, err := createFile(*path, access, 0, nil, syscall.OPEN_EXISTING, syscall.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0) + h, err := createFile(*path, + access, + 0, + nil, + syscall.OPEN_EXISTING, + windows.FILE_FLAG_OVERLAPPED|windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, + 0) if err == nil { return h, nil } - if err != cERROR_PIPE_BUSY { + if err != windows.ERROR_PIPE_BUSY { //nolint:errorlint // err is Errno return h, &os.PathError{Err: err, Op: "open", Path: *path} } // Wait 10 msec and try again. This is a rather simplistic @@ -213,9 +199,10 @@ func DialPipe(path string, timeout *time.Duration) (net.Conn, error) { } else { absTimeout = time.Now().Add(2 * time.Second) } - ctx, _ := context.WithDeadline(context.Background(), absTimeout) + ctx, cancel := context.WithDeadline(context.Background(), absTimeout) + defer cancel() conn, err := DialPipeContext(ctx, path) - if err == context.DeadlineExceeded { + if errors.Is(err, context.DeadlineExceeded) { return nil, ErrTimeout } return conn, err @@ -251,7 +238,7 @@ func DialPipeAccess(ctx context.Context, path string, access uint32) (net.Conn, // If the pipe is in message mode, return a message byte pipe, which // supports CloseWrite(). - if flags&cPIPE_TYPE_MESSAGE != 0 { + if flags&windows.PIPE_TYPE_MESSAGE != 0 { return &win32MessageBytePipe{ win32Pipe: win32Pipe{win32File: f, path: path}, }, nil @@ -283,7 +270,11 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy oa.Length = unsafe.Sizeof(oa) var ntPath unicodeString - if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil { + if err := rtlDosPathNameToNtPathName(&path16[0], + &ntPath, + 0, + 0, + ).Err(); err != nil { return 0, &os.PathError{Op: "open", Path: path, Err: err} } defer localFree(ntPath.Buffer) @@ -292,8 +283,8 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy // The security descriptor is only needed for the first pipe. if first { if sd != nil { - len := uint32(len(sd)) - sdb := localAlloc(0, len) + l := uint32(len(sd)) + sdb := localAlloc(0, l) defer localFree(sdb) copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd) oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb)) @@ -301,28 +292,28 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy // Construct the default named pipe security descriptor. var dacl uintptr if err := rtlDefaultNpAcl(&dacl).Err(); err != nil { - return 0, fmt.Errorf("getting default named pipe ACL: %s", err) + return 0, fmt.Errorf("getting default named pipe ACL: %w", err) } defer localFree(dacl) sdb := &securityDescriptor{ Revision: 1, - Control: cSE_DACL_PRESENT, + Control: windows.SE_DACL_PRESENT, Dacl: dacl, } oa.SecurityDescriptor = sdb } } - typ := uint32(cFILE_PIPE_REJECT_REMOTE_CLIENTS) + typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS) if c.MessageMode { - typ |= cFILE_PIPE_MESSAGE_TYPE + typ |= windows.FILE_PIPE_MESSAGE_TYPE } - disposition := uint32(cFILE_OPEN) + disposition := uint32(windows.FILE_OPEN) access := uint32(syscall.GENERIC_READ | syscall.GENERIC_WRITE | syscall.SYNCHRONIZE) if first { - disposition = cFILE_CREATE + disposition = windows.FILE_CREATE // By not asking for read or write access, the named pipe file system // will put this pipe into an initially disconnected state, blocking // client connections until the next call with first == false. @@ -335,7 +326,20 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy h syscall.Handle iosb ioStatusBlock ) - err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err() + err = ntCreateNamedPipeFile(&h, + access, + &oa, + &iosb, + syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE, + disposition, + 0, + typ, + 0, + 0, + 0xffffffff, + uint32(c.InputBufferSize), + uint32(c.OutputBufferSize), + &timeout).Err() if err != nil { return 0, &os.PathError{Op: "open", Path: path, Err: err} } @@ -380,7 +384,7 @@ func (l *win32PipeListener) makeConnectedServerPipe() (*win32File, error) { p.Close() p = nil err = <-ch - if err == nil || err == ErrFileClosed { + if err == nil || err == ErrFileClosed { //nolint:errorlint // err is Errno err = ErrPipeListenerClosed } } @@ -402,12 +406,12 @@ func (l *win32PipeListener) listenerRoutine() { p, err = l.makeConnectedServerPipe() // If the connection was immediately closed by the client, try // again. - if err != cERROR_NO_DATA { + if err != windows.ERROR_NO_DATA { //nolint:errorlint // err is Errno break } } responseCh <- acceptResponse{p, err} - closed = err == ErrPipeListenerClosed + closed = err == ErrPipeListenerClosed //nolint:errorlint // err is Errno } } syscall.Close(l.firstHandle) @@ -469,15 +473,15 @@ func ListenPipe(path string, c *PipeConfig) (net.Listener, error) { } func connectPipe(p *win32File) error { - c, err := p.prepareIo() + c, err := p.prepareIO() if err != nil { return err } defer p.wg.Done() err = connectNamedPipe(p.handle, &c.o) - _, err = p.asyncIo(c, nil, 0, err) - if err != nil && err != cERROR_PIPE_CONNECTED { + _, err = p.asyncIO(c, nil, 0, err) + if err != nil && err != windows.ERROR_PIPE_CONNECTED { //nolint:errorlint // err is Errno return err } return nil diff --git a/pipe_test.go b/pipe_test.go index aa0eb77c..c07c06be 100644 --- a/pipe_test.go +++ b/pipe_test.go @@ -7,14 +7,16 @@ import ( "bufio" "bytes" "context" + "errors" "io" "net" - "os" "sync" "syscall" "testing" "time" "unsafe" + + "golang.org/x/sys/windows" ) var testPipeName = `\\.\pipe\winiotestpipe` @@ -23,7 +25,7 @@ var aLongTimeAgo = time.Unix(1, 0) func TestDialUnknownFailsImmediately(t *testing.T) { _, err := DialPipe(testPipeName, nil) - if err.(*os.PathError).Err != syscall.ENOENT { + if !errors.Is(err, syscall.ENOENT) { t.Fatalf("expected ENOENT got %v", err) } } @@ -34,9 +36,9 @@ func TestDialListenerTimesOut(t *testing.T) { t.Fatal(err) } defer l.Close() - var d = time.Duration(10 * time.Millisecond) + var d = 10 * time.Millisecond _, err = DialPipe(testPipeName, &d) - if err != ErrTimeout { + if !errors.Is(err, ErrTimeout) { t.Fatalf("expected ErrTimeout, got %v", err) } } @@ -47,10 +49,11 @@ func TestDialContextListenerTimesOut(t *testing.T) { t.Fatal(err) } defer l.Close() - var d = time.Duration(10 * time.Millisecond) - ctx, _ := context.WithTimeout(context.Background(), d) + var d = 10 * time.Millisecond + ctx, cancel := context.WithTimeout(context.Background(), d) + defer cancel() _, err = DialPipeContext(ctx, testPipeName) - if err != context.DeadlineExceeded { + if !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("expected context.DeadlineExceeded, got %v", err) } } @@ -70,7 +73,7 @@ func TestDialListenerGetsCancelled(t *testing.T) { time.Sleep(time.Millisecond * 30) cancel() err = <-ch - if err != context.Canceled { + if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) } } @@ -85,7 +88,7 @@ func TestDialAccessDeniedWithRestrictedSD(t *testing.T) { } defer l.Close() _, err = DialPipe(testPipeName, nil) - if err.(*os.PathError).Err != syscall.ERROR_ACCESS_DENIED { + if !errors.Is(err, syscall.ERROR_ACCESS_DENIED) { t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err) } } @@ -93,7 +96,7 @@ func TestDialAccessDeniedWithRestrictedSD(t *testing.T) { func getConnection(cfg *PipeConfig) (client net.Conn, server net.Conn, err error) { l, err := ListenPipe(testPipeName, cfg) if err != nil { - return + return nil, nil, err } defer l.Close() @@ -109,18 +112,16 @@ func getConnection(cfg *PipeConfig) (client net.Conn, server net.Conn, err error c, err := DialPipe(testPipeName, nil) if err != nil { - return + return client, server, err } r := <-ch if err = r.err; err != nil { c.Close() - return + return nil, nil, err } - client = c - server = r.c - return + return c, r.c, nil } func TestReadTimeout(t *testing.T) { @@ -131,11 +132,11 @@ func TestReadTimeout(t *testing.T) { defer c.Close() defer s.Close() - c.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + _ = c.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) buf := make([]byte, 10) _, err = c.Read(buf) - if err != ErrTimeout { + if !errors.Is(err, ErrTimeout) { t.Fatalf("expected ErrTimeout, got %v", err) } } @@ -216,7 +217,7 @@ func TestCloseAbortsListen(t *testing.T) { l.Close() err = <-ch - if err != ErrPipeListenerClosed { + if !errors.Is(err, ErrPipeListenerClosed) { t.Fatalf("expected ErrPipeListenerClosed, got %v", err) } } @@ -275,7 +276,7 @@ func TestCloseWriteEOF(t *testing.T) { b := make([]byte, 10) _, err = s.Read(b) - if err != io.EOF { + if !errors.Is(err, io.EOF) { t.Fatal(err) } } @@ -287,7 +288,7 @@ func TestAcceptAfterCloseFails(t *testing.T) { } l.Close() _, err = l.Accept() - if err != ErrPipeListenerClosed { + if !errors.Is(err, ErrPipeListenerClosed) { t.Fatalf("expected ErrPipeListenerClosed, got %v", err) } } @@ -299,7 +300,7 @@ func TestDialTimesOutByDefault(t *testing.T) { } defer l.Close() _, err = DialPipe(testPipeName, nil) - if err != ErrTimeout { + if !errors.Is(err, ErrTimeout) { t.Fatalf("expected ErrTimeout, got %v", err) } } @@ -316,7 +317,8 @@ func TestTimeoutPendingRead(t *testing.T) { go func() { s, err := l.Accept() if err != nil { - t.Fatal(err) + t.Error(err) + return } time.Sleep(1 * time.Second) s.Close() @@ -337,11 +339,11 @@ func TestTimeoutPendingRead(t *testing.T) { }() time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline - client.SetReadDeadline(aLongTimeAgo) + _ = client.SetReadDeadline(aLongTimeAgo) select { case err = <-clientErr: - if err != ErrTimeout { + if !errors.Is(err, ErrTimeout) { t.Fatalf("expected ErrTimeout, got %v", err) } case <-time.After(100 * time.Millisecond): @@ -363,7 +365,8 @@ func TestTimeoutPendingWrite(t *testing.T) { go func() { s, err := l.Accept() if err != nil { - t.Fatal(err) + t.Error(err) + return } time.Sleep(1 * time.Second) s.Close() @@ -383,11 +386,11 @@ func TestTimeoutPendingWrite(t *testing.T) { }() time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline - client.SetWriteDeadline(aLongTimeAgo) + _ = client.SetWriteDeadline(aLongTimeAgo) select { case err = <-clientErr: - if err != ErrTimeout { + if !errors.Is(err, ErrTimeout) { t.Fatalf("expected ErrTimeout, got %v", err) } case <-time.After(100 * time.Millisecond): @@ -419,13 +422,14 @@ func TestEchoWithMessaging(t *testing.T) { // server echo conn, e := l.Accept() if e != nil { - t.Fatal(e) + t.Error(err) + return } defer conn.Close() time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent - io.Copy(conn, conn) - conn.(CloseWriter).CloseWrite() + _, _ = io.Copy(conn, conn) + _ = conn.(CloseWriter).CloseWrite() close(listenerDone) }() timeout := 1 * time.Second @@ -440,10 +444,12 @@ func TestEchoWithMessaging(t *testing.T) { bytes := make([]byte, 2) n, e := client.Read(bytes) if e != nil { - t.Fatal(e) + t.Error(err) + return } if n != 2 { - t.Fatalf("expected 2 bytes, got %v", n) + t.Errorf("expected 2 bytes, got %v", n) + return } close(clientDone) }() @@ -459,7 +465,7 @@ func TestEchoWithMessaging(t *testing.T) { if n != 2 { t.Fatalf("expected 2 bytes, got %v", n) } - client.(CloseWriter).CloseWrite() + _ = client.(CloseWriter).CloseWrite() <-listenerDone <-clientDone } @@ -473,12 +479,13 @@ func TestConnectRace(t *testing.T) { go func() { for { s, err := l.Accept() - if err == ErrPipeListenerClosed { + if errors.Is(err, ErrPipeListenerClosed) { return } if err != nil { - t.Fatal(err) + t.Error(err) + return } s.Close() } @@ -510,11 +517,13 @@ func TestMessageReadMode(t *testing.T) { defer wg.Done() s, err := l.Accept() if err != nil { - t.Fatal(err) + t.Error(err) + return } _, err = s.Write(msg) if err != nil { - t.Fatal(err) + t.Error(err) + return } s.Close() }() @@ -528,7 +537,7 @@ func TestMessageReadMode(t *testing.T) { setNamedPipeHandleState := syscall.NewLazyDLL("kernel32.dll").NewProc("SetNamedPipeHandleState") p := c.(*win32MessageBytePipe) - mode := uint32(cPIPE_READMODE_MESSAGE) + mode := uint32(windows.PIPE_READMODE_MESSAGE) if s, _, err := setNamedPipeHandleState.Call(uintptr(p.handle), uintptr(unsafe.Pointer(&mode)), 0, 0); s == 0 { t.Fatal(err) } @@ -537,7 +546,7 @@ func TestMessageReadMode(t *testing.T) { var vmsg []byte for { n, err := c.Read(ch) - if err == io.EOF { + if err == io.EOF { //nolint:errorlint break } if err != nil { diff --git a/pkg/etw/doc.go b/pkg/etw/doc.go new file mode 100644 index 00000000..888def77 --- /dev/null +++ b/pkg/etw/doc.go @@ -0,0 +1,8 @@ +// Package etw provides support for TraceLogging-based ETW (Event Tracing +// for Windows). TraceLogging is a format of ETW events that are self-describing +// (the event contains information on its own schema). This allows them to be +// decoded without needing a separate manifest with event information. The +// implementation here is based on the information found in +// TraceLoggingProvider.h in the Windows SDK, which implements TraceLogging as a +// set of C macros. +package etw diff --git a/pkg/etw/eventdata.go b/pkg/etw/eventdata.go index abf16803..a6354754 100644 --- a/pkg/etw/eventdata.go +++ b/pkg/etw/eventdata.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package etw @@ -14,60 +15,60 @@ type eventData struct { buffer bytes.Buffer } -// bytes returns the raw binary data containing the event data. The returned +// toBytes returns the raw binary data containing the event data. The returned // value is not copied from the internal buffer, so it can be mutated by the // eventData object after it is returned. -func (ed *eventData) bytes() []byte { +func (ed *eventData) toBytes() []byte { return ed.buffer.Bytes() } // writeString appends a string, including the null terminator, to the buffer. func (ed *eventData) writeString(data string) { - ed.buffer.WriteString(data) - ed.buffer.WriteByte(0) + _, _ = ed.buffer.WriteString(data) + _ = ed.buffer.WriteByte(0) } // writeInt8 appends a int8 to the buffer. func (ed *eventData) writeInt8(value int8) { - ed.buffer.WriteByte(uint8(value)) + _ = ed.buffer.WriteByte(uint8(value)) } // writeInt16 appends a int16 to the buffer. func (ed *eventData) writeInt16(value int16) { - binary.Write(&ed.buffer, binary.LittleEndian, value) + _ = binary.Write(&ed.buffer, binary.LittleEndian, value) } // writeInt32 appends a int32 to the buffer. func (ed *eventData) writeInt32(value int32) { - binary.Write(&ed.buffer, binary.LittleEndian, value) + _ = binary.Write(&ed.buffer, binary.LittleEndian, value) } // writeInt64 appends a int64 to the buffer. func (ed *eventData) writeInt64(value int64) { - binary.Write(&ed.buffer, binary.LittleEndian, value) + _ = binary.Write(&ed.buffer, binary.LittleEndian, value) } // writeUint8 appends a uint8 to the buffer. func (ed *eventData) writeUint8(value uint8) { - ed.buffer.WriteByte(value) + _ = ed.buffer.WriteByte(value) } // writeUint16 appends a uint16 to the buffer. func (ed *eventData) writeUint16(value uint16) { - binary.Write(&ed.buffer, binary.LittleEndian, value) + _ = binary.Write(&ed.buffer, binary.LittleEndian, value) } // writeUint32 appends a uint32 to the buffer. func (ed *eventData) writeUint32(value uint32) { - binary.Write(&ed.buffer, binary.LittleEndian, value) + _ = binary.Write(&ed.buffer, binary.LittleEndian, value) } // writeUint64 appends a uint64 to the buffer. func (ed *eventData) writeUint64(value uint64) { - binary.Write(&ed.buffer, binary.LittleEndian, value) + _ = binary.Write(&ed.buffer, binary.LittleEndian, value) } // writeFiletime appends a FILETIME to the buffer. func (ed *eventData) writeFiletime(value syscall.Filetime) { - binary.Write(&ed.buffer, binary.LittleEndian, value) + _ = binary.Write(&ed.buffer, binary.LittleEndian, value) } diff --git a/pkg/etw/eventdatadescriptor.go b/pkg/etw/eventdatadescriptor.go index 8b0ad481..9cbef490 100644 --- a/pkg/etw/eventdatadescriptor.go +++ b/pkg/etw/eventdatadescriptor.go @@ -1,3 +1,5 @@ +//go:build windows + package etw import ( @@ -13,11 +15,11 @@ const ( ) type eventDataDescriptor struct { - ptr ptr64 - size uint32 - dataType eventDataDescriptorType - reserved1 uint8 - reserved2 uint16 + ptr ptr64 + size uint32 + dataType eventDataDescriptorType + _ uint8 + _ uint16 } func newEventDataDescriptor(dataType eventDataDescriptorType, buffer []byte) eventDataDescriptor { diff --git a/pkg/etw/eventdescriptor.go b/pkg/etw/eventdescriptor.go index cc41f159..0dd11b45 100644 --- a/pkg/etw/eventdescriptor.go +++ b/pkg/etw/eventdescriptor.go @@ -1,3 +1,5 @@ +//go:build windows + package etw // Channel represents the ETW logging channel that is used. It can be used by @@ -45,6 +47,8 @@ const ( ) // EventDescriptor represents various metadata for an ETW event. +// +//nolint:structcheck // task is currently unused type eventDescriptor struct { id uint16 version uint8 @@ -70,6 +74,8 @@ func newEventDescriptor() *eventDescriptor { // should uniquely identify the other event metadata (contained in // EventDescriptor, and field metadata). Only the lower 24 bits of this value // are relevant. +// +//nolint:unused // keep for future use func (ed *eventDescriptor) identity() uint32 { return (uint32(ed.version) << 16) | uint32(ed.id) } @@ -78,6 +84,8 @@ func (ed *eventDescriptor) identity() uint32 { // should uniquely identify the other event metadata (contained in // EventDescriptor, and field metadata). Only the lower 24 bits of this value // are relevant. +// +//nolint:unused // keep for future use func (ed *eventDescriptor) setIdentity(identity uint32) { ed.id = uint16(identity) ed.version = uint8(identity >> 16) diff --git a/pkg/etw/eventmetadata.go b/pkg/etw/eventmetadata.go index 6fdc126c..a2e151a5 100644 --- a/pkg/etw/eventmetadata.go +++ b/pkg/etw/eventmetadata.go @@ -1,3 +1,5 @@ +//go:build windows + package etw import ( @@ -10,6 +12,8 @@ type inType byte // Various inType definitions for TraceLogging. These must match the definitions // found in TraceLoggingProvider.h in the Windows SDK. +// +//nolint:deadcode,varcheck // keep unused constants for potential future use const ( inTypeNull inType = iota inTypeUnicodeString @@ -47,6 +51,8 @@ type outType byte // Various outType definitions for TraceLogging. These must match the // definitions found in TraceLoggingProvider.h in the Windows SDK. +// +//nolint:deadcode,varcheck // keep unused constants for potential future use const ( // outTypeDefault indicates that the default formatting for the inType will // be used by the event decoder. @@ -81,11 +87,11 @@ type eventMetadata struct { buffer bytes.Buffer } -// bytes returns the raw binary data containing the event metadata. Before being +// toBytes returns the raw binary data containing the event metadata. Before being // returned, the current size of the buffer is written to the start of the // buffer. The returned value is not copied from the internal buffer, so it can // be mutated by the eventMetadata object after it is returned. -func (em *eventMetadata) bytes() []byte { +func (em *eventMetadata) toBytes() []byte { // Finalize the event metadata buffer by filling in the buffer length at the // beginning. binary.LittleEndian.PutUint16(em.buffer.Bytes(), uint16(em.buffer.Len())) @@ -95,7 +101,7 @@ func (em *eventMetadata) bytes() []byte { // writeEventHeader writes the metadata for the start of an event to the buffer. // This specifies the event name and tags. func (em *eventMetadata) writeEventHeader(name string, tags uint32) { - binary.Write(&em.buffer, binary.LittleEndian, uint16(0)) // Length placeholder + _ = binary.Write(&em.buffer, binary.LittleEndian, uint16(0)) // Length placeholder em.writeTags(tags) em.buffer.WriteString(name) em.buffer.WriteByte(0) // Null terminator for name @@ -118,7 +124,7 @@ func (em *eventMetadata) writeFieldInner(name string, inType inType, outType out } if arrSize != 0 { - binary.Write(&em.buffer, binary.LittleEndian, arrSize) + _ = binary.Write(&em.buffer, binary.LittleEndian, arrSize) } } @@ -151,13 +157,17 @@ func (em *eventMetadata) writeTags(tags uint32) { } // writeField writes the metadata for a simple field to the buffer. +// +//nolint:unparam // tags is currently always 0, may change in the future func (em *eventMetadata) writeField(name string, inType inType, outType outType, tags uint32) { em.writeFieldInner(name, inType, outType, tags, 0) } // writeArray writes the metadata for an array field to the buffer. The number // of elements in the array must be written as a uint16 in the event data, -// immediately preceeding the event data. +// immediately preceding the event data. +// +//nolint:unparam // tags is currently always 0, may change in the future func (em *eventMetadata) writeArray(name string, inType inType, outType outType, tags uint32) { em.writeFieldInner(name, inType|inTypeArray, outType, tags, 0) } @@ -165,6 +175,8 @@ func (em *eventMetadata) writeArray(name string, inType inType, outType outType, // writeCountedArray writes the metadata for an array field to the buffer. The // size of a counted array is fixed, and the size is written into the metadata // directly. +// +//nolint:unused // keep for future use func (em *eventMetadata) writeCountedArray(name string, count uint16, inType inType, outType outType, tags uint32) { em.writeFieldInner(name, inType|inTypeCountedArray, outType, tags, count) } diff --git a/pkg/etw/eventopt.go b/pkg/etw/eventopt.go index eaace688..73403220 100644 --- a/pkg/etw/eventopt.go +++ b/pkg/etw/eventopt.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package etw diff --git a/pkg/etw/fieldopt.go b/pkg/etw/fieldopt.go index b5ea80a4..b769c896 100644 --- a/pkg/etw/fieldopt.go +++ b/pkg/etw/fieldopt.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package etw @@ -481,7 +482,7 @@ func SmartField(name string, v interface{}) FieldOpt { case reflect.Int32: return SmartField(name, int32(rv.Int())) case reflect.Int64: - return SmartField(name, int64(rv.Int())) + return SmartField(name, int64(rv.Int())) //nolint:unconvert // make look consistent case reflect.Uint: return SmartField(name, uint(rv.Uint())) case reflect.Uint8: @@ -491,13 +492,13 @@ func SmartField(name string, v interface{}) FieldOpt { case reflect.Uint32: return SmartField(name, uint32(rv.Uint())) case reflect.Uint64: - return SmartField(name, uint64(rv.Uint())) + return SmartField(name, uint64(rv.Uint())) //nolint:unconvert // make look consistent case reflect.Uintptr: return SmartField(name, uintptr(rv.Uint())) case reflect.Float32: return SmartField(name, float32(rv.Float())) case reflect.Float64: - return SmartField(name, float64(rv.Float())) + return SmartField(name, float64(rv.Float())) //nolint:unconvert // make look consistent case reflect.String: return SmartField(name, rv.String()) case reflect.Struct: @@ -509,6 +510,9 @@ func SmartField(name string, v interface{}) FieldOpt { } } return Struct(name, fields...) + case reflect.Array, reflect.Chan, reflect.Complex128, reflect.Complex64, + reflect.Func, reflect.Interface, reflect.Invalid, reflect.Map, reflect.Ptr, + reflect.Slice, reflect.UnsafePointer: } } diff --git a/pkg/etw/newprovider.go b/pkg/etw/newprovider.go index 581ef595..3669b4f7 100644 --- a/pkg/etw/newprovider.go +++ b/pkg/etw/newprovider.go @@ -1,3 +1,4 @@ +//go:build windows && (amd64 || arm64 || 386) // +build windows // +build amd64 arm64 386 @@ -45,18 +46,18 @@ func NewProviderWithOptions(name string, options ...ProviderOpt) (provider *Prov trait := &bytes.Buffer{} if opts.group != (guid.GUID{}) { - binary.Write(trait, binary.LittleEndian, uint16(0)) // Write empty size for buffer (update later) - binary.Write(trait, binary.LittleEndian, uint8(1)) // EtwProviderTraitTypeGroup - traitArray := opts.group.ToWindowsArray() // Append group guid + _ = binary.Write(trait, binary.LittleEndian, uint16(0)) // Write empty size for buffer (update later) + _ = binary.Write(trait, binary.LittleEndian, uint8(1)) // EtwProviderTraitTypeGroup + traitArray := opts.group.ToWindowsArray() // Append group guid trait.Write(traitArray[:]) binary.LittleEndian.PutUint16(trait.Bytes(), uint16(trait.Len())) // Update size } metadata := &bytes.Buffer{} - binary.Write(metadata, binary.LittleEndian, uint16(0)) // Write empty size for buffer (to update later) + _ = binary.Write(metadata, binary.LittleEndian, uint16(0)) // Write empty size for buffer (to update later) metadata.WriteString(name) metadata.WriteByte(0) // Null terminator for name - trait.WriteTo(metadata) // Add traits if applicable + _, _ = trait.WriteTo(metadata) // Add traits if applicable binary.LittleEndian.PutUint16(metadata.Bytes(), uint16(metadata.Len())) // Update the size at the beginning of the buffer provider.metadata = metadata.Bytes() @@ -64,8 +65,8 @@ func NewProviderWithOptions(name string, options ...ProviderOpt) (provider *Prov provider.handle, eventInfoClassProviderSetTraits, uintptr(unsafe.Pointer(&provider.metadata[0])), - uint32(len(provider.metadata))); err != nil { - + uint32(len(provider.metadata)), + ); err != nil { return nil, err } diff --git a/pkg/etw/newprovider_unsupported.go b/pkg/etw/newprovider_unsupported.go index 5a05c134..e0057cfe 100644 --- a/pkg/etw/newprovider_unsupported.go +++ b/pkg/etw/newprovider_unsupported.go @@ -1,5 +1,5 @@ -// +build windows -// +build arm +//go:build windows && arm +// +build windows,arm package etw diff --git a/pkg/etw/provider.go b/pkg/etw/provider.go index a5b90d03..8174bff1 100644 --- a/pkg/etw/provider.go +++ b/pkg/etw/provider.go @@ -1,9 +1,10 @@ +//go:build windows // +build windows package etw import ( - "crypto/sha1" + "crypto/sha1" //nolint:gosec // not used for secure application "encoding/binary" "strings" "unicode/utf16" @@ -27,7 +28,7 @@ type Provider struct { keywordAll uint64 } -// String returns the `provider`.ID as a string +// String returns the `provider`.ID as a string. func (provider *Provider) String() string { if provider == nil { return "" @@ -54,6 +55,7 @@ const ( type eventInfoClass uint32 +//nolint:deadcode,varcheck // keep unused constants for potential future use const ( eventInfoClassProviderBinaryTrackInfo eventInfoClass = iota eventInfoClassProviderSetReserved1 @@ -65,10 +67,19 @@ const ( // enable/disable notifications from ETW. type EnableCallback func(guid.GUID, ProviderState, Level, uint64, uint64, uintptr) -func providerCallback(sourceID guid.GUID, state ProviderState, level Level, matchAnyKeyword uint64, matchAllKeyword uint64, filterData uintptr, i uintptr) { +func providerCallback( + sourceID guid.GUID, + state ProviderState, + level Level, + matchAnyKeyword uint64, + matchAllKeyword uint64, + filterData uintptr, + i uintptr, +) { provider := providers.getProvider(uint(i)) switch state { + case ProviderStateCaptureState: case ProviderStateDisable: provider.enabled = false case ProviderStateEnable: @@ -90,17 +101,22 @@ func providerCallback(sourceID guid.GUID, state ProviderState, level Level, matc // // The algorithm is roughly the RFC 4122 algorithm for a V5 UUID, but differs in // the following ways: -// - The input name is first upper-cased, UTF16-encoded, and converted to -// big-endian. -// - No variant is set on the result UUID. -// - The result UUID is treated as being in little-endian format, rather than -// big-endian. +// - The input name is first upper-cased, UTF16-encoded, and converted to +// big-endian. +// - No variant is set on the result UUID. +// - The result UUID is treated as being in little-endian format, rather than +// big-endian. func providerIDFromName(name string) guid.GUID { - buffer := sha1.New() - namespace := guid.GUID{0x482C2DB2, 0xC390, 0x47C8, [8]byte{0x87, 0xF8, 0x1A, 0x15, 0xBF, 0xC1, 0x30, 0xFB}} + buffer := sha1.New() //nolint:gosec // not used for secure application + namespace := guid.GUID{ + Data1: 0x482C2DB2, + Data2: 0xC390, + Data3: 0x47C8, + Data4: [8]byte{0x87, 0xF8, 0x1A, 0x15, 0xBF, 0xC1, 0x30, 0xFB}, + } namespaceBytes := namespace.ToArray() buffer.Write(namespaceBytes[:]) - binary.Write(buffer, binary.BigEndian, utf16.Encode([]rune(strings.ToUpper(name)))) + _ = binary.Write(buffer, binary.BigEndian, utf16.Encode([]rune(strings.ToUpper(name)))) sum := buffer.Sum(nil) sum[7] = (sum[7] & 0xf) | 0x50 @@ -117,25 +133,24 @@ type providerOpts struct { } // ProviderOpt allows the caller to specify provider options to -// NewProviderWithOptions +// NewProviderWithOptions. type ProviderOpt func(*providerOpts) -// WithCallback is used to provide a callback option to NewProviderWithOptions +// WithCallback is used to provide a callback option to NewProviderWithOptions. func WithCallback(callback EnableCallback) ProviderOpt { return func(opts *providerOpts) { opts.callback = callback } } -// WithID is used to provide a provider ID option to NewProviderWithOptions +// WithID is used to provide a provider ID option to NewProviderWithOptions. func WithID(id guid.GUID) ProviderOpt { return func(opts *providerOpts) { opts.id = id } } -// WithGroup is used to provide a provider group option to -// NewProviderWithOptions +// WithGroup is used to provide a provider group option to NewProviderWithOptions. func WithGroup(group guid.GUID) ProviderOpt { return func(opts *providerOpts) { opts.group = group @@ -237,11 +252,17 @@ func (provider *Provider) WriteEvent(name string, eventOpts []EventOpt, fieldOpt // event metadata (e.g. for the name) so we don't need to do this check for // the metadata. dataBlobs := [][]byte{} - if len(ed.bytes()) > 0 { - dataBlobs = [][]byte{ed.bytes()} + if len(ed.toBytes()) > 0 { + dataBlobs = [][]byte{ed.toBytes()} } - return provider.writeEventRaw(options.descriptor, options.activityID, options.relatedActivityID, [][]byte{em.bytes()}, dataBlobs) + return provider.writeEventRaw( + options.descriptor, + options.activityID, + options.relatedActivityID, + [][]byte{em.toBytes()}, + dataBlobs, + ) } // writeEventRaw writes a single ETW event from the provider. This function is @@ -257,17 +278,24 @@ func (provider *Provider) writeEventRaw( relatedActivityID guid.GUID, metadataBlobs [][]byte, dataBlobs [][]byte) error { - dataDescriptorCount := uint32(1 + len(metadataBlobs) + len(dataBlobs)) dataDescriptors := make([]eventDataDescriptor, 0, dataDescriptorCount) - dataDescriptors = append(dataDescriptors, newEventDataDescriptor(eventDataDescriptorTypeProviderMetadata, provider.metadata)) + dataDescriptors = append(dataDescriptors, + newEventDataDescriptor(eventDataDescriptorTypeProviderMetadata, provider.metadata)) for _, blob := range metadataBlobs { - dataDescriptors = append(dataDescriptors, newEventDataDescriptor(eventDataDescriptorTypeEventMetadata, blob)) + dataDescriptors = append(dataDescriptors, + newEventDataDescriptor(eventDataDescriptorTypeEventMetadata, blob)) } for _, blob := range dataBlobs { - dataDescriptors = append(dataDescriptors, newEventDataDescriptor(eventDataDescriptorTypeUserData, blob)) + dataDescriptors = append(dataDescriptors, + newEventDataDescriptor(eventDataDescriptorTypeUserData, blob)) } - return eventWriteTransfer(provider.handle, descriptor, (*windows.GUID)(&activityID), (*windows.GUID)(&relatedActivityID), dataDescriptorCount, &dataDescriptors[0]) + return eventWriteTransfer(provider.handle, + descriptor, + (*windows.GUID)(&activityID), + (*windows.GUID)(&relatedActivityID), + dataDescriptorCount, + &dataDescriptors[0]) } diff --git a/pkg/etw/provider_test.go b/pkg/etw/provider_test.go index 6b7c5ef8..1a98a1af 100644 --- a/pkg/etw/provider_test.go +++ b/pkg/etw/provider_test.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package etw diff --git a/pkg/etw/providerglobal.go b/pkg/etw/providerglobal.go index ce3d3057..0a1d90dd 100644 --- a/pkg/etw/providerglobal.go +++ b/pkg/etw/providerglobal.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package etw @@ -14,7 +15,6 @@ type providerMap struct { m map[uint]*Provider i uint lock sync.Mutex - once sync.Once } var providers = providerMap{ @@ -50,5 +50,7 @@ func (p *providerMap) getProvider(index uint) *Provider { return p.m[index] } +//todo: combine these into struct, so that "globalProviderCallback" is guaranteed to be initialized through method access + var providerCallbackOnce sync.Once var globalProviderCallback uintptr diff --git a/pkg/etw/ptr64_32.go b/pkg/etw/ptr64_32.go index d1a76125..26c9f194 100644 --- a/pkg/etw/ptr64_32.go +++ b/pkg/etw/ptr64_32.go @@ -1,3 +1,5 @@ +//go:build windows && (386 || arm) +// +build windows // +build 386 arm package etw diff --git a/pkg/etw/ptr64_64.go b/pkg/etw/ptr64_64.go index b86c8f2b..1524c643 100644 --- a/pkg/etw/ptr64_64.go +++ b/pkg/etw/ptr64_64.go @@ -1,3 +1,5 @@ +//go:build windows && (amd64 || arm64) +// +build windows // +build amd64 arm64 package etw diff --git a/tools/etw-provider-gen/noop.go b/pkg/etw/sample/main_other.go similarity index 71% rename from tools/etw-provider-gen/noop.go rename to pkg/etw/sample/main_other.go index bf98ceb4..7c3e856c 100644 --- a/tools/etw-provider-gen/noop.go +++ b/pkg/etw/sample/main_other.go @@ -1,3 +1,4 @@ +//go:build !windows // +build !windows package main diff --git a/pkg/etw/sample/sample.go b/pkg/etw/sample/main_windows.go similarity index 99% rename from pkg/etw/sample/sample.go rename to pkg/etw/sample/main_windows.go index fd315965..c4e4e858 100644 --- a/pkg/etw/sample/sample.go +++ b/pkg/etw/sample/main_windows.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows // Shows a sample usage of the ETW logging package. diff --git a/pkg/etw/etw.go b/pkg/etw/syscall.go similarity index 72% rename from pkg/etw/etw.go rename to pkg/etw/syscall.go index 0bf6e1aa..16f3bb13 100644 --- a/pkg/etw/etw.go +++ b/pkg/etw/syscall.go @@ -1,13 +1,8 @@ -// Package etw provides support for TraceLogging-based ETW (Event Tracing -// for Windows). TraceLogging is a format of ETW events that are self-describing -// (the event contains information on its own schema). This allows them to be -// decoded without needing a separate manifest with event information. The -// implementation here is based on the information found in -// TraceLoggingProvider.h in the Windows SDK, which implements TraceLogging as a -// set of C macros. +//go:build windows + package etw -//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go etw.go +//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go syscall.go //sys eventRegister(providerId *windows.GUID, callback uintptr, callbackContext uintptr, providerHandle *providerHandle) (win32err error) = advapi32.EventRegister diff --git a/pkg/etw/wrapper_32.go b/pkg/etw/wrapper_32.go index 6867a1f8..14c49984 100644 --- a/pkg/etw/wrapper_32.go +++ b/pkg/etw/wrapper_32.go @@ -1,3 +1,4 @@ +//go:build windows && (386 || arm) // +build windows // +build 386 arm diff --git a/pkg/etw/wrapper_64.go b/pkg/etw/wrapper_64.go index fe83df2b..8cfe2e8c 100644 --- a/pkg/etw/wrapper_64.go +++ b/pkg/etw/wrapper_64.go @@ -1,3 +1,4 @@ +//go:build windows && (amd64 || arm64) // +build windows // +build amd64 arm64 @@ -19,7 +20,6 @@ func eventWriteTransfer( relatedActivityID *windows.GUID, dataDescriptorCount uint32, dataDescriptors *eventDataDescriptor) (win32err error) { - return eventWriteTransfer_64( providerHandle, descriptor, @@ -34,7 +34,6 @@ func eventSetInformation( class eventInfoClass, information uintptr, length uint32) (win32err error) { - return eventSetInformation_64( providerHandle, class, @@ -46,7 +45,21 @@ func eventSetInformation( // for provider notifications. Because Go has trouble with callback arguments of // different size, it has only pointer-sized arguments, which are then cast to // the appropriate types when calling providerCallback. -func providerCallbackAdapter(sourceID *guid.GUID, state uintptr, level uintptr, matchAnyKeyword uintptr, matchAllKeyword uintptr, filterData uintptr, i uintptr) uintptr { - providerCallback(*sourceID, ProviderState(state), Level(level), uint64(matchAnyKeyword), uint64(matchAllKeyword), filterData, i) +func providerCallbackAdapter( + sourceID *guid.GUID, + state uintptr, + level uintptr, + matchAnyKeyword uintptr, + matchAllKeyword uintptr, + filterData uintptr, + i uintptr, +) uintptr { + providerCallback(*sourceID, + ProviderState(state), + Level(level), + uint64(matchAnyKeyword), + uint64(matchAllKeyword), + filterData, + i) return 0 } diff --git a/pkg/etwlogrus/hook.go b/pkg/etwlogrus/hook.go index 8cf7baa1..76f6239a 100644 --- a/pkg/etwlogrus/hook.go +++ b/pkg/etwlogrus/hook.go @@ -17,7 +17,7 @@ const defaultEventName = "LogrusEntry" // ErrNoProvider is returned when a hook is created without a provider being configured. var ErrNoProvider = errors.New("no ETW registered provider") -// HookOpt is an option to change the behavior of the Logrus ETW hook +// HookOpt is an option to change the behavior of the Logrus ETW hook. type HookOpt func(*Hook) error // Hook is a Logrus hook which logs received events to ETW. @@ -30,16 +30,16 @@ type Hook struct { getEventsOpts func(*logrus.Entry) []etw.EventOpt } -// NewHook registers a new ETW provider and returns a hook to log from it. The -// provider will be closed when the hook is closed. +// NewHook registers a new ETW provider and returns a hook to log from it. +// The provider will be closed when the hook is closed. func NewHook(providerName string, opts ...HookOpt) (*Hook, error) { opts = append(opts, WithNewETWProvider(providerName)) return NewHookFromOpts(opts...) } -// NewHookFromProvider creates a new hook based on an existing ETW provider. The -// provider will not be closed when the hook is closed. +// NewHookFromProvider creates a new hook based on an existing ETW provider. +// The provider will not be closed when the hook is closed. func NewHookFromProvider(provider *etw.Provider, opts ...HookOpt) (*Hook, error) { opts = append(opts, WithExistingETWProvider(provider)) @@ -73,7 +73,7 @@ func (h *Hook) validate() error { // Levels returns the set of levels that this hook wants to receive log entries // for. -func (h *Hook) Levels() []logrus.Level { +func (*Hook) Levels() []logrus.Level { return logrus.AllLevels } @@ -142,7 +142,7 @@ func (h *Hook) Fire(e *logrus.Entry) error { // as a session listening for the event having no available space in its // buffers). Therefore, we don't return the error from WriteEvent, as it is // just noise in many cases. - h.provider.WriteEvent(name, opts, fields) + _ = h.provider.WriteEvent(name, opts, fields) return nil } diff --git a/pkg/etwlogrus/hook_test.go b/pkg/etwlogrus/hook_test.go index 6f3e7ae7..f6e24bde 100644 --- a/pkg/etwlogrus/hook_test.go +++ b/pkg/etwlogrus/hook_test.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package etwlogrus diff --git a/pkg/etwlogrus/opts.go b/pkg/etwlogrus/opts.go index ea3f2e4f..499fca87 100644 --- a/pkg/etwlogrus/opts.go +++ b/pkg/etwlogrus/opts.go @@ -44,7 +44,7 @@ func WithGetName(f func(*logrus.Entry) string) HookOpt { } } -// WithAdditionalEventOpts allows additional ETW event properties (keywords, tags, etc.) to be specified +// WithEventOpts allows additional ETW event properties (keywords, tags, etc.) to be specified. func WithEventOpts(f func(*logrus.Entry) []etw.EventOpt) HookOpt { return func(h *Hook) error { h.getEventsOpts = f diff --git a/pkg/fs/fs_windows.go b/pkg/fs/fs_windows.go index 53bc797e..7a73aafb 100644 --- a/pkg/fs/fs_windows.go +++ b/pkg/fs/fs_windows.go @@ -27,5 +27,5 @@ func GetFileSystemType(path string) (fsType string, err error) { drive += `\` err = windows.GetVolumeInformation(windows.StringToUTF16Ptr(drive), nil, 0, nil, nil, nil, &buf[0], size) fsType = windows.UTF16ToString(buf) - return + return fsType, err } diff --git a/pkg/fs/fs_windows_test.go b/pkg/fs/fs_windows_test.go index 512b2352..f5e4c6ea 100644 --- a/pkg/fs/fs_windows_test.go +++ b/pkg/fs/fs_windows_test.go @@ -1,6 +1,7 @@ package fs import ( + "errors" "os" "testing" ) @@ -18,7 +19,7 @@ func TestGetFSTypeOfKnownDrive(t *testing.T) { func TestGetFSTypeOfInvalidPath(t *testing.T) { _, err := GetFileSystemType("7:\\") - if err != ErrInvalidPath { + if !errors.Is(err, ErrInvalidPath) { t.Fatalf("Expected `ErrInvalidPath`, got %v", err) } } diff --git a/pkg/guid/guid.go b/pkg/guid/guid.go index 6e5526bd..7e8e60b1 100644 --- a/pkg/guid/guid.go +++ b/pkg/guid/guid.go @@ -7,7 +7,7 @@ package guid import ( "crypto/rand" - "crypto/sha1" + "crypto/sha1" //nolint:gosec // not used for secure application "encoding" "encoding/binary" "fmt" @@ -59,7 +59,7 @@ func NewV4() (GUID, error) { // big-endian UTF16 stream of bytes. If that is desired, the string can be // encoded as such before being passed to this function. func NewV5(namespace GUID, name []byte) (GUID, error) { - b := sha1.New() + b := sha1.New() //nolint:gosec // not used for secure application namespaceBytes := namespace.ToArray() b.Write(namespaceBytes[:]) b.Write(name) diff --git a/pkg/process/process.go b/pkg/process/process.go index 85c70fe6..873d24e9 100644 --- a/pkg/process/process.go +++ b/pkg/process/process.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package process @@ -32,6 +33,8 @@ func EnumProcesses() ([]uint32, error) { // ProcessMemoryCountersEx is the PROCESS_MEMORY_COUNTERS_EX struct from // Windows: // https://docs.microsoft.com/en-us/windows/win32/api/psapi/ns-psapi-process_memory_counters_ex +// +//nolint:revive // process.ProcessMemoryCountersEx stutters, too late to change it type ProcessMemoryCountersEx struct { Cb uint32 PageFaultCount uint32 @@ -75,7 +78,7 @@ func QueryFullProcessImageName(process windows.Handle, flags uint32) (string, er for { b := make([]uint16, bufferSize) err := queryFullProcessImageName(process, flags, &b[0], &bufferSize) - if err == windows.ERROR_INSUFFICIENT_BUFFER { + if err == windows.ERROR_INSUFFICIENT_BUFFER { //nolint:errorlint // err is Errno bufferSize = bufferSize * 2 continue } @@ -84,5 +87,4 @@ func QueryFullProcessImageName(process windows.Handle, flags uint32) (string, er } return windows.UTF16ToString(b[:bufferSize]), nil } - } diff --git a/pkg/security/grantvmgroupaccess.go b/pkg/security/grantvmgroupaccess.go index 60292078..6df87b74 100644 --- a/pkg/security/grantvmgroupaccess.go +++ b/pkg/security/grantvmgroupaccess.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package security @@ -20,6 +21,7 @@ type ( trusteeForm uint32 trusteeType uint32 + //nolint:structcheck // structcheck thinks fields are unused, but the are used to pass data to OS explicitAccess struct { accessPermissions accessMask accessMode accessMode @@ -27,6 +29,7 @@ type ( trustee trustee } + //nolint:structcheck,unused // structcheck thinks fields are unused, but the are used to pass data to OS trustee struct { multipleTrustee *trustee multipleTrusteeOperation int32 @@ -44,6 +47,7 @@ const ( desiredAccessReadControl desiredAccess = 0x20000 desiredAccessWriteDac desiredAccess = 0x40000 + //cspell:disable-next-line gvmga = "GrantVmGroupAccess:" inheritModeNoInheritance inheritMode = 0x0 @@ -56,9 +60,9 @@ const ( shareModeRead shareMode = 0x1 shareModeWrite shareMode = 0x2 - sidVmGroup = "S-1-5-83-0" + sidVMGroup = "S-1-5-83-0" - trusteeFormIsSid trusteeForm = 0 + trusteeFormIsSID trusteeForm = 0 trusteeTypeWellKnownGroup trusteeType = 5 ) @@ -67,6 +71,8 @@ const ( // include Grant ACE entries for the VM Group SID. This is a golang re- // implementation of the same function in vmcompute, just not exported in // RS5. Which kind of sucks. Sucks a lot :/ +// +//revive:disable-next-line:var-naming VM, not Vm func GrantVmGroupAccess(name string) error { // Stat (to determine if `name` is a directory). s, err := os.Stat(name) @@ -79,7 +85,7 @@ func GrantVmGroupAccess(name string) error { if err != nil { return err // Already wrapped } - defer syscall.CloseHandle(fd) + defer syscall.CloseHandle(fd) //nolint:errcheck // Get the current DACL and Security Descriptor. Must defer LocalFree on success. ot := objectTypeFileObject @@ -89,7 +95,7 @@ func GrantVmGroupAccess(name string) error { if err := getSecurityInfo(fd, uint32(ot), uint32(si), nil, nil, &origDACL, nil, &sd); err != nil { return fmt.Errorf("%s GetSecurityInfo %s: %w", gvmga, name, err) } - defer syscall.LocalFree((syscall.Handle)(unsafe.Pointer(sd))) + defer syscall.LocalFree((syscall.Handle)(unsafe.Pointer(sd))) //nolint:errcheck // Generate a new DACL which is the current DACL with the required ACEs added. // Must defer LocalFree on success. @@ -97,7 +103,7 @@ func GrantVmGroupAccess(name string) error { if err != nil { return err // Already wrapped } - defer syscall.LocalFree((syscall.Handle)(unsafe.Pointer(newDACL))) + defer syscall.LocalFree((syscall.Handle)(unsafe.Pointer(newDACL))) //nolint:errcheck // And finally use SetSecurityInfo to apply the updated DACL. if err := setSecurityInfo(fd, uint32(ot), uint32(si), uintptr(0), uintptr(0), newDACL, uintptr(0)); err != nil { @@ -110,16 +116,19 @@ func GrantVmGroupAccess(name string) error { // createFile is a helper function to call [Nt]CreateFile to get a handle to // the file or directory. func createFile(name string, isDir bool) (syscall.Handle, error) { - namep := syscall.StringToUTF16(name) + namep, err := syscall.UTF16FromString(name) + if err != nil { + return syscall.InvalidHandle, fmt.Errorf("could not convernt name to UTF-16: %w", err) + } da := uint32(desiredAccessReadControl | desiredAccessWriteDac) sm := uint32(shareModeRead | shareModeWrite) fa := uint32(syscall.FILE_ATTRIBUTE_NORMAL) if isDir { - fa = uint32(fa | syscall.FILE_FLAG_BACKUP_SEMANTICS) + fa |= syscall.FILE_FLAG_BACKUP_SEMANTICS } fd, err := syscall.CreateFile(&namep[0], da, sm, nil, syscall.OPEN_EXISTING, fa, 0) if err != nil { - return 0, fmt.Errorf("%s syscall.CreateFile %s: %w", gvmga, name, err) + return syscall.InvalidHandle, fmt.Errorf("%s syscall.CreateFile %s: %w", gvmga, name, err) } return fd, nil } @@ -128,9 +137,9 @@ func createFile(name string, isDir bool) (syscall.Handle, error) { // The caller is responsible for LocalFree of the returned DACL on success. func generateDACLWithAcesAdded(name string, isDir bool, origDACL uintptr) (uintptr, error) { // Generate pointers to the SIDs based on the string SIDs - sid, err := syscall.StringToSid(sidVmGroup) + sid, err := syscall.StringToSid(sidVMGroup) if err != nil { - return 0, fmt.Errorf("%s syscall.StringToSid %s %s: %w", gvmga, name, sidVmGroup, err) + return 0, fmt.Errorf("%s syscall.StringToSid %s %s: %w", gvmga, name, sidVMGroup, err) } inheritance := inheritModeNoInheritance @@ -139,12 +148,12 @@ func generateDACLWithAcesAdded(name string, isDir bool, origDACL uintptr) (uintp } eaArray := []explicitAccess{ - explicitAccess{ + { accessPermissions: accessMaskDesiredPermission, accessMode: accessModeGrant, inheritance: inheritance, trustee: trustee{ - trusteeForm: trusteeFormIsSid, + trusteeForm: trusteeFormIsSID, trusteeType: trusteeTypeWellKnownGroup, name: uintptr(unsafe.Pointer(sid)), }, diff --git a/pkg/security/grantvmgroupaccess_test.go b/pkg/security/grantvmgroupaccess_test.go index 2dffaa12..bd648935 100644 --- a/pkg/security/grantvmgroupaccess_test.go +++ b/pkg/security/grantvmgroupaccess_test.go @@ -1,9 +1,9 @@ -//+build windows +//go:build windows +// +build windows package security import ( - "io/ioutil" "os" "path/filepath" "regexp" @@ -36,7 +36,7 @@ const ( // S-1-5-83-1-3166535780-1122986932-343720105-43916321:(I)(R,W) func TestGrantVmGroupAccess(t *testing.T) { - f, err := ioutil.TempFile("", "gvmgafile") + f, err := os.CreateTemp("", "gvmgafile") if err != nil { t.Fatal(err) } @@ -45,16 +45,12 @@ func TestGrantVmGroupAccess(t *testing.T) { os.Remove(f.Name()) }() - d, err := ioutil.TempDir("", "gvmgadir") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(d) - + d := t.TempDir() find, err := os.Create(filepath.Join(d, "find.txt")) if err != nil { t.Fatal(err) } + defer find.Close() if err := GrantVmGroupAccess(f.Name()); err != nil { t.Fatal(err) @@ -88,7 +84,6 @@ func TestGrantVmGroupAccess(t *testing.T) { find.Name(), []string{`(I)(R)`}, ) - } func verifyVMAccountDACLs(t *testing.T, name string, permissions []string) { diff --git a/privilege.go b/privilege.go index 120b6055..0ff9dac9 100644 --- a/privilege.go +++ b/privilege.go @@ -25,22 +25,17 @@ import ( //sys lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) = advapi32.LookupPrivilegeDisplayNameW const ( - SE_PRIVILEGE_ENABLED = 2 + //revive:disable-next-line:var-naming ALL_CAPS + SE_PRIVILEGE_ENABLED = windows.SE_PRIVILEGE_ENABLED - ERROR_NOT_ALL_ASSIGNED syscall.Errno = 1300 + //revive:disable-next-line:var-naming ALL_CAPS + ERROR_NOT_ALL_ASSIGNED syscall.Errno = windows.ERROR_NOT_ALL_ASSIGNED SeBackupPrivilege = "SeBackupPrivilege" SeRestorePrivilege = "SeRestorePrivilege" SeSecurityPrivilege = "SeSecurityPrivilege" ) -const ( - securityAnonymous = iota - securityIdentification - securityImpersonation - securityDelegation -) - var ( privNames = make(map[string]uint64) privNameMutex sync.Mutex @@ -52,11 +47,9 @@ type PrivilegeError struct { } func (e *PrivilegeError) Error() string { - s := "" + s := "Could not enable privilege " if len(e.privileges) > 1 { s = "Could not enable privileges " - } else { - s = "Could not enable privilege " } for i, p := range e.privileges { if i != 0 { @@ -95,7 +88,7 @@ func RunWithPrivileges(names []string, fn func() error) error { } func mapPrivileges(names []string) ([]uint64, error) { - var privileges []uint64 + privileges := make([]uint64, 0, len(names)) privNameMutex.Lock() defer privNameMutex.Unlock() for _, name := range names { @@ -128,7 +121,7 @@ func enableDisableProcessPrivilege(names []string, action uint32) error { return err } - p, _ := windows.GetCurrentProcess() + p := windows.CurrentProcess() var token windows.Token err = windows.OpenProcessToken(p, windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY, &token) if err != nil { @@ -141,10 +134,10 @@ func enableDisableProcessPrivilege(names []string, action uint32) error { func adjustPrivileges(token windows.Token, privileges []uint64, action uint32) error { var b bytes.Buffer - binary.Write(&b, binary.LittleEndian, uint32(len(privileges))) + _ = binary.Write(&b, binary.LittleEndian, uint32(len(privileges))) for _, p := range privileges { - binary.Write(&b, binary.LittleEndian, p) - binary.Write(&b, binary.LittleEndian, action) + _ = binary.Write(&b, binary.LittleEndian, p) + _ = binary.Write(&b, binary.LittleEndian, action) } prevState := make([]byte, b.Len()) reqSize := uint32(0) @@ -152,7 +145,7 @@ func adjustPrivileges(token windows.Token, privileges []uint64, action uint32) e if !success { return err } - if err == ERROR_NOT_ALL_ASSIGNED { + if err == ERROR_NOT_ALL_ASSIGNED { //nolint:errorlint // err is Errno return &PrivilegeError{privileges} } return nil @@ -178,7 +171,7 @@ func getPrivilegeName(luid uint64) string { } func newThreadToken() (windows.Token, error) { - err := impersonateSelf(securityImpersonation) + err := impersonateSelf(windows.SecurityImpersonation) if err != nil { return 0, err } diff --git a/privileges_test.go b/privileges_test.go index 0d2f492b..2e2175bf 100644 --- a/privileges_test.go +++ b/privileges_test.go @@ -3,11 +3,15 @@ package winio -import "testing" +import ( + "errors" + "testing" +) func TestRunWithUnavailablePrivilege(t *testing.T) { err := RunWithPrivilege("SeCreateTokenPrivilege", func() error { return nil }) - if _, ok := err.(*PrivilegeError); err == nil || !ok { + var perr *PrivilegeError + if !errors.As(err, &perr) { t.Fatal("expected PrivilegeError") } } diff --git a/reparse.go b/reparse.go index 5f2f3a47..67d1a104 100644 --- a/reparse.go +++ b/reparse.go @@ -116,16 +116,16 @@ func EncodeReparsePoint(rp *ReparsePoint) []byte { } var b bytes.Buffer - binary.Write(&b, binary.LittleEndian, &data) + _ = binary.Write(&b, binary.LittleEndian, &data) if !rp.IsMountPoint { flags := uint32(0) if relative { flags |= 1 } - binary.Write(&b, binary.LittleEndian, flags) + _ = binary.Write(&b, binary.LittleEndian, flags) } - binary.Write(&b, binary.LittleEndian, ntTarget16) - binary.Write(&b, binary.LittleEndian, target16) + _ = binary.Write(&b, binary.LittleEndian, ntTarget16) + _ = binary.Write(&b, binary.LittleEndian, target16) return b.Bytes() } diff --git a/sd.go b/sd.go index 48d8557f..5550ef6b 100644 --- a/sd.go +++ b/sd.go @@ -4,6 +4,7 @@ package winio import ( + "errors" "syscall" "unsafe" @@ -19,11 +20,6 @@ import ( //sys localFree(mem uintptr) = LocalFree //sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength -const ( - cERROR_NONE_MAPPED = syscall.Errno(1332) - cERROR_INVALID_SID = syscall.Errno(1337) -) - type AccountLookupError struct { Name string Err error @@ -34,10 +30,10 @@ func (e *AccountLookupError) Error() string { return "lookup account: empty account name specified" } var s string - switch e.Err { - case cERROR_INVALID_SID: + switch { + case errors.Is(e.Err, windows.ERROR_INVALID_SID): s = "the security ID structure is invalid" - case cERROR_NONE_MAPPED: + case errors.Is(e.Err, windows.ERROR_NONE_MAPPED): s = "not found" default: s = e.Err.Error() @@ -45,6 +41,8 @@ func (e *AccountLookupError) Error() string { return "lookup account " + e.Name + ": " + s } +func (e *AccountLookupError) Unwrap() error { return e.Err } + type SddlConversionError struct { Sddl string Err error @@ -54,15 +52,19 @@ func (e *SddlConversionError) Error() string { return "convert " + e.Sddl + ": " + e.Err.Error() } +func (e *SddlConversionError) Unwrap() error { return e.Err } + // LookupSidByName looks up the SID of an account by name +// +//revive:disable-next-line:var-naming SID, not Sid func LookupSidByName(name string) (sid string, err error) { if name == "" { - return "", &AccountLookupError{name, cERROR_NONE_MAPPED} + return "", &AccountLookupError{name, windows.ERROR_NONE_MAPPED} } var sidSize, sidNameUse, refDomainSize uint32 err = lookupAccountName(nil, name, nil, &sidSize, nil, &refDomainSize, &sidNameUse) - if err != nil && err != syscall.ERROR_INSUFFICIENT_BUFFER { + if err != nil && err != syscall.ERROR_INSUFFICIENT_BUFFER { //nolint:errorlint // err is Errno return "", &AccountLookupError{name, err} } sidBuffer := make([]byte, sidSize) @@ -82,9 +84,11 @@ func LookupSidByName(name string) (sid string, err error) { } // LookupNameBySid looks up the name of an account by SID +// +//revive:disable-next-line:var-naming SID, not Sid func LookupNameBySid(sid string) (name string, err error) { if sid == "" { - return "", &AccountLookupError{sid, cERROR_NONE_MAPPED} + return "", &AccountLookupError{sid, windows.ERROR_NONE_MAPPED} } sidBuffer, err := windows.UTF16PtrFromString(sid) @@ -100,7 +104,7 @@ func LookupNameBySid(sid string) (name string, err error) { var nameSize, refDomainSize, sidNameUse uint32 err = lookupAccountSid(nil, sidPtr, nil, &nameSize, nil, &refDomainSize, &sidNameUse) - if err != nil && err != windows.ERROR_INSUFFICIENT_BUFFER { + if err != nil && err != windows.ERROR_INSUFFICIENT_BUFFER { //nolint:errorlint // err is Errno return "", &AccountLookupError{sid, err} } diff --git a/sd_test.go b/sd_test.go index 9a6aab8a..c72bcbf4 100644 --- a/sd_test.go +++ b/sd_test.go @@ -3,20 +3,25 @@ package winio -import "testing" +import ( + "errors" + "testing" + + "golang.org/x/sys/windows" +) func TestLookupInvalidSid(t *testing.T) { _, err := LookupSidByName(".\\weoifjdsklfj") - aerr, ok := err.(*AccountLookupError) - if !ok || aerr.Err != cERROR_NONE_MAPPED { + var aerr *AccountLookupError + if !errors.As(err, &aerr) || !errors.Is(err, windows.ERROR_NONE_MAPPED) { t.Fatalf("expected AccountLookupError with ERROR_NONE_MAPPED, got %s", err) } } func TestLookupInvalidName(t *testing.T) { _, err := LookupNameBySid("notasid") - aerr, ok := err.(*AccountLookupError) - if !ok || aerr.Err != cERROR_INVALID_SID { + var aerr *AccountLookupError + if !errors.As(err, &aerr) || !errors.Is(aerr.Err, windows.ERROR_INVALID_SID) { t.Fatalf("expected AccountLookupError with ERROR_INVALID_SID got %s", err) } } @@ -36,8 +41,8 @@ func TestLookupValidSid(t *testing.T) { func TestLookupEmptyNameFails(t *testing.T) { _, err := LookupSidByName("") - aerr, ok := err.(*AccountLookupError) - if !ok || aerr.Err != cERROR_NONE_MAPPED { + var aerr *AccountLookupError + if !errors.As(err, &aerr) || !errors.Is(aerr.Err, windows.ERROR_NONE_MAPPED) { t.Fatalf("expected AccountLookupError with ERROR_NONE_MAPPED, got %s", err) } } diff --git a/syscall.go b/syscall.go index ca0de234..a6ca111b 100644 --- a/syscall.go +++ b/syscall.go @@ -1,3 +1,5 @@ +//go:build windows + package winio //go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go ./*.go diff --git a/pkg/etw/sample/noop.go b/tools/etw-provider-gen/main_others.go similarity index 71% rename from pkg/etw/sample/noop.go rename to tools/etw-provider-gen/main_others.go index bf98ceb4..7c3e856c 100644 --- a/pkg/etw/sample/noop.go +++ b/tools/etw-provider-gen/main_others.go @@ -1,3 +1,4 @@ +//go:build !windows // +build !windows package main diff --git a/tools/etw-provider-gen/main.go b/tools/etw-provider-gen/main_windows.go similarity index 96% rename from tools/etw-provider-gen/main.go rename to tools/etw-provider-gen/main_windows.go index 9e6df9d6..9d316fbe 100644 --- a/tools/etw-provider-gen/main.go +++ b/tools/etw-provider-gen/main_windows.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package main diff --git a/tools/mkwinsyscall/doc.go b/tools/mkwinsyscall/doc.go index 3d638e41..20b172cc 100644 --- a/tools/mkwinsyscall/doc.go +++ b/tools/mkwinsyscall/doc.go @@ -18,13 +18,16 @@ like func declarations if //sys is replaced by func, but: - If go func name needs to be different from its winapi dll name, the winapi name could be specified at the end, after "=" sign, like + //sys LoadLibrary(libname string) (handle uint32, err error) = LoadLibraryA - Each function that returns err needs to supply a condition, that return value of winapi will be tested against to detect failure. This would set err to windows "last-error", otherwise it will be nil. The value can be provided at end of //sys declaration, like + //sys LoadLibrary(libname string) (handle uint32, err error) [failretval==-1] = LoadLibraryA + and is [failretval==0] by default. - If the function name ends in a "?", then the function not existing is non- @@ -32,21 +35,23 @@ like func declarations if //sys is replaced by func, but: Usage: - mkwinsyscall [flags] [path ...] - -The flags are: - - -output string - Specify output file name (standard output if omitted). - -sort - Sort DLL and Function declarations (default true). Setting to false is intended to maintain - compatibility with older versions of mkwinsyscall so that diffs are easier to read and understand. - -systemdll - Specify that all DLLs should be loaded from the Windows system directory (default true). - -trace - Generate print statement after every syscall. - winio - Import this package ("github.com/Microsoft/go-winio"). - + mkwinsyscall [flags] [path ...] + +Flags + + -output string + Output file name (standard output if omitted). + -sort + Sort DLL and function declarations (default true). + Intended to help transition from older versions of mkwinsyscall by making diffs + easier to read and understand. + -systemdll + Whether all DLLs should be loaded from the Windows system directory (default true). + -trace + Generate print statement after every syscall. + -utf16 + Encode string arguments as UTF-16 for syscalls not ending in 'A' or 'W' (default true). + -winio + Import this package ("github.com/Microsoft/go-winio"). */ package main diff --git a/tools/mkwinsyscall/mkwinsyscall.go b/tools/mkwinsyscall/mkwinsyscall.go index de362c6a..e72be313 100644 --- a/tools/mkwinsyscall/mkwinsyscall.go +++ b/tools/mkwinsyscall/mkwinsyscall.go @@ -16,7 +16,6 @@ import ( "go/parser" "go/token" "io" - "io/ioutil" "log" "os" "path/filepath" @@ -29,6 +28,24 @@ import ( "golang.org/x/sys/windows" ) +const ( + pkgSyscall = "syscall" + pkgWindows = "windows" + + // common types. + + tBool = "bool" + tBoolPtr = "*bool" + tError = "error" + tString = "string" + + // error variable names. + + varErr = "err" + varErrNTStatus = "ntStatus" + varErrHR = "hr" +) + var ( filename = flag.String("output", "", "output file name (standard output if omitted)") printTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall") @@ -53,20 +70,20 @@ func packagename() string { } func windowsdot() string { - if packageName == "windows" { + if packageName == pkgWindows { return "" } - return "windows." + return pkgWindows + "." } func syscalldot() string { - if packageName == "syscall" { + if packageName == pkgSyscall { return "" } - return "syscall." + return pkgSyscall + "." } -// Param is function parameter +// Param is function parameter. type Param struct { Name string Type string @@ -134,9 +151,9 @@ func (p *Param) StringTmpVarCode() string { // TmpVarCode returns source code for temp variable. func (p *Param) TmpVarCode() string { switch { - case p.Type == "bool": + case p.Type == tBool: return p.BoolTmpVarCode() - case p.Type == "*bool": + case p.Type == tBoolPtr: return p.BoolPointerTmpVarCode() case strings.HasPrefix(p.Type, "[]"): return p.SliceTmpVarCode() @@ -148,7 +165,7 @@ func (p *Param) TmpVarCode() string { // TmpVarReadbackCode returns source code for reading back the temp variable into the original variable. func (p *Param) TmpVarReadbackCode() string { switch { - case p.Type == "*bool": + case p.Type == tBoolPtr: return fmt.Sprintf("*%s = %s != 0", p.Name, p.tmpVar()) default: return "" @@ -170,11 +187,11 @@ func (p *Param) SyscallArgList() []string { t := p.HelperType() var s string switch { - case t == "*bool": + case t == tBoolPtr: s = fmt.Sprintf("unsafe.Pointer(&%s)", p.tmpVar()) case t[0] == '*': s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name) - case t == "bool": + case t == tBool: s = p.tmpVar() case strings.HasPrefix(t, "[]"): return []string{ @@ -189,12 +206,12 @@ func (p *Param) SyscallArgList() []string { // IsError determines if p parameter is used to return error. func (p *Param) IsError() bool { - return p.Name == "err" && p.Type == "error" + return p.Name == varErr && p.Type == tError } // HelperType returns type of parameter p used in helper function. func (p *Param) HelperType() string { - if p.Type == "string" { + if p.Type == tString { return p.fn.StrconvType() } return p.Type @@ -226,9 +243,9 @@ type Rets struct { // ErrorVarName returns error variable name for r. func (r *Rets) ErrorVarName() string { if r.ReturnsError { - return "err" + return varErr } - if r.Type == "error" { + if r.Type == tError { return r.Name } return "" @@ -241,7 +258,7 @@ func (r *Rets) ToParams() []*Param { ps = append(ps, &Param{Name: r.Name, Type: r.Type}) } if r.ReturnsError { - ps = append(ps, &Param{Name: "err", Type: "error"}) + ps = append(ps, &Param{Name: varErr, Type: tError}) } return ps } @@ -295,8 +312,8 @@ func (r *Rets) SetErrorCode() string { const code = `if r0 != 0 { %s = %sErrno(r0) }` - const ntstatus = `if r0 != 0 { - ntstatus = %sNTStatus(r0) + const ntStatus = `if r0 != 0 { + %s = %sNTStatus(r0) }` const hrCode = `if int32(r0) < 0 { if r0&0x1fff0000 == 0x00070000 { @@ -311,22 +328,22 @@ func (r *Rets) SetErrorCode() string { if r.Name == "" { return r.useLongHandleErrorCode("r1") } - if r.Type == "error" { + if r.Type == tError { switch r.Name { - case "ntstatus": - return fmt.Sprintf(ntstatus, windowsdot()) - case "hr": + case varErrNTStatus, strings.ToLower(varErrNTStatus): // allow ntstatus to work + return fmt.Sprintf(ntStatus, r.Name, windowsdot()) + case varErrHR: return fmt.Sprintf(hrCode, r.Name, syscalldot()) default: return fmt.Sprintf(code, r.Name, syscalldot()) } } - s := "" + var s string switch { case r.Type[0] == '*': s = fmt.Sprintf("%s = (%s)(unsafe.Pointer(r0))", r.Name, r.Type) - case r.Type == "bool": + case r.Type == tBool: s = fmt.Sprintf("%s = r0 != 0", r.Name) default: s = fmt.Sprintf("%s = %s(r0)", r.Name, r.Type) @@ -564,7 +581,7 @@ func (f *Fn) HelperCallParamList() string { a := make([]string, 0, len(f.Params)) for _, p := range f.Params { s := p.Name - if p.Type == "string" { + if p.Type == tString { s = p.tmpVar() } a = append(a, s) @@ -583,7 +600,7 @@ func (f *Fn) MaybeAbsent() string { }` errorVar := f.Rets.ErrorVarName() if errorVar == "" { - errorVar = "err" + errorVar = varErr } return fmt.Sprintf(code, errorVar, f.DLLFuncName()) } @@ -616,7 +633,7 @@ func (f *Fn) StrconvType() string { // Otherwise it is false. func (f *Fn) HasStringParam() bool { for _, p := range f.Params { - if p.Type == "string" { + if p.Type == tString { return true } } @@ -892,7 +909,8 @@ func main() { if *filename == "" { _, err = os.Stdout.Write(data) } else { - err = ioutil.WriteFile(*filename, data, 0644) + //nolint:gosec // G306: code file, no need for wants 0600 + err = os.WriteFile(*filename, data, 0644) } if err != nil { log.Fatal(err) @@ -900,6 +918,7 @@ func main() { } // TODO: use println instead to print in the following template + const srcTemplate = ` {{define "main"}} //go:build windows diff --git a/vhd/vhd.go b/vhd/vhd.go index 46f69984..b54cad11 100644 --- a/vhd/vhd.go +++ b/vhd/vhd.go @@ -62,8 +62,8 @@ type OpenVirtualDiskParameters struct { Version2 OpenVersion2 } -// The higher level `OpenVersion2` struct uses bools to refer to `GetInfoOnly` and `ReadOnly` for ease of use. However, -// the internal windows structure uses `BOOLS` aka int32s for these types. `openVersion2` is used for translating +// The higher level `OpenVersion2` struct uses `bool`s to refer to `GetInfoOnly` and `ReadOnly` for ease of use. However, +// the internal windows structure uses `BOOL`s aka int32s for these types. `openVersion2` is used for translating // `OpenVersion2` fields to the correct windows internal field types on the `Open____` methods. type openVersion2 struct { getInfoOnly int32 @@ -87,9 +87,10 @@ type AttachVirtualDiskParameters struct { } const ( + //revive:disable-next-line:var-naming ALL_CAPS VIRTUAL_STORAGE_TYPE_DEVICE_VHDX = 0x3 - // Access Mask for opening a VHD + // Access Mask for opening a VHD. VirtualDiskAccessNone VirtualDiskAccessMask = 0x00000000 VirtualDiskAccessAttachRO VirtualDiskAccessMask = 0x00010000 VirtualDiskAccessAttachRW VirtualDiskAccessMask = 0x00020000 @@ -101,7 +102,7 @@ const ( VirtualDiskAccessAll VirtualDiskAccessMask = 0x003f0000 VirtualDiskAccessWritable VirtualDiskAccessMask = 0x00320000 - // Flags for creating a VHD + // Flags for creating a VHD. CreateVirtualDiskFlagNone CreateVirtualDiskFlag = 0x0 CreateVirtualDiskFlagFullPhysicalAllocation CreateVirtualDiskFlag = 0x1 CreateVirtualDiskFlagPreventWritesToSourceDisk CreateVirtualDiskFlag = 0x2 @@ -109,12 +110,12 @@ const ( CreateVirtualDiskFlagCreateBackingStorage CreateVirtualDiskFlag = 0x8 CreateVirtualDiskFlagUseChangeTrackingSourceLimit CreateVirtualDiskFlag = 0x10 CreateVirtualDiskFlagPreserveParentChangeTrackingState CreateVirtualDiskFlag = 0x20 - CreateVirtualDiskFlagVhdSetUseOriginalBackingStorage CreateVirtualDiskFlag = 0x40 + CreateVirtualDiskFlagVhdSetUseOriginalBackingStorage CreateVirtualDiskFlag = 0x40 //revive:disable-line:var-naming VHD, not Vhd CreateVirtualDiskFlagSparseFile CreateVirtualDiskFlag = 0x80 - CreateVirtualDiskFlagPmemCompatible CreateVirtualDiskFlag = 0x100 + CreateVirtualDiskFlagPmemCompatible CreateVirtualDiskFlag = 0x100 //revive:disable-line:var-naming PMEM, not Pmem CreateVirtualDiskFlagSupportCompressedVolumes CreateVirtualDiskFlag = 0x200 - // Flags for opening a VHD + // Flags for opening a VHD. OpenVirtualDiskFlagNone VirtualDiskFlag = 0x00000000 OpenVirtualDiskFlagNoParents VirtualDiskFlag = 0x00000001 OpenVirtualDiskFlagBlankFile VirtualDiskFlag = 0x00000002 @@ -127,7 +128,7 @@ const ( OpenVirtualDiskFlagNoWriteHardening VirtualDiskFlag = 0x00000100 OpenVirtualDiskFlagSupportCompressedVolumes VirtualDiskFlag = 0x00000200 - // Flags for attaching a VHD + // Flags for attaching a VHD. AttachVirtualDiskFlagNone AttachVirtualDiskFlag = 0x00000000 AttachVirtualDiskFlagReadOnly AttachVirtualDiskFlag = 0x00000001 AttachVirtualDiskFlagNoDriveLetter AttachVirtualDiskFlag = 0x00000002 @@ -140,12 +141,14 @@ const ( AttachVirtualDiskFlagSinglePartition AttachVirtualDiskFlag = 0x00000100 AttachVirtualDiskFlagRegisterVolume AttachVirtualDiskFlag = 0x00000200 - // Flags for detaching a VHD + // Flags for detaching a VHD. DetachVirtualDiskFlagNone DetachVirtualDiskFlag = 0x0 ) // CreateVhdx is a helper function to create a simple vhdx file at the given path using // default values. +// +//revive:disable-next-line:var-naming VHDX, not Vhdx func CreateVhdx(path string, maxSizeInGb, blockSizeInMb uint32) error { params := CreateVirtualDiskParameters{ Version: 2, @@ -172,6 +175,8 @@ func DetachVirtualDisk(handle syscall.Handle) (err error) { } // DetachVhd detaches a vhd found at `path`. +// +//revive:disable-next-line:var-naming VHD, not Vhd func DetachVhd(path string) error { handle, err := OpenVirtualDisk( path, @@ -181,12 +186,16 @@ func DetachVhd(path string) error { if err != nil { return err } - defer syscall.CloseHandle(handle) + defer syscall.CloseHandle(handle) //nolint:errcheck return DetachVirtualDisk(handle) } // AttachVirtualDisk attaches a virtual hard disk for use. -func AttachVirtualDisk(handle syscall.Handle, attachVirtualDiskFlag AttachVirtualDiskFlag, parameters *AttachVirtualDiskParameters) (err error) { +func AttachVirtualDisk( + handle syscall.Handle, + attachVirtualDiskFlag AttachVirtualDiskFlag, + parameters *AttachVirtualDiskParameters, +) (err error) { // Supports both version 1 and 2 of the attach parameters as version 2 wasn't present in RS5. if err := attachVirtualDisk( handle, @@ -203,6 +212,8 @@ func AttachVirtualDisk(handle syscall.Handle, attachVirtualDiskFlag AttachVirtua // AttachVhd attaches a virtual hard disk at `path` for use. Attaches using version 2 // of the ATTACH_VIRTUAL_DISK_PARAMETERS. +// +//revive:disable-next-line:var-naming VHD, not Vhd func AttachVhd(path string) (err error) { handle, err := OpenVirtualDisk( path, @@ -213,7 +224,7 @@ func AttachVhd(path string) (err error) { return err } - defer syscall.CloseHandle(handle) + defer syscall.CloseHandle(handle) //nolint:errcheck params := AttachVirtualDiskParameters{Version: 2} if err := AttachVirtualDisk( handle, @@ -226,7 +237,11 @@ func AttachVhd(path string) (err error) { } // OpenVirtualDisk obtains a handle to a VHD opened with supplied access mask and flags. -func OpenVirtualDisk(vhdPath string, virtualDiskAccessMask VirtualDiskAccessMask, openVirtualDiskFlags VirtualDiskFlag) (syscall.Handle, error) { +func OpenVirtualDisk( + vhdPath string, + virtualDiskAccessMask VirtualDiskAccessMask, + openVirtualDiskFlags VirtualDiskFlag, +) (syscall.Handle, error) { parameters := OpenVirtualDiskParameters{Version: 2} handle, err := OpenVirtualDiskWithParameters( vhdPath, @@ -241,7 +256,12 @@ func OpenVirtualDisk(vhdPath string, virtualDiskAccessMask VirtualDiskAccessMask } // OpenVirtualDiskWithParameters obtains a handle to a VHD opened with supplied access mask, flags and parameters. -func OpenVirtualDiskWithParameters(vhdPath string, virtualDiskAccessMask VirtualDiskAccessMask, openVirtualDiskFlags VirtualDiskFlag, parameters *OpenVirtualDiskParameters) (syscall.Handle, error) { +func OpenVirtualDiskWithParameters( + vhdPath string, + virtualDiskAccessMask VirtualDiskAccessMask, + openVirtualDiskFlags VirtualDiskFlag, + parameters *OpenVirtualDiskParameters, +) (syscall.Handle, error) { var ( handle syscall.Handle defaultType VirtualStorageType @@ -279,7 +299,12 @@ func OpenVirtualDiskWithParameters(vhdPath string, virtualDiskAccessMask Virtual } // CreateVirtualDisk creates a virtual harddisk and returns a handle to the disk. -func CreateVirtualDisk(path string, virtualDiskAccessMask VirtualDiskAccessMask, createVirtualDiskFlags CreateVirtualDiskFlag, parameters *CreateVirtualDiskParameters) (syscall.Handle, error) { +func CreateVirtualDisk( + path string, + virtualDiskAccessMask VirtualDiskAccessMask, + createVirtualDiskFlags CreateVirtualDiskFlag, + parameters *CreateVirtualDiskParameters, +) (syscall.Handle, error) { var ( handle syscall.Handle defaultType VirtualStorageType @@ -323,6 +348,8 @@ func GetVirtualDiskPhysicalPath(handle syscall.Handle) (_ string, err error) { } // CreateDiffVhd is a helper function to create a differencing virtual disk. +// +//revive:disable-next-line:var-naming VHD, not Vhd func CreateDiffVhd(diffVhdPath, baseVhdPath string, blockSizeInMB uint32) error { // Setting `ParentPath` is how to signal to create a differencing disk. createParams := &CreateVirtualDiskParameters{ diff --git a/wim/decompress.go b/wim/decompress.go index 01fe01cb..993a4131 100644 --- a/wim/decompress.go +++ b/wim/decompress.go @@ -1,3 +1,4 @@ +//go:build windows || linux // +build windows linux package wim @@ -5,7 +6,6 @@ package wim import ( "encoding/binary" "io" - "io/ioutil" "github.com/Microsoft/go-winio/wim/lzx" ) @@ -35,7 +35,6 @@ func newCompressedReader(r *io.SectionReader, originalSize int64, offset int64) for i, n := range chunks32 { chunks[i+1] = int64(n) } - } else { // 64-bit chunk offsets base = (nchunks - 1) * 8 @@ -62,7 +61,7 @@ func newCompressedReader(r *io.SectionReader, originalSize int64, offset int64) suboff := offset % chunkSize if suboff != 0 { - _, err := io.CopyN(ioutil.Discard, cr.d, suboff) + _, err := io.CopyN(io.Discard, cr.d, suboff) if err != nil { return nil, err } @@ -110,7 +109,7 @@ func (r *compressedReader) reset(n int) error { } r.d = d } else { - r.d = ioutil.NopCloser(section) + r.d = io.NopCloser(section) } return nil @@ -119,7 +118,7 @@ func (r *compressedReader) reset(n int) error { func (r *compressedReader) Read(b []byte) (int, error) { for { n, err := r.d.Read(b) - if err != io.EOF { + if err != io.EOF { //nolint:errorlint return n, err } diff --git a/wim/lzx/lzx.go b/wim/lzx/lzx.go index 4deb0df7..f7cf69a9 100644 --- a/wim/lzx/lzx.go +++ b/wim/lzx/lzx.go @@ -100,7 +100,7 @@ func (f *decompressor) ensureAtLeast(n int) error { } n, err := io.ReadAtLeast(f.r, f.b[f.bv-f.bo:], n) if err != nil { - if err == io.EOF { + if err == io.EOF { //nolint:errorlint err = io.ErrUnexpectedEOF } else { f.fail(err) @@ -117,10 +117,8 @@ func (f *decompressor) ensureAtLeast(n int) error { // Otherwise, on error, it sets f.err. func (f *decompressor) feed() bool { err := f.ensureAtLeast(2) - if err != nil { - if err == io.ErrUnexpectedEOF { - return false - } + if err == io.ErrUnexpectedEOF { //nolint:errorlint // returns io.ErrUnexpectedEOF by contract + return false } f.c |= (uint32(f.b[f.bo+1])<<8 | uint32(f.b[f.bo])) << (16 - f.nbits) f.nbits += 16 @@ -232,9 +230,8 @@ func (f *decompressor) getCode(h *huffman) uint16 { // are, since entries with all possible suffixes were // added to the table. c := h.table[f.c>>(32-tablebits)] - if c >= 1<= 1<>(32-(h.maxbits-tablebits))] } @@ -399,41 +396,37 @@ func (f *decompressor) readTrees(readAligned bool) (main *huffman, length *huffm } aligned = buildTable(alignedLen[:]) if aligned == nil { - err = errors.New("corrupt") - return + return main, length, aligned, errors.New("corrupt") } } // The main tree is encoded in two parts. err = f.readTree(f.mainlens[:maincodesplit]) if err != nil { - return + return main, length, aligned, err } err = f.readTree(f.mainlens[maincodesplit:]) if err != nil { - return + return main, length, aligned, err } main = buildTable(f.mainlens[:]) if main == nil { - err = errors.New("corrupt") - return + return main, length, aligned, errors.New("corrupt") } // The length tree is encoding in a single part. err = f.readTree(f.lenlens[:]) if err != nil { - return + return main, length, aligned, err } length = buildTable(f.lenlens[:]) if length == nil { - err = errors.New("corrupt") - return + return main, length, aligned, errors.New("corrupt") } - err = f.err - return + return main, length, aligned, f.err } // readCompressedBlock decodes a compressed block, writing into the window @@ -465,7 +458,7 @@ func (f *decompressor) readCompressedBlock(start, end uint16, hmain, hlength, ha matchlen += 2 var matchoffset uint16 - if slot < 3 { + if slot < 3 { //nolint:nestif // todo: simplify nested complexity // The offset is one of the LRU values. matchoffset = f.lru[slot] f.lru[slot] = f.lru[0] @@ -586,7 +579,7 @@ func (f *decompressor) Read(b []byte) (int, error) { return f.windowReader.Read(b) } -func (f *decompressor) Close() error { +func (*decompressor) Close() error { return nil } diff --git a/wim/validate/noop.go b/wim/validate/main_other.go similarity index 71% rename from wim/validate/noop.go rename to wim/validate/main_other.go index bf98ceb4..7c3e856c 100644 --- a/wim/validate/noop.go +++ b/wim/validate/main_other.go @@ -1,3 +1,4 @@ +//go:build !windows // +build !windows package main diff --git a/wim/validate/validate.go b/wim/validate/main_windows.go similarity index 85% rename from wim/validate/validate.go rename to wim/validate/main_windows.go index 2536c002..b7329fff 100644 --- a/wim/validate/validate.go +++ b/wim/validate/main_windows.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package main @@ -20,7 +21,6 @@ func main() { w, err := wim.NewReader(f) if err != nil { panic(err) - } fmt.Printf("%#v\n%#v\n", w.Image[0], w.Image[0].Windows) @@ -39,13 +39,13 @@ func main() { func recur(d *wim.File) error { files, err := d.Readdir() if err != nil { - return fmt.Errorf("%s: %s", d.Name, err) + return fmt.Errorf("%s: %w", d.Name, err) } for _, f := range files { if f.IsDir() { err = recur(f) if err != nil { - return fmt.Errorf("%s: %s", f.Name, err) + return fmt.Errorf("%s: %w", f.Name, err) } } } diff --git a/wim/wim.go b/wim/wim.go index 45e0c515..8f272f3d 100644 --- a/wim/wim.go +++ b/wim/wim.go @@ -1,3 +1,4 @@ +//go:build windows || linux // +build windows linux // Package wim implements a WIM file parser. @@ -8,13 +9,12 @@ package wim import ( "bytes" - "crypto/sha1" + "crypto/sha1" //nolint:gosec // not used for secure application "encoding/binary" "encoding/xml" "errors" "fmt" "io" - "io/ioutil" "strconv" "sync" "time" @@ -22,6 +22,8 @@ import ( ) // File attribute constants from Windows. +// +//nolint:revive // var-naming: ALL_CAPS const ( FILE_ATTRIBUTE_READONLY = 0x00000001 FILE_ATTRIBUTE_HIDDEN = 0x00000002 @@ -44,6 +46,8 @@ const ( ) // Windows processor architectures. +// +//nolint:revive // var-naming: ALL_CAPS const ( PROCESSOR_ARCHITECTURE_INTEL = 0 PROCESSOR_ARCHITECTURE_MIPS = 1 @@ -62,6 +66,8 @@ const ( var wimImageTag = [...]byte{'M', 'S', 'W', 'I', 'M', 0, 0, 0} +// todo: replace this with pkg/guid.GUID (and add tests to make sure nothing breaks) + type guid struct { Data1 uint32 Data2 uint16 @@ -70,7 +76,18 @@ type guid struct { } func (g guid) String() string { - return fmt.Sprintf("%08x-%04x-%04x-%02x%02x-%02x%02x%02x%02x%02x%02x", g.Data1, g.Data2, g.Data3, g.Data4[0], g.Data4[1], g.Data4[2], g.Data4[3], g.Data4[4], g.Data4[5], g.Data4[6], g.Data4[7]) + return fmt.Sprintf("%08x-%04x-%04x-%02x%02x-%02x%02x%02x%02x%02x%02x", + g.Data1, + g.Data2, + g.Data3, + g.Data4[0], + g.Data4[1], + g.Data4[2], + g.Data4[3], + g.Data4[4], + g.Data4[5], + g.Data4[6], + g.Data4[7]) } type resourceDescriptor struct { @@ -81,6 +98,7 @@ type resourceDescriptor struct { type resFlag byte +//nolint:deadcode,varcheck // need unused variables for iota to work const ( resFlagFree resFlag = 1 << iota resFlagMetadata @@ -120,6 +138,7 @@ type streamDescriptor struct { type hdrFlag uint32 +//nolint:deadcode,varcheck // need unused variables for iota to work const ( hdrFlagReserved hdrFlag = 1 << iota hdrFlagCompressed @@ -131,6 +150,7 @@ const ( hdrFlagRpFix ) +//nolint:deadcode,varcheck // need unused variables for iota to work const ( hdrFlagCompressReserved hdrFlag = 1 << (iota + 16) hdrFlagCompressXpress @@ -208,13 +228,13 @@ func (ft *Filetime) Time() time.Time { return time.Unix(0, nsec) } -// UnmarshalXML unmarshals the time from a WIM XML blob. +// UnmarshalXML unmarshalls the time from a WIM XML blob. func (ft *Filetime) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { - type time struct { + type Time struct { Low string `xml:"LOWPART"` High string `xml:"HIGHPART"` } - var t time + var t Time err := d.DecodeElement(&t, &start) if err != nil { return err @@ -283,6 +303,8 @@ func (e *ParseError) Error() string { return fmt.Sprintf("WIM parse error: %s %s: %s", e.Oper, e.Path, e.Err.Error()) } +func (e *ParseError) Unwrap() error { return e.Err } + // Reader provides functions to read a WIM file. type Reader struct { hdr wimHeader @@ -380,14 +402,14 @@ func NewReader(f io.ReaderAt) (*Reader, error) { return nil, err } - var info info - err = xml.Unmarshal([]byte(xmlinfo), &info) + var inf info + err = xml.Unmarshal([]byte(xmlinfo), &inf) if err != nil { return nil, &ParseError{Oper: "XML info", Err: err} } for i, img := range images { - for _, imgInfo := range info.Image { + for _, imgInfo := range inf.Image { if imgInfo.Index == i+1 { img.ImageInfo = imgInfo break @@ -417,8 +439,8 @@ func (r *Reader) resourceReaderWithOffset(hdr *resourceDescriptor, offset int64) var sr io.ReadCloser section := io.NewSectionReader(r.r, hdr.Offset, hdr.CompressedSize()) if hdr.Flags()&resFlagCompressed == 0 { - section.Seek(offset, 0) - sr = ioutil.NopCloser(section) + _, _ = section.Seek(offset, 0) + sr = io.NopCloser(section) } else { cr, err := newCompressedReader(section, hdr.OriginalSize, offset) if err != nil { @@ -436,7 +458,7 @@ func (r *Reader) readResource(hdr *resourceDescriptor) ([]byte, error) { return nil, err } defer rsrc.Close() - return ioutil.ReadAll(rsrc) + return io.ReadAll(rsrc) } func (r *Reader) readXML() (string, error) { @@ -449,17 +471,17 @@ func (r *Reader) readXML() (string, error) { } defer rsrc.Close() - XMLData := make([]uint16, r.hdr.XMLData.OriginalSize/2) - err = binary.Read(rsrc, binary.LittleEndian, XMLData) + xmlData := make([]uint16, r.hdr.XMLData.OriginalSize/2) + err = binary.Read(rsrc, binary.LittleEndian, xmlData) if err != nil { return "", &ParseError{Oper: "XML data", Err: err} } // The BOM will always indicate little-endian UTF-16. - if XMLData[0] != 0xfeff { + if xmlData[0] != 0xfeff { return "", &ParseError{Oper: "XML data", Err: errors.New("invalid BOM")} } - return string(utf16.Decode(XMLData[1:])), nil + return string(utf16.Decode(xmlData[1:])), nil } func (r *Reader) readOffsetTable(res *resourceDescriptor) (map[SHA1Hash]resourceDescriptor, []*Image, error) { @@ -475,7 +497,7 @@ func (r *Reader) readOffsetTable(res *resourceDescriptor) (map[SHA1Hash]resource for i := 0; ; i++ { var res streamDescriptor err := binary.Read(br, binary.LittleEndian, &res) - if err == io.EOF { + if err == io.EOF { //nolint:errorlint break } if err != nil { @@ -491,7 +513,7 @@ func (r *Reader) readOffsetTable(res *resourceDescriptor) (map[SHA1Hash]resource if err != nil { panic(fmt.Sprint(i, err)) } - hash := sha1.New() + hash := sha1.New() //nolint:gosec // not used for secure application _, err = io.Copy(hash, sec) sec.Close() if err != nil { @@ -522,12 +544,11 @@ func (r *Reader) readOffsetTable(res *resourceDescriptor) (map[SHA1Hash]resource return fileData, images, nil } -func (r *Reader) readSecurityDescriptors(rsrc io.Reader) (sds [][]byte, n int64, err error) { +func (*Reader) readSecurityDescriptors(rsrc io.Reader) (sds [][]byte, n int64, err error) { var secBlock securityblockDisk err = binary.Read(rsrc, binary.LittleEndian, &secBlock) if err != nil { - err = &ParseError{Oper: "security table", Err: err} - return + return sds, 0, &ParseError{Oper: "security table", Err: err} } n += securityblockDiskSize @@ -535,8 +556,7 @@ func (r *Reader) readSecurityDescriptors(rsrc io.Reader) (sds [][]byte, n int64, secSizes := make([]int64, secBlock.NumEntries) err = binary.Read(rsrc, binary.LittleEndian, &secSizes) if err != nil { - err = &ParseError{Oper: "security table sizes", Err: err} - return + return sds, n, &ParseError{Oper: "security table sizes", Err: err} } n += int64(secBlock.NumEntries * 8) @@ -546,8 +566,7 @@ func (r *Reader) readSecurityDescriptors(rsrc io.Reader) (sds [][]byte, n int64, sd := make([]byte, size&0xffffffff) _, err = io.ReadFull(rsrc, sd) if err != nil { - err = &ParseError{Oper: "security descriptor", Err: err} - return + return sds, n, &ParseError{Oper: "security descriptor", Err: err} } n += int64(len(sd)) sds[i] = sd @@ -555,17 +574,16 @@ func (r *Reader) readSecurityDescriptors(rsrc io.Reader) (sds [][]byte, n int64, secsize := int64((secBlock.TotalLength + 7) &^ 7) if n > secsize { - err = &ParseError{Oper: "security descriptor", Err: errors.New("security descriptor table too small")} - return + return sds, n, &ParseError{Oper: "security descriptor", Err: errors.New("security descriptor table too small")} } - _, err = io.CopyN(ioutil.Discard, rsrc, secsize-n) + _, err = io.CopyN(io.Discard, rsrc, secsize-n) if err != nil { - return + return sds, n, err } n = secsize - return + return sds, n, nil } // Open parses the image and returns the root directory. @@ -621,10 +639,10 @@ func (img *Image) readdir(offset int64) ([]*File, error) { img.curOffset = offset } if offset > img.curOffset { - _, err := io.CopyN(ioutil.Discard, img.r, offset-img.curOffset) + _, err := io.CopyN(io.Discard, img.r, offset-img.curOffset) if err != nil { img.reset() - if err == io.EOF { + if err == io.EOF { //nolint:errorlint err = io.ErrUnexpectedEOF } return nil, err @@ -635,7 +653,7 @@ func (img *Image) readdir(offset int64) ([]*File, error) { for { e, n, err := img.readNextEntry(img.r) img.curOffset += n - if err == io.EOF { + if err == io.EOF { //nolint:errorlint break } if err != nil { @@ -699,7 +717,11 @@ func (img *Image) readNextEntry(r io.Reader) (*File, int64, error) { var ok bool offset, ok = img.wim.fileData[dentry.Hash] if !ok { - return nil, 0, &ParseError{Oper: "directory entry", Path: name, Err: fmt.Errorf("could not find file data matching hash %#v", dentry)} + return nil, 0, &ParseError{ + Oper: "directory entry", + Path: name, + Err: fmt.Errorf("could not find file data matching hash %#v", dentry), + } } } @@ -742,9 +764,9 @@ func (img *Image) readNextEntry(r io.Reader) (*File, int64, error) { f.SecurityDescriptor = img.sds[dentry.SecurityID] } - _, err = io.CopyN(ioutil.Discard, r, left) + _, err = io.CopyN(io.Discard, r, left) if err != nil { - if err == io.EOF { + if err == io.EOF { //nolint:errorlint err = io.ErrUnexpectedEOF } return nil, 0, err @@ -771,7 +793,11 @@ func (img *Image) readNextEntry(r io.Reader) (*File, int64, error) { } if dentry.Attributes&FILE_ATTRIBUTE_REPARSE_POINT != 0 && f.Size == 0 { - return nil, 0, &ParseError{Oper: "directory entry", Path: name, Err: errors.New("reparse point is missing reparse stream")} + return nil, 0, &ParseError{ + Oper: "directory entry", + Path: name, + Err: errors.New("reparse point is missing reparse stream"), + } } return f, length, nil @@ -781,7 +807,7 @@ func (img *Image) readNextStream(r io.Reader) (*Stream, int64, error) { var length int64 err := binary.Read(r, binary.LittleEndian, &length) if err != nil { - if err == io.EOF { + if err == io.EOF { //nolint:errorlint err = io.ErrUnexpectedEOF } return nil, 0, &ParseError{Oper: "stream length check", Err: err} @@ -818,7 +844,11 @@ func (img *Image) readNextStream(r io.Reader) (*Stream, int64, error) { var ok bool offset, ok = img.wim.fileData[sentry.Hash] if !ok { - return nil, 0, &ParseError{Oper: "stream entry", Path: name, Err: fmt.Errorf("could not find file data matching hash %v", sentry.Hash)} + return nil, 0, &ParseError{ + Oper: "stream entry", + Path: name, + Err: fmt.Errorf("could not find file data matching hash %v", sentry.Hash), + } } } @@ -832,9 +862,9 @@ func (img *Image) readNextStream(r io.Reader) (*Stream, int64, error) { offset: offset, } - _, err = io.CopyN(ioutil.Discard, r, left) + _, err = io.CopyN(io.Discard, r, left) if err != nil { - if err == io.EOF { + if err == io.EOF { //nolint:errorlint err = io.ErrUnexpectedEOF } return nil, 0, err diff --git a/zsyscall_windows.go b/zsyscall_windows.go index 1e921f40..83f45a13 100644 --- a/zsyscall_windows.go +++ b/zsyscall_windows.go @@ -399,25 +399,25 @@ func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err erro return } -func ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) { +func ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntStatus) { r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0) - status = ntstatus(r0) + status = ntStatus(r0) return } -func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) { +func rtlDefaultNpAcl(dacl *uintptr) (status ntStatus) { r0, _, _ := syscall.Syscall(procRtlDefaultNpAcl.Addr(), 1, uintptr(unsafe.Pointer(dacl)), 0, 0) - status = ntstatus(r0) + status = ntStatus(r0) return } -func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) { +func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntStatus) { r0, _, _ := syscall.Syscall6(procRtlDosPathNameToNtPathName_U.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved), 0, 0) - status = ntstatus(r0) + status = ntStatus(r0) return } -func rtlNtStatusToDosError(status ntstatus) (winerr error) { +func rtlNtStatusToDosError(status ntStatus) (winerr error) { r0, _, _ := syscall.Syscall(procRtlNtStatusToDosErrorNoTeb.Addr(), 1, uintptr(status), 0, 0) if r0 != 0 { winerr = syscall.Errno(r0)