mirror of
https://github.com/wowlikon/LiB.git
synced 2026-03-21 23:53:38 +00:00
улучшение векторного поиска, добавление перегенерации векторов
This commit is contained in:
@@ -20,6 +20,7 @@ from .core import (
|
||||
)
|
||||
from library_service.settings import get_logger
|
||||
|
||||
# Получение логгера
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Основной модуль"""
|
||||
from library_service.services.embeddings import ensure_embeddings
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
import asyncio, sys, traceback
|
||||
@@ -22,7 +23,7 @@ from library_service.settings import (
|
||||
get_app,
|
||||
get_logger,
|
||||
OLLAMA_URL,
|
||||
ASSISTANT_LLM,
|
||||
ASSISTANT_LLM, EMBEDDINGS_MODEL, REGENERATE_EMBEDDINGS_FORCE, SKIP_REGENERATE_EMBEDDINGS,
|
||||
)
|
||||
|
||||
|
||||
@@ -53,7 +54,7 @@ async def lifespan(_):
|
||||
logger.info("[+] Loading ollama models...")
|
||||
try:
|
||||
ollama_client = Client(host=OLLAMA_URL)
|
||||
ollama_client.pull("mxbai-embed-large")
|
||||
ollama_client.pull(EMBEDDINGS_MODEL)
|
||||
|
||||
if ASSISTANT_LLM:
|
||||
ollama_client.pull(ASSISTANT_LLM)
|
||||
@@ -63,6 +64,8 @@ async def lifespan(_):
|
||||
except ResponseError as e:
|
||||
logger.error(f"[-] Failed to pull models {e}")
|
||||
|
||||
ensure_embeddings(REGENERATE_EMBEDDINGS_FORCE, SKIP_REGENERATE_EMBEDDINGS)
|
||||
|
||||
asyncio.create_task(cleanup_task())
|
||||
logger.info("[+] Starting application...")
|
||||
yield # Обработка запросов
|
||||
|
||||
@@ -7,15 +7,13 @@ from datetime import datetime, timezone
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status, UploadFile, File
|
||||
from ollama import Client
|
||||
from pydantic import Field
|
||||
from sqlalchemy import text, case, distinct
|
||||
from sqlalchemy.orm import selectinload, defer
|
||||
from sqlmodel import Session, select, col, func
|
||||
|
||||
from library_service.auth import RequireStaff, OptionalAuth
|
||||
from library_service.services import transcode_image
|
||||
from library_service.settings import get_session, OLLAMA_URL, BOOKS_PREVIEW_DIR
|
||||
from library_service.settings import get_session, BOOKS_PREVIEW_DIR
|
||||
from library_service.models.enums import BookStatus
|
||||
from library_service.models.db import (
|
||||
Author,
|
||||
@@ -37,10 +35,14 @@ from library_service.models.dto.misc import (
|
||||
BookWithAuthorsAndGenres,
|
||||
BookFilteredList,
|
||||
)
|
||||
from library_service.services import (
|
||||
transcode_image,
|
||||
generate_book_embedding,
|
||||
generate_search_embedding
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/books", tags=["books"])
|
||||
ollama_client = Client(host=OLLAMA_URL)
|
||||
|
||||
|
||||
def close_active_loan(session: Session, book_id: int) -> None:
|
||||
@@ -102,7 +104,7 @@ def filter_books(
|
||||
|
||||
if q:
|
||||
if current_user:
|
||||
emb = ollama_client.embeddings(model="mxbai-embed-large", prompt=q)["embedding"]
|
||||
emb = generate_search_embedding(q)
|
||||
distance_col = Book.embedding.cosine_distance(emb) # ty: ignore
|
||||
statement = statement.where(Book.embedding.is_not(None)) # ty: ignore
|
||||
|
||||
@@ -133,9 +135,8 @@ def create_book(
|
||||
session: Session = Depends(get_session),
|
||||
):
|
||||
"""Создает новую книгу в системе"""
|
||||
full_text = book.title + " " + book.description
|
||||
emb = ollama_client.embeddings(model="mxbai-embed-large", prompt=full_text)
|
||||
db_book = Book(**book.model_dump(), embedding=emb["embedding"])
|
||||
emb = generate_book_embedding(book.title, book.description)
|
||||
db_book = Book(**book.model_dump(), embedding=emb)
|
||||
|
||||
session.add(db_book)
|
||||
session.commit()
|
||||
@@ -263,13 +264,10 @@ def update_book(
|
||||
if book_update.description is not None:
|
||||
db_book.description = book_update.description
|
||||
|
||||
full_text = (
|
||||
(book_update.title or db_book.title)
|
||||
+ " "
|
||||
+ (book_update.description or db_book.description)
|
||||
db_book.embedding = generate_book_embedding(
|
||||
book_update.title or db_book.title,
|
||||
book_update.description or db_book.description,
|
||||
)
|
||||
emb = ollama_client.embeddings(model="mxbai-embed-large", prompt=full_text)
|
||||
db_book.embedding = emb["embedding"]
|
||||
|
||||
if book_update.page_count is not None:
|
||||
db_book.page_count = book_update.page_count
|
||||
|
||||
@@ -14,6 +14,12 @@ from .captcha import (
|
||||
)
|
||||
from .describe_er import SchemaGenerator
|
||||
from .image_processing import transcode_image
|
||||
from .embeddings import (
|
||||
get_ollama_client,
|
||||
generate_embedding,
|
||||
generate_book_embedding,
|
||||
generate_search_embedding,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"limiter",
|
||||
@@ -30,4 +36,8 @@ __all__ = [
|
||||
"prng",
|
||||
"SchemaGenerator",
|
||||
"transcode_image",
|
||||
"get_ollama_client",
|
||||
"generate_embedding",
|
||||
"generate_book_embedding",
|
||||
"generate_search_embedding",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
"""Модуль работы с векторными эмбеддингами"""
|
||||
from typing import List, Optional
|
||||
|
||||
from ollama import Client
|
||||
|
||||
from library_service.settings import OLLAMA_URL, EMBEDDINGS_MODEL, get_logger
|
||||
|
||||
|
||||
_client: Optional[Client] = None
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def get_ollama_client() -> Client:
|
||||
"""Возвращает singleton клиент Ollama"""
|
||||
global _client
|
||||
if _client is None:
|
||||
_client = Client(host=OLLAMA_URL)
|
||||
return _client
|
||||
|
||||
|
||||
def generate_embedding(text: str) -> List[float]:
|
||||
"""Генерирует эмбеддинг для текста."""
|
||||
client = get_ollama_client()
|
||||
response = client.embeddings(model=EMBEDDINGS_MODEL, prompt=text)
|
||||
return response["embedding"]
|
||||
|
||||
|
||||
def generate_book_embedding(title: str, description: str) -> List[float]:
|
||||
"""Генерирует эмбеддинг для книги на основе названия и описания."""
|
||||
full_text = f"Название книги: {title}. Описание: {description}"
|
||||
return generate_embedding(full_text)
|
||||
|
||||
|
||||
def generate_search_embedding(query: str) -> List[float]:
|
||||
"""Генерирует эмбеддинг для поискового запроса."""
|
||||
search_prompt = f"Represent this sentence for searching relevant passages: {query}"
|
||||
return generate_embedding(search_prompt)
|
||||
|
||||
|
||||
def regenerate_embeddings(force: bool = False) -> int:
|
||||
"""Генерирует эмбеддинги для книг в БД."""
|
||||
from sqlmodel import Session, select
|
||||
from library_service.settings import engine
|
||||
from library_service.models.db import Book
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(Book)
|
||||
|
||||
if not force:
|
||||
statement = statement.where(Book.embedding == None) # noqa: E711
|
||||
|
||||
books = session.exec(statement).all()
|
||||
|
||||
if not books:
|
||||
logger.info("[=] No books to process")
|
||||
return 0
|
||||
|
||||
logger.info(f"[+] Generating embeddings for {len(books)} books...")
|
||||
processed = 0
|
||||
|
||||
for book in books:
|
||||
try:
|
||||
book.embedding = generate_book_embedding(
|
||||
book.title,
|
||||
book.description or ""
|
||||
)
|
||||
session.add(book)
|
||||
logger.debug(f" [+] Book {book.id}: {book.title[:50]}")
|
||||
processed += 1
|
||||
except Exception as e:
|
||||
logger.warning(f" [-] Book {book.id}: {e}")
|
||||
|
||||
session.commit()
|
||||
logger.info(f"[+] Embedding generation complete: {processed}/{len(books)}")
|
||||
return processed
|
||||
|
||||
|
||||
def ensure_embeddings(force: bool, skip: bool) -> None:
|
||||
"""Проверяет и генерирует отсутствующие эмбеддинги"""
|
||||
|
||||
if skip:
|
||||
logger.info("[=] Embeddings generation skipped")
|
||||
return
|
||||
|
||||
logger.info("[+] Checking embeddings...")
|
||||
try:
|
||||
count = regenerate_embeddings(force=force)
|
||||
if count > 0:
|
||||
logger.info(f"[+] Generated {count} embeddings")
|
||||
else:
|
||||
logger.info("[+] All embeddings up to date")
|
||||
except Exception as e:
|
||||
logger.error(f"[-] Embeddings generation failed: {e}")
|
||||
@@ -100,6 +100,9 @@ PASSWORD = os.getenv("POSTGRES_PASSWORD")
|
||||
DATABASE = os.getenv("POSTGRES_DB")
|
||||
|
||||
OLLAMA_URL = os.getenv("OLLAMA_URL")
|
||||
EMBEDDINGS_MODEL = os.getenv("EMBEDDINGS_MODEL", "bge-m3")
|
||||
REGENERATE_EMBEDDINGS_FORCE = os.getenv("REGENERATE_EMBEDDINGS", "").lower() in ("1", "true", "yes")
|
||||
SKIP_REGENERATE_EMBEDDINGS = os.getenv("SKIP_EMBEDDINGS", "").lower() in ("1", "true", "yes")
|
||||
|
||||
ASSISTANT_LLM = ""
|
||||
logger = get_logger()
|
||||
|
||||
Reference in New Issue
Block a user