RAG 混合检索价值与知识增强挑战Retrieval-Augmented Generation (RAG) 作为大语言模型知识增强的核心技术,其检索质量直接决定了生成内容的准确性和可靠性。传统 RAG 系统往往面临检索精度不足、知识覆盖不全、多模态融合困难、实时更新滞后、生成质量不稳定等核心痛点。现代化混合检索需要从架构层面考虑多源异构、从算法角度考虑语义理解、从工程维度考虑实时性能,构建智能化的知识增强体系。企业级 RAG 需要解决跨域知识整合、多语言语义理解、实时知识更新、生成质量控制、隐私安全保护等复杂挑战。通过向量检索与关键词检索的深度融合、知识图谱的结构化补充、多模态信息的统一表示和智能化的质量评估,可以实现知识检索的精准匹配和生成内容的高质量输出,为大语言模型应用提供可靠的知识基础设施。核心检索架构与混合策略多路召回检索架构构建基于多路召回的混合检索架构,支持语义与字面匹配融合:# hybrid_retriever.py
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from abc import ABC, abstractmethod
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from elasticsearch import Elasticsearch
import networkx as nx
from PIL import Image
import clip
import spacy
from rank_bm25 import BM25Okapi
import redis
import json
from datetime import datetime, timedelta
@dataclass
class RetrievalResult:
"""检索结果数据结构"""
content: str
score: float
source: str
metadata: Dict[str, Any]
embedding: Optional[np.ndarray] = None
confidence: float = 0.0
retrieval_method: str = ""
@dataclass
class HybridRetrievalConfig:
"""混合检索配置"""
vector_weight: float = 0.4
vector_model: str = "sentence-transformers/all-mpnet-base-v2"
vector_top_k: int = 20
vector_threshold: float = 0.7
# 关键词检索配置
keyword_weight: float = 0.3
keyword_top_k: int = 15
keyword_boost_fields: List[str] = None
# 知识图谱检索配置
graph_weight: float = 0.2
graph_hop_limit: int = 2
graph_entity_types: List[str] = None
# 多模态检索配置
multimodal_weight: float = 0.1
image_model: str = "clip-ViT-B-32"
audio_model: str = "wav2vec2-base"
# 重排序配置
rerank_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
rerank_top_k: int = 10
rerank_threshold: float = 0.8
# 缓存配置
cache_enabled: bool = True
cache_ttl: int = 3600
cache_key_prefix: str = "rag_hybrid"
class BaseRetriever(ABC):
"""基础检索器抽象类"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.name = self.__class__.__name__
@abstractmethod
async def retrieve(self, query: str, top_k: int = 10) -> List[RetrievalResult]:
"""执行检索"""
pass
@abstractmethod
async def similarity_score(self, query: str, document: str) -> float:
"""计算相似度分数"""
pass
class VectorRetriever(BaseRetriever):
"""向量检索器"""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.model = SentenceTransformer(config.get('model', 'all-mpnet-base-v2'))
self.index = None # 向量索引
self.documents = []
self.embeddings = None
async def build_index(self, documents: List[str], metadata: List[Dict] = None):
"""构建向量索引"""
print(f"构建向量索引,文档数量: {len(documents)}")
# 生成文档嵌入
self.embeddings = self.model.encode(documents, show_progress_bar=True)
self.documents = documents
self.metadata = metadata or [{}] * len(documents)
# 构建 FAISS 索引(假设使用 FAISS)
import faiss
dimension = self.embeddings.shape[1]
self.index = faiss.IndexFlatIP(dimension) # 内积相似度
faiss.normalize_L2(self.embeddings) # L2 归一化
self.index.add(self.embeddings.astype(np.float32))
print(f"向量索引构建完成,维度: {dimension}")
async def retrieve(self, query: str, top_k: int = 10) -> List[RetrievalResult]:
"""向量检索"""
# 生成查询嵌入
query_embedding = self.model.encode([query])
faiss.normalize_L2(query_embedding)
# 相似度搜索
scores, indices = self.index.search(query_embedding.astype(np.float32), top_k)
results = []
for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
if score > self.config.get('threshold', 0.5):
results.append(RetrievalResult(
content=self.documents[idx],
score=float(score),
source="vector",
metadata=self.metadata[idx],
embedding=self.embeddings[idx],
confidence=float(score),
retrieval_method="vector_similarity"
))
return results
async def similarity_score(self, query: str, document: str) -> float:
"""计算向量相似度"""
query_embedding = self.model.encode([query])
doc_embedding = self.model.encode([document])
# 余弦相似度
similarity = np.dot(query_embedding[0], doc_embedding[0]) / (
np.linalg.norm(query_embedding[0]) * np.linalg.norm(doc_embedding[0])
)
return float(similarity)
class KeywordRetriever(BaseRetriever):
"""关键词检索器"""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.elasticsearch = Elasticsearch(
hosts=config.get('hosts', ['localhost:9200']),
timeout=30
)
self.index_name = config.get('index', 'documents')
self.tokenized_corpus = []
self.bm25 = None
async def build_index(self, documents: List[str]):
"""构建关键词索引"""
print(f"构建关键词索引,文档数量: {len(documents)}")
# 使用 BM25 算法
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)
stop_words = set(stopwords.words('english'))
# 文本预处理和分词
self.tokenized_corpus = []
for doc in documents:
tokens = word_tokenize(doc.lower())
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
self.tokenized_corpus.append(tokens)
# 构建 BM25 模型
self.bm25 = BM25Okapi(self.tokenized_corpus)
self.documents = documents
print("关键词索引构建完成")
async def retrieve(self, query: str, top_k: int = 10) -> List[RetrievalResult]:
"""关键词检索"""
# BM25 检索
import nltk
from nltk.tokenize import word_tokenize
query_tokens = word_tokenize(query.lower())
query_tokens = [token for token in query_tokens if token.isalnum()]
# 计算 BM25 分数
scores = self.bm25.get_scores(query_tokens)
# 获取 top-k 结果
top_indices = np.argsort(scores)[-top_k:][::-1]
results = []
for idx in top_indices:
if scores[idx] > 0: # 只返回有分数的文档
results.append(RetrievalResult(
content=self.documents[idx],
score=float(scores[idx]),
source="keyword",
metadata={"bm25_score": float(scores[idx])},
confidence=float(scores[idx]) / max(scores) if max(scores) > 0 else 0,
retrieval_method="bm25_keyword"
))
return results
async def similarity_score(self, query: str, document: str) -> float:
"""计算关键词相似度"""
# 简单的词频相似度
import nltk
from nltk.tokenize import word_tokenize
from collections import Counter
query_tokens = set(word_tokenize(query.lower()))
doc_tokens = set(word_tokenize(document.lower()))
if not query_tokens or not doc_tokens:
return 0.0
# Jaccard 相似度
intersection = len(query_tokens.intersection(doc_tokens))
union = len(query_tokens.union(doc_tokens))
return intersection / union if union > 0 else 0.0
class GraphRetriever(BaseRetriever):
"""知识图谱检索器"""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.graph = nx.DiGraph()
self.entity_linker = spacy.load("en_core_web_sm")
async def build_graph(self, triples: List[Tuple[str, str, str]]):
"""构建知识图谱"""
print(f"构建知识图谱,三元组数量: {len(triples)}")
for subject, predicate, object in triples:
self.graph.add_edge(subject, object, relation=predicate)
print(f"知识图谱构建完成,实体数量: {self.graph.number_of_nodes()}")
async def retrieve(self, query: str, top_k: int = 10) -> List[RetrievalResult]:
"""图谱检索"""
# 实体识别
doc = self.entity_linker(query)
entities = [ent.text for ent in doc.ents]
if not entities:
return []
results = []
for entity in entities:
if entity in self.graph:
# 多跳邻居检索
neighbors = list(nx.single_source_shortest_path_length(
self.graph, entity, cutoff=self.config.get('hop_limit', 2)
).keys())
for neighbor in neighbors[1:]: # 排除自身
# 构建路径信息
path = nx.shortest_path(self.graph, entity, neighbor)
path_info = []
for i in range(len(path) - 1):
edge_data = self.graph.get_edge_data(path[i], path[i + 1])
relation = edge_data.get('relation', 'related')
path_info.append(f"{path[i]} -[{relation}]-> {path[i + 1]}")
results.append(RetrievalResult(
content=f"Entity: {neighbor}, Path: {' -> '.join(path_info)}",
score=1.0 / len(path), # 路径越短分数越高
source="knowledge_graph",
metadata={
"source_entity": entity,
"target_entity": neighbor,
"path_length": len(path),
"path_info": path_info
},
confidence=1.0 / len(path),
retrieval_method="graph_traversal"
))
# 按分数排序并返回 top-k
results.sort(key=lambda x: x.score, reverse=True)
return results[:top_k]
async def similarity_score(self, query: str, document: str) -> float:
"""计算图谱相似度"""
# 基于实体重叠的相似度
doc = self.entity_linker(query)
query_entities = set([ent.text for ent in doc.ents])
# 简单的实体匹配
if query_entities:
matched_entities = sum(1 for entity in query_entities if entity in document)
return matched_entities / len(query_entities)
return 0.0
class MultimodalRetriever(BaseRetriever):
"""多模态检索器"""
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# 加载 CLIP 模型
self.clip_model, self.clip_preprocess = clip.load(
config.get('image_model', 'ViT-B/32'),
device=self.device
)
self.image_embeddings = []
self.image_metadata = []
async def build_multimodal_index(self, images: List[Image.Image], metadata: List[Dict]):
"""构建多模态索引"""
print(f"构建多模态索引,图像数量: {len(images)}")
with torch.no_grad():
for i, (image, meta) in enumerate(zip(images, metadata)):
# 预处理图像
image_input = self.clip_preprocess(image).unsqueeze(0).to(self.device)
# 提取图像特征
image_features = self.clip_model.encode_image(image_input)
image_features /= image_features.norm(dim=-1, keepdim=True)
self.image_embeddings.append(image_features.cpu().numpy())
self.image_metadata.append(meta)
if i % 100 == 0:
print(f"处理进度: {i}/{len(images)}")
print("多模态索引构建完成")
async def retrieve(self, query: str, top_k: int = 5) -> List[RetrievalResult]:
"""多模态检索"""
with torch.no_grad():
# 文本特征提取
text = clip.tokenize([query]).to(self.device)
text_features = self.clip_model.encode_text(text)
text_features /= text_features.norm(dim=-1, keepdim=True)
results = []
for i, (image_features, meta) in enumerate(zip(self.image_embeddings, self.image_metadata)):
# 计算相似度
similarity = np.dot(text_features.cpu().numpy(), image_features.T).item()
if similarity > 0.2: # 相似度阈值
results.append(RetrievalResult(
content=f"Image: {meta.get('filename', 'unknown')}, Description: {meta.get('description', '')}",
score=similarity,
source="multimodal",
metadata={
"filename": meta.get('filename'),
"description": meta.get('description'),
"similarity": similarity,
"image_index": i
},
confidence=similarity,
retrieval_method="clip_multimodal"
))
# 按相似度排序
results.sort(key=lambda x: x.score, reverse=True)
return results[:top_k]
async def similarity_score(self, query: str, image_description: str) -> float:
"""计算多模态相似度"""
# 这里简化处理,实际应该处理图像
return await self.similarity_score(query, image_description)
class HybridRetriever:
"""混合检索器协调器"""
def __init__(self, config: HybridRetrievalConfig):
self.config = config
self.retrievers = {}
self.cache = redis.Redis(
host=config.cache_host or 'localhost',
port=config.cache_port or 6379,
decode_responses=True
)
# 初始化各个检索器
self._initialize_retrievers()
def _initialize_retrievers(self):
"""初始化各个检索器"""
# 向量检索器
if self.config.vector_weight > 0:
self.retrievers['vector'] = VectorRetriever({
'model': self.config.vector_model,
'threshold': self.config.vector_threshold
})
# 关键词检索器
if self.config.keyword_weight > 0:
self.retrievers['keyword'] = KeywordRetriever({
'top_k': self.config.keyword_top_k,
'boost_fields': self.config.keyword_boost_fields or []
})
# 知识图谱检索器
if self.config.graph_weight > 0:
self.retrievers['graph'] = GraphRetriever({
'hop_limit': self.config.graph_hop_limit,
'entity_types': self.config.graph_entity_types or []
})
# 多模态检索器
if self.config.multimodal_weight > 0:
self.retrievers['multimodal'] = MultimodalRetriever({
'image_model': self.config.image_model,
'audio_model': self.config.audio_model
})
async def hybrid_retrieve(self, query: str, top_k: int = 10) -> List[RetrievalResult]:
"""执行混合检索"""
# 检查缓存
if self.config.cache_enabled:
cache_key = f"{self.config.cache_key_prefix}:{hash(query)}:{top_k}"
cached_result = self.cache.get(cache_key)
if cached_result:
return [RetrievalResult(**item) for item in json.loads(cached_result)]
# 并行执行各个检索器
retrieval_tasks = []
for name, retriever in self.retrievers.items():
task = retriever.retrieve(query, top_k * 2) # 获取更多结果用于融合
retrieval_tasks.append((name, task))
# 等待所有检索完成
retrieval_results = {}
for name, task in retrieval_tasks:
try:
results = await task
retrieval_results[name] = results
except Exception as e:
print(f"检索器 {name} 执行失败: {e}")
retrieval_results[name] = []
# 结果融合与重排序
fused_results = await self.fuse_results(retrieval_results, query)
# 缓存结果
if self.config.cache_enabled:
cache_data = json.dumps([result.__dict__ for result in fused_results])
self.cache.setex(cache_key, self.config.cache_ttl, cache_data)
return fused_results[:top_k]
async def fuse_results(self, retrieval_results: Dict[str, List[RetrievalResult]],
query: str) -> List[RetrievalResult]:
"""融合多路检索结果"""
# 收集所有结果
all_results = []
for method, results in retrieval_results.items():
weight = getattr(self.config, f'{method}_weight', 0.0)
for result in results:
# 应用权重并归一化分数
weighted_score = result.score * weight
result.score = weighted_score
all_results.append(result)
# 去重(基于内容相似度)
unique_results = self.deduplicate_results(all_results)
# 重排序
reranked_results = await self.rerank_results(unique_results, query)
return reranked_results
def deduplicate_results(self, results: List[RetrievalResult]) -> List[RetrievalResult]:
"""结果去重"""
seen_content = set()
unique_results = []
for result in results:
# 使用内容哈希作为去重键
content_hash = hash(result.content.strip().lower())
if content_hash not in seen_content:
seen_content.add(content_hash)
unique_results.append(result)
return unique_results
async def rerank_results(self, results: List[RetrievalResult], query: str) -> List[RetrievalResult]:
"""重排序结果"""
if not results:
return results
# 使用交叉编码器进行重排序
try:
from sentence_transformers import CrossEncoder
reranker = CrossEncoder(self.config.rerank_model)
# 准备输入对
input_pairs = [[query, result.content] for result in results]
# 获取重排序分数
rerank_scores = reranker.predict(input_pairs)
# 更新分数
for i, result in enumerate(results):
result.score = float(rerank_scores[i])
result.confidence = float(rerank_scores[i])
# 按重排序分数排序
results.sort(key=lambda x: x.score, reverse=True)
except Exception as e:
print(f"重排序失败: {e}")
# 回退到原始分数排序
results.sort(key=lambda x: x.score, reverse=True)
return results
async def build_knowledge_base(self, documents: List[str],
triples: List[Tuple[str, str, str]] = None,
images: List[Image.Image] = None,
metadata: List[Dict] = None):
"""构建知识库"""
print("开始构建混合检索知识库...")
# 构建向量索引
if 'vector' in self.retrievers:
await self.retrievers['vector'].build_index(documents, metadata)
# 构建关键词索引
if 'keyword' in self.retrievers:
await self.retrievers['keyword'].build_index(documents)
# 构建知识图谱
if 'graph' in self.retrievers and triples:
await self.retrievers['graph'].build_graph(triples)
# 构建多模态索引
if 'multimodal' in self.retrievers and images:
await self.retrievers['multimodal'].build_multimodal_index(images, metadata)
print("混合检索知识库构建完成")
知识增强生成与质量控制智能上下文组装实现基于相关性和多样性的上下文智能组装:# context_assembler.py
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity
import networkx as nx
from transformers import AutoTokenizer
@dataclass
class ContextConfig:
"""上下文组装配置"""
max_tokens: int = 2000
max_chunks: int = 10
diversity_weight: float = 0.3
relevance_weight: float = 0.7
clustering_enabled: bool = True
redundancy_removal: bool = True
source_diversity: bool = True
temporal_relevance: bool = True
class IntelligentContextAssembler:
"""智能上下文组装器"""
def __init__(self, config: ContextConfig, tokenizer_name: str = "gpt2"):
self.config = config
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
def assemble_context(self, query: str, retrieval_results: List[RetrievalResult]) -> Dict[str, Any]:
"""智能组装上下文"""
if not retrieval_results:
return {
"context": "",
"chunks": [],
"metadata": {
"total_chunks": 0,
"total_tokens": 0,
"sources": [],
"diversity_score": 0.0
}
}
# 1. 相关性重排序
ranked_results = self.rerank_by_relevance(query, retrieval_results)
# 2. 多样性选择
if self.config.clustering_enabled:
diverse_results = self.select_diverse_results(ranked_results)
else:
diverse_results = ranked_results
# 3. 冗余移除
if self.config.redundancy_removal:
filtered_results = self.remove_redundancy(diverse_results)
else:
filtered_results = diverse_results
# 4. 上下文组装
context_chunks = self.build_context_chunks(filtered_results)
# 5. 元数据生成
metadata = self.generate_metadata(context_chunks)
return {
"context": "\n\n".join(chunk["content"] for chunk in context_chunks),
"chunks": context_chunks,
"metadata": metadata
}
def rerank_by_relevance(self, query: str, results: List[RetrievalResult]) -> List[RetrievalResult]:
"""基于查询相关性重排序"""
# 使用交叉编码器计算相关性分数
try:
from sentence_transformers import CrossEncoder
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
input_pairs = [[query, result.content] for result in results]
relevance_scores = reranker.predict(input_pairs)
# 更新分数并排序
for i, result in enumerate(results):
result.score = float(relevance_scores[i])
results.sort(key=lambda x: x.score, reverse=True)
except Exception as e:
print(f"重排序失败: {e}")
# 使用原始分数
results.sort(key=lambda x: x.score, reverse=True)
return results
def select_diverse_results(self, results: List[RetrievalResult]) -> List[RetrievalResult]:
"""选择多样性结果"""
if len(results) < 3:
return results
# 提取嵌入向量
embeddings = []
valid_results = []
for result in results:
if result.embedding is not None:
embeddings.append(result.embedding)
valid_results.append(result)
if len(embeddings) < 3:
return results[:self.config.max_chunks]
embeddings = np.array(embeddings)
# K-means 聚类
n_clusters = min(len(embeddings), self.config.max_chunks // 2)
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
cluster_labels = kmeans.fit_predict(embeddings)
# 从每个聚类中选择代表性结果
diverse_results = []
for cluster_id in range(n_clusters):
cluster_indices = np.where(cluster_labels == cluster_id)[0]
# 选择聚类中分数最高的结果
best_idx = max(cluster_indices, key=lambda i: valid_results[i].score)
diverse_results.append(valid_results[best_idx])
# 如果聚类中有多个结果,也选择多样性较高的
if len(cluster_indices) > 1:
# 计算聚类内多样性
cluster_embeddings = embeddings[cluster_indices]
centroid = np.mean(cluster_embeddings, axis=0)
# 选择与质心距离适中的结果
distances = cosine_similarity([centroid], cluster_embeddings)[0]
diversity_idx = cluster_indices[np.argsort(distances)[len(distances)//2]]
if diversity_idx != best_idx:
diverse_results.append(valid_results[diversity_idx])
return diverse_results
def remove_redundancy(self, results: List[RetrievalResult]) -> List[RetrievalResult]:
"""移除冗余内容"""
if len(results) < 2:
return results
filtered_results = [results[0]] # 保留第一个结果
for i in range(1, len(results)):
current_result = results[i]
is_redundant = False
# 检查与已保留结果的相似度
for kept_result in filtered_results:
similarity = self.calculate_content_similarity(
current_result.content,
kept_result.content
)
if similarity > 0.8: # 相似度阈值
is_redundant = True
break
if not is_redundant:
filtered_results.append(current_result)
return filtered_results
def calculate_content_similarity(self, text1: str, text2: str) -> float:
"""计算文本相似度"""
# 使用 Jaccard 相似度
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())
if not words1 or not words2:
return 0.0
intersection = len(words1.intersection(words2))
union = len(words1.union(words2))
return intersection / union if union > 0 else 0.0
def build_context_chunks(self, results: List[RetrievalResult]) -> List[Dict[str, Any]]:
"""构建上下文块"""
chunks = []
total_tokens = 0
for result in results:
# 计算当前结果的 token 数量
tokens = self.tokenizer.encode(result.content)
token_count = len(tokens)
# 检查是否超出 token 限制
if total_tokens + token_count > self.config.max_tokens:
break
# 创建上下文块
chunk = {
"content": result.content,
"tokens": token_count,
"source": result.source,
"score": result.score,
"confidence": result.confidence,
"retrieval_method": result.retrieval_method,
"metadata": result.metadata
}
chunks.append(chunk)
total_tokens += token_count
# 检查是否超出块数限制
if len(chunks) >= self.config.max_chunks:
break
return chunks
def generate_metadata(self, chunks: List[Dict[str, Any]]) -> Dict[str, Any]:
"""生成上下文元数据"""
if not chunks:
return {
"total_chunks": 0,
"total_tokens": 0,
"sources": [],
"diversity_score": 0.0,
"average_confidence": 0.0
}
# 统计信息
total_tokens = sum(chunk["tokens"] for chunk in chunks)
sources = list(set(chunk["source"] for chunk in chunks))
# 计算平均置信度
avg_confidence = np.mean([chunk["confidence"] for chunk in chunks])
# 计算多样性分数
diversity_score = self.calculate_diversity_score(chunks)
return {
"total_chunks": len(chunks),
"total_tokens": total_tokens,
"sources": sources,
"diversity_score": diversity_score,
"average_confidence": avg_confidence,
"chunk_sizes": [chunk["tokens"] for chunk in chunks]
}
def calculate_diversity_score(self, chunks: List[Dict[str, Any]]) -> float:
"""计算多样性分数"""
if len(chunks) < 2:
return 0.0
# 基于来源的多样性
sources = [chunk["source"] for chunk in chunks]
unique_sources = len(set(sources))
source_diversity = unique_sources / len(sources) if sources else 0.0
# 基于内容的多样性(简化版本)
content_diversity = 0.0
if len(chunks) > 1:
similarities = []
for i in range(len(chunks)):
for j in range(i + 1, len(chunks)):
sim = self.calculate_content_similarity(
chunks[i]["content"],
chunks[j]["content"]
)
similarities.append(sim)
avg_similarity = np.mean(similarities) if similarities else 0.0
content_diversity = 1.0 - avg_similarity
# 综合多样性分数
diversity_score = (
self.config.source_diversity * source_diversity +
(1 - self.config.source_diversity) * content_diversity
)
return diversity_score
class QualityEvaluator:
"""生成质量评估器"""
def __init__(self):
self.metrics = {
'relevance': self.evaluate_relevance,
'accuracy': self.evaluate_accuracy,
'completeness': self.evaluate_completeness,
'coherence': self.evaluate_coherence
}
def evaluate_generation(self, query: str, context: str, generated_text: str) -> Dict[str, float]:
"""评估生成质量"""
scores = {}
for metric_name, metric_func in self.metrics.items():
try:
score = metric_func(query, context, generated_text)
scores[metric_name] = score
except Exception as e:
print(f"评估指标 {metric_name} 失败: {e}")
scores[metric_name] = 0.0
# 综合分数
scores['overall'] = np.mean(list(scores.values()))
return scores
def evaluate_relevance(self, query: str, context: str, generated_text: str) -> float:
"""评估相关性"""
# 使用嵌入相似度
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-MiniLM-L6-v2')
query_embedding = model.encode([query])
generated_embedding = model.encode([generated_text])
similarity = cosine_similarity(query_embedding, generated_embedding)[0][0]
return float(similarity)
except Exception as e:
print(f"相关性评估失败: {e}")
# 回退到关键词匹配
query_words = set(query.lower().split())
generated_words = set(generated_text.lower().split())
if not query_words:
return 0.0
overlap = len(query_words.intersection(generated_words))
return overlap / len(query_words)
def evaluate_accuracy(self, query: str, context: str, generated_text: str) -> float:
"""评估准确性"""
# 检查生成内容是否基于上下文
context_words = set(context.lower().split())
generated_words = set(generated_text.lower().split())
if not generated_words:
return 0.0
# 计算基于上下文的比例
context_based_words = generated_words.intersection(context_words)
accuracy = len(context_based_words) / len(generated_words)
return accuracy
def evaluate_completeness(self, query: str, context: str, generated_text: str) -> float:
"""评估完整性"""
# 检查是否回答了查询的所有方面
# 简化实现:检查关键概念覆盖
query_concepts = self.extract_key_concepts(query)
generated_concepts = self.extract_key_concepts(generated_text)
if not query_concepts:
return 0.0
covered_concepts = sum(1 for concept in query_concepts
if any(covered in generated_concepts for covered in self.expand_concept(concept)))
return covered_concepts / len(query_concepts)
def evaluate_coherence(self, query: str, context: str, generated_text: str) -> float:
"""评估连贯性"""
# 使用语言模型评估流畅度
try:
from transformers import pipeline
# 使用困惑度作为连贯性指标
unmasker = pipeline("fill-mask", model="bert-base-uncased")
# 随机遮蔽一些词语并检查预测准确性
sentences = generated_text.split('.')
coherence_scores = []
for sentence in sentences[:5]: # 检查前5个句子
if len(sentence.split()) > 5:
words = sentence.split()
masked_sentence = " ".join(words[:-1]) + " [MASK]"
try:
predictions = unmasker(masked_sentence)
if predictions and predictions[0]['token_str'] == words[-1]:
coherence_scores.append(1.0)
else:
coherence_scores.append(0.5)
except:
coherence_scores.append(0.5)
return np.mean(coherence_scores) if coherence_scores else 0.5
except Exception as e:
print(f"连贯性评估失败: {e}")
# 回退到简单的句子结构检查
sentences = generated_text.split('.')
if not sentences:
return 0.0
# 检查平均句子长度
avg_length = np.mean([len(sentence.split()) for sentence in sentences])
# 理想句子长度在10-30个词之间
coherence = 1.0 - abs(avg_length - 20) / 20
return max(0.0, coherence)
def extract_key_concepts(self, text: str) -> List[str]:
"""提取关键概念"""
try:
import spacy
nlp = spacy.load("en_core_web_sm")
doc = nlp(text)
concepts = []
# 提取命名实体
for ent in doc.ents:
concepts.append(ent.text.lower())
# 提取名词短语
for chunk in doc.noun_chunks:
if len(chunk.text.split()) <= 3: # 限制短语长度
concepts.append(chunk.text.lower())
return list(set(concepts)) # 去重
except Exception as e:
print(f"概念提取失败: {e}")
# 回退到关键词提取
words = text.lower().split()
# 简单的词频统计
word_freq = {}
for word in words:
if len(word) > 3: # 过滤短词
word_freq[word] = word_freq.get(word, 0) + 1
# 返回频率最高的词作为概念
return [word for word, freq in sorted(word_freq.items(), key=lambda x: x[1], reverse=True)[:10]]
def expand_concept(self, concept: str) -> List[str]:
"""扩展概念(同义词等)"""
# 简单的扩展:返回概念本身和复数形式
expansions = [concept]
# 添加复数形式(简单规则)
if concept.endswith('s'):
expansions.append(concept[:-1]) # 去s
else:
expansions.append(concept + 's') # 加s
return expansions
实时更新与知识同步增量知识更新机制实现高效的增量知识更新与同步机制:# knowledge_updater.py
import asyncio
import aiohttp
from typing import List, Dict, Any, Optional, Set
from dataclasses import dataclass
from datetime import datetime, timedelta
import redis
import json
import hashlib
from collections import defaultdict
import sqlite3
from abc import ABC, abstractmethod
@dataclass
class KnowledgeUpdate:
"""知识更新数据结构"""
update_id: str
content: str
source: str
update_type: str # add, modify, delete
timestamp: datetime
metadata: Dict[str, Any]
priority: int = 1
dependencies: List[str] = None
@dataclass
class UpdateConfig:
"""更新配置"""
batch_size: int = 100
update_interval: int = 300 # 5分钟
max_retry_attempts: int = 3
retry_delay: int = 60
deduplication_enabled: bool = True
conflict_resolution: str = "last_write_wins"
notification_enabled: bool = True
class UpdateSource(ABC):
"""更新源抽象类"""
def __init__(self, source_id: str, config: Dict[str, Any]):
self.source_id = source_id
self.config = config
self.last_update_time = None
@abstractmethod
async def fetch_updates(self, since: datetime) -> List[KnowledgeUpdate]:
"""获取更新"""
pass
@abstractmethod
async def validate_update(self, update: KnowledgeUpdate) -> bool:
"""验证更新"""
pass
class DatabaseUpdateSource(UpdateSource):
"""数据库更新源"""
def __init__(self, source_id: str, config: Dict[str, Any]):
super().__init__(source_id, config)
self.connection_string = config.get('connection_string')
self.table_name = config.get('table_name', 'knowledge_base')
self.sync_query = config.get('sync_query', f"SELECT * FROM {self.table_name} WHERE updated_at > ?")
async def fetch_updates(self, since: datetime) -> List[KnowledgeUpdate]:
"""从数据库获取更新"""
updates = []
try:
# 使用 asyncio 兼容的数据库连接
import aiosqlite
async with aiosqlite.connect(self.connection_string) as db:
async with db.execute(self.sync_query, (since,)) as cursor:
async for row in cursor:
update = self.parse_database_row(row)
if update and await self.validate_update(update):
updates.append(update)
self.last_update_time = datetime.now()
except Exception as e:
print(f"数据库更新源 {self.source_id} 获取更新失败: {e}")
return updates
def parse_database_row(self, row) -> Optional[KnowledgeUpdate]:
"""解析数据库行"""
try:
# 假设 row 是字典或可以通过索引访问
if isinstance(row, dict):
content = row.get('content', '')
update_type = row.get('update_type', 'modify')
timestamp = row.get('updated_at', datetime.now())
update_id = row.get('id', str(hash(content)))
else:
# 假设 row 是元组,需要知道列的顺序
content = row[1] if len(row) > 1 else ''
update_type = row[3] if len(row) > 3 else 'modify'
timestamp = row[4] if len(row) > 4 else datetime.now()
update_id = row[0] if len(row) > 0 else str(hash(content))
return KnowledgeUpdate(
update_id=update_id,
content=content,
source=self.source_id,
update_type=update_type,
timestamp=timestamp,
metadata={"database_source": self.source_id}
)
except Exception as e:
print(f"解析数据库行失败: {e}")
return None
async def validate_update(self, update: KnowledgeUpdate) -> bool:
"""验证数据库更新"""
# 基本验证:内容不为空,时间戳合理
if not update.content or len(update.content.strip()) == 0:
return False
# 时间戳不能是未来时间
if update.timestamp > datetime.now() + timedelta(hours=1):
return False
return True
class APIUpdateSource(UpdateSource):
"""API 更新源"""
def __init__(self, source_id: str, config: Dict[str, Any]):
super().__init__(source_id, config)
self.api_endpoint = config.get('api_endpoint')
self.api_key = config.get('api_key')
self.headers = config.get('headers', {})
self.rate_limit = config.get('rate_limit', 100) # 每分钟请求数
self.session = None
async def fetch_updates(self, since: datetime) -> List[KnowledgeUpdate]:
"""从 API 获取更新"""
updates = []
try:
if not self.session:
self.session = aiohttp.ClientSession(
headers=self.headers,
timeout=aiohttp.ClientTimeout(total=30)
)
# 构建 API 请求
params = {
'since': since.isoformat(),
'limit': self.config.get('batch_size', 100)
}
if self.api_key:
params['api_key'] = self.api_key
async with self.session.get(self.api_endpoint, params=params) as response:
if response.status == 200:
data = await response.json()
updates = self.parse_api_response(data)
# 验证更新
valid_updates = []
for update in updates:
if await self.validate_update(update):
valid_updates.append(update)
updates = valid_updates
else:
print(f"API 请求失败: {response.status}")
self.last_update_time = datetime.now()
except Exception as e:
print(f"API 更新源 {self.source_id} 获取更新失败: {e}")
return updates
def parse_api_response(self, data: Dict[str, Any]) -> List[KnowledgeUpdate]:
"""解析 API 响应"""
updates = []
# 假设 API 返回的数据格式
if 'updates' in data:
for item in data['updates']:
try:
update = KnowledgeUpdate(
update_id=item.get('id', str(hash(str(item)))),
content=item.get('content', ''),
source=self.source_id,
update_type=item.get('type', 'modify'),
timestamp=datetime.fromisoformat(item.get('timestamp', datetime.now().isoformat())),
metadata=item.get('metadata', {}),
priority=item.get('priority', 1)
)
updates.append(update)
except Exception as e:
print(f"解析 API 数据项失败: {e}")
return updates
async def validate_update(self, update: KnowledgeUpdate) -> bool:
"""验证 API 更新"""
# 基本验证
if not update.content or len(update.content.strip()) == 0:
return False
# 检查内容长度
max_length = self.config.get('max_content_length', 10000)
if len(update.content) > max_length:
return False
# 检查更新类型
valid_types = ['add', 'modify', 'delete']
if update.update_type not in valid_types:
return False
return True
class IncrementalKnowledgeUpdater:
"""增量知识更新器"""
def __init__(self, config: UpdateConfig):
self.config = config
self.update_sources = {}
self.cache = redis.Redis(
host=config.redis_host or 'localhost',
port=config.redis_port or 6379,
decode_responses=True
)
self.update_queue = asyncio.Queue(maxsize=1000)
self.processing_stats = defaultdict(int)
# 初始化数据库
self.init_database()
def init_database(self):
"""初始化更新跟踪数据库"""
conn = sqlite3.connect('knowledge_updates.db')
cursor = conn.cursor()
cursor.execute('''
CREATE TABLE IF NOT EXISTS update_log (
update_id TEXT PRIMARY KEY,
source TEXT,
update_type TEXT,
content_hash TEXT,
timestamp TEXT,
status TEXT,
retry_count INTEGER,
error_message TEXT
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS update_dependencies (
update_id TEXT,
dependency_id TEXT,
FOREIGN KEY (update_id) REFERENCES update_log(update_id),
FOREIGN KEY (dependency_id) REFERENCES update_log(update_id)
)
''')
conn.commit()
conn.close()
def add_update_source(self, source: UpdateSource):
"""添加更新源"""
self.update_sources[source.source_id] = source
async def start_update_loop(self):
"""启动更新循环"""
print("启动增量知识更新循环...")
# 启动多个工作协程
workers = []
for i in range(3): # 3个工作协程
worker = asyncio.create_task(self.update_worker(f"worker-{i}"))
workers.append(worker)
# 启动更新收集协程
collector = asyncio.create_task(self.update_collector())
# 启动重试处理协程
retry_handler = asyncio.create_task(self.retry_handler())
# 等待所有协程完成
await asyncio.gather(collector, retry_handler, *workers)
async def update_collector(self):
"""更新收集器"""
while True:
try:
# 从所有源收集更新
all_updates = []
for source_id, source in self.update_sources.items():
last_update = self.get_last_update_time(source_id)
try:
updates = await source.fetch_updates(last_update)
all_updates.extend(updates)
print(f"从源 {source_id} 收集到 {len(updates)} 个更新")
except Exception as e:
print(f"收集源 {source_id} 更新失败: {e}")
# 去重处理
if self.config.deduplication_enabled:
all_updates = await self.deduplicate_updates(all_updates)
# 添加到处理队列
for update in all_updates:
await self.update_queue.put(update)
# 更新统计
self.processing_stats['collected'] += len(all_updates)
# 等待下一轮收集
await asyncio.sleep(self.config.update_interval)
except Exception as e:
print(f"更新收集器错误: {e}")
await asyncio.sleep(60) # 错误后等待1分钟
async def update_worker(self, worker_id: str):
"""更新工作协程"""
print(f"更新工作协程 {worker_id} 启动")
while True:
try:
# 从队列获取更新
update = await self.update_queue.get()
# 处理更新
success = await self.process_update(update)
if success:
self.processing_stats['successful'] += 1
else:
self.processing_stats['failed'] += 1
# 标记任务完成
self.update_queue.task_done()
except asyncio.CancelledError:
print(f"更新工作协程 {worker_id} 被取消")
break
except Exception as e:
print(f"更新工作协程 {worker_id} 错误: {e}")
self.processing_stats['errors'] += 1
async def process_update(self, update: KnowledgeUpdate) -> bool:
"""处理单个更新"""
try:
# 检查是否已经处理过
if await self.is_update_processed(update.update_id):
print(f"更新 {update.update_id} 已经处理过,跳过")
return True
# 检查依赖关系
if not await self.check_dependencies(update):
print(f"更新 {update.update_id} 依赖未满足,延迟处理")
await self.delay_update(update)
return False
# 执行更新
success = await self.execute_update(update)
if success:
# 记录成功
await self.record_update_success(update)
# 发送通知
if self.config.notification_enabled:
await self.send_update_notification(update)
return True
else:
# 记录失败并重试
await self.record_update_failure(update, "执行失败")
return await self.retry_update(update)
except Exception as e:
print(f"处理更新 {update.update_id} 失败: {e}")
await self.record_update_failure(update, str(e))
return await self.retry_update(update)
async def deduplicate_updates(self, updates: List[KnowledgeUpdate]) -> List[KnowledgeUpdate]:
"""更新去重"""
if not self.config.deduplication_enabled:
return updates
deduplicated = []
seen_hashes = set()
for update in updates:
# 计算内容哈希
content_hash = hashlib.md5(update.content.encode()).hexdigest()
if content_hash not in seen_hashes:
seen_hashes.add(content_hash)
update.metadata['content_hash'] = content_hash
deduplicated.append(update)
else:
print(f"检测到重复更新: {update.update_id}")
return deduplicated
async def check_dependencies(self, update: KnowledgeUpdate) -> bool:
"""检查更新依赖"""
if not update.dependencies:
return True
for dependency_id in update.dependencies:
if not await self.is_update_processed(dependency_id):
return False
return True
async def execute_update(self, update: KnowledgeUpdate) -> bool:
"""执行更新操作"""
# 这里应该调用实际的知识库更新逻辑
# 简化实现
print(f"执行更新: {update.update_id}, 类型: {update.update_type}")
# 模拟更新操作
await asyncio.sleep(0.1) # 模拟处理时间
# 更新向量索引
await self.update_vector_index(update)
# 更新关键词索引
await self.update_keyword_index(update)
# 更新知识图谱
if update.update_type in ['add', 'modify']:
await self.update_knowledge_graph(update)
return True
async def update_vector_index(self, update: KnowledgeUpdate):
"""更新向量索引"""
# 这里应该调用向量索引的更新逻辑
pass
async def update_keyword_index(self, update: KnowledgeUpdate):
"""更新关键词索引"""
# 这里应该调用关键词索引的更新逻辑
pass
async def update_knowledge_graph(self, update: KnowledgeUpdate):
"""更新知识图谱"""
# 这里应该调用知识图谱的更新逻辑
pass
async def retry_update(self, update: KnowledgeUpdate) -> bool:
"""重试更新"""
retry_count = update.metadata.get('retry_count', 0)
if retry_count >= self.config.max_retry_attempts:
print(f"更新 {update.update_id} 达到最大重试次数,放弃")
return False
# 增加重试计数
update.metadata['retry_count'] = retry_count + 1
# 延迟后重新加入队列
await asyncio.sleep(self.config.retry_delay * (retry_count + 1))
await self.update_queue.put(update)
return True
async def retry_handler(self):
"""重试处理协程"""
while True:
try:
# 定期检查和处理失败的重试
await self.process_failed_retries()
await asyncio.sleep(300) # 每5分钟检查一次
except Exception as e:
print(f"重试处理器错误: {e}")
await asyncio.sleep(60)
def get_last_update_time(self, source_id: str) -> datetime:
"""获取最后更新时间"""
cache_key = f"last_update:{source_id}"
last_time = self.cache.get(cache_key)
if last_time:
return datetime.fromisoformat(last_time)
else:
# 默认返回24小时前
return datetime.now() - timedelta(hours=24)
async def is_update_processed(self, update_id: str) -> bool:
"""检查更新是否已处理"""
cache_key = f"processed:{update_id}"
return self.cache.exists(cache_key)
async def record_update_success(self, update: KnowledgeUpdate):
"""记录更新成功"""
# 缓存记录
cache_key = f"processed:{update.update_id}"
self.cache.setex(cache_key, 86400, update.timestamp.isoformat()) # 保留24小时
# 数据库记录
conn = sqlite3.connect('knowledge_updates.db')
cursor = conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO update_log
(update_id, source, update_type, content_hash, timestamp, status, retry_count)
VALUES (?, ?, ?, ?, ?, ?, ?)
''', (
update.update_id,
update.source,
update.update_type,
update.metadata.get('content_hash', ''),
update.timestamp.isoformat(),
'success',
update.metadata.get('retry_count', 0)
))
conn.commit()
conn.close()
async def record_update_failure(self, update: KnowledgeUpdate, error_message: str):
"""记录更新失败"""
conn = sqlite3.connect('knowledge_updates.db')
cursor = conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO update_log
(update_id, source, update_type, content_hash, timestamp, status, retry_count, error_message)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
update.update_id,
update.source,
update.update_type,
update.metadata.get('content_hash', ''),
update.timestamp.isoformat(),
'failed',
update.metadata.get('retry_count', 0),
error_message
))
conn.commit()
conn.close()
async def send_update_notification(self, update: KnowledgeUpdate):
"""发送更新通知"""
# 这里可以实现 WebSocket 通知、Webhook 调用等
notification = {
"type": "knowledge_update",
"update_id": update.update_id,
"source": update.source,
"update_type": update.update_type,
"timestamp": update.timestamp.isoformat(),
"priority": update.priority
}
print(f"发送更新通知: {notification}")
# 实际实现中可以调用 WebSocket、Webhook 等

发表评论 取消回复