Files
LibraryAPI/library_service/services/describe_er.py

284 lines
9.3 KiB
Python

"""Модуль генерации описания схемы БД"""
import enum
import inspect
from typing import (
List,
Dict,
Any,
Set,
Type,
Tuple,
Optional,
Union,
get_origin,
get_args,
)
from sqlalchemy import Enum as SAEnum
from sqlalchemy.inspection import inspect as sa_inspect
from sqlmodel import SQLModel
class SchemaGenerator:
"""Сервис генерации json описания схемы БД"""
def __init__(self, db_module, dto_module=None):
self.db_models = self._get_classes(db_module, is_table=True)
self.dto_models = (
self._get_classes(dto_module, is_table=False) if dto_module else []
)
self.link_table_names = self._identify_link_tables()
self.field_descriptions = self._collect_all_descriptions()
self._table_to_model = {m.__tablename__: m for m in self.db_models}
def _get_classes(
self, module, is_table: bool | None = None
) -> List[Type[SQLModel]]:
if module is None:
return []
classes = []
for name, obj in inspect.getmembers(module):
if (
inspect.isclass(obj)
and issubclass(obj, SQLModel)
and obj is not SQLModel
):
if is_table is True and hasattr(obj, "__table__"):
classes.append(obj)
elif is_table is False and not hasattr(obj, "__table__"):
classes.append(obj)
return classes
def _normalize_model_name(self, name: str) -> str:
suffixes = [
"Create",
"Read",
"Update",
"DTO",
"Base",
"List",
"Detail",
"Response",
"Request",
]
result = name
for suffix in suffixes:
if result.endswith(suffix) and len(result) > len(suffix):
result = result[: -len(suffix)]
return result
def _get_field_descriptions_from_class(self, cls: Type) -> Dict[str, str]:
descriptions = {}
for parent in cls.__mro__:
if parent is SQLModel or parent is object:
continue
fields = getattr(parent, "model_fields", {})
for field_name, field_info in fields.items():
if field_name in descriptions:
continue
desc = getattr(field_info, "description", None) or getattr(
field_info, "title", None
)
if desc:
descriptions[field_name] = desc
return descriptions
def _collect_all_descriptions(self) -> Dict[str, Dict[str, str]]:
result = {}
dto_map = {}
for dto in self.dto_models:
base_name = self._normalize_model_name(dto.__name__)
if base_name not in dto_map:
dto_map[base_name] = {}
for field, desc in self._get_field_descriptions_from_class(dto).items():
if field not in dto_map[base_name]:
dto_map[base_name][field] = desc
for model in self.db_models:
model_name = model.__name__
result[model_name] = {
**dto_map.get(model_name, {}),
**self._get_field_descriptions_from_class(model),
}
return result
def _identify_link_tables(self) -> Set[str]:
link_tables = set()
for model in self.db_models:
try:
for rel in sa_inspect(model).relationships:
if rel.secondary is not None:
link_tables.add(rel.secondary.name)
except Exception:
continue
return link_tables
def _collect_fk_relations(self) -> List[Dict[str, Any]]:
relations = []
processed: Set[Tuple[str, str, str, str]] = set()
for model in self.db_models:
if model.__tablename__ in self.link_table_names:
continue
for col in sa_inspect(model).columns:
for fk in col.foreign_keys:
target_table = fk.column.table.name
if target_table in self.link_table_names:
continue
target_model = self._table_to_model.get(target_table)
if not target_model:
continue
key = (
model.__name__,
col.name,
target_model.__name__,
fk.column.name,
)
if key not in processed:
relations.append(
{
"fromEntity": model.__name__,
"fromField": col.name,
"toEntity": target_model.__name__,
"toField": fk.column.name,
"fromMultiplicity": "N",
"toMultiplicity": "1",
}
)
processed.add(key)
return relations
def _collect_m2m_relations(self) -> List[Dict[str, Any]]:
relations = []
processed: Set[Tuple[str, str]] = set()
for model in self.db_models:
if model.__tablename__ in self.link_table_names:
continue
try:
for rel in sa_inspect(model).relationships:
if rel.direction.name != "MANYTOMANY":
continue
target_model = rel.mapper.class_
if target_model.__tablename__ in self.link_table_names:
continue
pair = tuple(sorted([model.__name__, target_model.__name__]))
if pair not in processed:
relations.append(
{
"fromEntity": pair[0],
"fromField": "id",
"toEntity": pair[1],
"toField": "id",
"fromMultiplicity": "N",
"toMultiplicity": "N",
}
)
processed.add(pair)
except Exception:
continue
return relations
def _extract_enum_from_annotation(self, annotation) -> Optional[Type[enum.Enum]]:
if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
return annotation
origin = get_origin(annotation)
if origin is Union:
for arg in get_args(annotation):
if isinstance(arg, type) and issubclass(arg, enum.Enum):
return arg
return None
def _get_enum_values(self, model: Type[SQLModel], col) -> Optional[List[str]]:
if isinstance(col.type, SAEnum):
if col.type.enum_class is not None:
return [e.value for e in col.type.enum_class]
if col.type.enums:
return list(col.type.enums)
try:
annotations = {}
for cls in model.__mro__:
if hasattr(cls, "__annotations__"):
annotations.update(cls.__annotations__)
if col.name in annotations:
annotation = annotations[col.name]
enum_class = self._extract_enum_from_annotation(annotation)
if enum_class:
return [e.value for e in enum_class]
except Exception:
pass
return None
def generate(self) -> Dict[str, Any]:
entities = []
for model in self.db_models:
table_name = model.__tablename__
if table_name in self.link_table_names:
continue
columns = sorted(
sa_inspect(model).columns,
key=lambda c: (
0 if c.primary_key else (1 if c.foreign_keys else 2),
c.name,
),
)
entity_fields = []
descriptions = self.field_descriptions.get(model.__name__, {})
for col in columns:
label = col.name
if col.primary_key:
label += " (PK)"
if col.foreign_keys:
label += " (FK)"
field_obj = {"id": col.name, "label": label}
tooltip_parts = []
if col.name in descriptions:
tooltip_parts.append(descriptions[col.name])
enum_values = self._get_enum_values(model, col)
if enum_values:
tooltip_parts.append(
"Варианты:\n" + "\n".join(f"{v}" for v in enum_values)
)
if tooltip_parts:
field_obj["tooltip"] = "\n\n".join(tooltip_parts)
entity_fields.append(field_obj)
entities.append(
{"id": model.__name__, "title": table_name, "fields": entity_fields}
)
relations = self._collect_fk_relations() + self._collect_m2m_relations()
return {"entities": entities, "relations": relations}