mirror of
https://github.com/yyhuni/xingrin.git
synced 2026-01-24 08:13:10 +08:00
feat(scan): add provider abstraction layer for flexible target sourcing
- Add TargetProvider base class and ProviderContext for unified target acquisition - Implement DatabaseTargetProvider for database-backed target queries - Implement ListTargetProvider for in-memory target lists (fast scan phase 1) - Implement SnapshotTargetProvider for snapshot table reads (fast scan phase 2+) - Implement PipelineTargetProvider for pipeline stage outputs - Add comprehensive provider tests covering common properties and individual providers - Update screenshot_flow to support both legacy mode (target_id) and provider mode - Add backward compatibility layer for existing task exports (directory, fingerprint, port, site, url_fetch, vuln scans) - Add task backward compatibility tests - Update .gitignore to exclude .hypothesis/ cache directory - Update frontend ANSI log viewer component - Update backend requirements.txt with new dependencies - Enables flexible data source integration while maintaining backward compatibility with existing database-driven workflows
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -64,6 +64,7 @@ backend/.env.local
|
||||
.coverage
|
||||
htmlcov/
|
||||
*.cover
|
||||
.hypothesis/
|
||||
|
||||
# ============================
|
||||
# 后端 (Go) 相关
|
||||
|
||||
@@ -5,6 +5,10 @@
|
||||
1. 从数据库获取 URL 列表(websites 和/或 endpoints)
|
||||
2. 批量截图并保存快照
|
||||
3. 同步到资产表
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
|
||||
"""
|
||||
|
||||
# Django 环境初始化
|
||||
@@ -12,6 +16,7 @@ from apps.common.prefect_django_setup import setup_django_for_prefect
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from prefect import flow
|
||||
|
||||
from apps.scan.tasks.screenshot import capture_screenshots_task
|
||||
@@ -25,6 +30,7 @@ from apps.scan.services.target_export_service import (
|
||||
get_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -88,11 +94,16 @@ def screenshot_flow(
|
||||
target_name: str,
|
||||
target_id: int,
|
||||
scan_workspace_dir: str,
|
||||
enabled_tools: dict
|
||||
enabled_tools: dict,
|
||||
provider: Optional[TargetProvider] = None
|
||||
) -> dict:
|
||||
"""
|
||||
截图 Flow
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
|
||||
|
||||
工作流程:
|
||||
Step 1: 解析配置
|
||||
Step 2: 收集 URL 列表
|
||||
@@ -105,6 +116,7 @@ def screenshot_flow(
|
||||
target_id: 目标 ID
|
||||
scan_workspace_dir: 扫描工作空间目录
|
||||
enabled_tools: 启用的工具配置
|
||||
provider: TargetProvider 实例(新模式,可选)
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
@@ -124,6 +136,7 @@ def screenshot_flow(
|
||||
f" Scan ID: {scan_id}\n" +
|
||||
f" Target: {target_name}\n" +
|
||||
f" Workspace: {scan_workspace_dir}\n" +
|
||||
f" Mode: {'Provider' if provider else 'Legacy'}\n" +
|
||||
"="*60
|
||||
)
|
||||
|
||||
@@ -136,14 +149,31 @@ def screenshot_flow(
|
||||
|
||||
logger.info("截图配置 - 并发: %d, URL来源: %s", concurrency, url_sources)
|
||||
|
||||
# Step 2: 使用统一服务收集 URL(带黑名单过滤和回退)
|
||||
data_sources = _map_url_sources_to_data_sources(url_sources)
|
||||
result = get_urls_with_fallback(target_id, sources=data_sources)
|
||||
# Step 2: 收集 URL 列表
|
||||
if provider is not None:
|
||||
# Provider 模式:使用 TargetProvider 获取 URL
|
||||
logger.info("使用 Provider 模式获取 URL - Provider: %s", type(provider).__name__)
|
||||
urls = list(provider.iter_urls())
|
||||
|
||||
# 应用黑名单过滤(如果有)
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
if blacklist_filter:
|
||||
urls = [url for url in urls if blacklist_filter.is_allowed(url)]
|
||||
|
||||
source_info = 'provider'
|
||||
tried_sources = ['provider']
|
||||
else:
|
||||
# 传统模式:使用统一服务收集 URL(带黑名单过滤和回退)
|
||||
data_sources = _map_url_sources_to_data_sources(url_sources)
|
||||
result = get_urls_with_fallback(target_id, sources=data_sources)
|
||||
|
||||
urls = result['urls']
|
||||
source_info = result['source']
|
||||
tried_sources = result['tried_sources']
|
||||
|
||||
urls = result['urls']
|
||||
logger.info(
|
||||
"URL 收集完成 - 来源: %s, 数量: %d, 尝试过: %s",
|
||||
result['source'], result['total_count'], result['tried_sources']
|
||||
source_info, len(urls), tried_sources
|
||||
)
|
||||
|
||||
if not urls:
|
||||
@@ -159,7 +189,7 @@ def screenshot_flow(
|
||||
'synced': 0
|
||||
}
|
||||
|
||||
user_log(scan_id, "screenshot", f"Found {len(urls)} URLs to capture (source: {result['source']})")
|
||||
user_log(scan_id, "screenshot", f"Found {len(urls)} URLs to capture (source: {source_info})")
|
||||
|
||||
# Step 3: 批量截图
|
||||
logger.info("Step 3: 批量截图 - %d 个 URL", len(urls))
|
||||
|
||||
56
backend/apps/scan/providers/__init__.py
Normal file
56
backend/apps/scan/providers/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
扫描目标提供者模块
|
||||
|
||||
提供统一的目标获取接口,支持多种数据源:
|
||||
- DatabaseTargetProvider: 从数据库查询(完整扫描)
|
||||
- ListTargetProvider: 使用内存列表(快速扫描阶段1)
|
||||
- SnapshotTargetProvider: 从快照表读取(快速扫描阶段2+)
|
||||
- PipelineTargetProvider: 使用管道输出(Phase 2)
|
||||
|
||||
使用方式:
|
||||
from apps.scan.providers import (
|
||||
DatabaseTargetProvider,
|
||||
ListTargetProvider,
|
||||
SnapshotTargetProvider,
|
||||
ProviderContext
|
||||
)
|
||||
|
||||
# 数据库模式(完整扫描)
|
||||
provider = DatabaseTargetProvider(target_id=123)
|
||||
|
||||
# 列表模式(快速扫描阶段1)
|
||||
context = ProviderContext(target_id=1, scan_id=100)
|
||||
provider = ListTargetProvider(
|
||||
targets=["a.test.com"],
|
||||
context=context
|
||||
)
|
||||
|
||||
# 快照模式(快速扫描阶段2+)
|
||||
context = ProviderContext(target_id=1, scan_id=100)
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain",
|
||||
context=context
|
||||
)
|
||||
|
||||
# 使用 Provider
|
||||
for host in provider.iter_hosts():
|
||||
scan(host)
|
||||
"""
|
||||
|
||||
from .base import TargetProvider, ProviderContext
|
||||
from .list_provider import ListTargetProvider
|
||||
from .database_provider import DatabaseTargetProvider
|
||||
from .snapshot_provider import SnapshotTargetProvider, SnapshotType
|
||||
from .pipeline_provider import PipelineTargetProvider, StageOutput
|
||||
|
||||
__all__ = [
|
||||
'TargetProvider',
|
||||
'ProviderContext',
|
||||
'ListTargetProvider',
|
||||
'DatabaseTargetProvider',
|
||||
'SnapshotTargetProvider',
|
||||
'SnapshotType',
|
||||
'PipelineTargetProvider',
|
||||
'StageOutput',
|
||||
]
|
||||
162
backend/apps/scan/providers/base.py
Normal file
162
backend/apps/scan/providers/base.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
扫描目标提供者基础模块
|
||||
|
||||
定义 ProviderContext 数据类和 TargetProvider 抽象基类。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterator, Optional, TYPE_CHECKING
|
||||
import ipaddress
|
||||
import logging
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from apps.common.utils import BlacklistFilter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderContext:
|
||||
"""
|
||||
Provider 上下文,携带元数据
|
||||
|
||||
Attributes:
|
||||
target_id: 关联的 Target ID(用于结果保存),None 表示临时扫描(不保存)
|
||||
scan_id: 扫描任务 ID
|
||||
"""
|
||||
target_id: Optional[int] = None
|
||||
scan_id: Optional[int] = None
|
||||
|
||||
|
||||
class TargetProvider(ABC):
|
||||
"""
|
||||
扫描目标提供者抽象基类
|
||||
|
||||
职责:
|
||||
- 提供扫描目标(域名、IP、URL 等)的迭代器
|
||||
- 提供黑名单过滤器
|
||||
- 携带上下文信息(target_id, scan_id 等)
|
||||
- 自动展开 CIDR(子类无需关心)
|
||||
|
||||
使用方式:
|
||||
provider = create_target_provider(target_id=123)
|
||||
for host in provider.iter_hosts():
|
||||
print(host)
|
||||
"""
|
||||
|
||||
def __init__(self, context: Optional[ProviderContext] = None):
|
||||
"""
|
||||
初始化 Provider
|
||||
|
||||
Args:
|
||||
context: Provider 上下文,None 时创建默认上下文
|
||||
"""
|
||||
self._context = context or ProviderContext()
|
||||
|
||||
@property
|
||||
def context(self) -> ProviderContext:
|
||||
"""返回 Provider 上下文"""
|
||||
return self._context
|
||||
|
||||
@staticmethod
|
||||
def _expand_host(host: str) -> Iterator[str]:
|
||||
"""
|
||||
展开主机(如果是 CIDR 则展开为多个 IP,否则直接返回)
|
||||
|
||||
这是一个内部方法,由 iter_hosts() 自动调用。
|
||||
|
||||
Args:
|
||||
host: 主机字符串(IP/域名/CIDR)
|
||||
|
||||
Yields:
|
||||
str: 单个主机(IP 或域名)
|
||||
|
||||
示例:
|
||||
"192.168.1.0/30" → "192.168.1.1", "192.168.1.2"
|
||||
"192.168.1.1" → "192.168.1.1"
|
||||
"example.com" → "example.com"
|
||||
"invalid" → (跳过,不返回)
|
||||
"""
|
||||
from apps.common.validators import detect_target_type
|
||||
from apps.targets.models import Target
|
||||
|
||||
host = host.strip()
|
||||
if not host:
|
||||
return
|
||||
|
||||
# 统一使用 detect_target_type 检测类型
|
||||
try:
|
||||
target_type = detect_target_type(host)
|
||||
|
||||
if target_type == Target.TargetType.CIDR:
|
||||
# 展开 CIDR
|
||||
network = ipaddress.ip_network(host, strict=False)
|
||||
if network.num_addresses == 1:
|
||||
yield str(network.network_address)
|
||||
else:
|
||||
for ip in network.hosts():
|
||||
yield str(ip)
|
||||
elif target_type == Target.TargetType.IP:
|
||||
# 单个 IP
|
||||
yield host
|
||||
elif target_type == Target.TargetType.DOMAIN:
|
||||
# 域名
|
||||
yield host
|
||||
except ValueError as e:
|
||||
# 无效格式,跳过并记录警告
|
||||
logger.warning("跳过无效的主机格式 '%s': %s", host, str(e))
|
||||
|
||||
def iter_hosts(self) -> Iterator[str]:
|
||||
"""
|
||||
迭代主机列表(域名/IP)
|
||||
|
||||
自动展开 CIDR,子类无需关心。
|
||||
|
||||
Yields:
|
||||
str: 主机名或 IP 地址(单个,不包含 CIDR)
|
||||
"""
|
||||
for host in self._iter_raw_hosts():
|
||||
yield from self._expand_host(host)
|
||||
|
||||
@abstractmethod
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""
|
||||
迭代原始主机列表(可能包含 CIDR)
|
||||
|
||||
子类实现此方法,返回原始数据即可,不需要处理 CIDR 展开。
|
||||
|
||||
Yields:
|
||||
str: 主机名、IP 地址或 CIDR
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""
|
||||
迭代 URL 列表
|
||||
|
||||
Yields:
|
||||
str: URL 字符串
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
|
||||
"""
|
||||
获取黑名单过滤器
|
||||
|
||||
Returns:
|
||||
BlacklistFilter: 黑名单过滤器实例,或 None(不过滤)
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def target_id(self) -> Optional[int]:
|
||||
"""返回关联的 target_id,临时扫描返回 None"""
|
||||
return self._context.target_id
|
||||
|
||||
@property
|
||||
def scan_id(self) -> Optional[int]:
|
||||
"""返回关联的 scan_id"""
|
||||
return self._context.scan_id
|
||||
131
backend/apps/scan/providers/database_provider.py
Normal file
131
backend/apps/scan/providers/database_provider.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
数据库目标提供者模块
|
||||
|
||||
提供基于数据库查询的目标提供者实现。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Iterator, Optional, TYPE_CHECKING
|
||||
|
||||
from .base import TargetProvider, ProviderContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from apps.common.utils import BlacklistFilter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseTargetProvider(TargetProvider):
|
||||
"""
|
||||
数据库目标提供者 - 从 Target 表及关联资产表查询
|
||||
|
||||
这是现有行为的封装,保持向后兼容。
|
||||
|
||||
数据来源:
|
||||
- iter_hosts(): 根据 Target 类型返回域名/IP
|
||||
- DOMAIN: 根域名 + Subdomain 表
|
||||
- IP: 直接返回 IP
|
||||
- CIDR: 使用 _expand_host() 展开为所有主机 IP
|
||||
- iter_urls(): WebSite/Endpoint 表,带回退链
|
||||
|
||||
使用方式:
|
||||
provider = DatabaseTargetProvider(target_id=123)
|
||||
for host in provider.iter_hosts():
|
||||
scan(host)
|
||||
"""
|
||||
|
||||
def __init__(self, target_id: int, context: Optional[ProviderContext] = None):
|
||||
"""
|
||||
初始化数据库目标提供者
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID(必需)
|
||||
context: Provider 上下文
|
||||
"""
|
||||
ctx = context or ProviderContext()
|
||||
ctx.target_id = target_id
|
||||
super().__init__(ctx)
|
||||
self._blacklist_filter: Optional['BlacklistFilter'] = None # 延迟加载
|
||||
|
||||
def iter_hosts(self) -> Iterator[str]:
|
||||
"""
|
||||
从数据库查询主机列表,自动展开 CIDR 并应用黑名单过滤
|
||||
|
||||
重写基类方法以支持黑名单过滤(需要在 CIDR 展开后过滤)
|
||||
"""
|
||||
blacklist = self.get_blacklist_filter()
|
||||
|
||||
for host in self._iter_raw_hosts():
|
||||
# 展开 CIDR
|
||||
for expanded_host in self._expand_host(host):
|
||||
# 应用黑名单过滤
|
||||
if not blacklist or blacklist.is_allowed(expanded_host):
|
||||
yield expanded_host
|
||||
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""
|
||||
从数据库查询原始主机列表(可能包含 CIDR)
|
||||
|
||||
根据 Target 类型决定数据来源:
|
||||
- DOMAIN: 根域名 + Subdomain 表
|
||||
- IP: 直接返回 target.name
|
||||
- CIDR: 返回 CIDR 字符串(由 iter_hosts() 展开)
|
||||
|
||||
注意:此方法不应用黑名单过滤,过滤在 iter_hosts() 中进行
|
||||
"""
|
||||
from apps.targets.services import TargetService
|
||||
from apps.targets.models import Target
|
||||
from apps.asset.services.asset.subdomain_service import SubdomainService
|
||||
|
||||
target = TargetService().get_target(self.target_id)
|
||||
if not target:
|
||||
logger.warning("Target ID %d 不存在", self.target_id)
|
||||
return
|
||||
|
||||
if target.type == Target.TargetType.DOMAIN:
|
||||
# 返回根域名
|
||||
yield target.name
|
||||
|
||||
# 返回子域名
|
||||
subdomain_service = SubdomainService()
|
||||
for domain in subdomain_service.iter_subdomain_names_by_target(
|
||||
target_id=self.target_id,
|
||||
chunk_size=1000
|
||||
):
|
||||
if domain != target.name: # 避免重复
|
||||
yield domain
|
||||
|
||||
elif target.type == Target.TargetType.IP:
|
||||
yield target.name
|
||||
|
||||
elif target.type == Target.TargetType.CIDR:
|
||||
# 直接返回 CIDR,由 iter_hosts() 展开并过滤
|
||||
yield target.name
|
||||
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""
|
||||
从数据库查询 URL 列表
|
||||
|
||||
使用现有的回退链逻辑:Endpoint → WebSite → Default
|
||||
"""
|
||||
from apps.scan.services.target_export_service import (
|
||||
_iter_urls_with_fallback, DataSource
|
||||
)
|
||||
|
||||
blacklist = self.get_blacklist_filter()
|
||||
|
||||
for url, source in _iter_urls_with_fallback(
|
||||
target_id=self.target_id,
|
||||
sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT],
|
||||
blacklist_filter=blacklist
|
||||
):
|
||||
yield url
|
||||
|
||||
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
|
||||
"""获取黑名单过滤器(延迟加载)"""
|
||||
if self._blacklist_filter is None:
|
||||
from apps.common.services import BlacklistService
|
||||
from apps.common.utils import BlacklistFilter
|
||||
rules = BlacklistService().get_rules(self.target_id)
|
||||
self._blacklist_filter = BlacklistFilter(rules)
|
||||
return self._blacklist_filter
|
||||
84
backend/apps/scan/providers/list_provider.py
Normal file
84
backend/apps/scan/providers/list_provider.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
列表目标提供者模块
|
||||
|
||||
提供基于内存列表的目标提供者实现。
|
||||
"""
|
||||
|
||||
from typing import Iterator, Optional, List
|
||||
|
||||
from .base import TargetProvider, ProviderContext
|
||||
|
||||
|
||||
class ListTargetProvider(TargetProvider):
|
||||
"""
|
||||
列表目标提供者 - 直接使用内存中的列表
|
||||
|
||||
用于快速扫描、临时扫描等场景,只扫描用户指定的目标。
|
||||
|
||||
特点:
|
||||
- 不查询数据库
|
||||
- 不应用黑名单过滤(用户明确指定的目标)
|
||||
- 不关联 target_id(由调用方负责创建 Target)
|
||||
- 自动检测输入类型(URL/域名/IP/CIDR)
|
||||
- 自动展开 CIDR
|
||||
|
||||
使用方式:
|
||||
# 快速扫描:用户提供目标,自动识别类型
|
||||
provider = ListTargetProvider(targets=[
|
||||
"example.com", # 域名
|
||||
"192.168.1.0/24", # CIDR(自动展开)
|
||||
"https://api.example.com" # URL
|
||||
])
|
||||
for host in provider.iter_hosts():
|
||||
scan(host)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
targets: Optional[List[str]] = None,
|
||||
context: Optional[ProviderContext] = None
|
||||
):
|
||||
"""
|
||||
初始化列表目标提供者
|
||||
|
||||
Args:
|
||||
targets: 目标列表(自动识别类型:URL/域名/IP/CIDR)
|
||||
context: Provider 上下文
|
||||
"""
|
||||
from apps.common.validators import detect_input_type
|
||||
|
||||
ctx = context or ProviderContext()
|
||||
super().__init__(ctx)
|
||||
|
||||
# 自动分类目标
|
||||
self._hosts = []
|
||||
self._urls = []
|
||||
|
||||
if targets:
|
||||
for target in targets:
|
||||
target = target.strip()
|
||||
if not target:
|
||||
continue
|
||||
|
||||
try:
|
||||
input_type = detect_input_type(target)
|
||||
if input_type == 'url':
|
||||
self._urls.append(target)
|
||||
else:
|
||||
# domain/ip/cidr 都作为 host
|
||||
self._hosts.append(target)
|
||||
except ValueError:
|
||||
# 无法识别类型,默认作为 host
|
||||
self._hosts.append(target)
|
||||
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""迭代原始主机列表(可能包含 CIDR)"""
|
||||
yield from self._hosts
|
||||
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""迭代 URL 列表"""
|
||||
yield from self._urls
|
||||
|
||||
def get_blacklist_filter(self) -> None:
|
||||
"""列表模式不使用黑名单过滤"""
|
||||
return None
|
||||
91
backend/apps/scan/providers/pipeline_provider.py
Normal file
91
backend/apps/scan/providers/pipeline_provider.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
管道目标提供者模块
|
||||
|
||||
提供基于管道阶段输出的目标提供者实现。
|
||||
用于 Phase 2 管道模式的阶段间数据传递。
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator, Optional, List, Dict, Any
|
||||
|
||||
from .base import TargetProvider, ProviderContext
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageOutput:
|
||||
"""
|
||||
阶段输出数据
|
||||
|
||||
用于在管道阶段之间传递数据。
|
||||
|
||||
Attributes:
|
||||
hosts: 主机列表(域名/IP)
|
||||
urls: URL 列表
|
||||
new_targets: 新发现的目标列表
|
||||
stats: 统计信息
|
||||
success: 是否成功
|
||||
error: 错误信息
|
||||
"""
|
||||
hosts: List[str] = field(default_factory=list)
|
||||
urls: List[str] = field(default_factory=list)
|
||||
new_targets: List[str] = field(default_factory=list)
|
||||
stats: Dict[str, Any] = field(default_factory=dict)
|
||||
success: bool = True
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class PipelineTargetProvider(TargetProvider):
|
||||
"""
|
||||
管道目标提供者 - 使用上一阶段的输出
|
||||
|
||||
用于 Phase 2 管道模式的阶段间数据传递。
|
||||
|
||||
特点:
|
||||
- 不查询数据库
|
||||
- 不应用黑名单过滤(数据已在上一阶段过滤)
|
||||
- 直接使用 StageOutput 中的数据
|
||||
|
||||
使用方式(Phase 2):
|
||||
stage1_output = stage1.run(input)
|
||||
provider = PipelineTargetProvider(
|
||||
previous_output=stage1_output,
|
||||
target_id=123
|
||||
)
|
||||
for host in provider.iter_hosts():
|
||||
stage2.scan(host)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
previous_output: StageOutput,
|
||||
target_id: Optional[int] = None,
|
||||
context: Optional[ProviderContext] = None
|
||||
):
|
||||
"""
|
||||
初始化管道目标提供者
|
||||
|
||||
Args:
|
||||
previous_output: 上一阶段的输出
|
||||
target_id: 可选,关联到某个 Target(用于保存结果)
|
||||
context: Provider 上下文
|
||||
"""
|
||||
ctx = context or ProviderContext(target_id=target_id)
|
||||
super().__init__(ctx)
|
||||
self._previous_output = previous_output
|
||||
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""迭代上一阶段输出的原始主机(可能包含 CIDR)"""
|
||||
yield from self._previous_output.hosts
|
||||
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""迭代上一阶段输出的 URL"""
|
||||
yield from self._previous_output.urls
|
||||
|
||||
def get_blacklist_filter(self) -> None:
|
||||
"""管道传递的数据已经过滤过了"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def previous_output(self) -> StageOutput:
|
||||
"""返回上一阶段的输出"""
|
||||
return self._previous_output
|
||||
175
backend/apps/scan/providers/snapshot_provider.py
Normal file
175
backend/apps/scan/providers/snapshot_provider.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
快照目标提供者模块
|
||||
|
||||
提供基于快照表的目标提供者实现。
|
||||
用于快速扫描的阶段间数据传递。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Iterator, Optional, Literal
|
||||
|
||||
from .base import TargetProvider, ProviderContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 快照类型定义
|
||||
SnapshotType = Literal["subdomain", "website", "endpoint", "host_port"]
|
||||
|
||||
|
||||
class SnapshotTargetProvider(TargetProvider):
|
||||
"""
|
||||
快照目标提供者 - 从快照表读取本次扫描的数据
|
||||
|
||||
用于快速扫描的阶段间数据传递,解决精确扫描控制问题。
|
||||
|
||||
核心价值:
|
||||
- 只返回本次扫描(scan_id)发现的资产
|
||||
- 避免扫描历史数据(DatabaseTargetProvider 会扫描所有历史资产)
|
||||
|
||||
特点:
|
||||
- 通过 scan_id 过滤快照表
|
||||
- 不应用黑名单过滤(数据已在上一阶段过滤)
|
||||
- 支持多种快照类型(subdomain/website/endpoint/host_port)
|
||||
|
||||
使用场景:
|
||||
# 快速扫描流程
|
||||
用户输入: a.test.com
|
||||
创建 Target: test.com (id=1)
|
||||
创建 Scan: scan_id=100
|
||||
|
||||
# 阶段1: 子域名发现
|
||||
provider = ListTargetProvider(
|
||||
targets=["a.test.com"],
|
||||
context=ProviderContext(target_id=1, scan_id=100)
|
||||
)
|
||||
# 发现: b.a.test.com, c.a.test.com
|
||||
# 保存: SubdomainSnapshot(scan_id=100) + Subdomain(target_id=1)
|
||||
|
||||
# 阶段2: 端口扫描
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain",
|
||||
context=ProviderContext(target_id=1, scan_id=100)
|
||||
)
|
||||
# 只返回: b.a.test.com, c.a.test.com(本次扫描发现的)
|
||||
# 不返回: www.test.com, api.test.com(历史数据)
|
||||
|
||||
# 阶段3: 网站扫描
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="host_port",
|
||||
context=ProviderContext(target_id=1, scan_id=100)
|
||||
)
|
||||
# 只返回本次扫描发现的 IP:Port
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scan_id: int,
|
||||
snapshot_type: SnapshotType,
|
||||
context: Optional[ProviderContext] = None
|
||||
):
|
||||
"""
|
||||
初始化快照目标提供者
|
||||
|
||||
Args:
|
||||
scan_id: 扫描任务 ID(必需)
|
||||
snapshot_type: 快照类型
|
||||
- "subdomain": 子域名快照(SubdomainSnapshot)
|
||||
- "website": 网站快照(WebsiteSnapshot)
|
||||
- "endpoint": 端点快照(EndpointSnapshot)
|
||||
- "host_port": 主机端口映射快照(HostPortMappingSnapshot)
|
||||
context: Provider 上下文
|
||||
"""
|
||||
ctx = context or ProviderContext()
|
||||
ctx.scan_id = scan_id
|
||||
super().__init__(ctx)
|
||||
self._scan_id = scan_id
|
||||
self._snapshot_type = snapshot_type
|
||||
|
||||
def _iter_raw_hosts(self) -> Iterator[str]:
|
||||
"""
|
||||
从快照表迭代主机列表
|
||||
|
||||
根据 snapshot_type 选择不同的快照表:
|
||||
- subdomain: SubdomainSnapshot.name
|
||||
- host_port: HostPortMappingSnapshot.host (返回 host:port 格式,不经过验证)
|
||||
"""
|
||||
if self._snapshot_type == "subdomain":
|
||||
from apps.asset.services.snapshot import SubdomainSnapshotsService
|
||||
service = SubdomainSnapshotsService()
|
||||
yield from service.iter_subdomain_names_by_scan(
|
||||
scan_id=self._scan_id,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
elif self._snapshot_type == "host_port":
|
||||
# host_port 类型不使用 _iter_raw_hosts,直接在 iter_hosts 中处理
|
||||
# 这里返回空,避免被基类的 iter_hosts 调用
|
||||
return
|
||||
|
||||
else:
|
||||
# 其他类型暂不支持 iter_hosts
|
||||
logger.warning(
|
||||
"快照类型 '%s' 不支持 iter_hosts,返回空迭代器",
|
||||
self._snapshot_type
|
||||
)
|
||||
return
|
||||
|
||||
def iter_hosts(self) -> Iterator[str]:
|
||||
"""
|
||||
迭代主机列表
|
||||
|
||||
对于 host_port 类型,返回 host:port 格式,不经过 CIDR 展开验证
|
||||
"""
|
||||
if self._snapshot_type == "host_port":
|
||||
# host_port 类型直接返回 host:port,不经过 _expand_host 验证
|
||||
from apps.asset.services.snapshot import HostPortMappingSnapshotsService
|
||||
service = HostPortMappingSnapshotsService()
|
||||
queryset = service.get_by_scan(scan_id=self._scan_id)
|
||||
for mapping in queryset.iterator(chunk_size=1000):
|
||||
yield f"{mapping.host}:{mapping.port}"
|
||||
else:
|
||||
# 其他类型使用基类的 iter_hosts(会调用 _iter_raw_hosts 并展开 CIDR)
|
||||
yield from super().iter_hosts()
|
||||
|
||||
def iter_urls(self) -> Iterator[str]:
|
||||
"""
|
||||
从快照表迭代 URL 列表
|
||||
|
||||
根据 snapshot_type 选择不同的快照表:
|
||||
- website: WebsiteSnapshot.url
|
||||
- endpoint: EndpointSnapshot.url
|
||||
"""
|
||||
if self._snapshot_type == "website":
|
||||
from apps.asset.services.snapshot import WebsiteSnapshotsService
|
||||
service = WebsiteSnapshotsService()
|
||||
yield from service.iter_website_urls_by_scan(
|
||||
scan_id=self._scan_id,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
elif self._snapshot_type == "endpoint":
|
||||
from apps.asset.services.snapshot import EndpointSnapshotsService
|
||||
service = EndpointSnapshotsService()
|
||||
# 从快照表获取端点 URL
|
||||
queryset = service.get_by_scan(scan_id=self._scan_id)
|
||||
for endpoint in queryset.iterator(chunk_size=1000):
|
||||
yield endpoint.url
|
||||
|
||||
else:
|
||||
# 其他类型暂不支持 iter_urls
|
||||
logger.warning(
|
||||
"快照类型 '%s' 不支持 iter_urls,返回空迭代器",
|
||||
self._snapshot_type
|
||||
)
|
||||
return
|
||||
|
||||
def get_blacklist_filter(self) -> None:
|
||||
"""快照数据已在上一阶段过滤过了"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def snapshot_type(self) -> SnapshotType:
|
||||
"""返回快照类型"""
|
||||
return self._snapshot_type
|
||||
3
backend/apps/scan/providers/tests/__init__.py
Normal file
3
backend/apps/scan/providers/tests/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
扫描目标提供者测试模块
|
||||
"""
|
||||
256
backend/apps/scan/providers/tests/test_common_properties.py
Normal file
256
backend/apps/scan/providers/tests/test_common_properties.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
通用属性测试
|
||||
|
||||
包含跨多个 Provider 的通用属性测试:
|
||||
- Property 4: Context Propagation
|
||||
- Property 5: Non-Database Provider Blacklist Filter
|
||||
- Property 7: CIDR Expansion Consistency
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, strategies as st, settings
|
||||
from ipaddress import IPv4Network
|
||||
|
||||
from apps.scan.providers import (
|
||||
ProviderContext,
|
||||
ListTargetProvider,
|
||||
DatabaseTargetProvider,
|
||||
PipelineTargetProvider,
|
||||
SnapshotTargetProvider
|
||||
)
|
||||
from apps.scan.providers.pipeline_provider import StageOutput
|
||||
|
||||
|
||||
class TestContextPropagation:
|
||||
"""
|
||||
Property 4: Context Propagation
|
||||
|
||||
*For any* ProviderContext,传入 Provider 构造函数后,
|
||||
Provider 的 target_id 和 scan_id 属性应该与 context 中的值一致。
|
||||
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=10000),
|
||||
scan_id=st.integers(min_value=1, max_value=10000)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_4_list_provider_context_propagation(self, target_id, scan_id):
|
||||
"""
|
||||
Property 4: Context Propagation (ListTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 4: Context Propagation
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
ctx = ProviderContext(target_id=target_id, scan_id=scan_id)
|
||||
provider = ListTargetProvider(targets=["example.com"], context=ctx)
|
||||
|
||||
assert provider.target_id == target_id
|
||||
assert provider.scan_id == scan_id
|
||||
assert provider.context.target_id == target_id
|
||||
assert provider.context.scan_id == scan_id
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=10000),
|
||||
scan_id=st.integers(min_value=1, max_value=10000)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_4_database_provider_context_propagation(self, target_id, scan_id):
|
||||
"""
|
||||
Property 4: Context Propagation (DatabaseTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 4: Context Propagation
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
ctx = ProviderContext(target_id=999, scan_id=scan_id)
|
||||
# DatabaseTargetProvider 会覆盖 context 中的 target_id
|
||||
provider = DatabaseTargetProvider(target_id=target_id, context=ctx)
|
||||
|
||||
assert provider.target_id == target_id # 使用构造函数参数
|
||||
assert provider.scan_id == scan_id # 使用 context 中的值
|
||||
assert provider.context.target_id == target_id
|
||||
assert provider.context.scan_id == scan_id
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=10000),
|
||||
scan_id=st.integers(min_value=1, max_value=10000)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_4_pipeline_provider_context_propagation(self, target_id, scan_id):
|
||||
"""
|
||||
Property 4: Context Propagation (PipelineTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 4: Context Propagation
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
ctx = ProviderContext(target_id=target_id, scan_id=scan_id)
|
||||
stage_output = StageOutput(hosts=["example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output, context=ctx)
|
||||
|
||||
assert provider.target_id == target_id
|
||||
assert provider.scan_id == scan_id
|
||||
assert provider.context.target_id == target_id
|
||||
assert provider.context.scan_id == scan_id
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=10000),
|
||||
scan_id=st.integers(min_value=1, max_value=10000)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_4_snapshot_provider_context_propagation(self, target_id, scan_id):
|
||||
"""
|
||||
Property 4: Context Propagation (SnapshotTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 4: Context Propagation
|
||||
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
|
||||
"""
|
||||
ctx = ProviderContext(target_id=target_id, scan_id=999)
|
||||
# SnapshotTargetProvider 会覆盖 context 中的 scan_id
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=scan_id,
|
||||
snapshot_type="subdomain",
|
||||
context=ctx
|
||||
)
|
||||
|
||||
assert provider.target_id == target_id # 使用 context 中的值
|
||||
assert provider.scan_id == scan_id # 使用构造函数参数
|
||||
assert provider.context.target_id == target_id
|
||||
assert provider.context.scan_id == scan_id
|
||||
|
||||
|
||||
class TestNonDatabaseProviderBlacklistFilter:
|
||||
"""
|
||||
Property 5: Non-Database Provider Blacklist Filter
|
||||
|
||||
*For any* ListTargetProvider 或 PipelineTargetProvider 实例,
|
||||
get_blacklist_filter() 方法应该返回 None。
|
||||
|
||||
**Validates: Requirements 3.4, 9.4, 9.5**
|
||||
"""
|
||||
|
||||
@given(targets=st.lists(st.text(min_size=1, max_size=20), max_size=10))
|
||||
@settings(max_examples=100)
|
||||
def test_property_5_list_provider_no_blacklist(self, targets):
|
||||
"""
|
||||
Property 5: Non-Database Provider Blacklist Filter (ListTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
|
||||
**Validates: Requirements 3.4, 9.4, 9.5**
|
||||
"""
|
||||
provider = ListTargetProvider(targets=targets)
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
@given(hosts=st.lists(st.text(min_size=1, max_size=20), max_size=10))
|
||||
@settings(max_examples=100)
|
||||
def test_property_5_pipeline_provider_no_blacklist(self, hosts):
|
||||
"""
|
||||
Property 5: Non-Database Provider Blacklist Filter (PipelineTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
|
||||
**Validates: Requirements 3.4, 9.4, 9.5**
|
||||
"""
|
||||
stage_output = StageOutput(hosts=hosts)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
def test_property_5_snapshot_provider_no_blacklist(self):
|
||||
"""
|
||||
Property 5: Non-Database Provider Blacklist Filter (SnapshotTargetProvider)
|
||||
|
||||
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
|
||||
**Validates: Requirements 3.4, 9.4, 9.5**
|
||||
"""
|
||||
provider = SnapshotTargetProvider(scan_id=1, snapshot_type="subdomain")
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
|
||||
class TestCIDRExpansionConsistency:
|
||||
"""
|
||||
Property 7: CIDR Expansion Consistency
|
||||
|
||||
*For any* CIDR 字符串(如 "192.168.1.0/24"),所有 Provider 的 iter_hosts()
|
||||
方法应该将其展开为相同的单个 IP 地址列表。
|
||||
|
||||
**Validates: Requirements 1.1, 3.6**
|
||||
"""
|
||||
|
||||
@given(
|
||||
# 生成小的 CIDR 范围以避免测试超时
|
||||
network_prefix=st.integers(min_value=1, max_value=254),
|
||||
cidr_suffix=st.integers(min_value=28, max_value=30) # /28 = 16 IPs, /30 = 4 IPs
|
||||
)
|
||||
@settings(max_examples=50, deadline=None)
|
||||
def test_property_7_cidr_expansion_consistency(self, network_prefix, cidr_suffix):
|
||||
"""
|
||||
Property 7: CIDR Expansion Consistency
|
||||
|
||||
Feature: scan-target-provider, Property 7: CIDR Expansion Consistency
|
||||
**Validates: Requirements 1.1, 3.6**
|
||||
|
||||
For any CIDR string, all Providers should expand it to the same IP list.
|
||||
"""
|
||||
cidr = f"192.168.{network_prefix}.0/{cidr_suffix}"
|
||||
|
||||
# 计算预期的 IP 列表
|
||||
network = IPv4Network(cidr, strict=False)
|
||||
# 排除网络地址和广播地址
|
||||
expected_ips = [str(ip) for ip in network.hosts()]
|
||||
|
||||
# 如果 CIDR 太小(/31 或 /32),使用所有地址
|
||||
if not expected_ips:
|
||||
expected_ips = [str(ip) for ip in network]
|
||||
|
||||
# ListTargetProvider
|
||||
list_provider = ListTargetProvider(targets=[cidr])
|
||||
list_result = list(list_provider.iter_hosts())
|
||||
|
||||
# PipelineTargetProvider
|
||||
stage_output = StageOutput(hosts=[cidr])
|
||||
pipeline_provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
pipeline_result = list(pipeline_provider.iter_hosts())
|
||||
|
||||
# 验证:所有 Provider 展开的结果应该一致
|
||||
assert list_result == expected_ips, f"ListProvider CIDR expansion mismatch for {cidr}"
|
||||
assert pipeline_result == expected_ips, f"PipelineProvider CIDR expansion mismatch for {cidr}"
|
||||
assert list_result == pipeline_result, f"Providers produce different results for {cidr}"
|
||||
|
||||
def test_cidr_expansion_with_multiple_cidrs(self):
|
||||
"""测试多个 CIDR 的展开一致性"""
|
||||
cidrs = ["192.168.1.0/30", "10.0.0.0/30"]
|
||||
|
||||
# 计算预期结果
|
||||
expected_ips = []
|
||||
for cidr in cidrs:
|
||||
network = IPv4Network(cidr, strict=False)
|
||||
expected_ips.extend([str(ip) for ip in network.hosts()])
|
||||
|
||||
# ListTargetProvider
|
||||
list_provider = ListTargetProvider(targets=cidrs)
|
||||
list_result = list(list_provider.iter_hosts())
|
||||
|
||||
# PipelineTargetProvider
|
||||
stage_output = StageOutput(hosts=cidrs)
|
||||
pipeline_provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
pipeline_result = list(pipeline_provider.iter_hosts())
|
||||
|
||||
# 验证
|
||||
assert list_result == expected_ips
|
||||
assert pipeline_result == expected_ips
|
||||
assert list_result == pipeline_result
|
||||
|
||||
def test_mixed_hosts_and_cidrs(self):
|
||||
"""测试混合主机和 CIDR 的处理"""
|
||||
targets = ["example.com", "192.168.1.0/30", "test.com"]
|
||||
|
||||
# 计算预期结果
|
||||
network = IPv4Network("192.168.1.0/30", strict=False)
|
||||
cidr_ips = [str(ip) for ip in network.hosts()]
|
||||
expected = ["example.com"] + cidr_ips + ["test.com"]
|
||||
|
||||
# ListTargetProvider
|
||||
list_provider = ListTargetProvider(targets=targets)
|
||||
list_result = list(list_provider.iter_hosts())
|
||||
|
||||
# 验证
|
||||
assert list_result == expected
|
||||
158
backend/apps/scan/providers/tests/test_database_provider.py
Normal file
158
backend/apps/scan/providers/tests/test_database_provider.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
DatabaseTargetProvider 属性测试
|
||||
|
||||
Property 7: DatabaseTargetProvider Blacklist Application
|
||||
*For any* 带有黑名单规则的 target_id,DatabaseTargetProvider 的 iter_hosts() 和 iter_urls()
|
||||
应该过滤掉匹配黑名单规则的目标。
|
||||
|
||||
**Validates: Requirements 2.3, 10.1, 10.2, 10.3**
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from hypothesis import given, strategies as st, settings
|
||||
|
||||
from apps.scan.providers.database_provider import DatabaseTargetProvider
|
||||
from apps.scan.providers.base import ProviderContext
|
||||
|
||||
|
||||
# 生成有效域名的策略
|
||||
def valid_domain_strategy():
|
||||
"""生成有效的域名"""
|
||||
label = st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=2,
|
||||
max_size=10
|
||||
)
|
||||
return st.builds(
|
||||
lambda a, b, c: f"{a}.{b}.{c}",
|
||||
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
|
||||
)
|
||||
|
||||
|
||||
class MockBlacklistFilter:
|
||||
"""模拟黑名单过滤器"""
|
||||
|
||||
def __init__(self, blocked_patterns: list):
|
||||
self.blocked_patterns = blocked_patterns
|
||||
|
||||
def is_allowed(self, target: str) -> bool:
|
||||
"""检查目标是否被允许(不在黑名单中)"""
|
||||
for pattern in self.blocked_patterns:
|
||||
if pattern in target:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class TestDatabaseTargetProviderProperties:
|
||||
"""DatabaseTargetProvider 属性测试类"""
|
||||
|
||||
@given(
|
||||
hosts=st.lists(valid_domain_strategy(), min_size=1, max_size=20),
|
||||
blocked_keyword=st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=2,
|
||||
max_size=5
|
||||
)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_7_blacklist_filters_hosts(self, hosts, blocked_keyword):
|
||||
"""
|
||||
Property 7: DatabaseTargetProvider Blacklist Application (hosts)
|
||||
|
||||
Feature: scan-target-provider, Property 7: DatabaseTargetProvider Blacklist Application
|
||||
**Validates: Requirements 2.3, 10.1, 10.2, 10.3**
|
||||
|
||||
For any set of hosts and a blacklist keyword, the provider should filter out
|
||||
all hosts containing the blocked keyword.
|
||||
"""
|
||||
# 创建模拟的黑名单过滤器
|
||||
mock_filter = MockBlacklistFilter([blocked_keyword])
|
||||
|
||||
# 创建 provider 并注入模拟的黑名单过滤器
|
||||
provider = DatabaseTargetProvider(target_id=1)
|
||||
provider._blacklist_filter = mock_filter
|
||||
|
||||
# 模拟 Target 和 SubdomainService
|
||||
mock_target = MagicMock()
|
||||
mock_target.type = 'domain'
|
||||
mock_target.name = hosts[0] if hosts else 'example.com'
|
||||
|
||||
with patch('apps.targets.services.TargetService') as mock_target_service, \
|
||||
patch('apps.asset.services.asset.subdomain_service.SubdomainService') as mock_subdomain_service:
|
||||
|
||||
mock_target_service.return_value.get_target.return_value = mock_target
|
||||
mock_subdomain_service.return_value.iter_subdomain_names_by_target.return_value = iter(hosts[1:] if len(hosts) > 1 else [])
|
||||
|
||||
# 获取结果
|
||||
result = list(provider.iter_hosts())
|
||||
|
||||
# 验证:所有结果都不包含被阻止的关键词
|
||||
for host in result:
|
||||
assert blocked_keyword not in host, f"Host '{host}' should be filtered by blacklist keyword '{blocked_keyword}'"
|
||||
|
||||
# 验证:所有不包含关键词的主机都应该在结果中
|
||||
if hosts:
|
||||
all_hosts = [hosts[0]] + [h for h in hosts[1:] if h != hosts[0]]
|
||||
expected_allowed = [h for h in all_hosts if blocked_keyword not in h]
|
||||
else:
|
||||
expected_allowed = []
|
||||
|
||||
assert set(result) == set(expected_allowed)
|
||||
|
||||
|
||||
class TestDatabaseTargetProviderUnit:
|
||||
"""DatabaseTargetProvider 单元测试类"""
|
||||
|
||||
def test_target_id_in_context(self):
|
||||
"""测试 target_id 正确设置到上下文中"""
|
||||
provider = DatabaseTargetProvider(target_id=123)
|
||||
assert provider.target_id == 123
|
||||
assert provider.context.target_id == 123
|
||||
|
||||
def test_context_propagation(self):
|
||||
"""测试上下文传递"""
|
||||
ctx = ProviderContext(scan_id=789)
|
||||
provider = DatabaseTargetProvider(target_id=123, context=ctx)
|
||||
|
||||
assert provider.target_id == 123 # target_id 被覆盖
|
||||
assert provider.scan_id == 789
|
||||
|
||||
def test_blacklist_filter_lazy_loading(self):
|
||||
"""测试黑名单过滤器延迟加载"""
|
||||
provider = DatabaseTargetProvider(target_id=123)
|
||||
|
||||
# 初始时 _blacklist_filter 为 None
|
||||
assert provider._blacklist_filter is None
|
||||
|
||||
# 模拟 BlacklistService
|
||||
with patch('apps.common.services.BlacklistService') as mock_service, \
|
||||
patch('apps.common.utils.BlacklistFilter') as mock_filter_class:
|
||||
|
||||
mock_service.return_value.get_rules.return_value = []
|
||||
mock_filter_instance = MagicMock()
|
||||
mock_filter_class.return_value = mock_filter_instance
|
||||
|
||||
# 第一次调用
|
||||
result1 = provider.get_blacklist_filter()
|
||||
assert result1 == mock_filter_instance
|
||||
|
||||
# 第二次调用应该返回缓存的实例
|
||||
result2 = provider.get_blacklist_filter()
|
||||
assert result2 == mock_filter_instance
|
||||
|
||||
# BlacklistService 只应该被调用一次
|
||||
mock_service.return_value.get_rules.assert_called_once_with(123)
|
||||
|
||||
def test_nonexistent_target_returns_empty(self):
|
||||
"""测试不存在的 target 返回空迭代器"""
|
||||
provider = DatabaseTargetProvider(target_id=99999)
|
||||
|
||||
with patch('apps.targets.services.TargetService') as mock_service, \
|
||||
patch('apps.common.services.BlacklistService') as mock_blacklist_service:
|
||||
|
||||
mock_service.return_value.get_target.return_value = None
|
||||
mock_blacklist_service.return_value.get_rules.return_value = []
|
||||
|
||||
result = list(provider.iter_hosts())
|
||||
assert result == []
|
||||
152
backend/apps/scan/providers/tests/test_list_provider.py
Normal file
152
backend/apps/scan/providers/tests/test_list_provider.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
ListTargetProvider 属性测试
|
||||
|
||||
Property 1: ListTargetProvider Round-Trip
|
||||
*For any* 主机列表和 URL 列表,创建 ListTargetProvider 后迭代 iter_hosts() 和 iter_urls()
|
||||
应该返回与输入相同的元素(顺序相同)。
|
||||
|
||||
**Validates: Requirements 3.1, 3.2**
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, strategies as st, settings, assume
|
||||
|
||||
from apps.scan.providers.list_provider import ListTargetProvider
|
||||
from apps.scan.providers.base import ProviderContext
|
||||
|
||||
|
||||
# 生成有效域名的策略
|
||||
def valid_domain_strategy():
|
||||
"""生成有效的域名"""
|
||||
# 生成简单的域名格式: subdomain.domain.tld
|
||||
label = st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=2,
|
||||
max_size=10
|
||||
)
|
||||
return st.builds(
|
||||
lambda a, b, c: f"{a}.{b}.{c}",
|
||||
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
|
||||
)
|
||||
|
||||
# 生成有效 IP 地址的策略
|
||||
def valid_ip_strategy():
|
||||
"""生成有效的 IPv4 地址"""
|
||||
octet = st.integers(min_value=1, max_value=254)
|
||||
return st.builds(
|
||||
lambda a, b, c, d: f"{a}.{b}.{c}.{d}",
|
||||
octet, octet, octet, octet
|
||||
)
|
||||
|
||||
# 组合策略:域名或 IP
|
||||
host_strategy = st.one_of(valid_domain_strategy(), valid_ip_strategy())
|
||||
|
||||
# 生成有效 URL 的策略
|
||||
def valid_url_strategy():
|
||||
"""生成有效的 URL"""
|
||||
domain = valid_domain_strategy()
|
||||
return st.builds(
|
||||
lambda d, path: f"https://{d}/{path}" if path else f"https://{d}",
|
||||
domain,
|
||||
st.one_of(
|
||||
st.just(""),
|
||||
st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=1,
|
||||
max_size=10
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
url_strategy = valid_url_strategy()
|
||||
|
||||
|
||||
class TestListTargetProviderProperties:
|
||||
"""ListTargetProvider 属性测试类"""
|
||||
|
||||
@given(hosts=st.lists(host_strategy, max_size=50))
|
||||
@settings(max_examples=100)
|
||||
def test_property_1_hosts_round_trip(self, hosts):
|
||||
"""
|
||||
Property 1: ListTargetProvider Round-Trip (hosts)
|
||||
|
||||
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
|
||||
**Validates: Requirements 3.1, 3.2**
|
||||
|
||||
For any host list, creating a ListTargetProvider and iterating iter_hosts()
|
||||
should return the same elements in the same order.
|
||||
"""
|
||||
# ListTargetProvider 使用 targets 参数,自动分类为 hosts/urls
|
||||
provider = ListTargetProvider(targets=hosts)
|
||||
result = list(provider.iter_hosts())
|
||||
assert result == hosts
|
||||
|
||||
@given(urls=st.lists(url_strategy, max_size=50))
|
||||
@settings(max_examples=100)
|
||||
def test_property_1_urls_round_trip(self, urls):
|
||||
"""
|
||||
Property 1: ListTargetProvider Round-Trip (urls)
|
||||
|
||||
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
|
||||
**Validates: Requirements 3.1, 3.2**
|
||||
|
||||
For any URL list, creating a ListTargetProvider and iterating iter_urls()
|
||||
should return the same elements in the same order.
|
||||
"""
|
||||
# ListTargetProvider 使用 targets 参数,自动分类为 hosts/urls
|
||||
provider = ListTargetProvider(targets=urls)
|
||||
result = list(provider.iter_urls())
|
||||
assert result == urls
|
||||
|
||||
@given(
|
||||
hosts=st.lists(host_strategy, max_size=30),
|
||||
urls=st.lists(url_strategy, max_size=30)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_1_combined_round_trip(self, hosts, urls):
|
||||
"""
|
||||
Property 1: ListTargetProvider Round-Trip (combined)
|
||||
|
||||
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
|
||||
**Validates: Requirements 3.1, 3.2**
|
||||
|
||||
For any combination of hosts and URLs, both should round-trip correctly.
|
||||
"""
|
||||
# 合并 hosts 和 urls,ListTargetProvider 会自动分类
|
||||
combined = hosts + urls
|
||||
provider = ListTargetProvider(targets=combined)
|
||||
|
||||
hosts_result = list(provider.iter_hosts())
|
||||
urls_result = list(provider.iter_urls())
|
||||
|
||||
assert hosts_result == hosts
|
||||
assert urls_result == urls
|
||||
|
||||
|
||||
class TestListTargetProviderUnit:
|
||||
"""ListTargetProvider 单元测试类"""
|
||||
|
||||
def test_empty_lists(self):
|
||||
"""测试空列表返回空迭代器 - Requirements 3.5"""
|
||||
provider = ListTargetProvider()
|
||||
assert list(provider.iter_hosts()) == []
|
||||
assert list(provider.iter_urls()) == []
|
||||
|
||||
def test_blacklist_filter_returns_none(self):
|
||||
"""测试黑名单过滤器返回 None - Requirements 3.4"""
|
||||
provider = ListTargetProvider(targets=["example.com"])
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
def test_target_id_association(self):
|
||||
"""测试 target_id 关联 - Requirements 3.3"""
|
||||
ctx = ProviderContext(target_id=123)
|
||||
provider = ListTargetProvider(targets=["example.com"], context=ctx)
|
||||
assert provider.target_id == 123
|
||||
|
||||
def test_context_propagation(self):
|
||||
"""测试上下文传递"""
|
||||
ctx = ProviderContext(target_id=456, scan_id=789)
|
||||
provider = ListTargetProvider(targets=["example.com"], context=ctx)
|
||||
|
||||
assert provider.target_id == 456
|
||||
assert provider.scan_id == 789
|
||||
180
backend/apps/scan/providers/tests/test_pipeline_provider.py
Normal file
180
backend/apps/scan/providers/tests/test_pipeline_provider.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
PipelineTargetProvider 属性测试
|
||||
|
||||
Property 3: PipelineTargetProvider Round-Trip
|
||||
*For any* StageOutput 对象,PipelineTargetProvider 的 iter_hosts() 和 iter_urls()
|
||||
应该返回与 StageOutput 中 hosts 和 urls 列表相同的元素。
|
||||
|
||||
**Validates: Requirements 5.1, 5.2**
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from hypothesis import given, strategies as st, settings
|
||||
|
||||
from apps.scan.providers.pipeline_provider import PipelineTargetProvider, StageOutput
|
||||
from apps.scan.providers.base import ProviderContext
|
||||
|
||||
|
||||
# 生成有效域名的策略
|
||||
def valid_domain_strategy():
|
||||
"""生成有效的域名"""
|
||||
label = st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=2,
|
||||
max_size=10
|
||||
)
|
||||
return st.builds(
|
||||
lambda a, b, c: f"{a}.{b}.{c}",
|
||||
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
|
||||
)
|
||||
|
||||
# 生成有效 IP 地址的策略
|
||||
def valid_ip_strategy():
|
||||
"""生成有效的 IPv4 地址"""
|
||||
octet = st.integers(min_value=1, max_value=254)
|
||||
return st.builds(
|
||||
lambda a, b, c, d: f"{a}.{b}.{c}.{d}",
|
||||
octet, octet, octet, octet
|
||||
)
|
||||
|
||||
# 组合策略:域名或 IP
|
||||
host_strategy = st.one_of(valid_domain_strategy(), valid_ip_strategy())
|
||||
|
||||
# 生成有效 URL 的策略
|
||||
def valid_url_strategy():
|
||||
"""生成有效的 URL"""
|
||||
domain = valid_domain_strategy()
|
||||
return st.builds(
|
||||
lambda d, path: f"https://{d}/{path}" if path else f"https://{d}",
|
||||
domain,
|
||||
st.one_of(
|
||||
st.just(""),
|
||||
st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=1,
|
||||
max_size=10
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
url_strategy = valid_url_strategy()
|
||||
|
||||
|
||||
class TestPipelineTargetProviderProperties:
|
||||
"""PipelineTargetProvider 属性测试类"""
|
||||
|
||||
@given(hosts=st.lists(host_strategy, max_size=50))
|
||||
@settings(max_examples=100)
|
||||
def test_property_3_hosts_round_trip(self, hosts):
|
||||
"""
|
||||
Property 3: PipelineTargetProvider Round-Trip (hosts)
|
||||
|
||||
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
|
||||
**Validates: Requirements 5.1, 5.2**
|
||||
|
||||
For any StageOutput with hosts, PipelineTargetProvider should return
|
||||
the same hosts in the same order.
|
||||
"""
|
||||
stage_output = StageOutput(hosts=hosts)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
result = list(provider.iter_hosts())
|
||||
assert result == hosts
|
||||
|
||||
@given(urls=st.lists(url_strategy, max_size=50))
|
||||
@settings(max_examples=100)
|
||||
def test_property_3_urls_round_trip(self, urls):
|
||||
"""
|
||||
Property 3: PipelineTargetProvider Round-Trip (urls)
|
||||
|
||||
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
|
||||
**Validates: Requirements 5.1, 5.2**
|
||||
|
||||
For any StageOutput with urls, PipelineTargetProvider should return
|
||||
the same urls in the same order.
|
||||
"""
|
||||
stage_output = StageOutput(urls=urls)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
result = list(provider.iter_urls())
|
||||
assert result == urls
|
||||
|
||||
@given(
|
||||
hosts=st.lists(host_strategy, max_size=30),
|
||||
urls=st.lists(url_strategy, max_size=30)
|
||||
)
|
||||
@settings(max_examples=100)
|
||||
def test_property_3_combined_round_trip(self, hosts, urls):
|
||||
"""
|
||||
Property 3: PipelineTargetProvider Round-Trip (combined)
|
||||
|
||||
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
|
||||
**Validates: Requirements 5.1, 5.2**
|
||||
|
||||
For any StageOutput with both hosts and urls, both should round-trip correctly.
|
||||
"""
|
||||
stage_output = StageOutput(hosts=hosts, urls=urls)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
|
||||
hosts_result = list(provider.iter_hosts())
|
||||
urls_result = list(provider.iter_urls())
|
||||
|
||||
assert hosts_result == hosts
|
||||
assert urls_result == urls
|
||||
|
||||
|
||||
class TestPipelineTargetProviderUnit:
|
||||
"""PipelineTargetProvider 单元测试类"""
|
||||
|
||||
def test_empty_stage_output(self):
|
||||
"""测试空 StageOutput 返回空迭代器 - Requirements 5.5"""
|
||||
stage_output = StageOutput()
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
|
||||
assert list(provider.iter_hosts()) == []
|
||||
assert list(provider.iter_urls()) == []
|
||||
|
||||
def test_blacklist_filter_returns_none(self):
|
||||
"""测试黑名单过滤器返回 None - Requirements 5.3"""
|
||||
stage_output = StageOutput(hosts=["example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
def test_target_id_association(self):
|
||||
"""测试 target_id 关联 - Requirements 5.4"""
|
||||
stage_output = StageOutput(hosts=["example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output, target_id=123)
|
||||
assert provider.target_id == 123
|
||||
|
||||
def test_context_propagation(self):
|
||||
"""测试上下文传递"""
|
||||
ctx = ProviderContext(target_id=456, scan_id=789)
|
||||
stage_output = StageOutput(hosts=["example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output, context=ctx)
|
||||
|
||||
assert provider.target_id == 456
|
||||
assert provider.scan_id == 789
|
||||
|
||||
def test_previous_output_property(self):
|
||||
"""测试 previous_output 属性"""
|
||||
stage_output = StageOutput(hosts=["example.com"], urls=["https://example.com"])
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
|
||||
assert provider.previous_output is stage_output
|
||||
assert provider.previous_output.hosts == ["example.com"]
|
||||
assert provider.previous_output.urls == ["https://example.com"]
|
||||
|
||||
def test_stage_output_with_metadata(self):
|
||||
"""测试带元数据的 StageOutput"""
|
||||
stage_output = StageOutput(
|
||||
hosts=["example.com"],
|
||||
urls=["https://example.com"],
|
||||
new_targets=["new.example.com"],
|
||||
stats={"count": 1},
|
||||
success=True,
|
||||
error=None
|
||||
)
|
||||
provider = PipelineTargetProvider(previous_output=stage_output)
|
||||
|
||||
assert list(provider.iter_hosts()) == ["example.com"]
|
||||
assert list(provider.iter_urls()) == ["https://example.com"]
|
||||
assert provider.previous_output.new_targets == ["new.example.com"]
|
||||
assert provider.previous_output.stats == {"count": 1}
|
||||
191
backend/apps/scan/providers/tests/test_snapshot_provider.py
Normal file
191
backend/apps/scan/providers/tests/test_snapshot_provider.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
SnapshotTargetProvider 单元测试
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from apps.scan.providers import SnapshotTargetProvider, ProviderContext
|
||||
|
||||
|
||||
class TestSnapshotTargetProvider:
|
||||
"""SnapshotTargetProvider 测试类"""
|
||||
|
||||
def test_init_with_scan_id_and_type(self):
|
||||
"""测试初始化"""
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain"
|
||||
)
|
||||
|
||||
assert provider.scan_id == 100
|
||||
assert provider.snapshot_type == "subdomain"
|
||||
assert provider.target_id is None # 默认 context
|
||||
|
||||
def test_init_with_context(self):
|
||||
"""测试带 context 初始化"""
|
||||
ctx = ProviderContext(target_id=1, scan_id=100)
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain",
|
||||
context=ctx
|
||||
)
|
||||
|
||||
assert provider.scan_id == 100
|
||||
assert provider.target_id == 1
|
||||
assert provider.snapshot_type == "subdomain"
|
||||
|
||||
@patch('apps.asset.services.snapshot.SubdomainSnapshotsService')
|
||||
def test_iter_hosts_subdomain(self, mock_service_class):
|
||||
"""测试从子域名快照迭代主机"""
|
||||
# Mock service
|
||||
mock_service = Mock()
|
||||
mock_service.iter_subdomain_names_by_scan.return_value = iter([
|
||||
"a.example.com",
|
||||
"b.example.com"
|
||||
])
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# 创建 provider
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain"
|
||||
)
|
||||
|
||||
# 迭代主机
|
||||
hosts = list(provider.iter_hosts())
|
||||
|
||||
assert hosts == ["a.example.com", "b.example.com"]
|
||||
mock_service.iter_subdomain_names_by_scan.assert_called_once_with(
|
||||
scan_id=100,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
@patch('apps.asset.services.snapshot.HostPortMappingSnapshotsService')
|
||||
def test_iter_hosts_host_port(self, mock_service_class):
|
||||
"""测试从主机端口映射快照迭代主机"""
|
||||
# Mock queryset
|
||||
mock_mapping1 = Mock()
|
||||
mock_mapping1.host = "example.com"
|
||||
mock_mapping1.port = 80
|
||||
|
||||
mock_mapping2 = Mock()
|
||||
mock_mapping2.host = "example.com"
|
||||
mock_mapping2.port = 443
|
||||
|
||||
mock_queryset = Mock()
|
||||
mock_queryset.iterator.return_value = iter([mock_mapping1, mock_mapping2])
|
||||
|
||||
# Mock service
|
||||
mock_service = Mock()
|
||||
mock_service.get_by_scan.return_value = mock_queryset
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# 创建 provider
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="host_port"
|
||||
)
|
||||
|
||||
# 迭代主机
|
||||
hosts = list(provider.iter_hosts())
|
||||
|
||||
assert hosts == ["example.com:80", "example.com:443"]
|
||||
mock_service.get_by_scan.assert_called_once_with(scan_id=100)
|
||||
|
||||
@patch('apps.asset.services.snapshot.WebsiteSnapshotsService')
|
||||
def test_iter_urls_website(self, mock_service_class):
|
||||
"""测试从网站快照迭代 URL"""
|
||||
# Mock service
|
||||
mock_service = Mock()
|
||||
mock_service.iter_website_urls_by_scan.return_value = iter([
|
||||
"http://example.com",
|
||||
"https://example.com"
|
||||
])
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# 创建 provider
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="website"
|
||||
)
|
||||
|
||||
# 迭代 URL
|
||||
urls = list(provider.iter_urls())
|
||||
|
||||
assert urls == ["http://example.com", "https://example.com"]
|
||||
mock_service.iter_website_urls_by_scan.assert_called_once_with(
|
||||
scan_id=100,
|
||||
chunk_size=1000
|
||||
)
|
||||
|
||||
@patch('apps.asset.services.snapshot.EndpointSnapshotsService')
|
||||
def test_iter_urls_endpoint(self, mock_service_class):
|
||||
"""测试从端点快照迭代 URL"""
|
||||
# Mock queryset
|
||||
mock_endpoint1 = Mock()
|
||||
mock_endpoint1.url = "http://example.com/api/v1"
|
||||
|
||||
mock_endpoint2 = Mock()
|
||||
mock_endpoint2.url = "http://example.com/api/v2"
|
||||
|
||||
mock_queryset = Mock()
|
||||
mock_queryset.iterator.return_value = iter([mock_endpoint1, mock_endpoint2])
|
||||
|
||||
# Mock service
|
||||
mock_service = Mock()
|
||||
mock_service.get_by_scan.return_value = mock_queryset
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# 创建 provider
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="endpoint"
|
||||
)
|
||||
|
||||
# 迭代 URL
|
||||
urls = list(provider.iter_urls())
|
||||
|
||||
assert urls == ["http://example.com/api/v1", "http://example.com/api/v2"]
|
||||
mock_service.get_by_scan.assert_called_once_with(scan_id=100)
|
||||
|
||||
def test_iter_hosts_unsupported_type(self):
|
||||
"""测试不支持的快照类型(iter_hosts)"""
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="website" # website 不支持 iter_hosts
|
||||
)
|
||||
|
||||
hosts = list(provider.iter_hosts())
|
||||
assert hosts == []
|
||||
|
||||
def test_iter_urls_unsupported_type(self):
|
||||
"""测试不支持的快照类型(iter_urls)"""
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain" # subdomain 不支持 iter_urls
|
||||
)
|
||||
|
||||
urls = list(provider.iter_urls())
|
||||
assert urls == []
|
||||
|
||||
def test_get_blacklist_filter(self):
|
||||
"""测试黑名单过滤器(快照模式不使用黑名单)"""
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100,
|
||||
snapshot_type="subdomain"
|
||||
)
|
||||
|
||||
assert provider.get_blacklist_filter() is None
|
||||
|
||||
def test_context_propagation(self):
|
||||
"""测试上下文传递"""
|
||||
ctx = ProviderContext(target_id=456, scan_id=789)
|
||||
provider = SnapshotTargetProvider(
|
||||
scan_id=100, # 会被 context 覆盖
|
||||
snapshot_type="subdomain",
|
||||
context=ctx
|
||||
)
|
||||
|
||||
assert provider.target_id == 456
|
||||
assert provider.scan_id == 100 # scan_id 在 __init__ 中被设置
|
||||
@@ -1,36 +1,48 @@
|
||||
"""
|
||||
导出站点 URL 到 TXT 文件的 Task
|
||||
|
||||
使用 export_urls_with_fallback 用例函数处理回退链逻辑
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
数据源: WebSite.url → Default
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
|
||||
from apps.scan.services.target_export_service import (
|
||||
export_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_sites")
|
||||
def export_sites_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
target_id: Optional[int] = None,
|
||||
output_file: str = "",
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000,
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的所有站点 URL 到 TXT 文件
|
||||
|
||||
数据源优先级(回退链):
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. WebSite 表 - 站点级别 URL
|
||||
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 每次读取的批次大小,默认 1000
|
||||
|
||||
Returns:
|
||||
@@ -44,6 +56,17 @@ def export_sites_task(
|
||||
ValueError: 参数错误
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# Provider 模式:使用 TargetProvider 导出
|
||||
if provider is not None:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
return _export_with_provider(output_file, provider)
|
||||
|
||||
# 传统模式:使用 export_urls_with_fallback
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
result = export_urls_with_fallback(
|
||||
target_id=target_id,
|
||||
output_file=output_file,
|
||||
@@ -62,3 +85,32 @@ def export_sites_task(
|
||||
'output_file': result['output_file'],
|
||||
'total_count': result['total_count'],
|
||||
}
|
||||
|
||||
|
||||
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
|
||||
"""使用 Provider 导出 URL"""
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_count = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count,
|
||||
}
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
"""
|
||||
导出 URL 任务
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
用于指纹识别前导出目标下的 URL 到文件
|
||||
使用 export_urls_with_fallback 用例函数处理回退链逻辑
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
from prefect import task
|
||||
|
||||
@@ -13,33 +18,51 @@ from apps.scan.services.target_export_service import (
|
||||
export_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_urls_for_fingerprint")
|
||||
def export_urls_for_fingerprint_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
target_id: Optional[int] = None,
|
||||
output_file: str = "",
|
||||
source: str = 'website', # 保留参数,兼容旧调用(实际值由回退链决定)
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的 URL 到文件(用于指纹识别)
|
||||
|
||||
数据源优先级(回退链):
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. WebSite 表 - 站点级别 URL
|
||||
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
output_file: 输出文件路径
|
||||
source: 数据源类型(保留参数,兼容旧调用,实际值由回退链决定)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 批量读取大小
|
||||
|
||||
Returns:
|
||||
dict: {'output_file': str, 'total_count': int, 'source': str}
|
||||
"""
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# Provider 模式:使用 TargetProvider 导出
|
||||
if provider is not None:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
return _export_with_provider(output_file, provider)
|
||||
|
||||
# 传统模式:使用 export_urls_with_fallback
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
result = export_urls_with_fallback(
|
||||
target_id=target_id,
|
||||
output_file=output_file,
|
||||
@@ -58,3 +81,32 @@ def export_urls_for_fingerprint_task(
|
||||
'total_count': result['total_count'],
|
||||
'source': result['source'],
|
||||
}
|
||||
|
||||
|
||||
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
|
||||
"""使用 Provider 导出 URL"""
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_count = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
return {
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count,
|
||||
'source': 'provider',
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""
|
||||
导出主机列表到 TXT 文件的 Task
|
||||
|
||||
使用 TargetExportService.export_hosts() 统一处理导出逻辑
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
根据 Target 类型决定导出内容:
|
||||
- DOMAIN: 从 Subdomain 表导出子域名
|
||||
@@ -9,30 +11,39 @@
|
||||
- CIDR: 展开 CIDR 范围内的所有 IP
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
|
||||
from apps.scan.services.target_export_service import create_export_service
|
||||
from apps.scan.providers import TargetProvider, DatabaseTargetProvider, ProviderContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_hosts")
|
||||
def export_hosts_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
target_id: Optional[int] = None,
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出主机列表到 TXT 文件
|
||||
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
根据 Target 类型自动决定导出内容:
|
||||
- DOMAIN: 从 Subdomain 表导出子域名(流式处理,支持 10万+ 域名)
|
||||
- IP: 直接写入 target.name(单个 IP)
|
||||
- CIDR: 展开 CIDR 范围内的所有可用 IP
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 每次读取的批次大小,默认 1000(仅对 DOMAIN 类型有效)
|
||||
|
||||
Returns:
|
||||
@@ -40,26 +51,60 @@ def export_hosts_task(
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_count': int,
|
||||
'target_type': str
|
||||
'target_type': str # 仅传统模式返回
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: Target 不存在
|
||||
ValueError: 参数错误(target_id 和 provider 都未提供)
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
# 使用工厂函数创建导出服务
|
||||
export_service = create_export_service(target_id)
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
result = export_service.export_hosts(
|
||||
target_id=target_id,
|
||||
output_path=output_file,
|
||||
batch_size=batch_size
|
||||
)
|
||||
# 向后兼容:如果没有提供 provider,使用 target_id 创建 DatabaseTargetProvider
|
||||
if provider is None:
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
provider = DatabaseTargetProvider(target_id=target_id)
|
||||
use_legacy_mode = True
|
||||
else:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
use_legacy_mode = False
|
||||
|
||||
# 保持返回值格式不变(向后兼容)
|
||||
return {
|
||||
'success': result['success'],
|
||||
'output_file': result['output_file'],
|
||||
'total_count': result['total_count'],
|
||||
'target_type': result['target_type']
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用 Provider 导出主机列表
|
||||
total_count = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for host in provider.iter_hosts():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(host):
|
||||
continue
|
||||
|
||||
f.write(f"{host}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个主机...", total_count)
|
||||
|
||||
logger.info("✓ 主机列表导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
# 构建返回值
|
||||
result = {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_count': total_count,
|
||||
}
|
||||
|
||||
# 传统模式:保持返回值格式不变(向后兼容)
|
||||
if use_legacy_mode:
|
||||
# 获取 target_type(仅传统模式需要)
|
||||
from apps.targets.services import TargetService
|
||||
target = TargetService().get_target(target_id)
|
||||
result['target_type'] = target.type if target else 'unknown'
|
||||
|
||||
return result
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""
|
||||
导出站点URL到文件的Task
|
||||
|
||||
直接使用 HostPortMapping 表查询 host+port 组合,拼接成URL格式写入文件
|
||||
使用 TargetExportService.generate_default_urls() 处理默认值回退逻辑
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
特殊逻辑:
|
||||
- 80 端口:只生成 HTTP URL(省略端口号)
|
||||
@@ -10,6 +11,7 @@
|
||||
- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
|
||||
@@ -17,6 +19,7 @@ from apps.asset.services import HostPortMappingService
|
||||
from apps.scan.services.target_export_service import create_export_service
|
||||
from apps.common.services import BlacklistService
|
||||
from apps.common.utils import BlacklistFilter
|
||||
from apps.scan.providers import TargetProvider, DatabaseTargetProvider, ProviderContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,26 +42,30 @@ def _generate_urls_from_port(host: str, port: int) -> list[str]:
|
||||
|
||||
@task(name="export_site_urls")
|
||||
def export_site_urls_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
target_id: Optional[int] = None,
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出目标下的所有站点URL到文件(基于 HostPortMapping 表)
|
||||
导出目标下的所有站点URL到文件
|
||||
|
||||
数据源: HostPortMapping (host + port) → Default
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从 HostPortMapping 表导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
特殊逻辑:
|
||||
传统模式特殊逻辑:
|
||||
- 80 端口:只生成 HTTP URL(省略端口号)
|
||||
- 443 端口:只生成 HTTPS URL(省略端口号)
|
||||
- 其他端口:生成 HTTP 和 HTTPS 两个URL(带端口号)
|
||||
|
||||
回退逻辑:
|
||||
回退逻辑(仅传统模式):
|
||||
- 如果 HostPortMapping 为空,使用 generate_default_urls() 生成默认 URL
|
||||
|
||||
Args:
|
||||
target_id: 目标ID
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
target_id: 目标ID(传统模式,向后兼容)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 每次处理的批次大小
|
||||
|
||||
Returns:
|
||||
@@ -66,14 +73,62 @@ def export_site_urls_task(
|
||||
'success': bool,
|
||||
'output_file': str,
|
||||
'total_urls': int,
|
||||
'association_count': int, # 主机端口关联数量
|
||||
'source': str, # 数据来源: "host_port" | "default"
|
||||
'association_count': int, # 主机端口关联数量(仅传统模式)
|
||||
'source': str, # 数据来源: "host_port" | "default" | "provider"
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: 参数错误
|
||||
IOError: 文件写入失败
|
||||
"""
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# 向后兼容:如果没有提供 provider,使用传统模式
|
||||
if provider is None:
|
||||
logger.info("使用传统模式 - Target ID: %d, 输出文件: %s", target_id, output_file)
|
||||
return _export_site_urls_legacy(target_id, output_file, batch_size)
|
||||
|
||||
# Provider 模式
|
||||
logger.info("使用 Provider 模式 - Provider: %s, 输出文件: %s", type(provider).__name__, output_file)
|
||||
|
||||
# 确保输出目录存在
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用 Provider 导出 URL 列表
|
||||
total_urls = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
f.write(f"{url}\n")
|
||||
total_urls += 1
|
||||
|
||||
if total_urls % 1000 == 0:
|
||||
logger.info("已导出 %d 个URL...", total_urls)
|
||||
|
||||
logger.info("✓ URL导出完成 - 总数: %d, 文件: %s", total_urls, str(output_path))
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'output_file': str(output_path),
|
||||
'total_urls': total_urls,
|
||||
'source': 'provider',
|
||||
}
|
||||
|
||||
|
||||
def _export_site_urls_legacy(target_id: int, output_file: str, batch_size: int) -> dict:
|
||||
"""
|
||||
传统模式:从 HostPortMapping 表导出 URL
|
||||
|
||||
保持原有逻辑不变,确保向后兼容
|
||||
"""
|
||||
logger.info("开始统计站点URL - Target ID: %d, 输出文件: %s", target_id, output_file)
|
||||
|
||||
# 确保输出目录存在
|
||||
|
||||
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
Task 向后兼容性测试
|
||||
|
||||
Property 8: Task Backward Compatibility
|
||||
*For any* 任务调用,当仅提供 target_id 参数时,任务应该创建 DatabaseTargetProvider
|
||||
并使用它进行数据访问,行为与改造前一致。
|
||||
|
||||
**Validates: Requirements 6.1, 6.2, 6.4, 6.5**
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
from hypothesis import given, strategies as st, settings
|
||||
|
||||
from apps.scan.tasks.port_scan.export_hosts_task import export_hosts_task
|
||||
from apps.scan.tasks.site_scan.export_site_urls_task import export_site_urls_task
|
||||
from apps.scan.providers import ListTargetProvider
|
||||
|
||||
|
||||
# 生成有效域名的策略
|
||||
def valid_domain_strategy():
|
||||
"""生成有效的域名"""
|
||||
label = st.text(
|
||||
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
|
||||
min_size=2,
|
||||
max_size=10
|
||||
)
|
||||
return st.builds(
|
||||
lambda a, b, c: f"{a}.{b}.{c}",
|
||||
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
|
||||
)
|
||||
|
||||
|
||||
class TestExportHostsTaskBackwardCompatibility:
|
||||
"""export_hosts_task 向后兼容性测试"""
|
||||
|
||||
@given(
|
||||
target_id=st.integers(min_value=1, max_value=1000),
|
||||
hosts=st.lists(valid_domain_strategy(), min_size=1, max_size=10)
|
||||
)
|
||||
@settings(max_examples=50, deadline=None)
|
||||
def test_property_8_legacy_mode_creates_database_provider(self, target_id, hosts):
|
||||
"""
|
||||
Property 8: Task Backward Compatibility (export_hosts_task)
|
||||
|
||||
Feature: scan-target-provider, Property 8: Task Backward Compatibility
|
||||
**Validates: Requirements 6.1, 6.2, 6.4, 6.5**
|
||||
|
||||
For any target_id, when calling export_hosts_task with only target_id,
|
||||
it should create a DatabaseTargetProvider and use it for data access.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
output_file = f.name
|
||||
|
||||
try:
|
||||
# Mock Target 和 SubdomainService
|
||||
mock_target = MagicMock()
|
||||
mock_target.type = 'domain'
|
||||
mock_target.name = hosts[0]
|
||||
|
||||
with patch('apps.scan.tasks.port_scan.export_hosts_task.DatabaseTargetProvider') as mock_provider_class, \
|
||||
patch('apps.targets.services.TargetService') as mock_target_service:
|
||||
|
||||
# 创建 mock provider 实例
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.iter_hosts.return_value = iter(hosts)
|
||||
mock_provider.get_blacklist_filter.return_value = None
|
||||
mock_provider_class.return_value = mock_provider
|
||||
|
||||
# Mock TargetService
|
||||
mock_target_service.return_value.get_target.return_value = mock_target
|
||||
|
||||
# 调用任务(传统模式:只传 target_id)
|
||||
result = export_hosts_task(
|
||||
output_file=output_file,
|
||||
target_id=target_id
|
||||
)
|
||||
|
||||
# 验证:应该创建了 DatabaseTargetProvider
|
||||
mock_provider_class.assert_called_once_with(target_id=target_id)
|
||||
|
||||
# 验证:返回值包含必需字段
|
||||
assert result['success'] is True
|
||||
assert result['output_file'] == output_file
|
||||
assert result['total_count'] == len(hosts)
|
||||
assert 'target_type' in result # 传统模式应该返回 target_type
|
||||
|
||||
# 验证:文件内容正确
|
||||
with open(output_file, 'r') as f:
|
||||
lines = [line.strip() for line in f.readlines()]
|
||||
assert lines == hosts
|
||||
|
||||
finally:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
|
||||
def test_legacy_mode_with_provider_parameter(self):
|
||||
"""测试当同时提供 target_id 和 provider 时,provider 优先"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
output_file = f.name
|
||||
|
||||
try:
|
||||
hosts = ['example.com', 'test.com']
|
||||
provider = ListTargetProvider(targets=hosts)
|
||||
|
||||
# 调用任务(同时提供 target_id 和 provider)
|
||||
result = export_hosts_task(
|
||||
output_file=output_file,
|
||||
target_id=123, # 应该被忽略
|
||||
provider=provider
|
||||
)
|
||||
|
||||
# 验证:使用了 provider
|
||||
assert result['success'] is True
|
||||
assert result['total_count'] == len(hosts)
|
||||
assert 'target_type' not in result # Provider 模式不返回 target_type
|
||||
|
||||
# 验证:文件内容正确
|
||||
with open(output_file, 'r') as f:
|
||||
lines = [line.strip() for line in f.readlines()]
|
||||
assert lines == hosts
|
||||
|
||||
finally:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
|
||||
def test_error_when_no_parameters(self):
|
||||
"""测试当 target_id 和 provider 都未提供时抛出错误"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
output_file = f.name
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError, match="必须提供 target_id 或 provider 参数之一"):
|
||||
export_hosts_task(output_file=output_file)
|
||||
|
||||
finally:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
|
||||
|
||||
class TestExportSiteUrlsTaskBackwardCompatibility:
|
||||
"""export_site_urls_task 向后兼容性测试"""
|
||||
|
||||
def test_property_8_legacy_mode_uses_traditional_logic(self):
|
||||
"""
|
||||
Property 8: Task Backward Compatibility (export_site_urls_task)
|
||||
|
||||
Feature: scan-target-provider, Property 8: Task Backward Compatibility
|
||||
**Validates: Requirements 6.1, 6.2, 6.4, 6.5**
|
||||
|
||||
When calling export_site_urls_task with only target_id,
|
||||
it should use the traditional logic (_export_site_urls_legacy).
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
output_file = f.name
|
||||
|
||||
try:
|
||||
target_id = 123
|
||||
|
||||
# Mock HostPortMappingService
|
||||
mock_associations = [
|
||||
{'host': 'example.com', 'port': 80},
|
||||
{'host': 'test.com', 'port': 443},
|
||||
]
|
||||
|
||||
with patch('apps.scan.tasks.site_scan.export_site_urls_task.HostPortMappingService') as mock_service_class, \
|
||||
patch('apps.scan.tasks.site_scan.export_site_urls_task.BlacklistService') as mock_blacklist_service:
|
||||
|
||||
# Mock HostPortMappingService
|
||||
mock_service = MagicMock()
|
||||
mock_service.iter_host_port_by_target.return_value = iter(mock_associations)
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
# Mock BlacklistService
|
||||
mock_blacklist = MagicMock()
|
||||
mock_blacklist.get_rules.return_value = []
|
||||
mock_blacklist_service.return_value = mock_blacklist
|
||||
|
||||
# 调用任务(传统模式:只传 target_id)
|
||||
result = export_site_urls_task(
|
||||
output_file=output_file,
|
||||
target_id=target_id
|
||||
)
|
||||
|
||||
# 验证:返回值包含传统模式的字段
|
||||
assert result['success'] is True
|
||||
assert result['output_file'] == output_file
|
||||
assert result['total_urls'] == 2 # 80 端口生成 1 个 URL,443 端口生成 1 个 URL
|
||||
assert 'association_count' in result # 传统模式应该返回 association_count
|
||||
assert result['association_count'] == 2
|
||||
assert result['source'] == 'host_port'
|
||||
|
||||
# 验证:文件内容正确
|
||||
with open(output_file, 'r') as f:
|
||||
lines = [line.strip() for line in f.readlines()]
|
||||
assert 'http://example.com' in lines
|
||||
assert 'https://test.com' in lines
|
||||
|
||||
finally:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
|
||||
def test_provider_mode_uses_provider_logic(self):
|
||||
"""测试当提供 provider 时使用 Provider 模式"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
output_file = f.name
|
||||
|
||||
try:
|
||||
urls = ['https://example.com', 'https://test.com']
|
||||
provider = ListTargetProvider(targets=urls)
|
||||
|
||||
# 调用任务(Provider 模式)
|
||||
result = export_site_urls_task(
|
||||
output_file=output_file,
|
||||
provider=provider
|
||||
)
|
||||
|
||||
# 验证:使用了 provider
|
||||
assert result['success'] is True
|
||||
assert result['total_urls'] == len(urls)
|
||||
assert 'association_count' not in result # Provider 模式不返回 association_count
|
||||
assert result['source'] == 'provider'
|
||||
|
||||
# 验证:文件内容正确
|
||||
with open(output_file, 'r') as f:
|
||||
lines = [line.strip() for line in f.readlines()]
|
||||
assert lines == urls
|
||||
|
||||
finally:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
|
||||
def test_error_when_no_parameters(self):
|
||||
"""测试当 target_id 和 provider 都未提供时抛出错误"""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
|
||||
output_file = f.name
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError, match="必须提供 target_id 或 provider 参数之一"):
|
||||
export_site_urls_task(output_file=output_file)
|
||||
|
||||
finally:
|
||||
Path(output_file).unlink(missing_ok=True)
|
||||
@@ -1,17 +1,23 @@
|
||||
"""
|
||||
导出站点 URL 列表任务
|
||||
|
||||
使用 export_urls_with_fallback 用例函数处理回退链逻辑
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
数据源: WebSite.url → Default(用于 katana 等爬虫工具)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from prefect import task
|
||||
|
||||
from apps.scan.services.target_export_service import (
|
||||
export_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,21 +29,27 @@ logger = logging.getLogger(__name__)
|
||||
)
|
||||
def export_sites_task(
|
||||
output_file: str,
|
||||
target_id: int,
|
||||
scan_id: int,
|
||||
target_id: Optional[int] = None,
|
||||
scan_id: Optional[int] = None,
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000
|
||||
) -> dict:
|
||||
"""
|
||||
导出站点 URL 列表到文件(用于 katana 等爬虫工具)
|
||||
|
||||
数据源优先级(回退链):
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. WebSite 表 - 站点级别 URL
|
||||
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
|
||||
Args:
|
||||
output_file: 输出文件路径
|
||||
target_id: 目标 ID
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
scan_id: 扫描 ID(保留参数,兼容旧调用)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 批次大小(内存优化)
|
||||
|
||||
Returns:
|
||||
@@ -50,6 +62,17 @@ def export_sites_task(
|
||||
ValueError: 参数错误
|
||||
RuntimeError: 执行失败
|
||||
"""
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# Provider 模式:使用 TargetProvider 导出
|
||||
if provider is not None:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
return _export_with_provider(output_file, provider)
|
||||
|
||||
# 传统模式:使用 export_urls_with_fallback
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
result = export_urls_with_fallback(
|
||||
target_id=target_id,
|
||||
output_file=output_file,
|
||||
@@ -67,3 +90,31 @@ def export_sites_task(
|
||||
'output_file': result['output_file'],
|
||||
'asset_count': result['total_count'],
|
||||
}
|
||||
|
||||
|
||||
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
|
||||
"""使用 Provider 导出 URL"""
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_count = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
return {
|
||||
'output_file': str(output_path),
|
||||
'asset_count': total_count,
|
||||
}
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
"""导出 Endpoint URL 到文件的 Task
|
||||
|
||||
使用 export_urls_with_fallback 用例函数处理回退链逻辑
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):使用 target_id 从数据库导出
|
||||
2. Provider 模式:使用 TargetProvider 从任意数据源导出
|
||||
|
||||
数据源优先级(回退链):
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. Endpoint.url - 最精细的 URL(含路径、参数等)
|
||||
2. WebSite.url - 站点级别 URL
|
||||
3. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from prefect import task
|
||||
|
||||
@@ -17,26 +20,33 @@ from apps.scan.services.target_export_service import (
|
||||
export_urls_with_fallback,
|
||||
DataSource,
|
||||
)
|
||||
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@task(name="export_endpoints")
|
||||
def export_endpoints_task(
|
||||
target_id: int,
|
||||
output_file: str,
|
||||
target_id: Optional[int] = None,
|
||||
output_file: str = "",
|
||||
provider: Optional[TargetProvider] = None,
|
||||
batch_size: int = 1000,
|
||||
) -> Dict[str, object]:
|
||||
"""导出目标下的所有 Endpoint URL 到文本文件。
|
||||
|
||||
数据源优先级(回退链):
|
||||
支持两种模式:
|
||||
1. 传统模式(向后兼容):传入 target_id,从数据库导出
|
||||
2. Provider 模式:传入 provider,从任意数据源导出
|
||||
|
||||
数据源优先级(回退链,仅传统模式):
|
||||
1. Endpoint 表 - 最精细的 URL(含路径、参数等)
|
||||
2. WebSite 表 - 站点级别 URL
|
||||
3. 默认生成 - 根据 Target 类型生成 http(s)://target_name
|
||||
|
||||
Args:
|
||||
target_id: 目标 ID
|
||||
target_id: 目标 ID(传统模式,向后兼容)
|
||||
output_file: 输出文件路径(绝对路径)
|
||||
provider: TargetProvider 实例(新模式)
|
||||
batch_size: 每次从数据库迭代的批大小
|
||||
|
||||
Returns:
|
||||
@@ -44,9 +54,20 @@ def export_endpoints_task(
|
||||
"success": bool,
|
||||
"output_file": str,
|
||||
"total_count": int,
|
||||
"source": str, # 数据来源: "endpoint" | "website" | "default" | "none"
|
||||
"source": str, # 数据来源: "endpoint" | "website" | "default" | "none" | "provider"
|
||||
}
|
||||
"""
|
||||
# 参数验证:至少提供一个
|
||||
if target_id is None and provider is None:
|
||||
raise ValueError("必须提供 target_id 或 provider 参数之一")
|
||||
|
||||
# Provider 模式:使用 TargetProvider 导出
|
||||
if provider is not None:
|
||||
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
|
||||
return _export_with_provider(output_file, provider)
|
||||
|
||||
# 传统模式:使用 export_urls_with_fallback
|
||||
logger.info("使用传统模式 - Target ID: %d", target_id)
|
||||
result = export_urls_with_fallback(
|
||||
target_id=target_id,
|
||||
output_file=output_file,
|
||||
@@ -65,3 +86,33 @@ def export_endpoints_task(
|
||||
"total_count": result['total_count'],
|
||||
"source": result['source'],
|
||||
}
|
||||
|
||||
|
||||
def _export_with_provider(output_file: str, provider: TargetProvider) -> Dict[str, object]:
|
||||
"""使用 Provider 导出 URL"""
|
||||
output_path = Path(output_file)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_count = 0
|
||||
blacklist_filter = provider.get_blacklist_filter()
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
|
||||
for url in provider.iter_urls():
|
||||
# 应用黑名单过滤(如果有)
|
||||
if blacklist_filter and not blacklist_filter.is_allowed(url):
|
||||
continue
|
||||
|
||||
f.write(f"{url}\n")
|
||||
total_count += 1
|
||||
|
||||
if total_count % 1000 == 0:
|
||||
logger.info("已导出 %d 个 URL...", total_count)
|
||||
|
||||
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"output_file": str(output_path),
|
||||
"total_count": total_count,
|
||||
"source": "provider",
|
||||
}
|
||||
|
||||
@@ -38,6 +38,7 @@ packaging>=21.0 # 版本比较
|
||||
# 测试框架
|
||||
pytest==8.0.0
|
||||
pytest-django==4.7.0
|
||||
hypothesis>=6.100.0 # 属性测试框架
|
||||
|
||||
# 工具库
|
||||
python-dateutil==2.9.0
|
||||
|
||||
@@ -88,7 +88,7 @@ function highlightSearch(html: string, query: string): string {
|
||||
|
||||
// 转义正则特殊字符
|
||||
const escapedQuery = query.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")
|
||||
const regex = new RegExp(`(${escapedQuery})`, "gi")
|
||||
const regex = new RegExp(`(${escapedQuery})`, "giu")
|
||||
|
||||
// 在标签外的文本中高亮关键词
|
||||
return html.replace(/(<[^>]+>)|([^<]+)/g, (match, tag, text) => {
|
||||
|
||||
Reference in New Issue
Block a user