package runtime import ( "fmt" "net/http" "runtime" "sorbet/internal/util" "sorbet/pkg/log" "github.com/labstack/echo/v4" ) // LogErrorFunc defines a function for custom logging in the middleware. type LogErrorFunc func(c echo.Context, err error, stack []byte) error // RecoverConfig defines the config for Recover middleware. type RecoverConfig struct { // Size of the stack to be printed. // Optional. Default value 4KB. StackSize int // DisableStackAll disables formatting stack traces of all other goroutines // into buffer after the trace for the current goroutine. // Optional. Default value false. DisableStackAll bool // DisablePrintStack disables printing stack trace. // Optional. Default value as false. DisablePrintStack bool // LogLevel is log level to printing stack trace. // Optional. Default value 0 (Print). LogLevel log.Level // LogErrorFunc defines a function for custom logging in the middleware. // If it's set you don't need to provide LogLevel for config. // If this function returns nil, the centralized HTTPErrorHandler will not be called. LogErrorFunc LogErrorFunc // DisableErrorHandler disables the call to centralized HTTPErrorHandler. // The recovered error is then passed back to upstream middleware, instead of swallowing the error. // Optional. Default value false. DisableErrorHandler bool } // DefaultRecoverConfig is the default Recover middleware config. var DefaultRecoverConfig = RecoverConfig{ StackSize: 4 << 10, // 4 KB DisableStackAll: false, DisablePrintStack: false, LogLevel: log.LevelDebug, LogErrorFunc: nil, DisableErrorHandler: false, } // Recover returns a middleware which recovers from panics anywhere in the chain // and handles the control to the centralized HTTPErrorHandler. func Recover(next echo.HandlerFunc) echo.HandlerFunc { return RecoverWithConfig(DefaultRecoverConfig)(next) } func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { if config.StackSize == 0 { config.StackSize = DefaultRecoverConfig.StackSize } switch config.LogLevel { case log.LevelTrace, log.LevelFatal, log.LevelPanic: panic("不应该将 LevelTrace、LevelFatal 和 LevelPanic 这三个日志作用在错误恢复中间件上") } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) (returnErr error) { defer func() { if r := recover(); r != nil { if r == http.ErrAbortHandler { panic(r) } err, ok := r.(error) if !ok { err = fmt.Errorf("%v", r) } var stack []byte var length int if !config.DisablePrintStack { stack = make([]byte, config.StackSize) length = runtime.Stack(stack, !config.DisableStackAll) stack = stack[:length] } if config.LogErrorFunc != nil { err = config.LogErrorFunc(c, err, stack) } else if !config.DisablePrintStack { var i []any if _, ok := c.Logger().(*util.EchoLogger); ok { i = append(i, fmt.Sprintf("%v %s\n", err, stack[:length]), log.RawLevel("PANIC RECOVER"), ) } else { i = append(i, fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length])) } switch config.LogLevel { case log.LevelDebug: c.Logger().Debug(i...) case log.LevelInfo: c.Logger().Info(i...) case log.LevelWarn: c.Logger().Warn(i...) case log.LevelError: c.Logger().Error(i...) case log.LevelOff: // None. default: c.Logger().Print(i...) } } if err != nil && !config.DisableErrorHandler { c.Error(err) } else { returnErr = err } } }() return next(c) } } }