улучшение векторного поиска, добавление перегенерации векторов

This commit is contained in:
2026-02-17 00:06:11 +03:00
parent 213d2bcb5a
commit e8e3310afa
6 changed files with 124 additions and 16 deletions
+1
View File
@@ -20,6 +20,7 @@ from .core import (
)
from library_service.settings import get_logger
# Получение логгера
logger = get_logger()
+5 -2
View File
@@ -1,4 +1,5 @@
"""Основной модуль"""
from library_service.services.embeddings import ensure_embeddings
from starlette.middleware.base import BaseHTTPMiddleware
import asyncio, sys, traceback
@@ -22,7 +23,7 @@ from library_service.settings import (
get_app,
get_logger,
OLLAMA_URL,
ASSISTANT_LLM,
ASSISTANT_LLM, EMBEDDINGS_MODEL, REGENERATE_EMBEDDINGS_FORCE, SKIP_REGENERATE_EMBEDDINGS,
)
@@ -53,7 +54,7 @@ async def lifespan(_):
logger.info("[+] Loading ollama models...")
try:
ollama_client = Client(host=OLLAMA_URL)
ollama_client.pull("mxbai-embed-large")
ollama_client.pull(EMBEDDINGS_MODEL)
if ASSISTANT_LLM:
ollama_client.pull(ASSISTANT_LLM)
@@ -63,6 +64,8 @@ async def lifespan(_):
except ResponseError as e:
logger.error(f"[-] Failed to pull models {e}")
ensure_embeddings(REGENERATE_EMBEDDINGS_FORCE, SKIP_REGENERATE_EMBEDDINGS)
asyncio.create_task(cleanup_task())
logger.info("[+] Starting application...")
yield # Обработка запросов
+12 -14
View File
@@ -7,15 +7,13 @@ from datetime import datetime, timezone
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status, UploadFile, File
from ollama import Client
from pydantic import Field
from sqlalchemy import text, case, distinct
from sqlalchemy.orm import selectinload, defer
from sqlmodel import Session, select, col, func
from library_service.auth import RequireStaff, OptionalAuth
from library_service.services import transcode_image
from library_service.settings import get_session, OLLAMA_URL, BOOKS_PREVIEW_DIR
from library_service.settings import get_session, BOOKS_PREVIEW_DIR
from library_service.models.enums import BookStatus
from library_service.models.db import (
Author,
@@ -37,10 +35,14 @@ from library_service.models.dto.misc import (
BookWithAuthorsAndGenres,
BookFilteredList,
)
from library_service.services import (
transcode_image,
generate_book_embedding,
generate_search_embedding
)
router = APIRouter(prefix="/books", tags=["books"])
ollama_client = Client(host=OLLAMA_URL)
def close_active_loan(session: Session, book_id: int) -> None:
@@ -102,7 +104,7 @@ def filter_books(
if q:
if current_user:
emb = ollama_client.embeddings(model="mxbai-embed-large", prompt=q)["embedding"]
emb = generate_search_embedding(q)
distance_col = Book.embedding.cosine_distance(emb) # ty: ignore
statement = statement.where(Book.embedding.is_not(None)) # ty: ignore
@@ -133,9 +135,8 @@ def create_book(
session: Session = Depends(get_session),
):
"""Создает новую книгу в системе"""
full_text = book.title + " " + book.description
emb = ollama_client.embeddings(model="mxbai-embed-large", prompt=full_text)
db_book = Book(**book.model_dump(), embedding=emb["embedding"])
emb = generate_book_embedding(book.title, book.description)
db_book = Book(**book.model_dump(), embedding=emb)
session.add(db_book)
session.commit()
@@ -263,13 +264,10 @@ def update_book(
if book_update.description is not None:
db_book.description = book_update.description
full_text = (
(book_update.title or db_book.title)
+ " "
+ (book_update.description or db_book.description)
db_book.embedding = generate_book_embedding(
book_update.title or db_book.title,
book_update.description or db_book.description,
)
emb = ollama_client.embeddings(model="mxbai-embed-large", prompt=full_text)
db_book.embedding = emb["embedding"]
if book_update.page_count is not None:
db_book.page_count = book_update.page_count
+10
View File
@@ -14,6 +14,12 @@ from .captcha import (
)
from .describe_er import SchemaGenerator
from .image_processing import transcode_image
from .embeddings import (
get_ollama_client,
generate_embedding,
generate_book_embedding,
generate_search_embedding,
)
__all__ = [
"limiter",
@@ -30,4 +36,8 @@ __all__ = [
"prng",
"SchemaGenerator",
"transcode_image",
"get_ollama_client",
"generate_embedding",
"generate_book_embedding",
"generate_search_embedding",
]
+93
View File
@@ -0,0 +1,93 @@
"""Модуль работы с векторными эмбеддингами"""
from typing import List, Optional
from ollama import Client
from library_service.settings import OLLAMA_URL, EMBEDDINGS_MODEL, get_logger
_client: Optional[Client] = None
logger = get_logger()
def get_ollama_client() -> Client:
"""Возвращает singleton клиент Ollama"""
global _client
if _client is None:
_client = Client(host=OLLAMA_URL)
return _client
def generate_embedding(text: str) -> List[float]:
"""Генерирует эмбеддинг для текста."""
client = get_ollama_client()
response = client.embeddings(model=EMBEDDINGS_MODEL, prompt=text)
return response["embedding"]
def generate_book_embedding(title: str, description: str) -> List[float]:
"""Генерирует эмбеддинг для книги на основе названия и описания."""
full_text = f"Название книги: {title}. Описание: {description}"
return generate_embedding(full_text)
def generate_search_embedding(query: str) -> List[float]:
"""Генерирует эмбеддинг для поискового запроса."""
search_prompt = f"Represent this sentence for searching relevant passages: {query}"
return generate_embedding(search_prompt)
def regenerate_embeddings(force: bool = False) -> int:
"""Генерирует эмбеддинги для книг в БД."""
from sqlmodel import Session, select
from library_service.settings import engine
from library_service.models.db import Book
with Session(engine) as session:
statement = select(Book)
if not force:
statement = statement.where(Book.embedding == None) # noqa: E711
books = session.exec(statement).all()
if not books:
logger.info("[=] No books to process")
return 0
logger.info(f"[+] Generating embeddings for {len(books)} books...")
processed = 0
for book in books:
try:
book.embedding = generate_book_embedding(
book.title,
book.description or ""
)
session.add(book)
logger.debug(f" [+] Book {book.id}: {book.title[:50]}")
processed += 1
except Exception as e:
logger.warning(f" [-] Book {book.id}: {e}")
session.commit()
logger.info(f"[+] Embedding generation complete: {processed}/{len(books)}")
return processed
def ensure_embeddings(force: bool, skip: bool) -> None:
"""Проверяет и генерирует отсутствующие эмбеддинги"""
if skip:
logger.info("[=] Embeddings generation skipped")
return
logger.info("[+] Checking embeddings...")
try:
count = regenerate_embeddings(force=force)
if count > 0:
logger.info(f"[+] Generated {count} embeddings")
else:
logger.info("[+] All embeddings up to date")
except Exception as e:
logger.error(f"[-] Embeddings generation failed: {e}")
+3
View File
@@ -100,6 +100,9 @@ PASSWORD = os.getenv("POSTGRES_PASSWORD")
DATABASE = os.getenv("POSTGRES_DB")
OLLAMA_URL = os.getenv("OLLAMA_URL")
EMBEDDINGS_MODEL = os.getenv("EMBEDDINGS_MODEL", "bge-m3")
REGENERATE_EMBEDDINGS_FORCE = os.getenv("REGENERATE_EMBEDDINGS", "").lower() in ("1", "true", "yes")
SKIP_REGENERATE_EMBEDDINGS = os.getenv("SKIP_EMBEDDINGS", "").lower() in ("1", "true", "yes")
ASSISTANT_LLM = ""
logger = get_logger()