AI RAG 从0开始构建简单RAG orzCat 2025-03-21 2025-03-21 前言 使用了很多开源的RAG方案,包括Dify,coze,cherry-studio等等,但是真正要了解RAG或者随着需求的升级,不可避免要自己开始搓RAG 现在RAG方向的发展很快,graphrag、lightrag等新方法我们暂且按下不表,先学习使用langchain完成一个标准流程的rag结构
了解RAG流程 做之前先了解RAG的基本流程,由于已经确定了使用langchain去开发,可以通过AI整理信息归纳出整体流程
整体流程
数据准备阶段
文档加载 📌 工具:LangChain Document Loaders
📄 支持格式:PDF/TXT/HTML/Markdown/数据库等
文本分块 🔧 方法:递归分块/固定窗口重叠分块 🛠️ 工具:RecursiveCharacterTextSplitter
向量化处理 🤖 模型:BGE、sentence-transformers
存储阶段 🗄️ 数据库选择 :
轻量级首选 :ChromaDB(内存/持久化模式)
生产级方案 :
Elasticsearch(支持混合搜索) [[Elasticsearch笔记]]
Pinecone(全托管云服务)
检索阶段 🔍 检索器类型 :
基础:VectorStoreRetriever
(纯向量检索)
增强:
混合检索:结合BM25+向量搜索(需Elasticsearch)
重排阶段(Reranking) 🎯 核心价值 :提升结果相关性
开源方案 :
BAAI/bge-reranker-base
cross-encoder/ms-marco-MiniLM-L6 ⚖️ 平衡策略:保留top 3-5个最终结果
生成阶段 💬 选择适当的LLM
[!NOTE] 总结 RAG的整体流程就是 加载文档->向量化->存储->检索->重排->回答 所有模型都使用硅基模型 嵌入模型BAAI/bge-m3,重排序模型BAAI/bge-reranker-v2-m3,对话模型deepseek-ai/DeepSeek-V3 数据库使用 Elasticsearch 混合检索
部署Elasticsearch windowns下部署 按照官网流程 Download Elasticsearch | Elastic
下载
运行bin\elasticsearch.bat
生成一个新密码 用户名:elastic
1 ./elasticsearch-reset-password -u elastic -i
开发 为了完全不会开发的宝宝选手(我自己),我直接使用cursor进行开发!
初步构建
帮我构建一个rag系统,按照文档加载分块,向量化,储存(使用elasticsearch),检索(BM25+向量检索),重排,生成的流程来,不同功能放在不同代码中
这样基本的架构就出来了,接下来有一些细节需要修改完善,具体根据cursor生成的代码不同需要修改的部分也不同,主要是以下几个方面:
1 2 3 4 5 6 7 8 9 10 11 12 13 def __init__ (self ): self .es = Elasticsearch( "https://localhost:9200" , basic_auth=("elastic" , "xxx" ), verify_certs=False )
如果想看到数据库的索引,可以下载ES的相关工具Kibana
确保文档的导入正确,有时候curosr会用上不存在的库,注意让他多改几遍
根据文件名命令索引,初次生成的代码可能用的固定的名称进行索引,对后续的扩展不友好 接下来根据自己的需要修改一下聊天模型的prompt和主代码的启动逻辑,安装好依赖 ,跑起来进行测试即可
代码 这里我是用的命令行的方式将文件的路径导入,python app.py C:/documents/xxx.md
app.py 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 from typing import List , Dict import osimport argparsefrom document_processor import DocumentProcessorfrom vector_store import VectorStorefrom retriever import Retrieverfrom reranker import Rerankerfrom generator import Generatorclass RAGSystem : def __init__ (self ): self .doc_processor = DocumentProcessor() self .vector_store = VectorStore() self .retriever = Retriever() self .reranker = Reranker() self .generator = Generator() def process_documents (self, documents_path: str ) -> None : """处理并索引文档""" print (f"开始处理文档: {documents_path} " ) processed_docs, index_name = self .doc_processor.process(documents_path) print (f"文档处理完成,共处理 {len (processed_docs)} 个文档片段" ) print (f"正在将文档存入向量数据库(索引:{index_name} )..." ) self .vector_store.store(processed_docs, index_name) print ("文档存储完成!系统已准备就绪。" ) def query (self, query: str ) -> str : """处理用户查询""" print ("\n正在检索相关文档..." ) retrieved_docs, index_name = self .retriever.retrieve(query) print ("正在重排序文档..." ) reranked_docs = self .reranker.rerank(query, retrieved_docs, index_name) print ("正在生成回答...\n" ) response = self .generator.generate(query, reranked_docs) return response def main (): parser = argparse.ArgumentParser(description='RAG系统 - 文档问答助手' ) parser.add_argument('docs_path' , type =str , help ='文档目录路径 (支持 Windows 和 Linux 路径)' ) args = parser.parse_args() docs_path = os.path.normpath(args.docs_path) if not os.path.exists(docs_path): print (f"错误:路径 '{docs_path} ' 不存在!" ) return rag_system = RAGSystem() try : rag_system.process_documents(docs_path) print ("\n现在您可以开始提问了!(输入 'q' 或 'quit' 退出)" ) while True : query = input ("\n请输入您的问题: " ).strip() if query.lower() in ['q' , 'quit' , 'exit' ]: print ("感谢使用!再见!" ) break if not query: print ("问题不能为空,请重新输入!" ) continue try : response = rag_system.query(query) print ("\n回答:" ) print (response) except Exception as e: print (f"\n生成回答时出错:{str (e)} " ) print ("请重试或联系管理员。" ) except KeyboardInterrupt: print ("\n\n程序被用户中断。感谢使用!" ) except Exception as e: print (f"\n程序运行出错:{str (e)} " ) if __name__ == "__main__" : main()
document_processor.py 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 from typing import List , Dict from langchain.text_splitter import RecursiveCharacterTextSplitterfrom langchain_community.document_loaders import ( DirectoryLoader, UnstructuredMarkdownLoader, PyPDFLoader, TextLoader ) import osclass DocumentLoader : """通用文档加载器""" def __init__ (self, file_path: str ): self .file_path = file_path self .extension = os.path.splitext(file_path)[1 ].lower() def load (self ): try : if self .extension == '.md' : loader = UnstructuredMarkdownLoader(self .file_path, encoding='utf-8' ) elif self .extension == '.pdf' : loader = PyPDFLoader(self .file_path) elif self .extension == '.txt' : loader = TextLoader(self .file_path, encoding='utf-8' ) else : raise ValueError(f"不支持的文件格式: {self.extension} " ) return loader.load() except UnicodeDecodeError: if self .extension in ['.md' , '.txt' ]: loader = TextLoader(self .file_path, encoding='gbk' ) return loader.load() raise class DocumentProcessor : def __init__ (self ): self .text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000 , chunk_overlap=200 , length_function=len , ) def get_index_name (self, path: str ) -> str : """根据文件路径生成索引名称""" if os.path.isdir(path): return f"rag_{os.path.basename(path).lower()} " else : return f"rag_{os.path.splitext(os.path.basename(path))[0 ].lower()} " def process (self, path: str ) -> tuple [List [Dict ], str ]: """ 加载并处理文档,支持目录或单个文件 返回:(处理后的文档列表, 索引名称) """ index_name = self .get_index_name(path) if os.path.isdir(path): documents = [] for root, _, files in os.walk(path): for file in files: file_path = os.path.join(root, file) try : loader = DocumentLoader(file_path) documents.extend(loader.load()) except Exception as e: print (f"警告:加载文件 {file_path} 时出错: {str (e)} " ) continue else : try : loader = DocumentLoader(path) documents = loader.load() except Exception as e: print (f"加载文件时出错: {str (e)} " ) raise chunks = self .text_splitter.split_documents(documents) processed_docs = [] for i, chunk in enumerate (chunks): processed_docs.append({ 'id' : f'doc_{i} ' , 'content' : chunk.page_content, 'metadata' : chunk.metadata }) return processed_docs, index_name
generator.py 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 from typing import List , Dict import requestsclass Generator : def __init__ (self ): self .api_key = "sk-xxx" self .api_base = "https://api.siliconflow.cn/v1" def generate (self, query: str , context_docs: List [Dict ] ) -> str : """使用SiliconFlow的chat API生成回答""" context_with_refs = [] for i, doc in enumerate (context_docs, 1 ): source_name = doc['index_name' ][4 :] if doc['index_name' ].startswith('rag_' ) else doc['index_name' ] context_with_refs.append(f"[{i} ] {doc['content' ]} \n来源:{source_name} " ) context = "\n\n" .join(context_with_refs) headers = { "Authorization" : f"Bearer {self.api_key} " , "Content-Type" : "application/json" } system_prompt = """你是一个有帮助的助手。请基于提供的参考内容回答问题。 1. 如果从参考内容中找到答案,请在相关内容后用方括号标注来源编号,例如:[1]、[2] 2. 如果内容来自多个来源,请标注所有相关来源 3. 如果无法从参考内容中得到答案,请明确说明 4. 回答要简洁清晰,避免重复引用 5. 在回答的最后,列出所有引用的来源文件名称""" response = requests.post( f"{self.api_base} /chat/completions" , headers=headers, json={ "model" : "deepseek-ai/DeepSeek-V3" , "messages" : [ {"role" : "system" , "content" : system_prompt}, {"role" : "user" , "content" : f""" 参考内容: {context} 问题:{query} 请按照要求回答问题,包括引用标注和来源列表。 """ } ], "temperature" : 0.7 , "max_tokens" : 1024 } ) if response.status_code != 200 : raise Exception(f"Error in generation: {response.text} " ) return response.json()["choices" ][0 ]["message" ]["content" ]
reranker.py 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 from typing import List , Dict import requestsclass Reranker : def __init__ (self ): self .api_key = "sk-xxx" self .api_base = "https://api.siliconflow.cn/v1" def rerank (self, query: str , documents: List [Dict ], index_name: str , top_k: int = 5 ) -> List [Dict ]: """使用SiliconFlow的rerank API重排序文档""" headers = { "Authorization" : f"Bearer {self.api_key} " , "Content-Type" : "application/json" } docs = [doc['content' ] for doc in documents] response = requests.post( f"{self.api_base} /rerank" , headers=headers, json={ "model" : "BAAI/bge-reranker-v2-m3" , "query" : query, "documents" : docs, "top_n" : top_k } ) if response.status_code != 200 : raise Exception(f"Error in reranking: {response.text} " ) results = response.json()["results" ] reranked_docs = [] for result in results: doc_index = result["index" ] original_doc = documents[doc_index].copy() original_doc['rerank_score' ] = result["relevance_score" ] original_doc['index_name' ] = index_name reranked_docs.append(original_doc) return reranked_docs
retriever.py 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 from typing import List , Dict , Tuple import requestsfrom elasticsearch import Elasticsearchimport osclass Retriever : def __init__ (self ): self .es = Elasticsearch( "https://localhost:9200" , basic_auth=("elastic" , "xxx" ), verify_certs=False ) self .api_key = "sk-xxx" self .api_base = "https://api.siliconflow.cn/v1" def get_embedding (self, text: str ) -> List [float ]: """调用SiliconFlow的embedding API获取向量""" headers = { "Authorization" : f"Bearer {self.api_key} " , "Content-Type" : "application/json" } response = requests.post( f"{self.api_base} /embeddings" , headers=headers, json={ "model" : "BAAI/bge-m3" , "input" : text } ) if response.status_code == 200 : return response.json()["data" ][0 ]["embedding" ] else : raise Exception(f"Error getting embedding: {response.text} " ) def get_all_indices (self ) -> List [str ]: """获取所有 RAG 相关的索引""" indices = self .es.indices.get_alias().keys() return [idx for idx in indices if idx.startswith('rag_' )] def retrieve (self, query: str , top_k: int = 10 ) -> Tuple [List [Dict ], str ]: """混合检索:结合 BM25 和向量检索""" indices = self .get_all_indices() if not indices: raise Exception("没有找到可用的文档索引!" ) query_vector = self .get_embedding(query) all_results = [] for index in indices: script_query = { "script_score" : { "query" : { "match" : { "content" : query } }, "script" : { "source" : "cosineSimilarity(params.query_vector, 'vector') + 1.0" , "params" : {"query_vector" : query_vector} } } } response = self .es.search( index=index, body={ "query" : script_query, "size" : top_k } ) for hit in response['hits' ]['hits' ]: result = { 'id' : hit['_id' ], 'content' : hit['_source' ]['content' ], 'score' : hit['_score' ], 'metadata' : hit['_source' ]['metadata' ], 'index' : index } all_results.append(result) all_results.sort(key=lambda x: x['score' ], reverse=True ) top_results = all_results[:top_k] if top_results: most_relevant_index = top_results[0 ]['index' ] else : most_relevant_index = indices[0 ] return top_results, most_relevant_index
vector_store.py 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 from typing import List , Dict import requestsimport numpy as npfrom elasticsearch import Elasticsearchclass VectorStore : def __init__ (self ): self .es = Elasticsearch( "https://localhost:9200" , basic_auth=("elastic" , "xxx" ), verify_certs=False ) self .api_key = "sk-xxx" self .api_base = "https://api.siliconflow.cn/v1" def get_embedding (self, text: str ) -> List [float ]: """调用SiliconFlow的embedding API获取向量""" headers = { "Authorization" : f"Bearer {self.api_key} " , "Content-Type" : "application/json" } response = requests.post( f"{self.api_base} /embeddings" , headers=headers, json={ "model" : "BAAI/bge-m3" , "input" : text } ) if response.status_code == 200 : return response.json()["data" ][0 ]["embedding" ] else : raise Exception(f"Error getting embedding: {response.text} " ) def store (self, documents: List [Dict ], index_name: str ) -> None : """将文档存储到 Elasticsearch""" if not self .es.indices.exists(index=index_name): self .create_index(index_name) bulk_data = [] for doc in documents: vector = self .get_embedding(doc['content' ]) bulk_data.append({ "index" : { "_index" : index_name, "_id" : doc['id' ] } }) bulk_data.append({ "content" : doc['content' ], "vector" : vector, "metadata" : doc['metadata' ] }) self .es.bulk(operations=bulk_data) def create_index (self, index_name: str ): """创建 Elasticsearch 索引""" settings = { "mappings" : { "properties" : { "content" : {"type" : "text" }, "vector" : { "type" : "dense_vector" , "dims" : 1024 }, "metadata" : {"type" : "object" } } } } self .es.indices.create(index=index_name, body=settings)
效果展示
开始升级
重写运行结构 :现在每跑一次都要重复一整个流程,非常不方便,我们将app.py的运行流程改成交互式的,显示当前所有索引,然后询问是直接回答还是添加文件。直接向cursor描述这个需求即可,注意添加文件用count的api获取当前document的数量,然后正确编号添加(不然cursor就变成覆盖了,改了半天也不知道用api去操作)
多知识库 :实际上我们想要的是有多个知识库,每个知识库里面有各种文件,而ES做到这一点我们就只需要将其每个索引当作一个知识库,索引当中的document增加元数据file_name用于区分,这样我们后续操作file_name即可了
追踪来源 :之前显示的来源列表是索引名称,现在索引名称其实是知识库的名称,我们更希望的是这段来源的文件名称,修改也很简单,在生成里面获取该doc的元数据信息再获取文件名即可
扩展
进阶扩展1:图片理解知识库 [[imgRAG]]解决问题:
处理图文结合的markdown/pdf
理解图片并将其嵌入知识库
输出的内容图文结合