package rsp import ( "bytes" "encoding/json" "errors" "fmt" "github.com/labstack/echo/v4" "gorm.io/gorm" "net/http" "runtime" "sorbet/pkg/ticket" "sorbet/pkg/v" ) var ( TextMarshaller func(map[string]any) (string, error) HtmlMarshaller func(map[string]any) (string, error) JsonpCallbacks []string DefaultJsonpCallback string negotiator *Negotiator ) func init() { TextMarshaller = toText HtmlMarshaller = toText JsonpCallbacks = []string{"callback", "cb", "jsonp"} DefaultJsonpCallback = "callback" negotiator = NewNegotiator(10) } func toText(m map[string]any) (string, error) { buf := &bytes.Buffer{} enc := json.NewEncoder(buf) enc.SetEscapeHTML(true) if err := enc.Encode(m); err != nil { return "", err } else { return buf.String(), nil } } type RespondValuer interface { RespondValue() any } type RespondData[T any] struct { data T } func (r *RespondData[T]) RespondValue() any { return r.data } type response struct { status int headers map[string]string cookies []*http.Cookie err error message string data any } type Option func(o *response) func StatusCode(status int) Option { return func(o *response) { o.status = status } } func Header(key, value string) Option { return func(o *response) { if o.headers == nil { o.headers = make(map[string]string) } o.headers[key] = value } } func Cookie(cookie *http.Cookie) Option { return func(o *response) { if o.cookies != nil { for i, h := range o.cookies { if h.Name == cookie.Name { o.cookies[i] = cookie return } } } o.cookies = append(o.cookies, cookie) } } func Message(msg string) Option { return func(o *response) { o.message = msg } } func Data(data any) Option { return func(o *response) { o.data = data } } func (r *response) result(c echo.Context) (m map[string]any, status int) { status = r.status m = map[string]any{ "code": nil, "success": false, "message": r.message, } isDebug := c.Echo().Debug var err error if ee, ok := r.err.(*v.Errors); ok { pb := &Problem{} for _, e := range ee.All() { pb.AddSubproblem(ErrBadParams.WithText(e.Error()).AsProblem(e.Field())) } m["code"] = ErrBadParams.code m["message"] = ErrBadParams.text m["problems"] = pb.Problems if status < 400 { status = ErrBadParams.status } } else if ex, yes := r.err.(*Error); yes { err = ex.internal m["code"] = ex.code m["success"] = errors.Is(ex, ErrOK) m["message"] = ex.text if data := ex.Data(); data != nil { m["data"] = data } if status < 400 { status = ex.status } } else if pb, okay := r.err.(*Problem); okay { m["code"] = pb.Code m["message"] = pb.Message m["problems"] = pb.Problems if status < 400 { status = ErrInternal.status } } else if r.err == nil { m["code"] = 0 m["success"] = true if r.data != nil { if val, ok := r.data.(RespondValuer); ok { m["data"] = val.RespondValue() } else { m["data"] = r.data } } } else { err = r.err switch { case errors.Is(r.err, ticket.ErrNoTicketFound), errors.Is(r.err, ticket.ErrInvalidTicket), errors.Is(r.err, ticket.ErrBadAudience), errors.Is(r.err, ticket.ErrTicketExpired), errors.Is(r.err, ticket.ErrUnauthorized): m["code"] = ErrUnauthorized.Code() if isDebug { m["message"] = r.err.Error() } else { m["message"] = ErrUnauthorized.Code() } if status < 400 { status = ErrUnauthorized.status } default: var he *echo.HTTPError if errors.As(err, &he) { status = he.Code err = he.Internal } m["code"] = ErrInternal.Code() m["message"] = r.err.Error() if status < 400 { status = ErrInternal.status } } } if c.Echo().Debug && m["success"] != true { if err != nil { m["error"] = err.Error() } m["stack"] = relevantCaller() } if m["message"] == "" { m["message"] = http.StatusText(status) } return } func respond(c echo.Context, o *response) error { defer func() { if o.err != nil { c.Logger().Error(o.err) } }() // 如果已经输出过,就忽略 if c.Response().Committed { return nil } // 设置报头 if o.headers != nil { header := c.Response().Header() for key, value := range o.headers { header.Set(key, value) } } // 设置 cookie if o.cookies != nil { for _, cookie := range o.cookies { c.SetCookie(cookie) } } m, status := o.result(c) // HEAD 请求没有结果 r := c.Request() if r.Method == http.MethodHead { return c.NoContent(status) } // 根据报头响应不同的格式 accept := r.Header.Get(echo.HeaderAccept) switch negotiator.Type(accept, "html", "json", "jsonp", "xml", "text", "text/*") { case "html": if html, err := HtmlMarshaller(m); err != nil { return err } else { return c.HTML(status, html) } case "json": return c.JSON(status, m) case "jsonp": qs := c.Request().URL.Query() for _, name := range JsonpCallbacks { if cb := qs.Get(name); cb != "" { return c.JSONP(status, cb, m) } } return c.JSONP(status, DefaultJsonpCallback, m) case "xml": return c.XML(status, m) case "text", "text/*": if text, err := TextMarshaller(m); err != nil { return err } else { return c.String(status, text) } } return c.JSON(status, m) } func relevantCaller() []string { pc := make([]uintptr, 16) n := runtime.Callers(3, pc) frames := runtime.CallersFrames(pc[:n]) var traces []string for { frame, more := frames.Next() traces = append(traces, fmt.Sprintf("%s:%s:%d", frame.File, frame.Func.Name(), frame.Line)) if !more { return traces } } } func Respond(c echo.Context, opts ...Option) error { o := response{status: http.StatusOK} for _, option := range opts { option(&o) } return respond(c, &o) } func Ok(c echo.Context, data any) error { return Respond(c, Data(data)) } func Created(c echo.Context, data any) error { return Respond(c, Data(data), StatusCode(http.StatusCreated)) } // Fail 响应一个错误 func Fail(c echo.Context, err error, opts ...Option) error { o := response{status: http.StatusInternalServerError} for _, option := range opts { option(&o) } if errors.Is(err, gorm.ErrRecordNotFound) { o.err = ErrRecordNotFound } else { o.err = err } return respond(c, &o) } // InternalError 响应一个服务器内部错误 func InternalError(c echo.Context, message ...string) error { return Fail(c, ErrInternal.WithText(message...)) } // ServiceUnavailable 响应一个服务暂不可用的错误 func ServiceUnavailable(c echo.Context, message ...string) error { return Fail(c, ErrServiceUnavailable.WithText(message...), StatusCode(http.StatusServiceUnavailable)) } // Unauthorized 需要一个身份验证凭据异常的错误 func Unauthorized(c echo.Context, message ...string) error { return Fail(c, ErrUnauthorized.WithText(message...), StatusCode(http.StatusUnauthorized)) } // Forbidden 响应一个不具有访问资源所需权限的错误(用户通过了身份验证) func Forbidden(c echo.Context, message ...string) error { return Fail(c, ErrForbidden.WithText(message...), StatusCode(http.StatusForbidden)) } // UnprocessableEntity 响应一个处理客户端上传失败的错误 func UnprocessableEntity(c echo.Context, message ...string) error { return Fail(c, ErrUnprocessableEntity.WithText(message...), StatusCode(http.StatusUnprocessableEntity)) } // BadRequest 响应一个服务器不理解客户端请求的错误 func BadRequest(c echo.Context, message ...string) error { return Fail(c, ErrBadRequest.WithText(message...), StatusCode(http.StatusBadRequest)) } // BadParams 响应一个客户端提交的参数不符合要求的错误 func BadParams(c echo.Context, message ...string) error { return Fail(c, ErrBadParams.WithText(message...), StatusCode(http.StatusBadRequest)) } // RecordNotFound 响应一个数据不存在的错误 func RecordNotFound(c echo.Context, message ...string) error { return Fail(c, ErrRecordNotFound.WithText(message...), StatusCode(http.StatusNotFound)) }