210 lines
6.0 KiB
Python
210 lines
6.0 KiB
Python
"""
|
||
数据库管理器
|
||
负责数据库连接、初始化和事务管理
|
||
"""
|
||
import sqlite3
|
||
import threading
|
||
from pathlib import Path
|
||
from typing import Optional, Any, List, Tuple
|
||
from contextlib import contextmanager
|
||
|
||
from config import Config
|
||
from utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class DatabaseManager:
|
||
"""
|
||
数据库管理器(单例模式)
|
||
提供数据库连接池和事务管理功能
|
||
"""
|
||
|
||
_instance = None
|
||
_lock = threading.Lock()
|
||
|
||
def __new__(cls):
|
||
"""单例模式实现"""
|
||
if cls._instance is None:
|
||
with cls._lock:
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
"""初始化数据库管理器"""
|
||
if not hasattr(self, '_initialized'):
|
||
self._db_path = Config.DATABASE_PATH
|
||
self._local = threading.local()
|
||
self._initialized = True
|
||
self._initialize_database()
|
||
logger.info(f"Database manager initialized: {self._db_path}")
|
||
|
||
def _initialize_database(self):
|
||
"""初始化数据库表结构"""
|
||
schema_path = Config.DATABASE_DIR / "schema.sql"
|
||
|
||
if not schema_path.exists():
|
||
logger.warning(f"Schema file not found: {schema_path}")
|
||
return
|
||
|
||
try:
|
||
with open(schema_path, 'r', encoding='utf-8') as f:
|
||
schema_sql = f.read()
|
||
|
||
conn = self._get_connection()
|
||
cursor = conn.cursor()
|
||
cursor.executescript(schema_sql)
|
||
conn.commit()
|
||
logger.info("Database schema initialized successfully")
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize database schema: {e}")
|
||
raise
|
||
|
||
def _get_connection(self) -> sqlite3.Connection:
|
||
"""获取当前线程的数据库连接"""
|
||
if not hasattr(self._local, 'connection') or self._local.connection is None:
|
||
self._local.connection = sqlite3.connect(
|
||
str(self._db_path),
|
||
check_same_thread=False,
|
||
timeout=30.0
|
||
)
|
||
self._local.connection.row_factory = sqlite3.Row
|
||
# 启用外键约束
|
||
self._local.connection.execute("PRAGMA foreign_keys = ON")
|
||
logger.debug(f"Created new database connection for thread {threading.current_thread().name}")
|
||
|
||
return self._local.connection
|
||
|
||
@contextmanager
|
||
def get_cursor(self):
|
||
"""
|
||
获取数据库游标的上下文管理器
|
||
|
||
使用示例:
|
||
with db_manager.get_cursor() as cursor:
|
||
cursor.execute("SELECT * FROM fields")
|
||
"""
|
||
conn = self._get_connection()
|
||
cursor = conn.cursor()
|
||
try:
|
||
yield cursor
|
||
conn.commit()
|
||
except Exception as e:
|
||
conn.rollback()
|
||
logger.error(f"Database operation failed: {e}")
|
||
raise
|
||
finally:
|
||
cursor.close()
|
||
|
||
@contextmanager
|
||
def transaction(self):
|
||
"""
|
||
事务上下文管理器
|
||
|
||
使用示例:
|
||
with db_manager.transaction():
|
||
# 执行多个数据库操作
|
||
pass
|
||
"""
|
||
conn = self._get_connection()
|
||
try:
|
||
yield conn
|
||
conn.commit()
|
||
logger.debug("Transaction committed successfully")
|
||
except Exception as e:
|
||
conn.rollback()
|
||
logger.error(f"Transaction rolled back: {e}")
|
||
raise
|
||
|
||
def execute_query(self, sql: str, params: Optional[Tuple] = None) -> List[sqlite3.Row]:
|
||
"""
|
||
执行查询SQL
|
||
|
||
Args:
|
||
sql: SQL查询语句
|
||
params: 查询参数
|
||
|
||
Returns:
|
||
查询结果列表
|
||
"""
|
||
with self.get_cursor() as cursor:
|
||
if params:
|
||
cursor.execute(sql, params)
|
||
else:
|
||
cursor.execute(sql)
|
||
return cursor.fetchall()
|
||
|
||
def execute_update(self, sql: str, params: Optional[Tuple] = None) -> int:
|
||
"""
|
||
执行更新SQL(INSERT, UPDATE, DELETE)
|
||
|
||
Args:
|
||
sql: SQL更新语句
|
||
params: 更新参数
|
||
|
||
Returns:
|
||
受影响的行数
|
||
"""
|
||
with self.get_cursor() as cursor:
|
||
if params:
|
||
cursor.execute(sql, params)
|
||
else:
|
||
cursor.execute(sql)
|
||
return cursor.rowcount
|
||
|
||
def execute_many(self, sql: str, params_list: List[Tuple]) -> int:
|
||
"""
|
||
批量执行SQL
|
||
|
||
Args:
|
||
sql: SQL语句
|
||
params_list: 参数列表
|
||
|
||
Returns:
|
||
受影响的行数
|
||
"""
|
||
with self.get_cursor() as cursor:
|
||
cursor.executemany(sql, params_list)
|
||
return cursor.rowcount
|
||
|
||
def get_last_insert_id(self) -> int:
|
||
"""获取最后插入的行ID"""
|
||
conn = self._get_connection()
|
||
return conn.execute("SELECT last_insert_rowid()").fetchone()[0]
|
||
|
||
def close(self):
|
||
"""关闭数据库连接"""
|
||
if hasattr(self._local, 'connection') and self._local.connection:
|
||
self._local.connection.close()
|
||
self._local.connection = None
|
||
logger.info("Database connection closed")
|
||
|
||
def vacuum(self):
|
||
"""优化数据库"""
|
||
try:
|
||
conn = self._get_connection()
|
||
conn.execute("VACUUM")
|
||
logger.info("Database vacuumed successfully")
|
||
except Exception as e:
|
||
logger.error(f"Failed to vacuum database: {e}")
|
||
|
||
def backup(self, backup_path: Path):
|
||
"""
|
||
备份数据库
|
||
|
||
Args:
|
||
backup_path: 备份文件路径
|
||
"""
|
||
try:
|
||
import shutil
|
||
shutil.copy2(self._db_path, backup_path)
|
||
logger.info(f"Database backed up to: {backup_path}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to backup database: {e}")
|
||
raise
|
||
|
||
|
||
# 全局数据库管理器实例
|
||
db_manager = DatabaseManager()
|