Files
LibraryAPI/library_service/services/describe_er.py

226 lines
7.6 KiB
Python

"""Модуль генерации описания схемы БД"""
import inspect
from typing import List, Dict, Any, Set, Type, Tuple
from pydantic.fields import FieldInfo
from sqlalchemy.inspection import inspect as sa_inspect
from sqlmodel import SQLModel
class SchemaGenerator:
"""Сервис генерации json описания схемы БД"""
def __init__(self, db_module, dto_module=None):
self.db_models = self._get_classes(db_module, is_table=True)
self.dto_models = (
self._get_classes(dto_module, is_table=False) if dto_module else []
)
self.link_table_names = self._identify_link_tables()
self.field_descriptions = self._collect_all_descriptions()
self._table_to_model = {m.__tablename__: m for m in self.db_models}
def _get_classes(
self, module, is_table: bool | None = None
) -> List[Type[SQLModel]]:
if module is None:
return []
classes = []
for name, obj in inspect.getmembers(module):
if (
inspect.isclass(obj)
and issubclass(obj, SQLModel)
and obj is not SQLModel
):
if is_table is True and hasattr(obj, "__table__"):
classes.append(obj)
elif is_table is False and not hasattr(obj, "__table__"):
classes.append(obj)
return classes
def _normalize_model_name(self, name: str) -> str:
suffixes = [
"Create",
"Read",
"Update",
"DTO",
"Base",
"List",
"Detail",
"Response",
"Request",
]
result = name
for suffix in suffixes:
if result.endswith(suffix) and len(result) > len(suffix):
result = result[: -len(suffix)]
return result
def _get_field_descriptions_from_class(self, cls: Type) -> Dict[str, str]:
descriptions = {}
for parent in cls.__mro__:
if parent is SQLModel or parent is object:
continue
fields = getattr(parent, "model_fields", {})
for field_name, field_info in fields.items():
if field_name in descriptions:
continue
desc = getattr(field_info, "description", None) or getattr(
field_info, "title", None
)
if desc:
descriptions[field_name] = desc
return descriptions
def _collect_all_descriptions(self) -> Dict[str, Dict[str, str]]:
result = {}
dto_map = {}
for dto in self.dto_models:
base_name = self._normalize_model_name(dto.__name__)
if base_name not in dto_map:
dto_map[base_name] = {}
for field, desc in self._get_field_descriptions_from_class(dto).items():
if field not in dto_map[base_name]:
dto_map[base_name][field] = desc
for model in self.db_models:
model_name = model.__name__
result[model_name] = {
**dto_map.get(model_name, {}),
**self._get_field_descriptions_from_class(model),
}
return result
def _identify_link_tables(self) -> Set[str]:
link_tables = set()
for model in self.db_models:
try:
for rel in sa_inspect(model).relationships:
if rel.secondary is not None:
link_tables.add(rel.secondary.name)
except Exception:
continue
return link_tables
def _collect_fk_relations(self) -> List[Dict[str, Any]]:
relations = []
processed: Set[Tuple[str, str, str, str]] = set()
for model in self.db_models:
if model.__tablename__ in self.link_table_names:
continue
for col in sa_inspect(model).columns:
for fk in col.foreign_keys:
target_table = fk.column.table.name
if target_table in self.link_table_names:
continue
target_model = self._table_to_model.get(target_table)
if not target_model:
continue
key = (
model.__name__,
col.name,
target_model.__name__,
fk.column.name,
)
if key not in processed:
relations.append(
{
"fromEntity": model.__name__,
"fromField": col.name,
"toEntity": target_model.__name__,
"toField": fk.column.name,
"fromMultiplicity": "N",
"toMultiplicity": "1",
}
)
processed.add(key)
return relations
def _collect_m2m_relations(self) -> List[Dict[str, Any]]:
relations = []
processed: Set[Tuple[str, str]] = set()
for model in self.db_models:
if model.__tablename__ in self.link_table_names:
continue
try:
for rel in sa_inspect(model).relationships:
if rel.direction.name != "MANYTOMANY":
continue
target_model = rel.mapper.class_
if target_model.__tablename__ in self.link_table_names:
continue
pair = tuple(sorted([model.__name__, target_model.__name__]))
if pair not in processed:
relations.append(
{
"fromEntity": pair[0],
"fromField": "id",
"toEntity": pair[1],
"toField": "id",
"fromMultiplicity": "N",
"toMultiplicity": "N",
}
)
processed.add(pair)
except Exception:
continue
return relations
def generate(self) -> Dict[str, Any]:
entities = []
for model in self.db_models:
table_name = model.__tablename__
if table_name in self.link_table_names:
continue
columns = sorted(
sa_inspect(model).columns,
key=lambda c: (
0 if c.primary_key else (1 if c.foreign_keys else 2),
c.name,
),
)
entity_fields = []
descriptions = self.field_descriptions.get(model.__name__, {})
for col in columns:
label = col.name
if col.primary_key:
label += " (PK)"
if col.foreign_keys:
label += " (FK)"
field_obj = {"id": col.name, "label": label}
if col.name in descriptions:
field_obj["tooltip"] = descriptions[col.name]
entity_fields.append(field_obj)
entities.append(
{"id": model.__name__, "title": table_name, "fields": entity_fields}
)
relations = self._collect_fk_relations() + self._collect_m2m_relations()
return {"entities": entities, "relations": relations}