From c6679fc81ee9d59fb4f9532228ca380b0129a3b2 Mon Sep 17 00:00:00 2001 From: Hunts Chen Date: Thu, 17 Mar 2022 12:40:11 -0700 Subject: [PATCH] add function to return all inherited files This is helpful in the use case of integrating with systemd socket activation. When there is no parent, the app still need to scan or use actication.Files() to find all the fds and call upgrader.AddFile() to track them. But when there is a parent, we can access the fds from upgrader instance rather then by scanning fds or calling actication.Files() both of which creates os.File objects for the fds. Without this change, there would be two set of os.File objects that represent the same fd set. One set of os.File is hold by the upgrader whie the another set returning from fd scanning (or actication.Files()) is being used at other places. This raise a risk of closing the same fd twice at different time which is really bad when the fd get reused by other files during the two close calls. --- fds.go | 30 +++++++++++++++++++++++++++ fds_test.go | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/fds.go b/fds.go index 29d7a7c..a0f3a77 100644 --- a/fds.go +++ b/fds.go @@ -323,6 +323,36 @@ func (f *Fds) addSyscallConnLocked(kind, network, addr string, conn syscall.Conn return nil } +// Files returns all inherited files and mark them as used. +// +// The descriptors may be in blocking mode. +func (f *Fds) Files() ([]*os.File, error) { + f.mu.Lock() + defer f.mu.Unlock() + + var files []*os.File + + for key, file := range f.inherited { + if key[0] != fdKind { + continue + } + + // Make a copy of the file, since we don't want to + // allow the caller to invalidate fds in f.inherited. + dup, err := dupFd(file.fd, key) + if err != nil { + return nil, err + } + + f.used[key] = file + delete(f.inherited, key) + + files = append(files, dup.File) + } + + return files, nil +} + // File returns an inherited file or nil. // // The descriptor may be in blocking mode. diff --git a/fds_test.go b/fds_test.go index 6879921..b066022 100644 --- a/fds_test.go +++ b/fds_test.go @@ -315,3 +315,62 @@ func TestFdsFile(t *testing.T) { } file.Close() } + +func TestFdsFiles(t *testing.T) { + r1, w1, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + defer r1.Close() + + r2, w2, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + defer r2.Close() + + testcases := []struct { + f *os.File + name string + expected string + }{ + { + w1, + "test1", + "fd:test1:", + }, + { + w2, + "test2", + "fd:test2:", + }, + } + + parent := newFds(nil, nil) + for _, tc := range testcases { + if err := parent.AddFile(tc.name, tc.f); err != nil { + t.Fatal("Can't add file:", err) + } + tc.f.Close() + } + + child := newFds(parent.copy(), nil) + files, err := child.Files() + if err != nil { + t.Fatal("Can't get inherited files:", err) + } + + if len(files) != len(testcases) { + t.Fatalf("Expected %d files, got %d", len(testcases), len(files)) + } + + for i, ff := range files { + tc := testcases[i] + + if ff.Name() != tc.expected { + t.Errorf("Expected file %q, got %q", tc.expected, ff.Name()) + } + + ff.Close() + } +}