""" 消息数据访问层 负责消息数据的CRUD操作 """ from typing import List, Optional from datetime import datetime from database import db_manager from database.repositories.field_repository import FieldRepository from models.message import Message from models.field import Field from config import ProtocolType, SerializationType from utils.logger import get_logger logger = get_logger(__name__) class MessageRepository: """消息数据访问类""" def __init__(self): self.field_repo = FieldRepository() def create(self, message: Message) -> int: """ 创建消息 Args: message: 消息对象 Returns: 新创建的消息ID """ sql = """ INSERT INTO messages ( full_name, system_name, message_type, version, description, protocol, serialization ) VALUES (?, ?, ?, ?, ?, ?, ?) """ params = ( message.full_name, message.system_name, message.message_type, message.version, message.description, message.protocol.value, message.serialization.value, ) try: with db_manager.transaction(): db_manager.execute_update(sql, params) message_id = db_manager.get_last_insert_id() # 关联字段 if message.fields: self._associate_fields(message_id, message.fields) logger.info(f"Created message: {message.full_name} (ID: {message_id})") return message_id except Exception as e: logger.error(f"Failed to create message: {e}") raise def get_by_id(self, message_id: int) -> Optional[Message]: """ 根据ID获取消息 Args: message_id: 消息ID Returns: 消息对象或None """ sql = "SELECT * FROM messages WHERE id = ?" try: rows = db_manager.execute_query(sql, (message_id,)) if rows: message = self._row_to_message(rows[0]) # 加载关联的字段 message.fields = self._get_message_fields(message_id) return message return None except Exception as e: logger.error(f"Failed to get message by id {message_id}: {e}") raise def get_by_full_name(self, full_name: str) -> Optional[Message]: """ 根据完整名称获取消息 Args: full_name: 完整名称 Returns: 消息对象或None """ sql = "SELECT * FROM messages WHERE full_name = ?" try: rows = db_manager.execute_query(sql, (full_name,)) if rows: message = self._row_to_message(rows[0]) message.fields = self._get_message_fields(message.id) return message return None except Exception as e: logger.error(f"Failed to get message by full_name {full_name}: {e}") raise def update(self, message: Message) -> bool: """ 更新消息 Args: message: 消息对象 Returns: 是否更新成功 """ sql = """ UPDATE messages SET full_name = ?, system_name = ?, message_type = ?, version = ?, description = ?, protocol = ?, serialization = ? WHERE id = ? """ params = ( message.full_name, message.system_name, message.message_type, message.version, message.description, message.protocol.value, message.serialization.value, message.id, ) try: with db_manager.transaction(): rows_affected = db_manager.execute_update(sql, params) # 更新字段关联 if rows_affected > 0: # 先删除旧的关联 self._clear_field_associations(message.id) # 添加新的关联 if message.fields: self._associate_fields(message.id, message.fields) logger.info(f"Updated message: {message.full_name} (ID: {message.id})") return True return False except Exception as e: logger.error(f"Failed to update message: {e}") raise def delete(self, message_id: int) -> bool: """ 删除消息 Args: message_id: 消息ID Returns: 是否删除成功 """ sql = "DELETE FROM messages WHERE id = ?" try: rows_affected = db_manager.execute_update(sql, (message_id,)) if rows_affected > 0: logger.info(f"Deleted message ID: {message_id}") return True return False except Exception as e: logger.error(f"Failed to delete message: {e}") raise def search(self, system_name: Optional[str] = None, message_type: Optional[str] = None, version: Optional[str] = None, field_name: Optional[str] = None, limit: int = 100, offset: int = 0) -> List[Message]: """ 搜索消息 Args: system_name: 系统名(支持模糊查询) message_type: 消息类型(支持模糊查询) version: 版本号 field_name: 字段名(查询包含该字段的消息) limit: 返回数量限制 offset: 偏移量 Returns: 消息列表 """ sql = "SELECT DISTINCT m.* FROM messages m" params = [] # 如果按字段名查询,需要JOIN if field_name: sql += """ INNER JOIN message_fields mf ON m.id = mf.message_id INNER JOIN fields f ON mf.field_id = f.id """ sql += " WHERE 1=1" if system_name: sql += " AND m.system_name LIKE ?" params.append(f"%{system_name}%") if message_type: sql += " AND m.message_type LIKE ?" params.append(f"%{message_type}%") if version: sql += " AND m.version = ?" params.append(version) if field_name: sql += " AND f.name LIKE ?" params.append(f"%{field_name}%") sql += " ORDER BY m.created_time DESC LIMIT ? OFFSET ?" params.extend([limit, offset]) try: rows = db_manager.execute_query(sql, tuple(params)) messages = [] for row in rows: message = self._row_to_message(row) message.fields = self._get_message_fields(message.id) messages.append(message) return messages except Exception as e: logger.error(f"Failed to search messages: {e}") raise def get_all(self, limit: int = 1000, offset: int = 0) -> List[Message]: """ 获取所有消息 Args: limit: 返回数量限制 offset: 偏移量 Returns: 消息列表 """ sql = "SELECT * FROM messages ORDER BY created_time DESC LIMIT ? OFFSET ?" try: rows = db_manager.execute_query(sql, (limit, offset)) messages = [] for row in rows: message = self._row_to_message(row) message.fields = self._get_message_fields(message.id) messages.append(message) return messages except Exception as e: logger.error(f"Failed to get all messages: {e}") raise def count(self, system_name: Optional[str] = None, message_type: Optional[str] = None) -> int: """ 统计消息数量 Args: system_name: 系统名 message_type: 消息类型 Returns: 消息数量 """ sql = "SELECT COUNT(*) as count FROM messages WHERE 1=1" params = [] if system_name: sql += " AND system_name LIKE ?" params.append(f"%{system_name}%") if message_type: sql += " AND message_type LIKE ?" params.append(f"%{message_type}%") try: rows = db_manager.execute_query(sql, tuple(params) if params else None) return rows[0]['count'] if rows else 0 except Exception as e: logger.error(f"Failed to count messages: {e}") raise def _associate_fields(self, message_id: int, fields: List[Field]): """关联字段到消息""" sql = """ INSERT INTO message_fields (message_id, field_id, field_order) VALUES (?, ?, ?) """ params_list = [ (message_id, field.id, idx) for idx, field in enumerate(fields) ] db_manager.execute_many(sql, params_list) def _clear_field_associations(self, message_id: int): """清除消息的字段关联""" sql = "DELETE FROM message_fields WHERE message_id = ?" db_manager.execute_update(sql, (message_id,)) def _get_message_fields(self, message_id: int) -> List[Field]: """获取消息关联的字段""" sql = """ SELECT f.* FROM fields f INNER JOIN message_fields mf ON f.id = mf.field_id WHERE mf.message_id = ? ORDER BY mf.field_order """ rows = db_manager.execute_query(sql, (message_id,)) return [self.field_repo._row_to_field(row) for row in rows] def _row_to_message(self, row) -> Message: """将数据库行转换为Message对象""" return Message( id=row['id'], full_name=row['full_name'], system_name=row['system_name'], message_type=row['message_type'], version=row['version'], description=row['description'], protocol=ProtocolType(row['protocol']), serialization=SerializationType(row['serialization']), fields=[], # 稍后加载 created_time=datetime.fromisoformat(row['created_time']) if row['created_time'] else None, updated_time=datetime.fromisoformat(row['updated_time']) if row['updated_time'] else None, )