diff --git a/conn.go b/conn.go index da4ff9de..bc098360 100644 --- a/conn.go +++ b/conn.go @@ -233,7 +233,8 @@ func (cn *conn) handlePgpass(o values) { if _, ok := o["password"]; ok { return } - filename := os.Getenv("PGPASSFILE") + // Get passfile from the options + filename := o["passfile"] if filename == "" { // XXX this code doesn't work on Windows where the default filename is // XXX %APPDATA%\postgresql\pgpass.conf @@ -2038,6 +2039,8 @@ func parseEnviron(env []string) (out map[string]string) { accrue("user") case "PGPASSWORD": accrue("password") + case "PGPASSFILE": + accrue("passfile") case "PGSERVICE", "PGSERVICEFILE", "PGREALM": unsupported() case "PGOPTIONS": diff --git a/conn_test.go b/conn_test.go index 8fb81ac8..96f70ddd 100644 --- a/conn_test.go +++ b/conn_test.go @@ -174,7 +174,10 @@ func TestPgpass(t *testing.T) { } testAssert("", "ok", "missing .pgpass, unexpected error %#v") os.Setenv("PGPASSFILE", pgpassFile) + defer os.Unsetenv("PGPASSFILE") testAssert("host=/tmp", "fail", ", unexpected error %#v") + os.Unsetenv("PGPASSFILE") + os.Remove(pgpassFile) pgpass, err := os.OpenFile(pgpassFile, os.O_RDWR|os.O_CREATE, 0644) if err != nil { @@ -189,6 +192,7 @@ localhost:*:*:*:pass_C if err != nil { t.Fatalf("Unexpected error writing pgpass file %#v", err) } + defer os.Remove(pgpassFile) pgpass.Close() assertPassword := func(extra values, expected string) { @@ -211,19 +215,22 @@ localhost:*:*:*:pass_C t.Fatalf("For %v expected %s got %s", extra, expected, pw) } } + // missing passfile means empty psasword + assertPassword(values{"host": "server", "dbname": "some_db", "user": "some_user"}, "") // wrong permissions for the pgpass file means it should be ignored - assertPassword(values{"host": "example.com", "user": "foo"}, "") + assertPassword(values{"host": "example.com", "passfile": pgpassFile, "user": "foo"}, "") // fix the permissions and check if it has taken effect os.Chmod(pgpassFile, 0600) - assertPassword(values{"host": "server", "dbname": "some_db", "user": "some_user"}, "pass_A") - assertPassword(values{"host": "example.com", "user": "foo"}, "pass_fallback") - assertPassword(values{"host": "example.com", "dbname": "some_db", "user": "some_user"}, "pass_B") + + assertPassword(values{"host": "server", "passfile": pgpassFile, "dbname": "some_db", "user": "some_user"}, "pass_A") + assertPassword(values{"host": "example.com", "passfile": pgpassFile, "user": "foo"}, "pass_fallback") + assertPassword(values{"host": "example.com", "passfile": pgpassFile, "dbname": "some_db", "user": "some_user"}, "pass_B") // localhost also matches the default "" and UNIX sockets - assertPassword(values{"host": "", "user": "some_user"}, "pass_C") - assertPassword(values{"host": "/tmp", "user": "some_user"}, "pass_C") - // cleanup - os.Remove(pgpassFile) - os.Setenv("PGPASSFILE", "") + assertPassword(values{"host": "", "passfile": pgpassFile, "user": "some_user"}, "pass_C") + assertPassword(values{"host": "/tmp", "passfile": pgpassFile, "user": "some_user"}, "pass_C") + // passfile connection parameter takes precedence + os.Setenv("PGPASSFILE", "/tmp") + assertPassword(values{"host": "server", "passfile": pgpassFile, "dbname": "some_db", "user": "some_user"}, "pass_A") } func TestExec(t *testing.T) { diff --git a/connector_test.go b/connector_test.go index d68810e9..ab34fc50 100644 --- a/connector_test.go +++ b/connector_test.go @@ -7,6 +7,7 @@ import ( "context" "database/sql" "database/sql/driver" + "os" "testing" ) @@ -66,3 +67,21 @@ func TestNewConnector_Driver(t *testing.T) { } txn.Rollback() } + +func TestNewConnector_Environ(t *testing.T) { + name := "" + os.Setenv("PGPASSFILE", "/tmp/.pgpass") + defer os.Unsetenv("PGPASSFILE") + c, err := NewConnector(name) + if err != nil { + t.Fatal(err) + } + for key, expected := range map[string]string{ + "passfile": "/tmp/.pgpass", + } { + if got := c.opts[key]; got != expected { + t.Fatalf("Getting values from environment variables, for %v expected %s got %s", key, expected, got) + } + } + +}