feat(scan): add provider abstraction layer for flexible target sourcing

- Add TargetProvider base class and ProviderContext for unified target acquisition
- Implement DatabaseTargetProvider for database-backed target queries
- Implement ListTargetProvider for in-memory target lists (fast scan phase 1)
- Implement SnapshotTargetProvider for snapshot table reads (fast scan phase 2+)
- Implement PipelineTargetProvider for pipeline stage outputs
- Add comprehensive provider tests covering common properties and individual providers
- Update screenshot_flow to support both legacy mode (target_id) and provider mode
- Add backward compatibility layer for existing task exports (directory, fingerprint, port, site, url_fetch, vuln scans)
- Add task backward compatibility tests
- Update .gitignore to exclude .hypothesis/ cache directory
- Update frontend ANSI log viewer component
- Update backend requirements.txt with new dependencies
- Enables flexible data source integration while maintaining backward compatibility with existing database-driven workflows
This commit is contained in:
yyhuni
2026-01-09 09:02:09 +08:00
parent f5ad8e68e9
commit 38e2856c08
23 changed files with 2276 additions and 59 deletions

1
.gitignore vendored
View File

@@ -64,6 +64,7 @@ backend/.env.local
.coverage
htmlcov/
*.cover
.hypothesis/
# ============================
# 后端 (Go) 相关

View File

@@ -5,6 +5,10 @@
1. 从数据库获取 URL 列表websites 和/或 endpoints
2. 批量截图并保存快照
3. 同步到资产表
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
"""
# Django 环境初始化
@@ -12,6 +16,7 @@ from apps.common.prefect_django_setup import setup_django_for_prefect
import logging
from pathlib import Path
from typing import Optional
from prefect import flow
from apps.scan.tasks.screenshot import capture_screenshots_task
@@ -25,6 +30,7 @@ from apps.scan.services.target_export_service import (
get_urls_with_fallback,
DataSource,
)
from apps.scan.providers import TargetProvider
logger = logging.getLogger(__name__)
@@ -88,11 +94,16 @@ def screenshot_flow(
target_name: str,
target_id: int,
scan_workspace_dir: str,
enabled_tools: dict
enabled_tools: dict,
provider: Optional[TargetProvider] = None
) -> dict:
"""
截图 Flow
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库获取 URL
2. Provider 模式:使用 TargetProvider 从任意数据源获取 URL
工作流程:
Step 1: 解析配置
Step 2: 收集 URL 列表
@@ -105,6 +116,7 @@ def screenshot_flow(
target_id: 目标 ID
scan_workspace_dir: 扫描工作空间目录
enabled_tools: 启用的工具配置
provider: TargetProvider 实例(新模式,可选)
Returns:
dict: {
@@ -124,6 +136,7 @@ def screenshot_flow(
f" Scan ID: {scan_id}\n" +
f" Target: {target_name}\n" +
f" Workspace: {scan_workspace_dir}\n" +
f" Mode: {'Provider' if provider else 'Legacy'}\n" +
"="*60
)
@@ -136,14 +149,31 @@ def screenshot_flow(
logger.info("截图配置 - 并发: %d, URL来源: %s", concurrency, url_sources)
# Step 2: 使用统一服务收集 URL带黑名单过滤和回退
data_sources = _map_url_sources_to_data_sources(url_sources)
result = get_urls_with_fallback(target_id, sources=data_sources)
# Step 2: 收集 URL 列表
if provider is not None:
# Provider 模式:使用 TargetProvider 获取 URL
logger.info("使用 Provider 模式获取 URL - Provider: %s", type(provider).__name__)
urls = list(provider.iter_urls())
# 应用黑名单过滤(如果有)
blacklist_filter = provider.get_blacklist_filter()
if blacklist_filter:
urls = [url for url in urls if blacklist_filter.is_allowed(url)]
source_info = 'provider'
tried_sources = ['provider']
else:
# 传统模式:使用统一服务收集 URL带黑名单过滤和回退
data_sources = _map_url_sources_to_data_sources(url_sources)
result = get_urls_with_fallback(target_id, sources=data_sources)
urls = result['urls']
source_info = result['source']
tried_sources = result['tried_sources']
urls = result['urls']
logger.info(
"URL 收集完成 - 来源: %s, 数量: %d, 尝试过: %s",
result['source'], result['total_count'], result['tried_sources']
source_info, len(urls), tried_sources
)
if not urls:
@@ -159,7 +189,7 @@ def screenshot_flow(
'synced': 0
}
user_log(scan_id, "screenshot", f"Found {len(urls)} URLs to capture (source: {result['source']})")
user_log(scan_id, "screenshot", f"Found {len(urls)} URLs to capture (source: {source_info})")
# Step 3: 批量截图
logger.info("Step 3: 批量截图 - %d 个 URL", len(urls))

View File

@@ -0,0 +1,56 @@
"""
扫描目标提供者模块
提供统一的目标获取接口,支持多种数据源:
- DatabaseTargetProvider: 从数据库查询(完整扫描)
- ListTargetProvider: 使用内存列表快速扫描阶段1
- SnapshotTargetProvider: 从快照表读取快速扫描阶段2+
- PipelineTargetProvider: 使用管道输出Phase 2
使用方式:
from apps.scan.providers import (
DatabaseTargetProvider,
ListTargetProvider,
SnapshotTargetProvider,
ProviderContext
)
# 数据库模式(完整扫描)
provider = DatabaseTargetProvider(target_id=123)
# 列表模式快速扫描阶段1
context = ProviderContext(target_id=1, scan_id=100)
provider = ListTargetProvider(
targets=["a.test.com"],
context=context
)
# 快照模式快速扫描阶段2+
context = ProviderContext(target_id=1, scan_id=100)
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain",
context=context
)
# 使用 Provider
for host in provider.iter_hosts():
scan(host)
"""
from .base import TargetProvider, ProviderContext
from .list_provider import ListTargetProvider
from .database_provider import DatabaseTargetProvider
from .snapshot_provider import SnapshotTargetProvider, SnapshotType
from .pipeline_provider import PipelineTargetProvider, StageOutput
__all__ = [
'TargetProvider',
'ProviderContext',
'ListTargetProvider',
'DatabaseTargetProvider',
'SnapshotTargetProvider',
'SnapshotType',
'PipelineTargetProvider',
'StageOutput',
]

View File

@@ -0,0 +1,162 @@
"""
扫描目标提供者基础模块
定义 ProviderContext 数据类和 TargetProvider 抽象基类。
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Iterator, Optional, TYPE_CHECKING
import ipaddress
import logging
if TYPE_CHECKING:
from apps.common.utils import BlacklistFilter
logger = logging.getLogger(__name__)
@dataclass
class ProviderContext:
"""
Provider 上下文,携带元数据
Attributes:
target_id: 关联的 Target ID用于结果保存None 表示临时扫描(不保存)
scan_id: 扫描任务 ID
"""
target_id: Optional[int] = None
scan_id: Optional[int] = None
class TargetProvider(ABC):
"""
扫描目标提供者抽象基类
职责:
- 提供扫描目标域名、IP、URL 等)的迭代器
- 提供黑名单过滤器
- 携带上下文信息target_id, scan_id 等)
- 自动展开 CIDR子类无需关心
使用方式:
provider = create_target_provider(target_id=123)
for host in provider.iter_hosts():
print(host)
"""
def __init__(self, context: Optional[ProviderContext] = None):
"""
初始化 Provider
Args:
context: Provider 上下文None 时创建默认上下文
"""
self._context = context or ProviderContext()
@property
def context(self) -> ProviderContext:
"""返回 Provider 上下文"""
return self._context
@staticmethod
def _expand_host(host: str) -> Iterator[str]:
"""
展开主机(如果是 CIDR 则展开为多个 IP否则直接返回
这是一个内部方法,由 iter_hosts() 自动调用。
Args:
host: 主机字符串IP/域名/CIDR
Yields:
str: 单个主机IP 或域名)
示例:
"192.168.1.0/30""192.168.1.1", "192.168.1.2"
"192.168.1.1""192.168.1.1"
"example.com""example.com"
"invalid" → (跳过,不返回)
"""
from apps.common.validators import detect_target_type
from apps.targets.models import Target
host = host.strip()
if not host:
return
# 统一使用 detect_target_type 检测类型
try:
target_type = detect_target_type(host)
if target_type == Target.TargetType.CIDR:
# 展开 CIDR
network = ipaddress.ip_network(host, strict=False)
if network.num_addresses == 1:
yield str(network.network_address)
else:
for ip in network.hosts():
yield str(ip)
elif target_type == Target.TargetType.IP:
# 单个 IP
yield host
elif target_type == Target.TargetType.DOMAIN:
# 域名
yield host
except ValueError as e:
# 无效格式,跳过并记录警告
logger.warning("跳过无效的主机格式 '%s': %s", host, str(e))
def iter_hosts(self) -> Iterator[str]:
"""
迭代主机列表(域名/IP
自动展开 CIDR子类无需关心。
Yields:
str: 主机名或 IP 地址(单个,不包含 CIDR
"""
for host in self._iter_raw_hosts():
yield from self._expand_host(host)
@abstractmethod
def _iter_raw_hosts(self) -> Iterator[str]:
"""
迭代原始主机列表(可能包含 CIDR
子类实现此方法,返回原始数据即可,不需要处理 CIDR 展开。
Yields:
str: 主机名、IP 地址或 CIDR
"""
pass
@abstractmethod
def iter_urls(self) -> Iterator[str]:
"""
迭代 URL 列表
Yields:
str: URL 字符串
"""
pass
@abstractmethod
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
"""
获取黑名单过滤器
Returns:
BlacklistFilter: 黑名单过滤器实例,或 None不过滤
"""
pass
@property
def target_id(self) -> Optional[int]:
"""返回关联的 target_id临时扫描返回 None"""
return self._context.target_id
@property
def scan_id(self) -> Optional[int]:
"""返回关联的 scan_id"""
return self._context.scan_id

View File

@@ -0,0 +1,131 @@
"""
数据库目标提供者模块
提供基于数据库查询的目标提供者实现。
"""
import logging
from typing import Iterator, Optional, TYPE_CHECKING
from .base import TargetProvider, ProviderContext
if TYPE_CHECKING:
from apps.common.utils import BlacklistFilter
logger = logging.getLogger(__name__)
class DatabaseTargetProvider(TargetProvider):
"""
数据库目标提供者 - 从 Target 表及关联资产表查询
这是现有行为的封装,保持向后兼容。
数据来源:
- iter_hosts(): 根据 Target 类型返回域名/IP
- DOMAIN: 根域名 + Subdomain 表
- IP: 直接返回 IP
- CIDR: 使用 _expand_host() 展开为所有主机 IP
- iter_urls(): WebSite/Endpoint 表,带回退链
使用方式:
provider = DatabaseTargetProvider(target_id=123)
for host in provider.iter_hosts():
scan(host)
"""
def __init__(self, target_id: int, context: Optional[ProviderContext] = None):
"""
初始化数据库目标提供者
Args:
target_id: 目标 ID必需
context: Provider 上下文
"""
ctx = context or ProviderContext()
ctx.target_id = target_id
super().__init__(ctx)
self._blacklist_filter: Optional['BlacklistFilter'] = None # 延迟加载
def iter_hosts(self) -> Iterator[str]:
"""
从数据库查询主机列表,自动展开 CIDR 并应用黑名单过滤
重写基类方法以支持黑名单过滤(需要在 CIDR 展开后过滤)
"""
blacklist = self.get_blacklist_filter()
for host in self._iter_raw_hosts():
# 展开 CIDR
for expanded_host in self._expand_host(host):
# 应用黑名单过滤
if not blacklist or blacklist.is_allowed(expanded_host):
yield expanded_host
def _iter_raw_hosts(self) -> Iterator[str]:
"""
从数据库查询原始主机列表(可能包含 CIDR
根据 Target 类型决定数据来源:
- DOMAIN: 根域名 + Subdomain 表
- IP: 直接返回 target.name
- CIDR: 返回 CIDR 字符串(由 iter_hosts() 展开)
注意:此方法不应用黑名单过滤,过滤在 iter_hosts() 中进行
"""
from apps.targets.services import TargetService
from apps.targets.models import Target
from apps.asset.services.asset.subdomain_service import SubdomainService
target = TargetService().get_target(self.target_id)
if not target:
logger.warning("Target ID %d 不存在", self.target_id)
return
if target.type == Target.TargetType.DOMAIN:
# 返回根域名
yield target.name
# 返回子域名
subdomain_service = SubdomainService()
for domain in subdomain_service.iter_subdomain_names_by_target(
target_id=self.target_id,
chunk_size=1000
):
if domain != target.name: # 避免重复
yield domain
elif target.type == Target.TargetType.IP:
yield target.name
elif target.type == Target.TargetType.CIDR:
# 直接返回 CIDR由 iter_hosts() 展开并过滤
yield target.name
def iter_urls(self) -> Iterator[str]:
"""
从数据库查询 URL 列表
使用现有的回退链逻辑Endpoint → WebSite → Default
"""
from apps.scan.services.target_export_service import (
_iter_urls_with_fallback, DataSource
)
blacklist = self.get_blacklist_filter()
for url, source in _iter_urls_with_fallback(
target_id=self.target_id,
sources=[DataSource.ENDPOINT, DataSource.WEBSITE, DataSource.DEFAULT],
blacklist_filter=blacklist
):
yield url
def get_blacklist_filter(self) -> Optional['BlacklistFilter']:
"""获取黑名单过滤器(延迟加载)"""
if self._blacklist_filter is None:
from apps.common.services import BlacklistService
from apps.common.utils import BlacklistFilter
rules = BlacklistService().get_rules(self.target_id)
self._blacklist_filter = BlacklistFilter(rules)
return self._blacklist_filter

View File

@@ -0,0 +1,84 @@
"""
列表目标提供者模块
提供基于内存列表的目标提供者实现。
"""
from typing import Iterator, Optional, List
from .base import TargetProvider, ProviderContext
class ListTargetProvider(TargetProvider):
"""
列表目标提供者 - 直接使用内存中的列表
用于快速扫描、临时扫描等场景,只扫描用户指定的目标。
特点:
- 不查询数据库
- 不应用黑名单过滤(用户明确指定的目标)
- 不关联 target_id由调用方负责创建 Target
- 自动检测输入类型URL/域名/IP/CIDR
- 自动展开 CIDR
使用方式:
# 快速扫描:用户提供目标,自动识别类型
provider = ListTargetProvider(targets=[
"example.com", # 域名
"192.168.1.0/24", # CIDR自动展开
"https://api.example.com" # URL
])
for host in provider.iter_hosts():
scan(host)
"""
def __init__(
self,
targets: Optional[List[str]] = None,
context: Optional[ProviderContext] = None
):
"""
初始化列表目标提供者
Args:
targets: 目标列表自动识别类型URL/域名/IP/CIDR
context: Provider 上下文
"""
from apps.common.validators import detect_input_type
ctx = context or ProviderContext()
super().__init__(ctx)
# 自动分类目标
self._hosts = []
self._urls = []
if targets:
for target in targets:
target = target.strip()
if not target:
continue
try:
input_type = detect_input_type(target)
if input_type == 'url':
self._urls.append(target)
else:
# domain/ip/cidr 都作为 host
self._hosts.append(target)
except ValueError:
# 无法识别类型,默认作为 host
self._hosts.append(target)
def _iter_raw_hosts(self) -> Iterator[str]:
"""迭代原始主机列表(可能包含 CIDR"""
yield from self._hosts
def iter_urls(self) -> Iterator[str]:
"""迭代 URL 列表"""
yield from self._urls
def get_blacklist_filter(self) -> None:
"""列表模式不使用黑名单过滤"""
return None

View File

@@ -0,0 +1,91 @@
"""
管道目标提供者模块
提供基于管道阶段输出的目标提供者实现。
用于 Phase 2 管道模式的阶段间数据传递。
"""
from dataclasses import dataclass, field
from typing import Iterator, Optional, List, Dict, Any
from .base import TargetProvider, ProviderContext
@dataclass
class StageOutput:
"""
阶段输出数据
用于在管道阶段之间传递数据。
Attributes:
hosts: 主机列表(域名/IP
urls: URL 列表
new_targets: 新发现的目标列表
stats: 统计信息
success: 是否成功
error: 错误信息
"""
hosts: List[str] = field(default_factory=list)
urls: List[str] = field(default_factory=list)
new_targets: List[str] = field(default_factory=list)
stats: Dict[str, Any] = field(default_factory=dict)
success: bool = True
error: Optional[str] = None
class PipelineTargetProvider(TargetProvider):
"""
管道目标提供者 - 使用上一阶段的输出
用于 Phase 2 管道模式的阶段间数据传递。
特点:
- 不查询数据库
- 不应用黑名单过滤(数据已在上一阶段过滤)
- 直接使用 StageOutput 中的数据
使用方式Phase 2
stage1_output = stage1.run(input)
provider = PipelineTargetProvider(
previous_output=stage1_output,
target_id=123
)
for host in provider.iter_hosts():
stage2.scan(host)
"""
def __init__(
self,
previous_output: StageOutput,
target_id: Optional[int] = None,
context: Optional[ProviderContext] = None
):
"""
初始化管道目标提供者
Args:
previous_output: 上一阶段的输出
target_id: 可选,关联到某个 Target用于保存结果
context: Provider 上下文
"""
ctx = context or ProviderContext(target_id=target_id)
super().__init__(ctx)
self._previous_output = previous_output
def _iter_raw_hosts(self) -> Iterator[str]:
"""迭代上一阶段输出的原始主机(可能包含 CIDR"""
yield from self._previous_output.hosts
def iter_urls(self) -> Iterator[str]:
"""迭代上一阶段输出的 URL"""
yield from self._previous_output.urls
def get_blacklist_filter(self) -> None:
"""管道传递的数据已经过滤过了"""
return None
@property
def previous_output(self) -> StageOutput:
"""返回上一阶段的输出"""
return self._previous_output

View File

@@ -0,0 +1,175 @@
"""
快照目标提供者模块
提供基于快照表的目标提供者实现。
用于快速扫描的阶段间数据传递。
"""
import logging
from typing import Iterator, Optional, Literal
from .base import TargetProvider, ProviderContext
logger = logging.getLogger(__name__)
# 快照类型定义
SnapshotType = Literal["subdomain", "website", "endpoint", "host_port"]
class SnapshotTargetProvider(TargetProvider):
"""
快照目标提供者 - 从快照表读取本次扫描的数据
用于快速扫描的阶段间数据传递,解决精确扫描控制问题。
核心价值:
- 只返回本次扫描scan_id发现的资产
- 避免扫描历史数据DatabaseTargetProvider 会扫描所有历史资产)
特点:
- 通过 scan_id 过滤快照表
- 不应用黑名单过滤(数据已在上一阶段过滤)
- 支持多种快照类型subdomain/website/endpoint/host_port
使用场景:
# 快速扫描流程
用户输入: a.test.com
创建 Target: test.com (id=1)
创建 Scan: scan_id=100
# 阶段1: 子域名发现
provider = ListTargetProvider(
targets=["a.test.com"],
context=ProviderContext(target_id=1, scan_id=100)
)
# 发现: b.a.test.com, c.a.test.com
# 保存: SubdomainSnapshot(scan_id=100) + Subdomain(target_id=1)
# 阶段2: 端口扫描
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain",
context=ProviderContext(target_id=1, scan_id=100)
)
# 只返回: b.a.test.com, c.a.test.com本次扫描发现的
# 不返回: www.test.com, api.test.com历史数据
# 阶段3: 网站扫描
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="host_port",
context=ProviderContext(target_id=1, scan_id=100)
)
# 只返回本次扫描发现的 IP:Port
"""
def __init__(
self,
scan_id: int,
snapshot_type: SnapshotType,
context: Optional[ProviderContext] = None
):
"""
初始化快照目标提供者
Args:
scan_id: 扫描任务 ID必需
snapshot_type: 快照类型
- "subdomain": 子域名快照SubdomainSnapshot
- "website": 网站快照WebsiteSnapshot
- "endpoint": 端点快照EndpointSnapshot
- "host_port": 主机端口映射快照HostPortMappingSnapshot
context: Provider 上下文
"""
ctx = context or ProviderContext()
ctx.scan_id = scan_id
super().__init__(ctx)
self._scan_id = scan_id
self._snapshot_type = snapshot_type
def _iter_raw_hosts(self) -> Iterator[str]:
"""
从快照表迭代主机列表
根据 snapshot_type 选择不同的快照表:
- subdomain: SubdomainSnapshot.name
- host_port: HostPortMappingSnapshot.host (返回 host:port 格式,不经过验证)
"""
if self._snapshot_type == "subdomain":
from apps.asset.services.snapshot import SubdomainSnapshotsService
service = SubdomainSnapshotsService()
yield from service.iter_subdomain_names_by_scan(
scan_id=self._scan_id,
chunk_size=1000
)
elif self._snapshot_type == "host_port":
# host_port 类型不使用 _iter_raw_hosts直接在 iter_hosts 中处理
# 这里返回空,避免被基类的 iter_hosts 调用
return
else:
# 其他类型暂不支持 iter_hosts
logger.warning(
"快照类型 '%s' 不支持 iter_hosts返回空迭代器",
self._snapshot_type
)
return
def iter_hosts(self) -> Iterator[str]:
"""
迭代主机列表
对于 host_port 类型,返回 host:port 格式,不经过 CIDR 展开验证
"""
if self._snapshot_type == "host_port":
# host_port 类型直接返回 host:port不经过 _expand_host 验证
from apps.asset.services.snapshot import HostPortMappingSnapshotsService
service = HostPortMappingSnapshotsService()
queryset = service.get_by_scan(scan_id=self._scan_id)
for mapping in queryset.iterator(chunk_size=1000):
yield f"{mapping.host}:{mapping.port}"
else:
# 其他类型使用基类的 iter_hosts会调用 _iter_raw_hosts 并展开 CIDR
yield from super().iter_hosts()
def iter_urls(self) -> Iterator[str]:
"""
从快照表迭代 URL 列表
根据 snapshot_type 选择不同的快照表:
- website: WebsiteSnapshot.url
- endpoint: EndpointSnapshot.url
"""
if self._snapshot_type == "website":
from apps.asset.services.snapshot import WebsiteSnapshotsService
service = WebsiteSnapshotsService()
yield from service.iter_website_urls_by_scan(
scan_id=self._scan_id,
chunk_size=1000
)
elif self._snapshot_type == "endpoint":
from apps.asset.services.snapshot import EndpointSnapshotsService
service = EndpointSnapshotsService()
# 从快照表获取端点 URL
queryset = service.get_by_scan(scan_id=self._scan_id)
for endpoint in queryset.iterator(chunk_size=1000):
yield endpoint.url
else:
# 其他类型暂不支持 iter_urls
logger.warning(
"快照类型 '%s' 不支持 iter_urls返回空迭代器",
self._snapshot_type
)
return
def get_blacklist_filter(self) -> None:
"""快照数据已在上一阶段过滤过了"""
return None
@property
def snapshot_type(self) -> SnapshotType:
"""返回快照类型"""
return self._snapshot_type

View File

@@ -0,0 +1,3 @@
"""
扫描目标提供者测试模块
"""

View File

@@ -0,0 +1,256 @@
"""
通用属性测试
包含跨多个 Provider 的通用属性测试:
- Property 4: Context Propagation
- Property 5: Non-Database Provider Blacklist Filter
- Property 7: CIDR Expansion Consistency
"""
import pytest
from hypothesis import given, strategies as st, settings
from ipaddress import IPv4Network
from apps.scan.providers import (
ProviderContext,
ListTargetProvider,
DatabaseTargetProvider,
PipelineTargetProvider,
SnapshotTargetProvider
)
from apps.scan.providers.pipeline_provider import StageOutput
class TestContextPropagation:
"""
Property 4: Context Propagation
*For any* ProviderContext传入 Provider 构造函数后,
Provider 的 target_id 和 scan_id 属性应该与 context 中的值一致。
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
@given(
target_id=st.integers(min_value=1, max_value=10000),
scan_id=st.integers(min_value=1, max_value=10000)
)
@settings(max_examples=100)
def test_property_4_list_provider_context_propagation(self, target_id, scan_id):
"""
Property 4: Context Propagation (ListTargetProvider)
Feature: scan-target-provider, Property 4: Context Propagation
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
ctx = ProviderContext(target_id=target_id, scan_id=scan_id)
provider = ListTargetProvider(targets=["example.com"], context=ctx)
assert provider.target_id == target_id
assert provider.scan_id == scan_id
assert provider.context.target_id == target_id
assert provider.context.scan_id == scan_id
@given(
target_id=st.integers(min_value=1, max_value=10000),
scan_id=st.integers(min_value=1, max_value=10000)
)
@settings(max_examples=100)
def test_property_4_database_provider_context_propagation(self, target_id, scan_id):
"""
Property 4: Context Propagation (DatabaseTargetProvider)
Feature: scan-target-provider, Property 4: Context Propagation
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
ctx = ProviderContext(target_id=999, scan_id=scan_id)
# DatabaseTargetProvider 会覆盖 context 中的 target_id
provider = DatabaseTargetProvider(target_id=target_id, context=ctx)
assert provider.target_id == target_id # 使用构造函数参数
assert provider.scan_id == scan_id # 使用 context 中的值
assert provider.context.target_id == target_id
assert provider.context.scan_id == scan_id
@given(
target_id=st.integers(min_value=1, max_value=10000),
scan_id=st.integers(min_value=1, max_value=10000)
)
@settings(max_examples=100)
def test_property_4_pipeline_provider_context_propagation(self, target_id, scan_id):
"""
Property 4: Context Propagation (PipelineTargetProvider)
Feature: scan-target-provider, Property 4: Context Propagation
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
ctx = ProviderContext(target_id=target_id, scan_id=scan_id)
stage_output = StageOutput(hosts=["example.com"])
provider = PipelineTargetProvider(previous_output=stage_output, context=ctx)
assert provider.target_id == target_id
assert provider.scan_id == scan_id
assert provider.context.target_id == target_id
assert provider.context.scan_id == scan_id
@given(
target_id=st.integers(min_value=1, max_value=10000),
scan_id=st.integers(min_value=1, max_value=10000)
)
@settings(max_examples=100)
def test_property_4_snapshot_provider_context_propagation(self, target_id, scan_id):
"""
Property 4: Context Propagation (SnapshotTargetProvider)
Feature: scan-target-provider, Property 4: Context Propagation
**Validates: Requirements 1.3, 1.5, 7.4, 7.5**
"""
ctx = ProviderContext(target_id=target_id, scan_id=999)
# SnapshotTargetProvider 会覆盖 context 中的 scan_id
provider = SnapshotTargetProvider(
scan_id=scan_id,
snapshot_type="subdomain",
context=ctx
)
assert provider.target_id == target_id # 使用 context 中的值
assert provider.scan_id == scan_id # 使用构造函数参数
assert provider.context.target_id == target_id
assert provider.context.scan_id == scan_id
class TestNonDatabaseProviderBlacklistFilter:
"""
Property 5: Non-Database Provider Blacklist Filter
*For any* ListTargetProvider 或 PipelineTargetProvider 实例,
get_blacklist_filter() 方法应该返回 None。
**Validates: Requirements 3.4, 9.4, 9.5**
"""
@given(targets=st.lists(st.text(min_size=1, max_size=20), max_size=10))
@settings(max_examples=100)
def test_property_5_list_provider_no_blacklist(self, targets):
"""
Property 5: Non-Database Provider Blacklist Filter (ListTargetProvider)
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
**Validates: Requirements 3.4, 9.4, 9.5**
"""
provider = ListTargetProvider(targets=targets)
assert provider.get_blacklist_filter() is None
@given(hosts=st.lists(st.text(min_size=1, max_size=20), max_size=10))
@settings(max_examples=100)
def test_property_5_pipeline_provider_no_blacklist(self, hosts):
"""
Property 5: Non-Database Provider Blacklist Filter (PipelineTargetProvider)
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
**Validates: Requirements 3.4, 9.4, 9.5**
"""
stage_output = StageOutput(hosts=hosts)
provider = PipelineTargetProvider(previous_output=stage_output)
assert provider.get_blacklist_filter() is None
def test_property_5_snapshot_provider_no_blacklist(self):
"""
Property 5: Non-Database Provider Blacklist Filter (SnapshotTargetProvider)
Feature: scan-target-provider, Property 5: Non-Database Provider Blacklist Filter
**Validates: Requirements 3.4, 9.4, 9.5**
"""
provider = SnapshotTargetProvider(scan_id=1, snapshot_type="subdomain")
assert provider.get_blacklist_filter() is None
class TestCIDRExpansionConsistency:
"""
Property 7: CIDR Expansion Consistency
*For any* CIDR 字符串(如 "192.168.1.0/24"),所有 Provider 的 iter_hosts()
方法应该将其展开为相同的单个 IP 地址列表。
**Validates: Requirements 1.1, 3.6**
"""
@given(
# 生成小的 CIDR 范围以避免测试超时
network_prefix=st.integers(min_value=1, max_value=254),
cidr_suffix=st.integers(min_value=28, max_value=30) # /28 = 16 IPs, /30 = 4 IPs
)
@settings(max_examples=50, deadline=None)
def test_property_7_cidr_expansion_consistency(self, network_prefix, cidr_suffix):
"""
Property 7: CIDR Expansion Consistency
Feature: scan-target-provider, Property 7: CIDR Expansion Consistency
**Validates: Requirements 1.1, 3.6**
For any CIDR string, all Providers should expand it to the same IP list.
"""
cidr = f"192.168.{network_prefix}.0/{cidr_suffix}"
# 计算预期的 IP 列表
network = IPv4Network(cidr, strict=False)
# 排除网络地址和广播地址
expected_ips = [str(ip) for ip in network.hosts()]
# 如果 CIDR 太小(/31 或 /32使用所有地址
if not expected_ips:
expected_ips = [str(ip) for ip in network]
# ListTargetProvider
list_provider = ListTargetProvider(targets=[cidr])
list_result = list(list_provider.iter_hosts())
# PipelineTargetProvider
stage_output = StageOutput(hosts=[cidr])
pipeline_provider = PipelineTargetProvider(previous_output=stage_output)
pipeline_result = list(pipeline_provider.iter_hosts())
# 验证:所有 Provider 展开的结果应该一致
assert list_result == expected_ips, f"ListProvider CIDR expansion mismatch for {cidr}"
assert pipeline_result == expected_ips, f"PipelineProvider CIDR expansion mismatch for {cidr}"
assert list_result == pipeline_result, f"Providers produce different results for {cidr}"
def test_cidr_expansion_with_multiple_cidrs(self):
"""测试多个 CIDR 的展开一致性"""
cidrs = ["192.168.1.0/30", "10.0.0.0/30"]
# 计算预期结果
expected_ips = []
for cidr in cidrs:
network = IPv4Network(cidr, strict=False)
expected_ips.extend([str(ip) for ip in network.hosts()])
# ListTargetProvider
list_provider = ListTargetProvider(targets=cidrs)
list_result = list(list_provider.iter_hosts())
# PipelineTargetProvider
stage_output = StageOutput(hosts=cidrs)
pipeline_provider = PipelineTargetProvider(previous_output=stage_output)
pipeline_result = list(pipeline_provider.iter_hosts())
# 验证
assert list_result == expected_ips
assert pipeline_result == expected_ips
assert list_result == pipeline_result
def test_mixed_hosts_and_cidrs(self):
"""测试混合主机和 CIDR 的处理"""
targets = ["example.com", "192.168.1.0/30", "test.com"]
# 计算预期结果
network = IPv4Network("192.168.1.0/30", strict=False)
cidr_ips = [str(ip) for ip in network.hosts()]
expected = ["example.com"] + cidr_ips + ["test.com"]
# ListTargetProvider
list_provider = ListTargetProvider(targets=targets)
list_result = list(list_provider.iter_hosts())
# 验证
assert list_result == expected

View File

@@ -0,0 +1,158 @@
"""
DatabaseTargetProvider 属性测试
Property 7: DatabaseTargetProvider Blacklist Application
*For any* 带有黑名单规则的 target_idDatabaseTargetProvider 的 iter_hosts() 和 iter_urls()
应该过滤掉匹配黑名单规则的目标。
**Validates: Requirements 2.3, 10.1, 10.2, 10.3**
"""
import pytest
from unittest.mock import patch, MagicMock
from hypothesis import given, strategies as st, settings
from apps.scan.providers.database_provider import DatabaseTargetProvider
from apps.scan.providers.base import ProviderContext
# 生成有效域名的策略
def valid_domain_strategy():
"""生成有效的域名"""
label = st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=10
)
return st.builds(
lambda a, b, c: f"{a}.{b}.{c}",
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
)
class MockBlacklistFilter:
"""模拟黑名单过滤器"""
def __init__(self, blocked_patterns: list):
self.blocked_patterns = blocked_patterns
def is_allowed(self, target: str) -> bool:
"""检查目标是否被允许(不在黑名单中)"""
for pattern in self.blocked_patterns:
if pattern in target:
return False
return True
class TestDatabaseTargetProviderProperties:
"""DatabaseTargetProvider 属性测试类"""
@given(
hosts=st.lists(valid_domain_strategy(), min_size=1, max_size=20),
blocked_keyword=st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=5
)
)
@settings(max_examples=100)
def test_property_7_blacklist_filters_hosts(self, hosts, blocked_keyword):
"""
Property 7: DatabaseTargetProvider Blacklist Application (hosts)
Feature: scan-target-provider, Property 7: DatabaseTargetProvider Blacklist Application
**Validates: Requirements 2.3, 10.1, 10.2, 10.3**
For any set of hosts and a blacklist keyword, the provider should filter out
all hosts containing the blocked keyword.
"""
# 创建模拟的黑名单过滤器
mock_filter = MockBlacklistFilter([blocked_keyword])
# 创建 provider 并注入模拟的黑名单过滤器
provider = DatabaseTargetProvider(target_id=1)
provider._blacklist_filter = mock_filter
# 模拟 Target 和 SubdomainService
mock_target = MagicMock()
mock_target.type = 'domain'
mock_target.name = hosts[0] if hosts else 'example.com'
with patch('apps.targets.services.TargetService') as mock_target_service, \
patch('apps.asset.services.asset.subdomain_service.SubdomainService') as mock_subdomain_service:
mock_target_service.return_value.get_target.return_value = mock_target
mock_subdomain_service.return_value.iter_subdomain_names_by_target.return_value = iter(hosts[1:] if len(hosts) > 1 else [])
# 获取结果
result = list(provider.iter_hosts())
# 验证:所有结果都不包含被阻止的关键词
for host in result:
assert blocked_keyword not in host, f"Host '{host}' should be filtered by blacklist keyword '{blocked_keyword}'"
# 验证:所有不包含关键词的主机都应该在结果中
if hosts:
all_hosts = [hosts[0]] + [h for h in hosts[1:] if h != hosts[0]]
expected_allowed = [h for h in all_hosts if blocked_keyword not in h]
else:
expected_allowed = []
assert set(result) == set(expected_allowed)
class TestDatabaseTargetProviderUnit:
"""DatabaseTargetProvider 单元测试类"""
def test_target_id_in_context(self):
"""测试 target_id 正确设置到上下文中"""
provider = DatabaseTargetProvider(target_id=123)
assert provider.target_id == 123
assert provider.context.target_id == 123
def test_context_propagation(self):
"""测试上下文传递"""
ctx = ProviderContext(scan_id=789)
provider = DatabaseTargetProvider(target_id=123, context=ctx)
assert provider.target_id == 123 # target_id 被覆盖
assert provider.scan_id == 789
def test_blacklist_filter_lazy_loading(self):
"""测试黑名单过滤器延迟加载"""
provider = DatabaseTargetProvider(target_id=123)
# 初始时 _blacklist_filter 为 None
assert provider._blacklist_filter is None
# 模拟 BlacklistService
with patch('apps.common.services.BlacklistService') as mock_service, \
patch('apps.common.utils.BlacklistFilter') as mock_filter_class:
mock_service.return_value.get_rules.return_value = []
mock_filter_instance = MagicMock()
mock_filter_class.return_value = mock_filter_instance
# 第一次调用
result1 = provider.get_blacklist_filter()
assert result1 == mock_filter_instance
# 第二次调用应该返回缓存的实例
result2 = provider.get_blacklist_filter()
assert result2 == mock_filter_instance
# BlacklistService 只应该被调用一次
mock_service.return_value.get_rules.assert_called_once_with(123)
def test_nonexistent_target_returns_empty(self):
"""测试不存在的 target 返回空迭代器"""
provider = DatabaseTargetProvider(target_id=99999)
with patch('apps.targets.services.TargetService') as mock_service, \
patch('apps.common.services.BlacklistService') as mock_blacklist_service:
mock_service.return_value.get_target.return_value = None
mock_blacklist_service.return_value.get_rules.return_value = []
result = list(provider.iter_hosts())
assert result == []

View File

@@ -0,0 +1,152 @@
"""
ListTargetProvider 属性测试
Property 1: ListTargetProvider Round-Trip
*For any* 主机列表和 URL 列表,创建 ListTargetProvider 后迭代 iter_hosts() 和 iter_urls()
应该返回与输入相同的元素(顺序相同)。
**Validates: Requirements 3.1, 3.2**
"""
import pytest
from hypothesis import given, strategies as st, settings, assume
from apps.scan.providers.list_provider import ListTargetProvider
from apps.scan.providers.base import ProviderContext
# 生成有效域名的策略
def valid_domain_strategy():
"""生成有效的域名"""
# 生成简单的域名格式: subdomain.domain.tld
label = st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=10
)
return st.builds(
lambda a, b, c: f"{a}.{b}.{c}",
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
)
# 生成有效 IP 地址的策略
def valid_ip_strategy():
"""生成有效的 IPv4 地址"""
octet = st.integers(min_value=1, max_value=254)
return st.builds(
lambda a, b, c, d: f"{a}.{b}.{c}.{d}",
octet, octet, octet, octet
)
# 组合策略:域名或 IP
host_strategy = st.one_of(valid_domain_strategy(), valid_ip_strategy())
# 生成有效 URL 的策略
def valid_url_strategy():
"""生成有效的 URL"""
domain = valid_domain_strategy()
return st.builds(
lambda d, path: f"https://{d}/{path}" if path else f"https://{d}",
domain,
st.one_of(
st.just(""),
st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=1,
max_size=10
)
)
)
url_strategy = valid_url_strategy()
class TestListTargetProviderProperties:
"""ListTargetProvider 属性测试类"""
@given(hosts=st.lists(host_strategy, max_size=50))
@settings(max_examples=100)
def test_property_1_hosts_round_trip(self, hosts):
"""
Property 1: ListTargetProvider Round-Trip (hosts)
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
**Validates: Requirements 3.1, 3.2**
For any host list, creating a ListTargetProvider and iterating iter_hosts()
should return the same elements in the same order.
"""
# ListTargetProvider 使用 targets 参数,自动分类为 hosts/urls
provider = ListTargetProvider(targets=hosts)
result = list(provider.iter_hosts())
assert result == hosts
@given(urls=st.lists(url_strategy, max_size=50))
@settings(max_examples=100)
def test_property_1_urls_round_trip(self, urls):
"""
Property 1: ListTargetProvider Round-Trip (urls)
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
**Validates: Requirements 3.1, 3.2**
For any URL list, creating a ListTargetProvider and iterating iter_urls()
should return the same elements in the same order.
"""
# ListTargetProvider 使用 targets 参数,自动分类为 hosts/urls
provider = ListTargetProvider(targets=urls)
result = list(provider.iter_urls())
assert result == urls
@given(
hosts=st.lists(host_strategy, max_size=30),
urls=st.lists(url_strategy, max_size=30)
)
@settings(max_examples=100)
def test_property_1_combined_round_trip(self, hosts, urls):
"""
Property 1: ListTargetProvider Round-Trip (combined)
Feature: scan-target-provider, Property 1: ListTargetProvider Round-Trip
**Validates: Requirements 3.1, 3.2**
For any combination of hosts and URLs, both should round-trip correctly.
"""
# 合并 hosts 和 urlsListTargetProvider 会自动分类
combined = hosts + urls
provider = ListTargetProvider(targets=combined)
hosts_result = list(provider.iter_hosts())
urls_result = list(provider.iter_urls())
assert hosts_result == hosts
assert urls_result == urls
class TestListTargetProviderUnit:
"""ListTargetProvider 单元测试类"""
def test_empty_lists(self):
"""测试空列表返回空迭代器 - Requirements 3.5"""
provider = ListTargetProvider()
assert list(provider.iter_hosts()) == []
assert list(provider.iter_urls()) == []
def test_blacklist_filter_returns_none(self):
"""测试黑名单过滤器返回 None - Requirements 3.4"""
provider = ListTargetProvider(targets=["example.com"])
assert provider.get_blacklist_filter() is None
def test_target_id_association(self):
"""测试 target_id 关联 - Requirements 3.3"""
ctx = ProviderContext(target_id=123)
provider = ListTargetProvider(targets=["example.com"], context=ctx)
assert provider.target_id == 123
def test_context_propagation(self):
"""测试上下文传递"""
ctx = ProviderContext(target_id=456, scan_id=789)
provider = ListTargetProvider(targets=["example.com"], context=ctx)
assert provider.target_id == 456
assert provider.scan_id == 789

View File

@@ -0,0 +1,180 @@
"""
PipelineTargetProvider 属性测试
Property 3: PipelineTargetProvider Round-Trip
*For any* StageOutput 对象PipelineTargetProvider 的 iter_hosts() 和 iter_urls()
应该返回与 StageOutput 中 hosts 和 urls 列表相同的元素。
**Validates: Requirements 5.1, 5.2**
"""
import pytest
from hypothesis import given, strategies as st, settings
from apps.scan.providers.pipeline_provider import PipelineTargetProvider, StageOutput
from apps.scan.providers.base import ProviderContext
# 生成有效域名的策略
def valid_domain_strategy():
"""生成有效的域名"""
label = st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=10
)
return st.builds(
lambda a, b, c: f"{a}.{b}.{c}",
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
)
# 生成有效 IP 地址的策略
def valid_ip_strategy():
"""生成有效的 IPv4 地址"""
octet = st.integers(min_value=1, max_value=254)
return st.builds(
lambda a, b, c, d: f"{a}.{b}.{c}.{d}",
octet, octet, octet, octet
)
# 组合策略:域名或 IP
host_strategy = st.one_of(valid_domain_strategy(), valid_ip_strategy())
# 生成有效 URL 的策略
def valid_url_strategy():
"""生成有效的 URL"""
domain = valid_domain_strategy()
return st.builds(
lambda d, path: f"https://{d}/{path}" if path else f"https://{d}",
domain,
st.one_of(
st.just(""),
st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=1,
max_size=10
)
)
)
url_strategy = valid_url_strategy()
class TestPipelineTargetProviderProperties:
"""PipelineTargetProvider 属性测试类"""
@given(hosts=st.lists(host_strategy, max_size=50))
@settings(max_examples=100)
def test_property_3_hosts_round_trip(self, hosts):
"""
Property 3: PipelineTargetProvider Round-Trip (hosts)
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
**Validates: Requirements 5.1, 5.2**
For any StageOutput with hosts, PipelineTargetProvider should return
the same hosts in the same order.
"""
stage_output = StageOutput(hosts=hosts)
provider = PipelineTargetProvider(previous_output=stage_output)
result = list(provider.iter_hosts())
assert result == hosts
@given(urls=st.lists(url_strategy, max_size=50))
@settings(max_examples=100)
def test_property_3_urls_round_trip(self, urls):
"""
Property 3: PipelineTargetProvider Round-Trip (urls)
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
**Validates: Requirements 5.1, 5.2**
For any StageOutput with urls, PipelineTargetProvider should return
the same urls in the same order.
"""
stage_output = StageOutput(urls=urls)
provider = PipelineTargetProvider(previous_output=stage_output)
result = list(provider.iter_urls())
assert result == urls
@given(
hosts=st.lists(host_strategy, max_size=30),
urls=st.lists(url_strategy, max_size=30)
)
@settings(max_examples=100)
def test_property_3_combined_round_trip(self, hosts, urls):
"""
Property 3: PipelineTargetProvider Round-Trip (combined)
Feature: scan-target-provider, Property 3: PipelineTargetProvider Round-Trip
**Validates: Requirements 5.1, 5.2**
For any StageOutput with both hosts and urls, both should round-trip correctly.
"""
stage_output = StageOutput(hosts=hosts, urls=urls)
provider = PipelineTargetProvider(previous_output=stage_output)
hosts_result = list(provider.iter_hosts())
urls_result = list(provider.iter_urls())
assert hosts_result == hosts
assert urls_result == urls
class TestPipelineTargetProviderUnit:
"""PipelineTargetProvider 单元测试类"""
def test_empty_stage_output(self):
"""测试空 StageOutput 返回空迭代器 - Requirements 5.5"""
stage_output = StageOutput()
provider = PipelineTargetProvider(previous_output=stage_output)
assert list(provider.iter_hosts()) == []
assert list(provider.iter_urls()) == []
def test_blacklist_filter_returns_none(self):
"""测试黑名单过滤器返回 None - Requirements 5.3"""
stage_output = StageOutput(hosts=["example.com"])
provider = PipelineTargetProvider(previous_output=stage_output)
assert provider.get_blacklist_filter() is None
def test_target_id_association(self):
"""测试 target_id 关联 - Requirements 5.4"""
stage_output = StageOutput(hosts=["example.com"])
provider = PipelineTargetProvider(previous_output=stage_output, target_id=123)
assert provider.target_id == 123
def test_context_propagation(self):
"""测试上下文传递"""
ctx = ProviderContext(target_id=456, scan_id=789)
stage_output = StageOutput(hosts=["example.com"])
provider = PipelineTargetProvider(previous_output=stage_output, context=ctx)
assert provider.target_id == 456
assert provider.scan_id == 789
def test_previous_output_property(self):
"""测试 previous_output 属性"""
stage_output = StageOutput(hosts=["example.com"], urls=["https://example.com"])
provider = PipelineTargetProvider(previous_output=stage_output)
assert provider.previous_output is stage_output
assert provider.previous_output.hosts == ["example.com"]
assert provider.previous_output.urls == ["https://example.com"]
def test_stage_output_with_metadata(self):
"""测试带元数据的 StageOutput"""
stage_output = StageOutput(
hosts=["example.com"],
urls=["https://example.com"],
new_targets=["new.example.com"],
stats={"count": 1},
success=True,
error=None
)
provider = PipelineTargetProvider(previous_output=stage_output)
assert list(provider.iter_hosts()) == ["example.com"]
assert list(provider.iter_urls()) == ["https://example.com"]
assert provider.previous_output.new_targets == ["new.example.com"]
assert provider.previous_output.stats == {"count": 1}

View File

@@ -0,0 +1,191 @@
"""
SnapshotTargetProvider 单元测试
"""
import pytest
from unittest.mock import Mock, patch
from apps.scan.providers import SnapshotTargetProvider, ProviderContext
class TestSnapshotTargetProvider:
"""SnapshotTargetProvider 测试类"""
def test_init_with_scan_id_and_type(self):
"""测试初始化"""
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain"
)
assert provider.scan_id == 100
assert provider.snapshot_type == "subdomain"
assert provider.target_id is None # 默认 context
def test_init_with_context(self):
"""测试带 context 初始化"""
ctx = ProviderContext(target_id=1, scan_id=100)
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain",
context=ctx
)
assert provider.scan_id == 100
assert provider.target_id == 1
assert provider.snapshot_type == "subdomain"
@patch('apps.asset.services.snapshot.SubdomainSnapshotsService')
def test_iter_hosts_subdomain(self, mock_service_class):
"""测试从子域名快照迭代主机"""
# Mock service
mock_service = Mock()
mock_service.iter_subdomain_names_by_scan.return_value = iter([
"a.example.com",
"b.example.com"
])
mock_service_class.return_value = mock_service
# 创建 provider
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain"
)
# 迭代主机
hosts = list(provider.iter_hosts())
assert hosts == ["a.example.com", "b.example.com"]
mock_service.iter_subdomain_names_by_scan.assert_called_once_with(
scan_id=100,
chunk_size=1000
)
@patch('apps.asset.services.snapshot.HostPortMappingSnapshotsService')
def test_iter_hosts_host_port(self, mock_service_class):
"""测试从主机端口映射快照迭代主机"""
# Mock queryset
mock_mapping1 = Mock()
mock_mapping1.host = "example.com"
mock_mapping1.port = 80
mock_mapping2 = Mock()
mock_mapping2.host = "example.com"
mock_mapping2.port = 443
mock_queryset = Mock()
mock_queryset.iterator.return_value = iter([mock_mapping1, mock_mapping2])
# Mock service
mock_service = Mock()
mock_service.get_by_scan.return_value = mock_queryset
mock_service_class.return_value = mock_service
# 创建 provider
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="host_port"
)
# 迭代主机
hosts = list(provider.iter_hosts())
assert hosts == ["example.com:80", "example.com:443"]
mock_service.get_by_scan.assert_called_once_with(scan_id=100)
@patch('apps.asset.services.snapshot.WebsiteSnapshotsService')
def test_iter_urls_website(self, mock_service_class):
"""测试从网站快照迭代 URL"""
# Mock service
mock_service = Mock()
mock_service.iter_website_urls_by_scan.return_value = iter([
"http://example.com",
"https://example.com"
])
mock_service_class.return_value = mock_service
# 创建 provider
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="website"
)
# 迭代 URL
urls = list(provider.iter_urls())
assert urls == ["http://example.com", "https://example.com"]
mock_service.iter_website_urls_by_scan.assert_called_once_with(
scan_id=100,
chunk_size=1000
)
@patch('apps.asset.services.snapshot.EndpointSnapshotsService')
def test_iter_urls_endpoint(self, mock_service_class):
"""测试从端点快照迭代 URL"""
# Mock queryset
mock_endpoint1 = Mock()
mock_endpoint1.url = "http://example.com/api/v1"
mock_endpoint2 = Mock()
mock_endpoint2.url = "http://example.com/api/v2"
mock_queryset = Mock()
mock_queryset.iterator.return_value = iter([mock_endpoint1, mock_endpoint2])
# Mock service
mock_service = Mock()
mock_service.get_by_scan.return_value = mock_queryset
mock_service_class.return_value = mock_service
# 创建 provider
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="endpoint"
)
# 迭代 URL
urls = list(provider.iter_urls())
assert urls == ["http://example.com/api/v1", "http://example.com/api/v2"]
mock_service.get_by_scan.assert_called_once_with(scan_id=100)
def test_iter_hosts_unsupported_type(self):
"""测试不支持的快照类型iter_hosts"""
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="website" # website 不支持 iter_hosts
)
hosts = list(provider.iter_hosts())
assert hosts == []
def test_iter_urls_unsupported_type(self):
"""测试不支持的快照类型iter_urls"""
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain" # subdomain 不支持 iter_urls
)
urls = list(provider.iter_urls())
assert urls == []
def test_get_blacklist_filter(self):
"""测试黑名单过滤器(快照模式不使用黑名单)"""
provider = SnapshotTargetProvider(
scan_id=100,
snapshot_type="subdomain"
)
assert provider.get_blacklist_filter() is None
def test_context_propagation(self):
"""测试上下文传递"""
ctx = ProviderContext(target_id=456, scan_id=789)
provider = SnapshotTargetProvider(
scan_id=100, # 会被 context 覆盖
snapshot_type="subdomain",
context=ctx
)
assert provider.target_id == 456
assert provider.scan_id == 100 # scan_id 在 __init__ 中被设置

View File

@@ -1,36 +1,48 @@
"""
导出站点 URL 到 TXT 文件的 Task
使用 export_urls_with_fallback 用例函数处理回退链逻辑
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库导出
2. Provider 模式:使用 TargetProvider 从任意数据源导出
数据源: WebSite.url → Default
"""
import logging
from typing import Optional
from pathlib import Path
from prefect import task
from apps.scan.services.target_export_service import (
export_urls_with_fallback,
DataSource,
)
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
logger = logging.getLogger(__name__)
@task(name="export_sites")
def export_sites_task(
target_id: int,
output_file: str,
target_id: Optional[int] = None,
output_file: str = "",
provider: Optional[TargetProvider] = None,
batch_size: int = 1000,
) -> dict:
"""
导出目标下的所有站点 URL 到 TXT 文件
数据源优先级(回退链)
支持两种模式
1. 传统模式(向后兼容):传入 target_id从数据库导出
2. Provider 模式:传入 provider从任意数据源导出
数据源优先级(回退链,仅传统模式):
1. WebSite 表 - 站点级别 URL
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
Args:
target_id: 目标 ID
target_id: 目标 ID(传统模式,向后兼容)
output_file: 输出文件路径(绝对路径)
provider: TargetProvider 实例(新模式)
batch_size: 每次读取的批次大小,默认 1000
Returns:
@@ -44,6 +56,17 @@ def export_sites_task(
ValueError: 参数错误
IOError: 文件写入失败
"""
# 参数验证:至少提供一个
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
# Provider 模式:使用 TargetProvider 导出
if provider is not None:
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
return _export_with_provider(output_file, provider)
# 传统模式:使用 export_urls_with_fallback
logger.info("使用传统模式 - Target ID: %d", target_id)
result = export_urls_with_fallback(
target_id=target_id,
output_file=output_file,
@@ -62,3 +85,32 @@ def export_sites_task(
'output_file': result['output_file'],
'total_count': result['total_count'],
}
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
"""使用 Provider 导出 URL"""
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
total_count = 0
blacklist_filter = provider.get_blacklist_filter()
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in provider.iter_urls():
# 应用黑名单过滤(如果有)
if blacklist_filter and not blacklist_filter.is_allowed(url):
continue
f.write(f"{url}\n")
total_count += 1
if total_count % 1000 == 0:
logger.info("已导出 %d 个 URL...", total_count)
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
return {
'success': True,
'output_file': str(output_path),
'total_count': total_count,
}

View File

@@ -1,11 +1,16 @@
"""
导出 URL 任务
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库导出
2. Provider 模式:使用 TargetProvider 从任意数据源导出
用于指纹识别前导出目标下的 URL 到文件
使用 export_urls_with_fallback 用例函数处理回退链逻辑
"""
import logging
from typing import Optional
from pathlib import Path
from prefect import task
@@ -13,33 +18,51 @@ from apps.scan.services.target_export_service import (
export_urls_with_fallback,
DataSource,
)
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
logger = logging.getLogger(__name__)
@task(name="export_urls_for_fingerprint")
def export_urls_for_fingerprint_task(
target_id: int,
output_file: str,
target_id: Optional[int] = None,
output_file: str = "",
source: str = 'website', # 保留参数,兼容旧调用(实际值由回退链决定)
provider: Optional[TargetProvider] = None,
batch_size: int = 1000
) -> dict:
"""
导出目标下的 URL 到文件(用于指纹识别)
数据源优先级(回退链)
支持两种模式
1. 传统模式(向后兼容):传入 target_id从数据库导出
2. Provider 模式:传入 provider从任意数据源导出
数据源优先级(回退链,仅传统模式):
1. WebSite 表 - 站点级别 URL
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
Args:
target_id: 目标 ID
target_id: 目标 ID(传统模式,向后兼容)
output_file: 输出文件路径
source: 数据源类型(保留参数,兼容旧调用,实际值由回退链决定)
provider: TargetProvider 实例(新模式)
batch_size: 批量读取大小
Returns:
dict: {'output_file': str, 'total_count': int, 'source': str}
"""
# 参数验证:至少提供一个
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
# Provider 模式:使用 TargetProvider 导出
if provider is not None:
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
return _export_with_provider(output_file, provider)
# 传统模式:使用 export_urls_with_fallback
logger.info("使用传统模式 - Target ID: %d", target_id)
result = export_urls_with_fallback(
target_id=target_id,
output_file=output_file,
@@ -58,3 +81,32 @@ def export_urls_for_fingerprint_task(
'total_count': result['total_count'],
'source': result['source'],
}
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
"""使用 Provider 导出 URL"""
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
total_count = 0
blacklist_filter = provider.get_blacklist_filter()
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in provider.iter_urls():
# 应用黑名单过滤(如果有)
if blacklist_filter and not blacklist_filter.is_allowed(url):
continue
f.write(f"{url}\n")
total_count += 1
if total_count % 1000 == 0:
logger.info("已导出 %d 个 URL...", total_count)
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
return {
'output_file': str(output_path),
'total_count': total_count,
'source': 'provider',
}

View File

@@ -1,7 +1,9 @@
"""
导出主机列表到 TXT 文件的 Task
使用 TargetExportService.export_hosts() 统一处理导出逻辑
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库导出
2. Provider 模式:使用 TargetProvider 从任意数据源导出
根据 Target 类型决定导出内容:
- DOMAIN: 从 Subdomain 表导出子域名
@@ -9,30 +11,39 @@
- CIDR: 展开 CIDR 范围内的所有 IP
"""
import logging
from typing import Optional
from pathlib import Path
from prefect import task
from apps.scan.services.target_export_service import create_export_service
from apps.scan.providers import TargetProvider, DatabaseTargetProvider, ProviderContext
logger = logging.getLogger(__name__)
@task(name="export_hosts")
def export_hosts_task(
target_id: int,
output_file: str,
target_id: Optional[int] = None,
provider: Optional[TargetProvider] = None,
batch_size: int = 1000
) -> dict:
"""
导出主机列表到 TXT 文件
支持两种模式:
1. 传统模式(向后兼容):传入 target_id从数据库导出
2. Provider 模式:传入 provider从任意数据源导出
根据 Target 类型自动决定导出内容:
- DOMAIN: 从 Subdomain 表导出子域名(流式处理,支持 10万+ 域名)
- IP: 直接写入 target.name单个 IP
- CIDR: 展开 CIDR 范围内的所有可用 IP
Args:
target_id: 目标 ID
output_file: 输出文件路径(绝对路径)
target_id: 目标 ID传统模式向后兼容
provider: TargetProvider 实例(新模式)
batch_size: 每次读取的批次大小,默认 1000仅对 DOMAIN 类型有效)
Returns:
@@ -40,26 +51,60 @@ def export_hosts_task(
'success': bool,
'output_file': str,
'total_count': int,
'target_type': str
'target_type': str # 仅传统模式返回
}
Raises:
ValueError: Target 不存在
ValueError: 参数错误target_id 和 provider 都未提供)
IOError: 文件写入失败
"""
# 使用工厂函数创建导出服务
export_service = create_export_service(target_id)
# 参数验证:至少提供一个
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
result = export_service.export_hosts(
target_id=target_id,
output_path=output_file,
batch_size=batch_size
)
# 向后兼容:如果没有提供 provider使用 target_id 创建 DatabaseTargetProvider
if provider is None:
logger.info("使用传统模式 - Target ID: %d", target_id)
provider = DatabaseTargetProvider(target_id=target_id)
use_legacy_mode = True
else:
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
use_legacy_mode = False
# 保持返回值格式不变(向后兼容)
return {
'success': result['success'],
'output_file': result['output_file'],
'total_count': result['total_count'],
'target_type': result['target_type']
# 确保输出目录存在
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 使用 Provider 导出主机列表
total_count = 0
blacklist_filter = provider.get_blacklist_filter()
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for host in provider.iter_hosts():
# 应用黑名单过滤(如果有)
if blacklist_filter and not blacklist_filter.is_allowed(host):
continue
f.write(f"{host}\n")
total_count += 1
if total_count % 1000 == 0:
logger.info("已导出 %d 个主机...", total_count)
logger.info("✓ 主机列表导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
# 构建返回值
result = {
'success': True,
'output_file': str(output_path),
'total_count': total_count,
}
# 传统模式:保持返回值格式不变(向后兼容)
if use_legacy_mode:
# 获取 target_type仅传统模式需要
from apps.targets.services import TargetService
target = TargetService().get_target(target_id)
result['target_type'] = target.type if target else 'unknown'
return result

View File

@@ -1,8 +1,9 @@
"""
导出站点URL到文件的Task
直接使用 HostPortMapping 表查询 host+port 组合拼接成URL格式写入文件
使用 TargetExportService.generate_default_urls() 处理默认值回退逻辑
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库导出
2. Provider 模式:使用 TargetProvider 从任意数据源导出
特殊逻辑:
- 80 端口:只生成 HTTP URL省略端口号
@@ -10,6 +11,7 @@
- 其他端口:生成 HTTP 和 HTTPS 两个URL带端口号
"""
import logging
from typing import Optional
from pathlib import Path
from prefect import task
@@ -17,6 +19,7 @@ from apps.asset.services import HostPortMappingService
from apps.scan.services.target_export_service import create_export_service
from apps.common.services import BlacklistService
from apps.common.utils import BlacklistFilter
from apps.scan.providers import TargetProvider, DatabaseTargetProvider, ProviderContext
logger = logging.getLogger(__name__)
@@ -39,26 +42,30 @@ def _generate_urls_from_port(host: str, port: int) -> list[str]:
@task(name="export_site_urls")
def export_site_urls_task(
target_id: int,
output_file: str,
target_id: Optional[int] = None,
provider: Optional[TargetProvider] = None,
batch_size: int = 1000
) -> dict:
"""
导出目标下的所有站点URL到文件(基于 HostPortMapping 表)
导出目标下的所有站点URL到文件
数据源: HostPortMapping (host + port) → Default
支持两种模式:
1. 传统模式(向后兼容):传入 target_id从 HostPortMapping 表导出
2. Provider 模式:传入 provider从任意数据源导出
特殊逻辑:
传统模式特殊逻辑:
- 80 端口:只生成 HTTP URL省略端口号
- 443 端口:只生成 HTTPS URL省略端口号
- 其他端口:生成 HTTP 和 HTTPS 两个URL带端口号
回退逻辑:
回退逻辑(仅传统模式)
- 如果 HostPortMapping 为空,使用 generate_default_urls() 生成默认 URL
Args:
target_id: 目标ID
output_file: 输出文件路径(绝对路径)
target_id: 目标ID传统模式向后兼容
provider: TargetProvider 实例(新模式)
batch_size: 每次处理的批次大小
Returns:
@@ -66,14 +73,62 @@ def export_site_urls_task(
'success': bool,
'output_file': str,
'total_urls': int,
'association_count': int, # 主机端口关联数量
'source': str, # 数据来源: "host_port" | "default"
'association_count': int, # 主机端口关联数量(仅传统模式)
'source': str, # 数据来源: "host_port" | "default" | "provider"
}
Raises:
ValueError: 参数错误
IOError: 文件写入失败
"""
# 参数验证:至少提供一个
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
# 向后兼容:如果没有提供 provider使用传统模式
if provider is None:
logger.info("使用传统模式 - Target ID: %d, 输出文件: %s", target_id, output_file)
return _export_site_urls_legacy(target_id, output_file, batch_size)
# Provider 模式
logger.info("使用 Provider 模式 - Provider: %s, 输出文件: %s", type(provider).__name__, output_file)
# 确保输出目录存在
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 使用 Provider 导出 URL 列表
total_urls = 0
blacklist_filter = provider.get_blacklist_filter()
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in provider.iter_urls():
# 应用黑名单过滤(如果有)
if blacklist_filter and not blacklist_filter.is_allowed(url):
continue
f.write(f"{url}\n")
total_urls += 1
if total_urls % 1000 == 0:
logger.info("已导出 %d 个URL...", total_urls)
logger.info("✓ URL导出完成 - 总数: %d, 文件: %s", total_urls, str(output_path))
return {
'success': True,
'output_file': str(output_path),
'total_urls': total_urls,
'source': 'provider',
}
def _export_site_urls_legacy(target_id: int, output_file: str, batch_size: int) -> dict:
"""
传统模式:从 HostPortMapping 表导出 URL
保持原有逻辑不变,确保向后兼容
"""
logger.info("开始统计站点URL - Target ID: %d, 输出文件: %s", target_id, output_file)
# 确保输出目录存在

View File

@@ -0,0 +1,240 @@
"""
Task 向后兼容性测试
Property 8: Task Backward Compatibility
*For any* 任务调用,当仅提供 target_id 参数时,任务应该创建 DatabaseTargetProvider
并使用它进行数据访问,行为与改造前一致。
**Validates: Requirements 6.1, 6.2, 6.4, 6.5**
"""
import tempfile
import pytest
from pathlib import Path
from unittest.mock import patch, MagicMock
from hypothesis import given, strategies as st, settings
from apps.scan.tasks.port_scan.export_hosts_task import export_hosts_task
from apps.scan.tasks.site_scan.export_site_urls_task import export_site_urls_task
from apps.scan.providers import ListTargetProvider
# 生成有效域名的策略
def valid_domain_strategy():
"""生成有效的域名"""
label = st.text(
alphabet=st.characters(whitelist_categories=('L',), min_codepoint=97, max_codepoint=122),
min_size=2,
max_size=10
)
return st.builds(
lambda a, b, c: f"{a}.{b}.{c}",
label, label, st.sampled_from(['com', 'net', 'org', 'io'])
)
class TestExportHostsTaskBackwardCompatibility:
"""export_hosts_task 向后兼容性测试"""
@given(
target_id=st.integers(min_value=1, max_value=1000),
hosts=st.lists(valid_domain_strategy(), min_size=1, max_size=10)
)
@settings(max_examples=50, deadline=None)
def test_property_8_legacy_mode_creates_database_provider(self, target_id, hosts):
"""
Property 8: Task Backward Compatibility (export_hosts_task)
Feature: scan-target-provider, Property 8: Task Backward Compatibility
**Validates: Requirements 6.1, 6.2, 6.4, 6.5**
For any target_id, when calling export_hosts_task with only target_id,
it should create a DatabaseTargetProvider and use it for data access.
"""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
output_file = f.name
try:
# Mock Target 和 SubdomainService
mock_target = MagicMock()
mock_target.type = 'domain'
mock_target.name = hosts[0]
with patch('apps.scan.tasks.port_scan.export_hosts_task.DatabaseTargetProvider') as mock_provider_class, \
patch('apps.targets.services.TargetService') as mock_target_service:
# 创建 mock provider 实例
mock_provider = MagicMock()
mock_provider.iter_hosts.return_value = iter(hosts)
mock_provider.get_blacklist_filter.return_value = None
mock_provider_class.return_value = mock_provider
# Mock TargetService
mock_target_service.return_value.get_target.return_value = mock_target
# 调用任务(传统模式:只传 target_id
result = export_hosts_task(
output_file=output_file,
target_id=target_id
)
# 验证:应该创建了 DatabaseTargetProvider
mock_provider_class.assert_called_once_with(target_id=target_id)
# 验证:返回值包含必需字段
assert result['success'] is True
assert result['output_file'] == output_file
assert result['total_count'] == len(hosts)
assert 'target_type' in result # 传统模式应该返回 target_type
# 验证:文件内容正确
with open(output_file, 'r') as f:
lines = [line.strip() for line in f.readlines()]
assert lines == hosts
finally:
Path(output_file).unlink(missing_ok=True)
def test_legacy_mode_with_provider_parameter(self):
"""测试当同时提供 target_id 和 provider 时provider 优先"""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
output_file = f.name
try:
hosts = ['example.com', 'test.com']
provider = ListTargetProvider(targets=hosts)
# 调用任务(同时提供 target_id 和 provider
result = export_hosts_task(
output_file=output_file,
target_id=123, # 应该被忽略
provider=provider
)
# 验证:使用了 provider
assert result['success'] is True
assert result['total_count'] == len(hosts)
assert 'target_type' not in result # Provider 模式不返回 target_type
# 验证:文件内容正确
with open(output_file, 'r') as f:
lines = [line.strip() for line in f.readlines()]
assert lines == hosts
finally:
Path(output_file).unlink(missing_ok=True)
def test_error_when_no_parameters(self):
"""测试当 target_id 和 provider 都未提供时抛出错误"""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
output_file = f.name
try:
with pytest.raises(ValueError, match="必须提供 target_id 或 provider 参数之一"):
export_hosts_task(output_file=output_file)
finally:
Path(output_file).unlink(missing_ok=True)
class TestExportSiteUrlsTaskBackwardCompatibility:
"""export_site_urls_task 向后兼容性测试"""
def test_property_8_legacy_mode_uses_traditional_logic(self):
"""
Property 8: Task Backward Compatibility (export_site_urls_task)
Feature: scan-target-provider, Property 8: Task Backward Compatibility
**Validates: Requirements 6.1, 6.2, 6.4, 6.5**
When calling export_site_urls_task with only target_id,
it should use the traditional logic (_export_site_urls_legacy).
"""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
output_file = f.name
try:
target_id = 123
# Mock HostPortMappingService
mock_associations = [
{'host': 'example.com', 'port': 80},
{'host': 'test.com', 'port': 443},
]
with patch('apps.scan.tasks.site_scan.export_site_urls_task.HostPortMappingService') as mock_service_class, \
patch('apps.scan.tasks.site_scan.export_site_urls_task.BlacklistService') as mock_blacklist_service:
# Mock HostPortMappingService
mock_service = MagicMock()
mock_service.iter_host_port_by_target.return_value = iter(mock_associations)
mock_service_class.return_value = mock_service
# Mock BlacklistService
mock_blacklist = MagicMock()
mock_blacklist.get_rules.return_value = []
mock_blacklist_service.return_value = mock_blacklist
# 调用任务(传统模式:只传 target_id
result = export_site_urls_task(
output_file=output_file,
target_id=target_id
)
# 验证:返回值包含传统模式的字段
assert result['success'] is True
assert result['output_file'] == output_file
assert result['total_urls'] == 2 # 80 端口生成 1 个 URL443 端口生成 1 个 URL
assert 'association_count' in result # 传统模式应该返回 association_count
assert result['association_count'] == 2
assert result['source'] == 'host_port'
# 验证:文件内容正确
with open(output_file, 'r') as f:
lines = [line.strip() for line in f.readlines()]
assert 'http://example.com' in lines
assert 'https://test.com' in lines
finally:
Path(output_file).unlink(missing_ok=True)
def test_provider_mode_uses_provider_logic(self):
"""测试当提供 provider 时使用 Provider 模式"""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
output_file = f.name
try:
urls = ['https://example.com', 'https://test.com']
provider = ListTargetProvider(targets=urls)
# 调用任务Provider 模式)
result = export_site_urls_task(
output_file=output_file,
provider=provider
)
# 验证:使用了 provider
assert result['success'] is True
assert result['total_urls'] == len(urls)
assert 'association_count' not in result # Provider 模式不返回 association_count
assert result['source'] == 'provider'
# 验证:文件内容正确
with open(output_file, 'r') as f:
lines = [line.strip() for line in f.readlines()]
assert lines == urls
finally:
Path(output_file).unlink(missing_ok=True)
def test_error_when_no_parameters(self):
"""测试当 target_id 和 provider 都未提供时抛出错误"""
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f:
output_file = f.name
try:
with pytest.raises(ValueError, match="必须提供 target_id 或 provider 参数之一"):
export_site_urls_task(output_file=output_file)
finally:
Path(output_file).unlink(missing_ok=True)

View File

@@ -1,17 +1,23 @@
"""
导出站点 URL 列表任务
使用 export_urls_with_fallback 用例函数处理回退链逻辑
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库导出
2. Provider 模式:使用 TargetProvider 从任意数据源导出
数据源: WebSite.url → Default用于 katana 等爬虫工具)
"""
import logging
from typing import Optional
from pathlib import Path
from prefect import task
from apps.scan.services.target_export_service import (
export_urls_with_fallback,
DataSource,
)
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
logger = logging.getLogger(__name__)
@@ -23,21 +29,27 @@ logger = logging.getLogger(__name__)
)
def export_sites_task(
output_file: str,
target_id: int,
scan_id: int,
target_id: Optional[int] = None,
scan_id: Optional[int] = None,
provider: Optional[TargetProvider] = None,
batch_size: int = 1000
) -> dict:
"""
导出站点 URL 列表到文件(用于 katana 等爬虫工具)
数据源优先级(回退链)
支持两种模式
1. 传统模式(向后兼容):传入 target_id从数据库导出
2. Provider 模式:传入 provider从任意数据源导出
数据源优先级(回退链,仅传统模式):
1. WebSite 表 - 站点级别 URL
2. 默认生成 - 根据 Target 类型生成 http(s)://target_name
Args:
output_file: 输出文件路径
target_id: 目标 ID
target_id: 目标 ID(传统模式,向后兼容)
scan_id: 扫描 ID保留参数兼容旧调用
provider: TargetProvider 实例(新模式)
batch_size: 批次大小(内存优化)
Returns:
@@ -50,6 +62,17 @@ def export_sites_task(
ValueError: 参数错误
RuntimeError: 执行失败
"""
# 参数验证:至少提供一个
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
# Provider 模式:使用 TargetProvider 导出
if provider is not None:
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
return _export_with_provider(output_file, provider)
# 传统模式:使用 export_urls_with_fallback
logger.info("使用传统模式 - Target ID: %d", target_id)
result = export_urls_with_fallback(
target_id=target_id,
output_file=output_file,
@@ -67,3 +90,31 @@ def export_sites_task(
'output_file': result['output_file'],
'asset_count': result['total_count'],
}
def _export_with_provider(output_file: str, provider: TargetProvider) -> dict:
"""使用 Provider 导出 URL"""
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
total_count = 0
blacklist_filter = provider.get_blacklist_filter()
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in provider.iter_urls():
# 应用黑名单过滤(如果有)
if blacklist_filter and not blacklist_filter.is_allowed(url):
continue
f.write(f"{url}\n")
total_count += 1
if total_count % 1000 == 0:
logger.info("已导出 %d 个 URL...", total_count)
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
return {
'output_file': str(output_path),
'asset_count': total_count,
}

View File

@@ -1,15 +1,18 @@
"""导出 Endpoint URL 到文件的 Task
使用 export_urls_with_fallback 用例函数处理回退链逻辑
支持两种模式:
1. 传统模式(向后兼容):使用 target_id 从数据库导出
2. Provider 模式:使用 TargetProvider 从任意数据源导出
数据源优先级(回退链):
数据源优先级(回退链,仅传统模式
1. Endpoint.url - 最精细的 URL含路径、参数等
2. WebSite.url - 站点级别 URL
3. 默认生成 - 根据 Target 类型生成 http(s)://target_name
"""
import logging
from typing import Dict
from typing import Dict, Optional
from pathlib import Path
from prefect import task
@@ -17,26 +20,33 @@ from apps.scan.services.target_export_service import (
export_urls_with_fallback,
DataSource,
)
from apps.scan.providers import TargetProvider, DatabaseTargetProvider
logger = logging.getLogger(__name__)
@task(name="export_endpoints")
def export_endpoints_task(
target_id: int,
output_file: str,
target_id: Optional[int] = None,
output_file: str = "",
provider: Optional[TargetProvider] = None,
batch_size: int = 1000,
) -> Dict[str, object]:
"""导出目标下的所有 Endpoint URL 到文本文件。
数据源优先级(回退链)
支持两种模式
1. 传统模式(向后兼容):传入 target_id从数据库导出
2. Provider 模式:传入 provider从任意数据源导出
数据源优先级(回退链,仅传统模式):
1. Endpoint 表 - 最精细的 URL含路径、参数等
2. WebSite 表 - 站点级别 URL
3. 默认生成 - 根据 Target 类型生成 http(s)://target_name
Args:
target_id: 目标 ID
target_id: 目标 ID(传统模式,向后兼容)
output_file: 输出文件路径(绝对路径)
provider: TargetProvider 实例(新模式)
batch_size: 每次从数据库迭代的批大小
Returns:
@@ -44,9 +54,20 @@ def export_endpoints_task(
"success": bool,
"output_file": str,
"total_count": int,
"source": str, # 数据来源: "endpoint" | "website" | "default" | "none"
"source": str, # 数据来源: "endpoint" | "website" | "default" | "none" | "provider"
}
"""
# 参数验证:至少提供一个
if target_id is None and provider is None:
raise ValueError("必须提供 target_id 或 provider 参数之一")
# Provider 模式:使用 TargetProvider 导出
if provider is not None:
logger.info("使用 Provider 模式 - Provider: %s", type(provider).__name__)
return _export_with_provider(output_file, provider)
# 传统模式:使用 export_urls_with_fallback
logger.info("使用传统模式 - Target ID: %d", target_id)
result = export_urls_with_fallback(
target_id=target_id,
output_file=output_file,
@@ -65,3 +86,33 @@ def export_endpoints_task(
"total_count": result['total_count'],
"source": result['source'],
}
def _export_with_provider(output_file: str, provider: TargetProvider) -> Dict[str, object]:
"""使用 Provider 导出 URL"""
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
total_count = 0
blacklist_filter = provider.get_blacklist_filter()
with open(output_path, 'w', encoding='utf-8', buffering=8192) as f:
for url in provider.iter_urls():
# 应用黑名单过滤(如果有)
if blacklist_filter and not blacklist_filter.is_allowed(url):
continue
f.write(f"{url}\n")
total_count += 1
if total_count % 1000 == 0:
logger.info("已导出 %d 个 URL...", total_count)
logger.info("✓ URL 导出完成 - 总数: %d, 文件: %s", total_count, str(output_path))
return {
"success": True,
"output_file": str(output_path),
"total_count": total_count,
"source": "provider",
}

View File

@@ -38,6 +38,7 @@ packaging>=21.0 # 版本比较
# 测试框架
pytest==8.0.0
pytest-django==4.7.0
hypothesis>=6.100.0 # 属性测试框架
# 工具库
python-dateutil==2.9.0

View File

@@ -88,7 +88,7 @@ function highlightSearch(html: string, query: string): string {
// 转义正则特殊字符
const escapedQuery = query.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")
const regex = new RegExp(`(${escapedQuery})`, "gi")
const regex = new RegExp(`(${escapedQuery})`, "giu")
// 在标签外的文本中高亮关键词
return html.replace(/(<[^>]+>)|([^<]+)/g, (match, tag, text) => {