diff --git a/sd.go b/sd.go index 8edbc843..48d8557f 100644 --- a/sd.go +++ b/sd.go @@ -6,10 +6,14 @@ package winio import ( "syscall" "unsafe" + + "golang.org/x/sys/windows" ) //sys lookupAccountName(systemName *uint16, accountName string, sid *byte, sidSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) = advapi32.LookupAccountNameW +//sys lookupAccountSid(systemName *uint16, sid *byte, name *uint16, nameSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) = advapi32.LookupAccountSidW //sys convertSidToStringSid(sid *byte, str **uint16) (err error) = advapi32.ConvertSidToStringSidW +//sys convertStringSidToSid(str *uint16, sid **byte) (err error) = advapi32.ConvertStringSidToSidW //sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW //sys convertSecurityDescriptorToStringSecurityDescriptor(sd *byte, revision uint32, secInfo uint32, sddl **uint16, sddlSize *uint32) (err error) = advapi32.ConvertSecurityDescriptorToStringSecurityDescriptorW //sys localFree(mem uintptr) = LocalFree @@ -17,6 +21,7 @@ import ( const ( cERROR_NONE_MAPPED = syscall.Errno(1332) + cERROR_INVALID_SID = syscall.Errno(1337) ) type AccountLookupError struct { @@ -30,6 +35,8 @@ func (e *AccountLookupError) Error() string { } var s string switch e.Err { + case cERROR_INVALID_SID: + s = "the security ID structure is invalid" case cERROR_NONE_MAPPED: s = "not found" default: @@ -74,6 +81,40 @@ func LookupSidByName(name string) (sid string, err error) { return sid, nil } +// LookupNameBySid looks up the name of an account by SID +func LookupNameBySid(sid string) (name string, err error) { + if sid == "" { + return "", &AccountLookupError{sid, cERROR_NONE_MAPPED} + } + + sidBuffer, err := windows.UTF16PtrFromString(sid) + if err != nil { + return "", &AccountLookupError{sid, err} + } + + var sidPtr *byte + if err = convertStringSidToSid(sidBuffer, &sidPtr); err != nil { + return "", &AccountLookupError{sid, err} + } + defer localFree(uintptr(unsafe.Pointer(sidPtr))) + + var nameSize, refDomainSize, sidNameUse uint32 + err = lookupAccountSid(nil, sidPtr, nil, &nameSize, nil, &refDomainSize, &sidNameUse) + if err != nil && err != windows.ERROR_INSUFFICIENT_BUFFER { + return "", &AccountLookupError{sid, err} + } + + nameBuffer := make([]uint16, nameSize) + refDomainBuffer := make([]uint16, refDomainSize) + err = lookupAccountSid(nil, sidPtr, &nameBuffer[0], &nameSize, &refDomainBuffer[0], &refDomainSize, &sidNameUse) + if err != nil { + return "", &AccountLookupError{sid, err} + } + + name = windows.UTF16ToString(nameBuffer) + return name, nil +} + func SddlToSecurityDescriptor(sddl string) ([]byte, error) { var sdBuffer uintptr err := convertStringSecurityDescriptorToSecurityDescriptor(sddl, 1, &sdBuffer, nil) diff --git a/sd_test.go b/sd_test.go index cb1291bc..9a6aab8a 100644 --- a/sd_test.go +++ b/sd_test.go @@ -13,10 +13,24 @@ func TestLookupInvalidSid(t *testing.T) { } } +func TestLookupInvalidName(t *testing.T) { + _, err := LookupNameBySid("notasid") + aerr, ok := err.(*AccountLookupError) + if !ok || aerr.Err != cERROR_INVALID_SID { + t.Fatalf("expected AccountLookupError with ERROR_INVALID_SID got %s", err) + } +} + func TestLookupValidSid(t *testing.T) { - sid, err := LookupSidByName("Everyone") - if err != nil || sid != "S-1-1-0" { - t.Fatalf("expected S-1-1-0, got %s, %s", sid, err) + everyone := "S-1-1-0" + name, err := LookupNameBySid(everyone) + if err != nil { + t.Fatalf("expected a valid account name, got %v", err) + } + + sid, err := LookupSidByName(name) + if err != nil || sid != everyone { + t.Fatalf("expected %s, got %s, %s", everyone, sid, err) } } diff --git a/zsyscall_windows.go b/zsyscall_windows.go index 8be405cf..81e00229 100644 --- a/zsyscall_windows.go +++ b/zsyscall_windows.go @@ -49,9 +49,11 @@ var ( procConvertSecurityDescriptorToStringSecurityDescriptorW = modadvapi32.NewProc("ConvertSecurityDescriptorToStringSecurityDescriptorW") procConvertSidToStringSidW = modadvapi32.NewProc("ConvertSidToStringSidW") procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW") + procConvertStringSidToSidW = modadvapi32.NewProc("ConvertStringSidToSidW") procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength") procImpersonateSelf = modadvapi32.NewProc("ImpersonateSelf") procLookupAccountNameW = modadvapi32.NewProc("LookupAccountNameW") + procLookupAccountSidW = modadvapi32.NewProc("LookupAccountSidW") procLookupPrivilegeDisplayNameW = modadvapi32.NewProc("LookupPrivilegeDisplayNameW") procLookupPrivilegeNameW = modadvapi32.NewProc("LookupPrivilegeNameW") procLookupPrivilegeValueW = modadvapi32.NewProc("LookupPrivilegeValueW") @@ -125,6 +127,14 @@ func _convertStringSecurityDescriptorToSecurityDescriptor(str *uint16, revision return } +func convertStringSidToSid(str *uint16, sid **byte) (err error) { + r1, _, e1 := syscall.Syscall(procConvertStringSidToSidW.Addr(), 2, uintptr(unsafe.Pointer(str)), uintptr(unsafe.Pointer(sid)), 0) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + func getSecurityDescriptorLength(sd uintptr) (len uint32) { r0, _, _ := syscall.Syscall(procGetSecurityDescriptorLength.Addr(), 1, uintptr(sd), 0, 0) len = uint32(r0) @@ -156,6 +166,14 @@ func _lookupAccountName(systemName *uint16, accountName *uint16, sid *byte, sidS return } +func lookupAccountSid(systemName *uint16, sid *byte, name *uint16, nameSize *uint32, refDomain *uint16, refDomainSize *uint32, sidNameUse *uint32) (err error) { + r1, _, e1 := syscall.Syscall9(procLookupAccountSidW.Addr(), 7, uintptr(unsafe.Pointer(systemName)), uintptr(unsafe.Pointer(sid)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(nameSize)), uintptr(unsafe.Pointer(refDomain)), uintptr(unsafe.Pointer(refDomainSize)), uintptr(unsafe.Pointer(sidNameUse)), 0, 0) + if r1 == 0 { + err = errnoErr(e1) + } + return +} + func lookupPrivilegeDisplayName(systemName string, name *uint16, buffer *uint16, size *uint32, languageId *uint32) (err error) { var _p0 *uint16 _p0, err = syscall.UTF16PtrFromString(systemName)