From 4f50eee45a6d31948fec07b33e5ca2dac2c0301d Mon Sep 17 00:00:00 2001 From: huoji Date: Fri, 17 Nov 2023 04:19:50 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E8=A7=A3=E9=99=A4IP=E5=B0=81?= =?UTF-8?q?=E7=A6=81=E7=9A=84=E9=80=89=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- linux_kernel/client_msg.c | 10 +++++++-- linux_kernel/msg.h | 1 + linux_kernel/network.c | 7 ++++++ linux_kernel/network.h | 1 + linux_service/events/network.cpp | 37 +++++++++++++++++++++++++++++++- linux_service/ip_blacktable.cpp | 4 +++- linux_service/msg.h | 1 + 7 files changed, 57 insertions(+), 4 deletions(-) diff --git a/linux_kernel/client_msg.c b/linux_kernel/client_msg.c index d5eb52b..778b831 100644 --- a/linux_kernel/client_msg.c +++ b/linux_kernel/client_msg.c @@ -1,12 +1,18 @@ #include "client_msg.h" void dispath_client_msg(struct client_msg_t* msg) { + uint32_t target_ip_address; + size_t block_time; switch (msg->type) { case SD_MSG_TYPE_CLIENT_BLOCK_IP: - const size_t target_ip_address = msg->u.ip_address.src_ip; - const size_t block_time = msg->u.ip_address.block_time; + target_ip_address = msg->u.ip_address.src_ip; + block_time = msg->u.ip_address.block_time; block_ip_address(target_ip_address, block_time); break; + case SD_MSG_TYPE_CLIENT_UNBLOCK_IP: + target_ip_address = msg->u.ip_address.src_ip; + unblock_ip_address(target_ip_address); + break; default: printk(KERN_INFO "Unknown msg type: %d\n", msg->type); break; diff --git a/linux_kernel/msg.h b/linux_kernel/msg.h index 270609c..0db4ff5 100644 --- a/linux_kernel/msg.h +++ b/linux_kernel/msg.h @@ -7,6 +7,7 @@ typedef enum _msg_type { SD_MSG_TYPE_SYN_ATTACK = 1, SD_MSG_TYPE_CLIENT_BLOCK_IP = 2, SD_MSG_TYPE_SSH_BF_ATTACK = 3, + SD_MSG_TYPE_CLIENT_UNBLOCK_IP = 4, }; typedef struct kernel_msg_t { diff --git a/linux_kernel/network.c b/linux_kernel/network.c index cb54892..01ee34f 100644 --- a/linux_kernel/network.c +++ b/linux_kernel/network.c @@ -24,6 +24,13 @@ bool check_is_blacklist_ip(u32 ip_address) { } return data->info.ip_meta_info.is_attack; } +void unblock_ip_address(u32 ip_address) { + struct ip_hashmap_node_t *data = get_ipdata_by_hashmap(ip_address); + if (data == NULL) { + return; + } + data->info.ip_meta_info.is_attack = false; +} bool check_syn_attack(struct iphdr *ip_header, struct sk_buff *skb) { bool is_block = false; do { diff --git a/linux_kernel/network.h b/linux_kernel/network.h index 011c377..04b36ce 100644 --- a/linux_kernel/network.h +++ b/linux_kernel/network.h @@ -13,3 +13,4 @@ extern unsigned int network_callback(const struct nf_hook_ops *ops, int (*okfn)(struct sk_buff *)); extern void block_ip_address(u32 ip_address, size_t time_sec); extern bool check_is_blacklist_ip(u32 ip_address); +extern void unblock_ip_address(u32 ip_address); diff --git a/linux_service/events/network.cpp b/linux_service/events/network.cpp index 0c70479..f2cdeb8 100644 --- a/linux_service/events/network.cpp +++ b/linux_service/events/network.cpp @@ -1,6 +1,12 @@ #include "network.h" +#include +#include +#include +#include namespace network_event { - +// read write lock +std::shared_mutex ip_blacklist_lock; +std::unordered_map ip_blacklist_cache; auto block_ip(uint32_t ip_address, size_t time_sec) -> bool { client_msg_t msg{0}; msg.check_sum = MSG_CHECK_SUM; @@ -9,11 +15,40 @@ auto block_ip(uint32_t ip_address, size_t time_sec) -> bool { msg.u.ip_address.block_time = time_sec; return client_msg::call_driver(msg); } +auto unblock_ip(uint32_t ip_address) -> bool { + client_msg_t msg{0}; + msg.check_sum = MSG_CHECK_SUM; + msg.type = static_cast(_msg_type::SD_MSG_TYPE_CLIENT_UNBLOCK_IP); + msg.u.ip_address.src_ip = ip_address; + return client_msg::call_driver(msg); +} auto on_ip_connect(uint32_t ip_address) -> bool { + std::shared_lock lock(ip_blacklist_lock); + if (ip_blacklist_cache.find(ip_address) != ip_blacklist_cache.end()) { + const auto current_time = std::time(nullptr); + const auto block_time = ip_blacklist_cache[ip_address]; + if (current_time - block_time < MAX_BLOCK_TIME) { + LOG("IP %s is in cache block list\n", + tools::cover_ip(ip_address).c_str()); + return true; + } + // cover lock to write lock, remove the ip from cache + lock.unlock(); + std::unique_lock ulock(ip_blacklist_lock); + ip_blacklist_cache.erase(ip_address); + } const auto is_still_in_block_list = global::ip_blacklist_db->selectRecordByIpAndTime(ip_address, MAX_BLOCK_TIME); if (is_still_in_block_list) { + const auto block_time = is_still_in_block_list.value().time; + if (block_time != 0) { + lock.unlock(); + std::unique_lock ulock(ip_blacklist_lock); + ip_blacklist_cache[ip_address] = + is_still_in_block_list.value().time; + } + LOG("IP %s is still in block list\n", tools::cover_ip(ip_address).c_str()); return true; diff --git a/linux_service/ip_blacktable.cpp b/linux_service/ip_blacktable.cpp index a0b435e..5a92947 100644 --- a/linux_service/ip_blacktable.cpp +++ b/linux_service/ip_blacktable.cpp @@ -97,6 +97,7 @@ auto IpBlacklistDB::deleteRecord(int id) -> void { sqlite3_finalize(stmt); } } +// if time == 0 , it means forever auto IpBlacklistDB::selectRecordByIpAndTime(uint32_t ip, uint64_t time_second) -> std::optional { std::vector records; @@ -109,7 +110,8 @@ auto IpBlacklistDB::selectRecordByIpAndTime(uint32_t ip, uint64_t time_second) // Adjusted SQL query to check if the timestamp is greater than or equal to // past_time const char *sql = - "SELECT * FROM ip_black_table WHERE ip = ? AND time >= ?;"; + "SELECT * FROM ip_black_table WHERE ip = ? AND (time >= ? OR time = " + "0);"; sqlite3_stmt *stmt; rc = sqlite3_prepare_v2(db, sql, -1, &stmt, NULL); diff --git a/linux_service/msg.h b/linux_service/msg.h index fc967a3..6f4c820 100644 --- a/linux_service/msg.h +++ b/linux_service/msg.h @@ -8,6 +8,7 @@ enum class _msg_type { SD_MSG_TYPE_SYN_ATTACK = 1, SD_MSG_TYPE_CLIENT_BLOCK_IP = 2, SD_MSG_TYPE_SSH_BF_ATTACK = 3, + SD_MSG_TYPE_CLIENT_UNBLOCK_IP = 4, }; typedef struct kernel_msg_t { unsigned long check_sum;