Compare commits
2 Commits
17eefa152f
...
faded4e634
Author | SHA1 | Date |
---|---|---|
熊二 | faded4e634 | 1 year ago |
熊二 | e8edf6c801 | 1 year ago |
@ -0,0 +1 @@ |
||||
* text=auto eol=lf |
@ -1,74 +0,0 @@ |
||||
package middleware |
||||
|
||||
import ( |
||||
"github.com/labstack/echo/v4" |
||||
"github.com/labstack/echo/v4/middleware" |
||||
) |
||||
|
||||
// KeyAuthValidator defines a function to validate KeyAuth credentials.
|
||||
type KeyAuthValidator = middleware.KeyAuthValidator |
||||
|
||||
// KeyAuthErrorHandler defines a function which is executed for an invalid key.
|
||||
type KeyAuthErrorHandler = middleware.KeyAuthErrorHandler |
||||
|
||||
// KeyAuthConfig defines the config for KeyAuth middleware.
|
||||
type KeyAuthConfig struct { |
||||
Skipper Skipper |
||||
|
||||
// KeyLookup is a string in the form of "<source>:<name>" or "<source>:<name>,<source>:<name>" that is used
|
||||
// to extract key from the request.
|
||||
// Optional. Default value "header:Authorization".
|
||||
// Possible values:
|
||||
// - "header:<name>" or "header:<name>:<cut-prefix>"
|
||||
// `<cut-prefix>` is argument value to cut/trim prefix of the extracted value. This is useful if header
|
||||
// value has static prefix like `Authorization: <auth-scheme> <authorisation-parameters>` where part that we
|
||||
// want to cut is `<auth-scheme> ` note the space at the end.
|
||||
// In case of basic authentication `Authorization: Basic <credentials>` prefix we want to remove is `Basic `.
|
||||
// - "query:<name>"
|
||||
// - "form:<name>"
|
||||
// - "cookie:<name>"
|
||||
// Multiple sources example:
|
||||
// - "header:Authorization,header:X-Api-Key"
|
||||
KeyLookup string |
||||
|
||||
// AuthScheme to be used in the Authorization header.
|
||||
// Optional. Default value "Bearer".
|
||||
AuthScheme string |
||||
|
||||
// Validator is a function to validate key.
|
||||
// Required.
|
||||
Validator KeyAuthValidator |
||||
|
||||
// ErrorHandler defines a function which is executed for an invalid key.
|
||||
// It may be used to define a custom error.
|
||||
ErrorHandler KeyAuthErrorHandler |
||||
|
||||
// ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to
|
||||
// ignore the error (by returning `nil`).
|
||||
// This is useful when parts of your site/api allow public access and some authorized routes provide extra functionality.
|
||||
// In that case you can use ErrorHandler to set a default public key auth value in the request context
|
||||
// and continue. Some logic down the remaining execution chain needs to check that (public) key auth value then.
|
||||
ContinueOnIgnoredError bool |
||||
} |
||||
|
||||
// DefaultKeyAuthConfig is the default KeyAuth middleware config.
|
||||
var DefaultKeyAuthConfig = KeyAuthConfig{ |
||||
Skipper: DefaultSkipper, |
||||
KeyLookup: "header:" + echo.HeaderAuthorization, |
||||
AuthScheme: "Bearer", |
||||
} |
||||
|
||||
func (a *KeyAuthConfig) ToMiddleware() echo.MiddlewareFunc { |
||||
return middleware.KeyAuthWithConfig(middleware.KeyAuthConfig{ |
||||
Skipper: a.Skipper, |
||||
KeyLookup: a.KeyLookup, |
||||
AuthScheme: a.AuthScheme, |
||||
Validator: a.Validator, |
||||
ErrorHandler: a.ErrorHandler, |
||||
ContinueOnIgnoredError: a.ContinueOnIgnoredError, |
||||
}) |
||||
} |
||||
|
||||
func KeyAuth() echo.MiddlewareFunc { |
||||
return DefaultKeyAuthConfig.ToMiddleware() |
||||
} |
@ -1,24 +0,0 @@ |
||||
package middleware |
||||
|
||||
import ( |
||||
"github.com/labstack/echo/v4" |
||||
"github.com/labstack/echo/v4/middleware" |
||||
) |
||||
|
||||
// Skipper defines a function to skip middleware. Returning true skips processing
|
||||
// the middleware.
|
||||
type Skipper = middleware.Skipper |
||||
|
||||
// BeforeFunc defines a function which is executed just before the middleware.
|
||||
type BeforeFunc = middleware.BeforeFunc |
||||
|
||||
type ValuesExtractor = middleware.ValuesExtractor |
||||
|
||||
type ToMiddleware interface { |
||||
ToMiddleware() echo.MiddlewareFunc |
||||
} |
||||
|
||||
// DefaultSkipper returns false which processes the middleware.
|
||||
func DefaultSkipper(echo.Context) bool { |
||||
return false |
||||
} |
@ -1,268 +0,0 @@ |
||||
package middleware |
||||
|
||||
import ( |
||||
"github.com/labstack/echo/v4/middleware" |
||||
"net/http" |
||||
"sync" |
||||
"time" |
||||
|
||||
"github.com/labstack/echo/v4" |
||||
"golang.org/x/time/rate" |
||||
) |
||||
|
||||
// RateLimiterStore is the interface to be implemented by custom stores.
|
||||
type RateLimiterStore interface { |
||||
// Allow Stores for the rate limiter have to implement the Allow method
|
||||
Allow(identifier string) (bool, error) |
||||
} |
||||
|
||||
type ( |
||||
// RateLimiterConfig defines the configuration for the rate limiter
|
||||
RateLimiterConfig struct { |
||||
Skipper Skipper |
||||
BeforeFunc middleware.BeforeFunc |
||||
// IdentifierExtractor uses echo.Context to extract the identifier for a visitor
|
||||
IdentifierExtractor Extractor |
||||
// Store defines a store for the rate limiter
|
||||
Store RateLimiterStore |
||||
// ErrorHandler provides a handler to be called when IdentifierExtractor returns an error
|
||||
ErrorHandler func(context echo.Context, err error) error |
||||
// DenyHandler provides a handler to be called when RateLimiter denies access
|
||||
DenyHandler func(context echo.Context, identifier string, err error) error |
||||
} |
||||
// Extractor is used to extract data from echo.Context
|
||||
Extractor func(context echo.Context) (string, error) |
||||
) |
||||
|
||||
// errors
|
||||
var ( |
||||
// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
|
||||
ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") |
||||
// ErrExtractorError denotes an error raised when extractor function is unsuccessful
|
||||
ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier") |
||||
) |
||||
|
||||
// DefaultRateLimiterConfig defines default values for RateLimiterConfig
|
||||
var DefaultRateLimiterConfig = RateLimiterConfig{ |
||||
Skipper: middleware.DefaultSkipper, |
||||
IdentifierExtractor: func(ctx echo.Context) (string, error) { |
||||
id := ctx.RealIP() |
||||
return id, nil |
||||
}, |
||||
ErrorHandler: func(context echo.Context, err error) error { |
||||
return &echo.HTTPError{ |
||||
Code: ErrExtractorError.Code, |
||||
Message: ErrExtractorError.Message, |
||||
Internal: err, |
||||
} |
||||
}, |
||||
DenyHandler: func(context echo.Context, identifier string, err error) error { |
||||
return &echo.HTTPError{ |
||||
Code: ErrRateLimitExceeded.Code, |
||||
Message: ErrRateLimitExceeded.Message, |
||||
Internal: err, |
||||
} |
||||
}, |
||||
} |
||||
|
||||
/* |
||||
RateLimiter returns a rate limiting middleware |
||||
|
||||
e := echo.New() |
||||
|
||||
limiterStore := middleware.NewRateLimiterMemoryStore(20) |
||||
|
||||
e.GET("/rate-limited", func(c echo.Context) error { |
||||
return c.String(http.StatusOK, "test") |
||||
}, RateLimiter(limiterStore)) |
||||
*/ |
||||
func RateLimiter(store RateLimiterStore) echo.MiddlewareFunc { |
||||
config := DefaultRateLimiterConfig |
||||
config.Store = store |
||||
return config.ToMiddleware() |
||||
} |
||||
|
||||
/* |
||||
ToMiddleware returns a rate limiting middleware |
||||
|
||||
e := echo.New() |
||||
|
||||
config := middleware.RateLimiterConfig{ |
||||
Skipper: DefaultSkipper, |
||||
Store: middleware.NewRateLimiterMemoryStore( |
||||
middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute} |
||||
) |
||||
IdentifierExtractor: func(ctx echo.Context) (string, error) { |
||||
id := ctx.RealIP() |
||||
return id, nil |
||||
}, |
||||
ErrorHandler: func(context echo.Context, err error) error { |
||||
return context.JSON(http.StatusTooManyRequests, nil) |
||||
}, |
||||
DenyHandler: func(context echo.Context, identifier string) error { |
||||
return context.JSON(http.StatusForbidden, nil) |
||||
}, |
||||
} |
||||
|
||||
e.GET("/rate-limited", func(c echo.Context) error { |
||||
return c.String(http.StatusOK, "test") |
||||
}, middleware.RateLimiterWithConfig(config)) |
||||
*/ |
||||
func (config *RateLimiterConfig) ToMiddleware() echo.MiddlewareFunc { |
||||
if config.Skipper == nil { |
||||
config.Skipper = DefaultSkipper |
||||
} |
||||
if config.IdentifierExtractor == nil { |
||||
config.IdentifierExtractor = DefaultRateLimiterConfig.IdentifierExtractor |
||||
} |
||||
if config.ErrorHandler == nil { |
||||
config.ErrorHandler = DefaultRateLimiterConfig.ErrorHandler |
||||
} |
||||
if config.DenyHandler == nil { |
||||
config.DenyHandler = DefaultRateLimiterConfig.DenyHandler |
||||
} |
||||
if config.Store == nil { |
||||
panic("Store configuration must be provided") |
||||
} |
||||
return func(next echo.HandlerFunc) echo.HandlerFunc { |
||||
return func(c echo.Context) error { |
||||
if config.Skipper(c) { |
||||
return next(c) |
||||
} |
||||
if config.BeforeFunc != nil { |
||||
config.BeforeFunc(c) |
||||
} |
||||
|
||||
identifier, err := config.IdentifierExtractor(c) |
||||
if err != nil { |
||||
c.Error(config.ErrorHandler(c, err)) |
||||
return nil |
||||
} |
||||
|
||||
if allow, err := config.Store.Allow(identifier); !allow { |
||||
c.Error(config.DenyHandler(c, identifier, err)) |
||||
return nil |
||||
} |
||||
return next(c) |
||||
} |
||||
} |
||||
} |
||||
|
||||
type ( |
||||
// RateLimiterMemoryStore is the built-in store implementation for RateLimiter
|
||||
RateLimiterMemoryStore struct { |
||||
visitors map[string]*Visitor |
||||
mutex sync.Mutex |
||||
rate rate.Limit // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
|
||||
|
||||
burst int |
||||
expiresIn time.Duration |
||||
lastCleanup time.Time |
||||
|
||||
timeNow func() time.Time |
||||
} |
||||
// Visitor signifies a unique user's limiter details
|
||||
Visitor struct { |
||||
*rate.Limiter |
||||
lastSeen time.Time |
||||
} |
||||
) |
||||
|
||||
/* |
||||
NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with |
||||
the provided rate (as req/s). |
||||
for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
|
||||
|
||||
Burst and ExpiresIn will be set to default values. |
||||
|
||||
Note that if the provided rate is a float number and Burst is zero, Burst will be treated as the rounded down value of the rate. |
||||
|
||||
Example (with 20 requests/sec): |
||||
|
||||
limiterStore := middleware.NewRateLimiterMemoryStore(20) |
||||
*/ |
||||
func NewRateLimiterMemoryStore(rate rate.Limit) (store *RateLimiterMemoryStore) { |
||||
return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ |
||||
Rate: rate, |
||||
}) |
||||
} |
||||
|
||||
/* |
||||
NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore |
||||
with the provided configuration. Rate must be provided. Burst will be set to the rounded down value of |
||||
the configured rate if not provided or set to 0. |
||||
|
||||
The build-in memory store is usually capable for modest loads. For higher loads other |
||||
store implementations should be considered. |
||||
|
||||
Characteristics: |
||||
* Concurrency above 100 parallel requests may causes measurable lock contention |
||||
* A high number of different IP addresses (above 16000) may be impacted by the internally used Go map |
||||
* A high number of requests from a single IP address may cause lock contention |
||||
|
||||
Example: |
||||
|
||||
limiterStore := middleware.NewRateLimiterMemoryStoreWithConfig( |
||||
middleware.RateLimiterMemoryStoreConfig{Rate: 50, Burst: 200, ExpiresIn: 5 * time.Minute}, |
||||
) |
||||
*/ |
||||
func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (store *RateLimiterMemoryStore) { |
||||
store = &RateLimiterMemoryStore{} |
||||
|
||||
store.rate = config.Rate |
||||
store.burst = config.Burst |
||||
store.expiresIn = config.ExpiresIn |
||||
if config.ExpiresIn == 0 { |
||||
store.expiresIn = DefaultRateLimiterMemoryStoreConfig.ExpiresIn |
||||
} |
||||
if config.Burst == 0 { |
||||
store.burst = int(config.Rate) |
||||
} |
||||
store.visitors = make(map[string]*Visitor) |
||||
store.timeNow = time.Now |
||||
store.lastCleanup = store.timeNow() |
||||
return |
||||
} |
||||
|
||||
// RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore
|
||||
type RateLimiterMemoryStoreConfig struct { |
||||
Rate rate.Limit // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit.
|
||||
Burst int // Burst is maximum number of requests to pass at the same moment. It additionally allows a number of requests to pass when rate limit is reached.
|
||||
ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up
|
||||
} |
||||
|
||||
// DefaultRateLimiterMemoryStoreConfig provides default configuration values for RateLimiterMemoryStore
|
||||
var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{ |
||||
ExpiresIn: 3 * time.Minute, |
||||
} |
||||
|
||||
// Allow implements RateLimiterStore.Allow
|
||||
func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { |
||||
store.mutex.Lock() |
||||
limiter, exists := store.visitors[identifier] |
||||
if !exists { |
||||
limiter = new(Visitor) |
||||
limiter.Limiter = rate.NewLimiter(store.rate, store.burst) |
||||
store.visitors[identifier] = limiter |
||||
} |
||||
now := store.timeNow() |
||||
limiter.lastSeen = now |
||||
if now.Sub(store.lastCleanup) > store.expiresIn { |
||||
store.cleanupStaleVisitors() |
||||
} |
||||
store.mutex.Unlock() |
||||
return limiter.AllowN(store.timeNow(), 1), nil |
||||
} |
||||
|
||||
/* |
||||
cleanupStaleVisitors helps manage the size of the visitors map by removing stale records |
||||
of users who haven't visited again after the configured expiry time has elapsed |
||||
*/ |
||||
func (store *RateLimiterMemoryStore) cleanupStaleVisitors() { |
||||
for id, visitor := range store.visitors { |
||||
if store.timeNow().Sub(visitor.lastSeen) > store.expiresIn { |
||||
delete(store.visitors, id) |
||||
} |
||||
} |
||||
store.lastCleanup = store.timeNow() |
||||
} |
@ -0,0 +1,108 @@ |
||||
package runtime |
||||
|
||||
import ( |
||||
"context" |
||||
"sorbet/pkg/log" |
||||
"sync" |
||||
) |
||||
|
||||
var ( |
||||
// 并发组控制器
|
||||
wg sync.WaitGroup |
||||
// 运行时上下文对象
|
||||
ctx context.Context |
||||
// 信号量,控制任务并发数量
|
||||
semaphore chan struct{} |
||||
// 异步任务通道
|
||||
tasks chan func(context.Context) |
||||
// 上下文取消函数
|
||||
cancel context.CancelFunc |
||||
// 程序退出通道
|
||||
exit chan chan error |
||||
) |
||||
|
||||
func tryCancel() { |
||||
cancelFunc := cancel |
||||
if cancelFunc == nil { |
||||
log.Panic("context already cancelled") |
||||
} |
||||
cancel = nil |
||||
cancelFunc() |
||||
// 等待任务执行完成
|
||||
for len(semaphore) > 0 { |
||||
work(nil) |
||||
} |
||||
// 理论上在没有信号量的情况下,是没有任务队列的。
|
||||
for len(tasks) > 0 { |
||||
log.Panic("nonempty task channel") |
||||
} |
||||
// 销毁注册的服务
|
||||
for i := len(servlets) - 1; i >= 0; i-- { |
||||
if err := servlets[i].Stop(); err != nil { |
||||
log.Warn("servlet Stop returned with error: ", err) |
||||
} |
||||
} |
||||
ctx = nil |
||||
} |
||||
|
||||
// Go 异步任务
|
||||
func Go(fn func(context.Context)) { |
||||
if semaphore == nil || ctx == nil { |
||||
log.Panic("not initialized") |
||||
return |
||||
} |
||||
select { |
||||
case semaphore <- struct{}{}: |
||||
// If we are below our limit, spawn a new worker rather
|
||||
// than waiting for one to become available.
|
||||
async(&wg, func(context.Context) { |
||||
work(fn) |
||||
}) |
||||
case tasks <- fn: |
||||
// 放到任务队列中
|
||||
// A worker is available and has accepted the task.
|
||||
} |
||||
} |
||||
|
||||
// Async 异步执行函数
|
||||
func Async(fn func(context.Context)) { |
||||
async(&wg, fn) |
||||
} |
||||
|
||||
// 执行任务
|
||||
func work(f func(context.Context)) { |
||||
defer func() { |
||||
if semaphore != nil { |
||||
// 释放
|
||||
<-semaphore |
||||
} |
||||
}() |
||||
|
||||
// todo 能不能复用 sync.WaitGroup
|
||||
g := &sync.WaitGroup{} |
||||
if f != nil { |
||||
Async(f) |
||||
} |
||||
for task := range tasks { |
||||
Async(task) |
||||
} |
||||
g.Wait() |
||||
} |
||||
|
||||
// 异步执行函数
|
||||
func async(wg *sync.WaitGroup, fn func(context.Context)) { |
||||
wg.Add(1) |
||||
go func() { |
||||
defer wg.Done() |
||||
if ctx == nil { |
||||
return |
||||
} |
||||
select { |
||||
case <-ctx.Done(): |
||||
// 上下文一旦结束,任务将被忽略
|
||||
log.Error("context done") |
||||
default: |
||||
fn(ctx) |
||||
} |
||||
}() |
||||
} |
@ -0,0 +1,90 @@ |
||||
package runtime |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"os" |
||||
"os/signal" |
||||
"sorbet/pkg/log" |
||||
"syscall" |
||||
) |
||||
|
||||
// Init 初始化
|
||||
func Init() error { |
||||
return InitWith(context.Background()) |
||||
} |
||||
|
||||
func InitWith(root context.Context) error { |
||||
if ctx != nil || cancel != nil { |
||||
return errors.New("not stopped") |
||||
} |
||||
// 初始化上下文相关
|
||||
ctx, cancel = context.WithCancel(root) |
||||
// 初始化退出通道
|
||||
if exit == nil { |
||||
exit = make(chan chan error) |
||||
} |
||||
// 初始化信号量
|
||||
if semaphore == nil { |
||||
semaphore = make(chan struct{}, 256) |
||||
} |
||||
// 初始化任务通道
|
||||
if tasks != nil { |
||||
tasks = make(chan func(context.Context), 256) |
||||
} |
||||
// 初始化网络服务组件
|
||||
for _, servlet := range servlets { |
||||
if err := servlet.Init(ctx); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func Start() error { |
||||
// 创建服务器
|
||||
srv, ln, err := createServer() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
// 启动网络服务器
|
||||
go func() { |
||||
err := srv.Serve(ln) |
||||
if err != nil { |
||||
log.Error("encountered an error while serving listener: ", err) |
||||
} |
||||
}() |
||||
log.Info("Listening on %s", ln.Addr().String()) |
||||
|
||||
// 监听停止命令,停止网络服务
|
||||
go func() { |
||||
errChan := <-exit |
||||
tryCancel() |
||||
errChan <- ln.Close() // stop the listener
|
||||
}() |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func Run() error { |
||||
if err := Start(); err != nil { |
||||
return err |
||||
} |
||||
// parse address for host, port
|
||||
ch := make(chan os.Signal, 1) |
||||
signal.Notify(ch, syscall.SIGTERM, syscall.SIGINT) |
||||
log.Info("Received signal %s", <-ch) |
||||
return Stop() |
||||
} |
||||
|
||||
// Stop 停止运行
|
||||
func Stop() error { |
||||
if ctx == nil { |
||||
return errors.New("already stopped") |
||||
} |
||||
ch := make(chan error) |
||||
exit <- ch |
||||
wg.Wait() |
||||
return <-ch |
||||
} |
@ -0,0 +1,121 @@ |
||||
package runtime |
||||
|
||||
import ( |
||||
"fmt" |
||||
"github.com/labstack/echo/v4" |
||||
"github.com/labstack/echo/v4/middleware" |
||||
"net" |
||||
"net/http" |
||||
"sorbet/internal/util" |
||||
"sorbet/pkg/env" |
||||
"sorbet/pkg/rsp" |
||||
"time" |
||||
) |
||||
|
||||
var ( |
||||
// maxHeaderBytes is used by the http server to limit the size of request headers.
|
||||
// This may need to be increased if accepting cookies from the public.
|
||||
maxHeaderBytes = 1 << 20 |
||||
// readTimeout is used by the http server to set a maximum duration before
|
||||
// timing out read of the request. The default timeout is 10 seconds.
|
||||
readTimeout = 10 * time.Second |
||||
// writeTimeout is used by the http server to set a maximum duration before
|
||||
// timing out write of the response. The default timeout is 10 seconds.
|
||||
writeTimeout = 10 * time.Second |
||||
// idleTimeout is used by the http server to set a maximum duration for
|
||||
// keep-alive connections.
|
||||
idleTimeout = 120 * time.Second |
||||
) |
||||
|
||||
func newEchoFramework() (*echo.Echo, error) { |
||||
e := echo.New() |
||||
e.HideBanner = true |
||||
e.HidePort = true |
||||
e.HTTPErrorHandler = rsp.HTTPErrorHandler |
||||
e.Debug = !env.IsEnv("prod") |
||||
e.Logger = util.NewEchoLogger() |
||||
e.Use(middleware.Recover()) |
||||
e.Use(middleware.CORS()) |
||||
e.Use(middleware.Logger()) |
||||
for _, servlet := range servlets { |
||||
group := e.Group("") |
||||
for _, routable := range servlet.Routes() { |
||||
routable.InitRoutes(group) |
||||
} |
||||
} |
||||
routes := e.Routes() |
||||
e.GET("/_routes", func(c echo.Context) error { |
||||
return c.JSON(http.StatusOK, routes) |
||||
}) |
||||
return e, nil |
||||
} |
||||
|
||||
func createServer() (*http.Server, net.Listener, error) { |
||||
e, err := newEchoFramework() |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
l, err := createTCPListener() |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
s := &http.Server{ |
||||
Handler: e, |
||||
MaxHeaderBytes: env.Int("SERVER_MAX_HEADER_BYTES", maxHeaderBytes), |
||||
ReadTimeout: env.Duration("SERVER_READ_TIMEOUT", readTimeout), |
||||
WriteTimeout: env.Duration("SERVER_WRITE_TIMEOUT", writeTimeout), |
||||
IdleTimeout: env.Duration("SERVER_IDLE_TIMEOUT", idleTimeout), |
||||
} |
||||
|
||||
return s, l, nil |
||||
} |
||||
|
||||
func createTCPListener() (net.Listener, error) { |
||||
l, err := net.Listen( |
||||
env.String("SERVER_NETWORK", "tcp"), |
||||
fmt.Sprintf("%s:%d", |
||||
env.String("SERVER_ADDRESS", "0.0.0.0"), |
||||
env.Int("SERVER_PORT", 1324), |
||||
), |
||||
) |
||||
if err == nil { |
||||
l = net.Listener(TCPKeepAliveListener{ |
||||
TCPListener: l.(*net.TCPListener), |
||||
}) |
||||
} |
||||
return l, err |
||||
} |
||||
|
||||
// TCPKeepAliveListener sets TCP keep-alive timeouts on accepted
|
||||
// connections. It's used by ListenAndServe and ListenAndServeTLS so
|
||||
// dead TCP connections (e.g. closing laptop mid-download) eventually
|
||||
// go away.
|
||||
//
|
||||
// This is here because it is not exposed in the stdlib and
|
||||
// we'd prefer to have a hold of the http.Server's net.Listener so we can close it
|
||||
// on shutdown.
|
||||
//
|
||||
// Taken from here: https://golang.org/src/net/http/server.go?s=63121:63175#L2120
|
||||
type TCPKeepAliveListener struct { |
||||
*net.TCPListener |
||||
} |
||||
|
||||
// Accept accepts the next incoming call and returns the new
|
||||
// connection. KeepAlivePeriod is set properly.
|
||||
func (ln TCPKeepAliveListener) Accept() (c net.Conn, err error) { |
||||
tc, err := ln.AcceptTCP() |
||||
if err != nil { |
||||
return |
||||
} |
||||
err = tc.SetKeepAlive(true) |
||||
if err != nil { |
||||
return |
||||
} |
||||
err = tc.SetKeepAlivePeriod(3 * time.Minute) |
||||
if err != nil { |
||||
return |
||||
} |
||||
return tc, nil |
||||
} |
@ -0,0 +1,69 @@ |
||||
package runtime |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"github.com/labstack/echo/v4" |
||||
"slices" |
||||
) |
||||
|
||||
// 注册的网络服务组件
|
||||
var servlets []Servlet |
||||
|
||||
// Servlet 网络服务组件接口
|
||||
type Servlet interface { |
||||
// Name 服务名称
|
||||
Name() string |
||||
// Priority 优先级,用于启动和销毁的执行顺序
|
||||
Priority() int |
||||
// Init 初始化服务
|
||||
Init(ctx context.Context) error |
||||
// Routes 注册路由
|
||||
Routes() []Routable |
||||
// Start 非阻塞方式启动服务
|
||||
Start() error |
||||
// Stop 停止服务
|
||||
Stop() error |
||||
} |
||||
|
||||
func Reset() error { |
||||
if len(servlets) == 0 { |
||||
return nil |
||||
} |
||||
if ctx == nil { |
||||
return errors.New("servlets is running") |
||||
} |
||||
servlets = nil |
||||
return nil |
||||
} |
||||
|
||||
// Use 注册服务
|
||||
func Use(servlets ...Servlet) error { |
||||
if len(servlets) == 0 { |
||||
return nil |
||||
} |
||||
for i := 0; i < len(servlets); i++ { |
||||
if !use(servlets[i]) { |
||||
return errors.New("service already registered") |
||||
} |
||||
} |
||||
slices.SortFunc(servlets, func(a, b Servlet) int { |
||||
return b.Priority() - a.Priority() // 按优先级排序
|
||||
}) |
||||
return nil |
||||
} |
||||
|
||||
func use(servlet Servlet) bool { |
||||
exists := slices.ContainsFunc(servlets, func(s Servlet) bool { |
||||
return s.Name() == servlet.Name() |
||||
}) |
||||
if !exists { |
||||
servlets = append(servlets, servlet) |
||||
return true |
||||
} |
||||
return false |
||||
} |
||||
|
||||
type Routable interface { |
||||
InitRoutes(*echo.Group) |
||||
} |
@ -1,72 +0,0 @@ |
||||
package services |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"github.com/labstack/echo/v4" |
||||
"slices" |
||||
"sorbet/internal/services/company" |
||||
"sorbet/internal/services/config" |
||||
"sorbet/internal/services/feature" |
||||
"sorbet/internal/services/resource" |
||||
"sorbet/internal/services/system" |
||||
"sorbet/pkg/crud" |
||||
) |
||||
|
||||
var services []crud.Service |
||||
|
||||
const ContextEchoKey = "echo_framework" |
||||
|
||||
func Register(service crud.Service) error { |
||||
for _, applet := range services { |
||||
if applet.Name() == service.Name() { |
||||
return errors.New("service already registered") |
||||
} |
||||
} |
||||
services = append(services, service) |
||||
return nil |
||||
} |
||||
|
||||
func Init(ctx context.Context) error { |
||||
services = []crud.Service{ |
||||
&config.Service{}, |
||||
&company.Service{}, |
||||
&resource.Service{}, |
||||
&feature.Service{}, |
||||
&system.Service{}, |
||||
} |
||||
|
||||
// 按优先级排序
|
||||
slices.SortFunc(services, func(a, b crud.Service) int { |
||||
return b.Priority() - a.Priority() |
||||
}) |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func Bootstrap(ctx context.Context) error { |
||||
e := ctx.Value(ContextEchoKey).(*echo.Echo) |
||||
for _, service := range services { |
||||
err := service.Init(crud.NewContext(ctx, e.Group(""))) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
err = service.Bootstrap() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func Destroy() error { |
||||
for i := len(services) - 1; i >= 0; i++ { |
||||
service := services[i] |
||||
err := service.Destroy() |
||||
if err != nil { |
||||
// TODO(hupeh): 是否需要销毁策略,比如可以继续或者中断等行为
|
||||
return err |
||||
} |
||||
} |
||||
return nil |
||||
} |
@ -1,128 +0,0 @@ |
||||
package app |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"os" |
||||
"sorbet/pkg/app/appkit" |
||||
) |
||||
|
||||
func Init(args ...string) (err error) { |
||||
if len(args) == 0 { |
||||
args = os.Args[1:] |
||||
} |
||||
|
||||
switch Status() { |
||||
case Initialized: |
||||
return nil |
||||
case Starting: |
||||
return errors.New("starting") |
||||
case Running: |
||||
return errors.New("running") |
||||
case Stopping: |
||||
return errors.New("stopping") |
||||
case Idle, Stopped: |
||||
// continue
|
||||
} |
||||
|
||||
initLifecycle() |
||||
initStats(args) |
||||
initUdp(args) |
||||
initBus() |
||||
initBuiltins() |
||||
|
||||
setStatus(Initialized) |
||||
|
||||
return |
||||
} |
||||
|
||||
func initBuiltins() { |
||||
// 订阅应用启动
|
||||
var sub appkit.Subscriber |
||||
sub = Sub("start", func([]byte) []byte { |
||||
sub.Cancel() |
||||
return nil |
||||
}) |
||||
|
||||
// 订阅应用停止
|
||||
Sub("stop", func(bytes []byte) []byte { |
||||
quit := exit |
||||
if quit != nil { |
||||
exit = nil |
||||
quit() |
||||
} |
||||
return nil |
||||
}) |
||||
} |
||||
|
||||
func Loop() error { |
||||
switch Status() { |
||||
case Starting, Running: |
||||
return nil |
||||
case Stopping: |
||||
return errors.New("stopping") |
||||
case Idle: |
||||
return errors.New("idle, you maybe forgot call app.Init()") |
||||
case Stopped: |
||||
return errors.New("stopped, you maybe forgot call app.Init()") |
||||
case Initialized: |
||||
// nothing
|
||||
} |
||||
|
||||
// 释放内存
|
||||
defer free() |
||||
|
||||
wg.Add(1) |
||||
setStatus(Starting) |
||||
|
||||
// In a goroutine, we wait on for all goroutines to complete (for example
|
||||
// timers). We use this to signal to the main thread to exit.
|
||||
// wg.Add(1) basically translates to uv_ref, if this was Node.
|
||||
// wg.Done() basically translates to uv_unref
|
||||
go func() { |
||||
wg.Wait() |
||||
setStatus(Stopping) |
||||
stop() |
||||
}() |
||||
|
||||
// 启动应用指令
|
||||
Pub("start", nil) |
||||
|
||||
for { |
||||
select { |
||||
case msg := <-pubs: |
||||
err := dispatch(msg) |
||||
exitOnError(err) |
||||
wg.Done() // Corresponds to the wg.Add(1) in Pub().
|
||||
|
||||
case <-ctx.Done(): |
||||
// 只有等待所有协程全部完成后,
|
||||
// 然后才可以退出主循环
|
||||
checkPubsEmpty() |
||||
setStatus(Stopped) |
||||
err := ctx.Err() |
||||
if errors.Is(err, context.Canceled) { |
||||
return nil |
||||
} |
||||
return err |
||||
} |
||||
|
||||
// We don't want to exit until we've received at least one message.
|
||||
// This is so the program doesn't exit after sending the "start"
|
||||
// message.
|
||||
if Status() == Starting { |
||||
wg.Done() |
||||
start() |
||||
setStatus(Running) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func Stop() { |
||||
Pub("stop", nil) |
||||
} |
||||
|
||||
func Go(fn func(context.Context)) { |
||||
assert(fn != nil, "invalid fn") |
||||
async(fn) |
||||
} |
@ -1,5 +0,0 @@ |
||||
package appkit |
||||
|
||||
type Subscriber interface { |
||||
Cancel() |
||||
} |
@ -1,81 +0,0 @@ |
||||
package appkit |
||||
|
||||
import ( |
||||
"bytes" |
||||
"encoding/binary" |
||||
"errors" |
||||
"math" |
||||
) |
||||
|
||||
// 数据协议如下:
|
||||
//
|
||||
// +-----------+--------+--------------+-------+------+
|
||||
// | 0x01,0x02 | length | topic-length | topic | data |
|
||||
// +-----------+--------+--------------+-------+------+
|
||||
// 2 2 2 n n
|
||||
//
|
||||
// * 头部有两个模数 0x01、0x02;
|
||||
// * length 表示消息的总长度;
|
||||
// * 然后分别是 topic、data,包含其长度和数据;
|
||||
// * 长度使用 uint16 类型,所以占用 2 个字节;
|
||||
//
|
||||
// todo 实现数据校验码,防止篡改
|
||||
|
||||
// Encode 数据编码
|
||||
func Encode(topic, data []byte) ([]byte, error) { |
||||
topicLen := len(topic) |
||||
dataLen := len(data) |
||||
size := 2 + topicLen + dataLen // 不包括模数和数据长度
|
||||
if topicLen == 0 { |
||||
return nil, errors.New("topic too small") |
||||
} |
||||
if size+4 > math.MaxUint16 { |
||||
return nil, errors.New("data too big") |
||||
} |
||||
buf := &bytes.Buffer{} |
||||
// 由于 模数、topic 和 data 都是基本类型,所以当
|
||||
// 执行 binary.Write() 时是不会出错的
|
||||
_ = binary.Write(buf, binary.BigEndian, []byte{0x01, 0x02}) |
||||
_ = binary.Write(buf, binary.BigEndian, uint16(size)) |
||||
_ = binary.Write(buf, binary.BigEndian, uint16(topicLen)) |
||||
_ = binary.Write(buf, binary.BigEndian, topic) |
||||
_ = binary.Write(buf, binary.BigEndian, data) // 如果没有数据会不会报错
|
||||
return buf.Bytes(), nil |
||||
} |
||||
|
||||
// Decode 数据解码
|
||||
func Decode(buf []byte) (topic, data []byte, err error) { |
||||
r := bytes.NewReader(buf) |
||||
// 读取模数
|
||||
var v [2]byte |
||||
err = binary.Read(r, binary.BigEndian, &v) |
||||
if err != nil { |
||||
return |
||||
} |
||||
if v[0] != 0x01 || v[1] != 0x02 { |
||||
err = errors.New("protocol error") |
||||
return |
||||
} |
||||
// 读取数据长度
|
||||
var size uint16 |
||||
err = binary.Read(r, binary.BigEndian, &size) |
||||
if err != nil { |
||||
return |
||||
} |
||||
// 读取主题长度
|
||||
var topicLen uint16 |
||||
err = binary.Read(r, binary.BigEndian, &topicLen) |
||||
if err != nil { |
||||
return |
||||
} |
||||
// 读取主题
|
||||
topic = make([]byte, topicLen) |
||||
err = binary.Read(r, binary.BigEndian, &topic) |
||||
if err != nil { |
||||
return |
||||
} |
||||
// 读取数据
|
||||
data = make([]byte, size-2-topicLen) |
||||
err = binary.Read(r, binary.BigEndian, &data) |
||||
return |
||||
} |
@ -1,215 +0,0 @@ |
||||
package app |
||||
|
||||
import ( |
||||
"errors" |
||||
"net" |
||||
"slices" |
||||
"sorbet/pkg/app/appkit" |
||||
"sync" |
||||
"sync/atomic" |
||||
) |
||||
|
||||
var ( |
||||
// 操作锁
|
||||
//
|
||||
// * 添加订阅器
|
||||
lock sync.RWMutex |
||||
|
||||
// 用于控制主体程序并发,涉及以下几个方面:
|
||||
//
|
||||
// * 任务调度
|
||||
// * 定时器
|
||||
// * 发布订阅
|
||||
// * 主程序轮询
|
||||
wg sync.WaitGroup |
||||
|
||||
// 消息发布通道
|
||||
pubs chan Msg |
||||
|
||||
// 消息订阅器
|
||||
subs map[string][]*subscriber |
||||
|
||||
nextSubId int32 |
||||
) |
||||
|
||||
func initBus() { |
||||
if pubs == nil { |
||||
pubs = make(chan Msg, 256) |
||||
} |
||||
if subs == nil { |
||||
subs = make(map[string][]*subscriber) |
||||
} |
||||
OnFree(freeBus) |
||||
} |
||||
|
||||
func freeBus() { |
||||
clear(subs) |
||||
|
||||
select { |
||||
case _, ok := <-pubs: |
||||
if !ok { |
||||
// 通道已经关闭了,需要重新构建
|
||||
pubs = make(chan Msg, 256) |
||||
return |
||||
} |
||||
// 清空 pubs 缓存的数据,复用它
|
||||
l := len(pubs) |
||||
for i := 0; i < l; i++ { |
||||
<-pubs |
||||
} |
||||
default: |
||||
// 通道里面没有值,可以复用
|
||||
} |
||||
} |
||||
|
||||
func Sub(topic string, handle func([]byte) []byte) appkit.Subscriber { |
||||
lock.Lock() |
||||
defer lock.Unlock() |
||||
sub := &subscriber{ |
||||
id: atomic.AddInt32(&nextSubId, 1), |
||||
topic: topic, |
||||
handle: handle, |
||||
} |
||||
if _, ok := subs[topic]; !ok { |
||||
subs[topic] = make([]*subscriber, 0) |
||||
} |
||||
subs[topic] = append(subs[topic], sub) |
||||
return sub |
||||
} |
||||
|
||||
func Pub(topic string, data []byte) { |
||||
pubMsg(Msg{Data: data, topic: topic}) |
||||
} |
||||
|
||||
func pubMsg(msg Msg) { |
||||
wg.Add(1) |
||||
pubs <- msg |
||||
} |
||||
|
||||
type Msg struct { |
||||
Data []byte |
||||
|
||||
topic string |
||||
addr net.Addr |
||||
pc net.PacketConn |
||||
} |
||||
|
||||
// Size 消息长度
|
||||
func (m Msg) Size() int { |
||||
return 6 + len([]byte(m.topic)) + len(m.Data) |
||||
} |
||||
|
||||
type subscriber struct { |
||||
id int32 |
||||
topic string |
||||
handle func([]byte) []byte |
||||
} |
||||
|
||||
func (s *subscriber) Active() bool { |
||||
return atomic.LoadInt32(&s.id) > 0 |
||||
} |
||||
|
||||
func (s *subscriber) Cancel() { |
||||
id := atomic.SwapInt32(&s.id, 0) |
||||
if id == 0 { |
||||
return |
||||
} |
||||
lock.Lock() |
||||
defer lock.Unlock() |
||||
subscribers, ok := subs[s.topic] |
||||
if !ok { |
||||
return |
||||
} |
||||
if len(subscribers) > 0 { |
||||
subscribers = slices.DeleteFunc(subscribers, func(sub *subscriber) bool { |
||||
return sub.id == id |
||||
}) |
||||
} |
||||
if len(subscribers) > 0 { |
||||
subs[s.topic] = subscribers |
||||
} else { |
||||
delete(subs, s.topic) |
||||
} |
||||
} |
||||
|
||||
func (s *subscriber) invoke(data []byte) []byte { |
||||
if s.Active() { |
||||
return s.handle(data) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func dispatch(msg Msg) error { |
||||
lock.RLock() |
||||
var cbs []*subscriber |
||||
if subscribers, ok := subs[msg.topic]; ok { |
||||
for _, s := range subscribers { |
||||
if s.Active() { |
||||
cbs = append(cbs, s) |
||||
} |
||||
} |
||||
} |
||||
lock.RUnlock() |
||||
|
||||
if len(cbs) == 0 { |
||||
return errors.New("no subscribers for topic " + msg.topic) |
||||
} |
||||
|
||||
var response []byte |
||||
for _, cb := range cbs { |
||||
res := cb.invoke(msg.Data) |
||||
if res != nil { |
||||
response = res |
||||
} |
||||
} |
||||
|
||||
if len(response) > 0 && msg.pc != nil { |
||||
stats.v8workerRespond++ |
||||
stats.v8workerBytesSent += len(response) |
||||
|
||||
n, ex := msg.pc.WriteTo(response, msg.addr) |
||||
if n > 0 { |
||||
stats.v8workerSend++ |
||||
stats.v8workerBytesSent += n |
||||
} |
||||
if ex != nil { |
||||
// todo handle the error
|
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func recv(pc net.PacketConn, addr net.Addr, buf []byte) { |
||||
stats.v8workerRecv++ |
||||
stats.v8workerBytesRecv += len(buf) |
||||
|
||||
topic, data, err := appkit.Decode(buf) |
||||
if err != nil { |
||||
// todo handle the error
|
||||
// return errors.New("invalid payload")
|
||||
return |
||||
} |
||||
|
||||
pubMsg(Msg{ |
||||
Data: data, |
||||
topic: string(topic), |
||||
addr: addr, |
||||
pc: pc, |
||||
}) |
||||
} |
||||
|
||||
func checkPubsEmpty() { |
||||
// We've received a done event. As a sanity check, make sure that resChan is
|
||||
// empty.
|
||||
select { |
||||
case _, ok := <-pubs: |
||||
if ok { |
||||
panic("Read a message from pubs after context deadlined.") |
||||
} else { |
||||
panic("pubs closed. Unexpected.") |
||||
} |
||||
default: |
||||
// No value ready, moving on.
|
||||
} |
||||
} |
@ -1,83 +0,0 @@ |
||||
package app |
||||
|
||||
import ( |
||||
"context" |
||||
"sync/atomic" |
||||
) |
||||
|
||||
const ( |
||||
Idle int32 = iota |
||||
Initialized |
||||
Starting |
||||
Running |
||||
Stopping |
||||
Stopped |
||||
) |
||||
|
||||
var ( |
||||
// 当前状态
|
||||
status int32 |
||||
|
||||
// 应用上下文
|
||||
ctx context.Context |
||||
exit context.CancelFunc |
||||
|
||||
start func() // 应用启动钩子
|
||||
stop func() // 应用停止钩子
|
||||
free func() // 内存释放钩子
|
||||
) |
||||
|
||||
func initLifecycle() { |
||||
ctx, exit = context.WithCancel(context.Background()) |
||||
|
||||
start = func() { |
||||
// nothing
|
||||
} |
||||
|
||||
stop = func() { |
||||
if exit != nil { |
||||
exit() |
||||
} |
||||
} |
||||
|
||||
free = func() { |
||||
ctx = nil |
||||
start = nil |
||||
stop = nil |
||||
exit = nil |
||||
} |
||||
} |
||||
|
||||
// setStatus 设置应用状态
|
||||
func setStatus(newStatus int32) { |
||||
atomic.StoreInt32(&status, newStatus) |
||||
} |
||||
|
||||
// Status 返回当前状态
|
||||
func Status() int32 { |
||||
return atomic.LoadInt32(&status) |
||||
} |
||||
|
||||
func OnStart(fn func()) { |
||||
oldStart := start |
||||
start = func() { |
||||
oldStart() |
||||
fn() |
||||
} |
||||
} |
||||
|
||||
func OnStop(fn func()) { |
||||
oldStop := stop |
||||
stop = func() { |
||||
fn() |
||||
oldStop() |
||||
} |
||||
} |
||||
|
||||
func OnFree(fn func()) { |
||||
oldFree := free |
||||
free = func() { |
||||
fn() |
||||
oldFree() |
||||
} |
||||
} |
@ -1,44 +0,0 @@ |
||||
package app |
||||
|
||||
import ( |
||||
"os" |
||||
"reflect" |
||||
"runtime/pprof" |
||||
) |
||||
|
||||
var stats struct { |
||||
v8workerSend int |
||||
v8workerRespond int |
||||
v8workerRecv int |
||||
v8workerBytesSent int |
||||
v8workerBytesRecv int |
||||
} |
||||
|
||||
func initStats(args []string) { |
||||
initCpuProfile(strArg(args, "--cpuprof")) |
||||
initMemoryProfile(strArg(args, "--memprof")) |
||||
OnFree(freeStats) |
||||
} |
||||
|
||||
func initCpuProfile(cpuprof string) { |
||||
if cpuprof != "" { |
||||
cpuProfile, err := os.Create(cpuprof) |
||||
check(err) |
||||
check(pprof.StartCPUProfile(cpuProfile)) |
||||
OnStop(pprof.StopCPUProfile) |
||||
} |
||||
} |
||||
|
||||
func initMemoryProfile(memprof string) { |
||||
if memprof != "" { |
||||
memProfile, err := os.Create(memprof) |
||||
check(err) |
||||
check(pprof.WriteHeapProfile(memProfile)) |
||||
OnStop(func() { check(memProfile.Close()) }) |
||||
} |
||||
} |
||||
|
||||
func freeStats() { |
||||
newStats := reflect.New(reflect.TypeOf(stats)) |
||||
reflect.ValueOf(&stats).Elem().Set(newStats.Elem()) |
||||
} |
@ -1,63 +0,0 @@ |
||||
package app |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"io" |
||||
"net" |
||||
) |
||||
|
||||
// 通过 UDP 协议实现外部通信
|
||||
var pc net.PacketConn |
||||
|
||||
func initUdp(args []string) { |
||||
// 未开启 udp
|
||||
if !boolArg(args, "--udp") { |
||||
return |
||||
} |
||||
|
||||
// 这一步属于初始化,切不可放到 OnStart 里面
|
||||
sizeLimit := intArg(args, "--payload", 1024<<2) |
||||
addr := strArg(args, "--addr", ":1324") |
||||
|
||||
// 当程序启动时开启 udp 服务器
|
||||
OnStart(func() { |
||||
listenUDP(addr, sizeLimit) |
||||
}) |
||||
|
||||
OnStop(freeUdp) |
||||
} |
||||
|
||||
func freeUdp() { |
||||
if pc != nil { |
||||
err := pc.Close() |
||||
pc = nil |
||||
check(err) |
||||
} |
||||
} |
||||
|
||||
func listenUDP(address string, maxSize int) { |
||||
var err error |
||||
pc, err = net.ListenPacket("udp", address) |
||||
check(err) |
||||
|
||||
async(func(ctx context.Context) { |
||||
for { |
||||
select { |
||||
case <-ctx.Done(): |
||||
return |
||||
default: |
||||
} |
||||
|
||||
buf := make([]byte, maxSize) |
||||
n, addr, err := pc.ReadFrom(buf) |
||||
if err != nil { |
||||
if errors.Is(err, io.EOF) { |
||||
return |
||||
} |
||||
continue |
||||
} |
||||
recv(pc, addr, buf[:n]) |
||||
} |
||||
}) |
||||
} |
@ -1,101 +0,0 @@ |
||||
package app |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"os" |
||||
"slices" |
||||
"strconv" |
||||
"strings" |
||||
) |
||||
|
||||
func assert(cond bool, msg string) { |
||||
if !cond { |
||||
panic(msg) |
||||
} |
||||
} |
||||
|
||||
func check(e error) { |
||||
if e != nil { |
||||
panic(e) |
||||
} |
||||
} |
||||
|
||||
func exitOnError(err error) { |
||||
if err != nil { |
||||
os.Stderr.WriteString(err.Error()) |
||||
os.Exit(1) |
||||
} |
||||
} |
||||
|
||||
func async(cb func(ctx context.Context)) { |
||||
wg.Add(1) |
||||
go func() { |
||||
defer wg.Done() |
||||
select { |
||||
case <-ctx.Done(): |
||||
return |
||||
default: |
||||
cb(ctx) |
||||
} |
||||
}() |
||||
} |
||||
|
||||
func getArg[T any](args []string, name string, f func(string, bool) T) T { |
||||
i := slices.Index(args, name) |
||||
if i == -1 { |
||||
return f("", false) |
||||
} |
||||
if i == len(args)-1 { |
||||
panic(errors.New("invalid " + name + " option")) |
||||
} |
||||
s := args[i+1] |
||||
if strings.HasPrefix(s, "--") { |
||||
return f("", true) |
||||
} |
||||
return f(s, true) |
||||
} |
||||
|
||||
func strArg(args []string, name string, def ...string) string { |
||||
return getArg(args, name, func(s string, b bool) string { |
||||
if b { |
||||
return s |
||||
} |
||||
for _, v := range def { |
||||
return v |
||||
} |
||||
return "" |
||||
}) |
||||
} |
||||
|
||||
func intArg(args []string, name string, def ...int) int { |
||||
return getArg(args, name, func(s string, b bool) int { |
||||
if b { |
||||
size, err := strconv.Atoi(s) |
||||
check(err) |
||||
assert(size > 0, "invalid "+name+" option") |
||||
return size |
||||
} |
||||
for _, v := range def { |
||||
return v |
||||
} |
||||
return 0 |
||||
}) |
||||
} |
||||
|
||||
func boolArg(args []string, name string, def ...bool) bool { |
||||
return getArg(args, name, func(s string, b bool) bool { |
||||
if b { |
||||
if s == "" { |
||||
return true |
||||
} |
||||
v, err := strconv.ParseBool(s) |
||||
check(err) |
||||
return v |
||||
} |
||||
for _, v := range def { |
||||
return v |
||||
} |
||||
return false |
||||
}) |
||||
} |
@ -1,53 +0,0 @@ |
||||
package crud |
||||
|
||||
import ( |
||||
"context" |
||||
"github.com/labstack/echo/v4" |
||||
"sync" |
||||
) |
||||
|
||||
type Context struct { |
||||
context.Context |
||||
store map[any]any |
||||
router *echo.Group |
||||
mu sync.RWMutex |
||||
} |
||||
|
||||
func NewContext(ctx context.Context, router *echo.Group) *Context { |
||||
return &Context{ |
||||
Context: ctx, |
||||
store: make(map[any]any), |
||||
router: router, |
||||
mu: sync.RWMutex{}, |
||||
} |
||||
} |
||||
|
||||
// Routes 注册路由
|
||||
func (c *Context) Routes(routes Routable) { |
||||
c.mu.Lock() |
||||
defer c.mu.Unlock() |
||||
routes.InitRoutes(c.router) |
||||
} |
||||
|
||||
// Set 设置值
|
||||
func (c *Context) Set(key, val any) { |
||||
c.mu.Lock() |
||||
defer c.mu.Unlock() |
||||
c.store[key] = val |
||||
} |
||||
|
||||
// Get 获取值,只会获取通过 Set 方法设置的值
|
||||
func (c *Context) Get(key any) (any, bool) { |
||||
c.mu.RLock() |
||||
defer c.mu.RUnlock() |
||||
val, ok := c.store[key] |
||||
return val, ok |
||||
} |
||||
|
||||
// Value 获取值
|
||||
func (c *Context) Value(key any) any { |
||||
if val, ok := c.Get(key); ok { |
||||
return val |
||||
} |
||||
return c.Value(key) |
||||
} |
@ -1,20 +0,0 @@ |
||||
package crud |
||||
|
||||
import "github.com/labstack/echo/v4" |
||||
|
||||
type Service interface { |
||||
// Name 服务名称
|
||||
Name() string |
||||
// Priority 优先级,用于启动和销毁的执行顺序
|
||||
Priority() int |
||||
// Init 初始化服务
|
||||
Init(ctx *Context) error |
||||
// Bootstrap 启动服务
|
||||
Bootstrap() error |
||||
// Destroy 销毁服务
|
||||
Destroy() error |
||||
} |
||||
|
||||
type Routable interface { |
||||
InitRoutes(r *echo.Group) |
||||
} |
Loading…
Reference in new issue