save all
This commit is contained in:
BIN
local_rag/File/tmp_abcd/百炼系列平板电脑产品介绍.pdf
Normal file
BIN
local_rag/File/tmp_abcd/百炼系列平板电脑产品介绍.pdf
Normal file
Binary file not shown.
2
local_rag/README.md
Normal file
2
local_rag/README.md
Normal file
@@ -0,0 +1,2 @@
|
||||
Run: uvicorn main:app --port 7866
|
||||
Then visit 127.0.0.1:7866
|
||||
File diff suppressed because one or more lines are too long
1
local_rag/VectorStore/tmp_abcd/docstore.json
Normal file
1
local_rag/VectorStore/tmp_abcd/docstore.json
Normal file
File diff suppressed because one or more lines are too long
1
local_rag/VectorStore/tmp_abcd/graph_store.json
Normal file
1
local_rag/VectorStore/tmp_abcd/graph_store.json
Normal file
@@ -0,0 +1 @@
|
||||
{"graph_dict": {}}
|
||||
1
local_rag/VectorStore/tmp_abcd/image__vector_store.json
Normal file
1
local_rag/VectorStore/tmp_abcd/image__vector_store.json
Normal file
@@ -0,0 +1 @@
|
||||
{"embedding_dict": {}, "text_id_to_ref_doc_id": {}, "metadata_dict": {}}
|
||||
1
local_rag/VectorStore/tmp_abcd/index_store.json
Normal file
1
local_rag/VectorStore/tmp_abcd/index_store.json
Normal file
@@ -0,0 +1 @@
|
||||
{"index_store/data": {"ab9a86f4-0029-48e5-b823-eccfc6f58622": {"__type__": "vector_store", "__data__": "{\"index_id\": \"ab9a86f4-0029-48e5-b823-eccfc6f58622\", \"summary\": null, \"nodes_dict\": {\"c7e14d46-b930-4a98-acaf-2854ac14c1de\": \"c7e14d46-b930-4a98-acaf-2854ac14c1de\", \"fb1f2a38-146e-4840-8f0d-221033dbc849\": \"fb1f2a38-146e-4840-8f0d-221033dbc849\"}, \"doc_id_dict\": {}, \"embeddings_dict\": {}}"}}}
|
||||
93
local_rag/chat.py
Normal file
93
local_rag/chat.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import os
|
||||
from openai import OpenAI
|
||||
from llama_index.core import StorageContext,load_index_from_storage,Settings
|
||||
from llama_index.embeddings.dashscope import (
|
||||
DashScopeEmbedding,
|
||||
DashScopeTextEmbeddingModels,
|
||||
DashScopeTextEmbeddingType,
|
||||
)
|
||||
from llama_index.postprocessor.dashscope_rerank import DashScopeRerank
|
||||
from create_kb import *
|
||||
DB_PATH = "VectorStore"
|
||||
TMP_NAME = "tmp_abcd"
|
||||
EMBED_MODEL = DashScopeEmbedding(
|
||||
model_name=DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V2,
|
||||
text_type=DashScopeTextEmbeddingType.TEXT_TYPE_DOCUMENT,
|
||||
)
|
||||
# 若使用本地嵌入模型,请取消以下注释:
|
||||
# from langchain_community.embeddings import ModelScopeEmbeddings
|
||||
# from llama_index.embeddings.langchain import LangchainEmbedding
|
||||
# embeddings = ModelScopeEmbeddings(model_id="modelscope/iic/nlp_gte_sentence-embedding_chinese-large")
|
||||
# EMBED_MODEL = LangchainEmbedding(embeddings)
|
||||
|
||||
# 设置嵌入模型
|
||||
Settings.embed_model = EMBED_MODEL
|
||||
|
||||
def get_model_response(multi_modal_input,history,model,temperature,max_tokens,history_round,db_name,similarity_threshold,chunk_cnt):
|
||||
# prompt = multi_modal_input['text']
|
||||
prompt = history[-1][0]
|
||||
tmp_files = multi_modal_input['files']
|
||||
if os.path.exists(os.path.join("File",TMP_NAME)):
|
||||
db_name = TMP_NAME
|
||||
else:
|
||||
if tmp_files:
|
||||
create_tmp_kb(tmp_files)
|
||||
db_name = TMP_NAME
|
||||
# 获取index
|
||||
print(f"prompt:{prompt},tmp_files:{tmp_files},db_name:{db_name}")
|
||||
try:
|
||||
dashscope_rerank = DashScopeRerank(top_n=chunk_cnt,return_documents=True)
|
||||
storage_context = StorageContext.from_defaults(
|
||||
persist_dir=os.path.join(DB_PATH,db_name)
|
||||
)
|
||||
index = load_index_from_storage(storage_context)
|
||||
print("index获取完成")
|
||||
retriever_engine = index.as_retriever(
|
||||
similarity_top_k=20,
|
||||
)
|
||||
# 获取chunk
|
||||
retrieve_chunk = retriever_engine.retrieve(prompt)
|
||||
print(f"原始chunk为:{retrieve_chunk}")
|
||||
try:
|
||||
results = dashscope_rerank.postprocess_nodes(retrieve_chunk, query_str=prompt)
|
||||
print(f"rerank成功,重排后的chunk为:{results}")
|
||||
except:
|
||||
results = retrieve_chunk[:chunk_cnt]
|
||||
print(f"rerank失败,chunk为:{results}")
|
||||
chunk_text = ""
|
||||
chunk_show = ""
|
||||
for i in range(len(results)):
|
||||
if results[i].score >= similarity_threshold:
|
||||
chunk_text = chunk_text + f"## {i+1}:\n {results[i].text}\n"
|
||||
chunk_show = chunk_show + f"## {i+1}:\n {results[i].text}\nscore: {round(results[i].score,2)}\n"
|
||||
print(f"已获取chunk:{chunk_text}")
|
||||
prompt_template = f"请参考以下内容:{chunk_text},以合适的语气回答用户的问题:{prompt}。如果参考内容中有图片链接也请直接返回。"
|
||||
except Exception as e:
|
||||
print(f"异常信息:{e}")
|
||||
prompt_template = prompt
|
||||
chunk_show = ""
|
||||
history[-1][-1] = ""
|
||||
client = OpenAI(
|
||||
api_key=os.getenv("DASHSCOPE_API_KEY"),
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
)
|
||||
system_message = {'role': 'system', 'content': 'You are a helpful assistant.'}
|
||||
messages = []
|
||||
history_round = min(len(history),history_round)
|
||||
for i in range(history_round):
|
||||
messages.append({'role': 'user', 'content': history[-history_round+i][0]})
|
||||
messages.append({'role': 'assistant', 'content': history[-history_round+i][1]})
|
||||
messages.append({'role': 'user', 'content': prompt_template})
|
||||
messages = [system_message] + messages
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True
|
||||
)
|
||||
assistant_response = ""
|
||||
for chunk in completion:
|
||||
assistant_response += chunk.choices[0].delta.content
|
||||
history[-1][-1] = assistant_response
|
||||
yield history,chunk_show
|
||||
134
local_rag/create_kb.py
Normal file
134
local_rag/create_kb.py
Normal file
@@ -0,0 +1,134 @@
|
||||
#####################################
|
||||
###### 创建知识库 #######
|
||||
#####################################
|
||||
import gradio as gr
|
||||
import os
|
||||
import shutil
|
||||
from llama_index.core import VectorStoreIndex,Settings,SimpleDirectoryReader
|
||||
from llama_index.embeddings.dashscope import (
|
||||
DashScopeEmbedding,
|
||||
DashScopeTextEmbeddingModels,
|
||||
DashScopeTextEmbeddingType,
|
||||
)
|
||||
from llama_index.core.schema import TextNode
|
||||
from upload_file import *
|
||||
DB_PATH = "VectorStore"
|
||||
STRUCTURED_FILE_PATH = "File/Structured"
|
||||
UNSTRUCTURED_FILE_PATH = "File/Unstructured"
|
||||
TMP_NAME = "tmp_abcd"
|
||||
EMBED_MODEL = DashScopeEmbedding(
|
||||
model_name=DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V2,
|
||||
text_type=DashScopeTextEmbeddingType.TEXT_TYPE_DOCUMENT,
|
||||
)
|
||||
# 若使用本地嵌入模型,请取消以下注释:
|
||||
# from langchain_community.embeddings import ModelScopeEmbeddings
|
||||
# from llama_index.embeddings.langchain import LangchainEmbedding
|
||||
# embeddings = ModelScopeEmbeddings(model_id="modelscope/iic/nlp_gte_sentence-embedding_chinese-large")
|
||||
# EMBED_MODEL = LangchainEmbedding(embeddings)
|
||||
|
||||
|
||||
# 设置嵌入模型
|
||||
Settings.embed_model = EMBED_MODEL
|
||||
# 刷新知识库
|
||||
def refresh_knowledge_base():
|
||||
return os.listdir(DB_PATH)
|
||||
|
||||
# 创建非结构化向量数据库
|
||||
def create_unstructured_db(db_name:str,label_name:list):
|
||||
print(f"知识库名称为:{db_name},类目名称为:{label_name}")
|
||||
if label_name is None:
|
||||
gr.Info("没有选择类目")
|
||||
elif len(db_name) == 0:
|
||||
gr.Info("没有命名知识库")
|
||||
# 判断是否存在同名向量数据库
|
||||
elif db_name in os.listdir(DB_PATH):
|
||||
gr.Info("知识库已存在,请换个名字或删除原来知识库再创建")
|
||||
else:
|
||||
gr.Info("正在创建知识库,请等待知识库创建成功信息显示后前往RAG问答")
|
||||
documents = []
|
||||
for label in label_name:
|
||||
label_path = os.path.join(UNSTRUCTURED_FILE_PATH,label)
|
||||
documents.extend(SimpleDirectoryReader(label_path).load_data())
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents
|
||||
)
|
||||
db_path = os.path.join(DB_PATH,db_name)
|
||||
if not os.path.exists(db_path):
|
||||
os.mkdir(db_path)
|
||||
index.storage_context.persist(db_path)
|
||||
elif os.path.exists(db_path):
|
||||
pass
|
||||
gr.Info("知识库创建成功,可前往RAG问答进行提问")
|
||||
|
||||
# 创建结构化向量数据库
|
||||
def create_structured_db(db_name:str,data_table:list):
|
||||
print(f"知识库名称为:{db_name},数据表名称为:{data_table}")
|
||||
if data_table is None:
|
||||
gr.Info("没有选择数据表")
|
||||
elif len(db_name) == 0:
|
||||
gr.Info("没有命名知识库")
|
||||
# 判断是否存在同名向量数据库
|
||||
elif db_name in os.listdir(DB_PATH):
|
||||
gr.Info("知识库已存在,请换个名字或删除原来知识库再创建")
|
||||
else:
|
||||
gr.Info("正在创建知识库,请等待知识库创建成功信息显示后前往RAG问答")
|
||||
documents = []
|
||||
for label in data_table:
|
||||
label_path = os.path.join(STRUCTURED_FILE_PATH,label)
|
||||
documents.extend(SimpleDirectoryReader(label_path).load_data())
|
||||
# index = VectorStoreIndex.from_documents(
|
||||
# documents
|
||||
# )
|
||||
nodes = []
|
||||
for doc in documents:
|
||||
doc_content = doc.get_content().split('\n')
|
||||
for chunk in doc_content:
|
||||
node = TextNode(text=chunk)
|
||||
node.metadata = {'source': doc.get_doc_id(),'file_name':doc.metadata['file_name']}
|
||||
nodes = nodes + [node]
|
||||
index = VectorStoreIndex(nodes)
|
||||
db_path = os.path.join(DB_PATH,db_name)
|
||||
if not os.path.exists(db_path):
|
||||
os.mkdir(db_path)
|
||||
index.storage_context.persist(db_path)
|
||||
gr.Info("知识库创建成功,可前往RAG问答进行提问")
|
||||
|
||||
|
||||
# 删除指定名称知识库
|
||||
def delete_db(db_name:str):
|
||||
if db_name is not None:
|
||||
folder_path = os.path.join(DB_PATH, db_name)
|
||||
if os.path.exists(folder_path):
|
||||
shutil.rmtree(folder_path)
|
||||
gr.Info(f"已成功删除{db_name}知识库")
|
||||
print(f"已成功删除{db_name}知识库")
|
||||
else:
|
||||
gr.Info(f"{db_name}知识库不存在")
|
||||
print(f"{db_name}知识库不存在")
|
||||
|
||||
# 实时更新知识库列表
|
||||
def update_knowledge_base():
|
||||
return gr.update(choices=os.listdir(DB_PATH))
|
||||
|
||||
# 临时文件创建知识库
|
||||
def create_tmp_kb(files):
|
||||
if not os.path.exists(os.path.join("File",TMP_NAME)):
|
||||
os.mkdir(os.path.join("File",TMP_NAME))
|
||||
for file in files:
|
||||
file_name = os.path.basename(file)
|
||||
shutil.move(file,os.path.join("File",TMP_NAME,file_name))
|
||||
documents = SimpleDirectoryReader(os.path.join("File",TMP_NAME)).load_data()
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents
|
||||
)
|
||||
db_path = os.path.join(DB_PATH,TMP_NAME)
|
||||
if not os.path.exists(db_path):
|
||||
os.mkdir(db_path)
|
||||
index.storage_context.persist(db_path)
|
||||
|
||||
# 清除tmp文件夹下内容
|
||||
def clear_tmp():
|
||||
if os.path.exists(os.path.join("File",TMP_NAME)):
|
||||
shutil.rmtree(os.path.join("File",TMP_NAME))
|
||||
if os.path.exists(os.path.join(DB_PATH,TMP_NAME)):
|
||||
shutil.rmtree(os.path.join(DB_PATH,TMP_NAME))
|
||||
150
local_rag/html_string.py
Normal file
150
local_rag/html_string.py
Normal file
@@ -0,0 +1,150 @@
|
||||
main_html = """<!DOCTYPE html>
|
||||
<html lang="zh">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>阿里云本地RAG解决方案</title>
|
||||
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
background-color: #f5f5f5;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
}
|
||||
header {
|
||||
background-color: #2196f3;
|
||||
color: white;
|
||||
width: 100%;
|
||||
padding: 1.5em;
|
||||
text-align: center;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
main {
|
||||
margin: 2em;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
background-color: white;
|
||||
border-radius: 8px;
|
||||
overflow: hidden;
|
||||
width: 90%;
|
||||
max-width: 800px;
|
||||
padding: 2em;
|
||||
}
|
||||
h1 {
|
||||
color: #333;
|
||||
}
|
||||
p {
|
||||
color: #666;
|
||||
font-size: 1.1em;
|
||||
}
|
||||
ul {
|
||||
list-style-type: none;
|
||||
padding: 0;
|
||||
}
|
||||
ul li {
|
||||
background-color: #2196f3;
|
||||
margin: 0.5em 0;
|
||||
padding: 1em;
|
||||
border-radius: 4px;
|
||||
transition: background-color 0.3s;
|
||||
}
|
||||
ul li a {
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
ul li:hover {
|
||||
background-color: #1976d2;
|
||||
}
|
||||
.material-icons {
|
||||
margin-right: 0.5em;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>阿里云本地RAG解决方案</h1>
|
||||
</header>
|
||||
<main>
|
||||
<p>如果您需要基于上传的文档与模型直接对话,请直接访问<a href="/chat">RAG问答</a>,并在输入框位置上传文件,就可以开始对话了。(此次上传的数据在页面刷新后无法保留,若您希望可以持久使用、维护知识库,请创建知识库)。</p>
|
||||
<p>如果您需要创建或更新知识库,请按照<a href="/upload_data">上传数据</a>、<a href="/create_knowledge_base">创建知识库</a>操作,在<a href="/chat">RAG问答</a>中的“知识库选择”位置选择您需要使用的知识库。</p>
|
||||
<p>如果您需要基于已创建好的知识库进行问答,请直接访问<a href="/chat">RAG问答</a>,在“加载知识库”处选择您已创建的知识库。</p>
|
||||
<ul>
|
||||
<li><a href="/upload_data"><span class="material-icons"></span> 1. 上传数据</a></li>
|
||||
<li><a href="/create_knowledge_base"><span class="material-icons"></span> 2. 创建知识库</a></li>
|
||||
<li><a href="/chat"><span class="material-icons"></span> 3. RAG问答</a></li>
|
||||
</ul>
|
||||
</main>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
plain_html = """<!DOCTYPE html>
|
||||
<html lang="zh">
|
||||
<head>
|
||||
<title>RAG问答</title>
|
||||
<link href="https://fonts.googleapis.com/icon?family=Material+Icons" rel="stylesheet">
|
||||
<style>
|
||||
.links-container {
|
||||
display: flex;
|
||||
justify-content: center; /* 在容器中居中分布子元素 */
|
||||
list-style-type: none; /* 去掉ul默认的列表样式 */
|
||||
padding: 0; /* 去掉ul默认的内边距 */
|
||||
margin: 0; /* 去掉ul默认的外边距 */
|
||||
}
|
||||
.links-container li {
|
||||
margin: 0 5px; /* 每个li元素的左右留出一些空间 */
|
||||
padding: 10px 15px; /* 添加内边距 */
|
||||
border: 1px solid #ccc; /* 添加边框 */
|
||||
border-radius: 5px; /* 添加圆角 */
|
||||
background-color: #f9f9f9; /* 背景颜色 */
|
||||
transition: background-color 0.3s; /* 背景颜色变化的过渡效果 */
|
||||
display: flex; /* 使用flex布局 */
|
||||
align-items: center; /* 垂直居中对齐 */
|
||||
height: 50px; /* 设置固定高度,确保一致 */
|
||||
}
|
||||
.links-container li:hover {
|
||||
background-color: #e0e0e0; /* 悬停时的背景颜色 */
|
||||
}
|
||||
.links-container a {
|
||||
text-decoration: none !important; /* 去掉链接的下划线 */
|
||||
color: #333; /* 链接颜色 */
|
||||
font-family: Arial, sans-serif; /* 字体 */
|
||||
font-size: 14px; /* 字体大小 */
|
||||
display: flex; /* 使用flex布局 */
|
||||
align-items: center; /* 垂直居中对齐 */
|
||||
height: 100%; /* 确保链接高度与父元素一致 */
|
||||
}
|
||||
.material-icons {
|
||||
font-size: 20px; /* 图标大小 */
|
||||
margin-right: 8px; /* 图标和文字间的间距 */
|
||||
text-decoration: none; /* 确保图标没有下划线 */
|
||||
}
|
||||
|
||||
/* 深色模式样式 */
|
||||
@media (prefers-color-scheme: dark) {
|
||||
.links-container li {
|
||||
background-color: #333; /* 深色模式下的背景颜色 */
|
||||
border-color: #555; /* 深色模式下的边框颜色 */
|
||||
}
|
||||
.links-container li:hover {
|
||||
background-color: #555; /* 深色模式下悬停时的背景颜色 */
|
||||
}
|
||||
.links-container a {
|
||||
color: #f9f9f9; /* 深色模式下的文字颜色 */
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<ul class="links-container">
|
||||
<li><a href="/"><span class="material-icons">home</span> 主页</a></li>
|
||||
<li><a href="/upload_data"><span class="material-icons">cloud_upload</span> 上传数据</a></li>
|
||||
<li><a href="/create_knowledge_base"><span class="material-icons">library_add</span> 创建知识库</a></li>
|
||||
<li><a href="/chat"><span class="material-icons">question_answer</span> RAG问答</a></li>
|
||||
</ul>
|
||||
</body>
|
||||
</html>"""
|
||||
BIN
local_rag/images/tongyi.png
Normal file
BIN
local_rag/images/tongyi.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 121 KiB |
BIN
local_rag/images/user.jpeg
Normal file
BIN
local_rag/images/user.jpeg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 14 KiB |
116
local_rag/main.py
Normal file
116
local_rag/main.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import HTMLResponse
|
||||
import gradio as gr
|
||||
import os
|
||||
from html_string import main_html,plain_html
|
||||
from upload_file import *
|
||||
from create_kb import *
|
||||
from chat import get_model_response
|
||||
def user(user_message, history):
|
||||
print(user_message)
|
||||
return {'text': '','files': user_message['files']}, history + [[user_message['text'], None]]
|
||||
|
||||
#####################################
|
||||
###### gradio界面 #######
|
||||
#####################################
|
||||
|
||||
def get_chat_block():
|
||||
with gr.Blocks(theme=gr.themes.Base(),css=".gradio_container { background-color: #f0f0f0; }") as chat:
|
||||
gr.HTML(plain_html)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
chatbot = gr.Chatbot(label="Chatbot",height=750,avatar_images=("images/user.jpeg","images/tongyi.png"))
|
||||
with gr.Row():
|
||||
#
|
||||
input_message = gr.MultimodalTextbox(label="请输入",file_types=[".xlsx",".csv",".docx",".pdf",".txt"],scale=7)
|
||||
clear_btn = gr.ClearButton(chatbot,input_message,scale=1)
|
||||
# 模型与知识库参数
|
||||
with gr.Column(scale=5):
|
||||
knowledge_base =gr.Dropdown(choices=os.listdir(DB_PATH),label="加载知识库",interactive=True,scale=2)
|
||||
with gr.Accordion(label="召回文本段",open=False):
|
||||
chunk_text = gr.Textbox(label="召回文本段",interactive=False,scale=5,lines=10)
|
||||
with gr.Accordion(label="模型设置",open=True):
|
||||
model =gr.Dropdown(choices=['qwen-max','qwen-plus','qwen-turbo'],label="选择模型",interactive=True,value="qwen-max",scale=2)
|
||||
temperature = gr.Slider(maximum=2,minimum=0,interactive=True,label="温度参数",step=0.01,value=0.85,scale=2)
|
||||
max_tokens = gr.Slider(maximum=8000,minimum=0,interactive=True,label="最大回复长度",step=50,value=1024,scale=2)
|
||||
history_round = gr.Slider(maximum=30,minimum=1,interactive=True,label="携带上下文轮数",step=1,value=3,scale=2)
|
||||
with gr.Accordion(label="RAG参数设置",open=True):
|
||||
chunk_cnt = gr.Slider(maximum=20,minimum=1,interactive=True,label="选择召回片段数",step=1,value=5,scale=2)
|
||||
similarity_threshold = gr.Slider(maximum=1,minimum=0,interactive=True,label="相似度阈值",step=0.01,value=0.2,scale=2)
|
||||
input_message.submit(fn=user,inputs=[input_message,chatbot],outputs=[input_message,chatbot],queue=False).then(
|
||||
fn=get_model_response,inputs=[input_message,chatbot,model,temperature,max_tokens,history_round,knowledge_base,similarity_threshold,chunk_cnt],outputs=[chatbot,chunk_text]
|
||||
)
|
||||
chat.load(update_knowledge_base,[],knowledge_base)
|
||||
chat.load(clear_tmp)
|
||||
return chat
|
||||
|
||||
|
||||
def get_upload_block():
|
||||
with gr.Blocks(theme=gr.themes.Base()) as upload:
|
||||
gr.HTML(plain_html)
|
||||
with gr.Tab("非结构化数据"):
|
||||
with gr.Accordion(label="新建类目",open=True):
|
||||
with gr.Column(scale=2):
|
||||
unstructured_file = gr.Files(file_types=["pdf","docx","txt"])
|
||||
with gr.Row():
|
||||
new_label = gr.Textbox(label="类目名称",placeholder="请输入类目名称",scale=5)
|
||||
create_label_btn = gr.Button("新建类目",variant="primary",scale=1)
|
||||
with gr.Accordion(label="管理类目",open=False):
|
||||
with gr.Row():
|
||||
data_label =gr.Dropdown(choices=os.listdir(UNSTRUCTURED_FILE_PATH),label="管理类目",interactive=True,scale=8,multiselect=True)
|
||||
delete_label_btn = gr.Button("删除类目",variant="stop",scale=1)
|
||||
with gr.Tab("结构化数据"):
|
||||
with gr.Accordion(label="新建数据表",open=True):
|
||||
with gr.Column(scale=2):
|
||||
structured_file = gr.Files(file_types=["xlsx","csv"])
|
||||
with gr.Row():
|
||||
new_label_1 = gr.Textbox(label="数据表名称",placeholder="请输入数据表名称",scale=5)
|
||||
create_label_btn_1 = gr.Button("新建数据表",variant="primary",scale=1)
|
||||
with gr.Accordion(label="管理数据表",open=False):
|
||||
with gr.Row():
|
||||
data_label_1 =gr.Dropdown(choices=os.listdir(STRUCTURED_FILE_PATH),label="管理数据表",interactive=True,scale=8,multiselect=True)
|
||||
delete_data_table_btn = gr.Button("删除数据表",variant="stop",scale=1)
|
||||
delete_label_btn.click(delete_label,inputs=[data_label]).then(fn=update_label,outputs=[data_label])
|
||||
create_label_btn.click(fn=upload_unstructured_file,inputs=[unstructured_file,new_label]).then(fn=update_label,outputs=[data_label])
|
||||
delete_data_table_btn.click(delete_data_table,inputs=[data_label_1]).then(fn=update_datatable,outputs=[data_label_1])
|
||||
create_label_btn_1.click(fn=upload_structured_file,inputs=[structured_file,new_label_1]).then(fn=update_datatable,outputs=[data_label_1])
|
||||
upload.load(update_label,[],data_label)
|
||||
upload.load(update_datatable,[],data_label_1)
|
||||
return upload
|
||||
|
||||
def get_knowledge_base_block():
|
||||
with gr.Blocks(theme=gr.themes.Base()) as knowledge:
|
||||
gr.HTML(plain_html)
|
||||
# 非结构化数据知识库
|
||||
with gr.Tab("非结构化数据"):
|
||||
with gr.Row():
|
||||
data_label_2 =gr.Dropdown(choices=os.listdir(UNSTRUCTURED_FILE_PATH),label="选择类目",interactive=True,scale=2,multiselect=True)
|
||||
knowledge_base_name = gr.Textbox(label="知识库名称",placeholder="请输入知识库名称",scale=2)
|
||||
create_knowledge_base_btn = gr.Button("确认创建知识库",variant="primary",scale=1)
|
||||
# 结构化数据知识库
|
||||
with gr.Tab("结构化数据"):
|
||||
with gr.Row():
|
||||
data_label_3 =gr.Dropdown(choices=os.listdir(STRUCTURED_FILE_PATH),label="选择数据表",interactive=True,scale=2,multiselect=True)
|
||||
knowledge_base_name_1 = gr.Textbox(label="知识库名称",placeholder="请输入知识库名称",scale=2)
|
||||
create_knowledge_base_btn_1 = gr.Button("确认创建知识库",variant="primary",scale=1)
|
||||
with gr.Row():
|
||||
knowledge_base =gr.Dropdown(choices=os.listdir(DB_PATH),label="管理知识库",interactive=True,scale=4)
|
||||
delete_db_btn = gr.Button("删除知识库",variant="stop",scale=1)
|
||||
create_knowledge_base_btn.click(fn=create_unstructured_db,inputs=[knowledge_base_name,data_label_2]).then(update_knowledge_base,outputs=[knowledge_base])
|
||||
delete_db_btn.click(delete_db,inputs=[knowledge_base]).then(update_knowledge_base,outputs=[knowledge_base])
|
||||
create_knowledge_base_btn_1.click(fn=create_structured_db,inputs=[knowledge_base_name_1,data_label_3]).then(update_knowledge_base,outputs=[knowledge_base])
|
||||
knowledge.load(update_knowledge_base,[],knowledge_base)
|
||||
knowledge.load(update_label,[],data_label_2)
|
||||
knowledge.load(update_datatable,[],data_label_3)
|
||||
return knowledge
|
||||
|
||||
app = FastAPI()
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
def read_main():
|
||||
html_content = main_html
|
||||
return HTMLResponse(content=html_content)
|
||||
|
||||
|
||||
app = gr.mount_gradio_app(app, get_chat_block(), path="/chat")
|
||||
app = gr.mount_gradio_app(app, get_upload_block(), path="/upload_data")
|
||||
app = gr.mount_gradio_app(app, get_knowledge_base_block(), path="/create_knowledge_base")
|
||||
26
local_rag/requirements.txt
Normal file
26
local_rag/requirements.txt
Normal file
@@ -0,0 +1,26 @@
|
||||
gradio==4.32.0
|
||||
faiss-cpu==1.8.0.post1
|
||||
dashscope==1.20.4
|
||||
openai==1.55.3
|
||||
httpx==0.27.0
|
||||
llama-index-vector-stores-faiss==0.1.2
|
||||
llama-index-embeddings-dashscope==0.1.4
|
||||
llama-index-readers-file==0.1.33
|
||||
matplotlib==3.9.3
|
||||
docx2txt==0.8
|
||||
openpyxl==3.1.5
|
||||
llama-index-core==0.10.67
|
||||
uvicorn==0.30.6
|
||||
fastapi==0.112.0
|
||||
llama-index-postprocessor-dashscope-rerank-custom==0.1.0
|
||||
simplejson==3.19.3
|
||||
pydantic==2.10.6
|
||||
# modelscope==1.18.0
|
||||
# langchain_community==0.2.16
|
||||
# transformers==4.44.2
|
||||
# llama_index.embeddings.huggingface==0.2.3
|
||||
# llama-index-embeddings-langchain==0.1.2
|
||||
# datasets==2.21.0
|
||||
# oss2==2.19.0
|
||||
# sortedcontainers==2.4.0
|
||||
# addict==2.4.0s
|
||||
107
local_rag/upload_file.py
Normal file
107
local_rag/upload_file.py
Normal file
@@ -0,0 +1,107 @@
|
||||
#####################################
|
||||
####### 上传文件 #######
|
||||
#####################################
|
||||
import gradio as gr
|
||||
import os
|
||||
import shutil
|
||||
import pandas as pd
|
||||
STRUCTURED_FILE_PATH = "File/Structured"
|
||||
UNSTRUCTURED_FILE_PATH = "File/Unstructured"
|
||||
# 刷新非结构化类目
|
||||
def refresh_label():
|
||||
return os.listdir(UNSTRUCTURED_FILE_PATH)
|
||||
|
||||
# 刷新结构化数据表
|
||||
def refresh_data_table():
|
||||
return os.listdir(STRUCTURED_FILE_PATH)
|
||||
|
||||
# 上传非结构化数据
|
||||
def upload_unstructured_file(files,label_name):
|
||||
if files is None:
|
||||
gr.Info("请上传文件")
|
||||
elif len(label_name) == 0:
|
||||
gr.Info("请输入类目名称")
|
||||
# 判断类目是否存在
|
||||
elif label_name in os.listdir(UNSTRUCTURED_FILE_PATH):
|
||||
gr.Info(f"{label_name}类目已存在")
|
||||
else:
|
||||
try:
|
||||
if not os.path.exists(os.path.join(UNSTRUCTURED_FILE_PATH,label_name)):
|
||||
os.mkdir(os.path.join(UNSTRUCTURED_FILE_PATH,label_name))
|
||||
for file in files:
|
||||
print(file)
|
||||
file_path = file.name
|
||||
file_name = os.path.basename(file_path)
|
||||
destination_file_path = os.path.join(UNSTRUCTURED_FILE_PATH,label_name,file_name)
|
||||
shutil.move(file_path,destination_file_path)
|
||||
gr.Info(f"文件已上传至{label_name}类目中,请前往创建知识库")
|
||||
except:
|
||||
gr.Info(f"请勿重复上传")
|
||||
|
||||
# 上传结构化数据
|
||||
def upload_structured_file(files,label_name):
|
||||
if files is None:
|
||||
gr.Info("请上传文件")
|
||||
elif len(label_name) == 0:
|
||||
gr.Info("请输入数据表名称")
|
||||
# 判断数据表是否存在
|
||||
elif label_name in os.listdir(STRUCTURED_FILE_PATH):
|
||||
gr.Info(f"{label_name}数据表已存在")
|
||||
else:
|
||||
try:
|
||||
if not os.path.exists(os.path.join(STRUCTURED_FILE_PATH,label_name)):
|
||||
os.mkdir(os.path.join(STRUCTURED_FILE_PATH,label_name))
|
||||
for file in files:
|
||||
file_path = file.name
|
||||
file_name = os.path.basename(file_path)
|
||||
destination_file_path = os.path.join(STRUCTURED_FILE_PATH,label_name,file_name)
|
||||
shutil.move(file_path,destination_file_path)
|
||||
if os.path.splitext(destination_file_path)[1] == ".xlsx":
|
||||
df = pd.read_excel(destination_file_path)
|
||||
elif os.path.splitext(destination_file_path)[1] == ".csv":
|
||||
df = pd.read_csv(destination_file_path)
|
||||
txt_file_name = os.path.splitext(file_name)[0]+'.txt'
|
||||
columns = df.columns
|
||||
with open(os.path.join(STRUCTURED_FILE_PATH,label_name,txt_file_name),"w") as file:
|
||||
for idx,row in df.iterrows():
|
||||
file.write("【")
|
||||
info = []
|
||||
for col in columns:
|
||||
info.append(f"{col}:{row[col]}")
|
||||
infos = ",".join(info)
|
||||
file.write(infos)
|
||||
if idx != len(df)-1:
|
||||
file.write("】\n")
|
||||
else:
|
||||
file.write("】")
|
||||
os.remove(destination_file_path)
|
||||
gr.Info(f"文件已上传至{label_name}数据表中,请前往创建知识库")
|
||||
except:
|
||||
gr.Info(f"请勿重复上传")
|
||||
|
||||
# 实时更新结构化数据表
|
||||
def update_datatable():
|
||||
return gr.update(choices=os.listdir(STRUCTURED_FILE_PATH))
|
||||
|
||||
|
||||
# 实时更新非结构化类目
|
||||
def update_label():
|
||||
return gr.update(choices=os.listdir(UNSTRUCTURED_FILE_PATH))
|
||||
|
||||
# 删除类目
|
||||
def delete_label(label_name):
|
||||
if label_name is not None:
|
||||
for label in label_name:
|
||||
folder_path = os.path.join(UNSTRUCTURED_FILE_PATH,label)
|
||||
if os.path.exists(folder_path):
|
||||
shutil.rmtree(folder_path)
|
||||
gr.Info(f"{label}类目已删除")
|
||||
|
||||
# 删除数据表
|
||||
def delete_data_table(table_name):
|
||||
if table_name is not None:
|
||||
for table in table_name:
|
||||
folder_path = os.path.join(STRUCTURED_FILE_PATH,table)
|
||||
if os.path.exists(folder_path):
|
||||
shutil.rmtree(folder_path)
|
||||
gr.Info(f"{table}数据表已删除")
|
||||
Reference in New Issue
Block a user