从0开始构建简单RAG

前言

使用了很多开源的RAG方案,包括Dify,coze,cherry-studio等等,但是真正要了解RAG或者随着需求的升级,不可避免要自己开始搓RAG
现在RAG方向的发展很快,graphrag、lightrag等新方法我们暂且按下不表,先学习使用langchain完成一个标准流程的rag结构

了解RAG流程

做之前先了解RAG的基本流程,由于已经确定了使用langchain去开发,可以通过AI整理信息归纳出整体流程

整体流程

  1. 数据准备阶段
    • 文档加载
      📌 工具:LangChain Document Loaders
      📄 支持格式:PDF/TXT/HTML/Markdown/数据库等
    • 文本分块
      🔧 方法:递归分块/固定窗口重叠分块
      🛠️ 工具:RecursiveCharacterTextSplitter
    • 向量化处理
      🤖 模型:BGE、sentence-transformers
  2. 存储阶段
    🗄️ 数据库选择
    • 轻量级首选:ChromaDB(内存/持久化模式)
    • 生产级方案
      • Elasticsearch(支持混合搜索)[[Elasticsearch笔记]]
      • Pinecone(全托管云服务)
  3. 检索阶段
    🔍 检索器类型
    • 基础:VectorStoreRetriever(纯向量检索)
    • 增强:
      • 混合检索:结合BM25+向量搜索(需Elasticsearch)
  4. 重排阶段(Reranking)
    🎯 核心价值:提升结果相关性
    • 开源方案
      • BAAI/bge-reranker-base
      • cross-encoder/ms-marco-MiniLM-L6
        ⚖️ 平衡策略:保留top 3-5个最终结果
  5. 生成阶段
    💬 选择适当的LLM

[!NOTE] 总结
RAG的整体流程就是 加载文档->向量化->存储->检索->重排->回答
所有模型都使用硅基模型 嵌入模型BAAI/bge-m3,重排序模型BAAI/bge-reranker-v2-m3,对话模型deepseek-ai/DeepSeek-V3
数据库使用 Elasticsearch
混合检索

部署Elasticsearch

windowns下部署

按照官网流程 Download Elasticsearch | Elastic

  • 下载
  • 运行bin\elasticsearch.bat
  • 生成一个新密码 用户名:elastic
1
./elasticsearch-reset-password -u elastic -i

开发

为了完全不会开发的宝宝选手(我自己),我直接使用cursor进行开发!

初步构建

帮我构建一个rag系统,按照文档加载分块,向量化,储存(使用elasticsearch),检索(BM25+向量检索),重排,生成的流程来,不同功能放在不同代码中

这样基本的架构就出来了,接下来有一些细节需要修改完善,具体根据cursor生成的代码不同需要修改的部分也不同,主要是以下几个方面:

  • 使用siliconFlow的模型和正确发起请求 可以直接在cursor中导入硅基的文档然后直接@ 结合 Cursor 使用 - SiliconFlow
  • 确保检索和嵌入时正确连接ES
1
2
3
4
5
6
7
8
9
10
11
12
13
    def __init__(self):

        # ES 8.x 的连接配置

        self.es = Elasticsearch(

            "https://localhost:9200",  # 注意是 https

            basic_auth=("elastic", "xxx"),  # 替换为你的密码

            verify_certs=False  # 开发环境可以禁用证书验证

        )

如果想看到数据库的索引,可以下载ES的相关工具Kibana

  • 确保文档的导入正确,有时候curosr会用上不存在的库,注意让他多改几遍
  • 根据文件名命令索引,初次生成的代码可能用的固定的名称进行索引,对后续的扩展不友好
    接下来根据自己的需要修改一下聊天模型的prompt和主代码的启动逻辑,安装好依赖,跑起来进行测试即可

代码

这里我是用的命令行的方式将文件的路径导入,python app.py C:/documents/xxx.md

app.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from typing import List, Dict
import os
import argparse
from document_processor import DocumentProcessor
from vector_store import VectorStore
from retriever import Retriever
from reranker import Reranker
from generator import Generator

class RAGSystem:
def __init__(self):
self.doc_processor = DocumentProcessor()
self.vector_store = VectorStore()
self.retriever = Retriever()
self.reranker = Reranker()
self.generator = Generator()

def process_documents(self, documents_path: str) -> None:
"""处理并索引文档"""
print(f"开始处理文档: {documents_path}")
# 处理文档
processed_docs, index_name = self.doc_processor.process(documents_path)
print(f"文档处理完成,共处理 {len(processed_docs)} 个文档片段")

# 存储到向量数据库
print(f"正在将文档存入向量数据库(索引:{index_name})...")
self.vector_store.store(processed_docs, index_name)
print("文档存储完成!系统已准备就绪。")

def query(self, query: str) -> str:
"""处理用户查询"""
print("\n正在检索相关文档...")
# 检索相关文档
retrieved_docs, index_name = self.retriever.retrieve(query)

print("正在重排序文档...")
# 重排序
reranked_docs = self.reranker.rerank(query, retrieved_docs, index_name)

print("正在生成回答...\n")
# 生成回答
response = self.generator.generate(query, reranked_docs)
return response

def main():
# 创建命令行参数解析器
parser = argparse.ArgumentParser(description='RAG系统 - 文档问答助手')
parser.add_argument('docs_path', type=str, help='文档目录路径 (支持 Windows 和 Linux 路径)')

# 解析命令行参数
args = parser.parse_args()

# 规范化路径
docs_path = os.path.normpath(args.docs_path)

# 检查路径是否存在
if not os.path.exists(docs_path):
print(f"错误:路径 '{docs_path}' 不存在!")
return

# 初始化RAG系统
rag_system = RAGSystem()

try:
# 处理文档
rag_system.process_documents(docs_path)

# 进入交互式问答循环
print("\n现在您可以开始提问了!(输入 'q' 或 'quit' 退出)")
while True:
query = input("\n请输入您的问题: ").strip()

# 检查是否退出
if query.lower() in ['q', 'quit', 'exit']:
print("感谢使用!再见!")
break

if not query:
print("问题不能为空,请重新输入!")
continue

try:
# 获取回答
response = rag_system.query(query)
print("\n回答:")
print(response)
except Exception as e:
print(f"\n生成回答时出错:{str(e)}")
print("请重试或联系管理员。")

except KeyboardInterrupt:
print("\n\n程序被用户中断。感谢使用!")
except Exception as e:
print(f"\n程序运行出错:{str(e)}")

if __name__ == "__main__":
main()

document_processor.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import List, Dict
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
DirectoryLoader,
UnstructuredMarkdownLoader,
PyPDFLoader,
TextLoader
)
import os

class DocumentLoader:
"""通用文档加载器"""
def __init__(self, file_path: str):
self.file_path = file_path
self.extension = os.path.splitext(file_path)[1].lower()

def load(self):
try:
if self.extension == '.md':
loader = UnstructuredMarkdownLoader(self.file_path, encoding='utf-8')
elif self.extension == '.pdf':
loader = PyPDFLoader(self.file_path)
elif self.extension == '.txt':
loader = TextLoader(self.file_path, encoding='utf-8')
else:
raise ValueError(f"不支持的文件格式: {self.extension}")

return loader.load()
except UnicodeDecodeError:
# 如果 utf-8 失败,尝试 gbk
if self.extension in ['.md', '.txt']:
loader = TextLoader(self.file_path, encoding='gbk')
return loader.load()
raise

class DocumentProcessor:
def __init__(self):
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len,
)

def get_index_name(self, path: str) -> str:
"""根据文件路径生成索引名称"""
if os.path.isdir(path):
# 如果是目录,使用目录名
return f"rag_{os.path.basename(path).lower()}"
else:
# 如果是文件,使用文件名(不含扩展名)
return f"rag_{os.path.splitext(os.path.basename(path))[0].lower()}"

def process(self, path: str) -> tuple[List[Dict], str]:
"""
加载并处理文档,支持目录或单个文件
返回:(处理后的文档列表, 索引名称)
"""
# 获取索引名称
index_name = self.get_index_name(path)

# 判断是目录还是文件
if os.path.isdir(path):
# 如果是目录,使用 DirectoryLoader
documents = []
for root, _, files in os.walk(path):
for file in files:
file_path = os.path.join(root, file)
try:
loader = DocumentLoader(file_path)
documents.extend(loader.load())
except Exception as e:
print(f"警告:加载文件 {file_path} 时出错: {str(e)}")
continue
else:
# 如果是单个文件
try:
loader = DocumentLoader(path)
documents = loader.load()
except Exception as e:
print(f"加载文件时出错: {str(e)}")
raise

# 分块
chunks = self.text_splitter.split_documents(documents)

# 处理成统一格式
processed_docs = []
for i, chunk in enumerate(chunks):
processed_docs.append({
'id': f'doc_{i}',
'content': chunk.page_content,
'metadata': chunk.metadata
})

return processed_docs, index_name

generator.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from typing import List, Dict
import requests

class Generator:
def __init__(self):
self.api_key = "sk-xxx" # 替换为您的API密钥
self.api_base = "https://api.siliconflow.cn/v1"

def generate(self, query: str, context_docs: List[Dict]) -> str:
"""使用SiliconFlow的chat API生成回答"""
# 构建带有引用标记的上下文
context_with_refs = []
for i, doc in enumerate(context_docs, 1):
# 获取文件名(去掉rag_前缀)
source_name = doc['index_name'][4:] if doc['index_name'].startswith('rag_') else doc['index_name']
context_with_refs.append(f"[{i}] {doc['content']}\n来源:{source_name}")

context = "\n\n".join(context_with_refs)

headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}

system_prompt = """你是一个有帮助的助手。请基于提供的参考内容回答问题。
1. 如果从参考内容中找到答案,请在相关内容后用方括号标注来源编号,例如:[1]、[2]
2. 如果内容来自多个来源,请标注所有相关来源
3. 如果无法从参考内容中得到答案,请明确说明
4. 回答要简洁清晰,避免重复引用
5. 在回答的最后,列出所有引用的来源文件名称"""

response = requests.post(
f"{self.api_base}/chat/completions",
headers=headers,
json={
"model": "deepseek-ai/DeepSeek-V3",
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"""
参考内容:
{context}

问题:{query}

请按照要求回答问题,包括引用标注和来源列表。
"""}
],
"temperature": 0.7,
"max_tokens": 1024
}
)

if response.status_code != 200:
raise Exception(f"Error in generation: {response.text}")

return response.json()["choices"][0]["message"]["content"]

reranker.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from typing import List, Dict
import requests

class Reranker:
def __init__(self):
self.api_key = "sk-xxx" # 替换为您的API密钥
self.api_base = "https://api.siliconflow.cn/v1"

def rerank(self, query: str, documents: List[Dict], index_name: str, top_k: int = 5) -> List[Dict]:
"""使用SiliconFlow的rerank API重排序文档"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}

# 准备文档列表
docs = [doc['content'] for doc in documents]

response = requests.post(
f"{self.api_base}/rerank",
headers=headers,
json={
"model": "BAAI/bge-reranker-v2-m3",
"query": query,
"documents": docs,
"top_n": top_k
}
)

if response.status_code != 200:
raise Exception(f"Error in reranking: {response.text}")

# 处理结果
results = response.json()["results"]
reranked_docs = []

for result in results:
doc_index = result["index"]
original_doc = documents[doc_index].copy()
original_doc['rerank_score'] = result["relevance_score"]
original_doc['index_name'] = index_name
reranked_docs.append(original_doc)

return reranked_docs

retriever.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from typing import List, Dict, Tuple
import requests
from elasticsearch import Elasticsearch
import os

class Retriever:
def __init__(self):
# 使用与 vector_store.py 相同的 ES 配置
self.es = Elasticsearch(
"https://localhost:9200", # 注意是 https
basic_auth=("elastic", "xxx"), # 使用相同的密码
verify_certs=False # 开发环境可以禁用证书验证
)
self.api_key = "sk-xxx"
self.api_base = "https://api.siliconflow.cn/v1"

def get_embedding(self, text: str) -> List[float]:
"""调用SiliconFlow的embedding API获取向量"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}

response = requests.post(
f"{self.api_base}/embeddings",
headers=headers,
json={
"model": "BAAI/bge-m3",
"input": text
}
)

if response.status_code == 200:
return response.json()["data"][0]["embedding"]
else:
raise Exception(f"Error getting embedding: {response.text}")

def get_all_indices(self) -> List[str]:
"""获取所有 RAG 相关的索引"""
indices = self.es.indices.get_alias().keys()
return [idx for idx in indices if idx.startswith('rag_')]

def retrieve(self, query: str, top_k: int = 10) -> Tuple[List[Dict], str]:
"""混合检索:结合 BM25 和向量检索"""
# 获取所有 RAG 索引
indices = self.get_all_indices()
if not indices:
raise Exception("没有找到可用的文档索引!")

# 计算查询向量
query_vector = self.get_embedding(query)

# 在所有索引中搜索
all_results = []
for index in indices:
# 构建混合查询
script_query = {
"script_score": {
"query": {
"match": {
"content": query # BM25
}
},
"script": {
"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
"params": {"query_vector": query_vector}
}
}
}

# 执行检索
response = self.es.search(
index=index,
body={
"query": script_query,
"size": top_k
}
)

# 处理结果
for hit in response['hits']['hits']:
result = {
'id': hit['_id'],
'content': hit['_source']['content'],
'score': hit['_score'],
'metadata': hit['_source']['metadata'],
'index': index
}
all_results.append(result)

# 按分数排序并选择最相关的文档
all_results.sort(key=lambda x: x['score'], reverse=True)
top_results = all_results[:top_k]

# 如果有结果,返回最相关文档所在的索引
if top_results:
most_relevant_index = top_results[0]['index']
else:
most_relevant_index = indices[0] # 如果没有结果,返回第一个索引

return top_results, most_relevant_index

vector_store.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from typing import List, Dict
import requests
import numpy as np
from elasticsearch import Elasticsearch

class VectorStore:
def __init__(self):
# ES 8.x 的连接配置
self.es = Elasticsearch(
"https://localhost:9200", # 注意是 https
basic_auth=("elastic", "xxx"), # 替换为你的密码
verify_certs=False # 开发环境可以禁用证书验证
)
self.api_key = "sk-xxx"
self.api_base = "https://api.siliconflow.cn/v1"

def get_embedding(self, text: str) -> List[float]:
"""调用SiliconFlow的embedding API获取向量"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}

response = requests.post(
f"{self.api_base}/embeddings",
headers=headers,
json={
"model": "BAAI/bge-m3",
"input": text
}
)

if response.status_code == 200:
return response.json()["data"][0]["embedding"]
else:
raise Exception(f"Error getting embedding: {response.text}")

def store(self, documents: List[Dict], index_name: str) -> None:
"""将文档存储到 Elasticsearch"""
# 创建索引
if not self.es.indices.exists(index=index_name):
self.create_index(index_name)

# 批量索引文档
bulk_data = []
for doc in documents:
# 获取文档向量
vector = self.get_embedding(doc['content'])

# 准备索引数据
bulk_data.append({
"index": {
"_index": index_name,
"_id": doc['id']
}
})
bulk_data.append({
"content": doc['content'],
"vector": vector,
"metadata": doc['metadata']
})

# 批量写入
self.es.bulk(operations=bulk_data)

def create_index(self, index_name: str):
"""创建 Elasticsearch 索引"""
settings = {
"mappings": {
"properties": {
"content": {"type": "text"},
"vector": {
"type": "dense_vector",
"dims": 1024 # bge-m3的向量维度
},
"metadata": {"type": "object"}
}
}
}
self.es.indices.create(index=index_name, body=settings)

效果展示


开始升级

  1. 重写运行结构:现在每跑一次都要重复一整个流程,非常不方便,我们将app.py的运行流程改成交互式的,显示当前所有索引,然后询问是直接回答还是添加文件。直接向cursor描述这个需求即可,注意添加文件用count的api获取当前document的数量,然后正确编号添加(不然cursor就变成覆盖了,改了半天也不知道用api去操作)
  2. 多知识库:实际上我们想要的是有多个知识库,每个知识库里面有各种文件,而ES做到这一点我们就只需要将其每个索引当作一个知识库,索引当中的document增加元数据file_name用于区分,这样我们后续操作file_name即可了
  3. 追踪来源:之前显示的来源列表是索引名称,现在索引名称其实是知识库的名称,我们更希望的是这段来源的文件名称,修改也很简单,在生成里面获取该doc的元数据信息再获取文件名即可

扩展

  • 将功能封装好为后端接口:这样就可以做rag的前端了
  • 完善对于文件的删除和查看:删通过ES的api获取所有doc的元数据把file_name对上的doc全删掉;查看可以通过保存上传的文件并给出地址的方式
  • 更好的引用列表:完善生成内容的prompt,将索引到的chunk进行总结加上文件名一起放在最后的引用列表(最好点击文件名可以跳转到对应位置
  • 增加更多的文件格式:现在pdf用的是PyPDFLoader,基于pypdf,只能支持文本类的pdf,markdown也是同样没有对图片进行处理

进阶扩展1:图片理解知识库 [[imgRAG]]

解决问题:

  • 处理图文结合的markdown/pdf
  • 理解图片并将其嵌入知识库
  • 输出的内容图文结合