package config import ( "fmt" "path/filepath" ) // DBType 数据库类型 type DBType string const ( // PostgreSQL 数据库 PostgreSQL DBType = "postgres" // MySQL 数据库 MySQL DBType = "mysql" // SQLite 数据库 SQLite DBType = "sqlite" ) // DatabaseConfig 数据库配置 type DatabaseConfig struct { Type DBType `mapstructure:"type"` // 数据库类型:postgres, mysql, sqlite Host string `mapstructure:"host"` // 数据库主机地址(SQLite不需要) Port int `mapstructure:"port"` // 数据库端口(SQLite不需要) User string `mapstructure:"user"` // 数据库用户名(SQLite不需要) Password string `mapstructure:"password"` // 数据库密码(SQLite不需要) DBName string `mapstructure:"dbname"` // 数据库名称或SQLite的文件路径 SSLMode string `mapstructure:"sslmode"` // SSL模式(仅PostgreSQL需要) Charset string `mapstructure:"charset"` // 字符集(仅MySQL需要) } // Validate 验证数据库配置 func (c *DatabaseConfig) Validate() error { if c.Type == "" { c.Type = PostgreSQL // 默认使用PostgreSQL } if c.Type != PostgreSQL && c.Type != MySQL && c.Type != SQLite { return fmt.Errorf("unsupported database type: %s", c.Type) } if c.Type == SQLite { // SQLite只需要文件路径 if c.DBName == "" { c.DBName = "data.db" // 默认数据库文件名 } // 确保目录存在 dir := filepath.Dir(c.DBName) if dir != "." && dir != "/" { if err := ensureDir(dir); err != nil { return fmt.Errorf("failed to create directory for SQLite: %w", err) } } } else { // PostgreSQL 和 MySQL 需要主机和用户名等 if c.Host == "" { return fmt.Errorf("database host cannot be empty") } if c.Port <= 0 { if c.Type == PostgreSQL { c.Port = 5432 // PostgreSQL默认端口 } else { c.Port = 3306 // MySQL默认端口 } } if c.User == "" { return fmt.Errorf("database user cannot be empty") } if c.DBName == "" { return fmt.Errorf("database name cannot be empty") } } // PostgreSQL特有配置 if c.Type == PostgreSQL && c.SSLMode == "" { c.SSLMode = "disable" } // MySQL特有配置 if c.Type == MySQL && c.Charset == "" { c.Charset = "utf8mb4" } return nil } // DSN 返回数据库连接字符串 func (c *DatabaseConfig) DSN() string { switch c.Type { case PostgreSQL: return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", c.Host, c.Port, c.User, c.Password, c.DBName, c.SSLMode) case MySQL: return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local", c.User, c.Password, c.Host, c.Port, c.DBName, c.Charset) case SQLite: return c.DBName default: return "" } } // 确保目录存在 func ensureDir(dirName string) error { // 使用 os.MkdirAll 创建目录,如果不存在的话 // 这里简化处理,实际代码可能需要导入 "os" 包并添加相关实现 return nil }