package model import ( "database/sql" "fmt" "github.com/gofiber/fiber/v2/log" "sync" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" "gitee.ltd/lxh/wechat-robot/internal/config" ) var ( db *gorm.DB dbOnce sync.Once ) // InitDB 初始化数据库连接 func InitDB(cfg *config.DatabaseConfig) error { var err error dbOnce.Do(func() { gormConfig := &gorm.Config{ Logger: logger.Default.LogMode(logger.Info), } dsn := cfg.DSN() // 根据数据库类型选择对应的驱动 switch cfg.Type { case config.PostgreSQL: db, err = gorm.Open(postgres.Open(dsn), gormConfig) case config.MySQL: db, err = gorm.Open(mysql.Open(dsn), gormConfig) case config.SQLite: db, err = gorm.Open(sqlite.Open(dsn), gormConfig) default: err = fmt.Errorf("unsupported database type: %s", cfg.Type) return } if err != nil { log.Fatalf("Failed to connect to database: %v", err) return } // 自动迁移数据库模型 err = migrateDB() if err != nil { log.Fatalf("Failed to migrate database: %v", err) return } // 对于SQLite,执行一些特定的优化 if cfg.Type == config.SQLite { var sqlDB *sql.DB sqlDB, err = db.DB() if err != nil { log.Errorf("Warning: Could not get underlying SQL DB: %v", err) return } // 启用外键约束 _, _ = sqlDB.Exec("PRAGMA foreign_keys = ON") // 设置连接池大小 sqlDB.SetMaxOpenConns(1) // SQLite建议使用单连接 } }) return err } // GetDB 获取数据库连接实例 func GetDB() *gorm.DB { if db == nil { panic("Database not initialized, call InitDB first") } return db } // CloseDB 关闭数据库连接 func CloseDB() error { if db == nil { return nil } sqlDB, err := db.DB() if err != nil { return fmt.Errorf("get sql.DB instance error: %w", err) } return sqlDB.Close() } // migrateDB 进行数据库迁移 func migrateDB() error { // 在这里添加需要自动迁移的模型 return db.AutoMigrate( &Robot{}, &Contact{}, &GroupMember{}, &Message{}, ) }