Files
knightutils/local_rag/chat.py
2025-08-27 22:22:18 +08:00

93 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from openai import OpenAI
from llama_index.core import StorageContext,load_index_from_storage,Settings
from llama_index.embeddings.dashscope import (
DashScopeEmbedding,
DashScopeTextEmbeddingModels,
DashScopeTextEmbeddingType,
)
from llama_index.postprocessor.dashscope_rerank import DashScopeRerank
from create_kb import *
DB_PATH = "VectorStore"
TMP_NAME = "tmp_abcd"
EMBED_MODEL = DashScopeEmbedding(
model_name=DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V2,
text_type=DashScopeTextEmbeddingType.TEXT_TYPE_DOCUMENT,
)
# 若使用本地嵌入模型,请取消以下注释:
# from langchain_community.embeddings import ModelScopeEmbeddings
# from llama_index.embeddings.langchain import LangchainEmbedding
# embeddings = ModelScopeEmbeddings(model_id="modelscope/iic/nlp_gte_sentence-embedding_chinese-large")
# EMBED_MODEL = LangchainEmbedding(embeddings)
# 设置嵌入模型
Settings.embed_model = EMBED_MODEL
def get_model_response(multi_modal_input,history,model,temperature,max_tokens,history_round,db_name,similarity_threshold,chunk_cnt):
# prompt = multi_modal_input['text']
prompt = history[-1][0]
tmp_files = multi_modal_input['files']
if os.path.exists(os.path.join("File",TMP_NAME)):
db_name = TMP_NAME
else:
if tmp_files:
create_tmp_kb(tmp_files)
db_name = TMP_NAME
# 获取index
print(f"prompt:{prompt},tmp_files:{tmp_files},db_name:{db_name}")
try:
dashscope_rerank = DashScopeRerank(top_n=chunk_cnt,return_documents=True)
storage_context = StorageContext.from_defaults(
persist_dir=os.path.join(DB_PATH,db_name)
)
index = load_index_from_storage(storage_context)
print("index获取完成")
retriever_engine = index.as_retriever(
similarity_top_k=20,
)
# 获取chunk
retrieve_chunk = retriever_engine.retrieve(prompt)
print(f"原始chunk为{retrieve_chunk}")
try:
results = dashscope_rerank.postprocess_nodes(retrieve_chunk, query_str=prompt)
print(f"rerank成功重排后的chunk为{results}")
except:
results = retrieve_chunk[:chunk_cnt]
print(f"rerank失败chunk为{results}")
chunk_text = ""
chunk_show = ""
for i in range(len(results)):
if results[i].score >= similarity_threshold:
chunk_text = chunk_text + f"## {i+1}:\n {results[i].text}\n"
chunk_show = chunk_show + f"## {i+1}:\n {results[i].text}\nscore: {round(results[i].score,2)}\n"
print(f"已获取chunk{chunk_text}")
prompt_template = f"请参考以下内容:{chunk_text},以合适的语气回答用户的问题:{prompt}。如果参考内容中有图片链接也请直接返回。"
except Exception as e:
print(f"异常信息:{e}")
prompt_template = prompt
chunk_show = ""
history[-1][-1] = ""
client = OpenAI(
api_key=os.getenv("DASHSCOPE_API_KEY"),
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
system_message = {'role': 'system', 'content': 'You are a helpful assistant.'}
messages = []
history_round = min(len(history),history_round)
for i in range(history_round):
messages.append({'role': 'user', 'content': history[-history_round+i][0]})
messages.append({'role': 'assistant', 'content': history[-history_round+i][1]})
messages.append({'role': 'user', 'content': prompt_template})
messages = [system_message] + messages
completion = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=True
)
assistant_response = ""
for chunk in completion:
assistant_response += chunk.choices[0].delta.content
history[-1][-1] = assistant_response
yield history,chunk_show