2025-03-27 16:27:41 +08:00

108 lines
2.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package model
import (
"fmt"
"log"
"sync"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gitee.ltd/lxh/wechat-robot/internal/config"
)
var (
db *gorm.DB
dbOnce sync.Once
)
// InitDB 初始化数据库连接
func InitDB(cfg *config.DatabaseConfig) error {
var err error
dbOnce.Do(func() {
gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
}
dsn := cfg.DSN()
// 根据数据库类型选择对应的驱动
switch cfg.Type {
case config.PostgreSQL:
db, err = gorm.Open(postgres.Open(dsn), gormConfig)
case config.MySQL:
db, err = gorm.Open(mysql.Open(dsn), gormConfig)
case config.SQLite:
db, err = gorm.Open(sqlite.Open(dsn), gormConfig)
default:
err = fmt.Errorf("unsupported database type: %s", cfg.Type)
return
}
if err != nil {
log.Fatalf("Failed to connect to database: %v", err)
return
}
// 自动迁移数据库模型
err = migrateDB()
if err != nil {
log.Fatalf("Failed to migrate database: %v", err)
return
}
// 对于SQLite执行一些特定的优化
if cfg.Type == config.SQLite {
sqlDB, err := db.DB()
if err != nil {
log.Printf("Warning: Could not get underlying SQL DB: %v", err)
return
}
// 启用外键约束
sqlDB.Exec("PRAGMA foreign_keys = ON")
// 设置连接池大小
sqlDB.SetMaxOpenConns(1) // SQLite建议使用单连接
}
})
return err
}
// GetDB 获取数据库连接实例
func GetDB() *gorm.DB {
if db == nil {
panic("Database not initialized, call InitDB first")
}
return db
}
// CloseDB 关闭数据库连接
func CloseDB() error {
if db == nil {
return nil
}
sqlDB, err := db.DB()
if err != nil {
return fmt.Errorf("get sql.DB instance error: %w", err)
}
return sqlDB.Close()
}
// migrateDB 进行数据库迁移
func migrateDB() error {
// 在这里添加需要自动迁移的模型
return db.AutoMigrate(
&Robot{},
&Contact{},
&GroupMember{},
&Message{},
)
}