SIT/database/db_manage.py

210 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据库管理器
负责数据库连接、初始化和事务管理
"""
import sqlite3
import threading
from pathlib import Path
from typing import Optional, Any, List, Tuple
from contextlib import contextmanager
from config import Config
from utils.logger import get_logger
logger = get_logger(__name__)
class DatabaseManager:
"""
数据库管理器(单例模式)
提供数据库连接池和事务管理功能
"""
_instance = None
_lock = threading.Lock()
def __new__(cls):
"""单例模式实现"""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
"""初始化数据库管理器"""
if not hasattr(self, '_initialized'):
self._db_path = Config.DATABASE_PATH
self._local = threading.local()
self._initialized = True
self._initialize_database()
logger.info(f"Database manager initialized: {self._db_path}")
def _initialize_database(self):
"""初始化数据库表结构"""
schema_path = Config.DATABASE_DIR / "schema.sql"
if not schema_path.exists():
logger.warning(f"Schema file not found: {schema_path}")
return
try:
with open(schema_path, 'r', encoding='utf-8') as f:
schema_sql = f.read()
conn = self._get_connection()
cursor = conn.cursor()
cursor.executescript(schema_sql)
conn.commit()
logger.info("Database schema initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize database schema: {e}")
raise
def _get_connection(self) -> sqlite3.Connection:
"""获取当前线程的数据库连接"""
if not hasattr(self._local, 'connection') or self._local.connection is None:
self._local.connection = sqlite3.connect(
str(self._db_path),
check_same_thread=False,
timeout=30.0
)
self._local.connection.row_factory = sqlite3.Row
# 启用外键约束
self._local.connection.execute("PRAGMA foreign_keys = ON")
logger.debug(f"Created new database connection for thread {threading.current_thread().name}")
return self._local.connection
@contextmanager
def get_cursor(self):
"""
获取数据库游标的上下文管理器
使用示例:
with db_manager.get_cursor() as cursor:
cursor.execute("SELECT * FROM fields")
"""
conn = self._get_connection()
cursor = conn.cursor()
try:
yield cursor
conn.commit()
except Exception as e:
conn.rollback()
logger.error(f"Database operation failed: {e}")
raise
finally:
cursor.close()
@contextmanager
def transaction(self):
"""
事务上下文管理器
使用示例:
with db_manager.transaction():
# 执行多个数据库操作
pass
"""
conn = self._get_connection()
try:
yield conn
conn.commit()
logger.debug("Transaction committed successfully")
except Exception as e:
conn.rollback()
logger.error(f"Transaction rolled back: {e}")
raise
def execute_query(self, sql: str, params: Optional[Tuple] = None) -> List[sqlite3.Row]:
"""
执行查询SQL
Args:
sql: SQL查询语句
params: 查询参数
Returns:
查询结果列表
"""
with self.get_cursor() as cursor:
if params:
cursor.execute(sql, params)
else:
cursor.execute(sql)
return cursor.fetchall()
def execute_update(self, sql: str, params: Optional[Tuple] = None) -> int:
"""
执行更新SQLINSERT, UPDATE, DELETE
Args:
sql: SQL更新语句
params: 更新参数
Returns:
受影响的行数
"""
with self.get_cursor() as cursor:
if params:
cursor.execute(sql, params)
else:
cursor.execute(sql)
return cursor.rowcount
def execute_many(self, sql: str, params_list: List[Tuple]) -> int:
"""
批量执行SQL
Args:
sql: SQL语句
params_list: 参数列表
Returns:
受影响的行数
"""
with self.get_cursor() as cursor:
cursor.executemany(sql, params_list)
return cursor.rowcount
def get_last_insert_id(self) -> int:
"""获取最后插入的行ID"""
conn = self._get_connection()
return conn.execute("SELECT last_insert_rowid()").fetchone()[0]
def close(self):
"""关闭数据库连接"""
if hasattr(self._local, 'connection') and self._local.connection:
self._local.connection.close()
self._local.connection = None
logger.info("Database connection closed")
def vacuum(self):
"""优化数据库"""
try:
conn = self._get_connection()
conn.execute("VACUUM")
logger.info("Database vacuumed successfully")
except Exception as e:
logger.error(f"Failed to vacuum database: {e}")
def backup(self, backup_path: Path):
"""
备份数据库
Args:
backup_path: 备份文件路径
"""
try:
import shutil
shutil.copy2(self._db_path, backup_path)
logger.info(f"Database backed up to: {backup_path}")
except Exception as e:
logger.error(f"Failed to backup database: {e}")
raise
# 全局数据库管理器实例
db_manager = DatabaseManager()