SIT/database/repositories/message_repository.py

345 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
消息数据访问层
负责消息数据的CRUD操作
"""
from typing import List, Optional
from datetime import datetime
from database import db_manager
from database.repositories.field_repository import FieldRepository
from models.message import Message
from models.field import Field
from config import ProtocolType, SerializationType
from utils.logger import get_logger
logger = get_logger(__name__)
class MessageRepository:
"""消息数据访问类"""
def __init__(self):
self.field_repo = FieldRepository()
def create(self, message: Message) -> int:
"""
创建消息
Args:
message: 消息对象
Returns:
新创建的消息ID
"""
sql = """
INSERT INTO messages (
full_name, system_name, message_type, version,
description, protocol, serialization
) VALUES (?, ?, ?, ?, ?, ?, ?)
"""
params = (
message.full_name,
message.system_name,
message.message_type,
message.version,
message.description,
message.protocol.value,
message.serialization.value,
)
try:
with db_manager.transaction():
db_manager.execute_update(sql, params)
message_id = db_manager.get_last_insert_id()
# 关联字段
if message.fields:
self._associate_fields(message_id, message.fields)
logger.info(f"Created message: {message.full_name} (ID: {message_id})")
return message_id
except Exception as e:
logger.error(f"Failed to create message: {e}")
raise
def get_by_id(self, message_id: int) -> Optional[Message]:
"""
根据ID获取消息
Args:
message_id: 消息ID
Returns:
消息对象或None
"""
sql = "SELECT * FROM messages WHERE id = ?"
try:
rows = db_manager.execute_query(sql, (message_id,))
if rows:
message = self._row_to_message(rows[0])
# 加载关联的字段
message.fields = self._get_message_fields(message_id)
return message
return None
except Exception as e:
logger.error(f"Failed to get message by id {message_id}: {e}")
raise
def get_by_full_name(self, full_name: str) -> Optional[Message]:
"""
根据完整名称获取消息
Args:
full_name: 完整名称
Returns:
消息对象或None
"""
sql = "SELECT * FROM messages WHERE full_name = ?"
try:
rows = db_manager.execute_query(sql, (full_name,))
if rows:
message = self._row_to_message(rows[0])
message.fields = self._get_message_fields(message.id)
return message
return None
except Exception as e:
logger.error(f"Failed to get message by full_name {full_name}: {e}")
raise
def update(self, message: Message) -> bool:
"""
更新消息
Args:
message: 消息对象
Returns:
是否更新成功
"""
sql = """
UPDATE messages SET
full_name = ?, system_name = ?, message_type = ?, version = ?,
description = ?, protocol = ?, serialization = ?
WHERE id = ?
"""
params = (
message.full_name,
message.system_name,
message.message_type,
message.version,
message.description,
message.protocol.value,
message.serialization.value,
message.id,
)
try:
with db_manager.transaction():
rows_affected = db_manager.execute_update(sql, params)
# 更新字段关联
if rows_affected > 0:
# 先删除旧的关联
self._clear_field_associations(message.id)
# 添加新的关联
if message.fields:
self._associate_fields(message.id, message.fields)
logger.info(f"Updated message: {message.full_name} (ID: {message.id})")
return True
return False
except Exception as e:
logger.error(f"Failed to update message: {e}")
raise
def delete(self, message_id: int) -> bool:
"""
删除消息
Args:
message_id: 消息ID
Returns:
是否删除成功
"""
sql = "DELETE FROM messages WHERE id = ?"
try:
rows_affected = db_manager.execute_update(sql, (message_id,))
if rows_affected > 0:
logger.info(f"Deleted message ID: {message_id}")
return True
return False
except Exception as e:
logger.error(f"Failed to delete message: {e}")
raise
def search(self, system_name: Optional[str] = None,
message_type: Optional[str] = None,
version: Optional[str] = None,
field_name: Optional[str] = None,
limit: int = 100, offset: int = 0) -> List[Message]:
"""
搜索消息
Args:
system_name: 系统名(支持模糊查询)
message_type: 消息类型(支持模糊查询)
version: 版本号
field_name: 字段名(查询包含该字段的消息)
limit: 返回数量限制
offset: 偏移量
Returns:
消息列表
"""
sql = "SELECT DISTINCT m.* FROM messages m"
params = []
# 如果按字段名查询需要JOIN
if field_name:
sql += """
INNER JOIN message_fields mf ON m.id = mf.message_id
INNER JOIN fields f ON mf.field_id = f.id
"""
sql += " WHERE 1=1"
if system_name:
sql += " AND m.system_name LIKE ?"
params.append(f"%{system_name}%")
if message_type:
sql += " AND m.message_type LIKE ?"
params.append(f"%{message_type}%")
if version:
sql += " AND m.version = ?"
params.append(version)
if field_name:
sql += " AND f.name LIKE ?"
params.append(f"%{field_name}%")
sql += " ORDER BY m.created_time DESC LIMIT ? OFFSET ?"
params.extend([limit, offset])
try:
rows = db_manager.execute_query(sql, tuple(params))
messages = []
for row in rows:
message = self._row_to_message(row)
message.fields = self._get_message_fields(message.id)
messages.append(message)
return messages
except Exception as e:
logger.error(f"Failed to search messages: {e}")
raise
def get_all(self, limit: int = 1000, offset: int = 0) -> List[Message]:
"""
获取所有消息
Args:
limit: 返回数量限制
offset: 偏移量
Returns:
消息列表
"""
sql = "SELECT * FROM messages ORDER BY created_time DESC LIMIT ? OFFSET ?"
try:
rows = db_manager.execute_query(sql, (limit, offset))
messages = []
for row in rows:
message = self._row_to_message(row)
message.fields = self._get_message_fields(message.id)
messages.append(message)
return messages
except Exception as e:
logger.error(f"Failed to get all messages: {e}")
raise
def count(self, system_name: Optional[str] = None,
message_type: Optional[str] = None) -> int:
"""
统计消息数量
Args:
system_name: 系统名
message_type: 消息类型
Returns:
消息数量
"""
sql = "SELECT COUNT(*) as count FROM messages WHERE 1=1"
params = []
if system_name:
sql += " AND system_name LIKE ?"
params.append(f"%{system_name}%")
if message_type:
sql += " AND message_type LIKE ?"
params.append(f"%{message_type}%")
try:
rows = db_manager.execute_query(sql, tuple(params) if params else None)
return rows[0]['count'] if rows else 0
except Exception as e:
logger.error(f"Failed to count messages: {e}")
raise
def _associate_fields(self, message_id: int, fields: List[Field]):
"""关联字段到消息"""
sql = """
INSERT INTO message_fields (message_id, field_id, field_order)
VALUES (?, ?, ?)
"""
params_list = [
(message_id, field.id, idx)
for idx, field in enumerate(fields)
]
db_manager.execute_many(sql, params_list)
def _clear_field_associations(self, message_id: int):
"""清除消息的字段关联"""
sql = "DELETE FROM message_fields WHERE message_id = ?"
db_manager.execute_update(sql, (message_id,))
def _get_message_fields(self, message_id: int) -> List[Field]:
"""获取消息关联的字段"""
sql = """
SELECT f.* FROM fields f
INNER JOIN message_fields mf ON f.id = mf.field_id
WHERE mf.message_id = ?
ORDER BY mf.field_order
"""
rows = db_manager.execute_query(sql, (message_id,))
return [self.field_repo._row_to_field(row) for row in rows]
def _row_to_message(self, row) -> Message:
"""将数据库行转换为Message对象"""
return Message(
id=row['id'],
full_name=row['full_name'],
system_name=row['system_name'],
message_type=row['message_type'],
version=row['version'],
description=row['description'],
protocol=ProtocolType(row['protocol']),
serialization=SerializationType(row['serialization']),
fields=[], # 稍后加载
created_time=datetime.fromisoformat(row['created_time']) if row['created_time'] else None,
updated_time=datetime.fromisoformat(row['updated_time']) if row['updated_time'] else None,
)