Compare commits
No commits in common. 'faded4e634563291c0cb419df16ddbc681627925' and '17eefa152f2041d8a28cf95d7cd4c53a37629365' have entirely different histories.
faded4e634
...
17eefa152f
@ -1 +0,0 @@ |
|||||||
* text=auto eol=lf |
|
@ -0,0 +1,74 @@ |
|||||||
|
package middleware |
||||||
|
|
||||||
|
import ( |
||||||
|
"github.com/labstack/echo/v4" |
||||||
|
"github.com/labstack/echo/v4/middleware" |
||||||
|
) |
||||||
|
|
||||||
|
// KeyAuthValidator defines a function to validate KeyAuth credentials.
|
||||||
|
type KeyAuthValidator = middleware.KeyAuthValidator |
||||||
|
|
||||||
|
// KeyAuthErrorHandler defines a function which is executed for an invalid key.
|
||||||
|
type KeyAuthErrorHandler = middleware.KeyAuthErrorHandler |
||||||
|
|
||||||
|
// KeyAuthConfig defines the config for KeyAuth middleware.
|
||||||
|
type KeyAuthConfig struct { |
||||||
|
Skipper Skipper |
||||||
|
|
||||||
|
// KeyLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
|
||||||
|
// to extract key from the request.
|
||||||
|
// Optional. Default value "header:Authorization".
|
||||||
|
// Possible values:
|
||||||
|
// - "header:<name>" or "header:<name>:<cut-prefix>"
|
||||||
|
// `<cut-prefix>` is argument value to cut/trim prefix of the extracted value. This is useful if header
|
||||||
|
// value has static prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we
|
||||||
|
// want to cut is `<auth-scheme> ` note the space at the end.
|
||||||
|
// In case of basic authentication `Authorization: Basic <credentials>` prefix we want to remove is `Basic `.
|
||||||
|
// - "query:<name>"
|
||||||
|
// - "form:<name>"
|
||||||
|
// - "cookie:<name>"
|
||||||
|
// Multiple sources example:
|
||||||
|
// - "header:Authorization,header:X-Api-Key"
|
||||||
|
KeyLookup string |
||||||
|
|
||||||
|
// AuthScheme to be used in the Authorization header.
|
||||||
|
// Optional. Default value "Bearer".
|
||||||
|
AuthScheme string |
||||||
|
|
||||||
|
// Validator is a function to validate key.
|
||||||
|
// Required.
|
||||||
|
Validator KeyAuthValidator |
||||||
|
|
||||||
|
// ErrorHandler defines a function which is executed for an invalid key.
|
||||||
|
// It may be used to define a custom error.
|
||||||
|
ErrorHandler KeyAuthErrorHandler |
||||||
|
|
||||||
|
// ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to
|
||||||
|
// ignore the error (by returning `nil`).
|
||||||
|
// This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality.
|
||||||
|
// In that case you can use ErrorHandler to set a default public key auth value in the request context
|
||||||
|
// and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then.
|
||||||
|
ContinueOnIgnoredError bool |
||||||
|
} |
||||||
|
|
||||||
|
// DefaultKeyAuthConfig is the default KeyAuth middleware config.
|
||||||
|
var DefaultKeyAuthConfig = KeyAuthConfig{ |
||||||
|
Skipper: DefaultSkipper, |
||||||
|
KeyLookup: "header:" + echo.HeaderAuthorization, |
||||||
|
AuthScheme: "Bearer", |
||||||
|
} |
||||||
|
|
||||||
|
func (a *KeyAuthConfig) ToMiddleware() echo.MiddlewareFunc { |
||||||
|
return middleware.KeyAuthWithConfig(middleware.KeyAuthConfig{ |
||||||
|
Skipper: a.Skipper, |
||||||
|
KeyLookup: a.KeyLookup, |
||||||
|
AuthScheme: a.AuthScheme, |
||||||
|
Validator: a.Validator, |
||||||
|
ErrorHandler: a.ErrorHandler, |
||||||
|
ContinueOnIgnoredError: a.ContinueOnIgnoredError, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func KeyAuth() echo.MiddlewareFunc { |
||||||
|
return DefaultKeyAuthConfig.ToMiddleware() |
||||||
|
} |
@ -0,0 +1,24 @@ |
|||||||
|
package middleware |
||||||
|
|
||||||
|
import ( |
||||||
|
"github.com/labstack/echo/v4" |
||||||
|
"github.com/labstack/echo/v4/middleware" |
||||||
|
) |
||||||
|
|
||||||
|
// Skipper defines a function to skip middleware. Returning true skips processing
|
||||||
|
// the middleware.
|
||||||
|
type Skipper = middleware.Skipper |
||||||
|
|
||||||
|
// BeforeFunc defines a function which is executed just before the middleware.
|
||||||
|
type BeforeFunc = middleware.BeforeFunc |
||||||
|
|
||||||
|
type ValuesExtractor = middleware.ValuesExtractor |
||||||
|
|
||||||
|
type ToMiddleware interface { |
||||||
|
ToMiddleware() echo.MiddlewareFunc |
||||||
|
} |
||||||
|
|
||||||
|
// DefaultSkipper returns false which processes the middleware.
|
||||||
|
func DefaultSkipper(echo.Context) bool { |
||||||
|
return false |
||||||
|
} |
@ -0,0 +1,268 @@ |
|||||||
|
package middleware |
||||||
|
|
||||||
|
import ( |
||||||
|
"github.com/labstack/echo/v4/middleware" |
||||||
|
"net/http" |
||||||
|
"sync" |
||||||
|
"time" |
||||||
|
|
||||||
|
"github.com/labstack/echo/v4" |
||||||
|
"golang.org/x/time/rate" |
||||||
|
) |
||||||
|
|
||||||
|
// RateLimiterStore is the interface to be implemented by custom stores.
|
||||||
|
type RateLimiterStore interface { |
||||||
|
// Allow Stores for the rate limiter have to implement the Allow method
|
||||||
|
Allow(identifier string) (bool, error) |
||||||
|
} |
||||||
|
|
||||||
|
type ( |
||||||
|
// RateLimiterConfig defines the configuration for the rate limiter
|
||||||
|
RateLimiterConfig struct { |
||||||
|
Skipper Skipper |
||||||
|
BeforeFunc middleware.BeforeFunc |
||||||
|
// IdentifierExtractor uses echo.Context to extract the identifier for a visitor
|
||||||
|
IdentifierExtractor Extractor |
||||||
|
// Store defines a store for the rate limiter
|
||||||
|
Store RateLimiterStore |
||||||
|
// ErrorHandler provides a handler to be called when IdentifierExtractor returns an error
|
||||||
|
ErrorHandler func(context echo.Context, err error) error |
||||||
|
// DenyHandler provides a handler to be called when RateLimiter denies access
|
||||||
|
DenyHandler func(context echo.Context, identifier string, err error) error |
||||||
|
} |
||||||
|
// Extractor is used to extract data from echo.Context
|
||||||
|
Extractor func(context echo.Context) (string, error) |
||||||
|
) |
||||||
|
|
||||||
|
// errors
|
||||||
|
var ( |
||||||
|
// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
|
||||||
|
ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") |
||||||
|
// ErrExtractorError denotes an error raised when extractor function is unsuccessful
|
||||||
|
ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") |
||||||
|
) |
||||||
|
|
||||||
|
// DefaultRateLimiterConfig defines default values for RateLimiterConfig
|
||||||
|
var DefaultRateLimiterConfig = RateLimiterConfig{ |
||||||
|
Skipper: middleware.DefaultSkipper, |
||||||
|
IdentifierExtractor: func(ctx echo.Context) (string, error) { |
||||||
|
id := ctx.RealIP() |
||||||
|
return id, nil |
||||||
|
}, |
||||||
|
ErrorHandler: func(context echo.Context, err error) error { |
||||||
|
return &echo.HTTPError{ |
||||||
|
Code: ErrExtractorError.Code, |
||||||
|
Message: ErrExtractorError.Message, |
||||||
|
Internal: err, |
||||||
|
} |
||||||
|
}, |
||||||
|
DenyHandler: func(context echo.Context, identifier string, err error) error { |
||||||
|
return &echo.HTTPError{ |
||||||
|
Code: ErrRateLimitExceeded.Code, |
||||||
|
Message: ErrRateLimitExceeded.Message, |
||||||
|
Internal: err, |
||||||
|
} |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
/* |
||||||
|
RateLimiter returns a rate limiting middleware |
||||||
|
|
||||||
|
e := echo.New() |
||||||
|
|
||||||
|
limiterStore := middleware.NewRateLimiterMemoryStore(20) |
||||||
|
|
||||||
|
e.GET("/rate-limited", func(c echo.Context) error { |
||||||
|
return c.String(http.StatusOK, "test") |
||||||
|
}, RateLimiter(limiterStore)) |
||||||
|
*/ |
||||||
|
func RateLimiter(store RateLimiterStore) echo.MiddlewareFunc { |
||||||
|
config := DefaultRateLimiterConfig |
||||||
|
config.Store = store |
||||||
|
return config.ToMiddleware() |
||||||
|
} |
||||||
|
|
||||||
|
/* |
||||||
|
ToMiddleware returns a rate limiting middleware |
||||||
|
|
||||||
|
e := echo.New() |
||||||
|
|
||||||
|
config := middleware.RateLimiterConfig{ |
||||||
|
Skipper: DefaultSkipper, |
||||||
|
Store: middleware.NewRateLimiterMemoryStore( |
||||||
|
middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute} |
||||||
|
) |
||||||
|
IdentifierExtractor: func(ctx echo.Context) (string, error) { |
||||||
|
id := ctx.RealIP() |
||||||
|
return id, nil |
||||||
|
}, |
||||||
|
ErrorHandler: func(context echo.Context, err error) error { |
||||||
|
return context.JSON(http.StatusTooManyRequests, nil) |
||||||
|
}, |
||||||
|
DenyHandler: func(context echo.Context, identifier string) error { |
||||||
|
return context.JSON(http.StatusForbidden, nil) |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
e.GET("/rate-limited", func(c echo.Context) error { |
||||||
|
return c.String(http.StatusOK, "test") |
||||||
|
}, middleware.RateLimiterWithConfig(config)) |
||||||
|
*/ |
||||||
|
func (config *RateLimiterConfig) ToMiddleware() echo.MiddlewareFunc { |
||||||
|
if config.Skipper == nil { |
||||||
|
config.Skipper = DefaultSkipper |
||||||
|
} |
||||||
|
if config.IdentifierExtractor == nil { |
||||||
|
config.IdentifierExtractor = DefaultRateLimiterConfig.IdentifierExtractor |
||||||
|
} |
||||||
|
if config.ErrorHandler == nil { |
||||||
|
config.ErrorHandler = DefaultRateLimiterConfig.ErrorHandler |
||||||
|
} |
||||||
|
if config.DenyHandler == nil { |
||||||
|
config.DenyHandler = DefaultRateLimiterConfig.DenyHandler |
||||||
|
} |
||||||
|
if config.Store == nil { |
||||||
|
panic("Store configuration must be provided") |
||||||
|
} |
||||||
|
return func(next echo.HandlerFunc) echo.HandlerFunc { |
||||||
|
return func(c echo.Context) error { |
||||||
|
if config.Skipper(c) { |
||||||
|
return next(c) |
||||||
|
} |
||||||
|
if config.BeforeFunc != nil { |
||||||
|
config.BeforeFunc(c) |
||||||
|
} |
||||||
|
|
||||||
|
identifier, err := config.IdentifierExtractor(c) |
||||||
|
if err != nil { |
||||||
|
c.Error(config.ErrorHandler(c, err)) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
if allow, err := config.Store.Allow(identifier); !allow { |
||||||
|
c.Error(config.DenyHandler(c, identifier, err)) |
||||||
|
return nil |
||||||
|
} |
||||||
|
return next(c) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
type ( |
||||||
|
// RateLimiterMemoryStore is the built-in store implementation for RateLimiter
|
||||||
|
RateLimiterMemoryStore struct { |
||||||
|
visitors map[string]*Visitor |
||||||
|
mutex sync.Mutex |
||||||
|
rate rate.Limit // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
|
||||||
|
|
||||||
|
burst int |
||||||
|
expiresIn time.Duration |
||||||
|
lastCleanup time.Time |
||||||
|
|
||||||
|
timeNow func() time.Time |
||||||
|
} |
||||||
|
// Visitor signifies a unique user's limiter details
|
||||||
|
Visitor struct { |
||||||
|
*rate.Limiter |
||||||
|
lastSeen time.Time |
||||||
|
} |
||||||
|
) |
||||||
|
|
||||||
|
/* |
||||||
|
NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with |
||||||
|
the provided rate (as req/s). |
||||||
|
for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
|
||||||
|
|
||||||
|
Burst and ExpiresIn will be set to default values. |
||||||
|
|
||||||
|
Note that if the provided rate is a float number and Burst is zero, Burst will be treated as the rounded down value of the rate. |
||||||
|
|
||||||
|
Example (with 20 requests/sec): |
||||||
|
|
||||||
|
limiterStore := middleware.NewRateLimiterMemoryStore(20) |
||||||
|
*/ |
||||||
|
func NewRateLimiterMemoryStore(rate rate.Limit) (store *RateLimiterMemoryStore) { |
||||||
|
return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ |
||||||
|
Rate: rate, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
/* |
||||||
|
NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore |
||||||
|
with the provided configuration. Rate must be provided. Burst will be set to the rounded down value of |
||||||
|
the configured rate if not provided or set to 0. |
||||||
|
|
||||||
|
The build-in memory store is usually capable for modest loads. For higher loads other |
||||||
|
store implementations should be considered. |
||||||
|
|
||||||
|
Characteristics: |
||||||
|
* Concurrency above 100 parallel requests may causes measurable lock contention |
||||||
|
* A high number of different IP addresses (above 16000) may be impacted by the internally used Go map |
||||||
|
* A high number of requests from a single IP address may cause lock contention |
||||||
|
|
||||||
|
Example: |
||||||
|
|
||||||
|
limiterStore := middleware.NewRateLimiterMemoryStoreWithConfig( |
||||||
|
middleware.RateLimiterMemoryStoreConfig{Rate: 50, Burst: 200, ExpiresIn: 5 * time.Minute}, |
||||||
|
) |
||||||
|
*/ |
||||||
|
func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (store *RateLimiterMemoryStore) { |
||||||
|
store = &RateLimiterMemoryStore{} |
||||||
|
|
||||||
|
store.rate = config.Rate |
||||||
|
store.burst = config.Burst |
||||||
|
store.expiresIn = config.ExpiresIn |
||||||
|
if config.ExpiresIn == 0 { |
||||||
|
store.expiresIn = DefaultRateLimiterMemoryStoreConfig.ExpiresIn |
||||||
|
} |
||||||
|
if config.Burst == 0 { |
||||||
|
store.burst = int(config.Rate) |
||||||
|
} |
||||||
|
store.visitors = make(map[string]*Visitor) |
||||||
|
store.timeNow = time.Now |
||||||
|
store.lastCleanup = store.timeNow() |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore
|
||||||
|
type RateLimiterMemoryStoreConfig struct { |
||||||
|
Rate rate.Limit // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
|
||||||
|
Burst int // Burst is maximum number of requests to pass at the same moment. It additionally allows a number of requests to pass when rate limit is reached.
|
||||||
|
ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up
|
||||||
|
} |
||||||
|
|
||||||
|
// DefaultRateLimiterMemoryStoreConfig provides default configuration values for RateLimiterMemoryStore
|
||||||
|
var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{ |
||||||
|
ExpiresIn: 3 * time.Minute, |
||||||
|
} |
||||||
|
|
||||||
|
// Allow implements RateLimiterStore.Allow
|
||||||
|
func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { |
||||||
|
store.mutex.Lock() |
||||||
|
limiter, exists := store.visitors[identifier] |
||||||
|
if !exists { |
||||||
|
limiter = new(Visitor) |
||||||
|
limiter.Limiter = rate.NewLimiter(store.rate, store.burst) |
||||||
|
store.visitors[identifier] = limiter |
||||||
|
} |
||||||
|
now := store.timeNow() |
||||||
|
limiter.lastSeen = now |
||||||
|
if now.Sub(store.lastCleanup) > store.expiresIn { |
||||||
|
store.cleanupStaleVisitors() |
||||||
|
} |
||||||
|
store.mutex.Unlock() |
||||||
|
return limiter.AllowN(store.timeNow(), 1), nil |
||||||
|
} |
||||||
|
|
||||||
|
/* |
||||||
|
cleanupStaleVisitors helps manage the size of the visitors map by removing stale records |
||||||
|
of users who haven't visited again after the configured expiry time has elapsed |
||||||
|
*/ |
||||||
|
func (store *RateLimiterMemoryStore) cleanupStaleVisitors() { |
||||||
|
for id, visitor := range store.visitors { |
||||||
|
if store.timeNow().Sub(visitor.lastSeen) > store.expiresIn { |
||||||
|
delete(store.visitors, id) |
||||||
|
} |
||||||
|
} |
||||||
|
store.lastCleanup = store.timeNow() |
||||||
|
} |
@ -1,108 +0,0 @@ |
|||||||
package runtime |
|
||||||
|
|
||||||
import ( |
|
||||||
"context" |
|
||||||
"sorbet/pkg/log" |
|
||||||
"sync" |
|
||||||
) |
|
||||||
|
|
||||||
var ( |
|
||||||
// 并发组控制器
|
|
||||||
wg sync.WaitGroup |
|
||||||
// 运行时上下文对象
|
|
||||||
ctx context.Context |
|
||||||
// 信号量,控制任务并发数量
|
|
||||||
semaphore chan struct{} |
|
||||||
// 异步任务通道
|
|
||||||
tasks chan func(context.Context) |
|
||||||
// 上下文取消函数
|
|
||||||
cancel context.CancelFunc |
|
||||||
// 程序退出通道
|
|
||||||
exit chan chan error |
|
||||||
) |
|
||||||
|
|
||||||
func tryCancel() { |
|
||||||
cancelFunc := cancel |
|
||||||
if cancelFunc == nil { |
|
||||||
log.Panic("context already cancelled") |
|
||||||
} |
|
||||||
cancel = nil |
|
||||||
cancelFunc() |
|
||||||
// 等待任务执行完成
|
|
||||||
for len(semaphore) > 0 { |
|
||||||
work(nil) |
|
||||||
} |
|
||||||
// 理论上在没有信号量的情况下,是没有任务队列的。
|
|
||||||
for len(tasks) > 0 { |
|
||||||
log.Panic("nonempty task channel") |
|
||||||
} |
|
||||||
// 销毁注册的服务
|
|
||||||
for i := len(servlets) - 1; i >= 0; i-- { |
|
||||||
if err := servlets[i].Stop(); err != nil { |
|
||||||
log.Warn("servlet Stop returned with error: ", err) |
|
||||||
} |
|
||||||
} |
|
||||||
ctx = nil |
|
||||||
} |
|
||||||
|
|
||||||
// Go 异步任务
|
|
||||||
func Go(fn func(context.Context)) { |
|
||||||
if semaphore == nil || ctx == nil { |
|
||||||
log.Panic("not initialized") |
|
||||||
return |
|
||||||
} |
|
||||||
select { |
|
||||||
case semaphore <- struct{}{}: |
|
||||||
// If we are below our limit, spawn a new worker rather
|
|
||||||
// than waiting for one to become available.
|
|
||||||
async(&wg, func(context.Context) { |
|
||||||
work(fn) |
|
||||||
}) |
|
||||||
case tasks <- fn: |
|
||||||
// 放到任务队列中
|
|
||||||
// A worker is available and has accepted the task.
|
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// Async 异步执行函数
|
|
||||||
func Async(fn func(context.Context)) { |
|
||||||
async(&wg, fn) |
|
||||||
} |
|
||||||
|
|
||||||
// 执行任务
|
|
||||||
func work(f func(context.Context)) { |
|
||||||
defer func() { |
|
||||||
if semaphore != nil { |
|
||||||
// 释放
|
|
||||||
<-semaphore |
|
||||||
} |
|
||||||
}() |
|
||||||
|
|
||||||
// todo 能不能复用 sync.WaitGroup
|
|
||||||
g := &sync.WaitGroup{} |
|
||||||
if f != nil { |
|
||||||
Async(f) |
|
||||||
} |
|
||||||
for task := range tasks { |
|
||||||
Async(task) |
|
||||||
} |
|
||||||
g.Wait() |
|
||||||
} |
|
||||||
|
|
||||||
// 异步执行函数
|
|
||||||
func async(wg *sync.WaitGroup, fn func(context.Context)) { |
|
||||||
wg.Add(1) |
|
||||||
go func() { |
|
||||||
defer wg.Done() |
|
||||||
if ctx == nil { |
|
||||||
return |
|
||||||
} |
|
||||||
select { |
|
||||||
case <-ctx.Done(): |
|
||||||
// 上下文一旦结束,任务将被忽略
|
|
||||||
log.Error("context done") |
|
||||||
default: |
|
||||||
fn(ctx) |
|
||||||
} |
|
||||||
}() |
|
||||||
} |
|
@ -1,90 +0,0 @@ |
|||||||
package runtime |
|
||||||
|
|
||||||
import ( |
|
||||||
"context" |
|
||||||
"errors" |
|
||||||
"os" |
|
||||||
"os/signal" |
|
||||||
"sorbet/pkg/log" |
|
||||||
"syscall" |
|
||||||
) |
|
||||||
|
|
||||||
// Init 初始化
|
|
||||||
func Init() error { |
|
||||||
return InitWith(context.Background()) |
|
||||||
} |
|
||||||
|
|
||||||
func InitWith(root context.Context) error { |
|
||||||
if ctx != nil || cancel != nil { |
|
||||||
return errors.New("not stopped") |
|
||||||
} |
|
||||||
// 初始化上下文相关
|
|
||||||
ctx, cancel = context.WithCancel(root) |
|
||||||
// 初始化退出通道
|
|
||||||
if exit == nil { |
|
||||||
exit = make(chan chan error) |
|
||||||
} |
|
||||||
// 初始化信号量
|
|
||||||
if semaphore == nil { |
|
||||||
semaphore = make(chan struct{}, 256) |
|
||||||
} |
|
||||||
// 初始化任务通道
|
|
||||||
if tasks != nil { |
|
||||||
tasks = make(chan func(context.Context), 256) |
|
||||||
} |
|
||||||
// 初始化网络服务组件
|
|
||||||
for _, servlet := range servlets { |
|
||||||
if err := servlet.Init(ctx); err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
} |
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
func Start() error { |
|
||||||
// 创建服务器
|
|
||||||
srv, ln, err := createServer() |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
// 启动网络服务器
|
|
||||||
go func() { |
|
||||||
err := srv.Serve(ln) |
|
||||||
if err != nil { |
|
||||||
log.Error("encountered an error while serving listener: ", err) |
|
||||||
} |
|
||||||
}() |
|
||||||
log.Info("Listening on %s", ln.Addr().String()) |
|
||||||
|
|
||||||
// 监听停止命令,停止网络服务
|
|
||||||
go func() { |
|
||||||
errChan := <-exit |
|
||||||
tryCancel() |
|
||||||
errChan <- ln.Close() // stop the listener
|
|
||||||
}() |
|
||||||
|
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
func Run() error { |
|
||||||
if err := Start(); err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
// parse address for host, port
|
|
||||||
ch := make(chan os.Signal, 1) |
|
||||||
signal.Notify(ch, syscall.SIGTERM, syscall.SIGINT) |
|
||||||
log.Info("Received signal %s", <-ch) |
|
||||||
return Stop() |
|
||||||
} |
|
||||||
|
|
||||||
// Stop 停止运行
|
|
||||||
func Stop() error { |
|
||||||
if ctx == nil { |
|
||||||
return errors.New("already stopped") |
|
||||||
} |
|
||||||
ch := make(chan error) |
|
||||||
exit <- ch |
|
||||||
wg.Wait() |
|
||||||
return <-ch |
|
||||||
} |
|
@ -1,121 +0,0 @@ |
|||||||
package runtime |
|
||||||
|
|
||||||
import ( |
|
||||||
"fmt" |
|
||||||
"github.com/labstack/echo/v4" |
|
||||||
"github.com/labstack/echo/v4/middleware" |
|
||||||
"net" |
|
||||||
"net/http" |
|
||||||
"sorbet/internal/util" |
|
||||||
"sorbet/pkg/env" |
|
||||||
"sorbet/pkg/rsp" |
|
||||||
"time" |
|
||||||
) |
|
||||||
|
|
||||||
var ( |
|
||||||
// maxHeaderBytes is used by the http server to limit the size of request headers.
|
|
||||||
// This may need to be increased if accepting cookies from the public.
|
|
||||||
maxHeaderBytes = 1 << 20 |
|
||||||
// readTimeout is used by the http server to set a maximum duration before
|
|
||||||
// timing out read of the request. The default timeout is 10 seconds.
|
|
||||||
readTimeout = 10 * time.Second |
|
||||||
// writeTimeout is used by the http server to set a maximum duration before
|
|
||||||
// timing out write of the response. The default timeout is 10 seconds.
|
|
||||||
writeTimeout = 10 * time.Second |
|
||||||
// idleTimeout is used by the http server to set a maximum duration for
|
|
||||||
// keep-alive connections.
|
|
||||||
idleTimeout = 120 * time.Second |
|
||||||
) |
|
||||||
|
|
||||||
func newEchoFramework() (*echo.Echo, error) { |
|
||||||
e := echo.New() |
|
||||||
e.HideBanner = true |
|
||||||
e.HidePort = true |
|
||||||
e.HTTPErrorHandler = rsp.HTTPErrorHandler |
|
||||||
e.Debug = !env.IsEnv("prod") |
|
||||||
e.Logger = util.NewEchoLogger() |
|
||||||
e.Use(middleware.Recover()) |
|
||||||
e.Use(middleware.CORS()) |
|
||||||
e.Use(middleware.Logger()) |
|
||||||
for _, servlet := range servlets { |
|
||||||
group := e.Group("") |
|
||||||
for _, routable := range servlet.Routes() { |
|
||||||
routable.InitRoutes(group) |
|
||||||
} |
|
||||||
} |
|
||||||
routes := e.Routes() |
|
||||||
e.GET("/_routes", func(c echo.Context) error { |
|
||||||
return c.JSON(http.StatusOK, routes) |
|
||||||
}) |
|
||||||
return e, nil |
|
||||||
} |
|
||||||
|
|
||||||
func createServer() (*http.Server, net.Listener, error) { |
|
||||||
e, err := newEchoFramework() |
|
||||||
if err != nil { |
|
||||||
return nil, nil, err |
|
||||||
} |
|
||||||
|
|
||||||
l, err := createTCPListener() |
|
||||||
if err != nil { |
|
||||||
return nil, nil, err |
|
||||||
} |
|
||||||
|
|
||||||
s := &http.Server{ |
|
||||||
Handler: e, |
|
||||||
MaxHeaderBytes: env.Int("SERVER_MAX_HEADER_BYTES", maxHeaderBytes), |
|
||||||
ReadTimeout: env.Duration("SERVER_READ_TIMEOUT", readTimeout), |
|
||||||
WriteTimeout: env.Duration("SERVER_WRITE_TIMEOUT", writeTimeout), |
|
||||||
IdleTimeout: env.Duration("SERVER_IDLE_TIMEOUT", idleTimeout), |
|
||||||
} |
|
||||||
|
|
||||||
return s, l, nil |
|
||||||
} |
|
||||||
|
|
||||||
func createTCPListener() (net.Listener, error) { |
|
||||||
l, err := net.Listen( |
|
||||||
env.String("SERVER_NETWORK", "tcp"), |
|
||||||
fmt.Sprintf("%s:%d", |
|
||||||
env.String("SERVER_ADDRESS", "0.0.0.0"), |
|
||||||
env.Int("SERVER_PORT", 1324), |
|
||||||
), |
|
||||||
) |
|
||||||
if err == nil { |
|
||||||
l = net.Listener(TCPKeepAliveListener{ |
|
||||||
TCPListener: l.(*net.TCPListener), |
|
||||||
}) |
|
||||||
} |
|
||||||
return l, err |
|
||||||
} |
|
||||||
|
|
||||||
// TCPKeepAliveListener sets TCP keep-alive timeouts on accepted
|
|
||||||
// connections. It's used by ListenAndServe and ListenAndServeTLS so
|
|
||||||
// dead TCP connections (e.g. closing laptop mid-download) eventually
|
|
||||||
// go away.
|
|
||||||
//
|
|
||||||
// This is here because it is not exposed in the stdlib and
|
|
||||||
// we'd prefer to have a hold of the http.Server's net.Listener so we can close it
|
|
||||||
// on shutdown.
|
|
||||||
//
|
|
||||||
// Taken from here: https://golang.org/src/net/http/server.go?s=63121:63175#L2120
|
|
||||||
type TCPKeepAliveListener struct { |
|
||||||
*net.TCPListener |
|
||||||
} |
|
||||||
|
|
||||||
// Accept accepts the next incoming call and returns the new
|
|
||||||
// connection. KeepAlivePeriod is set properly.
|
|
||||||
func (ln TCPKeepAliveListener) Accept() (c net.Conn, err error) { |
|
||||||
tc, err := ln.AcceptTCP() |
|
||||||
if err != nil { |
|
||||||
return |
|
||||||
} |
|
||||||
err = tc.SetKeepAlive(true) |
|
||||||
if err != nil { |
|
||||||
return |
|
||||||
} |
|
||||||
err = tc.SetKeepAlivePeriod(3 * time.Minute) |
|
||||||
if err != nil { |
|
||||||
return |
|
||||||
} |
|
||||||
return tc, nil |
|
||||||
} |
|
@ -1,69 +0,0 @@ |
|||||||
package runtime |
|
||||||
|
|
||||||
import ( |
|
||||||
"context" |
|
||||||
"errors" |
|
||||||
"github.com/labstack/echo/v4" |
|
||||||
"slices" |
|
||||||
) |
|
||||||
|
|
||||||
// 注册的网络服务组件
|
|
||||||
var servlets []Servlet |
|
||||||
|
|
||||||
// Servlet 网络服务组件接口
|
|
||||||
type Servlet interface { |
|
||||||
// Name 服务名称
|
|
||||||
Name() string |
|
||||||
// Priority 优先级,用于启动和销毁的执行顺序
|
|
||||||
Priority() int |
|
||||||
// Init 初始化服务
|
|
||||||
Init(ctx context.Context) error |
|
||||||
// Routes 注册路由
|
|
||||||
Routes() []Routable |
|
||||||
// Start 非阻塞方式启动服务
|
|
||||||
Start() error |
|
||||||
// Stop 停止服务
|
|
||||||
Stop() error |
|
||||||
} |
|
||||||
|
|
||||||
func Reset() error { |
|
||||||
if len(servlets) == 0 { |
|
||||||
return nil |
|
||||||
} |
|
||||||
if ctx == nil { |
|
||||||
return errors.New("servlets is running") |
|
||||||
} |
|
||||||
servlets = nil |
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
// Use 注册服务
|
|
||||||
func Use(servlets ...Servlet) error { |
|
||||||
if len(servlets) == 0 { |
|
||||||
return nil |
|
||||||
} |
|
||||||
for i := 0; i < len(servlets); i++ { |
|
||||||
if !use(servlets[i]) { |
|
||||||
return errors.New("service already registered") |
|
||||||
} |
|
||||||
} |
|
||||||
slices.SortFunc(servlets, func(a, b Servlet) int { |
|
||||||
return b.Priority() - a.Priority() // 按优先级排序
|
|
||||||
}) |
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
func use(servlet Servlet) bool { |
|
||||||
exists := slices.ContainsFunc(servlets, func(s Servlet) bool { |
|
||||||
return s.Name() == servlet.Name() |
|
||||||
}) |
|
||||||
if !exists { |
|
||||||
servlets = append(servlets, servlet) |
|
||||||
return true |
|
||||||
} |
|
||||||
return false |
|
||||||
} |
|
||||||
|
|
||||||
type Routable interface { |
|
||||||
InitRoutes(*echo.Group) |
|
||||||
} |
|
@ -0,0 +1,72 @@ |
|||||||
|
package services |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"errors" |
||||||
|
"github.com/labstack/echo/v4" |
||||||
|
"slices" |
||||||
|
"sorbet/internal/services/company" |
||||||
|
"sorbet/internal/services/config" |
||||||
|
"sorbet/internal/services/feature" |
||||||
|
"sorbet/internal/services/resource" |
||||||
|
"sorbet/internal/services/system" |
||||||
|
"sorbet/pkg/crud" |
||||||
|
) |
||||||
|
|
||||||
|
var services []crud.Service |
||||||
|
|
||||||
|
const ContextEchoKey = "echo_framework" |
||||||
|
|
||||||
|
func Register(service crud.Service) error { |
||||||
|
for _, applet := range services { |
||||||
|
if applet.Name() == service.Name() { |
||||||
|
return errors.New("service already registered") |
||||||
|
} |
||||||
|
} |
||||||
|
services = append(services, service) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func Init(ctx context.Context) error { |
||||||
|
services = []crud.Service{ |
||||||
|
&config.Service{}, |
||||||
|
&company.Service{}, |
||||||
|
&resource.Service{}, |
||||||
|
&feature.Service{}, |
||||||
|
&system.Service{}, |
||||||
|
} |
||||||
|
|
||||||
|
// 按优先级排序
|
||||||
|
slices.SortFunc(services, func(a, b crud.Service) int { |
||||||
|
return b.Priority() - a.Priority() |
||||||
|
}) |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func Bootstrap(ctx context.Context) error { |
||||||
|
e := ctx.Value(ContextEchoKey).(*echo.Echo) |
||||||
|
for _, service := range services { |
||||||
|
err := service.Init(crud.NewContext(ctx, e.Group(""))) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
err = service.Bootstrap() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func Destroy() error { |
||||||
|
for i := len(services) - 1; i >= 0; i++ { |
||||||
|
service := services[i] |
||||||
|
err := service.Destroy() |
||||||
|
if err != nil { |
||||||
|
// TODO(hupeh): 是否需要销毁策略,比如可以继续或者中断等行为
|
||||||
|
return err |
||||||
|
} |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,128 @@ |
|||||||
|
package app |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"errors" |
||||||
|
"os" |
||||||
|
"sorbet/pkg/app/appkit" |
||||||
|
) |
||||||
|
|
||||||
|
func Init(args ...string) (err error) { |
||||||
|
if len(args) == 0 { |
||||||
|
args = os.Args[1:] |
||||||
|
} |
||||||
|
|
||||||
|
switch Status() { |
||||||
|
case Initialized: |
||||||
|
return nil |
||||||
|
case Starting: |
||||||
|
return errors.New("starting") |
||||||
|
case Running: |
||||||
|
return errors.New("running") |
||||||
|
case Stopping: |
||||||
|
return errors.New("stopping") |
||||||
|
case Idle, Stopped: |
||||||
|
// continue
|
||||||
|
} |
||||||
|
|
||||||
|
initLifecycle() |
||||||
|
initStats(args) |
||||||
|
initUdp(args) |
||||||
|
initBus() |
||||||
|
initBuiltins() |
||||||
|
|
||||||
|
setStatus(Initialized) |
||||||
|
|
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
func initBuiltins() { |
||||||
|
// 订阅应用启动
|
||||||
|
var sub appkit.Subscriber |
||||||
|
sub = Sub("start", func([]byte) []byte { |
||||||
|
sub.Cancel() |
||||||
|
return nil |
||||||
|
}) |
||||||
|
|
||||||
|
// 订阅应用停止
|
||||||
|
Sub("stop", func(bytes []byte) []byte { |
||||||
|
quit := exit |
||||||
|
if quit != nil { |
||||||
|
exit = nil |
||||||
|
quit() |
||||||
|
} |
||||||
|
return nil |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func Loop() error { |
||||||
|
switch Status() { |
||||||
|
case Starting, Running: |
||||||
|
return nil |
||||||
|
case Stopping: |
||||||
|
return errors.New("stopping") |
||||||
|
case Idle: |
||||||
|
return errors.New("idle, you maybe forgot call app.Init()") |
||||||
|
case Stopped: |
||||||
|
return errors.New("stopped, you maybe forgot call app.Init()") |
||||||
|
case Initialized: |
||||||
|
// nothing
|
||||||
|
} |
||||||
|
|
||||||
|
// 释放内存
|
||||||
|
defer free() |
||||||
|
|
||||||
|
wg.Add(1) |
||||||
|
setStatus(Starting) |
||||||
|
|
||||||
|
// In a goroutine, we wait on for all goroutines to complete (for example
|
||||||
|
// timers). We use this to signal to the main thread to exit.
|
||||||
|
// wg.Add(1) basically translates to uv_ref, if this was Node.
|
||||||
|
// wg.Done() basically translates to uv_unref
|
||||||
|
go func() { |
||||||
|
wg.Wait() |
||||||
|
setStatus(Stopping) |
||||||
|
stop() |
||||||
|
}() |
||||||
|
|
||||||
|
// 启动应用指令
|
||||||
|
Pub("start", nil) |
||||||
|
|
||||||
|
for { |
||||||
|
select { |
||||||
|
case msg := <-pubs: |
||||||
|
err := dispatch(msg) |
||||||
|
exitOnError(err) |
||||||
|
wg.Done() // Corresponds to the wg.Add(1) in Pub().
|
||||||
|
|
||||||
|
case <-ctx.Done(): |
||||||
|
// 只有等待所有协程全部完成后,
|
||||||
|
// 然后才可以退出主循环
|
||||||
|
checkPubsEmpty() |
||||||
|
setStatus(Stopped) |
||||||
|
err := ctx.Err() |
||||||
|
if errors.Is(err, context.Canceled) { |
||||||
|
return nil |
||||||
|
} |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
// We don't want to exit until we've received at least one message.
|
||||||
|
// This is so the program doesn't exit after sending the "start"
|
||||||
|
// message.
|
||||||
|
if Status() == Starting { |
||||||
|
wg.Done() |
||||||
|
start() |
||||||
|
setStatus(Running) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func Stop() { |
||||||
|
Pub("stop", nil) |
||||||
|
} |
||||||
|
|
||||||
|
func Go(fn func(context.Context)) { |
||||||
|
assert(fn != nil, "invalid fn") |
||||||
|
async(fn) |
||||||
|
} |
@ -0,0 +1,5 @@ |
|||||||
|
package appkit |
||||||
|
|
||||||
|
type Subscriber interface { |
||||||
|
Cancel() |
||||||
|
} |
@ -0,0 +1,81 @@ |
|||||||
|
package appkit |
||||||
|
|
||||||
|
import ( |
||||||
|
"bytes" |
||||||
|
"encoding/binary" |
||||||
|
"errors" |
||||||
|
"math" |
||||||
|
) |
||||||
|
|
||||||
|
// 数据协议如下:
|
||||||
|
//
|
||||||
|
// +-----------+--------+--------------+-------+------+
|
||||||
|
// | 0x01,0x02 | length | topic-length | topic | data |
|
||||||
|
// +-----------+--------+--------------+-------+------+
|
||||||
|
// 2 2 2 n n
|
||||||
|
//
|
||||||
|
// * 头部有两个模数 0x01、0x02;
|
||||||
|
// * length 表示消息的总长度;
|
||||||
|
// * 然后分别是 topic、data,包含其长度和数据;
|
||||||
|
// * 长度使用 uint16 类型,所以占用 2 个字节;
|
||||||
|
//
|
||||||
|
// todo 实现数据校验码,防止篡改
|
||||||
|
|
||||||
|
// Encode 数据编码
|
||||||
|
func Encode(topic, data []byte) ([]byte, error) { |
||||||
|
topicLen := len(topic) |
||||||
|
dataLen := len(data) |
||||||
|
size := 2 + topicLen + dataLen // 不包括模数和数据长度
|
||||||
|
if topicLen == 0 { |
||||||
|
return nil, errors.New("topic too small") |
||||||
|
} |
||||||
|
if size+4 > math.MaxUint16 { |
||||||
|
return nil, errors.New("data too big") |
||||||
|
} |
||||||
|
buf := &bytes.Buffer{} |
||||||
|
// 由于 模数、topic 和 data 都是基本类型,所以当
|
||||||
|
// 执行 binary.Write() 时是不会出错的
|
||||||
|
_ = binary.Write(buf, binary.BigEndian, []byte{0x01, 0x02}) |
||||||
|
_ = binary.Write(buf, binary.BigEndian, uint16(size)) |
||||||
|
_ = binary.Write(buf, binary.BigEndian, uint16(topicLen)) |
||||||
|
_ = binary.Write(buf, binary.BigEndian, topic) |
||||||
|
_ = binary.Write(buf, binary.BigEndian, data) // 如果没有数据会不会报错
|
||||||
|
return buf.Bytes(), nil |
||||||
|
} |
||||||
|
|
||||||
|
// Decode 数据解码
|
||||||
|
func Decode(buf []byte) (topic, data []byte, err error) { |
||||||
|
r := bytes.NewReader(buf) |
||||||
|
// 读取模数
|
||||||
|
var v [2]byte |
||||||
|
err = binary.Read(r, binary.BigEndian, &v) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
if v[0] != 0x01 || v[1] != 0x02 { |
||||||
|
err = errors.New("protocol error") |
||||||
|
return |
||||||
|
} |
||||||
|
// 读取数据长度
|
||||||
|
var size uint16 |
||||||
|
err = binary.Read(r, binary.BigEndian, &size) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
// 读取主题长度
|
||||||
|
var topicLen uint16 |
||||||
|
err = binary.Read(r, binary.BigEndian, &topicLen) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
// 读取主题
|
||||||
|
topic = make([]byte, topicLen) |
||||||
|
err = binary.Read(r, binary.BigEndian, &topic) |
||||||
|
if err != nil { |
||||||
|
return |
||||||
|
} |
||||||
|
// 读取数据
|
||||||
|
data = make([]byte, size-2-topicLen) |
||||||
|
err = binary.Read(r, binary.BigEndian, &data) |
||||||
|
return |
||||||
|
} |
@ -0,0 +1,215 @@ |
|||||||
|
package app |
||||||
|
|
||||||
|
import ( |
||||||
|
"errors" |
||||||
|
"net" |
||||||
|
"slices" |
||||||
|
"sorbet/pkg/app/appkit" |
||||||
|
"sync" |
||||||
|
"sync/atomic" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
// 操作锁
|
||||||
|
//
|
||||||
|
// * 添加订阅器
|
||||||
|
lock sync.RWMutex |
||||||
|
|
||||||
|
// 用于控制主体程序并发,涉及以下几个方面:
|
||||||
|
//
|
||||||
|
// * 任务调度
|
||||||
|
// * 定时器
|
||||||
|
// * 发布订阅
|
||||||
|
// * 主程序轮询
|
||||||
|
wg sync.WaitGroup |
||||||
|
|
||||||
|
// 消息发布通道
|
||||||
|
pubs chan Msg |
||||||
|
|
||||||
|
// 消息订阅器
|
||||||
|
subs map[string][]*subscriber |
||||||
|
|
||||||
|
nextSubId int32 |
||||||
|
) |
||||||
|
|
||||||
|
func initBus() { |
||||||
|
if pubs == nil { |
||||||
|
pubs = make(chan Msg, 256) |
||||||
|
} |
||||||
|
if subs == nil { |
||||||
|
subs = make(map[string][]*subscriber) |
||||||
|
} |
||||||
|
OnFree(freeBus) |
||||||
|
} |
||||||
|
|
||||||
|
func freeBus() { |
||||||
|
clear(subs) |
||||||
|
|
||||||
|
select { |
||||||
|
case _, ok := <-pubs: |
||||||
|
if !ok { |
||||||
|
// 通道已经关闭了,需要重新构建
|
||||||
|
pubs = make(chan Msg, 256) |
||||||
|
return |
||||||
|
} |
||||||
|
// 清空 pubs 缓存的数据,复用它
|
||||||
|
l := len(pubs) |
||||||
|
for i := 0; i < l; i++ { |
||||||
|
<-pubs |
||||||
|
} |
||||||
|
default: |
||||||
|
// 通道里面没有值,可以复用
|
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func Sub(topic string, handle func([]byte) []byte) appkit.Subscriber { |
||||||
|
lock.Lock() |
||||||
|
defer lock.Unlock() |
||||||
|
sub := &subscriber{ |
||||||
|
id: atomic.AddInt32(&nextSubId, 1), |
||||||
|
topic: topic, |
||||||
|
handle: handle, |
||||||
|
} |
||||||
|
if _, ok := subs[topic]; !ok { |
||||||
|
subs[topic] = make([]*subscriber, 0) |
||||||
|
} |
||||||
|
subs[topic] = append(subs[topic], sub) |
||||||
|
return sub |
||||||
|
} |
||||||
|
|
||||||
|
func Pub(topic string, data []byte) { |
||||||
|
pubMsg(Msg{Data: data, topic: topic}) |
||||||
|
} |
||||||
|
|
||||||
|
func pubMsg(msg Msg) { |
||||||
|
wg.Add(1) |
||||||
|
pubs <- msg |
||||||
|
} |
||||||
|
|
||||||
|
type Msg struct { |
||||||
|
Data []byte |
||||||
|
|
||||||
|
topic string |
||||||
|
addr net.Addr |
||||||
|
pc net.PacketConn |
||||||
|
} |
||||||
|
|
||||||
|
// Size 消息长度
|
||||||
|
func (m Msg) Size() int { |
||||||
|
return 6 + len([]byte(m.topic)) + len(m.Data) |
||||||
|
} |
||||||
|
|
||||||
|
type subscriber struct { |
||||||
|
id int32 |
||||||
|
topic string |
||||||
|
handle func([]byte) []byte |
||||||
|
} |
||||||
|
|
||||||
|
func (s *subscriber) Active() bool { |
||||||
|
return atomic.LoadInt32(&s.id) > 0 |
||||||
|
} |
||||||
|
|
||||||
|
func (s *subscriber) Cancel() { |
||||||
|
id := atomic.SwapInt32(&s.id, 0) |
||||||
|
if id == 0 { |
||||||
|
return |
||||||
|
} |
||||||
|
lock.Lock() |
||||||
|
defer lock.Unlock() |
||||||
|
subscribers, ok := subs[s.topic] |
||||||
|
if !ok { |
||||||
|
return |
||||||
|
} |
||||||
|
if len(subscribers) > 0 { |
||||||
|
subscribers = slices.DeleteFunc(subscribers, func(sub *subscriber) bool { |
||||||
|
return sub.id == id |
||||||
|
}) |
||||||
|
} |
||||||
|
if len(subscribers) > 0 { |
||||||
|
subs[s.topic] = subscribers |
||||||
|
} else { |
||||||
|
delete(subs, s.topic) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (s *subscriber) invoke(data []byte) []byte { |
||||||
|
if s.Active() { |
||||||
|
return s.handle(data) |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func dispatch(msg Msg) error { |
||||||
|
lock.RLock() |
||||||
|
var cbs []*subscriber |
||||||
|
if subscribers, ok := subs[msg.topic]; ok { |
||||||
|
for _, s := range subscribers { |
||||||
|
if s.Active() { |
||||||
|
cbs = append(cbs, s) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
lock.RUnlock() |
||||||
|
|
||||||
|
if len(cbs) == 0 { |
||||||
|
return errors.New("no subscribers for topic " + msg.topic) |
||||||
|
} |
||||||
|
|
||||||
|
var response []byte |
||||||
|
for _, cb := range cbs { |
||||||
|
res := cb.invoke(msg.Data) |
||||||
|
if res != nil { |
||||||
|
response = res |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if len(response) > 0 && msg.pc != nil { |
||||||
|
stats.v8workerRespond++ |
||||||
|
stats.v8workerBytesSent += len(response) |
||||||
|
|
||||||
|
n, ex := msg.pc.WriteTo(response, msg.addr) |
||||||
|
if n > 0 { |
||||||
|
stats.v8workerSend++ |
||||||
|
stats.v8workerBytesSent += n |
||||||
|
} |
||||||
|
if ex != nil { |
||||||
|
// todo handle the error
|
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func recv(pc net.PacketConn, addr net.Addr, buf []byte) { |
||||||
|
stats.v8workerRecv++ |
||||||
|
stats.v8workerBytesRecv += len(buf) |
||||||
|
|
||||||
|
topic, data, err := appkit.Decode(buf) |
||||||
|
if err != nil { |
||||||
|
// todo handle the error
|
||||||
|
// return errors.New("invalid payload")
|
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
pubMsg(Msg{ |
||||||
|
Data: data, |
||||||
|
topic: string(topic), |
||||||
|
addr: addr, |
||||||
|
pc: pc, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func checkPubsEmpty() { |
||||||
|
// We've received a done event. As a sanity check, make sure that resChan is
|
||||||
|
// empty.
|
||||||
|
select { |
||||||
|
case _, ok := <-pubs: |
||||||
|
if ok { |
||||||
|
panic("Read a message from pubs after context deadlined.") |
||||||
|
} else { |
||||||
|
panic("pubs closed. Unexpected.") |
||||||
|
} |
||||||
|
default: |
||||||
|
// No value ready, moving on.
|
||||||
|
} |
||||||
|
} |
@ -0,0 +1,83 @@ |
|||||||
|
package app |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"sync/atomic" |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
Idle int32 = iota |
||||||
|
Initialized |
||||||
|
Starting |
||||||
|
Running |
||||||
|
Stopping |
||||||
|
Stopped |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
// 当前状态
|
||||||
|
status int32 |
||||||
|
|
||||||
|
// 应用上下文
|
||||||
|
ctx context.Context |
||||||
|
exit context.CancelFunc |
||||||
|
|
||||||
|
start func() // 应用启动钩子
|
||||||
|
stop func() // 应用停止钩子
|
||||||
|
free func() // 内存释放钩子
|
||||||
|
) |
||||||
|
|
||||||
|
func initLifecycle() { |
||||||
|
ctx, exit = context.WithCancel(context.Background()) |
||||||
|
|
||||||
|
start = func() { |
||||||
|
// nothing
|
||||||
|
} |
||||||
|
|
||||||
|
stop = func() { |
||||||
|
if exit != nil { |
||||||
|
exit() |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
free = func() { |
||||||
|
ctx = nil |
||||||
|
start = nil |
||||||
|
stop = nil |
||||||
|
exit = nil |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// setStatus 设置应用状态
|
||||||
|
func setStatus(newStatus int32) { |
||||||
|
atomic.StoreInt32(&status, newStatus) |
||||||
|
} |
||||||
|
|
||||||
|
// Status 返回当前状态
|
||||||
|
func Status() int32 { |
||||||
|
return atomic.LoadInt32(&status) |
||||||
|
} |
||||||
|
|
||||||
|
func OnStart(fn func()) { |
||||||
|
oldStart := start |
||||||
|
start = func() { |
||||||
|
oldStart() |
||||||
|
fn() |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func OnStop(fn func()) { |
||||||
|
oldStop := stop |
||||||
|
stop = func() { |
||||||
|
fn() |
||||||
|
oldStop() |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func OnFree(fn func()) { |
||||||
|
oldFree := free |
||||||
|
free = func() { |
||||||
|
fn() |
||||||
|
oldFree() |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,44 @@ |
|||||||
|
package app |
||||||
|
|
||||||
|
import ( |
||||||
|
"os" |
||||||
|
"reflect" |
||||||
|
"runtime/pprof" |
||||||
|
) |
||||||
|
|
||||||
|
var stats struct { |
||||||
|
v8workerSend int |
||||||
|
v8workerRespond int |
||||||
|
v8workerRecv int |
||||||
|
v8workerBytesSent int |
||||||
|
v8workerBytesRecv int |
||||||
|
} |
||||||
|
|
||||||
|
func initStats(args []string) { |
||||||
|
initCpuProfile(strArg(args, "--cpuprof")) |
||||||
|
initMemoryProfile(strArg(args, "--memprof")) |
||||||
|
OnFree(freeStats) |
||||||
|
} |
||||||
|
|
||||||
|
func initCpuProfile(cpuprof string) { |
||||||
|
if cpuprof != "" { |
||||||
|
cpuProfile, err := os.Create(cpuprof) |
||||||
|
check(err) |
||||||
|
check(pprof.StartCPUProfile(cpuProfile)) |
||||||
|
OnStop(pprof.StopCPUProfile) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func initMemoryProfile(memprof string) { |
||||||
|
if memprof != "" { |
||||||
|
memProfile, err := os.Create(memprof) |
||||||
|
check(err) |
||||||
|
check(pprof.WriteHeapProfile(memProfile)) |
||||||
|
OnStop(func() { check(memProfile.Close()) }) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func freeStats() { |
||||||
|
newStats := reflect.New(reflect.TypeOf(stats)) |
||||||
|
reflect.ValueOf(&stats).Elem().Set(newStats.Elem()) |
||||||
|
} |
@ -0,0 +1,63 @@ |
|||||||
|
package app |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"errors" |
||||||
|
"io" |
||||||
|
"net" |
||||||
|
) |
||||||
|
|
||||||
|
// 通过 UDP 协议实现外部通信
|
||||||
|
var pc net.PacketConn |
||||||
|
|
||||||
|
func initUdp(args []string) { |
||||||
|
// 未开启 udp
|
||||||
|
if !boolArg(args, "--udp") { |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
// 这一步属于初始化,切不可放到 OnStart 里面
|
||||||
|
sizeLimit := intArg(args, "--payload", 1024<<2) |
||||||
|
addr := strArg(args, "--addr", ":1324") |
||||||
|
|
||||||
|
// 当程序启动时开启 udp 服务器
|
||||||
|
OnStart(func() { |
||||||
|
listenUDP(addr, sizeLimit) |
||||||
|
}) |
||||||
|
|
||||||
|
OnStop(freeUdp) |
||||||
|
} |
||||||
|
|
||||||
|
func freeUdp() { |
||||||
|
if pc != nil { |
||||||
|
err := pc.Close() |
||||||
|
pc = nil |
||||||
|
check(err) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func listenUDP(address string, maxSize int) { |
||||||
|
var err error |
||||||
|
pc, err = net.ListenPacket("udp", address) |
||||||
|
check(err) |
||||||
|
|
||||||
|
async(func(ctx context.Context) { |
||||||
|
for { |
||||||
|
select { |
||||||
|
case <-ctx.Done(): |
||||||
|
return |
||||||
|
default: |
||||||
|
} |
||||||
|
|
||||||
|
buf := make([]byte, maxSize) |
||||||
|
n, addr, err := pc.ReadFrom(buf) |
||||||
|
if err != nil { |
||||||
|
if errors.Is(err, io.EOF) { |
||||||
|
return |
||||||
|
} |
||||||
|
continue |
||||||
|
} |
||||||
|
recv(pc, addr, buf[:n]) |
||||||
|
} |
||||||
|
}) |
||||||
|
} |
@ -0,0 +1,101 @@ |
|||||||
|
package app |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"errors" |
||||||
|
"os" |
||||||
|
"slices" |
||||||
|
"strconv" |
||||||
|
"strings" |
||||||
|
) |
||||||
|
|
||||||
|
func assert(cond bool, msg string) { |
||||||
|
if !cond { |
||||||
|
panic(msg) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func check(e error) { |
||||||
|
if e != nil { |
||||||
|
panic(e) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func exitOnError(err error) { |
||||||
|
if err != nil { |
||||||
|
os.Stderr.WriteString(err.Error()) |
||||||
|
os.Exit(1) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func async(cb func(ctx context.Context)) { |
||||||
|
wg.Add(1) |
||||||
|
go func() { |
||||||
|
defer wg.Done() |
||||||
|
select { |
||||||
|
case <-ctx.Done(): |
||||||
|
return |
||||||
|
default: |
||||||
|
cb(ctx) |
||||||
|
} |
||||||
|
}() |
||||||
|
} |
||||||
|
|
||||||
|
func getArg[T any](args []string, name string, f func(string, bool) T) T { |
||||||
|
i := slices.Index(args, name) |
||||||
|
if i == -1 { |
||||||
|
return f("", false) |
||||||
|
} |
||||||
|
if i == len(args)-1 { |
||||||
|
panic(errors.New("invalid " + name + " option")) |
||||||
|
} |
||||||
|
s := args[i+1] |
||||||
|
if strings.HasPrefix(s, "--") { |
||||||
|
return f("", true) |
||||||
|
} |
||||||
|
return f(s, true) |
||||||
|
} |
||||||
|
|
||||||
|
func strArg(args []string, name string, def ...string) string { |
||||||
|
return getArg(args, name, func(s string, b bool) string { |
||||||
|
if b { |
||||||
|
return s |
||||||
|
} |
||||||
|
for _, v := range def { |
||||||
|
return v |
||||||
|
} |
||||||
|
return "" |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func intArg(args []string, name string, def ...int) int { |
||||||
|
return getArg(args, name, func(s string, b bool) int { |
||||||
|
if b { |
||||||
|
size, err := strconv.Atoi(s) |
||||||
|
check(err) |
||||||
|
assert(size > 0, "invalid "+name+" option") |
||||||
|
return size |
||||||
|
} |
||||||
|
for _, v := range def { |
||||||
|
return v |
||||||
|
} |
||||||
|
return 0 |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
func boolArg(args []string, name string, def ...bool) bool { |
||||||
|
return getArg(args, name, func(s string, b bool) bool { |
||||||
|
if b { |
||||||
|
if s == "" { |
||||||
|
return true |
||||||
|
} |
||||||
|
v, err := strconv.ParseBool(s) |
||||||
|
check(err) |
||||||
|
return v |
||||||
|
} |
||||||
|
for _, v := range def { |
||||||
|
return v |
||||||
|
} |
||||||
|
return false |
||||||
|
}) |
||||||
|
} |
@ -0,0 +1,53 @@ |
|||||||
|
package crud |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"github.com/labstack/echo/v4" |
||||||
|
"sync" |
||||||
|
) |
||||||
|
|
||||||
|
type Context struct { |
||||||
|
context.Context |
||||||
|
store map[any]any |
||||||
|
router *echo.Group |
||||||
|
mu sync.RWMutex |
||||||
|
} |
||||||
|
|
||||||
|
func NewContext(ctx context.Context, router *echo.Group) *Context { |
||||||
|
return &Context{ |
||||||
|
Context: ctx, |
||||||
|
store: make(map[any]any), |
||||||
|
router: router, |
||||||
|
mu: sync.RWMutex{}, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Routes 注册路由
|
||||||
|
func (c *Context) Routes(routes Routable) { |
||||||
|
c.mu.Lock() |
||||||
|
defer c.mu.Unlock() |
||||||
|
routes.InitRoutes(c.router) |
||||||
|
} |
||||||
|
|
||||||
|
// Set 设置值
|
||||||
|
func (c *Context) Set(key, val any) { |
||||||
|
c.mu.Lock() |
||||||
|
defer c.mu.Unlock() |
||||||
|
c.store[key] = val |
||||||
|
} |
||||||
|
|
||||||
|
// Get 获取值,只会获取通过 Set 方法设置的值
|
||||||
|
func (c *Context) Get(key any) (any, bool) { |
||||||
|
c.mu.RLock() |
||||||
|
defer c.mu.RUnlock() |
||||||
|
val, ok := c.store[key] |
||||||
|
return val, ok |
||||||
|
} |
||||||
|
|
||||||
|
// Value 获取值
|
||||||
|
func (c *Context) Value(key any) any { |
||||||
|
if val, ok := c.Get(key); ok { |
||||||
|
return val |
||||||
|
} |
||||||
|
return c.Value(key) |
||||||
|
} |
@ -0,0 +1,20 @@ |
|||||||
|
package crud |
||||||
|
|
||||||
|
import "github.com/labstack/echo/v4" |
||||||
|
|
||||||
|
type Service interface { |
||||||
|
// Name 服务名称
|
||||||
|
Name() string |
||||||
|
// Priority 优先级,用于启动和销毁的执行顺序
|
||||||
|
Priority() int |
||||||
|
// Init 初始化服务
|
||||||
|
Init(ctx *Context) error |
||||||
|
// Bootstrap 启动服务
|
||||||
|
Bootstrap() error |
||||||
|
// Destroy 销毁服务
|
||||||
|
Destroy() error |
||||||
|
} |
||||||
|
|
||||||
|
type Routable interface { |
||||||
|
InitRoutes(r *echo.Group) |
||||||
|
} |
Loading…
Reference in new issue