Skip to content

Commit

Permalink
Merge pull request #123720 from HirazawaUi/fix-slow-dra-test
Browse files Browse the repository at this point in the history
kubelet: fix slow dra unit test
  • Loading branch information
k8s-ci-robot committed Mar 25, 2024
2 parents 20d0ab7 + 10b6319 commit 227c2e7
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 19 deletions.
3 changes: 2 additions & 1 deletion pkg/kubelet/cm/devicemanager/plugin/v1beta1/handler.go
Expand Up @@ -19,6 +19,7 @@ package v1beta1
import (
"fmt"
"os"
"time"

core "k8s.io/api/core/v1"
"k8s.io/klog/v2"
Expand All @@ -37,7 +38,7 @@ func (s *server) GetPluginHandler() cache.PluginHandler {
return s
}

func (s *server) RegisterPlugin(pluginName string, endpoint string, versions []string) error {
func (s *server) RegisterPlugin(pluginName string, endpoint string, versions []string, pluginClientTimeout *time.Duration) error {
klog.V(2).InfoS("Registering plugin at endpoint", "plugin", pluginName, "endpoint", endpoint)
return s.connectClient(pluginName, endpoint)
}
Expand Down
24 changes: 18 additions & 6 deletions pkg/kubelet/cm/dra/manager_test.go
Expand Up @@ -84,7 +84,7 @@ type fakeDRAServerInfo struct {
teardownFn tearDown
}

func setupFakeDRADriverGRPCServer(shouldTimeout bool) (fakeDRAServerInfo, error) {
func setupFakeDRADriverGRPCServer(shouldTimeout bool, pluginClientTimeout *time.Duration) (fakeDRAServerInfo, error) {
socketDir, err := os.MkdirTemp("", "dra")
if err != nil {
return fakeDRAServerInfo{
Expand Down Expand Up @@ -117,7 +117,7 @@ func setupFakeDRADriverGRPCServer(shouldTimeout bool) (fakeDRAServerInfo, error)
driverName: driverName,
}
if shouldTimeout {
timeout := plugin.PluginClientTimeout + time.Second
timeout := *pluginClientTimeout * 2
fakeDRADriverGRPCServer.timeout = &timeout
}

Expand Down Expand Up @@ -758,14 +758,20 @@ func TestPrepareResources(t *testing.T) {
}
}

draServerInfo, err := setupFakeDRADriverGRPCServer(test.wantTimeout)
var pluginClientTimeout *time.Duration
if test.wantTimeout {
timeout := time.Millisecond * 20
pluginClientTimeout = &timeout
}

draServerInfo, err := setupFakeDRADriverGRPCServer(test.wantTimeout, pluginClientTimeout)
if err != nil {
t.Fatal(err)
}
defer draServerInfo.teardownFn()

plg := plugin.NewRegistrationHandler(nil, getFakeNode)
if err := plg.RegisterPlugin(test.driverName, draServerInfo.socketName, []string{"1.27"}); err != nil {
if err := plg.RegisterPlugin(test.driverName, draServerInfo.socketName, []string{"1.27"}, pluginClientTimeout); err != nil {
t.Fatalf("failed to register plugin %s, err: %v", test.driverName, err)
}
defer plg.DeRegisterPlugin(test.driverName) // for sake of next tests
Expand Down Expand Up @@ -1058,14 +1064,20 @@ func TestUnprepareResources(t *testing.T) {
t.Fatalf("failed to create a new instance of the claimInfoCache, err: %v", err)
}

draServerInfo, err := setupFakeDRADriverGRPCServer(test.wantTimeout)
var pluginClientTimeout *time.Duration
if test.wantTimeout {
timeout := time.Millisecond * 20
pluginClientTimeout = &timeout
}

draServerInfo, err := setupFakeDRADriverGRPCServer(test.wantTimeout, pluginClientTimeout)
if err != nil {
t.Fatal(err)
}
defer draServerInfo.teardownFn()

plg := plugin.NewRegistrationHandler(nil, getFakeNode)
if err := plg.RegisterPlugin(test.driverName, draServerInfo.socketName, []string{"1.27"}); err != nil {
if err := plg.RegisterPlugin(test.driverName, draServerInfo.socketName, []string{"1.27"}, pluginClientTimeout); err != nil {
t.Fatalf("failed to register plugin %s, err: %v", test.driverName, err)
}
defer plg.DeRegisterPlugin(test.driverName) // for sake of next tests
Expand Down
4 changes: 2 additions & 2 deletions pkg/kubelet/cm/dra/plugin/client.go
Expand Up @@ -154,7 +154,7 @@ func (p *plugin) NodePrepareResources(
return nil, err
}

ctx, cancel := context.WithTimeout(ctx, PluginClientTimeout)
ctx, cancel := context.WithTimeout(ctx, p.clientTimeout)
defer cancel()

version := p.getVersion()
Expand Down Expand Up @@ -183,7 +183,7 @@ func (p *plugin) NodeUnprepareResources(
return nil, err
}

ctx, cancel := context.WithTimeout(ctx, PluginClientTimeout)
ctx, cancel := context.WithTimeout(ctx, p.clientTimeout)
defer cancel()

version := p.getVersion()
Expand Down
5 changes: 3 additions & 2 deletions pkg/kubelet/cm/dra/plugin/client_test.go
Expand Up @@ -268,8 +268,9 @@ func TestNodeUnprepareResource(t *testing.T) {
defer teardown()

p := &plugin{
endpoint: addr,
version: v1alpha3Version,
endpoint: addr,
version: v1alpha3Version,
clientTimeout: PluginClientTimeout,
}

conn, err := p.getOrCreateGRPCConn()
Expand Down
11 changes: 10 additions & 1 deletion pkg/kubelet/cm/dra/plugin/plugin.go
Expand Up @@ -49,6 +49,7 @@ type plugin struct {
endpoint string
version string
highestSupportedVersion *utilversion.Version
clientTimeout time.Duration
}

func (p *plugin) getOrCreateGRPCConn() (*grpc.ClientConn, error) {
Expand Down Expand Up @@ -116,19 +117,27 @@ func NewRegistrationHandler(kubeClient kubernetes.Interface, getNode func() (*v1
}

// RegisterPlugin is called when a plugin can be registered.
func (h *RegistrationHandler) RegisterPlugin(pluginName string, endpoint string, versions []string) error {
func (h *RegistrationHandler) RegisterPlugin(pluginName string, endpoint string, versions []string, pluginClientTimeout *time.Duration) error {
klog.InfoS("Register new DRA plugin", "name", pluginName, "endpoint", endpoint)

highestSupportedVersion, err := h.validateVersions("RegisterPlugin", pluginName, versions)
if err != nil {
return err
}

var timeout time.Duration
if pluginClientTimeout == nil {
timeout = PluginClientTimeout
} else {
timeout = *pluginClientTimeout
}

pluginInstance := &plugin{
conn: nil,
endpoint: endpoint,
version: v1alpha3Version,
highestSupportedVersion: highestSupportedVersion,
clientTimeout: timeout,
}

// Storing endpoint of newly registered DRA Plugin into the map, where plugin name will be the key
Expand Down
2 changes: 1 addition & 1 deletion pkg/kubelet/cm/dra/plugin/plugin_test.go
Expand Up @@ -56,7 +56,7 @@ func TestRegistrationHandler_ValidatePlugin(t *testing.T) {
description: "plugin already registered with a higher supported version",
handler: func() *RegistrationHandler {
handler := newRegistrationHandler()
if err := handler.RegisterPlugin("this-plugin-already-exists-and-has-a-long-name-so-it-doesnt-collide", "", []string{"v1.1.0"}); err != nil {
if err := handler.RegisterPlugin("this-plugin-already-exists-and-has-a-long-name-so-it-doesnt-collide", "", []string{"v1.1.0"}, nil); err != nil {
t.Fatal(err)
}
return handler
Expand Down
4 changes: 3 additions & 1 deletion pkg/kubelet/pluginmanager/cache/types.go
Expand Up @@ -16,6 +16,8 @@ limitations under the License.

package cache

import "time"

// PluginHandler is an interface a client of the pluginwatcher API needs to implement in
// order to consume plugins
// The PluginHandler follows the simple following state machine:
Expand Down Expand Up @@ -51,7 +53,7 @@ type PluginHandler interface {
// RegisterPlugin is called so that the plugin can be registered by any
// plugin consumer
// Error encountered here can still be Notified to the plugin.
RegisterPlugin(pluginName, endpoint string, versions []string) error
RegisterPlugin(pluginName, endpoint string, versions []string, pluginClientTimeout *time.Duration) error
// DeRegisterPlugin is called once the pluginwatcher observes that the socket has
// been deleted.
DeRegisterPlugin(pluginName string)
Expand Down
Expand Up @@ -121,7 +121,7 @@ func (og *operationGenerator) GenerateRegisterPluginFunc(
if err != nil {
klog.ErrorS(err, "RegisterPlugin error -- failed to add plugin", "path", socketPath)
}
if err := handler.RegisterPlugin(infoResp.Name, infoResp.Endpoint, infoResp.SupportedVersions); err != nil {
if err := handler.RegisterPlugin(infoResp.Name, infoResp.Endpoint, infoResp.SupportedVersions, nil); err != nil {
return og.notifyPlugin(client, false, fmt.Sprintf("RegisterPlugin error -- plugin registration failed with err: %v", err))
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/kubelet/pluginmanager/plugin_manager_test.go
Expand Up @@ -59,7 +59,7 @@ func (f *fakePluginHandler) ValidatePlugin(pluginName string, endpoint string, v
}

// RegisterPlugin is a fake method
func (f *fakePluginHandler) RegisterPlugin(pluginName, endpoint string, versions []string) error {
func (f *fakePluginHandler) RegisterPlugin(pluginName, endpoint string, versions []string, pluginClientTimeout *time.Duration) error {
f.Lock()
defer f.Unlock()
f.events = append(f.events, "register "+pluginName)
Expand Down
2 changes: 1 addition & 1 deletion pkg/kubelet/pluginmanager/reconciler/reconciler_test.go
Expand Up @@ -127,7 +127,7 @@ func (d *DummyImpl) ValidatePlugin(pluginName string, endpoint string, versions
}

// RegisterPlugin is a dummy implementation
func (d *DummyImpl) RegisterPlugin(pluginName string, endpoint string, versions []string) error {
func (d *DummyImpl) RegisterPlugin(pluginName string, endpoint string, versions []string, pluginClientTimeout *time.Duration) error {
return nil
}

Expand Down
11 changes: 9 additions & 2 deletions pkg/volume/csi/csi_plugin.go
Expand Up @@ -109,7 +109,7 @@ func (h *RegistrationHandler) ValidatePlugin(pluginName string, endpoint string,
}

// RegisterPlugin is called when a plugin can be registered
func (h *RegistrationHandler) RegisterPlugin(pluginName string, endpoint string, versions []string) error {
func (h *RegistrationHandler) RegisterPlugin(pluginName string, endpoint string, versions []string, pluginClientTimeout *time.Duration) error {
klog.Infof(log("Register new plugin with name: %s at endpoint: %s", pluginName, endpoint))

highestSupportedVersion, err := h.validateVersions("RegisterPlugin", pluginName, endpoint, versions)
Expand All @@ -130,7 +130,14 @@ func (h *RegistrationHandler) RegisterPlugin(pluginName string, endpoint string,
return err
}

ctx, cancel := context.WithTimeout(context.Background(), csiTimeout)
var timeout time.Duration
if pluginClientTimeout == nil {
timeout = csiTimeout
} else {
timeout = *pluginClientTimeout
}

ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

driverNodeID, maxVolumePerNode, accessibleTopology, err := csi.NodeGetInfo(ctx)
Expand Down

0 comments on commit 227c2e7

Please sign in to comment.