cvc/src/loreal.com/dit/middlewares/CircuitBreaker.go

315 lines
7.8 KiB
Go

package middlewares
import (
"loreal.com/dit/endpoint"
"fmt"
"net/http"
"sync"
"time"
)
// CircuitBreakerState is a type that represents a state of CircuitBreaker.
type CircuitBreakerState int
// These constants are states of CircuitBreaker.
const (
StateClosed CircuitBreakerState = iota
StateHalfOpen
StateOpen
)
// String implements stringer interface.
func (s CircuitBreakerState) String() string {
switch s {
case StateClosed:
return "closed"
case StateHalfOpen:
return "half-open"
case StateOpen:
return "open"
default:
return fmt.Sprintf("unknown state: %d", s)
}
}
// CircuitBreakerCounts holds the numbers of requests and their successes/failures.
// CircuitBreaker clears the internal Counts either
// on the change of the state or at the closed-state intervals.
// Counts ignores the results of the requests sent before clearing.
type CircuitBreakerCounts struct {
Requests uint32
TotalSuccesses uint32
TotalFailures uint32
ConsecutiveSuccesses uint32
ConsecutiveFailures uint32
}
func (c *CircuitBreakerCounts) onRequest() {
c.Requests++
}
func (c *CircuitBreakerCounts) onSuccess() {
c.TotalSuccesses++
c.ConsecutiveSuccesses++
c.ConsecutiveFailures = 0
}
func (c *CircuitBreakerCounts) onFailure() {
c.TotalFailures++
c.ConsecutiveFailures++
c.ConsecutiveSuccesses = 0
}
func (c *CircuitBreakerCounts) clear() {
c.Requests = 0
c.TotalSuccesses = 0
c.TotalFailures = 0
c.ConsecutiveSuccesses = 0
c.ConsecutiveFailures = 0
}
// CircuitBreakerSettings configures CircuitBreaker:
//
// Name is the name of the CircuitBreaker.
//
// MaxRequests is the maximum number of requests allowed to pass through
// when the CircuitBreaker is half-open.
// If MaxRequests is 0, the CircuitBreaker allows only 1 request.
//
// Interval is the cyclic period of the closed state
// for the CircuitBreaker to clear the internal Counts.
// If Interval is 0, the CircuitBreaker doesn't clear internal Counts during the closed state.
//
// Timeout is the period of the open state,
// after which the state of the CircuitBreaker becomes half-open.
// If Timeout is 0, the timeout value of the CircuitBreaker is set to 60 seconds.
//
// ReadyToTrip is called with a copy of Counts whenever a request fails in the closed state.
// If ReadyToTrip returns true, the CircuitBreaker will be placed into the open state.
// If ReadyToTrip is nil, default ReadyToTrip is used.
// Default ReadyToTrip returns true when the number of consecutive failures is more than 5.
//
// OnStateChange is called whenever the state of the CircuitBreaker changes.
type CircuitBreakerSettings struct {
Name string
MaxRequests uint32
Interval time.Duration
Timeout time.Duration
ReadyToTrip func(counts CircuitBreakerCounts) bool
OnStateChange func(name string, from CircuitBreakerState, to CircuitBreakerState)
}
// CircuitBreaker is a state machine to prevent sending requests that are likely to fail.
type CircuitBreaker struct {
name string
maxRequests uint32
interval time.Duration
timeout time.Duration
readyToTrip func(counts CircuitBreakerCounts) bool
onStateChange func(name string, from CircuitBreakerState, to CircuitBreakerState)
mutex sync.Mutex
state CircuitBreakerState
generation uint64
counts CircuitBreakerCounts
expiry time.Time
}
// newCircuitBreaker returns a new CircuitBreaker configured with the given Settings.
func newCircuitBreaker(st CircuitBreakerSettings) *CircuitBreaker {
cb := new(CircuitBreaker)
cb.name = st.Name
cb.interval = st.Interval
cb.onStateChange = st.OnStateChange
if st.MaxRequests == 0 {
cb.maxRequests = 1
} else {
cb.maxRequests = st.MaxRequests
}
if st.Timeout == 0 {
cb.timeout = defaultTimeout
} else {
cb.timeout = st.Timeout
}
if st.ReadyToTrip == nil {
cb.readyToTrip = defaultReadyToTrip
} else {
cb.readyToTrip = st.ReadyToTrip
}
cb.toNewGeneration(time.Now())
return cb
}
const defaultTimeout = time.Duration(60) * time.Second
func defaultReadyToTrip(counts CircuitBreakerCounts) bool {
return counts.ConsecutiveFailures > 5
}
// State returns the current state of the CircuitBreaker.
func (cb *CircuitBreaker) State() CircuitBreakerState {
cb.mutex.Lock()
defer cb.mutex.Unlock()
now := time.Now()
state, _ := cb.currentState(now)
return state
}
// Execute runs the given request if the CircuitBreaker accepts it.
// Execute returns an error instantly if the CircuitBreaker rejects the request.
// Otherwise, Execute returns the result of the request.
// If a panic occurs in the request, the CircuitBreaker handles it as an error
// and causes the same panic again.
func (cb *CircuitBreaker) Execute(req func() (*http.Response, error)) (*http.Response, error) {
generation, err := cb.beforeRequest()
if err != nil {
return nil, err
}
defer func() {
e := recover()
if e != nil {
cb.afterRequest(generation, fmt.Errorf("panic in request"))
panic(e)
}
}()
result, err := req()
cb.afterRequest(generation, err)
return result, err
}
func (cb *CircuitBreaker) beforeRequest() (uint64, error) {
cb.mutex.Lock()
defer cb.mutex.Unlock()
now := time.Now()
state, generation := cb.currentState(now)
if state == StateOpen {
return generation, cb.errorStateOpen()
} else if state == StateHalfOpen && cb.counts.Requests >= cb.maxRequests {
return generation, fmt.Errorf("too many requests")
}
cb.counts.onRequest()
return generation, nil
}
func (cb *CircuitBreaker) afterRequest(before uint64, err error) {
cb.mutex.Lock()
defer cb.mutex.Unlock()
now := time.Now()
state, generation := cb.currentState(now)
if generation != before {
return
}
if err == nil {
cb.onSuccess(state, now)
} else {
cb.onFailure(state, now)
}
}
func (cb *CircuitBreaker) onSuccess(state CircuitBreakerState, now time.Time) {
switch state {
case StateClosed:
cb.counts.onSuccess()
case StateHalfOpen:
cb.counts.onSuccess()
if cb.counts.ConsecutiveSuccesses >= cb.maxRequests {
cb.setState(StateClosed, now)
}
}
}
func (cb *CircuitBreaker) onFailure(state CircuitBreakerState, now time.Time) {
switch state {
case StateClosed:
cb.counts.onFailure()
if cb.readyToTrip(cb.counts) {
cb.setState(StateOpen, now)
}
case StateHalfOpen:
cb.setState(StateOpen, now)
}
}
func (cb *CircuitBreaker) currentState(now time.Time) (CircuitBreakerState, uint64) {
switch cb.state {
case StateClosed:
if !cb.expiry.IsZero() && cb.expiry.Before(now) {
cb.toNewGeneration(now)
}
case StateOpen:
if cb.expiry.Before(now) {
cb.setState(StateHalfOpen, now)
}
}
return cb.state, cb.generation
}
func (cb *CircuitBreaker) setState(state CircuitBreakerState, now time.Time) {
if cb.state == state {
return
}
prev := cb.state
cb.state = state
cb.toNewGeneration(now)
if cb.onStateChange != nil {
cb.onStateChange(cb.name, prev, state)
}
}
func (cb *CircuitBreaker) toNewGeneration(now time.Time) {
cb.generation++
cb.counts.clear()
var zero time.Time
switch cb.state {
case StateClosed:
if cb.interval == 0 {
cb.expiry = zero
} else {
cb.expiry = now.Add(cb.interval)
}
case StateOpen:
cb.expiry = now.Add(cb.timeout)
default: // StateHalfOpen
cb.expiry = zero
}
}
func (cb *CircuitBreaker) errorStateOpen() error {
if cb.name == "" {
return fmt.Errorf("circuit breaker is open")
}
return fmt.Errorf("circuit breaker '%s' is open", cb.name)
}
//ClientCircuitBreaker returns a Middlewarte that extents a Endpoint with fault tolerance
//configured with the given attempts and backoff duration.
func ClientCircuitBreaker(settings CircuitBreakerSettings) endpoint.ClientMiddleware {
return func(c endpoint.HTTPClient) endpoint.HTTPClient {
return func(r *http.Request) (resp *http.Response, err error) {
cb := newCircuitBreaker(settings)
return cb.Execute(func() (*http.Response, error) {
return c.Do(r)
})
}
}
}