diff --git a/library_service/auth/recovery.py b/library_service/auth/recovery.py index 393daba..d49c915 100644 --- a/library_service/auth/recovery.py +++ b/library_service/auth/recovery.py @@ -20,6 +20,7 @@ from .core import ( ) from library_service.settings import get_logger +# Получение логгера logger = get_logger() diff --git a/library_service/main.py b/library_service/main.py index 0e2516b..1aed07f 100644 --- a/library_service/main.py +++ b/library_service/main.py @@ -1,4 +1,5 @@ """Основной модуль""" +from library_service.services.embeddings import ensure_embeddings from starlette.middleware.base import BaseHTTPMiddleware import asyncio, sys, traceback @@ -22,7 +23,7 @@ from library_service.settings import ( get_app, get_logger, OLLAMA_URL, - ASSISTANT_LLM, + ASSISTANT_LLM, EMBEDDINGS_MODEL, REGENERATE_EMBEDDINGS_FORCE, SKIP_REGENERATE_EMBEDDINGS, ) @@ -53,7 +54,7 @@ async def lifespan(_): logger.info("[+] Loading ollama models...") try: ollama_client = Client(host=OLLAMA_URL) - ollama_client.pull("mxbai-embed-large") + ollama_client.pull(EMBEDDINGS_MODEL) if ASSISTANT_LLM: ollama_client.pull(ASSISTANT_LLM) @@ -63,6 +64,8 @@ async def lifespan(_): except ResponseError as e: logger.error(f"[-] Failed to pull models {e}") + ensure_embeddings(REGENERATE_EMBEDDINGS_FORCE, SKIP_REGENERATE_EMBEDDINGS) + asyncio.create_task(cleanup_task()) logger.info("[+] Starting application...") yield # Обработка запросов diff --git a/library_service/routers/books.py b/library_service/routers/books.py index 45a05ef..c0dffda 100644 --- a/library_service/routers/books.py +++ b/library_service/routers/books.py @@ -7,15 +7,13 @@ from datetime import datetime, timezone from typing import List from fastapi import APIRouter, Depends, HTTPException, Path, Query, status, UploadFile, File -from ollama import Client from pydantic import Field from sqlalchemy import text, case, distinct from sqlalchemy.orm import selectinload, defer from sqlmodel import Session, select, col, func from library_service.auth import RequireStaff, OptionalAuth -from library_service.services import transcode_image -from library_service.settings import get_session, OLLAMA_URL, BOOKS_PREVIEW_DIR +from library_service.settings import get_session, BOOKS_PREVIEW_DIR from library_service.models.enums import BookStatus from library_service.models.db import ( Author, @@ -37,10 +35,14 @@ from library_service.models.dto.misc import ( BookWithAuthorsAndGenres, BookFilteredList, ) +from library_service.services import ( + transcode_image, + generate_book_embedding, + generate_search_embedding +) router = APIRouter(prefix="/books", tags=["books"]) -ollama_client = Client(host=OLLAMA_URL) def close_active_loan(session: Session, book_id: int) -> None: @@ -102,7 +104,7 @@ def filter_books( if q: if current_user: - emb = ollama_client.embeddings(model="mxbai-embed-large", prompt=q)["embedding"] + emb = generate_search_embedding(q) distance_col = Book.embedding.cosine_distance(emb) # ty: ignore statement = statement.where(Book.embedding.is_not(None)) # ty: ignore @@ -133,9 +135,8 @@ def create_book( session: Session = Depends(get_session), ): """Создает новую книгу в системе""" - full_text = book.title + " " + book.description - emb = ollama_client.embeddings(model="mxbai-embed-large", prompt=full_text) - db_book = Book(**book.model_dump(), embedding=emb["embedding"]) + emb = generate_book_embedding(book.title, book.description) + db_book = Book(**book.model_dump(), embedding=emb) session.add(db_book) session.commit() @@ -263,13 +264,10 @@ def update_book( if book_update.description is not None: db_book.description = book_update.description - full_text = ( - (book_update.title or db_book.title) - + " " - + (book_update.description or db_book.description) + db_book.embedding = generate_book_embedding( + book_update.title or db_book.title, + book_update.description or db_book.description, ) - emb = ollama_client.embeddings(model="mxbai-embed-large", prompt=full_text) - db_book.embedding = emb["embedding"] if book_update.page_count is not None: db_book.page_count = book_update.page_count diff --git a/library_service/services/__init__.py b/library_service/services/__init__.py index 0208e9b..d498f70 100644 --- a/library_service/services/__init__.py +++ b/library_service/services/__init__.py @@ -14,6 +14,12 @@ from .captcha import ( ) from .describe_er import SchemaGenerator from .image_processing import transcode_image +from .embeddings import ( + get_ollama_client, + generate_embedding, + generate_book_embedding, + generate_search_embedding, +) __all__ = [ "limiter", @@ -30,4 +36,8 @@ __all__ = [ "prng", "SchemaGenerator", "transcode_image", + "get_ollama_client", + "generate_embedding", + "generate_book_embedding", + "generate_search_embedding", ] diff --git a/library_service/services/embeddings.py b/library_service/services/embeddings.py new file mode 100644 index 0000000..3f0a694 --- /dev/null +++ b/library_service/services/embeddings.py @@ -0,0 +1,93 @@ +"""Модуль работы с векторными эмбеддингами""" +from typing import List, Optional + +from ollama import Client + +from library_service.settings import OLLAMA_URL, EMBEDDINGS_MODEL, get_logger + + +_client: Optional[Client] = None +logger = get_logger() + + +def get_ollama_client() -> Client: + """Возвращает singleton клиент Ollama""" + global _client + if _client is None: + _client = Client(host=OLLAMA_URL) + return _client + + +def generate_embedding(text: str) -> List[float]: + """Генерирует эмбеддинг для текста.""" + client = get_ollama_client() + response = client.embeddings(model=EMBEDDINGS_MODEL, prompt=text) + return response["embedding"] + + +def generate_book_embedding(title: str, description: str) -> List[float]: + """Генерирует эмбеддинг для книги на основе названия и описания.""" + full_text = f"Название книги: {title}. Описание: {description}" + return generate_embedding(full_text) + + +def generate_search_embedding(query: str) -> List[float]: + """Генерирует эмбеддинг для поискового запроса.""" + search_prompt = f"Represent this sentence for searching relevant passages: {query}" + return generate_embedding(search_prompt) + + +def regenerate_embeddings(force: bool = False) -> int: + """Генерирует эмбеддинги для книг в БД.""" + from sqlmodel import Session, select + from library_service.settings import engine + from library_service.models.db import Book + + with Session(engine) as session: + statement = select(Book) + + if not force: + statement = statement.where(Book.embedding == None) # noqa: E711 + + books = session.exec(statement).all() + + if not books: + logger.info("[=] No books to process") + return 0 + + logger.info(f"[+] Generating embeddings for {len(books)} books...") + processed = 0 + + for book in books: + try: + book.embedding = generate_book_embedding( + book.title, + book.description or "" + ) + session.add(book) + logger.debug(f" [+] Book {book.id}: {book.title[:50]}") + processed += 1 + except Exception as e: + logger.warning(f" [-] Book {book.id}: {e}") + + session.commit() + logger.info(f"[+] Embedding generation complete: {processed}/{len(books)}") + return processed + + +def ensure_embeddings(force: bool, skip: bool) -> None: + """Проверяет и генерирует отсутствующие эмбеддинги""" + + if skip: + logger.info("[=] Embeddings generation skipped") + return + + logger.info("[+] Checking embeddings...") + try: + count = regenerate_embeddings(force=force) + if count > 0: + logger.info(f"[+] Generated {count} embeddings") + else: + logger.info("[+] All embeddings up to date") + except Exception as e: + logger.error(f"[-] Embeddings generation failed: {e}") diff --git a/library_service/settings.py b/library_service/settings.py index 065f3cf..0fbcde9 100644 --- a/library_service/settings.py +++ b/library_service/settings.py @@ -100,6 +100,9 @@ PASSWORD = os.getenv("POSTGRES_PASSWORD") DATABASE = os.getenv("POSTGRES_DB") OLLAMA_URL = os.getenv("OLLAMA_URL") +EMBEDDINGS_MODEL = os.getenv("EMBEDDINGS_MODEL", "bge-m3") +REGENERATE_EMBEDDINGS_FORCE = os.getenv("REGENERATE_EMBEDDINGS", "").lower() in ("1", "true", "yes") +SKIP_REGENERATE_EMBEDDINGS = os.getenv("SKIP_EMBEDDINGS", "").lower() in ("1", "true", "yes") ASSISTANT_LLM = "" logger = get_logger()