180 lines
5.5 KiB
Python
180 lines
5.5 KiB
Python
"""用户管理服务 - 内存数据库模块。
|
||
|
||
使用线程安全的 dict 模拟数据库存储,提供基础的 CRUD 操作。
|
||
所有数据存储在内存中,服务重启后数据会丢失。
|
||
"""
|
||
|
||
import threading
|
||
import uuid
|
||
from datetime import datetime
|
||
from typing import Optional
|
||
|
||
from app.models import UserInDB
|
||
|
||
|
||
class InMemoryDatabase:
|
||
"""基于内存的简易数据库,线程安全。
|
||
|
||
使用 dict 存储用户数据,key 为用户 ID,value 为 UserInDB 对象。
|
||
同时维护 username -> user_id 的映射表以支持按用户名查询。
|
||
|
||
Attributes:
|
||
_users: 用户数据存储,格式 {user_id: UserInDB}
|
||
_username_index: 用户名索引,格式 {username: user_id}
|
||
_lock: 线程锁,保证并发安全
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""初始化内存数据库,并插入预设的假数据。"""
|
||
self._users: dict[str, UserInDB] = {}
|
||
self._username_index: dict[str, str] = {}
|
||
self._lock = threading.Lock()
|
||
self._init_seed_data()
|
||
|
||
def _init_seed_data(self):
|
||
"""初始化预设的假数据,用于演示和测试。"""
|
||
seed_users = [
|
||
{
|
||
"id": "user-001",
|
||
"username": "alice",
|
||
"email": "alice@example.com",
|
||
"hashed_password": "e99a18c428cb38d5f260853678922e03", # hash of "abc123"
|
||
"created_at": "2024-01-15T08:30:00",
|
||
"is_active": True,
|
||
},
|
||
{
|
||
"id": "user-002",
|
||
"username": "bob",
|
||
"email": "bob@example.com",
|
||
"hashed_password": "e99a18c428cb38d5f260853678922e03",
|
||
"created_at": "2024-02-20T10:15:00",
|
||
"is_active": True,
|
||
},
|
||
{
|
||
"id": "user-003",
|
||
"username": "charlie",
|
||
"email": "charlie@example.com",
|
||
"hashed_password": "e99a18c428cb38d5f260853678922e03",
|
||
"created_at": "2024-03-10T14:45:00",
|
||
"is_active": True,
|
||
},
|
||
]
|
||
for user_data in seed_users:
|
||
user = UserInDB(**user_data)
|
||
self._users[user.id] = user
|
||
self._username_index[user.username] = user.id
|
||
|
||
def insert_user(self, user: UserInDB) -> UserInDB:
|
||
"""插入一个新用户。
|
||
|
||
Args:
|
||
user: 要插入的用户对象
|
||
|
||
Returns:
|
||
插入后的用户对象
|
||
"""
|
||
with self._lock:
|
||
self._users[user.id] = user
|
||
self._username_index[user.username] = user.id
|
||
return user
|
||
|
||
def get_user_by_id(self, user_id: str) -> Optional[UserInDB]:
|
||
"""根据用户 ID 查询用户。
|
||
|
||
Args:
|
||
user_id: 用户唯一标识
|
||
|
||
Returns:
|
||
如果找到则返回用户对象,否则返回 None
|
||
"""
|
||
with self._lock:
|
||
return self._users.get(user_id)
|
||
|
||
def get_user_by_username(self, username: str) -> Optional[UserInDB]:
|
||
"""根据用户名查询用户。
|
||
|
||
Args:
|
||
username: 用户名
|
||
|
||
Returns:
|
||
如果找到则返回用户对象,否则返回 None
|
||
"""
|
||
with self._lock:
|
||
user_id = self._username_index.get(username)
|
||
if user_id is None:
|
||
return None
|
||
return self._users.get(user_id)
|
||
|
||
def get_all_users(self, skip: int = 0, limit: int = 10) -> list[UserInDB]:
|
||
"""获取用户列表,支持分页。
|
||
|
||
Args:
|
||
skip: 跳过的记录数,默认 0
|
||
limit: 返回的最大记录数,默认 10
|
||
|
||
Returns:
|
||
用户对象列表
|
||
"""
|
||
with self._lock:
|
||
all_users = list(self._users.values())
|
||
return all_users[skip : skip + limit]
|
||
|
||
def update_user(self, user_id: str, update_data: dict) -> Optional[UserInDB]:
|
||
"""更新用户信息。
|
||
|
||
Args:
|
||
user_id: 要更新的用户 ID
|
||
update_data: 要更新的字段字典
|
||
|
||
Returns:
|
||
更新后的用户对象,如果用户不存在则返回 None
|
||
"""
|
||
with self._lock:
|
||
user = self._users.get(user_id)
|
||
if user is None:
|
||
return None
|
||
|
||
# 如果更新了用户名,需要同步更新索引
|
||
old_username = user.username
|
||
new_username = update_data.get("username")
|
||
if new_username and new_username != old_username:
|
||
# 检查新用户名是否已被占用
|
||
if new_username in self._username_index:
|
||
return None # 用户名冲突
|
||
del self._username_index[old_username]
|
||
self._username_index[new_username] = user_id
|
||
|
||
updated_user = user.model_copy(update=update_data)
|
||
self._users[user_id] = updated_user
|
||
return updated_user
|
||
|
||
def delete_user(self, user_id: str) -> bool:
|
||
"""删除用户。
|
||
|
||
Args:
|
||
user_id: 要删除的用户 ID
|
||
|
||
Returns:
|
||
删除成功返回 True,用户不存在返回 False
|
||
"""
|
||
with self._lock:
|
||
user = self._users.get(user_id)
|
||
if user is None:
|
||
return False
|
||
del self._users[user_id]
|
||
del self._username_index[user.username]
|
||
return True
|
||
|
||
def count_users(self) -> int:
|
||
"""统计用户总数。
|
||
|
||
Returns:
|
||
用户数量
|
||
"""
|
||
with self._lock:
|
||
return len(self._users)
|
||
|
||
|
||
# 全局单例数据库实例
|
||
db = InMemoryDatabase()
|