diff --git a/datasource.go b/datasource.go index 4bfb818..d9d14d7 100644 --- a/datasource.go +++ b/datasource.go @@ -227,11 +227,10 @@ func (ds *SQLDatasource) handleQuery(ctx context.Context, req backend.DataQuery, if errors.Is(err, ErrorQuery) && !errors.Is(err, context.DeadlineExceeded) { for i := 0; i < ds.driverSettings.Retries; i++ { backend.Logger.Warn(fmt.Sprintf("query failed. retrying %d times", i)) - db, err := ds.c.Connect(dbConn.settings, q.ConnectionArgs) + db, err := ds.dbReconnect(dbConn, q, cacheKey) if err != nil { return nil, err } - ds.storeDBConnection(cacheKey, dbConnection{db, dbConn.settings}) if ds.driverSettings.Pause > 0 { time.Sleep(time.Duration(ds.driverSettings.Pause * int(time.Second))) @@ -247,11 +246,10 @@ func (ds *SQLDatasource) handleQuery(ctx context.Context, req backend.DataQuery, if errors.Is(err, context.DeadlineExceeded) { for i := 0; i < ds.driverSettings.Retries; i++ { backend.Logger.Warn(fmt.Sprintf("connection timed out. retrying %d times", i)) - db, err := ds.c.Connect(dbConn.settings, q.ConnectionArgs) + db, err := ds.dbReconnect(dbConn, q, cacheKey) if err != nil { continue } - ds.storeDBConnection(cacheKey, dbConnection{db, dbConn.settings}) res, err = QueryDB(ctx, db, ds.c.Converters(), fillMode, q) if err == nil { @@ -263,6 +261,19 @@ func (ds *SQLDatasource) handleQuery(ctx context.Context, req backend.DataQuery, return nil, err } +func (ds *SQLDatasource) dbReconnect(dbConn dbConnection, q *Query, cacheKey string) (*sql.DB, error) { + if err := dbConn.db.Close(); err != nil { + backend.Logger.Warn(fmt.Sprintf("closing existing connection failed: %s", err.Error())) + } + + db, err := ds.c.Connect(dbConn.settings, q.ConnectionArgs) + if err != nil { + return nil, err + } + ds.storeDBConnection(cacheKey, dbConnection{db, dbConn.settings}) + return db, nil +} + // CheckHealth pings the connected SQL database func (ds *SQLDatasource) CheckHealth(ctx context.Context, req *backend.CheckHealthRequest) (*backend.CheckHealthResult, error) { key := defaultKey(getDatasourceUID(*req.PluginContext.DataSourceInstanceSettings)) diff --git a/datasource_test.go b/datasource_test.go index ea75cde..9615b9d 100644 --- a/datasource_test.go +++ b/datasource_test.go @@ -18,13 +18,13 @@ import ( ) type fakeDriver struct { - db *sql.DB + openDBfn func() (*sql.DB, error) Driver } func (d fakeDriver) Connect(backend.DataSourceInstanceSettings, json.RawMessage) (db *sql.DB, err error) { - return d.db, nil + return d.openDBfn() } func (d fakeDriver) Macros() Macros { @@ -41,7 +41,7 @@ func Test_getDBConnectionFromQuery(t *testing.T) { db := &sql.DB{} db2 := &sql.DB{} db3 := &sql.DB{} - d := &fakeDriver{db: db3} + d := &fakeDriver{openDBfn: func() (*sql.DB, error) { return db3, nil }} tests := []struct { desc string dsUID string @@ -144,7 +144,7 @@ func Test_timeout_retries(t *testing.T) { t.Errorf("failed to connect to mock driver: %v", err) } timeoutDriver := fakeDriver{ - db: db, + openDBfn: func() (*sql.DB, error) { return db, nil }, } retries := 5 max := time.Duration(testTimeout) * time.Second @@ -178,12 +178,15 @@ func Test_error_retries(t *testing.T) { } mockDriver := "sqlmock-error" mock.RegisterDriver(mockDriver, handler) - db, err := sql.Open(mockDriver, "") - if err != nil { - t.Errorf("failed to connect to mock driver: %v", err) - } + timeoutDriver := fakeDriver{ - db: db, + openDBfn: func() (*sql.DB, error) { + db, err := sql.Open(mockDriver, "") + if err != nil { + t.Errorf("failed to connect to mock driver: %v", err) + } + return db, nil + }, } retries := 5 max := time.Duration(10) * time.Second @@ -192,6 +195,7 @@ func Test_error_retries(t *testing.T) { key := defaultKey(dsUID) // Add the mandatory default db + db, _ := timeoutDriver.Connect(settings, nil) ds.storeDBConnection(key, dbConnection{db, settings}) ctx := context.Background()