112 lines
2.9 KiB
Go
112 lines
2.9 KiB
Go
package config
|
||
|
||
import (
|
||
"fmt"
|
||
"path/filepath"
|
||
)
|
||
|
||
// DBType 数据库类型
|
||
type DBType string
|
||
|
||
const (
|
||
// PostgreSQL 数据库
|
||
PostgreSQL DBType = "postgres"
|
||
// MySQL 数据库
|
||
MySQL DBType = "mysql"
|
||
// SQLite 数据库
|
||
SQLite DBType = "sqlite"
|
||
)
|
||
|
||
// DatabaseConfig 数据库配置
|
||
type DatabaseConfig struct {
|
||
Type DBType `mapstructure:"type"` // 数据库类型:postgres, mysql, sqlite
|
||
Host string `mapstructure:"host"` // 数据库主机地址(SQLite不需要)
|
||
Port int `mapstructure:"port"` // 数据库端口(SQLite不需要)
|
||
User string `mapstructure:"user"` // 数据库用户名(SQLite不需要)
|
||
Password string `mapstructure:"password"` // 数据库密码(SQLite不需要)
|
||
DBName string `mapstructure:"dbname"` // 数据库名称或SQLite的文件路径
|
||
SSLMode string `mapstructure:"sslmode"` // SSL模式(仅PostgreSQL需要)
|
||
Charset string `mapstructure:"charset"` // 字符集(仅MySQL需要)
|
||
}
|
||
|
||
// Validate 验证数据库配置
|
||
func (c *DatabaseConfig) Validate() error {
|
||
if c.Type == "" {
|
||
c.Type = PostgreSQL // 默认使用PostgreSQL
|
||
}
|
||
|
||
if c.Type != PostgreSQL && c.Type != MySQL && c.Type != SQLite {
|
||
return fmt.Errorf("unsupported database type: %s", c.Type)
|
||
}
|
||
|
||
if c.Type == SQLite {
|
||
// SQLite只需要文件路径
|
||
if c.DBName == "" {
|
||
c.DBName = "data.db" // 默认数据库文件名
|
||
}
|
||
// 确保目录存在
|
||
dir := filepath.Dir(c.DBName)
|
||
if dir != "." && dir != "/" {
|
||
if err := ensureDir(dir); err != nil {
|
||
return fmt.Errorf("failed to create directory for SQLite: %w", err)
|
||
}
|
||
}
|
||
} else {
|
||
// PostgreSQL 和 MySQL 需要主机和用户名等
|
||
if c.Host == "" {
|
||
return fmt.Errorf("database host cannot be empty")
|
||
}
|
||
|
||
if c.Port <= 0 {
|
||
if c.Type == PostgreSQL {
|
||
c.Port = 5432 // PostgreSQL默认端口
|
||
} else {
|
||
c.Port = 3306 // MySQL默认端口
|
||
}
|
||
}
|
||
|
||
if c.User == "" {
|
||
return fmt.Errorf("database user cannot be empty")
|
||
}
|
||
|
||
if c.DBName == "" {
|
||
return fmt.Errorf("database name cannot be empty")
|
||
}
|
||
}
|
||
|
||
// PostgreSQL特有配置
|
||
if c.Type == PostgreSQL && c.SSLMode == "" {
|
||
c.SSLMode = "disable"
|
||
}
|
||
|
||
// MySQL特有配置
|
||
if c.Type == MySQL && c.Charset == "" {
|
||
c.Charset = "utf8mb4"
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// DSN 返回数据库连接字符串
|
||
func (c *DatabaseConfig) DSN() string {
|
||
switch c.Type {
|
||
case PostgreSQL:
|
||
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||
c.Host, c.Port, c.User, c.Password, c.DBName, c.SSLMode)
|
||
case MySQL:
|
||
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local",
|
||
c.User, c.Password, c.Host, c.Port, c.DBName, c.Charset)
|
||
case SQLite:
|
||
return c.DBName
|
||
default:
|
||
return ""
|
||
}
|
||
}
|
||
|
||
// 确保目录存在
|
||
func ensureDir(dirName string) error {
|
||
// 使用 os.MkdirAll 创建目录,如果不存在的话
|
||
// 这里简化处理,实际代码可能需要导入 "os" 包并添加相关实现
|
||
return nil
|
||
}
|