From 9b751c066c8e147d313ed674cb3b43893fe2a7f5 Mon Sep 17 00:00:00 2001 From: hupeh Date: Wed, 25 Oct 2023 17:44:19 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=90=88=E5=B9=B6=E8=BF=90=E8=A1=8C?= =?UTF-8?q?=E6=97=B6=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/init.go | 29 +++- internal/middleware/cors.go | 136 ---------------- internal/runtime/server.go | 26 ++- .../logger.go => runtime/server_logger.go} | 2 +- .../recover.go => runtime/server_recover.go} | 6 +- pkg/rs/error.go | 154 ++++++++++++++++++ pkg/rsp/error.go | 27 ++- pkg/rsp/respond_utils.go | 12 +- 8 files changed, 244 insertions(+), 148 deletions(-) delete mode 100644 internal/middleware/cors.go rename internal/{middleware/logger.go => runtime/server_logger.go} (98%) rename internal/{middleware/recover.go => runtime/server_recover.go} (96%) create mode 100644 pkg/rs/error.go diff --git a/internal/init.go b/internal/init.go index 1f824a9..6de1e11 100644 --- a/internal/init.go +++ b/internal/init.go @@ -9,14 +9,14 @@ import ( "sorbet/internal/services/feature" "sorbet/internal/services/resource" "sorbet/internal/services/system" + "sorbet/internal/util" "sorbet/pkg/db" "sorbet/pkg/env" "sorbet/pkg/log" ) func Init() error { - err := syncEntities() - if err != nil { + if err := syncEntities(); err != nil { if !errors.Is(err, db.ErrNoCodeFirst) { return err } @@ -24,6 +24,9 @@ func Init() error { log.Warn("同步数据表结构需要开启 [DB_CODE_FIRST],在生产模式下请务必关闭。") } } + if err := initSystemUser(); err != nil { + return err + } return useServlets() } @@ -33,6 +36,7 @@ func syncEntities() error { &entities.Company{}, &entities.CompanyDepartment{}, &entities.CompanyEmployee{}, + &entities.CompanyAuthTicket{}, &entities.ConfigGroup{}, &entities.Config{}, &entities.Feature{}, @@ -52,6 +56,27 @@ func syncEntities() error { ) } +func initSystemUser() error { + var count int64 + err := db. + Model(&entities.SystemUser{}). + Where("username", "admin"). + Count(&count). + Error + if err != nil { + return err + } + if count > 0 { + return nil + } + hash, _ := util.PasswordHash("111111") + _, err = db.Create(&entities.SystemUser{ + Username: "admin", + Password: hash, + }) + return err +} + func useServlets() error { return runtime.Use( &config.Service{}, diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go deleted file mode 100644 index 351320d..0000000 --- a/internal/middleware/cors.go +++ /dev/null @@ -1,136 +0,0 @@ -package middleware - -import ( - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" - "net/http" -) - -type CORSConfig struct { - // Skipper defines a function to skip middleware. - Skipper func(c echo.Context) bool - - // AllowOrigins determines the value of the Access-Control-Allow-Origin - // response header. This header defines a list of origins that may access the - // resource. The wildcard characters '*' and '?' are supported and are - // converted to regex fragments '.*' and '.' accordingly. - // - // Security: use extreme caution when handling the origin, and carefully - // validate any logic. Remember that attackers may register hostile domain names. - // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html - // - // Optional. Default value []string{"*"}. - // - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin - AllowOrigins []string - - // AllowOriginFunc is a custom function to validate the origin. It takes the - // origin as an argument and returns true if allowed or false otherwise. If - // an error is returned, it is returned by the handler. If this option is - // set, AllowOrigins is ignored. - // - // Security: use extreme caution when handling the origin, and carefully - // validate any logic. Remember that attackers may register hostile domain names. - // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html - // - // Optional. - AllowOriginFunc func(origin string) (bool, error) - - // AllowMethods determines the value of the Access-Control-Allow-Methods - // response header. This header specified the list of methods allowed when - // accessing the resource. This is used in response to a preflight request. - // - // Optional. Default value DefaultCORSConfig.AllowMethods. - // If `allowMethods` is left empty, this middleware will fill for preflight - // request `Access-Control-Allow-Methods` header value - // from `Allow` header that echo.Router set into context. - // - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods - AllowMethods []string - - // AllowHeaders determines the value of the Access-Control-Allow-Headers - // response header. This header is used in response to a preflight request to - // indicate which HTTP headers can be used when making the actual request. - // - // Optional. Default value []string{}. - // - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers - AllowHeaders []string - - // AllowCredentials determines the value of the - // Access-Control-Allow-Credentials response header. This header indicates - // whether the response to the request can be exposed when the - // credentials mode (Request.credentials) is true. When used as part of a - // response to a preflight request, this indicates whether or not the actual - // request can be made using credentials. See also - // [MDN: Access-Control-Allow-Credentials]. - // - // Optional. Default value false, in which case the header is not set. - // - // Security: avoid using `AllowCredentials = true` with `AllowOrigins = *`. - // See "Exploiting CORS misconfigurations for Bitcoins and bounties", - // https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html - // - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials - AllowCredentials bool - - // UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials - // flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header. - // - // This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties) - // attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject. - // - // Optional. Default value is false. - UnsafeWildcardOriginWithAllowCredentials bool - - // ExposeHeaders determines the value of Access-Control-Expose-Headers, which - // defines a list of headers that clients are allowed to access. - // - // Optional. Default value []string{}, in which case the header is not set. - // - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header - ExposeHeaders []string - - // MaxAge determines the value of the Access-Control-Max-Age response header. - // This header indicates how long (in seconds) the results of a preflight - // request can be cached. - // - // Optional. Default value 0. The header is set only if MaxAge > 0. - // - // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age - MaxAge int -} - -// DefaultCORSConfig is the default CORS middleware config. -var DefaultCORSConfig = CORSConfig{ - Skipper: func(c echo.Context) bool { - return false - }, - AllowOrigins: []string{"*"}, - AllowMethods: []string{ - http.MethodGet, - http.MethodHead, - http.MethodPut, - http.MethodPatch, - http.MethodPost, - http.MethodDelete, - }, -} - -func (c *CORSConfig) ToMiddleware() echo.MiddlewareFunc { - return middleware.CORSWithConfig(middleware.CORSConfig{ - Skipper: c.Skipper, - AllowOrigins: c.AllowOrigins, - AllowOriginFunc: c.AllowOriginFunc, - AllowMethods: c.AllowMethods, - AllowHeaders: c.AllowHeaders, - AllowCredentials: c.AllowCredentials, - UnsafeWildcardOriginWithAllowCredentials: c.UnsafeWildcardOriginWithAllowCredentials, - ExposeHeaders: c.ExposeHeaders, - MaxAge: c.MaxAge, - }) -} - -func CORS() echo.MiddlewareFunc { - return DefaultCORSConfig.ToMiddleware() -} diff --git a/internal/runtime/server.go b/internal/runtime/server.go index 7b4be69..e41a767 100644 --- a/internal/runtime/server.go +++ b/internal/runtime/server.go @@ -34,15 +34,35 @@ func newEchoFramework() (*echo.Echo, error) { e.HTTPErrorHandler = rsp.HTTPErrorHandler e.Debug = !env.IsEnv("prod") e.Logger = util.NewEchoLogger() - e.Use(middleware.Recover()) - e.Use(middleware.CORS()) - e.Use(middleware.Logger()) + + // 配置错误捕获 + e.Use(Recover) + // 配置跨域 + e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ + AllowOrigins: env.List("CORS_ALLOW_ORIGINS", []string{"*"}), + AllowMethods: env.List("CORS_ALLOW_METHODS", []string{"GET", "HEAD", "PUT", "PATCH", "POST", "DELETE"}), + AllowCredentials: env.Bool("CORS_ALLOW_CREDENTIALS", false), + AllowHeaders: env.List("CORS_ALLOW_HEADERS", []string{"X-TICKET"}), + ExposeHeaders: env.List("CORS_EXPOSE_HEADERS"), + MaxAge: env.Int("CORS_MAX_AGE", 0), + })) + // 配置日志 + e.Use(Logger) + // 配置静态服务 + e.Use(middleware.StaticWithConfig(middleware.StaticConfig{ + Root: env.String("STATIC_ROOT", "web"), + Index: env.String("STATIC_INDEX", "index.html"), + HTML5: env.Bool("STATIC_HTML5", false), + Browse: env.Bool("STATIC_DIRECTORY_BROWSE", false), + })) + for _, servlet := range servlets { group := e.Group("") for _, routable := range servlet.Routes() { routable.InitRoutes(group) } } + routes := e.Routes() e.GET("/_routes", func(c echo.Context) error { return c.JSON(http.StatusOK, routes) diff --git a/internal/middleware/logger.go b/internal/runtime/server_logger.go similarity index 98% rename from internal/middleware/logger.go rename to internal/runtime/server_logger.go index 6355907..55598c7 100644 --- a/internal/middleware/logger.go +++ b/internal/runtime/server_logger.go @@ -1,4 +1,4 @@ -package middleware +package runtime import ( "context" diff --git a/internal/middleware/recover.go b/internal/runtime/server_recover.go similarity index 96% rename from internal/middleware/recover.go rename to internal/runtime/server_recover.go index 929ed71..4a64ea5 100644 --- a/internal/middleware/recover.go +++ b/internal/runtime/server_recover.go @@ -1,4 +1,4 @@ -package middleware +package runtime import ( "fmt" @@ -55,8 +55,8 @@ var DefaultRecoverConfig = RecoverConfig{ // Recover returns a middleware which recovers from panics anywhere in the chain // and handles the control to the centralized HTTPErrorHandler. -func Recover() echo.MiddlewareFunc { - return RecoverWithConfig(DefaultRecoverConfig) +func Recover(next echo.HandlerFunc) echo.HandlerFunc { + return RecoverWithConfig(DefaultRecoverConfig)(next) } func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { diff --git a/pkg/rs/error.go b/pkg/rs/error.go new file mode 100644 index 0000000..e50c269 --- /dev/null +++ b/pkg/rs/error.go @@ -0,0 +1,154 @@ +package rs + +import ( + "fmt" + "net/http" + "strings" +) + +var ( + // ErrOK 表示没有任何错误。 + // 对应 HTTP 响应状态码为 500。 + ErrOK = NewError(http.StatusOK, 0, "OK") + + // ErrInternal 客户端请求有效,但服务器处理时发生了意外。 + // 对应 HTTP 响应状态码为 500。 + ErrInternal = NewError(http.StatusInternalServerError, -100, "系统内部错误") + + // ErrServiceUnavailable 服务器无法处理请求,一般用于网站维护状态。 + // 对应 HTTP 响应状态码为 503。 + ErrServiceUnavailable = NewError(http.StatusServiceUnavailable, -101, "服务不可用") + + // ErrUnauthorized 用户未提供身份验证凭据,或者没有通过身份验证。 + // 响应的 HTTP 状态码为 401。 + ErrUnauthorized = NewError(http.StatusUnauthorized, -102, "身份验证失败") + + // ErrForbidden 用户通过了身份验证,但是不具有访问资源所需的权限。 + // 响应的 HTTP 状态码为 403。 + ErrForbidden = NewError(http.StatusForbidden, -103, "不具有访问资源所需的权限") + + // ErrGone 所请求的资源已从这个地址转移,不再可用。 + // 响应的 HTTP 状态码为 410。 + ErrGone = NewError(http.StatusGone, -104, "所请求的资源不存在") + + // ErrUnsupportedMediaType 客户端要求的返回格式不支持。 + // 比如,API 只能返回 JSON 格式,但是客户端要求返回 XML 格式。 + // 响应的 HTTP 状态码为 415。 + ErrUnsupportedMediaType = NewError(http.StatusUnsupportedMediaType, -105, "请求的数据格式错误") + + // ErrUnprocessableEntity 无法处理客户端上传的附件,导致请求失败。 + // 响应的 HTTP 状态码为 422。 + ErrUnprocessableEntity = NewError(http.StatusUnprocessableEntity, -106, "上传了不被支持的附件") + + // ErrTooManyRequests 客户端的请求次数超过限额。 + // 响应的 HTTP 状态码为 429。 + ErrTooManyRequests = NewError(http.StatusTooManyRequests, -107, "请求次数超过限额") + + // ErrSeeOther 表示需要参考另一个 URL 才能完成接收的请求操作, + // 当请求方式使用 POST、PUT 和 DELETE 时,对应的 HTTP 状态码为 303, + // 其它的请求方式在大多数情况下应该使用 400 状态码。 + ErrSeeOther = NewError(http.StatusSeeOther, -108, "需要更进一步才能完成操作") + + // ErrBadRequest 服务器不理解客户端的请求。 + // 对应 HTTP 状态码为 400。 + ErrBadRequest = NewError(http.StatusBadRequest, -109, "请求错误") + + // ErrBadParams 客户端提交的参数不符合要求 + // 对应 HTTP 状态码为 400。 + ErrBadParams = NewError(http.StatusBadRequest, -110, "参数错误") + + // ErrRecordNotFound 访问的数据不存在 + // 对应 HTTP 状态码为 404。 + ErrRecordNotFound = NewError(http.StatusNotFound, -111, "访问的数据不存在") +) + +type Error struct { + // 被包装的错误对象 + internal error + // 响应的 HTTP 状态码 + status int + // 错误码 + code int + // 错误提示消息 + text string + // 错误携带的响应数据 + data any +} + +func NewError(status, code int, text string) *Error { + return &Error{nil, status, code, text, nil} +} + +// Code 返回错误码 +func (e *Error) Code() int { + return e.code +} + +// Text 返回错误提示 +func (e *Error) Text() string { + return e.text +} + +// Data 返回携带的数据 +func (e *Error) Data() any { + return e.data +} + +// Internal 返回原始错误 +func (e *Error) Internal() error { + return e.internal +} + +// Unwrap 支持 errors.Unwrap() 方法 +func (e *Error) Unwrap() error { + return e.Internal() +} + +// WithInternal 通过实际错误对象派生新的实例 +func (e *Error) WithInternal(err error) *Error { + // 由于错误比较复杂,不好做完全等于, + // 在这里就直接复制当前对象 + c := *e + c.internal = err + return &c +} + +// WithStatus 通过 HTTP 状态码派生新的实例 +func (e *Error) WithStatus(status int) *Error { + if e.status != status { + c := *e + c.status = status + return &c + } + return e +} + +// WithText 通过新的错误提示派生新的实例 +func (e *Error) WithText(text string) *Error { + if text != e.text { + c := *e + c.text = text + return &c + } + return e +} + +// WithData 通过携带数据派生新的实例 +func (e *Error) WithData(data any) *Error { + if e.data != data { + c := *e + c.data = data + return &c + } + return e +} + +// String 实现 fmt.Stringer 接口 +func (e *Error) String() string { + return strings.TrimSpace(fmt.Sprintf("%d %s", e.code, e.text)) +} + +// Error 实现错误接口 +func (e *Error) Error() string { + return e.String() +} diff --git a/pkg/rsp/error.go b/pkg/rsp/error.go index dc4fa59..e30d2af 100644 --- a/pkg/rsp/error.go +++ b/pkg/rsp/error.go @@ -67,10 +67,11 @@ type Error struct { status int // HTTP 状态码 code int // 请求错误码 text string // 响应提示消息 + data any // 错误携带的响应数据 } func NewError(status, code int, text string) *Error { - return &Error{nil, status, code, text} + return &Error{nil, status, code, text, nil} } func (e *Error) Code() int { @@ -81,6 +82,18 @@ func (e *Error) Text() string { return e.text } +func (e *Error) Data() any { + return e.data +} + +func (e *Error) Internal() error { + return e.internal +} + +func (e *Error) Unwrap() error { + return e.Internal() +} + func (e *Error) WithInternal(err error) *Error { c := *e c.internal = err @@ -88,6 +101,9 @@ func (e *Error) WithInternal(err error) *Error { } func (e *Error) WithStatus(status int) *Error { + if e.status == status { + return e + } c := *e c.status = status return &c @@ -104,6 +120,15 @@ func (e *Error) WithText(text ...string) *Error { return e } +func (e *Error) WithData(data any) *Error { + if e.data == data { + return e + } + c := *e + c.data = data + return &c +} + func (e *Error) AsProblem(label string) *Problem { return &Problem{ Label: label, diff --git a/pkg/rsp/respond_utils.go b/pkg/rsp/respond_utils.go index 7bacf96..9e34605 100644 --- a/pkg/rsp/respond_utils.go +++ b/pkg/rsp/respond_utils.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "github.com/labstack/echo/v4" + "gorm.io/gorm" "net/http" "runtime" "sorbet/pkg/ticket" @@ -129,6 +130,9 @@ func (r *response) result(c echo.Context) (m map[string]any, status int) { m["code"] = ex.code m["success"] = errors.Is(ex, ErrOK) m["message"] = ex.text + if data := ex.Data(); data != nil { + m["data"] = data + } if status < 400 { status = ex.status } @@ -253,7 +257,7 @@ func respond(c echo.Context, o *response) error { func relevantCaller() []string { pc := make([]uintptr, 16) - n := runtime.Callers(1, pc) + n := runtime.Callers(3, pc) frames := runtime.CallersFrames(pc[:n]) var traces []string for { @@ -287,7 +291,11 @@ func Fail(c echo.Context, err error, opts ...Option) error { for _, option := range opts { option(&o) } - o.err = err + if errors.Is(err, gorm.ErrRecordNotFound) { + o.err = ErrRecordNotFound + } else { + o.err = err + } return respond(c, &o) }