From 3d613208bca2e74f2a20e04126ed30bcb5c4cc27 Mon Sep 17 00:00:00 2001 From: Hector Rivas Gandara Date: Fri, 7 Jul 2023 16:57:34 +0100 Subject: [PATCH] Use and respect the passfile connection parameter (#1129) * Use and respect the passfile connection parameter The postgres documentation[1] regarding the password file, states that: password file to use can be specified using the connection parameter passfile or the environment variable PGPASSFILE. The current implementation of lib/pq only respects the environment variable PGPASSFILE. This is not correct, but also limiting, as the PGPASSFILE is global and we might want to use different files for different clients in the same program. Fixing that is easy, by just checking the parameter passfile first, and if not, pull the value from PGPASSFILE. This also moves the parsing of PGPASSFILE to `parseEnviron`. Now the connection only checks the parameter passfile, that is populated by `parseEnviron`. [1] https://www.postgresql.org/docs/current/libpq-pgpass.html --- conn.go | 5 ++++- conn_test.go | 25 ++++++++++++++++--------- connector_test.go | 19 +++++++++++++++++++ 3 files changed, 39 insertions(+), 10 deletions(-) diff --git a/conn.go b/conn.go index da4ff9de..bc098360 100644 --- a/conn.go +++ b/conn.go @@ -233,7 +233,8 @@ func (cn *conn) handlePgpass(o values) { if _, ok := o["password"]; ok { return } - filename := os.Getenv("PGPASSFILE") + // Get passfile from the options + filename := o["passfile"] if filename == "" { // XXX this code doesn't work on Windows where the default filename is // XXX %APPDATA%\postgresql\pgpass.conf @@ -2038,6 +2039,8 @@ func parseEnviron(env []string) (out map[string]string) { accrue("user") case "PGPASSWORD": accrue("password") + case "PGPASSFILE": + accrue("passfile") case "PGSERVICE", "PGSERVICEFILE", "PGREALM": unsupported() case "PGOPTIONS": diff --git a/conn_test.go b/conn_test.go index 8fb81ac8..96f70ddd 100644 --- a/conn_test.go +++ b/conn_test.go @@ -174,7 +174,10 @@ func TestPgpass(t *testing.T) { } testAssert("", "ok", "missing .pgpass, unexpected error %#v") os.Setenv("PGPASSFILE", pgpassFile) + defer os.Unsetenv("PGPASSFILE") testAssert("host=/tmp", "fail", ", unexpected error %#v") + os.Unsetenv("PGPASSFILE") + os.Remove(pgpassFile) pgpass, err := os.OpenFile(pgpassFile, os.O_RDWR|os.O_CREATE, 0644) if err != nil { @@ -189,6 +192,7 @@ localhost:*:*:*:pass_C if err != nil { t.Fatalf("Unexpected error writing pgpass file %#v", err) } + defer os.Remove(pgpassFile) pgpass.Close() assertPassword := func(extra values, expected string) { @@ -211,19 +215,22 @@ localhost:*:*:*:pass_C t.Fatalf("For %v expected %s got %s", extra, expected, pw) } } + // missing passfile means empty psasword + assertPassword(values{"host": "server", "dbname": "some_db", "user": "some_user"}, "") // wrong permissions for the pgpass file means it should be ignored - assertPassword(values{"host": "example.com", "user": "foo"}, "") + assertPassword(values{"host": "example.com", "passfile": pgpassFile, "user": "foo"}, "") // fix the permissions and check if it has taken effect os.Chmod(pgpassFile, 0600) - assertPassword(values{"host": "server", "dbname": "some_db", "user": "some_user"}, "pass_A") - assertPassword(values{"host": "example.com", "user": "foo"}, "pass_fallback") - assertPassword(values{"host": "example.com", "dbname": "some_db", "user": "some_user"}, "pass_B") + + assertPassword(values{"host": "server", "passfile": pgpassFile, "dbname": "some_db", "user": "some_user"}, "pass_A") + assertPassword(values{"host": "example.com", "passfile": pgpassFile, "user": "foo"}, "pass_fallback") + assertPassword(values{"host": "example.com", "passfile": pgpassFile, "dbname": "some_db", "user": "some_user"}, "pass_B") // localhost also matches the default "" and UNIX sockets - assertPassword(values{"host": "", "user": "some_user"}, "pass_C") - assertPassword(values{"host": "/tmp", "user": "some_user"}, "pass_C") - // cleanup - os.Remove(pgpassFile) - os.Setenv("PGPASSFILE", "") + assertPassword(values{"host": "", "passfile": pgpassFile, "user": "some_user"}, "pass_C") + assertPassword(values{"host": "/tmp", "passfile": pgpassFile, "user": "some_user"}, "pass_C") + // passfile connection parameter takes precedence + os.Setenv("PGPASSFILE", "/tmp") + assertPassword(values{"host": "server", "passfile": pgpassFile, "dbname": "some_db", "user": "some_user"}, "pass_A") } func TestExec(t *testing.T) { diff --git a/connector_test.go b/connector_test.go index d68810e9..ab34fc50 100644 --- a/connector_test.go +++ b/connector_test.go @@ -7,6 +7,7 @@ import ( "context" "database/sql" "database/sql/driver" + "os" "testing" ) @@ -66,3 +67,21 @@ func TestNewConnector_Driver(t *testing.T) { } txn.Rollback() } + +func TestNewConnector_Environ(t *testing.T) { + name := "" + os.Setenv("PGPASSFILE", "/tmp/.pgpass") + defer os.Unsetenv("PGPASSFILE") + c, err := NewConnector(name) + if err != nil { + t.Fatal(err) + } + for key, expected := range map[string]string{ + "passfile": "/tmp/.pgpass", + } { + if got := c.opts[key]; got != expected { + t.Fatalf("Getting values from environment variables, for %v expected %s got %s", key, expected, got) + } + } + +}