diff --git a/api/.env.example b/api/.env.example index 35aaabbc10..516a119d98 100644 --- a/api/.env.example +++ b/api/.env.example @@ -654,3 +654,9 @@ TENANT_ISOLATED_TASK_CONCURRENCY=1 # Maximum number of segments for dataset segments API (0 for unlimited) DATASET_MAX_SEGMENTS_PER_REQUEST=0 + +# Multimodal knowledgebase limit +SINGLE_CHUNK_ATTACHMENT_LIMIT=10 +ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2 +ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60 +IMAGE_FILE_BATCH_LIMIT=10 diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index b5ffd09d01..a5916241df 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -360,6 +360,26 @@ class FileUploadConfig(BaseSettings): default=10, ) + IMAGE_FILE_BATCH_LIMIT: PositiveInt = Field( + description="Maximum number of files allowed in a image batch upload operation", + default=10, + ) + + SINGLE_CHUNK_ATTACHMENT_LIMIT: PositiveInt = Field( + description="Maximum number of files allowed in a single chunk attachment", + default=10, + ) + + ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: NonNegativeInt = Field( + description="Maximum allowed image file size for attachments in megabytes", + default=2, + ) + + ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: NonNegativeInt = Field( + description="Timeout for downloading image attachments in seconds", + default=60, + ) + inner_UPLOAD_FILE_EXTENSION_BLACKLIST: str = Field( description=( "Comma-separated list of file extensions that are blocked from upload. " diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 1fad8abd52..c0422ef6f4 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -151,6 +151,7 @@ class DatasetUpdatePayload(BaseModel): external_knowledge_id: str | None = None external_knowledge_api_id: str | None = None icon_info: dict[str, Any] | None = None + is_multimodal: bool | None = False @field_validator("indexing_technique") @classmethod @@ -423,17 +424,16 @@ class DatasetApi(Resource): payload = DatasetUpdatePayload.model_validate(console_ns.payload or {}) payload_data = payload.model_dump(exclude_unset=True) current_user, current_tenant_id = current_account_with_tenant() - # check embedding model setting if ( payload.indexing_technique == "high_quality" and payload.embedding_model_provider is not None and payload.embedding_model is not None ): - DatasetService.check_embedding_model_setting( + is_multimodal = DatasetService.check_is_multimodal_model( dataset.tenant_id, payload.embedding_model_provider, payload.embedding_model ) - + payload.is_multimodal = is_multimodal # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( current_user, dataset, payload.permission, payload.partial_member_list diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 2520111281..6145da31a5 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -424,6 +424,10 @@ class DatasetInitApi(Resource): model_type=ModelType.TEXT_EMBEDDING, model=knowledge_config.embedding_model, ) + is_multimodal = DatasetService.check_is_multimodal_model( + current_tenant_id, knowledge_config.embedding_model_provider, knowledge_config.embedding_model + ) + knowledge_config.is_multimodal = is_multimodal except InvokeAuthorizationError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index ee390cbfb7..e73abc2555 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -51,6 +51,7 @@ class SegmentCreatePayload(BaseModel): content: str answer: str | None = None keywords: list[str] | None = None + attachment_ids: list[str] | None = None class SegmentUpdatePayload(BaseModel): @@ -58,6 +59,7 @@ class SegmentUpdatePayload(BaseModel): answer: str | None = None keywords: list[str] | None = None regenerate_child_chunks: bool = False + attachment_ids: list[str] | None = None class BatchImportPayload(BaseModel): diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index fac90a0135..db7c50f422 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,7 +1,7 @@ import logging from typing import Any -from flask_restx import marshal +from flask_restx import marshal, reqparse from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -33,6 +33,7 @@ class HitTestingPayload(BaseModel): query: str = Field(max_length=250) retrieval_model: dict[str, Any] | None = None external_retrieval_model: dict[str, Any] | None = None + attachment_ids: list[str] | None = None class DatasetsHitTestingBase: @@ -54,16 +55,28 @@ class DatasetsHitTestingBase: def hit_testing_args_check(args: dict[str, Any]): HitTestingService.hit_testing_args_check(args) + @staticmethod + def parse_args(): + parser = ( + reqparse.RequestParser() + .add_argument("query", type=str, required=False, location="json") + .add_argument("attachment_ids", type=list, required=False, location="json") + .add_argument("retrieval_model", type=dict, required=False, location="json") + .add_argument("external_retrieval_model", type=dict, required=False, location="json") + ) + return parser.parse_args() + @staticmethod def perform_hit_testing(dataset, args): assert isinstance(current_user, Account) try: response = HitTestingService.retrieve( dataset=dataset, - query=args["query"], + query=args.get("query"), account=current_user, - retrieval_model=args["retrieval_model"], - external_retrieval_model=args["external_retrieval_model"], + retrieval_model=args.get("retrieval_model"), + external_retrieval_model=args.get("external_retrieval_model"), + attachment_ids=args.get("attachment_ids"), limit=10, ) return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index fdd7c2f479..29417dc896 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -45,6 +45,9 @@ class FileApi(Resource): "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, + "image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT, + "single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT, + "attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT, }, 200 @setup_required diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 9a9832dd4a..e2e6c11480 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -83,6 +83,7 @@ class AppRunner: context: str | None = None, memory: TokenBufferMemory | None = None, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: """ Organize prompt messages @@ -111,6 +112,7 @@ class AppRunner: memory=memory, model_config=model_config, image_detail_config=image_detail_config, + context_files=context_files, ) else: memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 53188cf506..f8338b226b 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.file import File from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import ImagePromptMessageContent @@ -146,6 +147,7 @@ class ChatAppRunner(AppRunner): # get context from datasets context = None + context_files: list[File] = [] if app_config.dataset and app_config.dataset.dataset_ids: hit_callback = DatasetIndexToolCallbackHandler( queue_manager, @@ -156,7 +158,7 @@ class ChatAppRunner(AppRunner): ) dataset_retrieval = DatasetRetrieval(application_generate_entity) - context = dataset_retrieval.retrieve( + context, retrieved_files = dataset_retrieval.retrieve( app_id=app_record.id, user_id=application_generate_entity.user_id, tenant_id=app_record.tenant_id, @@ -171,7 +173,11 @@ class ChatAppRunner(AppRunner): memory=memory, message_id=message.id, inputs=inputs, + vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get( + "enabled", False + ), ) + context_files = retrieved_files or [] # reorganize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -186,6 +192,7 @@ class ChatAppRunner(AppRunner): context=context, memory=memory, image_detail_config=image_detail_config, + context_files=context_files, ) # check hosting moderation diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index e2be4146e1..ddfb5725b4 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import ( CompletionAppGenerateEntity, ) from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.file import File from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.moderation.base import ModerationError @@ -102,6 +103,7 @@ class CompletionAppRunner(AppRunner): # get context from datasets context = None + context_files: list[File] = [] if app_config.dataset and app_config.dataset.dataset_ids: hit_callback = DatasetIndexToolCallbackHandler( queue_manager, @@ -116,7 +118,7 @@ class CompletionAppRunner(AppRunner): query = inputs.get(dataset_config.retrieve_config.query_variable, "") dataset_retrieval = DatasetRetrieval(application_generate_entity) - context = dataset_retrieval.retrieve( + context, retrieved_files = dataset_retrieval.retrieve( app_id=app_record.id, user_id=application_generate_entity.user_id, tenant_id=app_record.tenant_id, @@ -130,7 +132,11 @@ class CompletionAppRunner(AppRunner): hit_callback=hit_callback, message_id=message.id, inputs=inputs, + vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get( + "enabled", False + ), ) + context_files = retrieved_files or [] # reorganize all inputs and template to prompt messages # Include: prompt template, inputs, query(optional), files(optional) @@ -144,6 +150,7 @@ class CompletionAppRunner(AppRunner): query=query, context=context, image_detail_config=image_detail_config, + context_files=context_files, ) # check hosting moderation diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 14d5f38dcd..d0279349ca 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -7,7 +7,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import ChildChunk, DatasetQuery, DocumentSegment @@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler: document_id, ) continue - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunk_stmt = select(ChildChunk).where( ChildChunk.index_node_id == document.metadata["doc_id"], ChildChunk.dataset_id == dataset_document.dataset_id, diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 36b38b7b45..59de4f403d 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -7,7 +7,7 @@ import time import uuid from typing import Any -from flask import current_app +from flask import Flask, current_app from sqlalchemy import select from sqlalchemy.orm.exc import ObjectDeletedError @@ -21,7 +21,7 @@ from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import ChildDocument, Document @@ -36,6 +36,7 @@ from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper from libs.datetime_utils import naive_utc_now +from models import Account from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import UploadFile @@ -89,8 +90,17 @@ class IndexingRunner: text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform + current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first() + if not current_user: + raise ValueError("no current user found") + current_user.set_tenant_id(dataset.tenant_id) documents = self._transform( - index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict() + index_processor, + dataset, + text_docs, + requeried_document.doc_language, + processing_rule.to_dict(), + current_user=current_user, ) # save segment self._load_segments(dataset, requeried_document, documents) @@ -136,7 +146,7 @@ class IndexingRunner: for document_segment in document_segments: db.session.delete(document_segment) - if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: # delete child chunks db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete() db.session.commit() @@ -152,8 +162,17 @@ class IndexingRunner: text_docs = self._extract(index_processor, requeried_document, processing_rule.to_dict()) # transform + current_user = db.session.query(Account).filter_by(id=requeried_document.created_by).first() + if not current_user: + raise ValueError("no current user found") + current_user.set_tenant_id(dataset.tenant_id) documents = self._transform( - index_processor, dataset, text_docs, requeried_document.doc_language, processing_rule.to_dict() + index_processor, + dataset, + text_docs, + requeried_document.doc_language, + processing_rule.to_dict(), + current_user=current_user, ) # save segment self._load_segments(dataset, requeried_document, documents) @@ -209,7 +228,7 @@ class IndexingRunner: "dataset_id": document_segment.dataset_id, }, ) - if requeried_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if requeried_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = document_segment.get_child_chunks() if child_chunks: child_documents = [] @@ -302,6 +321,7 @@ class IndexingRunner: text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) documents = index_processor.transform( text_docs, + current_user=None, embedding_model_instance=embedding_model_instance, process_rule=processing_rule.to_dict(), tenant_id=tenant_id, @@ -551,7 +571,10 @@ class IndexingRunner: indexing_start_at = time.perf_counter() tokens = 0 create_keyword_thread = None - if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy": + if ( + dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX + and dataset.indexing_technique == "economy" + ): # create keyword index create_keyword_thread = threading.Thread( target=self._process_keyword_index, @@ -590,7 +613,7 @@ class IndexingRunner: for future in futures: tokens += future.result() if ( - dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX + dataset_document.doc_form != IndexStructureType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy" and create_keyword_thread is not None ): @@ -635,7 +658,13 @@ class IndexingRunner: db.session.commit() def _process_chunk( - self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance + self, + flask_app: Flask, + index_processor: BaseIndexProcessor, + chunk_documents: list[Document], + dataset: Dataset, + dataset_document: DatasetDocument, + embedding_model_instance: ModelInstance | None, ): with flask_app.app_context(): # check document is paused @@ -646,8 +675,15 @@ class IndexingRunner: page_content_list = [document.page_content for document in chunk_documents] tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list)) + multimodal_documents = [] + for document in chunk_documents: + if document.attachments and dataset.is_multimodal: + multimodal_documents.extend(document.attachments) + # load index - index_processor.load(dataset, chunk_documents, with_keywords=False) + index_processor.load( + dataset, chunk_documents, multimodal_documents=multimodal_documents, with_keywords=False + ) document_ids = [document.metadata["doc_id"] for document in chunk_documents] db.session.query(DocumentSegment).where( @@ -710,6 +746,7 @@ class IndexingRunner: text_docs: list[Document], doc_language: str, process_rule: dict, + current_user: Account | None = None, ) -> list[Document]: # get embedding model instance embedding_model_instance = None @@ -729,6 +766,7 @@ class IndexingRunner: documents = index_processor.transform( text_docs, + current_user, embedding_model_instance=embedding_model_instance, process_rule=process_rule, tenant_id=dataset.tenant_id, @@ -737,14 +775,16 @@ class IndexingRunner: return documents - def _load_segments(self, dataset, dataset_document, documents): + def _load_segments(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]): # save node to document segment doc_store = DatasetDocumentStore( dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id ) # add document segments - doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX) + doc_store.add_documents( + docs=documents, save_child=dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX + ) # update document status to indexing cur_time = naive_utc_now() diff --git a/api/core/model_manager.py b/api/core/model_manager.py index a63e94d59c..5a28bbcc3a 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -10,9 +10,9 @@ from core.errors.error import ProviderTokenNotInitError from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.entities.rerank_entities import RerankResult -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.moderation_model import ModerationModel @@ -200,7 +200,7 @@ class ModelInstance: def invoke_text_embedding( self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT - ) -> TextEmbeddingResult: + ) -> EmbeddingResult: """ Invoke large language model @@ -212,7 +212,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TextEmbeddingModel): raise Exception("Model type instance is not TextEmbeddingModel") return cast( - TextEmbeddingResult, + EmbeddingResult, self._round_robin_invoke( function=self.model_type_instance.invoke, model=self.model, @@ -223,6 +223,34 @@ class ModelInstance: ), ) + def invoke_multimodal_embedding( + self, + multimodel_documents: list[dict], + user: str | None = None, + input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, + ) -> EmbeddingResult: + """ + Invoke large language model + + :param multimodel_documents: multimodel documents to embed + :param user: unique user id + :param input_type: input type + :return: embeddings result + """ + if not isinstance(self.model_type_instance, TextEmbeddingModel): + raise Exception("Model type instance is not TextEmbeddingModel") + return cast( + EmbeddingResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke, + model=self.model, + credentials=self.credentials, + multimodel_documents=multimodel_documents, + user=user, + input_type=input_type, + ), + ) + def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]: """ Get number of tokens for text embedding @@ -276,6 +304,40 @@ class ModelInstance: ), ) + def invoke_multimodal_rerank( + self, + query: dict, + docs: list[dict], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + if not isinstance(self.model_type_instance, RerankModel): + raise Exception("Model type instance is not RerankModel") + return cast( + RerankResult, + self._round_robin_invoke( + function=self.model_type_instance.invoke_multimodal_rerank, + model=self.model, + credentials=self.credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + user=user, + ), + ) + def invoke_moderation(self, text: str, user: str | None = None) -> bool: """ Invoke moderation model @@ -461,6 +523,32 @@ class ModelManager: model=default_model_entity.model, ) + def check_model_support_vision(self, tenant_id: str, provider: str, model: str, model_type: ModelType) -> bool: + """ + Check if model supports vision + :param tenant_id: tenant id + :param provider: provider name + :param model: model name + :return: True if model supports vision, False otherwise + """ + model_instance = self.get_model_instance(tenant_id, provider, model_type, model) + model_type_instance = model_instance.model_type_instance + match model_type: + case ModelType.LLM: + model_type_instance = cast(LargeLanguageModel, model_type_instance) + case ModelType.TEXT_EMBEDDING: + model_type_instance = cast(TextEmbeddingModel, model_type_instance) + case ModelType.RERANK: + model_type_instance = cast(RerankModel, model_type_instance) + case _: + raise ValueError(f"Model type {model_type} is not supported") + model_schema = model_type_instance.get_model_schema(model, model_instance.credentials) + if not model_schema: + return False + if model_schema.features and ModelFeature.VISION in model_schema.features: + return True + return False + class LBModelManager: def __init__( diff --git a/api/core/model_runtime/entities/text_embedding_entities.py b/api/core/model_runtime/entities/text_embedding_entities.py index 846b89d658..854c448250 100644 --- a/api/core/model_runtime/entities/text_embedding_entities.py +++ b/api/core/model_runtime/entities/text_embedding_entities.py @@ -19,7 +19,7 @@ class EmbeddingUsage(ModelUsage): latency: float -class TextEmbeddingResult(BaseModel): +class EmbeddingResult(BaseModel): """ Model class for text embedding result. """ @@ -27,3 +27,13 @@ class TextEmbeddingResult(BaseModel): model: str embeddings: list[list[float]] usage: EmbeddingUsage + + +class FileEmbeddingResult(BaseModel): + """ + Model class for file embedding result. + """ + + model: str + embeddings: list[list[float]] + usage: EmbeddingUsage diff --git a/api/core/model_runtime/model_providers/__base/rerank_model.py b/api/core/model_runtime/model_providers/__base/rerank_model.py index 36067118b0..0a576b832a 100644 --- a/api/core/model_runtime/model_providers/__base/rerank_model.py +++ b/api/core/model_runtime/model_providers/__base/rerank_model.py @@ -50,3 +50,43 @@ class RerankModel(AIModel): ) except Exception as e: raise self._transform_invoke_error(e) + + def invoke_multimodal_rerank( + self, + model: str, + credentials: dict, + query: dict, + docs: list[dict], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + ) -> RerankResult: + """ + Invoke multimodal rerank model + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + try: + from core.plugin.impl.model import PluginModelClient + + plugin_model_manager = PluginModelClient() + return plugin_model_manager.invoke_multimodal_rerank( + tenant_id=self.tenant_id, + user_id=user or "unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model=model, + credentials=credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + ) + except Exception as e: + raise self._transform_invoke_error(e) diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index bd68ffe903..4c902e2c11 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -2,7 +2,7 @@ from pydantic import ConfigDict from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult from core.model_runtime.model_providers.__base.ai_model import AIModel @@ -20,16 +20,18 @@ class TextEmbeddingModel(AIModel): self, model: str, credentials: dict, - texts: list[str], + texts: list[str] | None = None, + multimodel_documents: list[dict] | None = None, user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, - ) -> TextEmbeddingResult: + ) -> EmbeddingResult: """ Invoke text embedding model :param model: model name :param credentials: model credentials :param texts: texts to embed + :param files: files to embed :param user: unique user id :param input_type: input type :return: embeddings result @@ -38,16 +40,29 @@ class TextEmbeddingModel(AIModel): try: plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_text_embedding( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - texts=texts, - input_type=input_type, - ) + if texts: + return plugin_model_manager.invoke_text_embedding( + tenant_id=self.tenant_id, + user_id=user or "unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model=model, + credentials=credentials, + texts=texts, + input_type=input_type, + ) + if multimodel_documents: + return plugin_model_manager.invoke_multimodal_embedding( + tenant_id=self.tenant_id, + user_id=user or "unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model=model, + credentials=credentials, + documents=multimodel_documents, + input_type=input_type, + ) + raise ValueError("No texts or files provided") except Exception as e: raise self._transform_invoke_error(e) diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 5dfc3c212e..5d70980967 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -6,7 +6,7 @@ from core.model_runtime.entities.llm_entities import LLMResultChunk from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.rerank_entities import RerankResult -from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, @@ -243,14 +243,14 @@ class PluginModelClient(BasePluginClient): credentials: dict, texts: list[str], input_type: str, - ) -> TextEmbeddingResult: + ) -> EmbeddingResult: """ Invoke text embedding """ response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke", - type_=TextEmbeddingResult, + type_=EmbeddingResult, data=jsonable_encoder( { "user_id": user_id, @@ -275,6 +275,48 @@ class PluginModelClient(BasePluginClient): raise ValueError("Failed to invoke text embedding") + def invoke_multimodal_embedding( + self, + tenant_id: str, + user_id: str, + plugin_id: str, + provider: str, + model: str, + credentials: dict, + documents: list[dict], + input_type: str, + ) -> EmbeddingResult: + """ + Invoke file embedding + """ + response = self._request_with_plugin_daemon_response_stream( + method="POST", + path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke", + type_=EmbeddingResult, + data=jsonable_encoder( + { + "user_id": user_id, + "data": { + "provider": provider, + "model_type": "text-embedding", + "model": model, + "credentials": credentials, + "documents": documents, + "input_type": input_type, + }, + } + ), + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + + for resp in response: + return resp + + raise ValueError("Failed to invoke file embedding") + def get_text_embedding_num_tokens( self, tenant_id: str, @@ -361,6 +403,51 @@ class PluginModelClient(BasePluginClient): raise ValueError("Failed to invoke rerank") + def invoke_multimodal_rerank( + self, + tenant_id: str, + user_id: str, + plugin_id: str, + provider: str, + model: str, + credentials: dict, + query: dict, + docs: list[dict], + score_threshold: float | None = None, + top_n: int | None = None, + ) -> RerankResult: + """ + Invoke multimodal rerank + """ + response = self._request_with_plugin_daemon_response_stream( + method="POST", + path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke", + type_=RerankResult, + data=jsonable_encoder( + { + "user_id": user_id, + "data": { + "provider": provider, + "model_type": "rerank", + "model": model, + "credentials": credentials, + "query": query, + "docs": docs, + "score_threshold": score_threshold, + "top_n": top_n, + }, + } + ), + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + for resp in response: + return resp + + raise ValueError("Failed to invoke multimodal rerank") + def invoke_tts( self, tenant_id: str, diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index d1d518a55d..f072092ea7 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -49,6 +49,7 @@ class SimplePromptTransform(PromptTransform): memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: inputs = {key: str(value) for key, value in inputs.items()} @@ -64,6 +65,7 @@ class SimplePromptTransform(PromptTransform): memory=memory, model_config=model_config, image_detail_config=image_detail_config, + context_files=context_files, ) else: prompt_messages, stops = self._get_completion_model_prompt_messages( @@ -76,6 +78,7 @@ class SimplePromptTransform(PromptTransform): memory=memory, model_config=model_config, image_detail_config=image_detail_config, + context_files=context_files, ) return prompt_messages, stops @@ -187,6 +190,7 @@ class SimplePromptTransform(PromptTransform): memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: prompt_messages: list[PromptMessage] = [] @@ -216,9 +220,9 @@ class SimplePromptTransform(PromptTransform): ) if query: - prompt_messages.append(self._get_last_user_message(query, files, image_detail_config)) + prompt_messages.append(self._get_last_user_message(query, files, image_detail_config, context_files)) else: - prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config)) + prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config, context_files)) return prompt_messages, None @@ -233,6 +237,7 @@ class SimplePromptTransform(PromptTransform): memory: TokenBufferMemory | None, model_config: ModelConfigWithCredentialsEntity, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> tuple[list[PromptMessage], list[str] | None]: # get prompt prompt, prompt_rules = self._get_prompt_str_and_rules( @@ -275,20 +280,27 @@ class SimplePromptTransform(PromptTransform): if stops is not None and len(stops) == 0: stops = None - return [self._get_last_user_message(prompt, files, image_detail_config)], stops + return [self._get_last_user_message(prompt, files, image_detail_config, context_files)], stops def _get_last_user_message( self, prompt: str, files: Sequence["File"], image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + context_files: list["File"] | None = None, ) -> UserPromptMessage: + prompt_message_contents: list[PromptMessageContentUnionTypes] = [] if files: - prompt_message_contents: list[PromptMessageContentUnionTypes] = [] for file in files: prompt_message_contents.append( file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) ) + if context_files: + for file in context_files: + prompt_message_contents.append( + file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config) + ) + if prompt_message_contents: prompt_message_contents.append(TextPromptMessageContent(data=prompt)) prompt_message = UserPromptMessage(content=prompt_message_contents) diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index cc946a72c3..bfa8781e9f 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -2,6 +2,7 @@ from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.data_post_processor.reorder import ReorderRunner +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner @@ -30,9 +31,10 @@ class DataPostProcessor: score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: if self.rerank_runner: - documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user) + documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user, query_type) if self.reorder_runner: documents = self.reorder_runner.run(documents) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 2290de19bc..cbd7cbeb64 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,23 +1,30 @@ import concurrent.futures from concurrent.futures import ThreadPoolExecutor +from typing import Any from flask import Flask, current_app from sqlalchemy import select from sqlalchemy.orm import Session, load_only from configs import dify_config +from core.model_manager import ModelManager +from core.model_runtime.entities.model_entities import ModelType from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector from core.rag.embedding.retrieval import RetrievalSegments from core.rag.entities.metadata_entities import MetadataCondition -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.tools.signature import sign_upload_file from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument +from models.model import UploadFile from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { @@ -37,14 +44,15 @@ class RetrievalService: retrieval_method: RetrievalMethod, dataset_id: str, query: str, - top_k: int, + top_k: int = 4, score_threshold: float | None = 0.0, reranking_model: dict | None = None, reranking_mode: str = "reranking_model", weights: dict | None = None, document_ids_filter: list[str] | None = None, + attachment_ids: list | None = None, ): - if not query: + if not query and not attachment_ids: return [] dataset = cls._get_dataset(dataset_id) if not dataset: @@ -56,69 +64,52 @@ class RetrievalService: # Optimize multithreading with thread pools with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore futures = [] - if retrieval_method == RetrievalMethod.KEYWORD_SEARCH: + retrieval_service = RetrievalService() + if query: futures.append( executor.submit( - cls.keyword_search, + retrieval_service._retrieve, flask_app=current_app._get_current_object(), # type: ignore - dataset_id=dataset_id, - query=query, - top_k=top_k, - all_documents=all_documents, - exceptions=exceptions, - document_ids_filter=document_ids_filter, - ) - ) - if RetrievalMethod.is_support_semantic_search(retrieval_method): - futures.append( - executor.submit( - cls.embedding_search, - flask_app=current_app._get_current_object(), # type: ignore - dataset_id=dataset_id, + retrieval_method=retrieval_method, + dataset=dataset, query=query, top_k=top_k, score_threshold=score_threshold, reranking_model=reranking_model, - all_documents=all_documents, - retrieval_method=retrieval_method, - exceptions=exceptions, + reranking_mode=reranking_mode, + weights=weights, document_ids_filter=document_ids_filter, + attachment_id=None, + all_documents=all_documents, + exceptions=exceptions, ) ) - if RetrievalMethod.is_support_fulltext_search(retrieval_method): - futures.append( - executor.submit( - cls.full_text_index_search, - flask_app=current_app._get_current_object(), # type: ignore - dataset_id=dataset_id, - query=query, - top_k=top_k, - score_threshold=score_threshold, - reranking_model=reranking_model, - all_documents=all_documents, - retrieval_method=retrieval_method, - exceptions=exceptions, - document_ids_filter=document_ids_filter, + if attachment_ids: + for attachment_id in attachment_ids: + futures.append( + executor.submit( + retrieval_service._retrieve, + flask_app=current_app._get_current_object(), # type: ignore + retrieval_method=retrieval_method, + dataset=dataset, + query=None, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + reranking_mode=reranking_mode, + weights=weights, + document_ids_filter=document_ids_filter, + attachment_id=attachment_id, + all_documents=all_documents, + exceptions=exceptions, + ) ) - ) - concurrent.futures.wait(futures, timeout=30, return_when=concurrent.futures.ALL_COMPLETED) + + concurrent.futures.wait(futures, timeout=3600, return_when=concurrent.futures.ALL_COMPLETED) if exceptions: raise ValueError(";\n".join(exceptions)) - # Deduplicate documents for hybrid search to avoid duplicate chunks - if retrieval_method == RetrievalMethod.HYBRID_SEARCH: - all_documents = cls._deduplicate_documents(all_documents) - data_post_processor = DataPostProcessor( - str(dataset.tenant_id), reranking_mode, reranking_model, weights, False - ) - all_documents = data_post_processor.invoke( - query=query, - documents=all_documents, - score_threshold=score_threshold, - top_n=top_k, - ) - return all_documents @classmethod @@ -223,6 +214,7 @@ class RetrievalService: retrieval_method: RetrievalMethod, exceptions: list, document_ids_filter: list[str] | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ): with flask_app.app_context(): try: @@ -231,14 +223,30 @@ class RetrievalService: raise ValueError("dataset not found") vector = Vector(dataset=dataset) - documents = vector.search_by_vector( - query, - search_type="similarity_score_threshold", - top_k=top_k, - score_threshold=score_threshold, - filter={"group_id": [dataset.id]}, - document_ids_filter=document_ids_filter, - ) + documents = [] + if query_type == QueryType.TEXT_QUERY: + documents.extend( + vector.search_by_vector( + query, + search_type="similarity_score_threshold", + top_k=top_k, + score_threshold=score_threshold, + filter={"group_id": [dataset.id]}, + document_ids_filter=document_ids_filter, + ) + ) + if query_type == QueryType.IMAGE_QUERY: + if not dataset.is_multimodal: + return + documents.extend( + vector.search_by_file( + file_id=query, + top_k=top_k, + score_threshold=score_threshold, + filter={"group_id": [dataset.id]}, + document_ids_filter=document_ids_filter, + ) + ) if documents: if ( @@ -250,14 +258,37 @@ class RetrievalService: data_post_processor = DataPostProcessor( str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False ) - all_documents.extend( - data_post_processor.invoke( - query=query, - documents=documents, - score_threshold=score_threshold, - top_n=len(documents), + if dataset.is_multimodal: + model_manager = ModelManager() + is_support_vision = model_manager.check_model_support_vision( + tenant_id=dataset.tenant_id, + provider=reranking_model.get("reranking_provider_name") or "", + model=reranking_model.get("reranking_model_name") or "", + model_type=ModelType.RERANK, + ) + if is_support_vision: + all_documents.extend( + data_post_processor.invoke( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents), + query_type=query_type, + ) + ) + else: + # not effective, return original documents + all_documents.extend(documents) + else: + all_documents.extend( + data_post_processor.invoke( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents), + query_type=query_type, + ) ) - ) else: all_documents.extend(documents) except Exception as e: @@ -339,103 +370,159 @@ class RetrievalService: records = [] include_segment_ids = set() segment_child_map = {} - - # Process documents - for document in documents: - document_id = document.metadata.get("document_id") - if document_id not in dataset_documents: - continue - - dataset_document = dataset_documents[document_id] - if not dataset_document: - continue - - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - # Handle parent-child documents - child_index_node_id = document.metadata.get("doc_id") - child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id) - child_chunk = db.session.scalar(child_chunk_stmt) - - if not child_chunk: + segment_file_map = {} + with Session(db.engine) as session: + # Process documents + for document in documents: + segment_id = None + attachment_info = None + child_chunk = None + document_id = document.metadata.get("document_id") + if document_id not in dataset_documents: continue - segment = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.id == child_chunk.segment_id, - ) - .options( - load_only( - DocumentSegment.id, - DocumentSegment.content, - DocumentSegment.answer, + dataset_document = dataset_documents[document_id] + if not dataset_document: + continue + + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + # Handle parent-child documents + if document.metadata.get("doc_type") == DocType.IMAGE: + attachment_info_dict = cls.get_segment_attachment_info( + dataset_document.dataset_id, + dataset_document.tenant_id, + document.metadata.get("doc_id") or "", + session, ) + if attachment_info_dict: + attachment_info = attachment_info_dict["attchment_info"] + segment_id = attachment_info_dict["segment_id"] + else: + child_index_node_id = document.metadata.get("doc_id") + child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id) + child_chunk = session.scalar(child_chunk_stmt) + + if not child_chunk: + continue + segment_id = child_chunk.segment_id + + if not segment_id: + continue + + segment = ( + session.query(DocumentSegment) + .where( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.id == segment_id, + ) + .options( + load_only( + DocumentSegment.id, + DocumentSegment.content, + DocumentSegment.answer, + ) + ) + .first() ) - .first() - ) - if not segment: - continue + if not segment: + continue - if segment.id not in include_segment_ids: - include_segment_ids.add(segment.id) - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - map_detail = { - "max_score": document.metadata.get("score", 0.0), - "child_chunks": [child_chunk_detail], - } - segment_child_map[segment.id] = map_detail - record = { - "segment": segment, - } - records.append(record) + if segment.id not in include_segment_ids: + include_segment_ids.add(segment.id) + if child_chunk: + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + map_detail = { + "max_score": document.metadata.get("score", 0.0), + "child_chunks": [child_chunk_detail], + } + segment_child_map[segment.id] = map_detail + record = { + "segment": segment, + } + if attachment_info: + segment_file_map[segment.id] = [attachment_info] + records.append(record) + else: + if child_chunk: + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) + segment_child_map[segment.id]["max_score"] = max( + segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) + ) + if attachment_info: + segment_file_map[segment.id].append(attachment_info) else: - child_chunk_detail = { - "id": child_chunk.id, - "content": child_chunk.content, - "position": child_chunk.position, - "score": document.metadata.get("score", 0.0), - } - segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) - segment_child_map[segment.id]["max_score"] = max( - segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) - ) - else: - # Handle normal documents - index_node_id = document.metadata.get("doc_id") - if not index_node_id: - continue - document_segment_stmt = select(DocumentSegment).where( - DocumentSegment.dataset_id == dataset_document.dataset_id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.index_node_id == index_node_id, - ) - segment = db.session.scalar(document_segment_stmt) + # Handle normal documents + segment = None + if document.metadata.get("doc_type") == DocType.IMAGE: + attachment_info_dict = cls.get_segment_attachment_info( + dataset_document.dataset_id, + dataset_document.tenant_id, + document.metadata.get("doc_id") or "", + session, + ) + if attachment_info_dict: + attachment_info = attachment_info_dict["attchment_info"] + segment_id = attachment_info_dict["segment_id"] + document_segment_stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.id == segment_id, + ) + segment = db.session.scalar(document_segment_stmt) + if segment: + segment_file_map[segment.id] = [attachment_info] + else: + index_node_id = document.metadata.get("doc_id") + if not index_node_id: + continue + document_segment_stmt = select(DocumentSegment).where( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, + ) + segment = db.session.scalar(document_segment_stmt) - if not segment: - continue - - include_segment_ids.add(segment.id) - record = { - "segment": segment, - "score": document.metadata.get("score"), # type: ignore - } - records.append(record) + if not segment: + continue + if segment.id not in include_segment_ids: + include_segment_ids.add(segment.id) + record = { + "segment": segment, + "score": document.metadata.get("score"), # type: ignore + } + if attachment_info: + segment_file_map[segment.id] = [attachment_info] + records.append(record) + else: + if attachment_info: + attachment_infos = segment_file_map.get(segment.id, []) + if attachment_info not in attachment_infos: + attachment_infos.append(attachment_info) + segment_file_map[segment.id] = attachment_infos # Add child chunks information to records for record in records: if record["segment"].id in segment_child_map: record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore record["score"] = segment_child_map[record["segment"].id]["max_score"] + if record["segment"].id in segment_file_map: + record["files"] = segment_file_map[record["segment"].id] # type: ignore[assignment] result = [] for record in records: @@ -447,6 +534,11 @@ class RetrievalService: if not isinstance(child_chunks, list): child_chunks = None + # Extract files, ensuring it's a list or None + files = record.get("files") + if not isinstance(files, list): + files = None + # Extract score, ensuring it's a float or None score_value = record.get("score") score = ( @@ -456,10 +548,149 @@ class RetrievalService: ) # Create RetrievalSegments object - retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score) + retrieval_segment = RetrievalSegments( + segment=segment, child_chunks=child_chunks, score=score, files=files + ) result.append(retrieval_segment) return result except Exception as e: db.session.rollback() raise e + + def _retrieve( + self, + flask_app: Flask, + retrieval_method: RetrievalMethod, + dataset: Dataset, + query: str | None = None, + top_k: int = 4, + score_threshold: float | None = 0.0, + reranking_model: dict | None = None, + reranking_mode: str = "reranking_model", + weights: dict | None = None, + document_ids_filter: list[str] | None = None, + attachment_id: str | None = None, + all_documents: list[Document] = [], + exceptions: list[str] = [], + ): + if not query and not attachment_id: + return + with flask_app.app_context(): + all_documents_item: list[Document] = [] + # Optimize multithreading with thread pools + with ThreadPoolExecutor(max_workers=dify_config.RETRIEVAL_SERVICE_EXECUTORS) as executor: # type: ignore + futures = [] + if retrieval_method == RetrievalMethod.KEYWORD_SEARCH and query: + futures.append( + executor.submit( + self.keyword_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=query, + top_k=top_k, + all_documents=all_documents_item, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + ) + ) + if RetrievalMethod.is_support_semantic_search(retrieval_method): + if query: + futures.append( + executor.submit( + self.embedding_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + all_documents=all_documents_item, + retrieval_method=retrieval_method, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + query_type=QueryType.TEXT_QUERY, + ) + ) + if attachment_id: + futures.append( + executor.submit( + self.embedding_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=attachment_id, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + all_documents=all_documents_item, + retrieval_method=retrieval_method, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + query_type=QueryType.IMAGE_QUERY, + ) + ) + if RetrievalMethod.is_support_fulltext_search(retrieval_method) and query: + futures.append( + executor.submit( + self.full_text_index_search, + flask_app=current_app._get_current_object(), # type: ignore + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + all_documents=all_documents_item, + retrieval_method=retrieval_method, + exceptions=exceptions, + document_ids_filter=document_ids_filter, + ) + ) + concurrent.futures.wait(futures, timeout=300, return_when=concurrent.futures.ALL_COMPLETED) + + if exceptions: + raise ValueError(";\n".join(exceptions)) + + # Deduplicate documents for hybrid search to avoid duplicate chunks + if retrieval_method == RetrievalMethod.HYBRID_SEARCH: + if attachment_id and reranking_mode == RerankMode.WEIGHTED_SCORE: + all_documents.extend(all_documents_item) + all_documents_item = self._deduplicate_documents(all_documents_item) + data_post_processor = DataPostProcessor( + str(dataset.tenant_id), reranking_mode, reranking_model, weights, False + ) + + query = query or attachment_id + if not query: + return + all_documents_item = data_post_processor.invoke( + query=query, + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.TEXT_QUERY if query else QueryType.IMAGE_QUERY, + ) + + all_documents.extend(all_documents_item) + + @classmethod + def get_segment_attachment_info( + cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session + ) -> dict[str, Any] | None: + upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first() + if upload_file: + attachment_binding = ( + session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.attachment_id == upload_file.id) + .first() + ) + if attachment_binding: + attchment_info = { + "id": upload_file.id, + "name": upload_file.name, + "extension": "." + upload_file.extension, + "mime_type": upload_file.mime_type, + "source_url": sign_upload_file(upload_file.id, upload_file.extension), + "size": upload_file.size, + } + return {"attchment_info": attchment_info, "segment_id": attachment_binding.segment_id} + return None diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 0beb388693..3a47241293 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -1,3 +1,4 @@ +import base64 import logging import time from abc import ABC, abstractmethod @@ -12,10 +13,13 @@ from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.embedding_base import Embeddings +from core.rag.index_processor.constant.doc_type import DocType from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from extensions.ext_storage import storage from models.dataset import Dataset, Whitelist +from models.model import UploadFile logger = logging.getLogger(__name__) @@ -203,6 +207,47 @@ class Vector: self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs) logger.info("Embedding %s texts took %s s", len(texts), time.time() - start) + def create_multimodal(self, file_documents: list | None = None, **kwargs): + if file_documents: + start = time.time() + logger.info("start embedding %s files %s", len(file_documents), start) + batch_size = 1000 + total_batches = len(file_documents) + batch_size - 1 + for i in range(0, len(file_documents), batch_size): + batch = file_documents[i : i + batch_size] + batch_start = time.time() + logger.info("Processing batch %s/%s (%s files)", i // batch_size + 1, total_batches, len(batch)) + + # Batch query all upload files to avoid N+1 queries + attachment_ids = [doc.metadata["doc_id"] for doc in batch] + stmt = select(UploadFile).where(UploadFile.id.in_(attachment_ids)) + upload_files = db.session.scalars(stmt).all() + upload_file_map = {str(f.id): f for f in upload_files} + + file_base64_list = [] + real_batch = [] + for document in batch: + attachment_id = document.metadata["doc_id"] + doc_type = document.metadata["doc_type"] + upload_file = upload_file_map.get(attachment_id) + if upload_file: + blob = storage.load_once(upload_file.key) + file_base64_str = base64.b64encode(blob).decode() + file_base64_list.append( + { + "content": file_base64_str, + "content_type": doc_type, + "file_id": attachment_id, + } + ) + real_batch.append(document) + batch_embeddings = self._embeddings.embed_multimodal_documents(file_base64_list) + logger.info( + "Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start + ) + self._vector_processor.create(texts=real_batch, embeddings=batch_embeddings, **kwargs) + logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start) + def add_texts(self, documents: list[Document], **kwargs): if kwargs.get("duplicate_check", False): documents = self._filter_duplicate_texts(documents) @@ -223,6 +268,22 @@ class Vector: query_vector = self._embeddings.embed_query(query) return self._vector_processor.search_by_vector(query_vector, **kwargs) + def search_by_file(self, file_id: str, **kwargs: Any) -> list[Document]: + upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + + if not upload_file: + return [] + blob = storage.load_once(upload_file.key) + file_base64_str = base64.b64encode(blob).decode() + multimodal_vector = self._embeddings.embed_multimodal_query( + { + "content": file_base64_str, + "content_type": DocType.IMAGE, + "file_id": file_id, + } + ) + return self._vector_processor.search_by_vector(multimodal_vector, **kwargs) + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: return self._vector_processor.search_by_full_text(query, **kwargs) diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 74a2653e9d..1fe74d3042 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -5,9 +5,9 @@ from sqlalchemy import func, select from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.rag.models.document import Document +from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding class DatasetDocumentStore: @@ -120,6 +120,9 @@ class DatasetDocumentStore: db.session.add(segment_document) db.session.flush() + self.add_multimodel_documents_binding( + segment_id=segment_document.id, multimodel_documents=doc.attachments + ) if save_child: if doc.children: for position, child in enumerate(doc.children, start=1): @@ -144,6 +147,9 @@ class DatasetDocumentStore: segment_document.index_node_hash = doc.metadata.get("doc_hash") segment_document.word_count = len(doc.page_content) segment_document.tokens = tokens + self.add_multimodel_documents_binding( + segment_id=segment_document.id, multimodel_documents=doc.attachments + ) if save_child and doc.children: # delete the existing child chunks db.session.query(ChildChunk).where( @@ -233,3 +239,15 @@ class DatasetDocumentStore: document_segment = db.session.scalar(stmt) return document_segment + + def add_multimodel_documents_binding(self, segment_id: str, multimodel_documents: list[AttachmentDocument] | None): + if multimodel_documents: + for multimodel_document in multimodel_documents: + binding = SegmentAttachmentBinding( + tenant_id=self._dataset.tenant_id, + dataset_id=self._dataset.id, + document_id=self._document_id, + segment_id=segment_id, + attachment_id=multimodel_document.metadata["doc_id"], + ) + db.session.add(binding) diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 7fb20c1941..3cbc7db75d 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -104,6 +104,88 @@ class CacheEmbedding(Embeddings): return text_embeddings + def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: + """Embed file documents.""" + # use doc embedding cache or store if not exists + multimodel_embeddings: list[Any] = [None for _ in range(len(multimodel_documents))] + embedding_queue_indices = [] + for i, multimodel_document in enumerate(multimodel_documents): + file_id = multimodel_document["file_id"] + embedding = ( + db.session.query(Embedding) + .filter_by( + model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider + ) + .first() + ) + if embedding: + multimodel_embeddings[i] = embedding.get_embedding() + else: + embedding_queue_indices.append(i) + + # NOTE: avoid closing the shared scoped session here; downstream code may still have pending work + + if embedding_queue_indices: + embedding_queue_multimodel_documents = [multimodel_documents[i] for i in embedding_queue_indices] + embedding_queue_embeddings = [] + try: + model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) + model_schema = model_type_instance.get_model_schema( + self._model_instance.model, self._model_instance.credentials + ) + max_chunks = ( + model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] + if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties + else 1 + ) + for i in range(0, len(embedding_queue_multimodel_documents), max_chunks): + batch_multimodel_documents = embedding_queue_multimodel_documents[i : i + max_chunks] + + embedding_result = self._model_instance.invoke_multimodal_embedding( + multimodel_documents=batch_multimodel_documents, + user=self._user, + input_type=EmbeddingInputType.DOCUMENT, + ) + + for vector in embedding_result.embeddings: + try: + # FIXME: type ignore for numpy here + normalized_embedding = (vector / np.linalg.norm(vector)).tolist() # type: ignore + # stackoverflow best way: https://stackoverflow.com/questions/20319813/how-to-check-list-containing-nan + if np.isnan(normalized_embedding).any(): + # for issue #11827 float values are not json compliant + logger.warning("Normalized embedding is nan: %s", normalized_embedding) + continue + embedding_queue_embeddings.append(normalized_embedding) + except IntegrityError: + db.session.rollback() + except Exception: + logger.exception("Failed transform embedding") + cache_embeddings = [] + try: + for i, n_embedding in zip(embedding_queue_indices, embedding_queue_embeddings): + multimodel_embeddings[i] = n_embedding + file_id = multimodel_documents[i]["file_id"] + if file_id not in cache_embeddings: + embedding_cache = Embedding( + model_name=self._model_instance.model, + hash=file_id, + provider_name=self._model_instance.provider, + embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL), + ) + embedding_cache.set_embedding(n_embedding) + db.session.add(embedding_cache) + cache_embeddings.append(file_id) + db.session.commit() + except IntegrityError: + db.session.rollback() + except Exception as ex: + db.session.rollback() + logger.exception("Failed to embed documents") + raise ex + + return multimodel_embeddings + def embed_query(self, text: str) -> list[float]: """Embed query text.""" # use doc embedding cache or store if not exists @@ -146,3 +228,46 @@ class CacheEmbedding(Embeddings): raise ex return embedding_results # type: ignore + + def embed_multimodal_query(self, multimodel_document: dict) -> list[float]: + """Embed multimodal documents.""" + # use doc embedding cache or store if not exists + file_id = multimodel_document["file_id"] + embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}" + embedding = redis_client.get(embedding_cache_key) + if embedding: + redis_client.expire(embedding_cache_key, 600) + decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float") + return [float(x) for x in decoded_embedding] + try: + embedding_result = self._model_instance.invoke_multimodal_embedding( + multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY + ) + + embedding_results = embedding_result.embeddings[0] + # FIXME: type ignore for numpy here + embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist() # type: ignore + if np.isnan(embedding_results).any(): + raise ValueError("Normalized embedding is nan please try again") + except Exception as ex: + if dify_config.DEBUG: + logger.exception("Failed to embed multimodal document '%s'", multimodel_document["file_id"]) + raise ex + + try: + # encode embedding to base64 + embedding_vector = np.array(embedding_results) + vector_bytes = embedding_vector.tobytes() + # Transform to Base64 + encoded_vector = base64.b64encode(vector_bytes) + # Transform to string + encoded_str = encoded_vector.decode("utf-8") + redis_client.setex(embedding_cache_key, 600, encoded_str) + except Exception as ex: + if dify_config.DEBUG: + logger.exception( + "Failed to add embedding to redis for the multimodal document '%s'", multimodel_document["file_id"] + ) + raise ex + + return embedding_results # type: ignore diff --git a/api/core/rag/embedding/embedding_base.py b/api/core/rag/embedding/embedding_base.py index 9f232ab910..1be55bda80 100644 --- a/api/core/rag/embedding/embedding_base.py +++ b/api/core/rag/embedding/embedding_base.py @@ -9,11 +9,21 @@ class Embeddings(ABC): """Embed search docs.""" raise NotImplementedError + @abstractmethod + def embed_multimodal_documents(self, multimodel_documents: list[dict]) -> list[list[float]]: + """Embed file documents.""" + raise NotImplementedError + @abstractmethod def embed_query(self, text: str) -> list[float]: """Embed query text.""" raise NotImplementedError + @abstractmethod + def embed_multimodal_query(self, multimodel_document: dict) -> list[float]: + """Embed multimodal query.""" + raise NotImplementedError + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """Asynchronous Embed search docs.""" raise NotImplementedError diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py index 8e92191568..b54a37b49e 100644 --- a/api/core/rag/embedding/retrieval.py +++ b/api/core/rag/embedding/retrieval.py @@ -19,3 +19,4 @@ class RetrievalSegments(BaseModel): segment: DocumentSegment child_chunks: list[RetrievalChildChunk] | None = None score: float | None = None + files: list[dict[str, str | int]] | None = None diff --git a/api/core/rag/entities/citation_metadata.py b/api/core/rag/entities/citation_metadata.py index aca879df7d..9f66cd9a03 100644 --- a/api/core/rag/entities/citation_metadata.py +++ b/api/core/rag/entities/citation_metadata.py @@ -21,3 +21,4 @@ class RetrievalSourceMetadata(BaseModel): page: int | None = None doc_metadata: dict[str, Any] | None = None title: str | None = None + files: list[dict[str, Any]] | None = None diff --git a/api/core/rag/index_processor/constant/doc_type.py b/api/core/rag/index_processor/constant/doc_type.py new file mode 100644 index 0000000000..93c8fecb8d --- /dev/null +++ b/api/core/rag/index_processor/constant/doc_type.py @@ -0,0 +1,6 @@ +from enum import StrEnum + + +class DocType(StrEnum): + TEXT = "text" + IMAGE = "image" diff --git a/api/core/rag/index_processor/constant/index_type.py b/api/core/rag/index_processor/constant/index_type.py index 659086e808..09617413f7 100644 --- a/api/core/rag/index_processor/constant/index_type.py +++ b/api/core/rag/index_processor/constant/index_type.py @@ -1,7 +1,12 @@ from enum import StrEnum -class IndexType(StrEnum): +class IndexStructureType(StrEnum): PARAGRAPH_INDEX = "text_model" QA_INDEX = "qa_model" PARENT_CHILD_INDEX = "hierarchical_model" + + +class IndexTechniqueType(StrEnum): + ECONOMY = "economy" + HIGH_QUALITY = "high_quality" diff --git a/api/core/rag/index_processor/constant/query_type.py b/api/core/rag/index_processor/constant/query_type.py new file mode 100644 index 0000000000..342bfef3f7 --- /dev/null +++ b/api/core/rag/index_processor/constant/query_type.py @@ -0,0 +1,6 @@ +from enum import StrEnum + + +class QueryType(StrEnum): + TEXT_QUERY = "text_query" + IMAGE_QUERY = "image_query" diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index d4eff53204..8a28eb477a 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -1,20 +1,34 @@ """Abstract interface for document loader implementations.""" +import cgi +import logging +import mimetypes +import os +import re from abc import ABC, abstractmethod from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Optional +from urllib.parse import unquote, urlparse + +import httpx from configs import dify_config +from core.helper import ssrf_proxy from core.rag.extractor.entity.extract_setting import ExtractSetting -from core.rag.models.document import Document +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.models.document import AttachmentDocument, Document from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.splitter.fixed_text_splitter import ( EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter, ) from core.rag.splitter.text_splitter import TextSplitter +from extensions.ext_database import db +from extensions.ext_storage import storage +from models import Account, ToolFile from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument +from models.model import UploadFile if TYPE_CHECKING: from core.model_manager import ModelInstance @@ -28,11 +42,18 @@ class BaseIndexProcessor(ABC): raise NotImplementedError @abstractmethod - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: raise NotImplementedError @abstractmethod - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): raise NotImplementedError @abstractmethod @@ -96,3 +117,178 @@ class BaseIndexProcessor(ABC): ) return character_splitter # type: ignore + + def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]: + """ + Get the content files from the document. + """ + multi_model_documents: list[AttachmentDocument] = [] + text = document.page_content + images = self._extract_markdown_images(text) + if not images: + return multi_model_documents + upload_file_id_list = [] + + for image in images: + # Collect all upload_file_ids including duplicates to preserve occurrence count + + # For data before v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?" + match = re.search(pattern, image) + if match: + upload_file_id = match.group(1) + upload_file_id_list.append(upload_file_id) + continue + + # For data after v0.10.0 + pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?" + match = re.search(pattern, image) + if match: + upload_file_id = match.group(1) + upload_file_id_list.append(upload_file_id) + continue + + # For tools directory - direct file formats (e.g., .png, .jpg, etc.) + # Match URL including any query parameters up to common URL boundaries (space, parenthesis, quotes) + pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?" + match = re.search(pattern, image) + if match: + if current_user: + tool_file_id = match.group(1) + upload_file_id = self._download_tool_file(tool_file_id, current_user) + if upload_file_id: + upload_file_id_list.append(upload_file_id) + continue + if current_user: + upload_file_id = self._download_image(image.split(" ")[0], current_user) + if upload_file_id: + upload_file_id_list.append(upload_file_id) + + if not upload_file_id_list: + return multi_model_documents + + # Get unique IDs for database query + unique_upload_file_ids = list(set(upload_file_id_list)) + upload_files = db.session.query(UploadFile).where(UploadFile.id.in_(unique_upload_file_ids)).all() + + # Create a mapping from ID to UploadFile for quick lookup + upload_file_map = {upload_file.id: upload_file for upload_file in upload_files} + + # Create a Document for each occurrence (including duplicates) + for upload_file_id in upload_file_id_list: + upload_file = upload_file_map.get(upload_file_id) + if upload_file: + multi_model_documents.append( + AttachmentDocument( + page_content=upload_file.name, + metadata={ + "doc_id": upload_file.id, + "doc_hash": "", + "document_id": document.metadata.get("document_id"), + "dataset_id": document.metadata.get("dataset_id"), + "doc_type": DocType.IMAGE, + }, + ) + ) + return multi_model_documents + + def _extract_markdown_images(self, text: str) -> list[str]: + """ + Extract the markdown images from the text. + """ + pattern = r"!\[.*?\]\((.*?)\)" + return re.findall(pattern, text) + + def _download_image(self, image_url: str, current_user: Account) -> str | None: + """ + Download the image from the URL. + Image size must not exceed 2MB. + """ + from services.file_service import FileService + + MAX_IMAGE_SIZE = dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT * 1024 * 1024 + DOWNLOAD_TIMEOUT = dify_config.ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT + + try: + # Download with timeout + response = ssrf_proxy.get(image_url, timeout=DOWNLOAD_TIMEOUT) + response.raise_for_status() + + # Check Content-Length header if available + content_length = response.headers.get("Content-Length") + if content_length and int(content_length) > MAX_IMAGE_SIZE: + logging.warning("Image from %s exceeds 2MB limit (size: %s bytes)", image_url, content_length) + return None + + filename = None + + content_disposition = response.headers.get("content-disposition") + if content_disposition: + _, params = cgi.parse_header(content_disposition) + if "filename" in params: + filename = params["filename"] + filename = unquote(filename) + + if not filename: + parsed_url = urlparse(image_url) + # unquote 处理 URL 中的中文 + path = unquote(parsed_url.path) + filename = os.path.basename(path) + + if not filename: + filename = "downloaded_image_file" + + name, current_ext = os.path.splitext(filename) + + content_type = response.headers.get("content-type", "").split(";")[0].strip() + + real_ext = mimetypes.guess_extension(content_type) + + if not current_ext and real_ext or current_ext in [".php", ".jsp", ".asp", ".html"] and real_ext: + filename = f"{name}{real_ext}" + # Download content with size limit + blob = b"" + for chunk in response.iter_bytes(chunk_size=8192): + blob += chunk + if len(blob) > MAX_IMAGE_SIZE: + logging.warning("Image from %s exceeds 2MB limit during download", image_url) + return None + + if not blob: + logging.warning("Image from %s is empty", image_url) + return None + + upload_file = FileService(db.engine).upload_file( + filename=filename, + content=blob, + mimetype=content_type, + user=current_user, + ) + return upload_file.id + except httpx.TimeoutException: + logging.warning("Timeout downloading image from %s after %s seconds", image_url, DOWNLOAD_TIMEOUT) + return None + except httpx.RequestError as e: + logging.warning("Error downloading image from %s: %s", image_url, str(e)) + return None + except Exception: + logging.exception("Unexpected error downloading image from %s", image_url) + return None + + def _download_tool_file(self, tool_file_id: str, current_user: Account) -> str | None: + """ + Download the tool file from the ID. + """ + from services.file_service import FileService + + tool_file = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first() + if not tool_file: + return None + blob = storage.load_once(tool_file.file_key) + upload_file = FileService(db.engine).upload_file( + filename=tool_file.name, + content=blob, + mimetype=tool_file.mimetype, + user=current_user, + ) + return upload_file.id diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py index c987edf342..ea6ab24699 100644 --- a/api/core/rag/index_processor/index_processor_factory.py +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -1,6 +1,6 @@ """Abstract interface for document loader implementations.""" -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor @@ -19,11 +19,11 @@ class IndexProcessorFactory: if not self._index_type: raise ValueError("Index type must be specified.") - if self._index_type == IndexType.PARAGRAPH_INDEX: + if self._index_type == IndexStructureType.PARAGRAPH_INDEX: return ParagraphIndexProcessor() - elif self._index_type == IndexType.QA_INDEX: + elif self._index_type == IndexStructureType.QA_INDEX: return QAIndexProcessor() - elif self._index_type == IndexType.PARENT_CHILD_INDEX: + elif self._index_type == IndexStructureType.PARENT_CHILD_INDEX: return ParentChildIndexProcessor() else: raise ValueError(f"Index type {self._index_type} is not supported.") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 5e5fea7ea9..a7c879f2c4 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -11,14 +11,17 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import Document +from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper +from models.account import Account from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument +from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import Rule @@ -33,7 +36,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): return text_docs - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: process_rule = kwargs.get("process_rule") if not process_rule: raise ValueError("No process rule found.") @@ -69,6 +72,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): if document_node.metadata is not None: document_node.metadata["doc_id"] = doc_id document_node.metadata["doc_hash"] = hash + multimodal_documents = ( + self._get_content_files(document_node, current_user) if document_node.metadata else None + ) + if multimodal_documents: + document_node.attachments = multimodal_documents # delete Splitter character page_content = remove_leading_symbols(document_node.page_content).strip() if len(page_content) > 0: @@ -77,10 +85,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor): all_documents.extend(split_documents) return all_documents - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) + if multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(multimodal_documents) with_keywords = False if with_keywords: keywords_list = kwargs.get("keywords_list") @@ -134,8 +151,9 @@ class ParagraphIndexProcessor(BaseIndexProcessor): return docs def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any): + documents: list[Any] = [] + all_multimodal_documents: list[Any] = [] if isinstance(chunks, list): - documents = [] for content in chunks: metadata = { "dataset_id": dataset.id, @@ -144,26 +162,68 @@ class ParagraphIndexProcessor(BaseIndexProcessor): "doc_hash": helper.generate_text_hash(content), } doc = Document(page_content=content, metadata=metadata) + attachments = self._get_content_files(doc) + if attachments: + doc.attachments = attachments + all_multimodal_documents.extend(attachments) documents.append(doc) - if documents: - # save node to document segment - doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) - # add document segments - doc_store.add_documents(docs=documents, save_child=False) - if dataset.indexing_technique == "high_quality": - vector = Vector(dataset) - vector.create(documents) - elif dataset.indexing_technique == "economy": - keyword = Keyword(dataset) - keyword.add_texts(documents) else: - raise ValueError("Chunks is not a list") + multimodal_general_structure = MultimodalGeneralStructureChunk.model_validate(chunks) + for general_chunk in multimodal_general_structure.general_chunks: + metadata = { + "dataset_id": dataset.id, + "document_id": document.id, + "doc_id": str(uuid.uuid4()), + "doc_hash": helper.generate_text_hash(general_chunk.content), + } + doc = Document(page_content=general_chunk.content, metadata=metadata) + if general_chunk.files: + attachments = [] + for file in general_chunk.files: + file_metadata = { + "doc_id": file.id, + "doc_hash": "", + "document_id": document.id, + "dataset_id": dataset.id, + "doc_type": DocType.IMAGE, + } + file_document = AttachmentDocument( + page_content=file.filename or "image_file", metadata=file_metadata + ) + attachments.append(file_document) + all_multimodal_documents.append(file_document) + doc.attachments = attachments + else: + account = AccountService.load_user(document.created_by) + if not account: + raise ValueError("Invalid account") + doc.attachments = self._get_content_files(doc, current_user=account) + if doc.attachments: + all_multimodal_documents.extend(doc.attachments) + documents.append(doc) + if documents: + # save node to document segment + doc_store = DatasetDocumentStore(dataset=dataset, user_id=document.created_by, document_id=document.id) + # add document segments + doc_store.add_documents(docs=documents, save_child=False) + if dataset.indexing_technique == "high_quality": + vector = Vector(dataset) + vector.create(documents) + if all_multimodal_documents: + vector.create_multimodal(all_multimodal_documents) + elif dataset.indexing_technique == "economy": + keyword = Keyword(dataset) + keyword.add_texts(documents) def format_preview(self, chunks: Any) -> Mapping[str, Any]: if isinstance(chunks, list): preview = [] for content in chunks: preview.append({"content": content}) - return {"chunk_structure": IndexType.PARAGRAPH_INDEX, "preview": preview, "total_segments": len(chunks)} + return { + "chunk_structure": IndexStructureType.PARAGRAPH_INDEX, + "preview": preview, + "total_segments": len(chunks), + } else: raise ValueError("Chunks is not a list") diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 4fa78e2f95..ee29d2fd65 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -13,14 +13,17 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import ChildDocument, Document, ParentChildStructureChunk +from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from libs import helper +from models import Account from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument +from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule @@ -35,7 +38,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): return text_docs - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: process_rule = kwargs.get("process_rule") if not process_rule: raise ValueError("No process rule found.") @@ -77,6 +80,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor): page_content = page_content if len(page_content) > 0: document_node.page_content = page_content + multimodel_documents = self._get_content_files(document_node, current_user) + if multimodel_documents: + document_node.attachments = multimodel_documents # parse document to child nodes child_nodes = self._split_child_nodes( document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") @@ -87,6 +93,9 @@ class ParentChildIndexProcessor(BaseIndexProcessor): elif rules.parent_mode == ParentMode.FULL_DOC: page_content = "\n".join([document.page_content for document in documents]) document = Document(page_content=page_content, metadata=documents[0].metadata) + multimodel_documents = self._get_content_files(document) + if multimodel_documents: + document.attachments = multimodel_documents # parse document to child nodes child_nodes = self._split_child_nodes( document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") @@ -104,7 +113,14 @@ class ParentChildIndexProcessor(BaseIndexProcessor): return all_documents - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) for document in documents: @@ -114,6 +130,8 @@ class ParentChildIndexProcessor(BaseIndexProcessor): Document.model_validate(child_document.model_dump()) for child_document in child_documents ] vector.create(formatted_child_documents) + if multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(multimodal_documents) def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): # node_ids is segment's node_ids @@ -244,6 +262,24 @@ class ParentChildIndexProcessor(BaseIndexProcessor): } child_documents.append(ChildDocument(page_content=child, metadata=child_metadata)) doc = Document(page_content=parent_child.parent_content, metadata=metadata, children=child_documents) + if parent_child.files and len(parent_child.files) > 0: + attachments = [] + for file in parent_child.files: + file_metadata = { + "doc_id": file.id, + "doc_hash": "", + "document_id": document.id, + "dataset_id": dataset.id, + "doc_type": DocType.IMAGE, + } + file_document = AttachmentDocument(page_content=file.filename or "", metadata=file_metadata) + attachments.append(file_document) + doc.attachments = attachments + else: + account = AccountService.load_user(document.created_by) + if not account: + raise ValueError("Invalid account") + doc.attachments = self._get_content_files(doc, current_user=account) documents.append(doc) if documents: # update document parent mode @@ -267,12 +303,17 @@ class ParentChildIndexProcessor(BaseIndexProcessor): doc_store.add_documents(docs=documents, save_child=True) if dataset.indexing_technique == "high_quality": all_child_documents = [] + all_multimodal_documents = [] for doc in documents: if doc.children: all_child_documents.extend(doc.children) + if doc.attachments: + all_multimodal_documents.extend(doc.attachments) + vector = Vector(dataset) if all_child_documents: - vector = Vector(dataset) vector.create(all_child_documents) + if all_multimodal_documents: + vector.create_multimodal(all_multimodal_documents) def format_preview(self, chunks: Any) -> Mapping[str, Any]: parent_childs = ParentChildStructureChunk.model_validate(chunks) @@ -280,7 +321,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): for parent_child in parent_childs.parent_child_chunks: preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents}) return { - "chunk_structure": IndexType.PARENT_CHILD_INDEX, + "chunk_structure": IndexStructureType.PARENT_CHILD_INDEX, "parent_mode": parent_childs.parent_mode, "preview": preview, "total_segments": len(parent_childs.parent_child_chunks), diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 3e3deb0180..1183d5fbd7 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -18,12 +18,13 @@ from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor -from core.rag.models.document import Document, QAStructureChunk +from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper +from models.account import Account from models.dataset import Dataset from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import Rule @@ -41,7 +42,7 @@ class QAIndexProcessor(BaseIndexProcessor): ) return text_docs - def transform(self, documents: list[Document], **kwargs) -> list[Document]: + def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]: preview = kwargs.get("preview") process_rule = kwargs.get("process_rule") if not process_rule: @@ -116,7 +117,7 @@ class QAIndexProcessor(BaseIndexProcessor): try: # Skip the first row - df = pd.read_csv(file) + df = pd.read_csv(file) # type: ignore text_docs = [] for _, row in df.iterrows(): data = Document(page_content=row.iloc[0], metadata={"answer": row.iloc[1]}) @@ -128,10 +129,19 @@ class QAIndexProcessor(BaseIndexProcessor): raise ValueError(str(e)) return text_docs - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + def load( + self, + dataset: Dataset, + documents: list[Document], + multimodal_documents: list[AttachmentDocument] | None = None, + with_keywords: bool = True, + **kwargs, + ): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) + if multimodal_documents and dataset.is_multimodal: + vector.create_multimodal(multimodal_documents) def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs): vector = Vector(dataset) @@ -197,7 +207,7 @@ class QAIndexProcessor(BaseIndexProcessor): for qa_chunk in qa_chunks.qa_chunks: preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer}) return { - "chunk_structure": IndexType.QA_INDEX, + "chunk_structure": IndexStructureType.QA_INDEX, "qa_preview": preview, "total_segments": len(qa_chunks.qa_chunks), } diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 4bd7b1d62e..611fad9a18 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -4,6 +4,8 @@ from typing import Any from pydantic import BaseModel, Field +from core.file import File + class ChildDocument(BaseModel): """Class for storing a piece of text and associated metadata.""" @@ -15,7 +17,19 @@ class ChildDocument(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: dict = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) + + +class AttachmentDocument(BaseModel): + """Class for storing a piece of text and associated metadata.""" + + page_content: str + + provider: str | None = "dify" + + vector: list[float] | None = None + + metadata: dict[str, Any] = Field(default_factory=dict) class Document(BaseModel): @@ -28,12 +42,31 @@ class Document(BaseModel): """Arbitrary metadata about the page content (e.g., source, relationships to other documents, etc.). """ - metadata: dict = Field(default_factory=dict) + metadata: dict[str, Any] = Field(default_factory=dict) provider: str | None = "dify" children: list[ChildDocument] | None = None + attachments: list[AttachmentDocument] | None = None + + +class GeneralChunk(BaseModel): + """ + General Chunk. + """ + + content: str + files: list[File] | None = None + + +class MultimodalGeneralStructureChunk(BaseModel): + """ + Multimodal General Structure Chunk. + """ + + general_chunks: list[GeneralChunk] + class GeneralStructureChunk(BaseModel): """ @@ -50,6 +83,7 @@ class ParentChildChunk(BaseModel): parent_content: str child_contents: list[str] + files: list[File] | None = None class ParentChildStructureChunk(BaseModel): diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py index 3561def008..88acb75133 100644 --- a/api/core/rag/rerank/rerank_base.py +++ b/api/core/rag/rerank/rerank_base.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document @@ -12,6 +13,7 @@ class BaseRerankRunner(ABC): score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ Run rerank model diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index e855b0083f..38309d3d77 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,6 +1,15 @@ -from core.model_manager import ModelInstance +import base64 + +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.rerank_entities import RerankResult +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.model import UploadFile class RerankModelRunner(BaseRerankRunner): @@ -14,6 +23,7 @@ class RerankModelRunner(BaseRerankRunner): score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ Run rerank model @@ -24,6 +34,56 @@ class RerankModelRunner(BaseRerankRunner): :param user: unique user id if needed :return: """ + model_manager = ModelManager() + is_support_vision = model_manager.check_model_support_vision( + tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id, + provider=self.rerank_model_instance.provider, + model=self.rerank_model_instance.model, + model_type=ModelType.RERANK, + ) + if not is_support_vision: + if query_type == QueryType.TEXT_QUERY: + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + else: + return documents + else: + rerank_result, unique_documents = self.fetch_multimodal_rerank( + query, documents, score_threshold, top_n, user, query_type + ) + + rerank_documents = [] + for result in rerank_result.docs: + if score_threshold is None or result.score >= score_threshold: + # format document + rerank_document = Document( + page_content=result.text, + metadata=unique_documents[result.index].metadata, + provider=unique_documents[result.index].provider, + ) + if rerank_document.metadata is not None: + rerank_document.metadata["score"] = result.score + rerank_documents.append(rerank_document) + + rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True) + return rerank_documents[:top_n] if top_n else rerank_documents + + def fetch_text_rerank( + self, + query: str, + documents: list[Document], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + ) -> tuple[RerankResult, list[Document]]: + """ + Fetch text rerank + :param query: search query + :param documents: documents for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id if needed + :return: + """ docs = [] doc_ids = set() unique_documents = [] @@ -33,33 +93,99 @@ class RerankModelRunner(BaseRerankRunner): and document.metadata is not None and document.metadata["doc_id"] not in doc_ids ): - doc_ids.add(document.metadata["doc_id"]) - docs.append(document.page_content) - unique_documents.append(document) + if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT: + doc_ids.add(document.metadata["doc_id"]) + docs.append(document.page_content) + unique_documents.append(document) elif document.provider == "external": if document not in unique_documents: docs.append(document.page_content) unique_documents.append(document) - documents = unique_documents - rerank_result = self.rerank_model_instance.invoke_rerank( query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user ) + return rerank_result, unique_documents - rerank_documents = [] + def fetch_multimodal_rerank( + self, + query: str, + documents: list[Document], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, + ) -> tuple[RerankResult, list[Document]]: + """ + Fetch multimodal rerank + :param query: search query + :param documents: documents for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id if needed + :param query_type: query type + :return: rerank result + """ + docs = [] + doc_ids = set() + unique_documents = [] + for document in documents: + if ( + document.provider == "dify" + and document.metadata is not None + and document.metadata["doc_id"] not in doc_ids + ): + if document.metadata.get("doc_type") == DocType.IMAGE: + # Query file info within db.session context to ensure thread-safe access + upload_file = ( + db.session.query(UploadFile).where(UploadFile.id == document.metadata["doc_id"]).first() + ) + if upload_file: + blob = storage.load_once(upload_file.key) + document_file_base64 = base64.b64encode(blob).decode() + document_file_dict = { + "content": document_file_base64, + "content_type": document.metadata["doc_type"], + } + docs.append(document_file_dict) + else: + document_text_dict = { + "content": document.page_content, + "content_type": document.metadata.get("doc_type") or DocType.TEXT, + } + docs.append(document_text_dict) + doc_ids.add(document.metadata["doc_id"]) + unique_documents.append(document) + elif document.provider == "external": + if document not in unique_documents: + docs.append( + { + "content": document.page_content, + "content_type": document.metadata.get("doc_type") or DocType.TEXT, + } + ) + unique_documents.append(document) - for result in rerank_result.docs: - if score_threshold is None or result.score >= score_threshold: - # format document - rerank_document = Document( - page_content=result.text, - metadata=documents[result.index].metadata, - provider=documents[result.index].provider, + documents = unique_documents + if query_type == QueryType.TEXT_QUERY: + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + return rerank_result, unique_documents + elif query_type == QueryType.IMAGE_QUERY: + # Query file info within db.session context to ensure thread-safe access + upload_file = db.session.query(UploadFile).where(UploadFile.id == query).first() + if upload_file: + blob = storage.load_once(upload_file.key) + file_query = base64.b64encode(blob).decode() + file_query_dict = { + "content": file_query, + "content_type": DocType.IMAGE, + } + rerank_result = self.rerank_model_instance.invoke_multimodal_rerank( + query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user ) - if rerank_document.metadata is not None: - rerank_document.metadata["score"] = result.score - rerank_documents.append(rerank_document) + return rerank_result, unique_documents + else: + raise ValueError(f"Upload file not found for query: {query}") - rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True) - return rerank_documents[:top_n] if top_n else rerank_documents + else: + raise ValueError(f"Query type {query_type} is not supported") diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index c455db6095..18020608cb 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -7,6 +7,8 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.embedding.cached_embedding import CacheEmbedding +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner @@ -24,6 +26,7 @@ class WeightRerankRunner(BaseRerankRunner): score_threshold: float | None = None, top_n: int | None = None, user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ Run rerank model @@ -43,8 +46,10 @@ class WeightRerankRunner(BaseRerankRunner): and document.metadata is not None and document.metadata["doc_id"] not in doc_ids ): - doc_ids.add(document.metadata["doc_id"]) - unique_documents.append(document) + # weight rerank only support text documents + if not document.metadata.get("doc_type") or document.metadata.get("doc_type") == DocType.TEXT: + doc_ids.add(document.metadata["doc_id"]) + unique_documents.append(document) else: if document not in unique_documents: unique_documents.append(document) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 3db67efb0e..ec55d2d0cc 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -8,6 +8,7 @@ from typing import Any, Union, cast from flask import Flask, current_app from sqlalchemy import and_, or_, select +from sqlalchemy.orm import Session from core.app.app_config.entities import ( DatasetEntity, @@ -19,6 +20,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus +from core.file import File, FileTransferMethod, FileType from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage @@ -37,7 +39,9 @@ from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext from core.rag.entities.metadata_entities import Condition, MetadataCondition -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -52,10 +56,12 @@ from core.rag.retrieval.template_prompts import ( METADATA_FILTER_USER_PROMPT_2, METADATA_FILTER_USER_PROMPT_3, ) +from core.tools.signature import sign_upload_file from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from libs.json_in_md_parser import parse_and_check_json_markdown -from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment +from models import UploadFile +from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService @@ -99,7 +105,8 @@ class DatasetRetrieval: message_id: str, memory: TokenBufferMemory | None = None, inputs: Mapping[str, Any] | None = None, - ) -> str | None: + vision_enabled: bool = False, + ) -> tuple[str | None, list[File] | None]: """ Retrieve dataset. :param app_id: app_id @@ -118,7 +125,7 @@ class DatasetRetrieval: """ dataset_ids = config.dataset_ids if len(dataset_ids) == 0: - return None + return None, [] retrieve_config = config.retrieve_config # check model is support tool calling @@ -136,7 +143,7 @@ class DatasetRetrieval: ) if not model_schema: - return None + return None, [] planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features @@ -182,8 +189,8 @@ class DatasetRetrieval: tenant_id, user_id, user_from, - available_datasets, query, + available_datasets, model_instance, model_config, planning_strategy, @@ -213,6 +220,7 @@ class DatasetRetrieval: dify_documents = [item for item in all_documents if item.provider == "dify"] external_documents = [item for item in all_documents if item.provider == "external"] document_context_list: list[DocumentContext] = [] + context_files: list[File] = [] retrieval_resource_list: list[RetrievalSourceMetadata] = [] # deal with external documents for item in external_documents: @@ -248,6 +256,31 @@ class DatasetRetrieval: score=record.score, ) ) + if vision_enabled: + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.segment_id == segment.id, + ) + ).all() + if attachments_with_bindings: + for _, upload_file in attachments_with_bindings: + attchment_info = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=segment.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + url=sign_upload_file(upload_file.id, upload_file.extension), + ) + context_files.append(attchment_info) if show_retrieve_source: for record in records: segment = record.segment @@ -288,8 +321,10 @@ class DatasetRetrieval: hit_callback.return_retriever_resource_info(retrieval_resource_list) if document_context_list: document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) - return str("\n".join([document_context.content for document_context in document_context_list])) - return "" + return str( + "\n".join([document_context.content for document_context in document_context_list]) + ), context_files + return "", context_files def single_retrieve( self, @@ -297,8 +332,8 @@ class DatasetRetrieval: tenant_id: str, user_id: str, user_from: str, - available_datasets: list, query: str, + available_datasets: list, model_instance: ModelInstance, model_config: ModelConfigWithCredentialsEntity, planning_strategy: PlanningStrategy, @@ -336,7 +371,7 @@ class DatasetRetrieval: dataset_id, router_usage = function_call_router.invoke(query, tools, model_config, model_instance) self._record_usage(router_usage) - + timer = None if dataset_id: # get retrieval model config dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) @@ -406,10 +441,19 @@ class DatasetRetrieval: weights=retrieval_model_config.get("weights", None), document_ids_filter=document_ids_filter, ) - self._on_query(query, [dataset_id], app_id, user_from, user_id) + self._on_query(query, None, [dataset_id], app_id, user_from, user_id) if results: - self._on_retrieval_end(results, message_id, timer) + thread = threading.Thread( + target=self._on_retrieval_end, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "documents": results, + "message_id": message_id, + "timer": timer, + }, + ) + thread.start() return results return [] @@ -421,7 +465,7 @@ class DatasetRetrieval: user_id: str, user_from: str, available_datasets: list, - query: str, + query: str | None, top_k: int, score_threshold: float, reranking_mode: str, @@ -431,10 +475,11 @@ class DatasetRetrieval: message_id: str | None = None, metadata_filter_document_ids: dict[str, list[str]] | None = None, metadata_condition: MetadataCondition | None = None, + attachment_ids: list[str] | None = None, ): if not available_datasets: return [] - threads = [] + all_threads = [] all_documents: list[Document] = [] dataset_ids = [dataset.id for dataset in available_datasets] index_type_check = all( @@ -467,131 +512,226 @@ class DatasetRetrieval: 0 ].embedding_model_provider weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model - - for dataset in available_datasets: - index_type = dataset.indexing_technique - document_ids_filter = None - if dataset.provider != "external": - if metadata_condition and not metadata_filter_document_ids: - continue - if metadata_filter_document_ids: - document_ids = metadata_filter_document_ids.get(dataset.id, []) - if document_ids: - document_ids_filter = document_ids - else: - continue - retrieval_thread = threading.Thread( - target=self._retriever, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "dataset_id": dataset.id, - "query": query, - "top_k": top_k, - "all_documents": all_documents, - "document_ids_filter": document_ids_filter, - "metadata_condition": metadata_condition, - }, - ) - threads.append(retrieval_thread) - retrieval_thread.start() - for thread in threads: - thread.join() - with measure_time() as timer: - if reranking_enable: - # do rerank for searched documents - data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) - - all_documents = data_post_processor.invoke( - query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k + if query: + query_thread = threading.Thread( + target=self._multiple_retrieve_thread, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "available_datasets": available_datasets, + "metadata_condition": metadata_condition, + "metadata_filter_document_ids": metadata_filter_document_ids, + "all_documents": all_documents, + "tenant_id": tenant_id, + "reranking_enable": reranking_enable, + "reranking_mode": reranking_mode, + "reranking_model": reranking_model, + "weights": weights, + "top_k": top_k, + "score_threshold": score_threshold, + "query": query, + "attachment_id": None, + }, ) - else: - if index_type == "economy": - all_documents = self.calculate_keyword_score(query, all_documents, top_k) - elif index_type == "high_quality": - all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold) - else: - all_documents = all_documents[:top_k] if top_k else all_documents - - self._on_query(query, dataset_ids, app_id, user_from, user_id) + all_threads.append(query_thread) + query_thread.start() + if attachment_ids: + for attachment_id in attachment_ids: + attachment_thread = threading.Thread( + target=self._multiple_retrieve_thread, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "available_datasets": available_datasets, + "metadata_condition": metadata_condition, + "metadata_filter_document_ids": metadata_filter_document_ids, + "all_documents": all_documents, + "tenant_id": tenant_id, + "reranking_enable": reranking_enable, + "reranking_mode": reranking_mode, + "reranking_model": reranking_model, + "weights": weights, + "top_k": top_k, + "score_threshold": score_threshold, + "query": None, + "attachment_id": attachment_id, + }, + ) + all_threads.append(attachment_thread) + attachment_thread.start() + for thread in all_threads: + thread.join() + self._on_query(query, attachment_ids, dataset_ids, app_id, user_from, user_id) if all_documents: - self._on_retrieval_end(all_documents, message_id, timer) - - return all_documents - - def _on_retrieval_end(self, documents: list[Document], message_id: str | None = None, timer: dict | None = None): - """Handle retrieval end.""" - dify_documents = [document for document in documents if document.provider == "dify"] - for document in dify_documents: - if document.metadata is not None: - dataset_document_stmt = select(DatasetDocument).where( - DatasetDocument.id == document.metadata["document_id"] - ) - dataset_document = db.session.scalar(dataset_document_stmt) - if dataset_document: - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - child_chunk_stmt = select(ChildChunk).where( - ChildChunk.index_node_id == document.metadata["doc_id"], - ChildChunk.dataset_id == dataset_document.dataset_id, - ChildChunk.document_id == dataset_document.id, - ) - child_chunk = db.session.scalar(child_chunk_stmt) - if child_chunk: - _ = ( - db.session.query(DocumentSegment) - .where(DocumentSegment.id == child_chunk.segment_id) - .update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, - synchronize_session=False, - ) - ) - else: - query = db.session.query(DocumentSegment).where( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) - - # if 'dataset_id' in document.metadata: - if "dataset_id" in document.metadata: - query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"]) - - # add hit count to document segment - query.update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False - ) - - db.session.commit() - - # get tracing instance - trace_manager: TraceQueueManager | None = ( - self.application_generate_entity.trace_manager if self.application_generate_entity else None - ) - if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer - ) + # add thread to call _on_retrieval_end + retrieval_end_thread = threading.Thread( + target=self._on_retrieval_end, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "documents": all_documents, + "message_id": message_id, + "timer": timer, + }, ) + retrieval_end_thread.start() + retrieval_resource_list = [] + doc_ids_filter = [] + for document in all_documents: + if document.provider == "dify": + doc_id = document.metadata.get("doc_id") + if doc_id and doc_id not in doc_ids_filter: + doc_ids_filter.append(doc_id) + retrieval_resource_list.append(document) + elif document.provider == "external": + retrieval_resource_list.append(document) + return retrieval_resource_list - def _on_query(self, query: str, dataset_ids: list[str], app_id: str, user_from: str, user_id: str): + def _on_retrieval_end( + self, flask_app: Flask, documents: list[Document], message_id: str | None = None, timer: dict | None = None + ): + """Handle retrieval end.""" + with flask_app.app_context(): + dify_documents = [document for document in documents if document.provider == "dify"] + segment_ids = [] + segment_index_node_ids = [] + with Session(db.engine) as session: + for document in dify_documents: + if document.metadata is not None: + dataset_document_stmt = select(DatasetDocument).where( + DatasetDocument.id == document.metadata["document_id"] + ) + dataset_document = session.scalar(dataset_document_stmt) + if dataset_document: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + segment_id = None + if ( + "doc_type" not in document.metadata + or document.metadata.get("doc_type") == DocType.TEXT + ): + child_chunk_stmt = select(ChildChunk).where( + ChildChunk.index_node_id == document.metadata["doc_id"], + ChildChunk.dataset_id == dataset_document.dataset_id, + ChildChunk.document_id == dataset_document.id, + ) + child_chunk = session.scalar(child_chunk_stmt) + if child_chunk: + segment_id = child_chunk.segment_id + elif ( + "doc_type" in document.metadata + and document.metadata.get("doc_type") == DocType.IMAGE + ): + attachment_info_dict = RetrievalService.get_segment_attachment_info( + dataset_document.dataset_id, + dataset_document.tenant_id, + document.metadata.get("doc_id") or "", + session, + ) + if attachment_info_dict: + segment_id = attachment_info_dict["segment_id"] + if segment_id: + if segment_id not in segment_ids: + segment_ids.append(segment_id) + _ = ( + session.query(DocumentSegment) + .where(DocumentSegment.id == segment_id) + .update( + {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, + synchronize_session=False, + ) + ) + else: + query = None + if ( + "doc_type" not in document.metadata + or document.metadata.get("doc_type") == DocType.TEXT + ): + if document.metadata["doc_id"] not in segment_index_node_ids: + segment = ( + session.query(DocumentSegment) + .where(DocumentSegment.index_node_id == document.metadata["doc_id"]) + .first() + ) + if segment: + segment_index_node_ids.append(document.metadata["doc_id"]) + segment_ids.append(segment.id) + query = session.query(DocumentSegment).where( + DocumentSegment.id == segment.id + ) + elif ( + "doc_type" in document.metadata + and document.metadata.get("doc_type") == DocType.IMAGE + ): + attachment_info_dict = RetrievalService.get_segment_attachment_info( + dataset_document.dataset_id, + dataset_document.tenant_id, + document.metadata.get("doc_id") or "", + session, + ) + if attachment_info_dict: + segment_id = attachment_info_dict["segment_id"] + if segment_id not in segment_ids: + segment_ids.append(segment_id) + query = session.query(DocumentSegment).where(DocumentSegment.id == segment_id) + if query: + # if 'dataset_id' in document.metadata: + if "dataset_id" in document.metadata: + query = query.where( + DocumentSegment.dataset_id == document.metadata["dataset_id"] + ) + + # add hit count to document segment + query.update( + {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, + synchronize_session=False, + ) + + db.session.commit() + + # get tracing instance + trace_manager: TraceQueueManager | None = ( + self.application_generate_entity.trace_manager if self.application_generate_entity else None + ) + if trace_manager: + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer + ) + ) + + def _on_query( + self, + query: str | None, + attachment_ids: list[str] | None, + dataset_ids: list[str], + app_id: str, + user_from: str, + user_id: str, + ): """ Handle query. """ - if not query: + if not query and not attachment_ids: return dataset_queries = [] for dataset_id in dataset_ids: - dataset_query = DatasetQuery( - dataset_id=dataset_id, - content=query, - source="app", - source_app_id=app_id, - created_by_role=user_from, - created_by=user_id, - ) - dataset_queries.append(dataset_query) - if dataset_queries: - db.session.add_all(dataset_queries) + contents = [] + if query: + contents.append({"content_type": QueryType.TEXT_QUERY, "content": query}) + if attachment_ids: + for attachment_id in attachment_ids: + contents.append({"content_type": QueryType.IMAGE_QUERY, "content": attachment_id}) + if contents: + dataset_query = DatasetQuery( + dataset_id=dataset_id, + content=json.dumps(contents), + source="app", + source_app_id=app_id, + created_by_role=user_from, + created_by=user_id, + ) + dataset_queries.append(dataset_query) + if dataset_queries: + db.session.add_all(dataset_queries) db.session.commit() def _retriever( @@ -603,6 +743,7 @@ class DatasetRetrieval: all_documents: list, document_ids_filter: list[str] | None = None, metadata_condition: MetadataCondition | None = None, + attachment_ids: list[str] | None = None, ): with flask_app.app_context(): dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) @@ -611,7 +752,7 @@ class DatasetRetrieval: if not dataset: return [] - if dataset.provider == "external": + if dataset.provider == "external" and query: external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( tenant_id=dataset.tenant_id, dataset_id=dataset_id, @@ -663,6 +804,7 @@ class DatasetRetrieval: reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", weights=retrieval_model.get("weights", None), document_ids_filter=document_ids_filter, + attachment_ids=attachment_ids, ) all_documents.extend(documents) @@ -1222,3 +1364,86 @@ class DatasetRetrieval: usage = LLMUsage.empty_usage() return full_text, usage + + def _multiple_retrieve_thread( + self, + flask_app: Flask, + available_datasets: list, + metadata_condition: MetadataCondition | None, + metadata_filter_document_ids: dict[str, list[str]] | None, + all_documents: list[Document], + tenant_id: str, + reranking_enable: bool, + reranking_mode: str, + reranking_model: dict | None, + weights: dict[str, Any] | None, + top_k: int, + score_threshold: float, + query: str | None, + attachment_id: str | None, + ): + with flask_app.app_context(): + threads = [] + all_documents_item: list[Document] = [] + index_type = None + for dataset in available_datasets: + index_type = dataset.indexing_technique + document_ids_filter = None + if dataset.provider != "external": + if metadata_condition and not metadata_filter_document_ids: + continue + if metadata_filter_document_ids: + document_ids = metadata_filter_document_ids.get(dataset.id, []) + if document_ids: + document_ids_filter = document_ids + else: + continue + retrieval_thread = threading.Thread( + target=self._retriever, + kwargs={ + "flask_app": flask_app, + "dataset_id": dataset.id, + "query": query, + "top_k": top_k, + "all_documents": all_documents_item, + "document_ids_filter": document_ids_filter, + "metadata_condition": metadata_condition, + "attachment_ids": [attachment_id] if attachment_id else None, + }, + ) + threads.append(retrieval_thread) + retrieval_thread.start() + for thread in threads: + thread.join() + + if reranking_enable: + # do rerank for searched documents + data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False) + if query: + all_documents_item = data_post_processor.invoke( + query=query, + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.TEXT_QUERY, + ) + if attachment_id: + all_documents_item = data_post_processor.invoke( + documents=all_documents_item, + score_threshold=score_threshold, + top_n=top_k, + query_type=QueryType.IMAGE_QUERY, + query=attachment_id, + ) + else: + if index_type == IndexTechniqueType.ECONOMY: + if not query: + all_documents_item = [] + else: + all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k) + elif index_type == IndexTechniqueType.HIGH_QUALITY: + all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold) + else: + all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item + if all_documents_item: + all_documents.extend(all_documents_item) diff --git a/api/core/schemas/builtin/schemas/v1/multimodal_general_structure.json b/api/core/schemas/builtin/schemas/v1/multimodal_general_structure.json new file mode 100644 index 0000000000..1a07869662 --- /dev/null +++ b/api/core/schemas/builtin/schemas/v1/multimodal_general_structure.json @@ -0,0 +1,65 @@ +{ + "$id": "https://dify.ai/schemas/v1/multimodal_general_structure.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "version": "1.0.0", + "type": "array", + "title": "Multimodal General Structure", + "description": "Schema for multimodal general structure (v1) - array of objects", + "properties": { + "general_chunks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "The content" + }, + "files": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "file name" + }, + "size": { + "type": "number", + "description": "file size" + }, + "extension": { + "type": "string", + "description": "file extension" + }, + "type": { + "type": "string", + "description": "file type" + }, + "mime_type": { + "type": "string", + "description": "file mime type" + }, + "transfer_method": { + "type": "string", + "description": "file transfer method" + }, + "url": { + "type": "string", + "description": "file url" + }, + "related_id": { + "type": "string", + "description": "file related id" + } + }, + "description": "List of files" + } + } + }, + "required": ["content"] + }, + "description": "List of content and files" + } + } +} \ No newline at end of file diff --git a/api/core/schemas/builtin/schemas/v1/multimodal_parent_child_structure.json b/api/core/schemas/builtin/schemas/v1/multimodal_parent_child_structure.json new file mode 100644 index 0000000000..4ffb590519 --- /dev/null +++ b/api/core/schemas/builtin/schemas/v1/multimodal_parent_child_structure.json @@ -0,0 +1,78 @@ +{ + "$id": "https://dify.ai/schemas/v1/multimodal_parent_child_structure.json", + "$schema": "http://json-schema.org/draft-07/schema#", + "version": "1.0.0", + "type": "object", + "title": "Multimodal Parent-Child Structure", + "description": "Schema for multimodal parent-child structure (v1)", + "properties": { + "parent_mode": { + "type": "string", + "description": "The mode of parent-child relationship" + }, + "parent_child_chunks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "parent_content": { + "type": "string", + "description": "The parent content" + }, + "files": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "file name" + }, + "size": { + "type": "number", + "description": "file size" + }, + "extension": { + "type": "string", + "description": "file extension" + }, + "type": { + "type": "string", + "description": "file type" + }, + "mime_type": { + "type": "string", + "description": "file mime type" + }, + "transfer_method": { + "type": "string", + "description": "file transfer method" + }, + "url": { + "type": "string", + "description": "file url" + }, + "related_id": { + "type": "string", + "description": "file related id" + } + }, + "required": ["name", "size", "extension", "type", "mime_type", "transfer_method", "url", "related_id"] + }, + "description": "List of files" + }, + "child_contents": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of child contents" + } + }, + "required": ["parent_content", "child_contents"] + }, + "description": "List of parent-child chunk pairs" + } + }, + "required": ["parent_mode", "parent_child_chunks"] +} \ No newline at end of file diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index 5cdf473542..fef3157f27 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -25,6 +25,24 @@ def sign_tool_file(tool_file_id: str, extension: str) -> str: return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" +def sign_upload_file(upload_file_id: str, extension: str) -> str: + """ + sign file to get a temporary url for plugin access + """ + # Use internal URL for plugin/tool file access in Docker environments + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + file_preview_url = f"{base_url}/files/{upload_file_id}/image-preview" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: """ verify signature diff --git a/api/core/tools/utils/text_processing_utils.py b/api/core/tools/utils/text_processing_utils.py index 105823f896..80c69e94c8 100644 --- a/api/core/tools/utils/text_processing_utils.py +++ b/api/core/tools/utils/text_processing_utils.py @@ -13,5 +13,5 @@ def remove_leading_symbols(text: str) -> str: """ # Match Unicode ranges for punctuation and symbols # FIXME this pattern is confused quick fix for #11868 maybe refactor it later - pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+" + pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F\"#$%&'()*+,./:;<=>?@^_`~]+" return re.sub(pattern, "", text) diff --git a/api/core/workflow/node_events/node.py b/api/core/workflow/node_events/node.py index ebf93f2fc2..e4fa52f444 100644 --- a/api/core/workflow/node_events/node.py +++ b/api/core/workflow/node_events/node.py @@ -3,6 +3,7 @@ from datetime import datetime from pydantic import Field +from core.file import File from core.model_runtime.entities.llm_entities import LLMUsage from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.pause_reason import PauseReason @@ -14,6 +15,7 @@ from .base import NodeEventBase class RunRetrieverResourceEvent(NodeEventBase): retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") context: str = Field(..., description="context") + context_files: list[File] | None = Field(default=None, description="context files") class ModelInvokeCompletedEvent(NodeEventBase): diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 8aa6a5016f..86bb2495e7 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -114,7 +114,8 @@ class KnowledgeRetrievalNodeData(BaseNodeData): """ type: str = "knowledge-retrieval" - query_variable_selector: list[str] + query_variable_selector: list[str] | None | str = None + query_attachment_selector: list[str] | None | str = None dataset_ids: list[str] retrieval_mode: Literal["single", "multiple"] multiple_retrieval_config: MultipleRetrievalConfig | None = None diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 1b57d23e24..adc474bd60 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -25,6 +25,8 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables import ( + ArrayFileSegment, + FileSegment, StringSegment, ) from core.variables.segments import ArrayObjectSegment @@ -119,20 +121,41 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD return "1" def _run(self) -> NodeRunResult: - # extract variables - variable = self.graph_runtime_state.variable_pool.get(self.node_data.query_variable_selector) - if not isinstance(variable, StringSegment): + if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector: return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, - error="Query variable is not string type.", - ) - query = variable.value - variables = {"query": query} - if not query: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required." + process_data={}, + outputs={}, + metadata={}, + llm_usage=LLMUsage.empty_usage(), ) + variables: dict[str, Any] = {} + # extract variables + if self._node_data.query_variable_selector: + variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector) + if not isinstance(variable, StringSegment): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error="Query variable is not string type.", + ) + query = variable.value + variables["query"] = query + + if self._node_data.query_attachment_selector: + variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_attachment_selector) + if not isinstance(variable, ArrayFileSegment) and not isinstance(variable, FileSegment): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs={}, + error="Attachments variable is not array file or file type.", + ) + if isinstance(variable, ArrayFileSegment): + variables["attachments"] = variable.value + else: + variables["attachments"] = [variable.value] + # TODO(-LAN-): Move this check outside. # check rate limit knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id) @@ -161,7 +184,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD # retrieve knowledge usage = LLMUsage.empty_usage() try: - results, usage = self._fetch_dataset_retriever(node_data=self.node_data, query=query) + results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables) outputs = {"result": ArrayObjectSegment(value=results)} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -198,12 +221,16 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD db.session.close() def _fetch_dataset_retriever( - self, node_data: KnowledgeRetrievalNodeData, query: str + self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any] ) -> tuple[list[dict[str, Any]], LLMUsage]: usage = LLMUsage.empty_usage() available_datasets = [] dataset_ids = node_data.dataset_ids - + query = variables.get("query") + attachments = variables.get("attachments") + metadata_filter_document_ids = None + metadata_condition = None + metadata_usage = LLMUsage.empty_usage() # Subquery: Count the number of available documents for each dataset subquery = ( db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count")) @@ -234,13 +261,14 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD if not dataset: continue available_datasets.append(dataset) - metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition( - [dataset.id for dataset in available_datasets], query, node_data - ) - usage = self._merge_usage(usage, metadata_usage) + if query: + metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition( + [dataset.id for dataset in available_datasets], query, node_data + ) + usage = self._merge_usage(usage, metadata_usage) all_documents = [] dataset_retrieval = DatasetRetrieval() - if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: + if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query: # fetch model config if node_data.single_retrieval_config is None: raise ValueError("single_retrieval_config is required") @@ -272,7 +300,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD metadata_filter_document_ids=metadata_filter_document_ids, metadata_condition=metadata_condition, ) - elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: + elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": @@ -319,6 +347,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD reranking_enable=node_data.multiple_retrieval_config.reranking_enable, metadata_filter_document_ids=metadata_filter_document_ids, metadata_condition=metadata_condition, + attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None, ) usage = self._merge_usage(usage, dataset_retrieval.llm_usage) @@ -327,7 +356,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD retrieval_resource_list = [] # deal with external documents for item in external_documents: - source = { + source: dict[str, dict[str, str | Any | dict[Any, Any] | None] | Any | str | None] = { "metadata": { "_source": "knowledge", "dataset_id": item.metadata.get("dataset_id"), @@ -384,6 +413,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD "doc_metadata": document.doc_metadata, }, "title": document.name, + "files": list(record.files) if record.files else None, } if segment.answer: source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}" @@ -393,13 +423,21 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD if retrieval_resource_list: retrieval_resource_list = sorted( retrieval_resource_list, - key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0, + key=self._score, # type: ignore[arg-type, return-value] reverse=True, ) for position, item in enumerate(retrieval_resource_list, start=1): - item["metadata"]["position"] = position + item["metadata"]["position"] = position # type: ignore[index] return retrieval_resource_list, usage + def _score(self, item: dict[str, Any]) -> float: + meta = item.get("metadata") + if isinstance(meta, dict): + s = meta.get("score") + if isinstance(s, (int, float)): + return float(s) + return 0.0 + def _get_metadata_filter_condition( self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]: @@ -659,7 +697,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data) variable_mapping = {} - variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector + if typed_node_data.query_variable_selector: + variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector + if typed_node_data.query_attachment_selector: + variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector return variable_mapping def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 1a2473e0bb..10682ae38a 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -7,8 +7,10 @@ import time from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal +from sqlalchemy import select + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.file import FileType, file_manager +from core.file import File, FileTransferMethod, FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output @@ -44,6 +46,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.tools.signature import sign_upload_file from core.variables import ( ArrayFileSegment, ArraySegment, @@ -72,6 +75,9 @@ from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser from core.workflow.runtime import VariablePool +from extensions.ext_database import db +from models.dataset import SegmentAttachmentBinding +from models.model import UploadFile from . import llm_utils from .entities import ( @@ -179,12 +185,17 @@ class LLMNode(Node[LLMNodeData]): # fetch context value generator = self._fetch_context(node_data=self.node_data) context = None + context_files: list[File] = [] for event in generator: context = event.context + context_files = event.context_files or [] yield event if context: node_inputs["#context#"] = context + if context_files: + node_inputs["#context_files#"] = [file.model_dump() for file in context_files] + # fetch model config model_instance, model_config = LLMNode._fetch_model_config( node_data_model=self.node_data.model, @@ -220,6 +231,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool=variable_pool, jinja2_variables=self.node_data.prompt_config.jinja2_variables, tenant_id=self.tenant_id, + context_files=context_files, ) # handle invoke result @@ -654,10 +666,13 @@ class LLMNode(Node[LLMNodeData]): context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector) if context_value_variable: if isinstance(context_value_variable, StringSegment): - yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value) + yield RunRetrieverResourceEvent( + retriever_resources=[], context=context_value_variable.value, context_files=[] + ) elif isinstance(context_value_variable, ArraySegment): context_str = "" original_retriever_resource: list[RetrievalSourceMetadata] = [] + context_files: list[File] = [] for item in context_value_variable.value: if isinstance(item, str): context_str += item + "\n" @@ -670,9 +685,34 @@ class LLMNode(Node[LLMNodeData]): retriever_resource = self._convert_to_original_retriever_resource(item) if retriever_resource: original_retriever_resource.append(retriever_resource) - + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.segment_id == retriever_resource.segment_id, + ) + ).all() + if attachments_with_bindings: + for _, upload_file in attachments_with_bindings: + attchment_info = File( + id=upload_file.id, + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + tenant_id=self.tenant_id, + type=FileType.IMAGE, + transfer_method=FileTransferMethod.LOCAL_FILE, + remote_url=upload_file.source_url, + related_id=upload_file.id, + size=upload_file.size, + storage_key=upload_file.key, + url=sign_upload_file(upload_file.id, upload_file.extension), + ) + context_files.append(attchment_info) yield RunRetrieverResourceEvent( - retriever_resources=original_retriever_resource, context=context_str.strip() + retriever_resources=original_retriever_resource, + context=context_str.strip(), + context_files=context_files, ) def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None: @@ -700,6 +740,7 @@ class LLMNode(Node[LLMNodeData]): content=context_dict.get("content"), page=metadata.get("page"), doc_metadata=metadata.get("doc_metadata"), + files=context_dict.get("files"), ) return source @@ -741,6 +782,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], tenant_id: str, + context_files: list["File"] | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] @@ -853,6 +895,23 @@ class LLMNode(Node[LLMNodeData]): else: prompt_messages.append(UserPromptMessage(content=file_prompts)) + # The context_files + if vision_enabled and context_files: + file_prompts = [] + for file in context_files: + file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) + file_prompts.append(file_prompt) + # If last prompt is a user prompt, add files into its contents, + # otherwise append a new user prompt + if ( + len(prompt_messages) > 0 + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) + # Remove empty messages and filter unsupported content filtered_prompt_messages = [] for prompt_message in prompt_messages: diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index 89c4d8fba9..1e5ec7d200 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -97,11 +97,27 @@ dataset_detail_fields = { "total_documents": fields.Integer, "total_available_documents": fields.Integer, "enable_api": fields.Boolean, + "is_multimodal": fields.Boolean, +} + +file_info_fields = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "source_url": fields.String, +} + +content_fields = { + "content_type": fields.String, + "content": fields.String, + "file_info": fields.Nested(file_info_fields, allow_null=True), } dataset_query_detail_fields = { "id": fields.String, - "content": fields.String, + "queries": fields.Nested(content_fields), "source": fields.String, "source_app_id": fields.String, "created_by_role": fields.String, diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index c12ebc09c8..a707500445 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -9,6 +9,8 @@ upload_config_fields = { "video_file_size_limit": fields.Integer, "audio_file_size_limit": fields.Integer, "workflow_file_upload_limit": fields.Integer, + "image_file_batch_limit": fields.Integer, + "single_chunk_attachment_limit": fields.Integer, } diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index 75bdff1803..e70f9fa722 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -43,9 +43,19 @@ child_chunk_fields = { "score": fields.Float, } +files_fields = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "source_url": fields.String, +} + hit_testing_record_fields = { "segment": fields.Nested(segment_fields), "child_chunks": fields.List(fields.Nested(child_chunk_fields)), "score": fields.Float, "tsne_position": fields.Raw, + "files": fields.List(fields.Nested(files_fields)), } diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 2ff917d6bc..56d6b68378 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -13,6 +13,15 @@ child_chunk_fields = { "updated_at": TimestampField, } +attachment_fields = { + "id": fields.String, + "name": fields.String, + "size": fields.Integer, + "extension": fields.String, + "mime_type": fields.String, + "source_url": fields.String, +} + segment_fields = { "id": fields.String, "position": fields.Integer, @@ -39,4 +48,5 @@ segment_fields = { "error": fields.String, "stopped_at": TimestampField, "child_chunks": fields.List(fields.Nested(child_chunk_fields)), + "attachments": fields.List(fields.Nested(attachment_fields)), } diff --git a/api/migrations/versions/2025_11_12_1537-d57accd375ae_support_multi_modal.py b/api/migrations/versions/2025_11_12_1537-d57accd375ae_support_multi_modal.py new file mode 100644 index 0000000000..187bf7136d --- /dev/null +++ b/api/migrations/versions/2025_11_12_1537-d57accd375ae_support_multi_modal.py @@ -0,0 +1,57 @@ +"""support-multi-modal + +Revision ID: d57accd375ae +Revises: 03f8dcbc611e +Create Date: 2025-11-12 15:37:12.363670 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd57accd375ae' +down_revision = '7bb281b7a422' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('segment_attachment_bindings', + sa.Column('id', models.types.StringUUID(), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('segment_id', models.types.StringUUID(), nullable=False), + sa.Column('attachment_id', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.PrimaryKeyConstraint('id', name='segment_attachment_binding_pkey') + ) + with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op: + batch_op.create_index( + 'segment_attachment_binding_tenant_dataset_document_segment_idx', + ['tenant_id', 'dataset_id', 'document_id', 'segment_id'], + unique=False + ) + batch_op.create_index('segment_attachment_binding_attachment_idx', ['attachment_id'], unique=False) + + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.add_column(sa.Column('is_multimodal', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please + with op.batch_alter_table('datasets', schema=None) as batch_op: + batch_op.drop_column('is_multimodal') + + + with op.batch_alter_table('segment_attachment_bindings', schema=None) as batch_op: + batch_op.drop_index('segment_attachment_binding_attachment_idx') + batch_op.drop_index('segment_attachment_binding_tenant_dataset_document_segment_idx') + + op.drop_table('segment_attachment_bindings') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index e072711b82..5bbf44050c 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -19,7 +19,9 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource +from core.rag.index_processor.constant.query_type import QueryType from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.tools.signature import sign_upload_file from extensions.ext_storage import storage from libs.uuid_utils import uuidv7 from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule @@ -76,6 +78,7 @@ class Dataset(Base): pipeline_id = mapped_column(StringUUID, nullable=True) chunk_structure = mapped_column(sa.String(255), nullable=True) enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) + is_multimodal = mapped_column(sa.Boolean, nullable=False, server_default=db.text("false")) @property def total_documents(self): @@ -728,9 +731,7 @@ class DocumentSegment(Base): created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) - updated_at: Mapped[datetime] = mapped_column( - DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() - ) + updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) indexing_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) completed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) error = mapped_column(LongText, nullable=True) @@ -866,6 +867,47 @@ class DocumentSegment(Base): return text + @property + def attachments(self) -> list[dict[str, Any]]: + # Use JOIN to fetch attachments in a single query instead of two separate queries + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.tenant_id == self.tenant_id, + SegmentAttachmentBinding.dataset_id == self.dataset_id, + SegmentAttachmentBinding.document_id == self.document_id, + SegmentAttachmentBinding.segment_id == self.id, + ) + ).all() + if not attachments_with_bindings: + return [] + attachment_list = [] + for _, attachment in attachments_with_bindings: + upload_file_id = attachment.id + nonce = os.urandom(16).hex() + timestamp = str(int(time.time())) + data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + reference_url = dify_config.CONSOLE_API_URL or "" + base_url = f"{reference_url}/files/{upload_file_id}/image-preview" + source_url = f"{base_url}?{params}" + attachment_list.append( + { + "id": attachment.id, + "name": attachment.name, + "size": attachment.size, + "extension": attachment.extension, + "mime_type": attachment.mime_type, + "source_url": source_url, + } + ) + return attachment_list + class ChildChunk(Base): __tablename__ = "child_chunks" @@ -963,6 +1005,38 @@ class DatasetQuery(TypeBase): DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False ) + @property + def queries(self) -> list[dict[str, Any]]: + try: + queries = json.loads(self.content) + if isinstance(queries, list): + for query in queries: + if query["content_type"] == QueryType.IMAGE_QUERY: + file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first() + if file_info: + query["file_info"] = { + "id": file_info.id, + "name": file_info.name, + "size": file_info.size, + "extension": file_info.extension, + "mime_type": file_info.mime_type, + "source_url": sign_upload_file(file_info.id, file_info.extension), + } + else: + query["file_info"] = None + + return queries + else: + return [queries] + except JSONDecodeError: + return [ + { + "content_type": QueryType.TEXT_QUERY, + "content": self.content, + "file_info": None, + } + ] + class DatasetKeywordTable(TypeBase): __tablename__ = "dataset_keyword_tables" @@ -1470,3 +1544,25 @@ class PipelineRecommendedPlugin(TypeBase): onupdate=func.current_timestamp(), init=False, ) + + +class SegmentAttachmentBinding(Base): + __tablename__ = "segment_attachment_bindings" + __table_args__ = ( + sa.PrimaryKeyConstraint("id", name="segment_attachment_binding_pkey"), + sa.Index( + "segment_attachment_binding_tenant_dataset_document_segment_idx", + "tenant_id", + "dataset_id", + "document_id", + "segment_id", + ), + sa.Index("segment_attachment_binding_attachment_idx", "attachment_id"), + ) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7())) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + document_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/services/attachment_service.py b/api/services/attachment_service.py new file mode 100644 index 0000000000..2bd5627d5e --- /dev/null +++ b/api/services/attachment_service.py @@ -0,0 +1,31 @@ +import base64 + +from sqlalchemy import Engine +from sqlalchemy.orm import sessionmaker +from werkzeug.exceptions import NotFound + +from extensions.ext_storage import storage +from models.model import UploadFile + +PREVIEW_WORDS_LIMIT = 3000 + + +class AttachmentService: + _session_maker: sessionmaker + + def __init__(self, session_factory: sessionmaker | Engine | None = None): + if isinstance(session_factory, Engine): + self._session_maker = sessionmaker(bind=session_factory) + elif isinstance(session_factory, sessionmaker): + self._session_maker = session_factory + else: + raise AssertionError("must be a sessionmaker or an Engine.") + + def get_file_base64(self, file_id: str) -> str: + upload_file = ( + self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first() + ) + if not upload_file: + raise NotFound("File not found") + blob = storage.load_once(upload_file.key) + return base64.b64encode(blob).decode() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index bb09311349..00f06e9405 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -7,7 +7,7 @@ import time import uuid from collections import Counter from collections.abc import Sequence -from typing import Any, Literal +from typing import Any, Literal, cast import sqlalchemy as sa from redis.exceptions import LockNotOwnedError @@ -19,9 +19,10 @@ from configs import dify_config from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.helper.name_generator import generate_incremental_name from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod from enums.cloud_plan import CloudPlan from events.dataset_event import dataset_was_deleted @@ -46,6 +47,7 @@ from models.dataset import ( DocumentSegment, ExternalKnowledgeBindings, Pipeline, + SegmentAttachmentBinding, ) from models.model import UploadFile from models.provider_ids import ModelProviderID @@ -363,6 +365,27 @@ class DatasetService: except ProviderTokenNotInitError as ex: raise ValueError(ex.description) + @staticmethod + def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str): + try: + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + provider=model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=model, + ) + text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance) + model_schema = text_embedding_model.get_model_schema(model_instance.model, model_instance.credentials) + if not model_schema: + raise ValueError("Model schema not found") + if model_schema.features and ModelFeature.VISION in model_schema.features: + return True + else: + return False + except LLMBadRequestError: + raise ValueError("No Model available. Please configure a valid provider in the Settings -> Model Provider.") + @staticmethod def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str): try: @@ -402,13 +425,13 @@ class DatasetService: if not dataset: raise ValueError("Dataset not found") # check if dataset name is exists - - if DatasetService._has_dataset_same_name( - tenant_id=dataset.tenant_id, - dataset_id=dataset_id, - name=data.get("name", dataset.name), - ): - raise ValueError("Dataset name already exists") + if data.get("name") and data.get("name") != dataset.name: + if DatasetService._has_dataset_same_name( + tenant_id=dataset.tenant_id, + dataset_id=dataset_id, + name=data.get("name", dataset.name), + ): + raise ValueError("Dataset name already exists") # Verify user has permission to update this dataset DatasetService.check_dataset_permission(dataset, user) @@ -844,6 +867,12 @@ class DatasetService: model_type=ModelType.TEXT_EMBEDDING, model=knowledge_configuration.embedding_model or "", ) + is_multimodal = DatasetService.check_is_multimodal_model( + current_user.current_tenant_id, + knowledge_configuration.embedding_model_provider, + knowledge_configuration.embedding_model, + ) + dataset.is_multimodal = is_multimodal dataset.embedding_model = embedding_model.model dataset.embedding_model_provider = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( @@ -880,6 +909,12 @@ class DatasetService: dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( embedding_model.provider, embedding_model.model ) + is_multimodal = DatasetService.check_is_multimodal_model( + current_user.current_tenant_id, + knowledge_configuration.embedding_model_provider, + knowledge_configuration.embedding_model, + ) + dataset.is_multimodal = is_multimodal dataset.collection_binding_id = dataset_collection_binding.id dataset.indexing_technique = knowledge_configuration.indexing_technique except LLMBadRequestError: @@ -937,6 +972,12 @@ class DatasetService: ) ) dataset.collection_binding_id = dataset_collection_binding.id + is_multimodal = DatasetService.check_is_multimodal_model( + current_user.current_tenant_id, + knowledge_configuration.embedding_model_provider, + knowledge_configuration.embedding_model, + ) + dataset.is_multimodal = is_multimodal except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -2305,6 +2346,7 @@ class DocumentService: embedding_model_provider=knowledge_config.embedding_model_provider, collection_binding_id=dataset_collection_binding_id, retrieval_model=retrieval_model.model_dump() if retrieval_model else None, + is_multimodal=knowledge_config.is_multimodal, ) db.session.add(dataset) @@ -2685,6 +2727,13 @@ class SegmentService: if "content" not in args or not args["content"] or not args["content"].strip(): raise ValueError("Content is empty") + if args.get("attachment_ids"): + if not isinstance(args["attachment_ids"], list): + raise ValueError("Attachment IDs is invalid") + single_chunk_attachment_limit = dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT + if len(args["attachment_ids"]) > single_chunk_attachment_limit: + raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}") + @classmethod def create_segment(cls, args: dict, document: Document, dataset: Dataset): assert isinstance(current_user, Account) @@ -2731,11 +2780,23 @@ class SegmentService: segment_document.word_count += len(args["answer"]) segment_document.answer = args["answer"] - db.session.add(segment_document) - # update document word count - assert document.word_count is not None - document.word_count += segment_document.word_count - db.session.add(document) + db.session.add(segment_document) + # update document word count + assert document.word_count is not None + document.word_count += segment_document.word_count + db.session.add(document) + db.session.commit() + + if args["attachment_ids"]: + for attachment_id in args["attachment_ids"]: + binding = SegmentAttachmentBinding( + tenant_id=current_user.current_tenant_id, + dataset_id=document.dataset_id, + document_id=document.id, + segment_id=segment_document.id, + attachment_id=attachment_id, + ) + db.session.add(binding) db.session.commit() # save vector index @@ -2899,7 +2960,7 @@ class SegmentService: document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) # update segment index task - if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # regenerate child chunks # get embedding model instance if dataset.indexing_technique == "high_quality": @@ -2926,12 +2987,11 @@ class SegmentService: .where(DatasetProcessRule.id == document.dataset_process_rule_id) .first() ) - if not processing_rule: - raise ValueError("No processing rule found.") - VectorService.generate_child_chunks( - segment, document, dataset, embedding_model_instance, processing_rule, True - ) - elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): + if processing_rule: + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) + elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX): if args.enabled or keyword_changed: # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) @@ -2976,7 +3036,7 @@ class SegmentService: db.session.add(document) db.session.add(segment) db.session.commit() - if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # get embedding model instance if dataset.indexing_technique == "high_quality": # check embedding model setting @@ -3002,15 +3062,15 @@ class SegmentService: .where(DatasetProcessRule.id == document.dataset_process_rule_id) .first() ) - if not processing_rule: - raise ValueError("No processing rule found.") - VectorService.generate_child_chunks( - segment, document, dataset, embedding_model_instance, processing_rule, True - ) - elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): + if processing_rule: + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) + elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX): # update segment vector index VectorService.update_segment_vector(args.keywords, segment, dataset) - + # update multimodel vector index + VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset) except Exception as e: logger.exception("update segment index failed") segment.enabled = False @@ -3048,7 +3108,9 @@ class SegmentService: ) child_node_ids = [chunk[0] for chunk in child_chunks if chunk[0]] - delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id, child_node_ids) + delete_segment_from_index_task.delay( + [segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids + ) db.session.delete(segment) # update document word count @@ -3097,7 +3159,9 @@ class SegmentService: # Start async cleanup with both parent and child node IDs if index_node_ids or child_node_ids: - delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids) + delete_segment_from_index_task.delay( + index_node_ids, dataset.id, document.id, segment_db_ids, child_node_ids + ) if document.word_count is None: document.word_count = 0 diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 131e90e195..7959734e89 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -124,6 +124,14 @@ class KnowledgeConfig(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None name: str | None = None + is_multimodal: bool = False + + +class SegmentCreateArgs(BaseModel): + content: str | None = None + answer: str | None = None + keywords: list[str] | None = None + attachment_ids: list[str] | None = None class SegmentUpdateArgs(BaseModel): @@ -132,6 +140,7 @@ class SegmentUpdateArgs(BaseModel): keywords: list[str] | None = None regenerate_child_chunks: bool = False enabled: bool | None = None + attachment_ids: list[str] | None = None class ChildChunkUpdateArgs(BaseModel): diff --git a/api/services/file_service.py b/api/services/file_service.py index 1980cd8d59..0911cf38c4 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -1,3 +1,4 @@ +import base64 import hashlib import os import uuid @@ -123,6 +124,15 @@ class FileService: return file_size <= file_size_limit + def get_file_base64(self, file_id: str) -> str: + upload_file = ( + self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first() + ) + if not upload_file: + raise NotFound("File not found") + blob = storage.load_once(upload_file.key) + return base64.b64encode(blob).decode() + def upload_text(self, text: str, text_name: str, user_id: str, tenant_id: str) -> UploadFile: if len(text_name) > 200: text_name = text_name[:200] diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index dfb49cf2bd..8e8e78f83f 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -1,3 +1,4 @@ +import json import logging import time from typing import Any @@ -5,6 +6,7 @@ from typing import Any from core.app.app_config.entities import ModelConfig from core.model_runtime.entities import LLMMode from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -32,6 +34,7 @@ class HitTestingService: account: Account, retrieval_model: Any, # FIXME drop this any external_retrieval_model: dict, + attachment_ids: list | None = None, limit: int = 10, ): start = time.perf_counter() @@ -41,7 +44,7 @@ class HitTestingService: retrieval_model = dataset.retrieval_model or default_retrieval_model document_ids_filter = None metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {}) - if metadata_filtering_conditions: + if metadata_filtering_conditions and query: dataset_retrieval = DatasetRetrieval() from core.app.app_config.entities import MetadataFilteringCondition @@ -66,6 +69,7 @@ class HitTestingService: retrieval_method=RetrievalMethod(retrieval_model.get("search_method", RetrievalMethod.SEMANTIC_SEARCH)), dataset_id=dataset.id, query=query, + attachment_ids=attachment_ids, top_k=retrieval_model.get("top_k", 4), score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] @@ -80,17 +84,24 @@ class HitTestingService: end = time.perf_counter() logger.debug("Hit testing retrieve in %s seconds", end - start) - - dataset_query = DatasetQuery( - dataset_id=dataset.id, - content=query, - source="hit_testing", - source_app_id=None, - created_by_role="account", - created_by=account.id, - ) - - db.session.add(dataset_query) + dataset_queries = [] + if query: + content = {"content_type": QueryType.TEXT_QUERY, "content": query} + dataset_queries.append(content) + if attachment_ids: + for attachment_id in attachment_ids: + content = {"content_type": QueryType.IMAGE_QUERY, "content": attachment_id} + dataset_queries.append(content) + if dataset_queries: + dataset_query = DatasetQuery( + dataset_id=dataset.id, + content=json.dumps(dataset_queries), + source="hit_testing", + source_app_id=None, + created_by_role="account", + created_by=account.id, + ) + db.session.add(dataset_query) db.session.commit() return cls.compact_retrieve_response(query, all_documents) @@ -168,9 +179,14 @@ class HitTestingService: @classmethod def hit_testing_args_check(cls, args): query = args["query"] + attachment_ids = args["attachment_ids"] - if not query or len(query) > 250: - raise ValueError("Query is required and cannot exceed 250 characters") + if not attachment_ids and not query: + raise ValueError("Query or attachment_ids is required") + if query and len(query) > 250: + raise ValueError("Query cannot exceed 250 characters") + if attachment_ids and not isinstance(attachment_ids, list): + raise ValueError("Attachment_ids must be a list") @staticmethod def escape_query_for_search(query: str) -> str: diff --git a/api/services/vector_service.py b/api/services/vector_service.py index abc92a0181..f1fa33cb75 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -4,11 +4,14 @@ from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import AttachmentDocument, Document from extensions.ext_database import db -from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment +from models import UploadFile +from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import ParentMode @@ -21,9 +24,10 @@ class VectorService: cls, keywords_list: list[list[str]] | None, segments: list[DocumentSegment], dataset: Dataset, doc_form: str ): documents: list[Document] = [] + multimodal_documents: list[AttachmentDocument] = [] for segment in segments: - if doc_form == IndexType.PARENT_CHILD_INDEX: + if doc_form == IndexStructureType.PARENT_CHILD_INDEX: dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first() if not dataset_document: logger.warning( @@ -70,12 +74,29 @@ class VectorService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, + "doc_type": DocType.TEXT, }, ) documents.append(rag_document) + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_document: AttachmentDocument = AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + multimodal_documents.append(multimodal_document) + index_processor: BaseIndexProcessor = IndexProcessorFactory(doc_form).init_index_processor() + if len(documents) > 0: - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) + index_processor.load(dataset, documents, None, with_keywords=True, keywords_list=keywords_list) + if len(multimodal_documents) > 0: + index_processor.load(dataset, [], multimodal_documents, with_keywords=False) @classmethod def update_segment_vector(cls, keywords: list[str] | None, segment: DocumentSegment, dataset: Dataset): @@ -130,6 +151,7 @@ class VectorService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, + "doc_type": DocType.TEXT, }, ) # use full doc mode to generate segment's child chunk @@ -226,3 +248,92 @@ class VectorService: def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset): vector = Vector(dataset=dataset) vector.delete_by_ids([child_chunk.index_node_id]) + + @classmethod + def update_multimodel_vector(cls, segment: DocumentSegment, attachment_ids: list[str], dataset: Dataset): + if dataset.indexing_technique != "high_quality": + return + + attachments = segment.attachments + old_attachment_ids = [attachment["id"] for attachment in attachments] if attachments else [] + + # Check if there's any actual change needed + if set(attachment_ids) == set(old_attachment_ids): + return + + try: + vector = Vector(dataset=dataset) + if dataset.is_multimodal: + # Delete old vectors if they exist + if old_attachment_ids: + vector.delete_by_ids(old_attachment_ids) + + # Delete existing segment attachment bindings in one operation + db.session.query(SegmentAttachmentBinding).where(SegmentAttachmentBinding.segment_id == segment.id).delete( + synchronize_session=False + ) + + if not attachment_ids: + db.session.commit() + return + + # Bulk fetch upload files - only fetch needed fields + upload_file_list = db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all() + + if not upload_file_list: + db.session.commit() + return + + # Create a mapping for quick lookup + upload_file_map = {upload_file.id: upload_file for upload_file in upload_file_list} + + # Prepare batch operations + bindings = [] + documents = [] + + # Create common metadata base to avoid repetition + base_metadata = { + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + } + + # Process attachments in the order specified by attachment_ids + for attachment_id in attachment_ids: + upload_file = upload_file_map.get(attachment_id) + if not upload_file: + logger.warning("Upload file not found for attachment_id: %s", attachment_id) + continue + + # Create segment attachment binding + bindings.append( + SegmentAttachmentBinding( + tenant_id=segment.tenant_id, + dataset_id=segment.dataset_id, + document_id=segment.document_id, + segment_id=segment.id, + attachment_id=upload_file.id, + ) + ) + + # Create document for vector indexing + documents.append( + Document(page_content=upload_file.name, metadata={**base_metadata, "doc_id": upload_file.id}) + ) + + # Bulk insert all bindings at once + if bindings: + db.session.add_all(bindings) + + # Add documents to vector store if any + if documents and dataset.is_multimodal: + vector.add_texts(documents, duplicate_check=True) + + # Single commit for all operations + db.session.commit() + + except Exception: + logger.exception("Failed to update multimodal vector for segment %s", segment.id) + db.session.rollback() + raise diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 933ad6b9e2..e7dead8a56 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -4,9 +4,10 @@ import time import click from celery import shared_task -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -55,6 +56,7 @@ def add_document_to_index_task(dataset_document_id: str): ) documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -65,7 +67,7 @@ def add_document_to_index_task(dataset_document_id: str): "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -81,11 +83,25 @@ def add_document_to_index_task(dataset_document_id: str): ) child_documents.append(child_document) document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) index_type = dataset.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.load(dataset, documents) + index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) # delete auto disable log db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete() diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 5f2a355d16..8608df6b8e 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -18,6 +18,7 @@ from models.dataset import ( DatasetQuery, Document, DocumentSegment, + SegmentAttachmentBinding, ) from models.model import UploadFile @@ -58,14 +59,20 @@ def clean_dataset_task( ) documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all() segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all() + # Use JOIN to fetch attachments with bindings in a single query + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where(SegmentAttachmentBinding.tenant_id == tenant_id, SegmentAttachmentBinding.dataset_id == dataset_id) + ).all() # Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace # This ensures all invalid doc_form values are properly handled if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()): # Use default paragraph index type for empty/invalid datasets to enable vector database cleanup - from core.rag.index_processor.constant.index_type import IndexType + from core.rag.index_processor.constant.index_type import IndexStructureType - doc_form = IndexType.PARAGRAPH_INDEX + doc_form = IndexStructureType.PARAGRAPH_INDEX logger.info( click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow") ) @@ -90,6 +97,7 @@ def clean_dataset_task( for document in documents: db.session.delete(document) + # delete document file for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) @@ -107,6 +115,19 @@ def clean_dataset_task( ) db.session.delete(image_file) db.session.delete(segment) + # delete segment attachments + if attachments_with_bindings: + for binding, attachment_file in attachments_with_bindings: + try: + storage.delete(attachment_file.key) + except Exception: + logger.exception( + "Delete attachment_file failed when storage deleted, \ + attachment_file_id: %s", + binding.attachment_id, + ) + db.session.delete(attachment_file) + db.session.delete(binding) db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete() db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete() diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 62200715cc..6d2feb1da3 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -9,7 +9,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_storage import storage -from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment +from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment, SegmentAttachmentBinding from models.model import UploadFile logger = logging.getLogger(__name__) @@ -36,6 +36,16 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i raise Exception("Document has no dataset") segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() + # Use JOIN to fetch attachments with bindings in a single query + attachments_with_bindings = db.session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where( + SegmentAttachmentBinding.tenant_id == dataset.tenant_id, + SegmentAttachmentBinding.dataset_id == dataset_id, + SegmentAttachmentBinding.document_id == document_id, + ) + ).all() # check segment is exist if segments: index_node_ids = [segment.index_node_id for segment in segments] @@ -69,6 +79,19 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i logger.exception("Delete file failed when document deleted, file_id: %s", file_id) db.session.delete(file) db.session.commit() + # delete segment attachments + if attachments_with_bindings: + for binding, attachment_file in attachments_with_bindings: + try: + storage.delete(attachment_file.key) + except Exception: + logger.exception( + "Delete attachment_file failed when storage deleted, \ + attachment_file_id: %s", + binding.attachment_id, + ) + db.session.delete(attachment_file) + db.session.delete(binding) # delete dataset metadata binding db.session.query(DatasetMetadataBinding).where( diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py index 713f149c38..3d13afdec0 100644 --- a/api/tasks/deal_dataset_index_update_task.py +++ b/api/tasks/deal_dataset_index_update_task.py @@ -4,9 +4,10 @@ import time import click from celery import shared_task # type: ignore -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -28,7 +29,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): if not dataset: raise Exception("Dataset not found") - index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX + index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "upgrade": dataset_documents = ( @@ -119,6 +120,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): ) if segments: documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -129,7 +131,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -145,9 +147,25 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): ) child_documents.append(child_document) document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) # save vector index - index_processor.load(dataset, documents, with_keywords=False) + index_processor.load( + dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False + ) db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index dc6ef6fb61..1c7de3b1ce 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -1,14 +1,14 @@ import logging import time -from typing import Literal import click from celery import shared_task from sqlalchemy import select -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "add", "update"]): +def deal_dataset_vector_index_task(dataset_id: str, action: str): """ Async deal dataset from index :param dataset_id: dataset_id @@ -32,7 +32,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a if not dataset: raise Exception("Dataset not found") - index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX + index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "remove": index_processor.clean(dataset, None, with_keywords=False) @@ -119,6 +119,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a ) if segments: documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -129,7 +130,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a "dataset_id": segment.dataset_id, }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -145,9 +146,25 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a ) child_documents.append(child_document) document.children = child_documents + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) # save vector index - index_processor.load(dataset, documents, with_keywords=False) + index_processor.load( + dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False + ) db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index e8cbd0f250..bea5c952cf 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -6,14 +6,15 @@ from celery import shared_task from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db -from models.dataset import Dataset, Document +from models.dataset import Dataset, Document, SegmentAttachmentBinding +from models.model import UploadFile logger = logging.getLogger(__name__) @shared_task(queue="dataset") def delete_segment_from_index_task( - index_node_ids: list, dataset_id: str, document_id: str, child_node_ids: list | None = None + index_node_ids: list, dataset_id: str, document_id: str, segment_ids: list, child_node_ids: list | None = None ): """ Async Remove segment from index @@ -49,6 +50,21 @@ def delete_segment_from_index_task( delete_child_chunks=True, precomputed_child_node_ids=child_node_ids, ) + if dataset.is_multimodal: + # delete segment attachment binding + segment_attachment_bindings = ( + db.session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + .all() + ) + if segment_attachment_bindings: + attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] + index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False) + for binding in segment_attachment_bindings: + db.session.delete(binding) + # delete upload file + db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False) + db.session.commit() end_at = time.perf_counter() logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green")) diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py index 9038dc179b..c2a3de29f4 100644 --- a/api/tasks/disable_segments_from_index_task.py +++ b/api/tasks/disable_segments_from_index_task.py @@ -8,7 +8,7 @@ from sqlalchemy import select from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.dataset import Dataset, DocumentSegment +from models.dataset import Dataset, DocumentSegment, SegmentAttachmentBinding from models.dataset import Document as DatasetDocument logger = logging.getLogger(__name__) @@ -59,6 +59,16 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen try: index_node_ids = [segment.index_node_id for segment in segments] + if dataset.is_multimodal: + segment_ids = [segment.id for segment in segments] + segment_attachment_bindings = ( + db.session.query(SegmentAttachmentBinding) + .where(SegmentAttachmentBinding.segment_id.in_(segment_ids)) + .all() + ) + if segment_attachment_bindings: + attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] + index_node_ids.extend(attachment_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) end_at = time.perf_counter() diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 07c44f333e..7615469ed0 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -4,9 +4,10 @@ import time import click from celery import shared_task -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -67,7 +68,7 @@ def enable_segment_to_index_task(segment_id: str): return index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -83,8 +84,24 @@ def enable_segment_to_index_task(segment_id: str): ) child_documents.append(child_document) document.children = child_documents + multimodel_documents = [] + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodel_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) + # save vector index - index_processor.load(dataset, [document]) + index_processor.load(dataset, [document], multimodal_documents=multimodel_documents) end_at = time.perf_counter() logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green")) diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py index c5ca7a6171..9f17d09e18 100644 --- a/api/tasks/enable_segments_to_index_task.py +++ b/api/tasks/enable_segments_to_index_task.py @@ -5,9 +5,10 @@ import click from celery import shared_task from sqlalchemy import select -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import ChildDocument, Document +from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now @@ -60,6 +61,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i try: documents = [] + multimodal_documents = [] for segment in segments: document = Document( page_content=segment.content, @@ -71,7 +73,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i }, ) - if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: child_chunks = segment.get_child_chunks() if child_chunks: child_documents = [] @@ -87,9 +89,24 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i ) child_documents.append(child_document) document.children = child_documents + + if dataset.is_multimodal: + for attachment in segment.attachments: + multimodal_documents.append( + AttachmentDocument( + page_content=attachment["name"], + metadata={ + "doc_id": attachment["id"], + "doc_hash": "", + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + "doc_type": DocType.IMAGE, + }, + ) + ) documents.append(document) # save vector index - index_processor.load(dataset, documents) + index_processor.load(dataset, documents, multimodal_documents=multimodal_documents) end_at = time.perf_counter() logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green")) diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 9478bb9ddb..088d6ba6ba 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.ext_database import db from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -95,7 +95,7 @@ class TestAddDocumentToIndexTask: created_by=account.id, indexing_status="completed", enabled=True, - doc_form=IndexType.PARAGRAPH_INDEX, + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db.session.add(document) db.session.commit() @@ -172,7 +172,9 @@ class TestAddDocumentToIndexTask: # Assert: Verify the expected outcomes # Verify index processor was called correctly - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify database state changes @@ -204,7 +206,7 @@ class TestAddDocumentToIndexTask: ) # Update document to use different index type - document.doc_form = IndexType.QA_INDEX + document.doc_form = IndexStructureType.QA_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -221,7 +223,9 @@ class TestAddDocumentToIndexTask: add_document_to_index_task(document.id) # Assert: Verify different index type handling - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.QA_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with correct parameters @@ -360,7 +364,7 @@ class TestAddDocumentToIndexTask: ) # Update document to use parent-child index type - document.doc_form = IndexType.PARENT_CHILD_INDEX + document.doc_form = IndexStructureType.PARENT_CHILD_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -391,7 +395,7 @@ class TestAddDocumentToIndexTask: # Assert: Verify parent-child index processing mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( - IndexType.PARENT_CHILD_INDEX + IndexStructureType.PARENT_CHILD_INDEX ) mock_external_service_dependencies["index_processor"].load.assert_called_once() @@ -465,8 +469,10 @@ class TestAddDocumentToIndexTask: # Act: Execute the task add_document_to_index_task(document.id) - # Assert: Verify index processing occurred with all completed segments - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + # Assert: Verify index processing occurred but with empty documents list + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with all completed segments @@ -532,7 +538,9 @@ class TestAddDocumentToIndexTask: assert len(remaining_logs) == 0 # Verify index processing occurred normally - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify segments were enabled @@ -699,7 +707,9 @@ class TestAddDocumentToIndexTask: add_document_to_index_task(document.id) # Assert: Verify only eligible segments were processed - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with correct parameters diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 94e9b76965..37d886f569 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -12,7 +12,7 @@ from unittest.mock import MagicMock, patch from faker import Faker -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Dataset, Document, DocumentSegment, Tenant from tasks.delete_segment_from_index_task import delete_segment_from_index_task @@ -164,7 +164,7 @@ class TestDeleteSegmentFromIndexTask: document.updated_at = fake.date_time_this_year() document.doc_type = kwargs.get("doc_type", "text") document.doc_metadata = kwargs.get("doc_metadata", {}) - document.doc_form = kwargs.get("doc_form", IndexType.PARAGRAPH_INDEX) + document.doc_form = kwargs.get("doc_form", IndexStructureType.PARAGRAPH_INDEX) document.doc_language = kwargs.get("doc_language", "en") db_session_with_containers.add(document) @@ -244,8 +244,11 @@ class TestDeleteSegmentFromIndexTask: mock_processor = MagicMock() mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor + # Extract segment IDs for the task + segment_ids = [segment.id for segment in segments] + # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed successfully assert result is None # Task should return None on success @@ -279,7 +282,7 @@ class TestDeleteSegmentFromIndexTask: index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)] # Execute the task with non-existent dataset - result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id) + result = delete_segment_from_index_task(index_node_ids, non_existent_dataset_id, non_existent_document_id, []) # Verify the task completed without exceptions assert result is None # Task should return None when dataset not found @@ -305,7 +308,7 @@ class TestDeleteSegmentFromIndexTask: index_node_ids = [f"node_{fake.uuid4()}" for _ in range(3)] # Execute the task with non-existent document - result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, non_existent_document_id, []) # Verify the task completed without exceptions assert result is None # Task should return None when document not found @@ -330,9 +333,10 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Execute the task with disabled document - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without exceptions assert result is None # Task should return None when document is disabled @@ -357,9 +361,10 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Execute the task with archived document - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without exceptions assert result is None # Task should return None when document is archived @@ -386,9 +391,10 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Execute the task with incomplete indexing - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without exceptions assert result is None # Task should return None when indexing is not completed @@ -409,7 +415,11 @@ class TestDeleteSegmentFromIndexTask: fake = Faker() # Test different document forms - document_forms = [IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX, IndexType.PARENT_CHILD_INDEX] + document_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] for doc_form in document_forms: # Create test data for each document form @@ -420,13 +430,14 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 2, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Mock the index processor mock_processor = MagicMock() mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed successfully assert result is None @@ -469,6 +480,7 @@ class TestDeleteSegmentFromIndexTask: segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Mock the index processor to raise an exception mock_processor = MagicMock() @@ -476,7 +488,7 @@ class TestDeleteSegmentFromIndexTask: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - should not raise exception - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed without raising exceptions assert result is None # Task should return None even when exceptions occur @@ -518,7 +530,7 @@ class TestDeleteSegmentFromIndexTask: mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, []) # Verify the task completed successfully assert result is None @@ -555,13 +567,14 @@ class TestDeleteSegmentFromIndexTask: # Create large number of segments segments = self._create_test_document_segments(db_session_with_containers, document, account, 50, fake) index_node_ids = [segment.index_node_id for segment in segments] + segment_ids = [segment.id for segment in segments] # Mock the index processor mock_processor = MagicMock() mock_index_processor_factory.return_value.init_index_processor.return_value = mock_processor # Execute the task - result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id) + result = delete_segment_from_index_task(index_node_ids, dataset.id, document.id, segment_ids) # Verify the task completed successfully assert result is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index 798fe091ab..b738646736 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.ext_database import db from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -95,7 +95,7 @@ class TestEnableSegmentsToIndexTask: created_by=account.id, indexing_status="completed", enabled=True, - doc_form=IndexType.PARAGRAPH_INDEX, + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) db.session.add(document) db.session.commit() @@ -166,7 +166,7 @@ class TestEnableSegmentsToIndexTask: ) # Update document to use different index type - document.doc_form = IndexType.QA_INDEX + document.doc_form = IndexStructureType.QA_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -185,7 +185,9 @@ class TestEnableSegmentsToIndexTask: enable_segments_to_index_task(segment_ids, dataset.id, document.id) # Assert: Verify different index type handling - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.QA_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.QA_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with correct parameters @@ -328,7 +330,9 @@ class TestEnableSegmentsToIndexTask: enable_segments_to_index_task(non_existent_segment_ids, dataset.id, document.id) # Assert: Verify index processor was created but load was not called - mock_external_service_dependencies["index_processor_factory"].assert_called_once_with(IndexType.PARAGRAPH_INDEX) + mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( + IndexStructureType.PARAGRAPH_INDEX + ) mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_enable_segments_to_index_with_parent_child_structure( @@ -350,7 +354,7 @@ class TestEnableSegmentsToIndexTask: ) # Update document to use parent-child index type - document.doc_form = IndexType.PARENT_CHILD_INDEX + document.doc_form = IndexStructureType.PARENT_CHILD_INDEX db.session.commit() # Refresh dataset to ensure doc_form property reflects the updated document @@ -383,7 +387,7 @@ class TestEnableSegmentsToIndexTask: # Assert: Verify parent-child index processing mock_external_service_dependencies["index_processor_factory"].assert_called_once_with( - IndexType.PARENT_CHILD_INDEX + IndexStructureType.PARENT_CHILD_INDEX ) mock_external_service_dependencies["index_processor"].load.assert_called_once() diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index d9f6dcc43c..025a0d8d70 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -53,7 +53,7 @@ from sqlalchemy.exc import IntegrityError from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import ModelPropertyKey -from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeConnectionError, @@ -99,10 +99,10 @@ class TestCacheEmbeddingDocuments: @pytest.fixture def sample_embedding_result(self): - """Create a sample TextEmbeddingResult for testing. + """Create a sample EmbeddingResult for testing. Returns: - TextEmbeddingResult: Mock embedding result with proper structure + EmbeddingResult: Mock embedding result with proper structure """ # Create normalized embedding vectors (dimension 1536 for ada-002) embedding_vector = np.random.randn(1536) @@ -118,7 +118,7 @@ class TestCacheEmbeddingDocuments: latency=0.5, ) - return TextEmbeddingResult( + return EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_vector], usage=usage, @@ -197,7 +197,7 @@ class TestCacheEmbeddingDocuments: latency=0.8, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -296,7 +296,7 @@ class TestCacheEmbeddingDocuments: latency=0.6, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=new_embeddings, usage=usage, @@ -386,7 +386,7 @@ class TestCacheEmbeddingDocuments: latency=0.5, ) - return TextEmbeddingResult( + return EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -449,7 +449,7 @@ class TestCacheEmbeddingDocuments: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[valid_vector.tolist(), nan_vector], usage=usage, @@ -629,7 +629,7 @@ class TestCacheEmbeddingQuery: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -728,7 +728,7 @@ class TestCacheEmbeddingQuery: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[nan_vector], usage=usage, @@ -793,7 +793,7 @@ class TestCacheEmbeddingQuery: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -873,13 +873,13 @@ class TestEmbeddingModelSwitching: latency=0.3, ) - result_ada = TextEmbeddingResult( + result_ada = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_ada], usage=usage, ) - result_3_small = TextEmbeddingResult( + result_3_small = EmbeddingResult( model="text-embedding-3-small", embeddings=[normalized_3_small], usage=usage, @@ -953,13 +953,13 @@ class TestEmbeddingModelSwitching: latency=0.4, ) - result_openai = TextEmbeddingResult( + result_openai = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_openai], usage=usage_openai, ) - result_cohere = TextEmbeddingResult( + result_cohere = EmbeddingResult( model="embed-english-v3.0", embeddings=[normalized_cohere], usage=usage_cohere, @@ -1042,7 +1042,7 @@ class TestEmbeddingDimensionValidation: latency=0.7, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1095,7 +1095,7 @@ class TestEmbeddingDimensionValidation: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1148,7 +1148,7 @@ class TestEmbeddingDimensionValidation: latency=0.3, ) - result_ada = TextEmbeddingResult( + result_ada = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized_ada], usage=usage_ada, @@ -1181,7 +1181,7 @@ class TestEmbeddingDimensionValidation: latency=0.4, ) - result_cohere = TextEmbeddingResult( + result_cohere = EmbeddingResult( model="embed-english-v3.0", embeddings=[normalized_cohere], usage=usage_cohere, @@ -1279,7 +1279,7 @@ class TestEmbeddingEdgeCases: latency=0.1, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1322,7 +1322,7 @@ class TestEmbeddingEdgeCases: latency=1.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1370,7 +1370,7 @@ class TestEmbeddingEdgeCases: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1422,7 +1422,7 @@ class TestEmbeddingEdgeCases: latency=0.2, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1478,7 +1478,7 @@ class TestEmbeddingEdgeCases: ) # Model returns embeddings for all texts - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1546,7 +1546,7 @@ class TestEmbeddingEdgeCases: latency=0.8, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1603,7 +1603,7 @@ class TestEmbeddingEdgeCases: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1657,7 +1657,7 @@ class TestEmbeddingEdgeCases: latency=0.5, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1757,7 +1757,7 @@ class TestEmbeddingCachePerformance: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, @@ -1826,7 +1826,7 @@ class TestEmbeddingCachePerformance: latency=0.5, ) - return TextEmbeddingResult( + return EmbeddingResult( model="text-embedding-ada-002", embeddings=embeddings, usage=usage, @@ -1888,7 +1888,7 @@ class TestEmbeddingCachePerformance: latency=0.3, ) - embedding_result = TextEmbeddingResult( + embedding_result = EmbeddingResult( model="text-embedding-ada-002", embeddings=[normalized], usage=usage, diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index d26e98db8d..c00fee8fe5 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -62,7 +62,7 @@ from core.indexing_runner import ( IndexingRunner, ) from core.model_runtime.entities.model_entities import ModelType -from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import ChildDocument, Document from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule @@ -112,7 +112,7 @@ def create_mock_dataset_document( document_id: str | None = None, dataset_id: str | None = None, tenant_id: str | None = None, - doc_form: str = IndexType.PARAGRAPH_INDEX, + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, data_source_type: str = "upload_file", doc_language: str = "English", ) -> Mock: @@ -133,8 +133,8 @@ def create_mock_dataset_document( Mock: A configured mock DatasetDocument object with all required attributes. Example: - >>> doc = create_mock_dataset_document(doc_form=IndexType.QA_INDEX) - >>> assert doc.doc_form == IndexType.QA_INDEX + >>> doc = create_mock_dataset_document(doc_form=IndexStructureType.QA_INDEX) + >>> assert doc.doc_form == IndexStructureType.QA_INDEX """ doc = Mock(spec=DatasetDocument) doc.id = document_id or str(uuid.uuid4()) @@ -276,7 +276,7 @@ class TestIndexingRunnerExtract: doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) doc.tenant_id = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX doc.data_source_type = "upload_file" doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())} return doc @@ -616,7 +616,7 @@ class TestIndexingRunnerLoad: doc = Mock(spec=DatasetDocument) doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX return doc @pytest.fixture @@ -700,7 +700,7 @@ class TestIndexingRunnerLoad: """Test loading with parent-child index structure.""" # Arrange runner = IndexingRunner() - sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX + sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX sample_dataset.indexing_technique = "high_quality" # Add child documents @@ -775,7 +775,7 @@ class TestIndexingRunnerRun: doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) doc.tenant_id = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX doc.doc_language = "English" doc.data_source_type = "upload_file" doc.data_source_info_dict = {"upload_file_id": str(uuid.uuid4())} @@ -802,6 +802,21 @@ class TestIndexingRunnerRun: mock_process_rule.to_dict.return_value = {"mode": "automatic", "rules": {}} mock_dependencies["db"].session.scalar.return_value = mock_process_rule + # Mock current_user (Account) for _transform + mock_current_user = MagicMock() + mock_current_user.set_tenant_id = MagicMock() + + # Setup db.session.query to return different results based on the model + def mock_query_side_effect(model): + mock_query_result = MagicMock() + if model.__name__ == "Dataset": + mock_query_result.filter_by.return_value.first.return_value = mock_dataset + elif model.__name__ == "Account": + mock_query_result.filter_by.return_value.first.return_value = mock_current_user + return mock_query_result + + mock_dependencies["db"].session.query.side_effect = mock_query_side_effect + # Mock processor mock_processor = MagicMock() mock_dependencies["factory"].return_value.init_index_processor.return_value = mock_processor @@ -1268,7 +1283,7 @@ class TestIndexingRunnerLoadSegments: doc.id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) doc.created_by = str(uuid.uuid4()) - doc.doc_form = IndexType.PARAGRAPH_INDEX + doc.doc_form = IndexStructureType.PARAGRAPH_INDEX return doc @pytest.fixture @@ -1316,7 +1331,7 @@ class TestIndexingRunnerLoadSegments: """Test loading segments for parent-child index.""" # Arrange runner = IndexingRunner() - sample_dataset_document.doc_form = IndexType.PARENT_CHILD_INDEX + sample_dataset_document.doc_form = IndexStructureType.PARENT_CHILD_INDEX # Add child documents for doc in sample_documents: @@ -1413,7 +1428,7 @@ class TestIndexingRunnerEstimate: tenant_id=tenant_id, extract_settings=extract_settings, tmp_processing_rule={"mode": "automatic", "rules": {}}, - doc_form=IndexType.PARAGRAPH_INDEX, + doc_form=IndexStructureType.PARAGRAPH_INDEX, ) diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index 4912884c55..ebe6c37818 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -26,6 +26,18 @@ from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.weight_rerank import WeightRerankRunner +def create_mock_model_instance(): + """Create a properly configured mock ModelInstance for reranking tests.""" + mock_instance = Mock(spec=ModelInstance) + # Setup provider_model_bundle chain for check_model_support_vision + mock_instance.provider_model_bundle = Mock() + mock_instance.provider_model_bundle.configuration = Mock() + mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id" + mock_instance.provider = "test-provider" + mock_instance.model = "test-model" + return mock_instance + + class TestRerankModelRunner: """Unit tests for RerankModelRunner. @@ -37,10 +49,23 @@ class TestRerankModelRunner: - Metadata preservation and score injection """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + @pytest.fixture def mock_model_instance(self): """Create a mock ModelInstance for reranking.""" mock_instance = Mock(spec=ModelInstance) + # Setup provider_model_bundle chain for check_model_support_vision + mock_instance.provider_model_bundle = Mock() + mock_instance.provider_model_bundle.configuration = Mock() + mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id" + mock_instance.provider = "test-provider" + mock_instance.model = "test-model" return mock_instance @pytest.fixture @@ -803,7 +828,7 @@ class TestRerankRunnerFactory: - Parameters are forwarded to runner constructor """ # Arrange: Mock model instance - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() # Act: Create runner via factory runner = RerankRunnerFactory.create_rerank_runner( @@ -865,7 +890,7 @@ class TestRerankRunnerFactory: - String values are properly matched """ # Arrange: Mock model instance - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() # Act: Create runner using enum value runner = RerankRunnerFactory.create_rerank_runner( @@ -886,6 +911,13 @@ class TestRerankIntegration: - Real-world usage scenarios """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_model_reranking_full_workflow(self): """Test complete model-based reranking workflow. @@ -895,7 +927,7 @@ class TestRerankIntegration: - Top results are returned correctly """ # Arrange: Create mock model and documents - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -951,7 +983,7 @@ class TestRerankIntegration: - Normalization is consistent """ # Arrange: Create mock model with various scores - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -991,6 +1023,13 @@ class TestRerankEdgeCases: - Concurrent reranking scenarios """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_rerank_with_empty_metadata(self): """Test reranking when documents have empty metadata. @@ -1000,7 +1039,7 @@ class TestRerankEdgeCases: - Empty metadata documents are processed correctly """ # Arrange: Create documents with empty metadata - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1046,7 +1085,7 @@ class TestRerankEdgeCases: - Score comparison logic works at boundary """ # Arrange: Create mock with various scores including negatives - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1082,7 +1121,7 @@ class TestRerankEdgeCases: - No overflow or precision issues """ # Arrange: All documents with perfect scores - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1117,7 +1156,7 @@ class TestRerankEdgeCases: - Content encoding is preserved """ # Arrange: Documents with special characters - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1159,7 +1198,7 @@ class TestRerankEdgeCases: - Content is not truncated unexpectedly """ # Arrange: Documents with very long content - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() long_content = "This is a very long document. " * 1000 # ~30,000 characters mock_rerank_result = RerankResult( @@ -1196,7 +1235,7 @@ class TestRerankEdgeCases: - All documents are processed correctly """ # Arrange: Create 100 documents - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() num_docs = 100 # Create rerank results for all documents @@ -1287,7 +1326,7 @@ class TestRerankEdgeCases: - Documents can still be ranked """ # Arrange: Empty query - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ @@ -1325,6 +1364,13 @@ class TestRerankPerformance: - Score calculation optimization """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_rerank_batch_processing(self): """Test that documents are processed in a single batch. @@ -1334,7 +1380,7 @@ class TestRerankPerformance: - Efficient batch processing """ # Arrange: Multiple documents - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[RerankDocument(index=i, text=f"Doc {i}", score=0.9 - i * 0.1) for i in range(5)], @@ -1435,6 +1481,13 @@ class TestRerankErrorHandling: - Error propagation """ + @pytest.fixture(autouse=True) + def mock_model_manager(self): + """Auto-use fixture to patch ModelManager for all tests in this class.""" + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + yield mock_mm + def test_rerank_model_invocation_error(self): """Test handling of model invocation errors. @@ -1444,7 +1497,7 @@ class TestRerankErrorHandling: - Error context is preserved """ # Arrange: Mock model that raises exception - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_model_instance.invoke_rerank.side_effect = RuntimeError("Model invocation failed") documents = [ @@ -1470,7 +1523,7 @@ class TestRerankErrorHandling: - Invalid results don't corrupt output """ # Arrange: Rerank result with invalid index - mock_model_instance = Mock(spec=ModelInstance) + mock_model_instance = create_mock_model_instance() mock_rerank_result = RerankResult( model="bge-reranker-base", docs=[ diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index 0163e42992..affd6c648f 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -425,15 +425,15 @@ class TestRetrievalService: # ==================== Vector Search Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_basic(self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents): + def test_vector_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test basic vector/semantic search functionality. This test validates the core vector search flow: 1. Dataset is retrieved from database - 2. embedding_search is called via ThreadPoolExecutor + 2. _retrieve is called via ThreadPoolExecutor 3. Documents are added to shared all_documents list 4. Results are returned to caller @@ -447,28 +447,28 @@ class TestRetrievalService: # Set up the mock dataset that will be "retrieved" from database mock_get_dataset.return_value = mock_dataset - # Create a side effect function that simulates embedding_search behavior - # In the real implementation, embedding_search: - # 1. Gets the dataset - # 2. Creates a Vector instance - # 3. Calls search_by_vector with embeddings - # 4. Extends all_documents with results - def side_effect_embedding_search( + # Create a side effect function that simulates _retrieve behavior + # _retrieve modifies the all_documents list in place + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - """Simulate embedding_search adding documents to the shared list.""" - all_documents.extend(sample_documents) + """Simulate _retrieve adding documents to the shared list.""" + if all_documents is not None: + all_documents.extend(sample_documents) - mock_embedding_search.side_effect = side_effect_embedding_search + mock_retrieve.side_effect = side_effect_retrieve # Define test parameters query = "What is Python?" # Natural language query @@ -481,7 +481,7 @@ class TestRetrievalService: # 1. Check if query is empty (early return if so) # 2. Get the dataset using _get_dataset # 3. Create ThreadPoolExecutor - # 4. Submit embedding_search task + # 4. Submit _retrieve task # 5. Wait for completion # 6. Return all_documents list results = RetrievalService.retrieve( @@ -502,15 +502,13 @@ class TestRetrievalService: # Verify documents maintain their scores (highest score first in sample_documents) assert results[0].metadata["score"] == 0.95, "First document should have highest score from sample_documents" - # Verify embedding_search was called exactly once + # Verify _retrieve was called exactly once # This confirms the search method was invoked by ThreadPoolExecutor - mock_embedding_search.assert_called_once() + mock_retrieve.assert_called_once() - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_with_document_filter( - self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents - ): + def test_vector_search_with_document_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test vector search with document ID filtering. @@ -522,21 +520,25 @@ class TestRetrievalService: mock_get_dataset.return_value = mock_dataset filtered_docs = [sample_documents[0]] - def side_effect_embedding_search( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.extend(filtered_docs) + if all_documents is not None: + all_documents.extend(filtered_docs) - mock_embedding_search.side_effect = side_effect_embedding_search + mock_retrieve.side_effect = side_effect_retrieve document_ids_filter = [sample_documents[0].metadata["document_id"]] # Act @@ -552,12 +554,12 @@ class TestRetrievalService: assert len(results) == 1 assert results[0].metadata["doc_id"] == "doc1" # Verify document_ids_filter was passed - call_kwargs = mock_embedding_search.call_args.kwargs + call_kwargs = mock_retrieve.call_args.kwargs assert call_kwargs["document_ids_filter"] == document_ids_filter - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_empty_results(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_vector_search_empty_results(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test vector search when no results match the query. @@ -567,8 +569,8 @@ class TestRetrievalService: """ # Arrange mock_get_dataset.return_value = mock_dataset - # embedding_search doesn't add anything to all_documents - mock_embedding_search.side_effect = lambda *args, **kwargs: None + # _retrieve doesn't add anything to all_documents + mock_retrieve.side_effect = lambda *args, **kwargs: None # Act results = RetrievalService.retrieve( @@ -583,9 +585,9 @@ class TestRetrievalService: # ==================== Keyword Search Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_keyword_search_basic(self, mock_get_dataset, mock_keyword_search, mock_dataset, sample_documents): + def test_keyword_search_basic(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test basic keyword search functionality. @@ -597,12 +599,25 @@ class TestRetrievalService: # Arrange mock_get_dataset.return_value = mock_dataset - def side_effect_keyword_search( - flask_app, dataset_id, query, top_k, all_documents, exceptions, document_ids_filter=None + def side_effect_retrieve( + flask_app, + retrieval_method, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, + document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.extend(sample_documents) + if all_documents is not None: + all_documents.extend(sample_documents) - mock_keyword_search.side_effect = side_effect_keyword_search + mock_retrieve.side_effect = side_effect_retrieve query = "Python programming" top_k = 3 @@ -618,7 +633,7 @@ class TestRetrievalService: # Assert assert len(results) == 3 assert all(isinstance(doc, Document) for doc in results) - mock_keyword_search.assert_called_once() + mock_retrieve.assert_called_once() @patch("core.rag.datasource.retrieval_service.RetrievalService.keyword_search") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") @@ -1147,11 +1162,9 @@ class TestRetrievalService: # ==================== Metadata Filtering Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_with_metadata_filter( - self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents - ): + def test_vector_search_with_metadata_filter(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test vector search with metadata-based document filtering. @@ -1166,21 +1179,25 @@ class TestRetrievalService: filtered_doc = sample_documents[0] filtered_doc.metadata["category"] = "programming" - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.append(filtered_doc) + if all_documents is not None: + all_documents.append(filtered_doc) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve # Act results = RetrievalService.retrieve( @@ -1243,9 +1260,9 @@ class TestRetrievalService: # Assert assert results == [] - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_retrieve_with_exception_handling(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test that exceptions during retrieval are properly handled. @@ -1256,22 +1273,26 @@ class TestRetrievalService: # Arrange mock_get_dataset.return_value = mock_dataset - # Make embedding_search add an exception to the exceptions list + # Make _retrieve add an exception to the exceptions list def side_effect_with_exception( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - exceptions.append("Search failed") + if exceptions is not None: + exceptions.append("Search failed") - mock_embedding_search.side_effect = side_effect_with_exception + mock_retrieve.side_effect = side_effect_with_exception # Act & Assert with pytest.raises(ValueError) as exc_info: @@ -1286,9 +1307,9 @@ class TestRetrievalService: # ==================== Score Threshold Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_vector_search_with_score_threshold(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test vector search with score threshold filtering. @@ -1306,21 +1327,25 @@ class TestRetrievalService: provider="dify", ) - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - all_documents.append(high_score_doc) + if all_documents is not None: + all_documents.append(high_score_doc) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve score_threshold = 0.8 @@ -1339,9 +1364,9 @@ class TestRetrievalService: # ==================== Top-K Limiting Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_embedding_search, mock_dataset): + def test_retrieve_respects_top_k_limit(self, mock_get_dataset, mock_retrieve, mock_dataset): """ Test that retrieval respects top_k parameter. @@ -1362,22 +1387,26 @@ class TestRetrievalService: for i in range(10) ] - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): # Return only top_k documents - all_documents.extend(many_docs[:top_k]) + if all_documents is not None: + all_documents.extend(many_docs[:top_k]) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve top_k = 3 @@ -1390,9 +1419,9 @@ class TestRetrievalService: ) # Assert - # Verify top_k was passed to embedding_search - assert mock_embedding_search.called - call_kwargs = mock_embedding_search.call_args.kwargs + # Verify _retrieve was called + assert mock_retrieve.called + call_kwargs = mock_retrieve.call_args.kwargs assert call_kwargs["top_k"] == top_k # Verify we got the right number of results assert len(results) == top_k @@ -1421,11 +1450,9 @@ class TestRetrievalService: # ==================== Reranking Tests ==================== - @patch("core.rag.datasource.retrieval_service.RetrievalService.embedding_search") + @patch("core.rag.datasource.retrieval_service.RetrievalService._retrieve") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") - def test_semantic_search_with_reranking( - self, mock_get_dataset, mock_embedding_search, mock_dataset, sample_documents - ): + def test_semantic_search_with_reranking(self, mock_get_dataset, mock_retrieve, mock_dataset, sample_documents): """ Test semantic search with reranking model. @@ -1439,22 +1466,26 @@ class TestRetrievalService: # Simulate reranking changing order reranked_docs = list(reversed(sample_documents)) - def side_effect_embedding( + def side_effect_retrieve( flask_app, - dataset_id, - query, - top_k, - score_threshold, - reranking_model, - all_documents, retrieval_method, - exceptions, + dataset, + query=None, + top_k=4, + score_threshold=None, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, document_ids_filter=None, + attachment_id=None, + all_documents=None, + exceptions=None, ): - # embedding_search handles reranking internally - all_documents.extend(reranked_docs) + # _retrieve handles reranking internally + if all_documents is not None: + all_documents.extend(reranked_docs) - mock_embedding_search.side_effect = side_effect_embedding + mock_retrieve.side_effect = side_effect_retrieve reranking_model = { "reranking_provider_name": "cohere", @@ -1473,7 +1504,7 @@ class TestRetrievalService: # Assert # For semantic search with reranking, reranking_model should be passed assert len(results) == 3 - call_kwargs = mock_embedding_search.call_args.kwargs + call_kwargs = mock_retrieve.call_args.kwargs assert call_kwargs["reranking_model"] == reranking_model diff --git a/api/tests/unit_tests/utils/test_text_processing.py b/api/tests/unit_tests/utils/test_text_processing.py index 8bfc97ae63..8af47e8967 100644 --- a/api/tests/unit_tests/utils/test_text_processing.py +++ b/api/tests/unit_tests/utils/test_text_processing.py @@ -8,7 +8,9 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols [ ("...Hello, World!", "Hello, World!"), ("。测试中文标点", "测试中文标点"), - ("!@#Test symbols", "Test symbols"), + # Note: ! is not in the removal pattern, only @# are removed, leaving "!Test symbols" + # The pattern intentionally excludes ! as per #11868 fix + ("@#Test symbols", "Test symbols"), ("Hello, World!", "Hello, World!"), ("", ""), (" ", " "), diff --git a/docker/.env.example b/docker/.env.example index b71c38e07a..80e87425c1 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -808,6 +808,19 @@ UPLOAD_FILE_BATCH_LIMIT=5 # Recommended: exe,bat,cmd,com,scr,vbs,ps1,msi,dll UPLOAD_FILE_EXTENSION_BLACKLIST= +# Maximum number of files allowed in a single chunk attachment, default 10. +SINGLE_CHUNK_ATTACHMENT_LIMIT=10 + +# Maximum number of files allowed in a image batch upload operation +IMAGE_FILE_BATCH_LIMIT=10 + +# Maximum allowed image file size for attachments in megabytes, default 2. +ATTACHMENT_IMAGE_FILE_SIZE_LIMIT=2 + +# Timeout for downloading image attachments in seconds, default 60. +ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT=60 + + # ETL type, support: `dify`, `Unstructured` # `dify` Dify's proprietary file extraction scheme # `Unstructured` Unstructured.io file extraction scheme @@ -1415,4 +1428,4 @@ WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE=100 WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK=0 # Tenant isolated task queue configuration -TENANT_ISOLATED_TASK_CONCURRENCY=1 \ No newline at end of file +TENANT_ISOLATED_TASK_CONCURRENCY=1 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 7ae8a70699..3e416c36c9 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -364,6 +364,10 @@ x-shared-env: &shared-api-worker-env UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} UPLOAD_FILE_EXTENSION_BLACKLIST: ${UPLOAD_FILE_EXTENSION_BLACKLIST:-} + SINGLE_CHUNK_ATTACHMENT_LIMIT: ${SINGLE_CHUNK_ATTACHMENT_LIMIT:-10} + IMAGE_FILE_BATCH_LIMIT: ${IMAGE_FILE_BATCH_LIMIT:-10} + ATTACHMENT_IMAGE_FILE_SIZE_LIMIT: ${ATTACHMENT_IMAGE_FILE_SIZE_LIMIT:-2} + ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT: ${ATTACHMENT_IMAGE_DOWNLOAD_TIMEOUT:-60} ETL_TYPE: ${ETL_TYPE:-dify} UNSTRUCTURED_API_URL: ${UNSTRUCTURED_API_URL:-} UNSTRUCTURED_API_KEY: ${UNSTRUCTURED_API_KEY:-}