diff --git a/internal/config/config.go b/internal/config/config.go index 64bad81c..67f7d919 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -18,6 +18,7 @@ const ( // Google Secret Manager googleSecretManagerProject = "GOOGLE_SECRET_MANAGER_PROJECT" + trustedProxies = "TRUSTED_PROXIES" // general settings appURL = "APP_URL" @@ -125,7 +126,7 @@ func getSecretManagerValue(name string, defaultValue string) string { return defaultValue } - log.Printf("Reading '%s' from Secret Manager\n", name) + log.Printf("Reading '%s' from Secret Manager", name) // Create the client. ctx := context.Background() @@ -151,6 +152,10 @@ func getSecretManagerValue(name string, defaultValue string) string { return string(result.Payload.Data) } +func TrustedProxies() []string { + return viper.GetStringSlice(trustedProxies) +} + func AppURL() string { return viper.GetString(appURL) } diff --git a/internal/page/session_handler.go b/internal/page/session_handler.go index 932b8379..152106d3 100644 --- a/internal/page/session_handler.go +++ b/internal/page/session_handler.go @@ -135,6 +135,12 @@ func (h *sessionHandler) executeBatch(commands []*command.Command) (results []st return nil, err } messages = append(messages, NewMessage("", UpdateControlPropsAction, payload)) + } else if cmdName == command.Get { + value, err := h.get(cmd) + if err != nil { + return nil, err + } + results = append(results, value) } else if cmdName == command.Clean { payload, err := h.cleanWithMessage(cmd) if err != nil { diff --git a/internal/server/client_ip.go b/internal/server/client_ip.go new file mode 100644 index 00000000..29171dac --- /dev/null +++ b/internal/server/client_ip.go @@ -0,0 +1,128 @@ +package server + +import ( + "net" + "strings" + + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" +) + +var ( + trustedCIDRs []*net.IPNet + remoteIPHeaders = []string{"X-Forwarded-For", "X-Real-IP"} +) + +// ClientIP implements a best effort algorithm to return the real client IP. +// It called c.RemoteIP() under the hood, to check if the remote IP is a trusted proxy or not. +// If it's it will then try to parse the headers defined in Engine.RemoteIPHeaders (defaulting to [X-Forwarded-For, X-Real-Ip]). +// If the headers are nots syntactically valid OR the remote IP does not correspong to a trusted proxy, +// the remote IP (coming form Request.RemoteAddr) is returned. +func ClientIP(c *gin.Context) string { + remoteIP, trusted := RemoteIP(c) + if remoteIP == nil { + return "" + } + + if trusted && remoteIPHeaders != nil { + for _, headerName := range remoteIPHeaders { + ip, valid := validateHeader(c.Request.Header.Get(headerName)) + if valid { + return ip + } + } + } + return remoteIP.String() +} + +// RemoteIP parses the IP from Request.RemoteAddr, normalizes and returns the IP (without the port). +// It also checks if the remoteIP is a trusted proxy or not. +// In order to perform this validation, it will see if the IP is contained within at least one of the CIDR blocks +// defined in Engine.TrustedProxies +func RemoteIP(c *gin.Context) (net.IP, bool) { + ip, _, err := net.SplitHostPort(strings.TrimSpace(c.Request.RemoteAddr)) + if err != nil { + return nil, false + } + remoteIP := net.ParseIP(ip) + if remoteIP == nil { + return nil, false + } + + if trustedCIDRs != nil { + for _, cidr := range trustedCIDRs { + if cidr.Contains(remoteIP) { + return remoteIP, true + } + } + } + + return remoteIP, false +} + +func validateHeader(header string) (clientIP string, valid bool) { + if header == "" { + return "", false + } + items := strings.Split(header, ",") + for i, ipStr := range items { + ipStr = strings.TrimSpace(ipStr) + ip := net.ParseIP(ipStr) + if ip == nil { + return "", false + } + + // We need to return the first IP in the list, but, + // we should not early return since we need to validate that + // the rest of the header is syntactically valid + if i == 0 { + clientIP = ipStr + valid = true + } + } + return +} + +func prepareTrustedCIDRs(engine *gin.Engine) { + if engine.TrustedProxies == nil { + return + } + + trustedCIDRs = make([]*net.IPNet, 0, len(engine.TrustedProxies)) + for _, trustedProxy := range engine.TrustedProxies { + if !strings.Contains(trustedProxy, "/") { + ip := parseIP(trustedProxy) + if ip == nil { + log.Errorf("error parsing IP: %s", trustedProxy) + return + } + + switch len(ip) { + case net.IPv4len: + trustedProxy += "/32" + case net.IPv6len: + trustedProxy += "/128" + } + } + _, cidrNet, err := net.ParseCIDR(trustedProxy) + if err != nil { + log.Errorf("error parsing CIDR: %s", err) + return + } + trustedCIDRs = append(trustedCIDRs, cidrNet) + } +} + +// parseIP parse a string representation of an IP and returns a net.IP with the +// minimum byte representation or nil if input is invalid. +func parseIP(ip string) net.IP { + parsedIP := net.ParseIP(ip) + + if ipv4 := parsedIP.To4(); ipv4 != nil { + // return ip in a 4-byte representation + return ipv4 + } + + // return ip in a 16-byte representation or nil + return parsedIP +} diff --git a/internal/server/oauth_handler.go b/internal/server/oauth_handler.go index f055d5cd..af0dbd04 100644 --- a/internal/server/oauth_handler.go +++ b/internal/server/oauth_handler.go @@ -94,7 +94,7 @@ func oauthHandler(c *gin.Context, authProvider string) { } // create new principal and update its details from API - principal := auth.NewPrincipal(authProvider, c.ClientIP(), c.Request.UserAgent(), state.GroupsEnabled) + principal := auth.NewPrincipal(authProvider, ClientIP(c), c.Request.UserAgent(), state.GroupsEnabled) principal.SetToken(token) err = principal.UpdateDetails() diff --git a/internal/server/server.go b/internal/server/server.go index 6d224058..92a803e8 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -74,6 +74,12 @@ func Start(ctx context.Context, wg *sync.WaitGroup, serverPort int) { // Set the router as the default one shipped with Gin router := gin.Default() + if config.TrustedProxies() != nil && len(config.TrustedProxies()) > 0 { + router.TrustedProxies = config.TrustedProxies() + log.Println("Trusted proxies:", router.TrustedProxies) + prepareTrustedCIDRs(router) + } + // force SSL if config.ForceSSL() { router.Use(secure.Secure(secure.Options{ @@ -184,7 +190,7 @@ func websocketHandler(c *gin.Context) { } wsc := page_connection.NewWebSocket(conn) - page.NewClient(wsc, c.ClientIP(), principal) + page.NewClient(wsc, ClientIP(c), principal) } func getSecurityPrincipal(c *gin.Context) (*auth.SecurityPrincipal, error) { @@ -198,7 +204,7 @@ func getSecurityPrincipal(c *gin.Context) (*auth.SecurityPrincipal, error) { principal = store.GetSecurityPrincipal(principalID) if principal == nil { return nil, nil - } else if principal.ClientIP != c.ClientIP() || principal.UserAgentHash != utils.SHA1(c.Request.UserAgent()) { + } else if principal.ClientIP != ClientIP(c) || principal.UserAgentHash != utils.SHA1(c.Request.UserAgent()) { log.Errorln("Principal not found or its IP address or User Agent do not match") store.DeleteSecurityPrincipal(principalID) } else { diff --git a/internal/store/store.go b/internal/store/store.go index 19928ac9..63543d69 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -133,21 +133,24 @@ func DeleteExpiredClient(clientID string) []string { cache.SetRemove(fmt.Sprintf(sessionWebClientsKey, pageID, sessionID), clientID) cache.SetRemove(fmt.Sprintf(pageHostClientsKey, pageID), clientID) - for _, sessionID := range GetPageHostClientSessions(pageID, clientID) { - RemoveSessionHostClient(pageID, sessionID, clientID) - - sessionClients := GetSessionWebClients(pageID, sessionID) - for _, clientID := range sessionClients { - clients = append(clients, clientID) - RemoveSessionWebClient(pageID, sessionID, clientID) + page := GetPageByID(pageID) + if page.IsApp { + for _, sessionID := range GetPageHostClientSessions(pageID, clientID) { + RemoveSessionHostClient(pageID, sessionID, clientID) + + sessionClients := GetSessionWebClients(pageID, sessionID) + for _, clientID := range sessionClients { + clients = append(clients, clientID) + RemoveSessionWebClient(pageID, sessionID, clientID) + } + + DeleteSession(pageID, sessionID) } + RemovePageHostClientSessions(pageID, clientID) - DeleteSession(pageID, sessionID) - } - RemovePageHostClientSessions(pageID, clientID) - - if len(GetPageHostClients(pageID)) == 0 { - DeletePage(pageID) + if len(GetPageHostClients(pageID)) == 0 { + DeletePage(pageID) + } } } cache.Remove(fmt.Sprintf(clientSessionsKey, clientID))