-
Notifications
You must be signed in to change notification settings - Fork 0
/
pull-updater.go
244 lines (212 loc) · 7.65 KB
/
pull-updater.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
package main
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"strings"
"time"
"github.com/google/go-github/v53/github"
"github.com/gregjones/httpcache"
"github.com/palantir/go-githubapp/githubapp"
"github.com/pkg/errors"
"github.com/rcrowley/go-metrics"
"github.com/rs/zerolog"
"gopkg.in/yaml.v2"
)
// Struct to hold the GitHub App client and configuration
type PRBranchUpdateHandler struct {
githubapp.ClientCreator
preamble string
labels []string
}
// Struct to hold the server and app configuration
type Config struct {
Server HTTPConfig `yaml:"server"`
Github githubapp.Config `yaml:"github"`
AppConfig MyApplicationConfig `yaml:"app_configuration"`
}
// Struct to hold the application configuration
type MyApplicationConfig struct {
PullRequestPreamble string `yaml:"pull_request_preamble"`
PullRequestLabels []string `yaml:"pull_request_labels"`
}
// Struct to hold the HTTP server configuration
type HTTPConfig struct {
Address string `yaml:"address"`
Port int `yaml:"port"`
}
// Read the server configuration
func readConfig(path string) (*Config, error) {
var c Config
bytes, err := os.ReadFile(path)
if err != nil {
return nil, errors.Wrapf(err, "failed reading server config file: %s", path)
}
if err := yaml.UnmarshalStrict(bytes, &c); err != nil {
return nil, errors.Wrap(err, "failed parsing configuration file")
}
return &c, nil
}
// Return the event types that the handler will handle
func (h *PRBranchUpdateHandler) Handles() []string {
return []string{"push"}
}
// Check if the pull request has the required labels from the configuration
func hasAllLabels(configLabels []string, prLabels []*github.Label) bool {
for _, configLabel := range configLabels {
if !contains(prLabels, configLabel) {
return false
}
}
return true
}
func contains(prLabels []*github.Label, configLabel string) bool {
for _, prLabel := range prLabels {
prLabel := prLabel.GetName()
prLabel = strings.ToLower(prLabel)
configLabel = strings.ToLower(configLabel)
if prLabel == configLabel {
return true
}
}
return false
}
// This handler is called when the server recives a webhook push event.
// The handler will check if the push was to the default branch and if so
// check if there are any open pull requests that are approved to merge and
// are behind the default branch. If so, the pull request will be updated
// to the latest default branch commit.
func (h *PRBranchUpdateHandler) Handle(ctx context.Context, eventType, deliveryID string, payload []byte) error {
// Create a new logger
logger := zerolog.New(os.Stdout).With().Timestamp().Logger()
// Get the push event payload
var pushEvent *github.PushEvent
if err := json.Unmarshal(payload, &pushEvent); err != nil {
return errors.Wrap(err, "failed to parse push event payload")
}
// Get the installation ID
installationID := githubapp.GetInstallationIDFromEvent(pushEvent)
// Get the installation client
client, err := h.NewInstallationClient(installationID)
if err != nil {
return err
}
// Get the repository information
repo := pushEvent.GetRepo()
repoOwner := repo.GetOwner().GetLogin()
repoName := repo.GetName()
repoDefaultBranch := repo.GetDefaultBranch()
// Check if the push was to the default branch
if pushEvent.GetRef() != fmt.Sprintf("refs/heads/%s", repoDefaultBranch) {
return nil
}
// Get all open pull requests
logger.Info().Msgf("Getting all open pull requests for %s/%s", repoOwner, repoName)
pullRequests, _, err := client.PullRequests.List(ctx, repoOwner, repoName, &github.PullRequestListOptions{
State: "open",
})
if err != nil {
return err
}
logger.Info().Msgf("Found %d open pull requests", len(pullRequests))
// Iterate over all open pull requests
for _, pr := range pullRequests {
// Get the pull request information
prNum := pr.GetNumber()
headRef := pr.GetHead().GetRef()
baseRef := pr.GetBase().GetRef()
prLabels := pr.Labels
// Check if the pull request has the correct labels
hasLabels := true
if len(h.labels) > 0 {
logger.Info().Msgf("Checking if pull request %s/%s#%d has the correct labels", repoOwner, repoName, prNum)
hasLabels = hasAllLabels(h.labels, prLabels)
if hasLabels {
logger.Info().Msgf("Pull request %s/%s#%d has the correct labels", repoOwner, repoName, prNum)
} else {
logger.Info().Msgf("Pull request %s/%s#%d does not have the correct labels", repoOwner, repoName, prNum)
continue
}
}
// Compare the pull request head to the default branch
commitComparison, _, _ := client.Repositories.CompareCommits(ctx, repoOwner, repoName, baseRef, headRef, nil)
// Check if the pull request is behind the default branch
if commitComparison.GetBehindBy() >= 1 {
// Update the pull request branch
logger.Info().Msgf("Pull request %s/%s#%d is behind default branch %s by %d commits", repoOwner, repoName, prNum, repoDefaultBranch, commitComparison.GetBehindBy())
updateResponse, _, err := client.PullRequests.UpdateBranch(ctx, repoOwner, repoName, prNum, nil)
if err != nil {
// Check if the error is due to the job being scheduled on GitHub side
if err.Error() == "job scheduled on GitHub side; try again later" {
logger.Info().Msgf("Job scheduled on GitHub side")
// Comment on the pull request
msg := fmt.Sprintf("%s\n\n%s", h.preamble, updateResponse.GetMessage())
prComment := github.IssueComment{
Body: &msg,
}
logger.Info().Msgf("Commenting on pull request %s/%s#%d", repoOwner, repoName, prNum)
if _, _, err := client.Issues.CreateComment(ctx, repoOwner, repoName, prNum, &prComment); err != nil {
return err
}
} else {
// Comment on the pull request that the update failed
msg := fmt.Sprintf("Failed to update pull request. Error: %s", err.Error())
prComment := github.IssueComment{
Body: &msg,
}
logger.Info().Msgf("Commenting on pull request %s/%s#%d", repoOwner, repoName, prNum)
if _, _, err := client.Issues.CreateComment(ctx, repoOwner, repoName, prNum, &prComment); err != nil {
return err
}
}
}
logger.Info().Msgf("Updated pull request %s/%s#%d. Message: %s", repoOwner, repoName, prNum, updateResponse.GetMessage())
} else {
logger.Info().Msgf("Pull request %s/%s#%d on branch %s is up to date with default branch %s", repoOwner, repoName, prNum, headRef, repoDefaultBranch)
}
}
return nil
}
func main() {
// Read the configuration file
config, err := readConfig("config.yml")
if err != nil {
panic(err)
}
// Create the logger
logger := zerolog.New(os.Stdout).With().Timestamp().Logger()
zerolog.DefaultContextLogger = &logger
// Create the metrics registry
metricsRegistry := metrics.DefaultRegistry
// Create the GitHub App client creator
cc, err := githubapp.NewDefaultCachingClientCreator(
config.Github,
githubapp.WithClientUserAgent("pr-updater-app/1.0.0"),
githubapp.WithClientTimeout(3*time.Second),
githubapp.WithClientCaching(true, func() httpcache.Cache { return httpcache.NewMemoryCache() }),
githubapp.WithClientMiddleware(
githubapp.ClientMetrics(metricsRegistry),
),
)
if err != nil {
panic(err)
}
// Create the HTTP handler
prBranchUpdateHandler := &PRBranchUpdateHandler{
ClientCreator: cc,
preamble: config.AppConfig.PullRequestPreamble,
labels: config.AppConfig.PullRequestLabels,
}
webhookHandler := githubapp.NewDefaultEventDispatcher(config.Github, prBranchUpdateHandler)
// Create the HTTP server
http.Handle(githubapp.DefaultWebhookRoute, webhookHandler)
addr := fmt.Sprintf("%s:%d", config.Server.Address, config.Server.Port)
logger.Info().Msgf("Starting server on %s...", addr)
// Start the HTTP server
err = http.ListenAndServe(addr, nil)
if err != nil {
panic(err)
}
}