SIT/database/repositories/message_repository.py

345 lines
10 KiB
Python
Raw Normal View History

2026-01-29 09:22:54 +00:00
"""
消息数据访问层
负责消息数据的CRUD操作
"""
from typing import List, Optional
from datetime import datetime
from database import db_manager
from database.repositories.field_repository import FieldRepository
from models.message import Message
from models.field import Field
from config import ProtocolType, SerializationType
from utils.logger import get_logger
logger = get_logger(__name__)
class MessageRepository:
"""消息数据访问类"""
def __init__(self):
self.field_repo = FieldRepository()
def create(self, message: Message) -> int:
"""
创建消息
Args:
message: 消息对象
Returns:
新创建的消息ID
"""
sql = """
INSERT INTO messages (
full_name, system_name, message_type, version,
description, protocol, serialization
) VALUES (?, ?, ?, ?, ?, ?, ?)
"""
params = (
message.full_name,
message.system_name,
message.message_type,
message.version,
message.description,
message.protocol.value,
message.serialization.value,
)
try:
with db_manager.transaction():
db_manager.execute_update(sql, params)
message_id = db_manager.get_last_insert_id()
# 关联字段
if message.fields:
self._associate_fields(message_id, message.fields)
logger.info(f"Created message: {message.full_name} (ID: {message_id})")
return message_id
except Exception as e:
logger.error(f"Failed to create message: {e}")
raise
def get_by_id(self, message_id: int) -> Optional[Message]:
"""
根据ID获取消息
Args:
message_id: 消息ID
Returns:
消息对象或None
"""
sql = "SELECT * FROM messages WHERE id = ?"
try:
rows = db_manager.execute_query(sql, (message_id,))
if rows:
message = self._row_to_message(rows[0])
# 加载关联的字段
message.fields = self._get_message_fields(message_id)
return message
return None
except Exception as e:
logger.error(f"Failed to get message by id {message_id}: {e}")
raise
def get_by_full_name(self, full_name: str) -> Optional[Message]:
"""
根据完整名称获取消息
Args:
full_name: 完整名称
Returns:
消息对象或None
"""
sql = "SELECT * FROM messages WHERE full_name = ?"
try:
rows = db_manager.execute_query(sql, (full_name,))
if rows:
message = self._row_to_message(rows[0])
message.fields = self._get_message_fields(message.id)
return message
return None
except Exception as e:
logger.error(f"Failed to get message by full_name {full_name}: {e}")
raise
def update(self, message: Message) -> bool:
"""
更新消息
Args:
message: 消息对象
Returns:
是否更新成功
"""
sql = """
UPDATE messages SET
full_name = ?, system_name = ?, message_type = ?, version = ?,
description = ?, protocol = ?, serialization = ?
WHERE id = ?
"""
params = (
message.full_name,
message.system_name,
message.message_type,
message.version,
message.description,
message.protocol.value,
message.serialization.value,
message.id,
)
try:
with db_manager.transaction():
rows_affected = db_manager.execute_update(sql, params)
# 更新字段关联
if rows_affected > 0:
# 先删除旧的关联
self._clear_field_associations(message.id)
# 添加新的关联
if message.fields:
self._associate_fields(message.id, message.fields)
logger.info(f"Updated message: {message.full_name} (ID: {message.id})")
return True
return False
except Exception as e:
logger.error(f"Failed to update message: {e}")
raise
def delete(self, message_id: int) -> bool:
"""
删除消息
Args:
message_id: 消息ID
Returns:
是否删除成功
"""
sql = "DELETE FROM messages WHERE id = ?"
try:
rows_affected = db_manager.execute_update(sql, (message_id,))
if rows_affected > 0:
logger.info(f"Deleted message ID: {message_id}")
return True
return False
except Exception as e:
logger.error(f"Failed to delete message: {e}")
raise
def search(self, system_name: Optional[str] = None,
message_type: Optional[str] = None,
version: Optional[str] = None,
field_name: Optional[str] = None,
limit: int = 100, offset: int = 0) -> List[Message]:
"""
搜索消息
Args:
system_name: 系统名支持模糊查询
message_type: 消息类型支持模糊查询
version: 版本号
field_name: 字段名查询包含该字段的消息
limit: 返回数量限制
offset: 偏移量
Returns:
消息列表
"""
sql = "SELECT DISTINCT m.* FROM messages m"
params = []
# 如果按字段名查询需要JOIN
if field_name:
sql += """
INNER JOIN message_fields mf ON m.id = mf.message_id
INNER JOIN fields f ON mf.field_id = f.id
"""
sql += " WHERE 1=1"
if system_name:
sql += " AND m.system_name LIKE ?"
params.append(f"%{system_name}%")
if message_type:
sql += " AND m.message_type LIKE ?"
params.append(f"%{message_type}%")
if version:
sql += " AND m.version = ?"
params.append(version)
if field_name:
sql += " AND f.name LIKE ?"
params.append(f"%{field_name}%")
sql += " ORDER BY m.created_time DESC LIMIT ? OFFSET ?"
params.extend([limit, offset])
try:
rows = db_manager.execute_query(sql, tuple(params))
messages = []
for row in rows:
message = self._row_to_message(row)
message.fields = self._get_message_fields(message.id)
messages.append(message)
return messages
except Exception as e:
logger.error(f"Failed to search messages: {e}")
raise
def get_all(self, limit: int = 1000, offset: int = 0) -> List[Message]:
"""
获取所有消息
Args:
limit: 返回数量限制
offset: 偏移量
Returns:
消息列表
"""
sql = "SELECT * FROM messages ORDER BY created_time DESC LIMIT ? OFFSET ?"
try:
rows = db_manager.execute_query(sql, (limit, offset))
messages = []
for row in rows:
message = self._row_to_message(row)
message.fields = self._get_message_fields(message.id)
messages.append(message)
return messages
except Exception as e:
logger.error(f"Failed to get all messages: {e}")
raise
def count(self, system_name: Optional[str] = None,
message_type: Optional[str] = None) -> int:
"""
统计消息数量
Args:
system_name: 系统名
message_type: 消息类型
Returns:
消息数量
"""
sql = "SELECT COUNT(*) as count FROM messages WHERE 1=1"
params = []
if system_name:
sql += " AND system_name LIKE ?"
params.append(f"%{system_name}%")
if message_type:
sql += " AND message_type LIKE ?"
params.append(f"%{message_type}%")
try:
rows = db_manager.execute_query(sql, tuple(params) if params else None)
return rows[0]['count'] if rows else 0
except Exception as e:
logger.error(f"Failed to count messages: {e}")
raise
def _associate_fields(self, message_id: int, fields: List[Field]):
"""关联字段到消息"""
sql = """
INSERT INTO message_fields (message_id, field_id, field_order)
VALUES (?, ?, ?)
"""
params_list = [
(message_id, field.id, idx)
for idx, field in enumerate(fields)
]
db_manager.execute_many(sql, params_list)
def _clear_field_associations(self, message_id: int):
"""清除消息的字段关联"""
sql = "DELETE FROM message_fields WHERE message_id = ?"
db_manager.execute_update(sql, (message_id,))
def _get_message_fields(self, message_id: int) -> List[Field]:
"""获取消息关联的字段"""
sql = """
SELECT f.* FROM fields f
INNER JOIN message_fields mf ON f.id = mf.field_id
WHERE mf.message_id = ?
ORDER BY mf.field_order
"""
rows = db_manager.execute_query(sql, (message_id,))
return [self.field_repo._row_to_field(row) for row in rows]
def _row_to_message(self, row) -> Message:
"""将数据库行转换为Message对象"""
return Message(
id=row['id'],
full_name=row['full_name'],
system_name=row['system_name'],
message_type=row['message_type'],
version=row['version'],
description=row['description'],
protocol=ProtocolType(row['protocol']),
serialization=SerializationType(row['serialization']),
fields=[], # 稍后加载
created_time=datetime.fromisoformat(row['created_time']) if row['created_time'] else None,
updated_time=datetime.fromisoformat(row['updated_time']) if row['updated_time'] else None,
)