This commit is contained in:
2025-08-27 22:22:18 +08:00
parent f6c7c65d6c
commit e5362b80e2
32 changed files with 914 additions and 0 deletions

2
local_rag/README.md Normal file
View File

@@ -0,0 +1,2 @@
Run: uvicorn main:app --port 7866
Then visit 127.0.0.1:7866

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1 @@
{"graph_dict": {}}

View File

@@ -0,0 +1 @@
{"embedding_dict": {}, "text_id_to_ref_doc_id": {}, "metadata_dict": {}}

View File

@@ -0,0 +1 @@
{"index_store/data": {"ab9a86f4-0029-48e5-b823-eccfc6f58622": {"__type__": "vector_store", "__data__": "{\"index_id\": \"ab9a86f4-0029-48e5-b823-eccfc6f58622\", \"summary\": null, \"nodes_dict\": {\"c7e14d46-b930-4a98-acaf-2854ac14c1de\": \"c7e14d46-b930-4a98-acaf-2854ac14c1de\", \"fb1f2a38-146e-4840-8f0d-221033dbc849\": \"fb1f2a38-146e-4840-8f0d-221033dbc849\"}, \"doc_id_dict\": {}, \"embeddings_dict\": {}}"}}}

93
local_rag/chat.py Normal file
View File

@@ -0,0 +1,93 @@
import os
from openai import OpenAI
from llama_index.core import StorageContext,load_index_from_storage,Settings
from llama_index.embeddings.dashscope import (
DashScopeEmbedding,
DashScopeTextEmbeddingModels,
DashScopeTextEmbeddingType,
)
from llama_index.postprocessor.dashscope_rerank import DashScopeRerank
from create_kb import *
DB_PATH = "VectorStore"
TMP_NAME = "tmp_abcd"
EMBED_MODEL = DashScopeEmbedding(
model_name=DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V2,
text_type=DashScopeTextEmbeddingType.TEXT_TYPE_DOCUMENT,
)
# 若使用本地嵌入模型,请取消以下注释:
# from langchain_community.embeddings import ModelScopeEmbeddings
# from llama_index.embeddings.langchain import LangchainEmbedding
# embeddings = ModelScopeEmbeddings(model_id="modelscope/iic/nlp_gte_sentence-embedding_chinese-large")
# EMBED_MODEL = LangchainEmbedding(embeddings)
# 设置嵌入模型
Settings.embed_model = EMBED_MODEL
def get_model_response(multi_modal_input,history,model,temperature,max_tokens,history_round,db_name,similarity_threshold,chunk_cnt):
# prompt = multi_modal_input['text']
prompt = history[-1][0]
tmp_files = multi_modal_input['files']
if os.path.exists(os.path.join("File",TMP_NAME)):
db_name = TMP_NAME
else:
if tmp_files:
create_tmp_kb(tmp_files)
db_name = TMP_NAME
# 获取index
print(f"prompt:{prompt},tmp_files:{tmp_files},db_name:{db_name}")
try:
dashscope_rerank = DashScopeRerank(top_n=chunk_cnt,return_documents=True)
storage_context = StorageContext.from_defaults(
persist_dir=os.path.join(DB_PATH,db_name)
)
index = load_index_from_storage(storage_context)
print("index获取完成")
retriever_engine = index.as_retriever(
similarity_top_k=20,
)
# 获取chunk
retrieve_chunk = retriever_engine.retrieve(prompt)
print(f"原始chunk为{retrieve_chunk}")
try:
results = dashscope_rerank.postprocess_nodes(retrieve_chunk, query_str=prompt)
print(f"rerank成功重排后的chunk为{results}")
except:
results = retrieve_chunk[:chunk_cnt]
print(f"rerank失败chunk为{results}")
chunk_text = ""
chunk_show = ""
for i in range(len(results)):
if results[i].score >= similarity_threshold:
chunk_text = chunk_text + f"## {i+1}:\n {results[i].text}\n"
chunk_show = chunk_show + f"## {i+1}:\n {results[i].text}\nscore: {round(results[i].score,2)}\n"
print(f"已获取chunk{chunk_text}")
prompt_template = f"请参考以下内容:{chunk_text},以合适的语气回答用户的问题:{prompt}。如果参考内容中有图片链接也请直接返回。"
except Exception as e:
print(f"异常信息:{e}")
prompt_template = prompt
chunk_show = ""
history[-1][-1] = ""
client = OpenAI(
api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
system_message = {'role': 'system', 'content': 'You are a helpful assistant.'}
messages = []
history_round = min(len(history),history_round)
for i in range(history_round):
messages.append({'role': 'user', 'content': history[-history_round+i][0]})
messages.append({'role': 'assistant', 'content': history[-history_round+i][1]})
messages.append({'role': 'user', 'content': prompt_template})
messages = [system_message] + messages
completion = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True
)
assistant_response = ""
for chunk in completion:
assistant_response += chunk.choices[0].delta.content
history[-1][-1] = assistant_response
yield history,chunk_show

134
local_rag/create_kb.py Normal file
View File

@@ -0,0 +1,134 @@
#####################################
###### 创建知识库 #######
#####################################
import gradio as gr
import os
import shutil
from llama_index.core import VectorStoreIndex,Settings,SimpleDirectoryReader
from llama_index.embeddings.dashscope import (
DashScopeEmbedding,
DashScopeTextEmbeddingModels,
DashScopeTextEmbeddingType,
)
from llama_index.core.schema import TextNode
from upload_file import *
DB_PATH = "VectorStore"
STRUCTURED_FILE_PATH = "File/Structured"
UNSTRUCTURED_FILE_PATH = "File/Unstructured"
TMP_NAME = "tmp_abcd"
EMBED_MODEL = DashScopeEmbedding(
model_name=DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V2,
text_type=DashScopeTextEmbeddingType.TEXT_TYPE_DOCUMENT,
)
# 若使用本地嵌入模型,请取消以下注释:
# from langchain_community.embeddings import ModelScopeEmbeddings
# from llama_index.embeddings.langchain import LangchainEmbedding
# embeddings = ModelScopeEmbeddings(model_id="modelscope/iic/nlp_gte_sentence-embedding_chinese-large")
# EMBED_MODEL = LangchainEmbedding(embeddings)
# 设置嵌入模型
Settings.embed_model = EMBED_MODEL
# 刷新知识库
def refresh_knowledge_base():
return os.listdir(DB_PATH)
# 创建非结构化向量数据库
def create_unstructured_db(db_name:str,label_name:list):
print(f"知识库名称为:{db_name},类目名称为:{label_name}")
if label_name is None:
gr.Info("没有选择类目")
elif len(db_name) == 0:
gr.Info("没有命名知识库")
# 判断是否存在同名向量数据库
elif db_name in os.listdir(DB_PATH):
gr.Info("知识库已存在,请换个名字或删除原来知识库再创建")
else:
gr.Info("正在创建知识库请等待知识库创建成功信息显示后前往RAG问答")
documents = []
for label in label_name:
label_path = os.path.join(UNSTRUCTURED_FILE_PATH,label)
documents.extend(SimpleDirectoryReader(label_path).load_data())
index = VectorStoreIndex.from_documents(
documents
)
db_path = os.path.join(DB_PATH,db_name)
if not os.path.exists(db_path):
os.mkdir(db_path)
index.storage_context.persist(db_path)
elif os.path.exists(db_path):
pass
gr.Info("知识库创建成功可前往RAG问答进行提问")
# 创建结构化向量数据库
def create_structured_db(db_name:str,data_table:list):
print(f"知识库名称为:{db_name},数据表名称为:{data_table}")
if data_table is None:
gr.Info("没有选择数据表")
elif len(db_name) == 0:
gr.Info("没有命名知识库")
# 判断是否存在同名向量数据库
elif db_name in os.listdir(DB_PATH):
gr.Info("知识库已存在,请换个名字或删除原来知识库再创建")
else:
gr.Info("正在创建知识库请等待知识库创建成功信息显示后前往RAG问答")
documents = []
for label in data_table:
label_path = os.path.join(STRUCTURED_FILE_PATH,label)
documents.extend(SimpleDirectoryReader(label_path).load_data())
# index = VectorStoreIndex.from_documents(
# documents
# )
nodes = []
for doc in documents:
doc_content = doc.get_content().split('\n')
for chunk in doc_content:
node = TextNode(text=chunk)
node.metadata = {'source': doc.get_doc_id(),'file_name':doc.metadata['file_name']}
nodes = nodes + [node]
index = VectorStoreIndex(nodes)
db_path = os.path.join(DB_PATH,db_name)
if not os.path.exists(db_path):
os.mkdir(db_path)
index.storage_context.persist(db_path)
gr.Info("知识库创建成功可前往RAG问答进行提问")
# 删除指定名称知识库
def delete_db(db_name:str):
if db_name is not None:
folder_path = os.path.join(DB_PATH, db_name)
if os.path.exists(folder_path):
shutil.rmtree(folder_path)
gr.Info(f"已成功删除{db_name}知识库")
print(f"已成功删除{db_name}知识库")
else:
gr.Info(f"{db_name}知识库不存在")
print(f"{db_name}知识库不存在")
# 实时更新知识库列表
def update_knowledge_base():
return gr.update(choices=os.listdir(DB_PATH))
# 临时文件创建知识库
def create_tmp_kb(files):
if not os.path.exists(os.path.join("File",TMP_NAME)):
os.mkdir(os.path.join("File",TMP_NAME))
for file in files:
file_name = os.path.basename(file)
shutil.move(file,os.path.join("File",TMP_NAME,file_name))
documents = SimpleDirectoryReader(os.path.join("File",TMP_NAME)).load_data()
index = VectorStoreIndex.from_documents(
documents
)
db_path = os.path.join(DB_PATH,TMP_NAME)
if not os.path.exists(db_path):
os.mkdir(db_path)
index.storage_context.persist(db_path)
# 清除tmp文件夹下内容
def clear_tmp():
if os.path.exists(os.path.join("File",TMP_NAME)):
shutil.rmtree(os.path.join("File",TMP_NAME))
if os.path.exists(os.path.join(DB_PATH,TMP_NAME)):
shutil.rmtree(os.path.join(DB_PATH,TMP_NAME))

150
local_rag/html_string.py Normal file
View File

@@ -0,0 +1,150 @@
main_html = """<!DOCTYPE html>
<html lang="zh">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>阿里云本地RAG解决方案</title>
<link rel="stylesheet" href="https://fonts.googleapis.com/icon?family=Material+Icons">
<style>
body {
font-family: Arial, sans-serif;
background-color: #f5f5f5;
margin: 0;
padding: 0;
display: flex;
flex-direction: column;
align-items: center;
}
header {
background-color: #2196f3;
color: white;
width: 100%;
padding: 1.5em;
text-align: center;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}
main {
margin: 2em;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
background-color: white;
border-radius: 8px;
overflow: hidden;
width: 90%;
max-width: 800px;
padding: 2em;
}
h1 {
color: #333;
}
p {
color: #666;
font-size: 1.1em;
}
ul {
list-style-type: none;
padding: 0;
}
ul li {
background-color: #2196f3;
margin: 0.5em 0;
padding: 1em;
border-radius: 4px;
transition: background-color 0.3s;
}
ul li a {
color: white;
text-decoration: none;
display: flex;
align-items: center;
}
ul li:hover {
background-color: #1976d2;
}
.material-icons {
margin-right: 0.5em;
}
</style>
</head>
<body>
<header>
<h1>阿里云本地RAG解决方案</h1>
</header>
<main>
<p>如果您需要基于上传的文档与模型直接对话,请直接访问<a href="/chat">RAG问答</a>,并在输入框位置上传文件,就可以开始对话了。(此次上传的数据在页面刷新后无法保留,若您希望可以持久使用、维护知识库,请创建知识库)。</p>
<p>如果您需要创建或更新知识库,请按照<a href="/upload_data">上传数据</a>、<a href="/create_knowledge_base">创建知识库</a>操作,在<a href="/chat">RAG问答</a>中的“知识库选择”位置选择您需要使用的知识库。</p>
<p>如果您需要基于已创建好的知识库进行问答,请直接访问<a href="/chat">RAG问答</a>,在“加载知识库”处选择您已创建的知识库。</p>
<ul>
<li><a href="/upload_data"><span class="material-icons"></span> 1. 上传数据</a></li>
<li><a href="/create_knowledge_base"><span class="material-icons"></span> 2. 创建知识库</a></li>
<li><a href="/chat"><span class="material-icons"></span> 3. RAG问答</a></li>
</ul>
</main>
</body>
</html>"""
plain_html = """<!DOCTYPE html>
<html lang="zh">
<head>
<title>RAG问答</title>
<link href="https://fonts.googleapis.com/icon?family=Material+Icons" rel="stylesheet">
<style>
.links-container {
display: flex;
justify-content: center; /* 在容器中居中分布子元素 */
list-style-type: none; /* 去掉ul默认的列表样式 */
padding: 0; /* 去掉ul默认的内边距 */
margin: 0; /* 去掉ul默认的外边距 */
}
.links-container li {
margin: 0 5px; /* 每个li元素的左右留出一些空间 */
padding: 10px 15px; /* 添加内边距 */
border: 1px solid #ccc; /* 添加边框 */
border-radius: 5px; /* 添加圆角 */
background-color: #f9f9f9; /* 背景颜色 */
transition: background-color 0.3s; /* 背景颜色变化的过渡效果 */
display: flex; /* 使用flex布局 */
align-items: center; /* 垂直居中对齐 */
height: 50px; /* 设置固定高度,确保一致 */
}
.links-container li:hover {
background-color: #e0e0e0; /* 悬停时的背景颜色 */
}
.links-container a {
text-decoration: none !important; /* 去掉链接的下划线 */
color: #333; /* 链接颜色 */
font-family: Arial, sans-serif; /* 字体 */
font-size: 14px; /* 字体大小 */
display: flex; /* 使用flex布局 */
align-items: center; /* 垂直居中对齐 */
height: 100%; /* 确保链接高度与父元素一致 */
}
.material-icons {
font-size: 20px; /* 图标大小 */
margin-right: 8px; /* 图标和文字间的间距 */
text-decoration: none; /* 确保图标没有下划线 */
}
/* 深色模式样式 */
@media (prefers-color-scheme: dark) {
.links-container li {
background-color: #333; /* 深色模式下的背景颜色 */
border-color: #555; /* 深色模式下的边框颜色 */
}
.links-container li:hover {
background-color: #555; /* 深色模式下悬停时的背景颜色 */
}
.links-container a {
color: #f9f9f9; /* 深色模式下的文字颜色 */
}
}
</style>
</head>
<body>
<ul class="links-container">
<li><a href="/"><span class="material-icons">home</span> 主页</a></li>
<li><a href="/upload_data"><span class="material-icons">cloud_upload</span> 上传数据</a></li>
<li><a href="/create_knowledge_base"><span class="material-icons">library_add</span> 创建知识库</a></li>
<li><a href="/chat"><span class="material-icons">question_answer</span> RAG问答</a></li>
</ul>
</body>
</html>"""

BIN
local_rag/images/tongyi.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 121 KiB

BIN
local_rag/images/user.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

116
local_rag/main.py Normal file
View File

@@ -0,0 +1,116 @@
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
import gradio as gr
import os
from html_string import main_html,plain_html
from upload_file import *
from create_kb import *
from chat import get_model_response
def user(user_message, history):
print(user_message)
return {'text': '','files': user_message['files']}, history + [[user_message['text'], None]]
#####################################
###### gradio界面 #######
#####################################
def get_chat_block():
with gr.Blocks(theme=gr.themes.Base(),css=".gradio_container { background-color: #f0f0f0; }") as chat:
gr.HTML(plain_html)
with gr.Row():
with gr.Column(scale=10):
chatbot = gr.Chatbot(label="Chatbot",height=750,avatar_images=("images/user.jpeg","images/tongyi.png"))
with gr.Row():
#
input_message = gr.MultimodalTextbox(label="请输入",file_types=[".xlsx",".csv",".docx",".pdf",".txt"],scale=7)
clear_btn = gr.ClearButton(chatbot,input_message,scale=1)
# 模型与知识库参数
with gr.Column(scale=5):
knowledge_base =gr.Dropdown(choices=os.listdir(DB_PATH),label="加载知识库",interactive=True,scale=2)
with gr.Accordion(label="召回文本段",open=False):
chunk_text = gr.Textbox(label="召回文本段",interactive=False,scale=5,lines=10)
with gr.Accordion(label="模型设置",open=True):
model =gr.Dropdown(choices=['qwen-max','qwen-plus','qwen-turbo'],label="选择模型",interactive=True,value="qwen-max",scale=2)
temperature = gr.Slider(maximum=2,minimum=0,interactive=True,label="温度参数",step=0.01,value=0.85,scale=2)
max_tokens = gr.Slider(maximum=8000,minimum=0,interactive=True,label="最大回复长度",step=50,value=1024,scale=2)
history_round = gr.Slider(maximum=30,minimum=1,interactive=True,label="携带上下文轮数",step=1,value=3,scale=2)
with gr.Accordion(label="RAG参数设置",open=True):
chunk_cnt = gr.Slider(maximum=20,minimum=1,interactive=True,label="选择召回片段数",step=1,value=5,scale=2)
similarity_threshold = gr.Slider(maximum=1,minimum=0,interactive=True,label="相似度阈值",step=0.01,value=0.2,scale=2)
input_message.submit(fn=user,inputs=[input_message,chatbot],outputs=[input_message,chatbot],queue=False).then(
fn=get_model_response,inputs=[input_message,chatbot,model,temperature,max_tokens,history_round,knowledge_base,similarity_threshold,chunk_cnt],outputs=[chatbot,chunk_text]
)
chat.load(update_knowledge_base,[],knowledge_base)
chat.load(clear_tmp)
return chat
def get_upload_block():
with gr.Blocks(theme=gr.themes.Base()) as upload:
gr.HTML(plain_html)
with gr.Tab("非结构化数据"):
with gr.Accordion(label="新建类目",open=True):
with gr.Column(scale=2):
unstructured_file = gr.Files(file_types=["pdf","docx","txt"])
with gr.Row():
new_label = gr.Textbox(label="类目名称",placeholder="请输入类目名称",scale=5)
create_label_btn = gr.Button("新建类目",variant="primary",scale=1)
with gr.Accordion(label="管理类目",open=False):
with gr.Row():
data_label =gr.Dropdown(choices=os.listdir(UNSTRUCTURED_FILE_PATH),label="管理类目",interactive=True,scale=8,multiselect=True)
delete_label_btn = gr.Button("删除类目",variant="stop",scale=1)
with gr.Tab("结构化数据"):
with gr.Accordion(label="新建数据表",open=True):
with gr.Column(scale=2):
structured_file = gr.Files(file_types=["xlsx","csv"])
with gr.Row():
new_label_1 = gr.Textbox(label="数据表名称",placeholder="请输入数据表名称",scale=5)
create_label_btn_1 = gr.Button("新建数据表",variant="primary",scale=1)
with gr.Accordion(label="管理数据表",open=False):
with gr.Row():
data_label_1 =gr.Dropdown(choices=os.listdir(STRUCTURED_FILE_PATH),label="管理数据表",interactive=True,scale=8,multiselect=True)
delete_data_table_btn = gr.Button("删除数据表",variant="stop",scale=1)
delete_label_btn.click(delete_label,inputs=[data_label]).then(fn=update_label,outputs=[data_label])
create_label_btn.click(fn=upload_unstructured_file,inputs=[unstructured_file,new_label]).then(fn=update_label,outputs=[data_label])
delete_data_table_btn.click(delete_data_table,inputs=[data_label_1]).then(fn=update_datatable,outputs=[data_label_1])
create_label_btn_1.click(fn=upload_structured_file,inputs=[structured_file,new_label_1]).then(fn=update_datatable,outputs=[data_label_1])
upload.load(update_label,[],data_label)
upload.load(update_datatable,[],data_label_1)
return upload
def get_knowledge_base_block():
with gr.Blocks(theme=gr.themes.Base()) as knowledge:
gr.HTML(plain_html)
# 非结构化数据知识库
with gr.Tab("非结构化数据"):
with gr.Row():
data_label_2 =gr.Dropdown(choices=os.listdir(UNSTRUCTURED_FILE_PATH),label="选择类目",interactive=True,scale=2,multiselect=True)
knowledge_base_name = gr.Textbox(label="知识库名称",placeholder="请输入知识库名称",scale=2)
create_knowledge_base_btn = gr.Button("确认创建知识库",variant="primary",scale=1)
# 结构化数据知识库
with gr.Tab("结构化数据"):
with gr.Row():
data_label_3 =gr.Dropdown(choices=os.listdir(STRUCTURED_FILE_PATH),label="选择数据表",interactive=True,scale=2,multiselect=True)
knowledge_base_name_1 = gr.Textbox(label="知识库名称",placeholder="请输入知识库名称",scale=2)
create_knowledge_base_btn_1 = gr.Button("确认创建知识库",variant="primary",scale=1)
with gr.Row():
knowledge_base =gr.Dropdown(choices=os.listdir(DB_PATH),label="管理知识库",interactive=True,scale=4)
delete_db_btn = gr.Button("删除知识库",variant="stop",scale=1)
create_knowledge_base_btn.click(fn=create_unstructured_db,inputs=[knowledge_base_name,data_label_2]).then(update_knowledge_base,outputs=[knowledge_base])
delete_db_btn.click(delete_db,inputs=[knowledge_base]).then(update_knowledge_base,outputs=[knowledge_base])
create_knowledge_base_btn_1.click(fn=create_structured_db,inputs=[knowledge_base_name_1,data_label_3]).then(update_knowledge_base,outputs=[knowledge_base])
knowledge.load(update_knowledge_base,[],knowledge_base)
knowledge.load(update_label,[],data_label_2)
knowledge.load(update_datatable,[],data_label_3)
return knowledge
app = FastAPI()
@app.get("/", response_class=HTMLResponse)
def read_main():
html_content = main_html
return HTMLResponse(content=html_content)
app = gr.mount_gradio_app(app, get_chat_block(), path="/chat")
app = gr.mount_gradio_app(app, get_upload_block(), path="/upload_data")
app = gr.mount_gradio_app(app, get_knowledge_base_block(), path="/create_knowledge_base")

View File

@@ -0,0 +1,26 @@
gradio==4.32.0
faiss-cpu==1.8.0.post1
dashscope==1.20.4
openai==1.55.3
httpx==0.27.0
llama-index-vector-stores-faiss==0.1.2
llama-index-embeddings-dashscope==0.1.4
llama-index-readers-file==0.1.33
matplotlib==3.9.3
docx2txt==0.8
openpyxl==3.1.5
llama-index-core==0.10.67
uvicorn==0.30.6
fastapi==0.112.0
llama-index-postprocessor-dashscope-rerank-custom==0.1.0
simplejson==3.19.3
pydantic==2.10.6
# modelscope==1.18.0
# langchain_community==0.2.16
# transformers==4.44.2
# llama_index.embeddings.huggingface==0.2.3
# llama-index-embeddings-langchain==0.1.2
# datasets==2.21.0
# oss2==2.19.0
# sortedcontainers==2.4.0
# addict==2.4.0s

107
local_rag/upload_file.py Normal file
View File

@@ -0,0 +1,107 @@
#####################################
####### 上传文件 #######
#####################################
import gradio as gr
import os
import shutil
import pandas as pd
STRUCTURED_FILE_PATH = "File/Structured"
UNSTRUCTURED_FILE_PATH = "File/Unstructured"
# 刷新非结构化类目
def refresh_label():
return os.listdir(UNSTRUCTURED_FILE_PATH)
# 刷新结构化数据表
def refresh_data_table():
return os.listdir(STRUCTURED_FILE_PATH)
# 上传非结构化数据
def upload_unstructured_file(files,label_name):
if files is None:
gr.Info("请上传文件")
elif len(label_name) == 0:
gr.Info("请输入类目名称")
# 判断类目是否存在
elif label_name in os.listdir(UNSTRUCTURED_FILE_PATH):
gr.Info(f"{label_name}类目已存在")
else:
try:
if not os.path.exists(os.path.join(UNSTRUCTURED_FILE_PATH,label_name)):
os.mkdir(os.path.join(UNSTRUCTURED_FILE_PATH,label_name))
for file in files:
print(file)
file_path = file.name
file_name = os.path.basename(file_path)
destination_file_path = os.path.join(UNSTRUCTURED_FILE_PATH,label_name,file_name)
shutil.move(file_path,destination_file_path)
gr.Info(f"文件已上传至{label_name}类目中,请前往创建知识库")
except:
gr.Info(f"请勿重复上传")
# 上传结构化数据
def upload_structured_file(files,label_name):
if files is None:
gr.Info("请上传文件")
elif len(label_name) == 0:
gr.Info("请输入数据表名称")
# 判断数据表是否存在
elif label_name in os.listdir(STRUCTURED_FILE_PATH):
gr.Info(f"{label_name}数据表已存在")
else:
try:
if not os.path.exists(os.path.join(STRUCTURED_FILE_PATH,label_name)):
os.mkdir(os.path.join(STRUCTURED_FILE_PATH,label_name))
for file in files:
file_path = file.name
file_name = os.path.basename(file_path)
destination_file_path = os.path.join(STRUCTURED_FILE_PATH,label_name,file_name)
shutil.move(file_path,destination_file_path)
if os.path.splitext(destination_file_path)[1] == ".xlsx":
df = pd.read_excel(destination_file_path)
elif os.path.splitext(destination_file_path)[1] == ".csv":
df = pd.read_csv(destination_file_path)
txt_file_name = os.path.splitext(file_name)[0]+'.txt'
columns = df.columns
with open(os.path.join(STRUCTURED_FILE_PATH,label_name,txt_file_name),"w") as file:
for idx,row in df.iterrows():
file.write("")
info = []
for col in columns:
info.append(f"{col}:{row[col]}")
infos = ",".join(info)
file.write(infos)
if idx != len(df)-1:
file.write("\n")
else:
file.write("")
os.remove(destination_file_path)
gr.Info(f"文件已上传至{label_name}数据表中,请前往创建知识库")
except:
gr.Info(f"请勿重复上传")
# 实时更新结构化数据表
def update_datatable():
return gr.update(choices=os.listdir(STRUCTURED_FILE_PATH))
# 实时更新非结构化类目
def update_label():
return gr.update(choices=os.listdir(UNSTRUCTURED_FILE_PATH))
# 删除类目
def delete_label(label_name):
if label_name is not None:
for label in label_name:
folder_path = os.path.join(UNSTRUCTURED_FILE_PATH,label)
if os.path.exists(folder_path):
shutil.rmtree(folder_path)
gr.Info(f"{label}类目已删除")
# 删除数据表
def delete_data_table(table_name):
if table_name is not None:
for table in table_name:
folder_path = os.path.join(STRUCTURED_FILE_PATH,table)
if os.path.exists(folder_path):
shutil.rmtree(folder_path)
gr.Info(f"{table}数据表已删除")