121 lines
4.4 KiB
Python
121 lines
4.4 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
数据库初始化脚本
|
||
功能:
|
||
1. 检查数据库连接
|
||
2. 创建所有定义的表(如果不存在)
|
||
3. 检查现有表的字段,如果缺失则自动添加
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import logging
|
||
from sqlalchemy import create_engine, inspect, text
|
||
from app.core.config import settings
|
||
from app.db import Base
|
||
# 导入所有模型以确保它们被注册到 Base.metadata
|
||
from app.models import user, group, business, session
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
def get_engine():
|
||
"""获取数据库引擎"""
|
||
# 优先使用环境变量中的配置,如果没有则构建
|
||
db_url = settings.DATABASE_URL
|
||
if not db_url:
|
||
host = os.getenv('DB_HOST', 'localhost')
|
||
port = os.getenv('DB_PORT', '3306')
|
||
user = os.getenv('DB_USER', 'root')
|
||
password = os.getenv('DB_PASSWORD', '')
|
||
db_name = os.getenv('DB_NAME', 'designer_db')
|
||
db_url = f"mysql+pymysql://{user}:{password}@{host}:{port}/{db_name}"
|
||
|
||
return create_engine(db_url)
|
||
|
||
def map_python_type_to_sql(col_type):
|
||
"""将 SQLAlchemy 类型映射为 MySQL 类型"""
|
||
type_str = str(col_type).lower()
|
||
if 'varchar' in type_str:
|
||
return type_str
|
||
if 'string' in type_str:
|
||
length = getattr(col_type, 'length', 255)
|
||
return f"varchar({length})"
|
||
if 'integer' in type_str or 'int' in type_str:
|
||
return "int"
|
||
if 'boolean' in type_str:
|
||
return "tinyint(1)"
|
||
if 'datetime' in type_str:
|
||
return "datetime"
|
||
if 'date' in type_str:
|
||
return "date"
|
||
if 'float' in type_str:
|
||
return "float"
|
||
if 'text' in type_str:
|
||
return "text"
|
||
return "varchar(255)" # 默认
|
||
|
||
def init_db():
|
||
logger.info("🔄 开始数据库初始化检查...")
|
||
|
||
try:
|
||
engine = get_engine()
|
||
inspector = inspect(engine)
|
||
|
||
# 1. 创建缺失的表
|
||
logger.info("📊 检查表结构...")
|
||
Base.metadata.create_all(bind=engine)
|
||
logger.info("✅ 基础表结构检查完成")
|
||
|
||
# 2. 检查并补充缺失的列
|
||
logger.info("🔍 检查缺失字段...")
|
||
|
||
existing_tables = inspector.get_table_names()
|
||
|
||
with engine.connect() as conn:
|
||
for table_name, table in Base.metadata.tables.items():
|
||
if table_name not in existing_tables:
|
||
continue
|
||
|
||
# 获取数据库中现有的列
|
||
existing_columns = [col['name'] for col in inspector.get_columns(table_name)]
|
||
|
||
# 检查模型定义的列
|
||
for column in table.columns:
|
||
if column.name not in existing_columns:
|
||
logger.info(f" ➕ 发现缺失字段: {table_name}.{column.name}")
|
||
|
||
# 构建 ALTER TABLE 语句
|
||
col_type = map_python_type_to_sql(column.type)
|
||
default_val = ""
|
||
|
||
# 处理默认值 (简化处理,只处理常见类型)
|
||
if column.default:
|
||
arg = column.default.arg
|
||
if isinstance(arg, (int, float, bool)):
|
||
if isinstance(arg, bool):
|
||
arg = 1 if arg else 0
|
||
default_val = f" DEFAULT {arg}"
|
||
elif isinstance(arg, str):
|
||
default_val = f" DEFAULT '{arg}'"
|
||
|
||
nullable = "NULL" if column.nullable else "NOT NULL"
|
||
if column.nullable and not default_val:
|
||
default_val = " DEFAULT NULL"
|
||
|
||
sql = f"ALTER TABLE {table_name} ADD COLUMN {column.name} {col_type} {nullable}{default_val};"
|
||
logger.info(f" 🚀 执行: {sql}")
|
||
conn.execute(text(sql))
|
||
|
||
conn.commit()
|
||
logger.info("✅ 数据库同步完成!")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 数据库初始化失败: {e}")
|
||
# 不抛出异常,以免阻断容器启动(如果是网络波动等临时问题)
|
||
# 但在生产环境中可能需要抛出
|
||
|
||
if __name__ == "__main__":
|
||
init_db()
|