345 lines
10 KiB
Python
345 lines
10 KiB
Python
|
|
"""
|
|||
|
|
消息数据访问层
|
|||
|
|
负责消息数据的CRUD操作
|
|||
|
|
"""
|
|||
|
|
from typing import List, Optional
|
|||
|
|
from datetime import datetime
|
|||
|
|
|
|||
|
|
from database import db_manager
|
|||
|
|
from database.repositories.field_repository import FieldRepository
|
|||
|
|
from models.message import Message
|
|||
|
|
from models.field import Field
|
|||
|
|
from config import ProtocolType, SerializationType
|
|||
|
|
from utils.logger import get_logger
|
|||
|
|
|
|||
|
|
logger = get_logger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MessageRepository:
|
|||
|
|
"""消息数据访问类"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self.field_repo = FieldRepository()
|
|||
|
|
|
|||
|
|
def create(self, message: Message) -> int:
|
|||
|
|
"""
|
|||
|
|
创建消息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
message: 消息对象
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
新创建的消息ID
|
|||
|
|
"""
|
|||
|
|
sql = """
|
|||
|
|
INSERT INTO messages (
|
|||
|
|
full_name, system_name, message_type, version,
|
|||
|
|
description, protocol, serialization
|
|||
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
params = (
|
|||
|
|
message.full_name,
|
|||
|
|
message.system_name,
|
|||
|
|
message.message_type,
|
|||
|
|
message.version,
|
|||
|
|
message.description,
|
|||
|
|
message.protocol.value,
|
|||
|
|
message.serialization.value,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with db_manager.transaction():
|
|||
|
|
db_manager.execute_update(sql, params)
|
|||
|
|
message_id = db_manager.get_last_insert_id()
|
|||
|
|
|
|||
|
|
# 关联字段
|
|||
|
|
if message.fields:
|
|||
|
|
self._associate_fields(message_id, message.fields)
|
|||
|
|
|
|||
|
|
logger.info(f"Created message: {message.full_name} (ID: {message_id})")
|
|||
|
|
return message_id
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Failed to create message: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def get_by_id(self, message_id: int) -> Optional[Message]:
|
|||
|
|
"""
|
|||
|
|
根据ID获取消息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
message_id: 消息ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
消息对象或None
|
|||
|
|
"""
|
|||
|
|
sql = "SELECT * FROM messages WHERE id = ?"
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
rows = db_manager.execute_query(sql, (message_id,))
|
|||
|
|
if rows:
|
|||
|
|
message = self._row_to_message(rows[0])
|
|||
|
|
# 加载关联的字段
|
|||
|
|
message.fields = self._get_message_fields(message_id)
|
|||
|
|
return message
|
|||
|
|
return None
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Failed to get message by id {message_id}: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def get_by_full_name(self, full_name: str) -> Optional[Message]:
|
|||
|
|
"""
|
|||
|
|
根据完整名称获取消息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
full_name: 完整名称
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
消息对象或None
|
|||
|
|
"""
|
|||
|
|
sql = "SELECT * FROM messages WHERE full_name = ?"
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
rows = db_manager.execute_query(sql, (full_name,))
|
|||
|
|
if rows:
|
|||
|
|
message = self._row_to_message(rows[0])
|
|||
|
|
message.fields = self._get_message_fields(message.id)
|
|||
|
|
return message
|
|||
|
|
return None
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Failed to get message by full_name {full_name}: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def update(self, message: Message) -> bool:
|
|||
|
|
"""
|
|||
|
|
更新消息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
message: 消息对象
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
是否更新成功
|
|||
|
|
"""
|
|||
|
|
sql = """
|
|||
|
|
UPDATE messages SET
|
|||
|
|
full_name = ?, system_name = ?, message_type = ?, version = ?,
|
|||
|
|
description = ?, protocol = ?, serialization = ?
|
|||
|
|
WHERE id = ?
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
params = (
|
|||
|
|
message.full_name,
|
|||
|
|
message.system_name,
|
|||
|
|
message.message_type,
|
|||
|
|
message.version,
|
|||
|
|
message.description,
|
|||
|
|
message.protocol.value,
|
|||
|
|
message.serialization.value,
|
|||
|
|
message.id,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with db_manager.transaction():
|
|||
|
|
rows_affected = db_manager.execute_update(sql, params)
|
|||
|
|
|
|||
|
|
# 更新字段关联
|
|||
|
|
if rows_affected > 0:
|
|||
|
|
# 先删除旧的关联
|
|||
|
|
self._clear_field_associations(message.id)
|
|||
|
|
# 添加新的关联
|
|||
|
|
if message.fields:
|
|||
|
|
self._associate_fields(message.id, message.fields)
|
|||
|
|
|
|||
|
|
logger.info(f"Updated message: {message.full_name} (ID: {message.id})")
|
|||
|
|
return True
|
|||
|
|
return False
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Failed to update message: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def delete(self, message_id: int) -> bool:
|
|||
|
|
"""
|
|||
|
|
删除消息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
message_id: 消息ID
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
是否删除成功
|
|||
|
|
"""
|
|||
|
|
sql = "DELETE FROM messages WHERE id = ?"
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
rows_affected = db_manager.execute_update(sql, (message_id,))
|
|||
|
|
if rows_affected > 0:
|
|||
|
|
logger.info(f"Deleted message ID: {message_id}")
|
|||
|
|
return True
|
|||
|
|
return False
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Failed to delete message: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def search(self, system_name: Optional[str] = None,
|
|||
|
|
message_type: Optional[str] = None,
|
|||
|
|
version: Optional[str] = None,
|
|||
|
|
field_name: Optional[str] = None,
|
|||
|
|
limit: int = 100, offset: int = 0) -> List[Message]:
|
|||
|
|
"""
|
|||
|
|
搜索消息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
system_name: 系统名(支持模糊查询)
|
|||
|
|
message_type: 消息类型(支持模糊查询)
|
|||
|
|
version: 版本号
|
|||
|
|
field_name: 字段名(查询包含该字段的消息)
|
|||
|
|
limit: 返回数量限制
|
|||
|
|
offset: 偏移量
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
消息列表
|
|||
|
|
"""
|
|||
|
|
sql = "SELECT DISTINCT m.* FROM messages m"
|
|||
|
|
params = []
|
|||
|
|
|
|||
|
|
# 如果按字段名查询,需要JOIN
|
|||
|
|
if field_name:
|
|||
|
|
sql += """
|
|||
|
|
INNER JOIN message_fields mf ON m.id = mf.message_id
|
|||
|
|
INNER JOIN fields f ON mf.field_id = f.id
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
sql += " WHERE 1=1"
|
|||
|
|
|
|||
|
|
if system_name:
|
|||
|
|
sql += " AND m.system_name LIKE ?"
|
|||
|
|
params.append(f"%{system_name}%")
|
|||
|
|
|
|||
|
|
if message_type:
|
|||
|
|
sql += " AND m.message_type LIKE ?"
|
|||
|
|
params.append(f"%{message_type}%")
|
|||
|
|
|
|||
|
|
if version:
|
|||
|
|
sql += " AND m.version = ?"
|
|||
|
|
params.append(version)
|
|||
|
|
|
|||
|
|
if field_name:
|
|||
|
|
sql += " AND f.name LIKE ?"
|
|||
|
|
params.append(f"%{field_name}%")
|
|||
|
|
|
|||
|
|
sql += " ORDER BY m.created_time DESC LIMIT ? OFFSET ?"
|
|||
|
|
params.extend([limit, offset])
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
rows = db_manager.execute_query(sql, tuple(params))
|
|||
|
|
messages = []
|
|||
|
|
for row in rows:
|
|||
|
|
message = self._row_to_message(row)
|
|||
|
|
message.fields = self._get_message_fields(message.id)
|
|||
|
|
messages.append(message)
|
|||
|
|
return messages
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Failed to search messages: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def get_all(self, limit: int = 1000, offset: int = 0) -> List[Message]:
|
|||
|
|
"""
|
|||
|
|
获取所有消息
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
limit: 返回数量限制
|
|||
|
|
offset: 偏移量
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
消息列表
|
|||
|
|
"""
|
|||
|
|
sql = "SELECT * FROM messages ORDER BY created_time DESC LIMIT ? OFFSET ?"
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
rows = db_manager.execute_query(sql, (limit, offset))
|
|||
|
|
messages = []
|
|||
|
|
for row in rows:
|
|||
|
|
message = self._row_to_message(row)
|
|||
|
|
message.fields = self._get_message_fields(message.id)
|
|||
|
|
messages.append(message)
|
|||
|
|
return messages
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Failed to get all messages: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def count(self, system_name: Optional[str] = None,
|
|||
|
|
message_type: Optional[str] = None) -> int:
|
|||
|
|
"""
|
|||
|
|
统计消息数量
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
system_name: 系统名
|
|||
|
|
message_type: 消息类型
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
消息数量
|
|||
|
|
"""
|
|||
|
|
sql = "SELECT COUNT(*) as count FROM messages WHERE 1=1"
|
|||
|
|
params = []
|
|||
|
|
|
|||
|
|
if system_name:
|
|||
|
|
sql += " AND system_name LIKE ?"
|
|||
|
|
params.append(f"%{system_name}%")
|
|||
|
|
|
|||
|
|
if message_type:
|
|||
|
|
sql += " AND message_type LIKE ?"
|
|||
|
|
params.append(f"%{message_type}%")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
rows = db_manager.execute_query(sql, tuple(params) if params else None)
|
|||
|
|
return rows[0]['count'] if rows else 0
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Failed to count messages: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def _associate_fields(self, message_id: int, fields: List[Field]):
|
|||
|
|
"""关联字段到消息"""
|
|||
|
|
sql = """
|
|||
|
|
INSERT INTO message_fields (message_id, field_id, field_order)
|
|||
|
|
VALUES (?, ?, ?)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
params_list = [
|
|||
|
|
(message_id, field.id, idx)
|
|||
|
|
for idx, field in enumerate(fields)
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
db_manager.execute_many(sql, params_list)
|
|||
|
|
|
|||
|
|
def _clear_field_associations(self, message_id: int):
|
|||
|
|
"""清除消息的字段关联"""
|
|||
|
|
sql = "DELETE FROM message_fields WHERE message_id = ?"
|
|||
|
|
db_manager.execute_update(sql, (message_id,))
|
|||
|
|
|
|||
|
|
def _get_message_fields(self, message_id: int) -> List[Field]:
|
|||
|
|
"""获取消息关联的字段"""
|
|||
|
|
sql = """
|
|||
|
|
SELECT f.* FROM fields f
|
|||
|
|
INNER JOIN message_fields mf ON f.id = mf.field_id
|
|||
|
|
WHERE mf.message_id = ?
|
|||
|
|
ORDER BY mf.field_order
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
rows = db_manager.execute_query(sql, (message_id,))
|
|||
|
|
return [self.field_repo._row_to_field(row) for row in rows]
|
|||
|
|
|
|||
|
|
def _row_to_message(self, row) -> Message:
|
|||
|
|
"""将数据库行转换为Message对象"""
|
|||
|
|
return Message(
|
|||
|
|
id=row['id'],
|
|||
|
|
full_name=row['full_name'],
|
|||
|
|
system_name=row['system_name'],
|
|||
|
|
message_type=row['message_type'],
|
|||
|
|
version=row['version'],
|
|||
|
|
description=row['description'],
|
|||
|
|
protocol=ProtocolType(row['protocol']),
|
|||
|
|
serialization=SerializationType(row['serialization']),
|
|||
|
|
fields=[], # 稍后加载
|
|||
|
|
created_time=datetime.fromisoformat(row['created_time']) if row['created_time'] else None,
|
|||
|
|
updated_time=datetime.fromisoformat(row['updated_time']) if row['updated_time'] else None,
|
|||
|
|
)
|