107 lines
3.2 KiB
Go
107 lines
3.2 KiB
Go
package firewall
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/qdm12/private-internet-access-docker/internal/models"
|
|
)
|
|
|
|
func (c *configurator) SetVPNConnections(ctx context.Context, connections []models.OpenVPNConnection) (err error) {
|
|
c.stateMutex.Lock()
|
|
defer c.stateMutex.Unlock()
|
|
|
|
if !c.enabled {
|
|
c.logger.Info("firewall disabled, only updating VPN connections internal list")
|
|
c.vpnConnections = make([]models.OpenVPNConnection, len(connections))
|
|
copy(c.vpnConnections, connections)
|
|
return nil
|
|
}
|
|
|
|
c.logger.Info("setting VPN connections through firewall...")
|
|
|
|
connectionsToAdd := findConnectionsToAdd(c.vpnConnections, connections)
|
|
connectionsToRemove := findConnectionsToRemove(c.vpnConnections, connections)
|
|
if len(connectionsToAdd) == 0 && len(connectionsToRemove) == 0 {
|
|
return nil
|
|
}
|
|
|
|
defaultInterface, _, err := c.routing.DefaultRoute()
|
|
if err != nil {
|
|
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
|
|
}
|
|
|
|
c.removeConnections(ctx, connectionsToRemove, defaultInterface)
|
|
if err := c.addConnections(ctx, connectionsToAdd, defaultInterface); err != nil {
|
|
return fmt.Errorf("cannot set VPN connections through firewall: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func removeConnectionFromConnections(connections []models.OpenVPNConnection, connection models.OpenVPNConnection) []models.OpenVPNConnection {
|
|
L := len(connections)
|
|
for i := range connections {
|
|
if connection.Equal(connections[i]) {
|
|
connections[i] = connections[L-1]
|
|
connections = connections[:L-1]
|
|
break
|
|
}
|
|
}
|
|
return connections
|
|
}
|
|
|
|
func findConnectionsToAdd(oldConnections, newConnections []models.OpenVPNConnection) (connectionsToAdd []models.OpenVPNConnection) {
|
|
for _, newConnection := range newConnections {
|
|
found := false
|
|
for _, oldConnection := range oldConnections {
|
|
if oldConnection.Equal(newConnection) {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
connectionsToAdd = append(connectionsToAdd, newConnection)
|
|
}
|
|
}
|
|
return connectionsToAdd
|
|
}
|
|
|
|
func findConnectionsToRemove(oldConnections, newConnections []models.OpenVPNConnection) (connectionsToRemove []models.OpenVPNConnection) {
|
|
for _, oldConnection := range oldConnections {
|
|
found := false
|
|
for _, newConnection := range newConnections {
|
|
if oldConnection.Equal(newConnection) {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
connectionsToRemove = append(connectionsToRemove, oldConnection)
|
|
}
|
|
}
|
|
return connectionsToRemove
|
|
}
|
|
|
|
func (c *configurator) removeConnections(ctx context.Context, connections []models.OpenVPNConnection, defaultInterface string) {
|
|
for _, conn := range connections {
|
|
const remove = true
|
|
if err := c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
|
|
c.logger.Error("cannot remove outdated VPN connection through firewall: %s", err)
|
|
continue
|
|
}
|
|
c.vpnConnections = removeConnectionFromConnections(c.vpnConnections, conn)
|
|
}
|
|
}
|
|
|
|
func (c *configurator) addConnections(ctx context.Context, connections []models.OpenVPNConnection, defaultInterface string) error {
|
|
const remove = false
|
|
for _, conn := range connections {
|
|
if err := c.acceptOutputTrafficToVPN(ctx, defaultInterface, conn, remove); err != nil {
|
|
return err
|
|
}
|
|
c.vpnConnections = append(c.vpnConnections, conn)
|
|
}
|
|
return nil
|
|
}
|