345 lines
10 KiB
Python
345 lines
10 KiB
Python
"""
|
||
消息数据访问层
|
||
负责消息数据的CRUD操作
|
||
"""
|
||
from typing import List, Optional
|
||
from datetime import datetime
|
||
|
||
from database import db_manager
|
||
from database.repositories.field_repository import FieldRepository
|
||
from models.message import Message
|
||
from models.field import Field
|
||
from config import ProtocolType, SerializationType
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class MessageRepository:
|
||
"""消息数据访问类"""
|
||
|
||
def __init__(self):
|
||
self.field_repo = FieldRepository()
|
||
|
||
def create(self, message: Message) -> int:
|
||
"""
|
||
创建消息
|
||
|
||
Args:
|
||
message: 消息对象
|
||
|
||
Returns:
|
||
新创建的消息ID
|
||
"""
|
||
sql = """
|
||
INSERT INTO messages (
|
||
full_name, system_name, message_type, version,
|
||
description, protocol, serialization
|
||
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||
"""
|
||
|
||
params = (
|
||
message.full_name,
|
||
message.system_name,
|
||
message.message_type,
|
||
message.version,
|
||
message.description,
|
||
message.protocol.value,
|
||
message.serialization.value,
|
||
)
|
||
|
||
try:
|
||
with db_manager.transaction():
|
||
db_manager.execute_update(sql, params)
|
||
message_id = db_manager.get_last_insert_id()
|
||
|
||
# 关联字段
|
||
if message.fields:
|
||
self._associate_fields(message_id, message.fields)
|
||
|
||
logger.info(f"Created message: {message.full_name} (ID: {message_id})")
|
||
return message_id
|
||
except Exception as e:
|
||
logger.error(f"Failed to create message: {e}")
|
||
raise
|
||
|
||
def get_by_id(self, message_id: int) -> Optional[Message]:
|
||
"""
|
||
根据ID获取消息
|
||
|
||
Args:
|
||
message_id: 消息ID
|
||
|
||
Returns:
|
||
消息对象或None
|
||
"""
|
||
sql = "SELECT * FROM messages WHERE id = ?"
|
||
|
||
try:
|
||
rows = db_manager.execute_query(sql, (message_id,))
|
||
if rows:
|
||
message = self._row_to_message(rows[0])
|
||
# 加载关联的字段
|
||
message.fields = self._get_message_fields(message_id)
|
||
return message
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Failed to get message by id {message_id}: {e}")
|
||
raise
|
||
|
||
def get_by_full_name(self, full_name: str) -> Optional[Message]:
|
||
"""
|
||
根据完整名称获取消息
|
||
|
||
Args:
|
||
full_name: 完整名称
|
||
|
||
Returns:
|
||
消息对象或None
|
||
"""
|
||
sql = "SELECT * FROM messages WHERE full_name = ?"
|
||
|
||
try:
|
||
rows = db_manager.execute_query(sql, (full_name,))
|
||
if rows:
|
||
message = self._row_to_message(rows[0])
|
||
message.fields = self._get_message_fields(message.id)
|
||
return message
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Failed to get message by full_name {full_name}: {e}")
|
||
raise
|
||
|
||
def update(self, message: Message) -> bool:
|
||
"""
|
||
更新消息
|
||
|
||
Args:
|
||
message: 消息对象
|
||
|
||
Returns:
|
||
是否更新成功
|
||
"""
|
||
sql = """
|
||
UPDATE messages SET
|
||
full_name = ?, system_name = ?, message_type = ?, version = ?,
|
||
description = ?, protocol = ?, serialization = ?
|
||
WHERE id = ?
|
||
"""
|
||
|
||
params = (
|
||
message.full_name,
|
||
message.system_name,
|
||
message.message_type,
|
||
message.version,
|
||
message.description,
|
||
message.protocol.value,
|
||
message.serialization.value,
|
||
message.id,
|
||
)
|
||
|
||
try:
|
||
with db_manager.transaction():
|
||
rows_affected = db_manager.execute_update(sql, params)
|
||
|
||
# 更新字段关联
|
||
if rows_affected > 0:
|
||
# 先删除旧的关联
|
||
self._clear_field_associations(message.id)
|
||
# 添加新的关联
|
||
if message.fields:
|
||
self._associate_fields(message.id, message.fields)
|
||
|
||
logger.info(f"Updated message: {message.full_name} (ID: {message.id})")
|
||
return True
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"Failed to update message: {e}")
|
||
raise
|
||
|
||
def delete(self, message_id: int) -> bool:
|
||
"""
|
||
删除消息
|
||
|
||
Args:
|
||
message_id: 消息ID
|
||
|
||
Returns:
|
||
是否删除成功
|
||
"""
|
||
sql = "DELETE FROM messages WHERE id = ?"
|
||
|
||
try:
|
||
rows_affected = db_manager.execute_update(sql, (message_id,))
|
||
if rows_affected > 0:
|
||
logger.info(f"Deleted message ID: {message_id}")
|
||
return True
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"Failed to delete message: {e}")
|
||
raise
|
||
|
||
def search(self, system_name: Optional[str] = None,
|
||
message_type: Optional[str] = None,
|
||
version: Optional[str] = None,
|
||
field_name: Optional[str] = None,
|
||
limit: int = 100, offset: int = 0) -> List[Message]:
|
||
"""
|
||
搜索消息
|
||
|
||
Args:
|
||
system_name: 系统名(支持模糊查询)
|
||
message_type: 消息类型(支持模糊查询)
|
||
version: 版本号
|
||
field_name: 字段名(查询包含该字段的消息)
|
||
limit: 返回数量限制
|
||
offset: 偏移量
|
||
|
||
Returns:
|
||
消息列表
|
||
"""
|
||
sql = "SELECT DISTINCT m.* FROM messages m"
|
||
params = []
|
||
|
||
# 如果按字段名查询,需要JOIN
|
||
if field_name:
|
||
sql += """
|
||
INNER JOIN message_fields mf ON m.id = mf.message_id
|
||
INNER JOIN fields f ON mf.field_id = f.id
|
||
"""
|
||
|
||
sql += " WHERE 1=1"
|
||
|
||
if system_name:
|
||
sql += " AND m.system_name LIKE ?"
|
||
params.append(f"%{system_name}%")
|
||
|
||
if message_type:
|
||
sql += " AND m.message_type LIKE ?"
|
||
params.append(f"%{message_type}%")
|
||
|
||
if version:
|
||
sql += " AND m.version = ?"
|
||
params.append(version)
|
||
|
||
if field_name:
|
||
sql += " AND f.name LIKE ?"
|
||
params.append(f"%{field_name}%")
|
||
|
||
sql += " ORDER BY m.created_time DESC LIMIT ? OFFSET ?"
|
||
params.extend([limit, offset])
|
||
|
||
try:
|
||
rows = db_manager.execute_query(sql, tuple(params))
|
||
messages = []
|
||
for row in rows:
|
||
message = self._row_to_message(row)
|
||
message.fields = self._get_message_fields(message.id)
|
||
messages.append(message)
|
||
return messages
|
||
except Exception as e:
|
||
logger.error(f"Failed to search messages: {e}")
|
||
raise
|
||
|
||
def get_all(self, limit: int = 1000, offset: int = 0) -> List[Message]:
|
||
"""
|
||
获取所有消息
|
||
|
||
Args:
|
||
limit: 返回数量限制
|
||
offset: 偏移量
|
||
|
||
Returns:
|
||
消息列表
|
||
"""
|
||
sql = "SELECT * FROM messages ORDER BY created_time DESC LIMIT ? OFFSET ?"
|
||
|
||
try:
|
||
rows = db_manager.execute_query(sql, (limit, offset))
|
||
messages = []
|
||
for row in rows:
|
||
message = self._row_to_message(row)
|
||
message.fields = self._get_message_fields(message.id)
|
||
messages.append(message)
|
||
return messages
|
||
except Exception as e:
|
||
logger.error(f"Failed to get all messages: {e}")
|
||
raise
|
||
|
||
def count(self, system_name: Optional[str] = None,
|
||
message_type: Optional[str] = None) -> int:
|
||
"""
|
||
统计消息数量
|
||
|
||
Args:
|
||
system_name: 系统名
|
||
message_type: 消息类型
|
||
|
||
Returns:
|
||
消息数量
|
||
"""
|
||
sql = "SELECT COUNT(*) as count FROM messages WHERE 1=1"
|
||
params = []
|
||
|
||
if system_name:
|
||
sql += " AND system_name LIKE ?"
|
||
params.append(f"%{system_name}%")
|
||
|
||
if message_type:
|
||
sql += " AND message_type LIKE ?"
|
||
params.append(f"%{message_type}%")
|
||
|
||
try:
|
||
rows = db_manager.execute_query(sql, tuple(params) if params else None)
|
||
return rows[0]['count'] if rows else 0
|
||
except Exception as e:
|
||
logger.error(f"Failed to count messages: {e}")
|
||
raise
|
||
|
||
def _associate_fields(self, message_id: int, fields: List[Field]):
|
||
"""关联字段到消息"""
|
||
sql = """
|
||
INSERT INTO message_fields (message_id, field_id, field_order)
|
||
VALUES (?, ?, ?)
|
||
"""
|
||
|
||
params_list = [
|
||
(message_id, field.id, idx)
|
||
for idx, field in enumerate(fields)
|
||
]
|
||
|
||
db_manager.execute_many(sql, params_list)
|
||
|
||
def _clear_field_associations(self, message_id: int):
|
||
"""清除消息的字段关联"""
|
||
sql = "DELETE FROM message_fields WHERE message_id = ?"
|
||
db_manager.execute_update(sql, (message_id,))
|
||
|
||
def _get_message_fields(self, message_id: int) -> List[Field]:
|
||
"""获取消息关联的字段"""
|
||
sql = """
|
||
SELECT f.* FROM fields f
|
||
INNER JOIN message_fields mf ON f.id = mf.field_id
|
||
WHERE mf.message_id = ?
|
||
ORDER BY mf.field_order
|
||
"""
|
||
|
||
rows = db_manager.execute_query(sql, (message_id,))
|
||
return [self.field_repo._row_to_field(row) for row in rows]
|
||
|
||
def _row_to_message(self, row) -> Message:
|
||
"""将数据库行转换为Message对象"""
|
||
return Message(
|
||
id=row['id'],
|
||
full_name=row['full_name'],
|
||
system_name=row['system_name'],
|
||
message_type=row['message_type'],
|
||
version=row['version'],
|
||
description=row['description'],
|
||
protocol=ProtocolType(row['protocol']),
|
||
serialization=SerializationType(row['serialization']),
|
||
fields=[], # 稍后加载
|
||
created_time=datetime.fromisoformat(row['created_time']) if row['created_time'] else None,
|
||
updated_time=datetime.fromisoformat(row['updated_time']) if row['updated_time'] else None,
|
||
)
|