mirror of
https://github.com/wowlikon/LiB.git
synced 2026-02-04 12:31:09 +00:00
284 lines
9.3 KiB
Python
284 lines
9.3 KiB
Python
"""Модуль генерации описания схемы БД"""
|
|
|
|
import enum
|
|
import inspect
|
|
from typing import (
|
|
List,
|
|
Dict,
|
|
Any,
|
|
Set,
|
|
Type,
|
|
Tuple,
|
|
Optional,
|
|
Union,
|
|
get_origin,
|
|
get_args,
|
|
)
|
|
|
|
from sqlalchemy import Enum as SAEnum
|
|
from sqlalchemy.inspection import inspect as sa_inspect
|
|
from sqlmodel import SQLModel
|
|
|
|
|
|
class SchemaGenerator:
|
|
"""Сервис генерации json описания схемы БД"""
|
|
|
|
def __init__(self, db_module, dto_module=None):
|
|
self.db_models = self._get_classes(db_module, is_table=True)
|
|
self.dto_models = (
|
|
self._get_classes(dto_module, is_table=False) if dto_module else []
|
|
)
|
|
self.link_table_names = self._identify_link_tables()
|
|
self.field_descriptions = self._collect_all_descriptions()
|
|
self._table_to_model = {m.__tablename__: m for m in self.db_models}
|
|
|
|
def _get_classes(
|
|
self, module, is_table: bool | None = None
|
|
) -> List[Type[SQLModel]]:
|
|
if module is None:
|
|
return []
|
|
|
|
classes = []
|
|
for name, obj in inspect.getmembers(module):
|
|
if (
|
|
inspect.isclass(obj)
|
|
and issubclass(obj, SQLModel)
|
|
and obj is not SQLModel
|
|
):
|
|
if is_table is True and hasattr(obj, "__table__"):
|
|
classes.append(obj)
|
|
elif is_table is False and not hasattr(obj, "__table__"):
|
|
classes.append(obj)
|
|
return classes
|
|
|
|
def _normalize_model_name(self, name: str) -> str:
|
|
suffixes = [
|
|
"Create",
|
|
"Read",
|
|
"Update",
|
|
"DTO",
|
|
"Base",
|
|
"List",
|
|
"Detail",
|
|
"Response",
|
|
"Request",
|
|
]
|
|
result = name
|
|
for suffix in suffixes:
|
|
if result.endswith(suffix) and len(result) > len(suffix):
|
|
result = result[: -len(suffix)]
|
|
return result
|
|
|
|
def _get_field_descriptions_from_class(self, cls: Type) -> Dict[str, str]:
|
|
descriptions = {}
|
|
|
|
for parent in cls.__mro__:
|
|
if parent is SQLModel or parent is object:
|
|
continue
|
|
|
|
fields = getattr(parent, "model_fields", {})
|
|
for field_name, field_info in fields.items():
|
|
if field_name in descriptions:
|
|
continue
|
|
|
|
desc = getattr(field_info, "description", None) or getattr(
|
|
field_info, "title", None
|
|
)
|
|
if desc:
|
|
descriptions[field_name] = desc
|
|
|
|
return descriptions
|
|
|
|
def _collect_all_descriptions(self) -> Dict[str, Dict[str, str]]:
|
|
result = {}
|
|
|
|
dto_map = {}
|
|
for dto in self.dto_models:
|
|
base_name = self._normalize_model_name(dto.__name__)
|
|
if base_name not in dto_map:
|
|
dto_map[base_name] = {}
|
|
|
|
for field, desc in self._get_field_descriptions_from_class(dto).items():
|
|
if field not in dto_map[base_name]:
|
|
dto_map[base_name][field] = desc
|
|
|
|
for model in self.db_models:
|
|
model_name = model.__name__
|
|
result[model_name] = {
|
|
**dto_map.get(model_name, {}),
|
|
**self._get_field_descriptions_from_class(model),
|
|
}
|
|
|
|
return result
|
|
|
|
def _identify_link_tables(self) -> Set[str]:
|
|
link_tables = set()
|
|
for model in self.db_models:
|
|
try:
|
|
for rel in sa_inspect(model).relationships:
|
|
if rel.secondary is not None:
|
|
link_tables.add(rel.secondary.name)
|
|
except Exception:
|
|
continue
|
|
return link_tables
|
|
|
|
def _collect_fk_relations(self) -> List[Dict[str, Any]]:
|
|
relations = []
|
|
processed: Set[Tuple[str, str, str, str]] = set()
|
|
|
|
for model in self.db_models:
|
|
if model.__tablename__ in self.link_table_names:
|
|
continue
|
|
|
|
for col in sa_inspect(model).columns:
|
|
for fk in col.foreign_keys:
|
|
target_table = fk.column.table.name
|
|
if target_table in self.link_table_names:
|
|
continue
|
|
|
|
target_model = self._table_to_model.get(target_table)
|
|
if not target_model:
|
|
continue
|
|
|
|
key = (
|
|
model.__name__,
|
|
col.name,
|
|
target_model.__name__,
|
|
fk.column.name,
|
|
)
|
|
|
|
if key not in processed:
|
|
relations.append(
|
|
{
|
|
"fromEntity": model.__name__,
|
|
"fromField": col.name,
|
|
"toEntity": target_model.__name__,
|
|
"toField": fk.column.name,
|
|
"fromMultiplicity": "N",
|
|
"toMultiplicity": "1",
|
|
}
|
|
)
|
|
processed.add(key)
|
|
return relations
|
|
|
|
def _collect_m2m_relations(self) -> List[Dict[str, Any]]:
|
|
relations = []
|
|
processed: Set[Tuple[str, str]] = set()
|
|
|
|
for model in self.db_models:
|
|
if model.__tablename__ in self.link_table_names:
|
|
continue
|
|
|
|
try:
|
|
for rel in sa_inspect(model).relationships:
|
|
if rel.direction.name != "MANYTOMANY":
|
|
continue
|
|
|
|
target_model = rel.mapper.class_
|
|
if target_model.__tablename__ in self.link_table_names:
|
|
continue
|
|
|
|
pair = tuple(sorted([model.__name__, target_model.__name__]))
|
|
if pair not in processed:
|
|
relations.append(
|
|
{
|
|
"fromEntity": pair[0],
|
|
"fromField": "id",
|
|
"toEntity": pair[1],
|
|
"toField": "id",
|
|
"fromMultiplicity": "N",
|
|
"toMultiplicity": "N",
|
|
}
|
|
)
|
|
processed.add(pair)
|
|
except Exception:
|
|
continue
|
|
|
|
return relations
|
|
|
|
def _extract_enum_from_annotation(self, annotation) -> Optional[Type[enum.Enum]]:
|
|
if isinstance(annotation, type) and issubclass(annotation, enum.Enum):
|
|
return annotation
|
|
|
|
origin = get_origin(annotation)
|
|
if origin is Union:
|
|
for arg in get_args(annotation):
|
|
if isinstance(arg, type) and issubclass(arg, enum.Enum):
|
|
return arg
|
|
|
|
return None
|
|
|
|
def _get_enum_values(self, model: Type[SQLModel], col) -> Optional[List[str]]:
|
|
if isinstance(col.type, SAEnum):
|
|
if col.type.enum_class is not None:
|
|
return [e.value for e in col.type.enum_class]
|
|
if col.type.enums:
|
|
return list(col.type.enums)
|
|
|
|
try:
|
|
annotations = {}
|
|
for cls in model.__mro__:
|
|
if hasattr(cls, "__annotations__"):
|
|
annotations.update(cls.__annotations__)
|
|
|
|
if col.name in annotations:
|
|
annotation = annotations[col.name]
|
|
enum_class = self._extract_enum_from_annotation(annotation)
|
|
if enum_class:
|
|
return [e.value for e in enum_class]
|
|
except Exception:
|
|
pass
|
|
|
|
return None
|
|
|
|
def generate(self) -> Dict[str, Any]:
|
|
entities = []
|
|
|
|
for model in self.db_models:
|
|
table_name = model.__tablename__
|
|
if table_name in self.link_table_names:
|
|
continue
|
|
|
|
columns = sorted(
|
|
sa_inspect(model).columns,
|
|
key=lambda c: (
|
|
0 if c.primary_key else (1 if c.foreign_keys else 2),
|
|
c.name,
|
|
),
|
|
)
|
|
|
|
entity_fields = []
|
|
descriptions = self.field_descriptions.get(model.__name__, {})
|
|
|
|
for col in columns:
|
|
label = col.name
|
|
if col.primary_key:
|
|
label += " (PK)"
|
|
if col.foreign_keys:
|
|
label += " (FK)"
|
|
|
|
field_obj = {"id": col.name, "label": label}
|
|
|
|
tooltip_parts = []
|
|
|
|
if col.name in descriptions:
|
|
tooltip_parts.append(descriptions[col.name])
|
|
|
|
enum_values = self._get_enum_values(model, col)
|
|
if enum_values:
|
|
tooltip_parts.append(
|
|
"Варианты:\n" + "\n".join(f"• {v}" for v in enum_values)
|
|
)
|
|
|
|
if tooltip_parts:
|
|
field_obj["tooltip"] = "\n\n".join(tooltip_parts)
|
|
|
|
entity_fields.append(field_obj)
|
|
|
|
entities.append(
|
|
{"id": model.__name__, "title": table_name, "fields": entity_fields}
|
|
)
|
|
|
|
relations = self._collect_fk_relations() + self._collect_m2m_relations()
|
|
return {"entities": entities, "relations": relations}
|