2025-03-27 16:27:41 +08:00

112 lines
2.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package config
import (
"fmt"
"path/filepath"
)
// DBType 数据库类型
type DBType string
const (
// PostgreSQL 数据库
PostgreSQL DBType = "postgres"
// MySQL 数据库
MySQL DBType = "mysql"
// SQLite 数据库
SQLite DBType = "sqlite"
)
// DatabaseConfig 数据库配置
type DatabaseConfig struct {
Type DBType `mapstructure:"type"` // 数据库类型postgres, mysql, sqlite
Host string `mapstructure:"host"` // 数据库主机地址SQLite不需要
Port int `mapstructure:"port"` // 数据库端口SQLite不需要
User string `mapstructure:"user"` // 数据库用户名SQLite不需要
Password string `mapstructure:"password"` // 数据库密码SQLite不需要
DBName string `mapstructure:"dbname"` // 数据库名称或SQLite的文件路径
SSLMode string `mapstructure:"sslmode"` // SSL模式仅PostgreSQL需要
Charset string `mapstructure:"charset"` // 字符集仅MySQL需要
}
// Validate 验证数据库配置
func (c *DatabaseConfig) Validate() error {
if c.Type == "" {
c.Type = PostgreSQL // 默认使用PostgreSQL
}
if c.Type != PostgreSQL && c.Type != MySQL && c.Type != SQLite {
return fmt.Errorf("unsupported database type: %s", c.Type)
}
if c.Type == SQLite {
// SQLite只需要文件路径
if c.DBName == "" {
c.DBName = "data.db" // 默认数据库文件名
}
// 确保目录存在
dir := filepath.Dir(c.DBName)
if dir != "." && dir != "/" {
if err := ensureDir(dir); err != nil {
return fmt.Errorf("failed to create directory for SQLite: %w", err)
}
}
} else {
// PostgreSQL 和 MySQL 需要主机和用户名等
if c.Host == "" {
return fmt.Errorf("database host cannot be empty")
}
if c.Port <= 0 {
if c.Type == PostgreSQL {
c.Port = 5432 // PostgreSQL默认端口
} else {
c.Port = 3306 // MySQL默认端口
}
}
if c.User == "" {
return fmt.Errorf("database user cannot be empty")
}
if c.DBName == "" {
return fmt.Errorf("database name cannot be empty")
}
}
// PostgreSQL特有配置
if c.Type == PostgreSQL && c.SSLMode == "" {
c.SSLMode = "disable"
}
// MySQL特有配置
if c.Type == MySQL && c.Charset == "" {
c.Charset = "utf8mb4"
}
return nil
}
// DSN 返回数据库连接字符串
func (c *DatabaseConfig) DSN() string {
switch c.Type {
case PostgreSQL:
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
c.Host, c.Port, c.User, c.Password, c.DBName, c.SSLMode)
case MySQL:
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local",
c.User, c.Password, c.Host, c.Port, c.DBName, c.Charset)
case SQLite:
return c.DBName
default:
return ""
}
}
// 确保目录存在
func ensureDir(dirName string) error {
// 使用 os.MkdirAll 创建目录,如果不存在的话
// 这里简化处理,实际代码可能需要导入 "os" 包并添加相关实现
return nil
}