""" 规则引擎模块 提供违禁词检测、规则冲突检测和规则版本管理功能 验收标准: - 违禁词召回率 ≥ 95% - 误报率 ≤ 5% - 语境感知检测能力 """ import re from dataclasses import dataclass, field from typing import Any from datetime import datetime @dataclass class DetectionResult: """检测结果""" word: str position: int context: str = "" severity: str = "medium" confidence: float = 1.0 @dataclass class ProhibitedWordResult: """违禁词检测结果""" detected_words: list[DetectionResult] total_count: int has_violations: bool @dataclass class ContextClassificationResult: """语境分类结果""" context_type: str # "advertisement", "daily", "unknown" confidence: float is_advertisement: bool @dataclass class ConflictDetail: """冲突详情""" rule1: dict[str, Any] rule2: dict[str, Any] conflict_type: str description: str @dataclass class ConflictResult: """规则冲突检测结果""" has_conflicts: bool conflicts: list[ConflictDetail] @dataclass class RuleVersion: """规则版本""" version_id: str rules: dict[str, Any] created_at: datetime is_active: bool = True class ContextClassifier: """语境分类器""" # 广告语境关键词 AD_KEYWORDS = { "产品", "购买", "下单", "优惠", "折扣", "促销", "限时", "效果", "功效", "推荐", "种草", "链接", "商品", "价格", } # 日常语境关键词 DAILY_KEYWORDS = { "今天", "昨天", "明天", "心情", "感觉", "天气", "朋友", "家人", "生活", "日常", "分享", "记录", } def classify(self, text: str) -> ContextClassificationResult: """分类文本语境""" if not text: return ContextClassificationResult( context_type="unknown", confidence=0.0, is_advertisement=False, ) ad_score = sum(1 for kw in self.AD_KEYWORDS if kw in text) daily_score = sum(1 for kw in self.DAILY_KEYWORDS if kw in text) total = ad_score + daily_score if total == 0: return ContextClassificationResult( context_type="unknown", confidence=0.5, is_advertisement=False, ) if ad_score > daily_score: return ContextClassificationResult( context_type="advertisement", confidence=ad_score / (ad_score + daily_score), is_advertisement=True, ) else: return ContextClassificationResult( context_type="daily", confidence=daily_score / (ad_score + daily_score), is_advertisement=False, ) class ProhibitedWordDetector: """违禁词检测器""" def __init__(self, rules: list[dict[str, Any]] | None = None): """ 初始化检测器 Args: rules: 违禁词规则列表,每个规则包含 word, reason, severity 等字段 """ self.rules = rules or [] self.context_classifier = ContextClassifier() self._build_pattern() def _build_pattern(self) -> None: """构建正则表达式模式""" if not self.rules: self.pattern = None return words = [re.escape(r.get("word", "")) for r in self.rules if r.get("word")] if words: # 按长度降序排序,确保长词优先匹配 words.sort(key=len, reverse=True) self.pattern = re.compile("|".join(words)) else: self.pattern = None def detect( self, text: str, context: str = "advertisement" ) -> ProhibitedWordResult: """ 检测文本中的违禁词 Args: text: 待检测文本 context: 语境类型 ("advertisement" 或 "daily") Returns: 检测结果 """ if not text or not self.pattern: return ProhibitedWordResult( detected_words=[], total_count=0, has_violations=False, ) # 如果是日常语境,降低敏感度 if context == "daily": return ProhibitedWordResult( detected_words=[], total_count=0, has_violations=False, ) detected = [] for match in self.pattern.finditer(text): word = match.group() rule = self._find_rule(word) detected.append(DetectionResult( word=word, position=match.start(), context=text[max(0, match.start()-10):match.end()+10], severity=rule.get("severity", "medium") if rule else "medium", confidence=0.95, )) return ProhibitedWordResult( detected_words=detected, total_count=len(detected), has_violations=len(detected) > 0, ) def detect_with_context_awareness(self, text: str) -> ProhibitedWordResult: """ 带语境感知的违禁词检测 自动判断文本语境,在日常语境下降低敏感度 """ context_result = self.context_classifier.classify(text) if context_result.is_advertisement: return self.detect(text, context="advertisement") else: return self.detect(text, context="daily") def _find_rule(self, word: str) -> dict[str, Any] | None: """查找匹配的规则""" for rule in self.rules: if rule.get("word") == word: return rule return None class RuleConflictDetector: """规则冲突检测器""" def detect_conflicts( self, brief_rules: dict[str, Any], platform_rules: dict[str, Any] ) -> ConflictResult: """ 检测 Brief 规则和平台规则之间的冲突 Args: brief_rules: Brief 定义的规则 platform_rules: 平台规则 Returns: 冲突检测结果 """ conflicts = [] brief_forbidden = set( w.get("word", "") for w in brief_rules.get("forbidden_words", []) ) platform_forbidden = set( w.get("word", "") for w in platform_rules.get("forbidden_words", []) ) # 检查是否有 Brief 允许但平台禁止的词 # (这里简化实现,实际可能需要更复杂的逻辑) # 检查卖点是否包含平台禁用词 selling_points = brief_rules.get("selling_points", []) for sp in selling_points: text = sp.get("text", "") for forbidden in platform_forbidden: if forbidden in text: conflicts.append(ConflictDetail( rule1={"type": "selling_point", "text": text}, rule2={"type": "platform_forbidden", "word": forbidden}, conflict_type="selling_point_contains_forbidden", description=f"卖点 '{text}' 包含平台禁用词 '{forbidden}'", )) return ConflictResult( has_conflicts=len(conflicts) > 0, conflicts=conflicts, ) def check_compatibility( self, rule1: dict[str, Any], rule2: dict[str, Any] ) -> bool: """检查两条规则是否兼容""" # 简化实现:检查是否有直接冲突 if rule1.get("type") == "required" and rule2.get("type") == "forbidden": if rule1.get("word") == rule2.get("word"): return False return True class RuleVersionManager: """规则版本管理器""" def __init__(self): self.versions: list[RuleVersion] = [] self._current_version: RuleVersion | None = None def create_version(self, rules: dict[str, Any]) -> RuleVersion: """创建新版本""" version = RuleVersion( version_id=f"v{len(self.versions) + 1}", rules=rules, created_at=datetime.now(), is_active=True, ) # 将之前的版本设为非活动 if self._current_version: self._current_version.is_active = False self.versions.append(version) self._current_version = version return version def get_current_version(self) -> RuleVersion | None: """获取当前活动版本""" return self._current_version def rollback(self, version_id: str) -> RuleVersion | None: """回滚到指定版本""" for version in self.versions: if version.version_id == version_id: # 将当前版本设为非活动 if self._current_version: self._current_version.is_active = False # 激活目标版本 version.is_active = True self._current_version = version return version return None def get_history(self) -> list[RuleVersion]: """获取版本历史""" return list(self.versions) class PlatformRuleSyncService: """平台规则同步服务""" def __init__(self): self.synced_rules: dict[str, dict[str, Any]] = {} self.last_sync: dict[str, datetime] = {} def sync_platform_rules(self, platform: str) -> dict[str, Any]: """ 同步平台规则 Args: platform: 平台标识 (douyin, xiaohongshu, etc.) Returns: 同步后的规则 """ # 模拟同步(实际应从平台 API 获取) rules = { "platform": platform, "version": "2026.01", "forbidden_words": [ {"word": "最", "category": "ad_law"}, {"word": "第一", "category": "ad_law"}, ], "synced_at": datetime.now().isoformat(), } self.synced_rules[platform] = rules self.last_sync[platform] = datetime.now() return rules def get_rules(self, platform: str) -> dict[str, Any] | None: """获取已同步的平台规则""" return self.synced_rules.get(platform) def is_sync_needed(self, platform: str, max_age_hours: int = 24) -> bool: """检查是否需要重新同步""" if platform not in self.last_sync: return True age = datetime.now() - self.last_sync[platform] return age.total_seconds() > max_age_hours * 3600