Skip to content

Commit

Permalink
Merge pull request #99858 from freehan/firewall-fix
Browse files Browse the repository at this point in the history
Revert "Revert "fix a bug where only service with less than 100 ports can have GCE lo…
  • Loading branch information
k8s-ci-robot committed Mar 8, 2021
2 parents 2d3acce + e3a5347 commit 97cd5bb
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func (g *Cloud) ensureExternalLoadBalancer(clusterName string, clusterID string,
return nil, err
}

firewallExists, firewallNeedsUpdate, err := g.firewallNeedsUpdate(loadBalancerName, serviceName.String(), g.region, ipAddressToUse, ports, sourceRanges)
firewallExists, firewallNeedsUpdate, err := g.firewallNeedsUpdate(loadBalancerName, serviceName.String(), ipAddressToUse, ports, sourceRanges)
if err != nil {
return nil, err
}
Expand All @@ -181,13 +181,13 @@ func (g *Cloud) ensureExternalLoadBalancer(clusterName string, clusterID string,
// without needing to be deleted and recreated.
if firewallExists {
klog.Infof("ensureExternalLoadBalancer(%s): Updating firewall.", lbRefStr)
if err := g.updateFirewall(apiService, MakeFirewallName(loadBalancerName), g.region, desc, sourceRanges, ports, hosts); err != nil {
if err := g.updateFirewall(apiService, MakeFirewallName(loadBalancerName), desc, sourceRanges, ports, hosts); err != nil {
return nil, err
}
klog.Infof("ensureExternalLoadBalancer(%s): Updated firewall.", lbRefStr)
} else {
klog.Infof("ensureExternalLoadBalancer(%s): Creating firewall.", lbRefStr)
if err := g.createFirewall(apiService, MakeFirewallName(loadBalancerName), g.region, desc, sourceRanges, ports, hosts); err != nil {
if err := g.createFirewall(apiService, MakeFirewallName(loadBalancerName), desc, sourceRanges, ports, hosts); err != nil {
return nil, err
}
klog.Infof("ensureExternalLoadBalancer(%s): Created firewall.", lbRefStr)
Expand Down Expand Up @@ -845,7 +845,7 @@ func translateAffinityType(affinityType v1.ServiceAffinity) string {
}
}

func (g *Cloud) firewallNeedsUpdate(name, serviceName, region, ipAddress string, ports []v1.ServicePort, sourceRanges utilnet.IPNetSet) (exists bool, needsUpdate bool, err error) {
func (g *Cloud) firewallNeedsUpdate(name, serviceName, ipAddress string, ports []v1.ServicePort, sourceRanges utilnet.IPNetSet) (exists bool, needsUpdate bool, err error) {
fw, err := g.GetFirewall(MakeFirewallName(name))
if err != nil {
if isHTTPErrorCode(err, http.StatusNotFound) {
Expand All @@ -860,15 +860,15 @@ func (g *Cloud) firewallNeedsUpdate(name, serviceName, region, ipAddress string,
return true, true, nil
}
// Make sure the allowed ports match.
allowedPorts := make([]string, len(ports))
for ix := range ports {
allowedPorts[ix] = strconv.Itoa(int(ports[ix].Port))
}
if !equalStringSets(allowedPorts, fw.Allowed[0].Ports) {
portNums, portRanges, _ := getPortsAndProtocol(ports)
// This logic checks if the existing firewall rules contains either enumerated service ports or port ranges.
// This is to prevent unnecessary noop updates to the firewall rule when the existing firewall rule is
// set up via the previous pattern using enumerated ports instead of port ranges.
if !equalStringSets(portNums, fw.Allowed[0].Ports) && !equalStringSets(portRanges, fw.Allowed[0].Ports) {
return true, true, nil
}
// The service controller already verified that the protocol matches on all ports, no need to check.

// The service controller already verified that the protocol matches on all ports, no need to check.
actualSourceRanges, err := utilnet.ParseIPNets(fw.SourceRanges...)
if err != nil {
// This really shouldn't happen... GCE has returned something unexpected
Expand Down Expand Up @@ -899,7 +899,7 @@ func (g *Cloud) ensureHTTPHealthCheckFirewall(svc *v1.Service, serviceName, ipAd
return fmt.Errorf("error getting firewall for health checks: %v", err)
}
klog.Infof("Creating firewall %v for health checks.", fwName)
if err := g.createFirewall(svc, fwName, region, desc, sourceRanges, ports, hosts); err != nil {
if err := g.createFirewall(svc, fwName, desc, sourceRanges, ports, hosts); err != nil {
return err
}
klog.Infof("Created firewall %v for health checks.", fwName)
Expand All @@ -912,7 +912,7 @@ func (g *Cloud) ensureHTTPHealthCheckFirewall(svc *v1.Service, serviceName, ipAd
!equalStringSets(fw.Allowed[0].Ports, []string{strconv.Itoa(int(ports[0].Port))}) ||
!equalStringSets(fw.SourceRanges, sourceRanges.StringSlice()) {
klog.Warningf("Firewall %v exists but parameters have drifted - updating...", fwName)
if err := g.updateFirewall(svc, fwName, region, desc, sourceRanges, ports, hosts); err != nil {
if err := g.updateFirewall(svc, fwName, desc, sourceRanges, ports, hosts); err != nil {
klog.Warningf("Failed to reconcile firewall %v parameters.", fwName)
return err
}
Expand Down Expand Up @@ -948,8 +948,8 @@ func createForwardingRule(s CloudForwardingRuleService, name, serviceName, regio
return nil
}

func (g *Cloud) createFirewall(svc *v1.Service, name, region, desc string, sourceRanges utilnet.IPNetSet, ports []v1.ServicePort, hosts []*gceInstance) error {
firewall, err := g.firewallObject(name, region, desc, sourceRanges, ports, hosts)
func (g *Cloud) createFirewall(svc *v1.Service, name, desc string, sourceRanges utilnet.IPNetSet, ports []v1.ServicePort, hosts []*gceInstance) error {
firewall, err := g.firewallObject(name, desc, sourceRanges, ports, hosts)
if err != nil {
return err
}
Expand All @@ -966,8 +966,8 @@ func (g *Cloud) createFirewall(svc *v1.Service, name, region, desc string, sourc
return nil
}

func (g *Cloud) updateFirewall(svc *v1.Service, name, region, desc string, sourceRanges utilnet.IPNetSet, ports []v1.ServicePort, hosts []*gceInstance) error {
firewall, err := g.firewallObject(name, region, desc, sourceRanges, ports, hosts)
func (g *Cloud) updateFirewall(svc *v1.Service, name, desc string, sourceRanges utilnet.IPNetSet, ports []v1.ServicePort, hosts []*gceInstance) error {
firewall, err := g.firewallObject(name, desc, sourceRanges, ports, hosts)
if err != nil {
return err
}
Expand All @@ -985,11 +985,11 @@ func (g *Cloud) updateFirewall(svc *v1.Service, name, region, desc string, sourc
return nil
}

func (g *Cloud) firewallObject(name, region, desc string, sourceRanges utilnet.IPNetSet, ports []v1.ServicePort, hosts []*gceInstance) (*compute.Firewall, error) {
allowedPorts := make([]string, len(ports))
for ix := range ports {
allowedPorts[ix] = strconv.Itoa(int(ports[ix].Port))
}
func (g *Cloud) firewallObject(name, desc string, sourceRanges utilnet.IPNetSet, ports []v1.ServicePort, hosts []*gceInstance) (*compute.Firewall, error) {
// Concatenate service ports into port ranges. This help to workaround the gce firewall limitation where only
// 100 ports or port ranges can be used in a firewall rule.
_, portRanges, _ := getPortsAndProtocol(ports)

// If the node tags to be used for this cluster have been predefined in the
// provider config, just use them. Otherwise, invoke computeHostTags method to get the tags.
hostTags := g.nodeTags
Expand All @@ -1014,7 +1014,7 @@ func (g *Cloud) firewallObject(name, region, desc string, sourceRanges utilnet.I
// mixed TCP and UDP ports. It should be possible to use a
// single firewall rule for both a TCP and UDP lb.
IPProtocol: strings.ToLower(string(ports[0].Protocol)),
Ports: allowedPorts,
Ports: portRanges,
},
},
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ package gce
import (
"context"
"fmt"

"reflect"
"strings"
"testing"

Expand All @@ -33,6 +35,9 @@ import (
"github.com/GoogleCloudPlatform/k8s-cloud-provider/pkg/cloud/mock"
"k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/apimachinery/pkg/util/json"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/client-go/tools/record"
utilnet "k8s.io/utils/net"
)
Expand Down Expand Up @@ -681,11 +686,20 @@ func TestFirewallNeedsUpdate(t *testing.T) {
gce, err := fakeGCECloud(DefaultTestClusterValues())
require.NoError(t, err)
svc := fakeLoadbalancerService("")
svc.Spec.Ports = []v1.ServicePort{
{Name: "port1", Protocol: v1.ProtocolTCP, Port: int32(80), TargetPort: intstr.FromInt(80)},
{Name: "port2", Protocol: v1.ProtocolTCP, Port: int32(81), TargetPort: intstr.FromInt(81)},
{Name: "port3", Protocol: v1.ProtocolTCP, Port: int32(82), TargetPort: intstr.FromInt(82)},
{Name: "port4", Protocol: v1.ProtocolTCP, Port: int32(84), TargetPort: intstr.FromInt(84)},
{Name: "port5", Protocol: v1.ProtocolTCP, Port: int32(85), TargetPort: intstr.FromInt(85)},
{Name: "port6", Protocol: v1.ProtocolTCP, Port: int32(86), TargetPort: intstr.FromInt(86)},
{Name: "port7", Protocol: v1.ProtocolTCP, Port: int32(88), TargetPort: intstr.FromInt(87)},
}

status, err := createExternalLoadBalancer(gce, svc, []string{"test-node-1"}, vals.ClusterName, vals.ClusterID, vals.ZoneName)
require.NotNil(t, status)
require.NoError(t, err)
svcName := "/" + svc.ObjectMeta.Name
region := vals.Region

ipAddr := status.Ingress[0].IP
lbName := gce.GetLoadBalancerName(context.TODO(), "", svc)
Expand Down Expand Up @@ -795,6 +809,78 @@ func TestFirewallNeedsUpdate(t *testing.T) {
needsUpdate: false,
hasErr: false,
},
"Backward compatible with previous firewall setup with enumerated ports": {
lbName: lbName,
ipAddr: ipAddr,
ports: svc.Spec.Ports,
ipnet: ipnet,
fwIPProtocol: "tcp",
getHook: func(ctx context.Context, key *meta.Key, m *cloud.MockFirewalls) (bool, *compute.Firewall, error) {
obj, ok := m.Objects[*key]
if !ok {
return false, nil, nil
}
fw, err := copyFirewallObj(obj.Obj.(*compute.Firewall))
if err != nil {
return true, nil, err
}
// enumerate the service ports in the firewall rule
fw.Allowed[0].Ports = []string{"80", "81", "82", "84", "85", "86", "88"}
return true, fw, nil
},
sourceRange: fw.SourceRanges[0],
exists: true,
needsUpdate: false,
hasErr: false,
},
"need to update previous firewall setup with enumerated ports ": {
lbName: lbName,
ipAddr: ipAddr,
ports: svc.Spec.Ports,
ipnet: ipnet,
fwIPProtocol: "tcp",
getHook: func(ctx context.Context, key *meta.Key, m *cloud.MockFirewalls) (bool, *compute.Firewall, error) {
obj, ok := m.Objects[*key]
if !ok {
return false, nil, nil
}
fw, err := copyFirewallObj(obj.Obj.(*compute.Firewall))
if err != nil {
return true, nil, err
}
// enumerate the service ports in the firewall rule
fw.Allowed[0].Ports = []string{"80", "81", "82", "84", "85", "86"}
return true, fw, nil
},
sourceRange: fw.SourceRanges[0],
exists: true,
needsUpdate: true,
hasErr: false,
},
"need to update port-ranges ": {
lbName: lbName,
ipAddr: ipAddr,
ports: svc.Spec.Ports,
ipnet: ipnet,
fwIPProtocol: "tcp",
getHook: func(ctx context.Context, key *meta.Key, m *cloud.MockFirewalls) (bool, *compute.Firewall, error) {
obj, ok := m.Objects[*key]
if !ok {
return false, nil, nil
}
fw, err := copyFirewallObj(obj.Obj.(*compute.Firewall))
if err != nil {
return true, nil, err
}
// enumerate the service ports in the firewall rule
fw.Allowed[0].Ports = []string{"80-82", "86"}
return true, fw, nil
},
sourceRange: fw.SourceRanges[0],
exists: true,
needsUpdate: true,
hasErr: false,
},
} {
t.Run(desc, func(t *testing.T) {
fw, err = gce.GetFirewall(MakeFirewallName(tc.lbName))
Expand All @@ -813,11 +899,9 @@ func TestFirewallNeedsUpdate(t *testing.T) {
exists, needsUpdate, err := gce.firewallNeedsUpdate(
tc.lbName,
svcName,
region,
tc.ipAddr,
tc.ports,
tc.ipnet)

assert.Equal(t, tc.exists, exists, "'exists' didn't return as expected "+desc)
assert.Equal(t, tc.needsUpdate, needsUpdate, "'needsUpdate' didn't return as expected "+desc)
if tc.hasErr {
Expand Down Expand Up @@ -947,7 +1031,6 @@ func TestCreateAndUpdateFirewallSucceedsOnXPN(t *testing.T) {
gce.createFirewall(
svc,
gce.GetLoadBalancerName(context.TODO(), "", svc),
gce.region,
"A sad little firewall",
ipnet,
svc.Spec.Ports,
Expand All @@ -960,7 +1043,6 @@ func TestCreateAndUpdateFirewallSucceedsOnXPN(t *testing.T) {
gce.updateFirewall(
svc,
gce.GetLoadBalancerName(context.TODO(), "", svc),
gce.region,
"A sad little firewall",
ipnet,
svc.Spec.Ports,
Expand Down Expand Up @@ -1262,3 +1344,129 @@ func TestNeedToUpdateHttpHealthChecks(t *testing.T) {
})
}
}

func TestFirewallObject(t *testing.T) {
t.Parallel()
vals := DefaultTestClusterValues()
gce, err := fakeGCECloud(vals)
gce.nodeTags = []string{"node-tags"}
require.NoError(t, err)
srcRanges := []string{"10.10.0.0/24", "10.20.0.0/24"}
sourceRanges, _ := utilnet.ParseIPNets(srcRanges...)
fwName := "test-fw"
fwDesc := "test-desc"
baseFw := compute.Firewall{
Name: fwName,
Description: fwDesc,
Network: gce.networkURL,
SourceRanges: []string{},
TargetTags: gce.nodeTags,
Allowed: []*compute.FirewallAllowed{
{
IPProtocol: "tcp",
Ports: []string{"80"},
},
},
}

for _, tc := range []struct {
desc string
sourceRanges utilnet.IPNetSet
svcPorts []v1.ServicePort
expectedFirewall func(fw compute.Firewall) compute.Firewall
}{
{
desc: "empty source ranges",
sourceRanges: utilnet.IPNetSet{},
svcPorts: []v1.ServicePort{
{Name: "port1", Protocol: v1.ProtocolTCP, Port: int32(80), TargetPort: intstr.FromInt(80)},
},
expectedFirewall: func(fw compute.Firewall) compute.Firewall {
return fw
},
},
{
desc: "has source ranges",
sourceRanges: sourceRanges,
svcPorts: []v1.ServicePort{
{Name: "port1", Protocol: v1.ProtocolTCP, Port: int32(80), TargetPort: intstr.FromInt(80)},
},
expectedFirewall: func(fw compute.Firewall) compute.Firewall {
fw.SourceRanges = srcRanges
return fw
},
},
{
desc: "has multiple ports",
sourceRanges: sourceRanges,
svcPorts: []v1.ServicePort{
{Name: "port1", Protocol: v1.ProtocolTCP, Port: int32(80), TargetPort: intstr.FromInt(80)},
{Name: "port2", Protocol: v1.ProtocolTCP, Port: int32(82), TargetPort: intstr.FromInt(82)},
{Name: "port3", Protocol: v1.ProtocolTCP, Port: int32(84), TargetPort: intstr.FromInt(84)},
},
expectedFirewall: func(fw compute.Firewall) compute.Firewall {
fw.Allowed = []*compute.FirewallAllowed{
{
IPProtocol: "tcp",
Ports: []string{"80", "82", "84"},
},
}
fw.SourceRanges = srcRanges
return fw
},
},
{
desc: "has multiple ports",
sourceRanges: sourceRanges,
svcPorts: []v1.ServicePort{
{Name: "port1", Protocol: v1.ProtocolTCP, Port: int32(80), TargetPort: intstr.FromInt(80)},
{Name: "port2", Protocol: v1.ProtocolTCP, Port: int32(81), TargetPort: intstr.FromInt(81)},
{Name: "port3", Protocol: v1.ProtocolTCP, Port: int32(82), TargetPort: intstr.FromInt(82)},
{Name: "port4", Protocol: v1.ProtocolTCP, Port: int32(84), TargetPort: intstr.FromInt(84)},
{Name: "port5", Protocol: v1.ProtocolTCP, Port: int32(85), TargetPort: intstr.FromInt(85)},
{Name: "port6", Protocol: v1.ProtocolTCP, Port: int32(86), TargetPort: intstr.FromInt(86)},
{Name: "port7", Protocol: v1.ProtocolTCP, Port: int32(88), TargetPort: intstr.FromInt(87)},
},
expectedFirewall: func(fw compute.Firewall) compute.Firewall {
fw.Allowed = []*compute.FirewallAllowed{
{
IPProtocol: "tcp",
Ports: []string{"80-82", "84-86", "88"},
},
}
fw.SourceRanges = srcRanges
return fw
},
},
} {
t.Run(tc.desc, func(t *testing.T) {
ret, err := gce.firewallObject(fwName, fwDesc, tc.sourceRanges, tc.svcPorts, nil)
require.NoError(t, err)
expectedFirewall := tc.expectedFirewall(baseFw)
retSrcRanges := sets.NewString(ret.SourceRanges...)
expectSrcRanges := sets.NewString(expectedFirewall.SourceRanges...)
if !expectSrcRanges.Equal(retSrcRanges) {
t.Errorf("expect firewall source ranges to be %v, but got %v", expectSrcRanges, retSrcRanges)
}
ret.SourceRanges = nil
expectedFirewall.SourceRanges = nil
if !reflect.DeepEqual(*ret, expectedFirewall) {
t.Errorf("expect firewall to be %+v, but got %+v", expectedFirewall, ret)
}
})
}
}

func copyFirewallObj(firewall *compute.Firewall) (*compute.Firewall, error) {
// make a copy of the original obj via json marshal and unmarshal
jsonObj, err := firewall.MarshalJSON()
if err != nil {
return nil, err
}
var fw compute.Firewall
err = json.Unmarshal(jsonObj, &fw)
if err != nil {
return nil, err
}
return &fw, nil
}

0 comments on commit 97cd5bb

Please sign in to comment.