283 lines
6.7 KiB
Go
283 lines
6.7 KiB
Go
package global
|
|
//
|
|
//import (
|
|
// "encoding/json"
|
|
// "fmt"
|
|
// "github.com/go-oauth2/oauth2/v4"
|
|
// "github.com/go-oauth2/oauth2/v4/models"
|
|
// "gorm.io/driver/mysql"
|
|
// "gorm.io/driver/postgres"
|
|
// "gorm.io/driver/sqlserver"
|
|
// "gorm.io/gorm"
|
|
// "gorm.io/gorm/logger"
|
|
// "io"
|
|
// "log"
|
|
// "os"
|
|
// "time"
|
|
//)
|
|
//
|
|
//
|
|
//
|
|
//// NewConfig create mysql configuration instance
|
|
//func NewConfig(dsn string, dbType DBType, tableName string) *Config {
|
|
// return &Config{
|
|
// DSN: dsn,
|
|
// DBType: dbType,
|
|
// TableName: tableName,
|
|
// MaxLifetime: time.Hour * 2,
|
|
// }
|
|
//}
|
|
//
|
|
//// Config gorm configuration
|
|
//type Config struct {
|
|
// DSN string
|
|
// DBType DBType
|
|
// TableName string
|
|
// MaxLifetime time.Duration
|
|
//}
|
|
//
|
|
//type DBType int8
|
|
//
|
|
//const (
|
|
// MySQL = iota
|
|
// PostgreSQL
|
|
// SQLite
|
|
// SQLServer
|
|
//)
|
|
//
|
|
//var defaultConfig = &gorm.Config{
|
|
// Logger: logger.New(
|
|
// log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
|
|
// logger.Config{
|
|
// SlowThreshold: time.Second, // slow SQL
|
|
// LogLevel: logger.Info, // log level
|
|
// Colorful: true, // color
|
|
// },
|
|
// ),
|
|
//}
|
|
//
|
|
//// NewStore create mysql store instance,
|
|
//func NewStore(config *Config, gcInterval int) *Store {
|
|
// var d gorm.Dialector
|
|
// switch config.DBType {
|
|
// case MySQL:
|
|
// d = mysql.New(mysql.Config{
|
|
// DSN: config.DSN,
|
|
// })
|
|
// case PostgreSQL:
|
|
// d = postgres.New(postgres.Config{
|
|
// DSN: config.DSN,
|
|
// })
|
|
// case SQLite:
|
|
// d = sqlite.Open(config.DSN)
|
|
// case SQLServer:
|
|
// d = sqlserver.Open(config.DSN)
|
|
// default:
|
|
// fmt.Println("unsupported databases")
|
|
// return nil
|
|
// }
|
|
// db, err := gorm.Open(d, defaultConfig)
|
|
// if err != nil {
|
|
// panic(err)
|
|
// }
|
|
// // default client pool
|
|
// s, err := db.DB()
|
|
// if err != nil {
|
|
// panic(err)
|
|
// }
|
|
// s.SetMaxIdleConns(10)
|
|
// s.SetMaxOpenConns(100)
|
|
// s.SetConnMaxLifetime(time.Hour)
|
|
//
|
|
// return NewStoreWithDB(config, db, gcInterval)
|
|
//}
|
|
//
|
|
//func NewStoreWithDB(config *Config, db *gorm.DB, gcInterval int) *Store {
|
|
// store := &Store{
|
|
// db: db,
|
|
// tableName: "oauth2_token",
|
|
// stdout: os.Stderr,
|
|
// }
|
|
// if config.TableName != "" {
|
|
// store.tableName = config.TableName
|
|
// }
|
|
// interval := 600
|
|
// if gcInterval > 0 {
|
|
// interval = gcInterval
|
|
// }
|
|
// store.ticker = time.NewTicker(time.Second * time.Duration(interval))
|
|
//
|
|
// if !db.Migrator().HasTable(store.tableName) {
|
|
// if err := db.Table(store.tableName).Migrator().CreateTable(&StoreItem{}); err != nil {
|
|
// panic(err)
|
|
// }
|
|
// }
|
|
//
|
|
// go store.gc()
|
|
// return store
|
|
//}
|
|
//
|
|
//// Store mysql token store
|
|
//type Store struct {
|
|
// tableName string
|
|
// db *gorm.DB
|
|
// stdout io.Writer
|
|
// ticker *time.Ticker
|
|
//}
|
|
//
|
|
//// SetStdout set error output
|
|
//func (s *Store) SetStdout(stdout io.Writer) *Store {
|
|
// s.stdout = stdout
|
|
// return s
|
|
//}
|
|
//
|
|
//// Close close the store
|
|
//func (s *Store) Close() {
|
|
// s.ticker.Stop()
|
|
//}
|
|
//
|
|
//func (s *Store) errorf(format string, args ...interface{}) {
|
|
// if s.stdout != nil {
|
|
// buf := fmt.Sprintf(format, args...)
|
|
// s.stdout.Write([]byte(buf))
|
|
// }
|
|
//}
|
|
//
|
|
//func (s *Store) gc() {
|
|
// for range s.ticker.C {
|
|
// now := time.Now().Unix()
|
|
// var count int64
|
|
// if err := s.db.Table(s.tableName).Where("expired_at <= ?", now).Or("code = ? and access = ? AND refresh = ?", "", "", "").Count(&count).Error; err != nil {
|
|
// s.errorf("[ERROR]:%s\n", err)
|
|
// return
|
|
// }
|
|
// if count > 0 {
|
|
// // not soft delete.
|
|
// if err := s.db.Table(s.tableName).Where("expired_at <= ?", now).Or("code = ? and access = ? AND refresh = ?", "", "", "").Unscoped().Delete(&StoreItem{}).Error; err != nil {
|
|
// s.errorf("[ERROR]:%s\n", err)
|
|
// }
|
|
// }
|
|
// }
|
|
//}
|
|
//
|
|
//// Create create and store the new token information
|
|
//func (s *Store) Create(ctx context.Context, info oauth2.TokenInfo) error {
|
|
// jv, err := json.Marshal(info)
|
|
// if err != nil {
|
|
// return err
|
|
// }
|
|
// item := &StoreItem{
|
|
// Data: string(jv),
|
|
// }
|
|
//
|
|
// if code := info.GetCode(); code != "" {
|
|
// item.Code = code
|
|
// item.ExpiredAt = info.GetCodeCreateAt().Add(info.GetCodeExpiresIn()).Unix()
|
|
// } else {
|
|
// item.Access = info.GetAccess()
|
|
// item.ExpiredAt = info.GetAccessCreateAt().Add(info.GetAccessExpiresIn()).Unix()
|
|
//
|
|
// if refresh := info.GetRefresh(); refresh != "" {
|
|
// item.Refresh = info.GetRefresh()
|
|
// item.ExpiredAt = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Unix()
|
|
// }
|
|
// }
|
|
//
|
|
// return s.db.WithContext(ctx).Table(s.tableName).Create(item).Error
|
|
//}
|
|
//
|
|
//// RemoveByCode delete the authorization code
|
|
//func (s *Store) RemoveByCode(ctx context.Context, code string) error {
|
|
// return s.db.WithContext(ctx).
|
|
// Table(s.tableName).
|
|
// Where("code = ?", code).
|
|
// Update("code", "").
|
|
// Error
|
|
//}
|
|
//
|
|
//// RemoveByAccess use the access token to delete the token information
|
|
//func (s *Store) RemoveByAccess(ctx context.Context, access string) error {
|
|
// return s.db.WithContext(ctx).
|
|
// Table(s.tableName).
|
|
// Where("access = ?", access).
|
|
// Update("access", "").
|
|
// Error
|
|
//}
|
|
//
|
|
//// RemoveByRefresh use the refresh token to delete the token information
|
|
//func (s *Store) RemoveByRefresh(ctx context.Context, refresh string) error {
|
|
// return s.db.WithContext(ctx).
|
|
// Table(s.tableName).
|
|
// Where("refresh = ?", refresh).
|
|
// Update("refresh", "").
|
|
// Error
|
|
//}
|
|
//
|
|
//func (s *Store) toTokenInfo(data string) oauth2.TokenInfo {
|
|
// var tm models.Token
|
|
// err := json.Unmarshal([]byte(data), &tm)
|
|
// if err != nil {
|
|
// return nil
|
|
// }
|
|
// return &tm
|
|
//}
|
|
//
|
|
//// GetByCode use the authorization code for token information data
|
|
//func (s *Store) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
|
|
// if code == "" {
|
|
// return nil, nil
|
|
// }
|
|
//
|
|
// var item StoreItem
|
|
// if err := s.db.WithContext(ctx).
|
|
// Table(s.tableName).
|
|
// Where("code = ?", code).
|
|
// Find(&item).Error; err != nil {
|
|
// return nil, err
|
|
// }
|
|
// if item.ID == 0 {
|
|
// return nil, nil
|
|
// }
|
|
//
|
|
// return s.toTokenInfo(item.Data), nil
|
|
//}
|
|
//
|
|
//// GetByAccess use the access token for token information data
|
|
//func (s *Store) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
|
|
// if access == "" {
|
|
// return nil, nil
|
|
// }
|
|
//
|
|
// var item StoreItem
|
|
// if err := s.db.WithContext(ctx).
|
|
// Table(s.tableName).
|
|
// Where("access = ?", access).
|
|
// Find(&item).Error; err != nil {
|
|
// return nil, err
|
|
// }
|
|
// if item.ID == 0 {
|
|
// return nil, nil
|
|
// }
|
|
//
|
|
// return s.toTokenInfo(item.Data), nil
|
|
//}
|
|
//
|
|
//// GetByRefresh use the refresh token for token information data
|
|
//func (s *Store) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
|
|
// if refresh == "" {
|
|
// return nil, nil
|
|
// }
|
|
//
|
|
// var item StoreItem
|
|
// if err := s.db.WithContext(ctx).
|
|
// Table(s.tableName).
|
|
// Where("refresh = ?", refresh).
|
|
// Find(&item).Error; err != nil {
|
|
// return nil, err
|
|
// }
|
|
// if item.ID == 0 {
|
|
// return nil, nil
|
|
// }
|
|
//
|
|
// return s.toTokenInfo(item.Data), nil
|
|
//}
|