Files
DP/Server/init_db.py
zuowei1216 1b19ff1b92 20251222
2025-12-22 21:06:29 +08:00

121 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
数据库初始化脚本
功能:
1. 检查数据库连接
2. 创建所有定义的表(如果不存在)
3. 检查现有表的字段,如果缺失则自动添加
"""
import os
import sys
import logging
from sqlalchemy import create_engine, inspect, text
from app.core.config import settings
from app.db import Base
# 导入所有模型以确保它们被注册到 Base.metadata
from app.models import user, group, business, session
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def get_engine():
"""获取数据库引擎"""
# 优先使用环境变量中的配置,如果没有则构建
db_url = settings.DATABASE_URL
if not db_url:
host = os.getenv('DB_HOST', 'localhost')
port = os.getenv('DB_PORT', '3306')
user = os.getenv('DB_USER', 'root')
password = os.getenv('DB_PASSWORD', '')
db_name = os.getenv('DB_NAME', 'designer_db')
db_url = f"mysql+pymysql://{user}:{password}@{host}:{port}/{db_name}"
return create_engine(db_url)
def map_python_type_to_sql(col_type):
"""将 SQLAlchemy 类型映射为 MySQL 类型"""
type_str = str(col_type).lower()
if 'varchar' in type_str:
return type_str
if 'string' in type_str:
length = getattr(col_type, 'length', 255)
return f"varchar({length})"
if 'integer' in type_str or 'int' in type_str:
return "int"
if 'boolean' in type_str:
return "tinyint(1)"
if 'datetime' in type_str:
return "datetime"
if 'date' in type_str:
return "date"
if 'float' in type_str:
return "float"
if 'text' in type_str:
return "text"
return "varchar(255)" # 默认
def init_db():
logger.info("🔄 开始数据库初始化检查...")
try:
engine = get_engine()
inspector = inspect(engine)
# 1. 创建缺失的表
logger.info("📊 检查表结构...")
Base.metadata.create_all(bind=engine)
logger.info("✅ 基础表结构检查完成")
# 2. 检查并补充缺失的列
logger.info("🔍 检查缺失字段...")
existing_tables = inspector.get_table_names()
with engine.connect() as conn:
for table_name, table in Base.metadata.tables.items():
if table_name not in existing_tables:
continue
# 获取数据库中现有的列
existing_columns = [col['name'] for col in inspector.get_columns(table_name)]
# 检查模型定义的列
for column in table.columns:
if column.name not in existing_columns:
logger.info(f" 发现缺失字段: {table_name}.{column.name}")
# 构建 ALTER TABLE 语句
col_type = map_python_type_to_sql(column.type)
default_val = ""
# 处理默认值 (简化处理,只处理常见类型)
if column.default:
arg = column.default.arg
if isinstance(arg, (int, float, bool)):
if isinstance(arg, bool):
arg = 1 if arg else 0
default_val = f" DEFAULT {arg}"
elif isinstance(arg, str):
default_val = f" DEFAULT '{arg}'"
nullable = "NULL" if column.nullable else "NOT NULL"
if column.nullable and not default_val:
default_val = " DEFAULT NULL"
sql = f"ALTER TABLE {table_name} ADD COLUMN {column.name} {col_type} {nullable}{default_val};"
logger.info(f" 🚀 执行: {sql}")
conn.execute(text(sql))
conn.commit()
logger.info("✅ 数据库同步完成!")
except Exception as e:
logger.error(f"❌ 数据库初始化失败: {e}")
# 不抛出异常,以免阻断容器启动(如果是网络波动等临时问题)
# 但在生产环境中可能需要抛出
if __name__ == "__main__":
init_db()