Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

automtls: fix bidirectional communication when AutoMTLS is enabled #193

Merged
merged 5 commits into from
May 3, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,8 @@ func (c *Client) Start() (addr net.Addr, err error) {

c.config.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.RequireAndVerifyClientCert,
MinVersion: tls.VersionTLS12,
ServerName: "localhost",
}
}
Expand Down Expand Up @@ -774,7 +776,7 @@ func (c *Client) Start() (addr net.Addr, err error) {
}

// loadServerCert is used by AutoMTLS to read an x.509 cert returned by the
// server, and load it as the RootCA for the client TLSConfig.
// server, and load it as the RootCA and ClientCA for the client TLSConfig.
func (c *Client) loadServerCert(cert string) error {
certPool := x509.NewCertPool()

Expand All @@ -791,6 +793,7 @@ func (c *Client) loadServerCert(cert string) error {
certPool.AddCert(x509Cert)

c.config.TLSConfig.RootCAs = certPool
c.config.TLSConfig.ClientCAs = certPool
calvn marked this conversation as resolved.
Show resolved Hide resolved
return nil
}

Expand Down
2 changes: 2 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ func Serve(opts *ServeConfig) {
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: clientCertPool,
MinVersion: tls.VersionTLS12,
RootCAs: clientCertPool,
ServerName: "localhost",
}
fairclothjm marked this conversation as resolved.
Show resolved Hide resolved

// We send back the raw leaf cert data for the client rather than the
Expand Down
70 changes: 69 additions & 1 deletion server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,75 @@ func TestServer_testMode(t *testing.T) {
t.Logf("HELLO")
}

func TestServer_testMode_AutoMTLS(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

closeCh := make(chan struct{})
go Serve(&ServeConfig{
HandshakeConfig: testVersionedHandshake,
VersionedPlugins: map[int]PluginSet{
2: testGRPCPluginMap,
},
GRPCServer: DefaultGRPCServer,
Logger: hclog.NewNullLogger(),
Test: &ServeTestConfig{
Context: ctx,
ReattachConfigCh: nil,
CloseCh: closeCh,
},
})

// Connect!
process := helperProcess("test-mtls")
c := NewClient(&ClientConfig{
Cmd: process,
HandshakeConfig: testVersionedHandshake,
VersionedPlugins: map[int]PluginSet{
2: testGRPCPluginMap,
},
AllowedProtocols: []Protocol{ProtocolGRPC},
AutoMTLS: true,
})
client, err := c.Client()
if err != nil {
t.Fatalf("err: %s", err)
}

// Grab the impl
raw, err := client.Dispense("test")
if err != nil {
t.Fatalf("err should be nil, got %s", err)
}

tester, ok := raw.(testInterface)
if !ok {
t.Fatalf("bad: %#v", raw)
}

n := tester.Double(3)
if n != 6 {
t.Fatal("invalid response", n)
}

// ensure we can make use of bidirectional communication with AutoMTLS
// enabled
err = tester.Bidirectional()
if err != nil {
t.Fatal("invalid response", err)
}

// Pinging should work
if err := client.Ping(); err != nil {
fairclothjm marked this conversation as resolved.
Show resolved Hide resolved
t.Fatalf("should not err: %s", err)
}

c.Kill()
calvn marked this conversation as resolved.
Show resolved Hide resolved
// Canceling should cause an exit
cancel()
<-closeCh
}

func TestRmListener_impl(t *testing.T) {
var _ net.Listener = new(rmListener)
}
Expand Down Expand Up @@ -145,7 +214,6 @@ func TestProtocolSelection_no_server(t *testing.T) {
if protocol != ProtocolNetRPC {
t.Fatalf("bad protocol %s", protocol)
}

}

func TestServer_testStdLogger(t *testing.T) {
Expand Down