Skip to content

Commit

Permalink
Move parse of PGPASSFILE to parseEnviron
Browse files Browse the repository at this point in the history
Now the connection only checks the parameter passfile, that
is populated by parseEnviron.

Refactored the test for this
  • Loading branch information
keymon committed Jun 28, 2023
1 parent 29f3a40 commit eb3a56e
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 14 deletions.
7 changes: 3 additions & 4 deletions conn.go
Expand Up @@ -233,11 +233,8 @@ func (cn *conn) handlePgpass(o values) {
if _, ok := o["password"]; ok {
return
}
// Get passfile from the options, if empty, get it from envvar
// Get passfile from the options
filename := o["passfile"]
if filename == "" {
filename = os.Getenv("PGPASSFILE")
}
if filename == "" {
// XXX this code doesn't work on Windows where the default filename is
// XXX %APPDATA%\postgresql\pgpass.conf
Expand Down Expand Up @@ -2042,6 +2039,8 @@ func parseEnviron(env []string) (out map[string]string) {
accrue("user")
case "PGPASSWORD":
accrue("password")
case "PGPASSFILE":
accrue("passfile")
case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
unsupported()
case "PGOPTIONS":
Expand Down
23 changes: 13 additions & 10 deletions conn_test.go
Expand Up @@ -174,7 +174,10 @@ func TestPgpass(t *testing.T) {
}
testAssert("", "ok", "missing .pgpass, unexpected error %#v")
os.Setenv("PGPASSFILE", pgpassFile)
defer os.Unsetenv("PGPASSFILE")
testAssert("host=/tmp", "fail", ", unexpected error %#v")
os.Unsetenv("PGPASSFILE")

os.Remove(pgpassFile)
pgpass, err := os.OpenFile(pgpassFile, os.O_RDWR|os.O_CREATE, 0644)
if err != nil {
Expand Down Expand Up @@ -212,22 +215,22 @@ localhost:*:*:*:pass_C
t.Fatalf("For %v expected %s got %s", extra, expected, pw)
}
}
// missing passfile means empty psasword
assertPassword(values{"host": "server", "dbname": "some_db", "user": "some_user"}, "")
// wrong permissions for the pgpass file means it should be ignored
assertPassword(values{"host": "example.com", "user": "foo"}, "")
assertPassword(values{"host": "example.com", "passfile": pgpassFile, "user": "foo"}, "")
// fix the permissions and check if it has taken effect
os.Chmod(pgpassFile, 0600)
assertPassword(values{"host": "server", "dbname": "some_db", "user": "some_user"}, "pass_A")
assertPassword(values{"host": "example.com", "user": "foo"}, "pass_fallback")
assertPassword(values{"host": "example.com", "dbname": "some_db", "user": "some_user"}, "pass_B")

assertPassword(values{"host": "server", "passfile": pgpassFile, "dbname": "some_db", "user": "some_user"}, "pass_A")
assertPassword(values{"host": "example.com", "passfile": pgpassFile, "user": "foo"}, "pass_fallback")
assertPassword(values{"host": "example.com", "passfile": pgpassFile, "dbname": "some_db", "user": "some_user"}, "pass_B")
// localhost also matches the default "" and UNIX sockets
assertPassword(values{"host": "", "user": "some_user"}, "pass_C")
assertPassword(values{"host": "/tmp", "user": "some_user"}, "pass_C")
assertPassword(values{"host": "", "passfile": pgpassFile, "user": "some_user"}, "pass_C")
assertPassword(values{"host": "/tmp", "passfile": pgpassFile, "user": "some_user"}, "pass_C")
// passfile connection parameter takes precedence
os.Setenv("PGPASSFILE", "/tmp")
assertPassword(values{"host": "server", "dbname": "some_db", "user": "some_user", "passfile": pgpassFile}, "pass_A")

// cleanup
os.Setenv("PGPASSFILE", "")
assertPassword(values{"host": "server", "passfile": pgpassFile, "dbname": "some_db", "user": "some_user"}, "pass_A")
}

func TestExec(t *testing.T) {
Expand Down
19 changes: 19 additions & 0 deletions connector_test.go
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"database/sql"
"database/sql/driver"
"os"
"testing"
)

Expand Down Expand Up @@ -66,3 +67,21 @@ func TestNewConnector_Driver(t *testing.T) {
}
txn.Rollback()
}

func TestNewConnector_Environ(t *testing.T) {
name := ""
os.Setenv("PGPASSFILE", "/tmp/.pgpass")
defer os.Unsetenv("PGPASSFILE")
c, err := NewConnector(name)
if err != nil {
t.Fatal(err)
}
for key, expected := range map[string]string{
"passfile": "/tmp/.pgpass",
} {
if got := c.opts[key]; got != expected {
t.Fatalf("Getting values from environment variables, for %v expected %s got %s", key, expected, got)
}
}

}

0 comments on commit eb3a56e

Please sign in to comment.