Files
gluetun/internal/firewall/vpn.go
2020-07-13 02:15:32 +00:00

107 lines
3.2 KiB
Go

package firewall
import (
"context"
"fmt"
"github.com/qdm12/private-internet-access-docker/internal/models"
)
func (c *configurator) SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error) {
c.stateMutex.Lock()
defer c.stateMutex.Unlock()
if !c.enabled {
c.logger.Info("firewall disabled, only updating VPN connections internal list")
c.vpnConnections = make([]models.OpenVPNConnection, len(connections))
copy(c.vpnConnections, connections)
return nil
}
c.logger.Info("setting VPN connections through firewall...")
connectionsToAdd := findConnectionsToAdd(c.vpnConnections, connections)
connectionsToRemove := findConnectionsToRemove(c.vpnConnections, connections)
if len(connectionsToAdd) == 0 && len(connectionsToRemove) == 0 {
return nil
}
defaultInterface, _, err := c.routing.DefaultRoute()
if err != nil {
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
}
c.removeConnections(ctx, connectionsToRemove, defaultInterface)
if err := c.addConnections(ctx, connectionsToAdd, defaultInterface); err != nil {
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
}
return nil
}
func removeConnectionFromConnections(connections []models.OpenVPNConnection, connection models.OpenVPNConnection) []models.OpenVPNConnection {
L := len(connections)
for i := range connections {
if connection.Equal(connections[i]) {
connections[i] = connections[L-1]
connections = connections[:L-1]
break
}
}
return connections
}
func findConnectionsToAdd(oldConnections, newConnections []models.OpenVPNConnection) (connectionsToAdd []models.OpenVPNConnection) {
for _, newConnection := range newConnections {
found := false
for _, oldConnection := range oldConnections {
if oldConnection.Equal(newConnection) {
found = true
break
}
}
if !found {
connectionsToAdd = append(connectionsToAdd, newConnection)
}
}
return connectionsToAdd
}
func findConnectionsToRemove(oldConnections, newConnections []models.OpenVPNConnection) (connectionsToRemove []models.OpenVPNConnection) {
for _, oldConnection := range oldConnections {
found := false
for _, newConnection := range newConnections {
if oldConnection.Equal(newConnection) {
found = true
break
}
}
if !found {
connectionsToRemove = append(connectionsToRemove, oldConnection)
}
}
return connectionsToRemove
}
func (c *configurator) removeConnections(ctx context.Context, connections []models.OpenVPNConnection, defaultInterface string) {
for _, conn := range connections {
const remove = true
if err := c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
c.logger.Error("cannot remove outdated VPN connection through firewall: %s", err)
continue
}
c.vpnConnections = removeConnectionFromConnections(c.vpnConnections, conn)
}
}
func (c *configurator) addConnections(ctx context.Context, connections []models.OpenVPNConnection, defaultInterface string) error {
const remove = false
for _, conn := range connections {
if err := c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
return err
}
c.vpnConnections = append(c.vpnConnections, conn)
}
return nil
}