mirror of
https://github.com/wowlikon/LiB.git
synced 2026-03-21 23:53:38 +00:00
улучшение векторного поиска, добавление перегенерации векторов
This commit is contained in:
@@ -20,6 +20,7 @@ from .core import (
|
|||||||
)
|
)
|
||||||
from library_service.settings import get_logger
|
from library_service.settings import get_logger
|
||||||
|
|
||||||
|
# Получение логгера
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Основной модуль"""
|
"""Основной модуль"""
|
||||||
|
from library_service.services.embeddings import ensure_embeddings
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
import asyncio, sys, traceback
|
import asyncio, sys, traceback
|
||||||
@@ -22,7 +23,7 @@ from library_service.settings import (
|
|||||||
get_app,
|
get_app,
|
||||||
get_logger,
|
get_logger,
|
||||||
OLLAMA_URL,
|
OLLAMA_URL,
|
||||||
ASSISTANT_LLM,
|
ASSISTANT_LLM, EMBEDDINGS_MODEL, REGENERATE_EMBEDDINGS_FORCE, SKIP_REGENERATE_EMBEDDINGS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -53,7 +54,7 @@ async def lifespan(_):
|
|||||||
logger.info("[+] Loading ollama models...")
|
logger.info("[+] Loading ollama models...")
|
||||||
try:
|
try:
|
||||||
ollama_client = Client(host=OLLAMA_URL)
|
ollama_client = Client(host=OLLAMA_URL)
|
||||||
ollama_client.pull("mxbai-embed-large")
|
ollama_client.pull(EMBEDDINGS_MODEL)
|
||||||
|
|
||||||
if ASSISTANT_LLM:
|
if ASSISTANT_LLM:
|
||||||
ollama_client.pull(ASSISTANT_LLM)
|
ollama_client.pull(ASSISTANT_LLM)
|
||||||
@@ -63,6 +64,8 @@ async def lifespan(_):
|
|||||||
except ResponseError as e:
|
except ResponseError as e:
|
||||||
logger.error(f"[-] Failed to pull models {e}")
|
logger.error(f"[-] Failed to pull models {e}")
|
||||||
|
|
||||||
|
ensure_embeddings(REGENERATE_EMBEDDINGS_FORCE, SKIP_REGENERATE_EMBEDDINGS)
|
||||||
|
|
||||||
asyncio.create_task(cleanup_task())
|
asyncio.create_task(cleanup_task())
|
||||||
logger.info("[+] Starting application...")
|
logger.info("[+] Starting application...")
|
||||||
yield # Обработка запросов
|
yield # Обработка запросов
|
||||||
|
|||||||
@@ -7,15 +7,13 @@ from datetime import datetime, timezone
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status, UploadFile, File
|
from fastapi import APIRouter, Depends, HTTPException, Path, Query, status, UploadFile, File
|
||||||
from ollama import Client
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from sqlalchemy import text, case, distinct
|
from sqlalchemy import text, case, distinct
|
||||||
from sqlalchemy.orm import selectinload, defer
|
from sqlalchemy.orm import selectinload, defer
|
||||||
from sqlmodel import Session, select, col, func
|
from sqlmodel import Session, select, col, func
|
||||||
|
|
||||||
from library_service.auth import RequireStaff, OptionalAuth
|
from library_service.auth import RequireStaff, OptionalAuth
|
||||||
from library_service.services import transcode_image
|
from library_service.settings import get_session, BOOKS_PREVIEW_DIR
|
||||||
from library_service.settings import get_session, OLLAMA_URL, BOOKS_PREVIEW_DIR
|
|
||||||
from library_service.models.enums import BookStatus
|
from library_service.models.enums import BookStatus
|
||||||
from library_service.models.db import (
|
from library_service.models.db import (
|
||||||
Author,
|
Author,
|
||||||
@@ -37,10 +35,14 @@ from library_service.models.dto.misc import (
|
|||||||
BookWithAuthorsAndGenres,
|
BookWithAuthorsAndGenres,
|
||||||
BookFilteredList,
|
BookFilteredList,
|
||||||
)
|
)
|
||||||
|
from library_service.services import (
|
||||||
|
transcode_image,
|
||||||
|
generate_book_embedding,
|
||||||
|
generate_search_embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/books", tags=["books"])
|
router = APIRouter(prefix="/books", tags=["books"])
|
||||||
ollama_client = Client(host=OLLAMA_URL)
|
|
||||||
|
|
||||||
|
|
||||||
def close_active_loan(session: Session, book_id: int) -> None:
|
def close_active_loan(session: Session, book_id: int) -> None:
|
||||||
@@ -102,7 +104,7 @@ def filter_books(
|
|||||||
|
|
||||||
if q:
|
if q:
|
||||||
if current_user:
|
if current_user:
|
||||||
emb = ollama_client.embeddings(model="mxbai-embed-large", prompt=q)["embedding"]
|
emb = generate_search_embedding(q)
|
||||||
distance_col = Book.embedding.cosine_distance(emb) # ty: ignore
|
distance_col = Book.embedding.cosine_distance(emb) # ty: ignore
|
||||||
statement = statement.where(Book.embedding.is_not(None)) # ty: ignore
|
statement = statement.where(Book.embedding.is_not(None)) # ty: ignore
|
||||||
|
|
||||||
@@ -133,9 +135,8 @@ def create_book(
|
|||||||
session: Session = Depends(get_session),
|
session: Session = Depends(get_session),
|
||||||
):
|
):
|
||||||
"""Создает новую книгу в системе"""
|
"""Создает новую книгу в системе"""
|
||||||
full_text = book.title + " " + book.description
|
emb = generate_book_embedding(book.title, book.description)
|
||||||
emb = ollama_client.embeddings(model="mxbai-embed-large", prompt=full_text)
|
db_book = Book(**book.model_dump(), embedding=emb)
|
||||||
db_book = Book(**book.model_dump(), embedding=emb["embedding"])
|
|
||||||
|
|
||||||
session.add(db_book)
|
session.add(db_book)
|
||||||
session.commit()
|
session.commit()
|
||||||
@@ -263,13 +264,10 @@ def update_book(
|
|||||||
if book_update.description is not None:
|
if book_update.description is not None:
|
||||||
db_book.description = book_update.description
|
db_book.description = book_update.description
|
||||||
|
|
||||||
full_text = (
|
db_book.embedding = generate_book_embedding(
|
||||||
(book_update.title or db_book.title)
|
book_update.title or db_book.title,
|
||||||
+ " "
|
book_update.description or db_book.description,
|
||||||
+ (book_update.description or db_book.description)
|
|
||||||
)
|
)
|
||||||
emb = ollama_client.embeddings(model="mxbai-embed-large", prompt=full_text)
|
|
||||||
db_book.embedding = emb["embedding"]
|
|
||||||
|
|
||||||
if book_update.page_count is not None:
|
if book_update.page_count is not None:
|
||||||
db_book.page_count = book_update.page_count
|
db_book.page_count = book_update.page_count
|
||||||
|
|||||||
@@ -14,6 +14,12 @@ from .captcha import (
|
|||||||
)
|
)
|
||||||
from .describe_er import SchemaGenerator
|
from .describe_er import SchemaGenerator
|
||||||
from .image_processing import transcode_image
|
from .image_processing import transcode_image
|
||||||
|
from .embeddings import (
|
||||||
|
get_ollama_client,
|
||||||
|
generate_embedding,
|
||||||
|
generate_book_embedding,
|
||||||
|
generate_search_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"limiter",
|
"limiter",
|
||||||
@@ -30,4 +36,8 @@ __all__ = [
|
|||||||
"prng",
|
"prng",
|
||||||
"SchemaGenerator",
|
"SchemaGenerator",
|
||||||
"transcode_image",
|
"transcode_image",
|
||||||
|
"get_ollama_client",
|
||||||
|
"generate_embedding",
|
||||||
|
"generate_book_embedding",
|
||||||
|
"generate_search_embedding",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1,93 @@
|
|||||||
|
"""Модуль работы с векторными эмбеддингами"""
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from ollama import Client
|
||||||
|
|
||||||
|
from library_service.settings import OLLAMA_URL, EMBEDDINGS_MODEL, get_logger
|
||||||
|
|
||||||
|
|
||||||
|
_client: Optional[Client] = None
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def get_ollama_client() -> Client:
|
||||||
|
"""Возвращает singleton клиент Ollama"""
|
||||||
|
global _client
|
||||||
|
if _client is None:
|
||||||
|
_client = Client(host=OLLAMA_URL)
|
||||||
|
return _client
|
||||||
|
|
||||||
|
|
||||||
|
def generate_embedding(text: str) -> List[float]:
|
||||||
|
"""Генерирует эмбеддинг для текста."""
|
||||||
|
client = get_ollama_client()
|
||||||
|
response = client.embeddings(model=EMBEDDINGS_MODEL, prompt=text)
|
||||||
|
return response["embedding"]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_book_embedding(title: str, description: str) -> List[float]:
|
||||||
|
"""Генерирует эмбеддинг для книги на основе названия и описания."""
|
||||||
|
full_text = f"Название книги: {title}. Описание: {description}"
|
||||||
|
return generate_embedding(full_text)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_search_embedding(query: str) -> List[float]:
|
||||||
|
"""Генерирует эмбеддинг для поискового запроса."""
|
||||||
|
search_prompt = f"Represent this sentence for searching relevant passages: {query}"
|
||||||
|
return generate_embedding(search_prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def regenerate_embeddings(force: bool = False) -> int:
|
||||||
|
"""Генерирует эмбеддинги для книг в БД."""
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
from library_service.settings import engine
|
||||||
|
from library_service.models.db import Book
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
statement = select(Book)
|
||||||
|
|
||||||
|
if not force:
|
||||||
|
statement = statement.where(Book.embedding == None) # noqa: E711
|
||||||
|
|
||||||
|
books = session.exec(statement).all()
|
||||||
|
|
||||||
|
if not books:
|
||||||
|
logger.info("[=] No books to process")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
logger.info(f"[+] Generating embeddings for {len(books)} books...")
|
||||||
|
processed = 0
|
||||||
|
|
||||||
|
for book in books:
|
||||||
|
try:
|
||||||
|
book.embedding = generate_book_embedding(
|
||||||
|
book.title,
|
||||||
|
book.description or ""
|
||||||
|
)
|
||||||
|
session.add(book)
|
||||||
|
logger.debug(f" [+] Book {book.id}: {book.title[:50]}")
|
||||||
|
processed += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f" [-] Book {book.id}: {e}")
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
logger.info(f"[+] Embedding generation complete: {processed}/{len(books)}")
|
||||||
|
return processed
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_embeddings(force: bool, skip: bool) -> None:
|
||||||
|
"""Проверяет и генерирует отсутствующие эмбеддинги"""
|
||||||
|
|
||||||
|
if skip:
|
||||||
|
logger.info("[=] Embeddings generation skipped")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("[+] Checking embeddings...")
|
||||||
|
try:
|
||||||
|
count = regenerate_embeddings(force=force)
|
||||||
|
if count > 0:
|
||||||
|
logger.info(f"[+] Generated {count} embeddings")
|
||||||
|
else:
|
||||||
|
logger.info("[+] All embeddings up to date")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[-] Embeddings generation failed: {e}")
|
||||||
@@ -100,6 +100,9 @@ PASSWORD = os.getenv("POSTGRES_PASSWORD")
|
|||||||
DATABASE = os.getenv("POSTGRES_DB")
|
DATABASE = os.getenv("POSTGRES_DB")
|
||||||
|
|
||||||
OLLAMA_URL = os.getenv("OLLAMA_URL")
|
OLLAMA_URL = os.getenv("OLLAMA_URL")
|
||||||
|
EMBEDDINGS_MODEL = os.getenv("EMBEDDINGS_MODEL", "bge-m3")
|
||||||
|
REGENERATE_EMBEDDINGS_FORCE = os.getenv("REGENERATE_EMBEDDINGS", "").lower() in ("1", "true", "yes")
|
||||||
|
SKIP_REGENERATE_EMBEDDINGS = os.getenv("SKIP_EMBEDDINGS", "").lower() in ("1", "true", "yes")
|
||||||
|
|
||||||
ASSISTANT_LLM = ""
|
ASSISTANT_LLM = ""
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|||||||
Reference in New Issue
Block a user