210 lines
6.0 KiB
Python
210 lines
6.0 KiB
Python
|
|
"""
|
|||
|
|
数据库管理器
|
|||
|
|
负责数据库连接、初始化和事务管理
|
|||
|
|
"""
|
|||
|
|
import sqlite3
|
|||
|
|
import threading
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Optional, Any, List, Tuple
|
|||
|
|
from contextlib import contextmanager
|
|||
|
|
|
|||
|
|
from config import Config
|
|||
|
|
from utils.logger import get_logger
|
|||
|
|
|
|||
|
|
logger = get_logger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DatabaseManager:
|
|||
|
|
"""
|
|||
|
|
数据库管理器(单例模式)
|
|||
|
|
提供数据库连接池和事务管理功能
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
_instance = None
|
|||
|
|
_lock = threading.Lock()
|
|||
|
|
|
|||
|
|
def __new__(cls):
|
|||
|
|
"""单例模式实现"""
|
|||
|
|
if cls._instance is None:
|
|||
|
|
with cls._lock:
|
|||
|
|
if cls._instance is None:
|
|||
|
|
cls._instance = super().__new__(cls)
|
|||
|
|
return cls._instance
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
"""初始化数据库管理器"""
|
|||
|
|
if not hasattr(self, '_initialized'):
|
|||
|
|
self._db_path = Config.DATABASE_PATH
|
|||
|
|
self._local = threading.local()
|
|||
|
|
self._initialized = True
|
|||
|
|
self._initialize_database()
|
|||
|
|
logger.info(f"Database manager initialized: {self._db_path}")
|
|||
|
|
|
|||
|
|
def _initialize_database(self):
|
|||
|
|
"""初始化数据库表结构"""
|
|||
|
|
schema_path = Config.DATABASE_DIR / "schema.sql"
|
|||
|
|
|
|||
|
|
if not schema_path.exists():
|
|||
|
|
logger.warning(f"Schema file not found: {schema_path}")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
with open(schema_path, 'r', encoding='utf-8') as f:
|
|||
|
|
schema_sql = f.read()
|
|||
|
|
|
|||
|
|
conn = self._get_connection()
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
cursor.executescript(schema_sql)
|
|||
|
|
conn.commit()
|
|||
|
|
logger.info("Database schema initialized successfully")
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Failed to initialize database schema: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def _get_connection(self) -> sqlite3.Connection:
|
|||
|
|
"""获取当前线程的数据库连接"""
|
|||
|
|
if not hasattr(self._local, 'connection') or self._local.connection is None:
|
|||
|
|
self._local.connection = sqlite3.connect(
|
|||
|
|
str(self._db_path),
|
|||
|
|
check_same_thread=False,
|
|||
|
|
timeout=30.0
|
|||
|
|
)
|
|||
|
|
self._local.connection.row_factory = sqlite3.Row
|
|||
|
|
# 启用外键约束
|
|||
|
|
self._local.connection.execute("PRAGMA foreign_keys = ON")
|
|||
|
|
logger.debug(f"Created new database connection for thread {threading.current_thread().name}")
|
|||
|
|
|
|||
|
|
return self._local.connection
|
|||
|
|
|
|||
|
|
@contextmanager
|
|||
|
|
def get_cursor(self):
|
|||
|
|
"""
|
|||
|
|
获取数据库游标的上下文管理器
|
|||
|
|
|
|||
|
|
使用示例:
|
|||
|
|
with db_manager.get_cursor() as cursor:
|
|||
|
|
cursor.execute("SELECT * FROM fields")
|
|||
|
|
"""
|
|||
|
|
conn = self._get_connection()
|
|||
|
|
cursor = conn.cursor()
|
|||
|
|
try:
|
|||
|
|
yield cursor
|
|||
|
|
conn.commit()
|
|||
|
|
except Exception as e:
|
|||
|
|
conn.rollback()
|
|||
|
|
logger.error(f"Database operation failed: {e}")
|
|||
|
|
raise
|
|||
|
|
finally:
|
|||
|
|
cursor.close()
|
|||
|
|
|
|||
|
|
@contextmanager
|
|||
|
|
def transaction(self):
|
|||
|
|
"""
|
|||
|
|
事务上下文管理器
|
|||
|
|
|
|||
|
|
使用示例:
|
|||
|
|
with db_manager.transaction():
|
|||
|
|
# 执行多个数据库操作
|
|||
|
|
pass
|
|||
|
|
"""
|
|||
|
|
conn = self._get_connection()
|
|||
|
|
try:
|
|||
|
|
yield conn
|
|||
|
|
conn.commit()
|
|||
|
|
logger.debug("Transaction committed successfully")
|
|||
|
|
except Exception as e:
|
|||
|
|
conn.rollback()
|
|||
|
|
logger.error(f"Transaction rolled back: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
def execute_query(self, sql: str, params: Optional[Tuple] = None) -> List[sqlite3.Row]:
|
|||
|
|
"""
|
|||
|
|
执行查询SQL
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
sql: SQL查询语句
|
|||
|
|
params: 查询参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
查询结果列表
|
|||
|
|
"""
|
|||
|
|
with self.get_cursor() as cursor:
|
|||
|
|
if params:
|
|||
|
|
cursor.execute(sql, params)
|
|||
|
|
else:
|
|||
|
|
cursor.execute(sql)
|
|||
|
|
return cursor.fetchall()
|
|||
|
|
|
|||
|
|
def execute_update(self, sql: str, params: Optional[Tuple] = None) -> int:
|
|||
|
|
"""
|
|||
|
|
执行更新SQL(INSERT, UPDATE, DELETE)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
sql: SQL更新语句
|
|||
|
|
params: 更新参数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
受影响的行数
|
|||
|
|
"""
|
|||
|
|
with self.get_cursor() as cursor:
|
|||
|
|
if params:
|
|||
|
|
cursor.execute(sql, params)
|
|||
|
|
else:
|
|||
|
|
cursor.execute(sql)
|
|||
|
|
return cursor.rowcount
|
|||
|
|
|
|||
|
|
def execute_many(self, sql: str, params_list: List[Tuple]) -> int:
|
|||
|
|
"""
|
|||
|
|
批量执行SQL
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
sql: SQL语句
|
|||
|
|
params_list: 参数列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
受影响的行数
|
|||
|
|
"""
|
|||
|
|
with self.get_cursor() as cursor:
|
|||
|
|
cursor.executemany(sql, params_list)
|
|||
|
|
return cursor.rowcount
|
|||
|
|
|
|||
|
|
def get_last_insert_id(self) -> int:
|
|||
|
|
"""获取最后插入的行ID"""
|
|||
|
|
conn = self._get_connection()
|
|||
|
|
return conn.execute("SELECT last_insert_rowid()").fetchone()[0]
|
|||
|
|
|
|||
|
|
def close(self):
|
|||
|
|
"""关闭数据库连接"""
|
|||
|
|
if hasattr(self._local, 'connection') and self._local.connection:
|
|||
|
|
self._local.connection.close()
|
|||
|
|
self._local.connection = None
|
|||
|
|
logger.info("Database connection closed")
|
|||
|
|
|
|||
|
|
def vacuum(self):
|
|||
|
|
"""优化数据库"""
|
|||
|
|
try:
|
|||
|
|
conn = self._get_connection()
|
|||
|
|
conn.execute("VACUUM")
|
|||
|
|
logger.info("Database vacuumed successfully")
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Failed to vacuum database: {e}")
|
|||
|
|
|
|||
|
|
def backup(self, backup_path: Path):
|
|||
|
|
"""
|
|||
|
|
备份数据库
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
backup_path: 备份文件路径
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
import shutil
|
|||
|
|
shutil.copy2(self._db_path, backup_path)
|
|||
|
|
logger.info(f"Database backed up to: {backup_path}")
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Failed to backup database: {e}")
|
|||
|
|
raise
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局数据库管理器实例
|
|||
|
|
db_manager = DatabaseManager()
|