commit b41fdd7816aa08ea4624ad46241c1c526a566d11 Author: hupeh Date: Sun Jan 1 23:26:48 2023 +0800 :tada: 初始化项目 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4ae3ed3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.idea +tmp +.air.toml +*.db +*.log diff --git a/app/db.go b/app/db.go new file mode 100644 index 0000000..22102db --- /dev/null +++ b/app/db.go @@ -0,0 +1,82 @@ +package app + +import ( + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "math" + "strings" +) + +var DB *gorm.DB + +func ConfigGormDB() { + var err error + DB, err = gorm.Open(sqlite.Open("gorm.db"), &gorm.Config{}) + if err != nil { + panic(err) + } + if err = DB.AutoMigrate(&User{}, &Goods{}, &Price{}); err != nil { + panic(err) + } +} + +type User struct { + gorm.Model + Name string `json:"name"` // 用户名称 + PhoneNumber string `json:"phone_number"` // 用户手机 + Password string `json:"password"` // 登录密码 +} + +// Goods 商品 +type Goods struct { + gorm.Model + Name string `json:"name"` // 商品名称 + Price float32 `json:"price"` // 商品当前价格 + Prices []Price `json:"prices,omitempty"` // 商品价格列表 +} + +// Price 商品价格 +type Price struct { + gorm.Model + GoodsID uint `json:"goods_id"` // 管理商品 + Price float32 `json:"price"` // 商品价格 +} + +func Paginate(r *Request) func(db *gorm.DB) *gorm.DB { + return func(db *gorm.DB) *gorm.DB { + page := r.Int("page", 1, func(p int) int { + return int(math.Max(float64(p), 1)) + }) + perPage := r.Int("per_page", 30, func(i int) int { + return int(math.Max(float64(i), 1)) + }) + offset := (page - 1) * perPage + return db.Offset(offset).Limit(perPage) + } +} + +func Search(r *Request, key, query string) func(db *gorm.DB) *gorm.DB { + return func(db *gorm.DB) *gorm.DB { + if keyword, ok := r.Get(key); ok { + return db.Where(query, keyword) + } else { + return db + } + } +} + +func TimeRange(r *Request, column string) func(db *gorm.DB) *gorm.DB { + return func(db *gorm.DB) *gorm.DB { + var queries []string + var args []any + if val, ok := r.Get("start_time"); ok { + queries = append(queries, column+" >= ?") + args = append(args, val) + } + if val, ok := r.Get("end_time"); ok { + queries = append(queries, column+" <= ?") + args = append(args, val) + } + return db.Where(strings.Join(queries, " AND "), args...) + } +} diff --git a/app/err.go b/app/err.go new file mode 100644 index 0000000..d43d721 --- /dev/null +++ b/app/err.go @@ -0,0 +1,15 @@ +package app + +type Error struct { + Status int `json:"-"` + Code int `json:"code"` + Message string `json:"message"` +} + +func NewError(code int, message string) *Error { + return &Error{Code: code, Message: message} +} + +func (e *Error) Error() string { + return e.Message +} diff --git a/app/log.go b/app/log.go new file mode 100644 index 0000000..4c580be --- /dev/null +++ b/app/log.go @@ -0,0 +1,122 @@ +package app + +import ( + "fmt" + "log" + "os" + "path" + "strings" + "sync" + "time" +) + +const ( + LogWhenSecond = iota + LogWhenMinute + LogWhenHour + LogWhenDay +) + +var ( + lg func(level string, data any) + fd *os.File +) + +func ConfigLogger(file string, when int8) { + // 解决 Windows 电脑路径问题 + file = strings.ReplaceAll(file, "\\", "/") + if err := os.MkdirAll(path.Dir(file), 0777); err != nil { + panic(err) + } + + var interval int64 + var suffix string + switch when { + case LogWhenSecond: + interval = 1 + suffix = "2006-01-02_15-04-05" + case LogWhenMinute: + interval = 60 + suffix = "2006-01-02_15-04" + case LogWhenHour: + interval = 3600 + suffix = "2006-01-02_15" + case LogWhenDay: + interval = 3600 * 24 + suffix = "2006-01-02" + default: + panic(fmt.Errorf("invalid when rotate: %d", when)) + } + + var err error + fd, err = os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + panic(err) + } + + fInfo, err := fd.Stat() + if err != nil { + panic(err) + } + + rolloverAt := fInfo.ModTime().Unix() + interval + locker := sync.RWMutex{} + + resolveRotatedFile := func(suffix string) string { + filenameWithSuffix := path.Base(file) + fileSuffix := path.Ext(filenameWithSuffix) + filename := strings.TrimSuffix(filenameWithSuffix, fileSuffix) + return path.Dir(file) + "/" + filename + "." + suffix + fileSuffix + } + + lg = func(level string, data any) { + now := time.Now() + + if rolloverAt <= now.Unix() { + locker.Lock() + defer locker.Unlock() + + fName := resolveRotatedFile(now.Format(suffix)) + //fName := file + now.Format(suffix) + _ = fd.Close() + e := os.Rename(file, fName) + if e != nil { + log.Println(e) + return + } + fd, err = os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + log.Println(e) + return + } + now = time.Now() + rolloverAt = now.Unix() + interval + } + + t := now.Format("2006-01-02 15:04:05") + _, err := fmt.Fprintf(fd, "%s [%s] %v", t, level, data) + if err != nil { + log.Println(err) + } + } +} + +func LogError(data any) { + lg("ERROR", data) +} + +func LogWarning(data any) { + lg("WARNING", data) +} + +func LogInfo(data any) { + lg("INFO", data) +} + +func LogoDebug(data any) { + lg("DEBUG", data) +} + +func LogFatal(data any) { + lg("FATAL", data) +} \ No newline at end of file diff --git a/app/net.go b/app/net.go new file mode 100644 index 0000000..32860c8 --- /dev/null +++ b/app/net.go @@ -0,0 +1,268 @@ +package app + +import ( + "encoding/json" + "github.com/go-chi/chi/v5" + "net/http" + "strconv" + "sync" +) + +type TapFunc[V any] func(V) V +type Convertor[V any] func(s string) (V, error) +type HandlerFunc func(w *ResponseWriter, r *Request) +type ParamGetter func(key string) (string, bool) + +func GetParam[V any](r *Params, key string, def V, convertor Convertor[V], taps []TapFunc[V]) V { + var v V + if str, ok := r.Get(key); ok { + if x, err := convertor(str); err != nil { + v = def + } else { + v = x + } + } else { + v = def + } + for _, tap := range taps { + v = tap(v) + } + return v +} + +type Params struct { + pg ParamGetter +} + +func NewParams(pg ParamGetter) *Params { + return &Params{pg } +} + +func NewPathParams(r *Request) *Params { + return NewParams(func(key string) (string, bool) { + ctx := chi.RouteContext(r.Context()) + for k := len(ctx.URLParams.Keys) - 1; k >= 0; k-- { + if ctx.URLParams.Keys[k] == key { + return ctx.URLParams.Values[k], true + } + } + return "", false + }) +} + +func NewQueryParams(r *Request) *Params { + return NewParams(func(key string) (string, bool) { + if values, ok := r.URL.Query()[key]; ok && len(values) > 0 { + return values[0], true + } else { + return "", false + } + }) +} + +func (u *Params) Get(key string) (string, bool) { + return u.pg(key) +} + +func (u *Params) Value(key string, taps ...TapFunc[string]) string { + val, _ := u.Get(key) + for _, tap := range taps { + val = tap(val) + } + return val +} + +func (u *Params) Int(key string, def int, taps ...TapFunc[int]) int { + return GetParam[int](u, key, def, func(s string) (int, error) { + n, e := strconv.ParseInt(s, 10, 64) + return int(n), e + }, taps) +} + +func (u *Params) Int32(key string, def int32, taps ...TapFunc[int32]) int32 { + return GetParam[int32](u, key, def, func(s string) (int32, error) { + n, e := strconv.ParseInt(s, 10, 64) + return int32(n), e + }, taps) +} + +func (u *Params) Int64(key string, def int64, taps ...TapFunc[int64]) int64 { + return GetParam[int64](u, key, def, func(s string) (int64, error) { + return strconv.ParseInt(s, 10, 64) + }, taps) +} + +func (u *Params) Uint(key string, def uint, taps ...TapFunc[uint]) uint { + return GetParam[uint](u, key, def, func(s string) (uint, error) { + n, e := strconv.ParseUint(s, 10, 64) + return uint(n), e + }, taps) +} + +func (u *Params) Uint32(key string, def uint32, taps ...TapFunc[uint32]) uint32 { + return GetParam[uint32](u, key, def, func(s string) (uint32, error) { + n,e := strconv.ParseUint(s, 10, 64) + return uint32(n), e + }, taps) +} + +func (u *Params) Uint64(key string, def uint64, taps ...TapFunc[uint64]) uint64 { + return GetParam[uint64](u, key, def, func(s string) (uint64, error) { + return strconv.ParseUint(s, 10, 64) + }, taps) +} + +func (u *Params) Float32(key string, def float32, taps ...TapFunc[float32]) float32 { + return GetParam[float32](u, key, def, func(s string) (float32, error) { + f, e := strconv.ParseFloat(s, 32) + return float32(f), e + }, taps) +} + +func (u *Params) Float64(key string, def float64, taps ...TapFunc[float64]) float64 { + return GetParam[float64](u, key, def, func(s string) (float64, error) { + return strconv.ParseFloat(s, 32) + }, taps) +} + + +type Request struct { + *http.Request + *Params + pathParams *Params + queryParams *Params +} + +func NewRequest(r *http.Request) *Request { + return &Request{ + Request: r, + Params: NewParams(func(key string) (string, bool) { + _ = r.ParseForm() + if values, ok := r.Form[key]; ok && len(values) > 0 { + return values[0], true + } + ctx := chi.RouteContext(r.Context()) + for k := len(ctx.URLParams.Keys) - 1; k >= 0; k-- { + if ctx.URLParams.Keys[k] == key { + return ctx.URLParams.Values[k], true + } + } + return "", false + }), + } +} + +func (r *Request) get(key string) (string, bool) { + _ = r.ParseForm() + values, ok := r.PostForm[key] + if !ok { + values, ok = r.Form[key] + } + if !ok { + values, ok = r.URL.Query()[key] + } + if ok { + return values[0], true + } + return "", false +} + +func (r *Request) PathParams() *Params { + if r.pathParams == nil { + r.pathParams = NewPathParams(r) + } + return r.pathParams +} + +func (r *Request) QueryParams() *Params { + if r.queryParams == nil { + r.queryParams = NewQueryParams(r) + } + return r.queryParams +} + +type ResponseWriter struct { + http.ResponseWriter + mutex sync.RWMutex + sent bool +} + +func NewResponseWriter(w http.ResponseWriter) *ResponseWriter { + return &ResponseWriter{ + ResponseWriter: w, + mutex: sync.RWMutex{}, + sent: false, + } +} + +func (w *ResponseWriter) IsSent() bool { + w.mutex.RLock() + defer w.mutex.RUnlock() + return w.sent +} + +func (w *ResponseWriter) Send(status int, body any) { + w.mutex.Lock() + defer w.mutex.Unlock() + + if w.sent { + LogWarning("the response writer was sent") + return + } + + buf, err := json.Marshal(body) + if err != nil { + LogError(err) + w.Fail(http.StatusInternalServerError, -1, err.Error()) + return + } + + w.sent = true + w.WriteHeader(status) + w.Header().Set("Content-Type", "application/json; charset=UTF-8") + if _, err := w.Write(buf); err != nil { + LogError(err) + } +} + +func (w *ResponseWriter) Fail(status int, code int, message ...string) { + info := map[string]any{"code": code, "message": ""} + if len(message) > 0 && len(message[0]) > 0 { + info["message"] = message[0] + } else { + info["message"] = http.StatusText(status) + } + w.Send(status, info) +} + +func (w *ResponseWriter) Error(err error) { + if ex, ok := err.(*Error); ok { + status := http.StatusBadRequest + if ex.Status > 0 { + status = ex.Status + } + w.Fail(status, ex.Code, ex.Message) + } else { + LogError(err) + w.Fail(http.StatusInternalServerError, -1, err.Error()) + } +} + +func (w *ResponseWriter) Ok(data any, message ...string) { + info := map[string]any{ + "code": 0, + "message": "ok", + "data": data, + } + if len(message) > 0 { + info["message"] = message[0] + } + w.Send(http.StatusOK, info) +} + +func Handler(hf HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + hf(NewResponseWriter(w), NewRequest(r)) + } +} + diff --git a/app/rts.go b/app/rts.go new file mode 100644 index 0000000..856d679 --- /dev/null +++ b/app/rts.go @@ -0,0 +1,220 @@ +package app + +import ( + "errors" + "github.com/go-chi/chi/v5" + "gorm.io/gorm" + "net/http" + "strconv" + "strings" +) + +func userInfoFromRequest(r *Request) (*User, error) { + var user User + + if name, ok := r.Get("name"); ok && len(name) > 0 { + user.Name = name + } else { + return nil, NewError(1, "缺少用户名称") + } + + if phoneNumber, ok := r.Get("phone_number"); ok && len(phoneNumber) > 0 { + if len(phoneNumber) != 11 { + return nil, NewError(2, "手机号码格式错误") + } + user.PhoneNumber = phoneNumber + } else { + return nil, NewError(2, "缺少手机号码") + } + + if password, ok := r.Get("password"); ok && len(password) > 0 { + if len(password) < 6 { + return nil, NewError(2, "密码太短") + } + user.Password = password + } else { + return nil, NewError(2, "缺少密码") + } + + return &user, nil +} + +// CreateUser 创建用户 +func CreateUser(w *ResponseWriter, r *Request) { + user, err := userInfoFromRequest(r) + if err != nil { + w.Error(err) + return + } + + var count int64 + if err = DB.Model(&User{}).Where("phone_number = ?", user.PhoneNumber).Count(&count).Error; err != nil { + w.Error(err) + return + } + if count > 0 { + w.Error(NewError(2, "手机号码已经被使用了")) + return + } + + if err = DB.Create(&user).Error; err != nil { + w.Error(err) + } else { + w.Ok(user, "创建用户成功") + } +} + +// UpdateUser 更新用户 +func UpdateUser(w *ResponseWriter, r *Request) { + id := chi.URLParam(r.Request, "id") + if len(id) == 0 { + w.Error(NewError(1, "缺少用户ID")) + } + user, err := userInfoFromRequest(r) + if err != nil { + w.Error(err) + return + } + +} + +// DeleteUser 删除用户 +func DeleteUser(w *ResponseWriter, r *Request) { + +} + +// ListUser 用户列表 +func ListUser(w *ResponseWriter, r *Request) { + +} + +// Login 用户登录 +func Login(w *ResponseWriter, r *Request) { + +} + +// CreateGoods 创建商品 +func CreateGoods(w *ResponseWriter, r *Request) { + name := r.Value("name") + price := r.Float32("price", 0) + if len(name) == 0 { + w.Fail(http.StatusBadRequest, 1, "商品名称错误") + return + } + if price <= 0 { + w.Fail(http.StatusBadRequest, 2, "商品价格错误") + return + } + var goods Goods + err := DB.First(&goods, "name = ?", name).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + goods = Goods{ + Name: name, + Price: price, + } + err = DB.Create(&goods).Error + if err != nil { + w.Fail(http.StatusBadRequest, 3, "创建商品失败") + } else { + w.Ok(goods) + } + } else if err != nil { + LogError(err) + w.Fail(http.StatusBadRequest, 4, "商品价格错误") + } else { + w.Fail(http.StatusBadRequest, 5, "商品已经存在") + } +} + +// UpdateGoods 更新产品信息 +func UpdateGoods(w *ResponseWriter, r *Request) { + name := r.Value("name") + price := r.Float32("price", 0) + id := uint(r.Uint64("id", 0)) + if len(name) == 0 { + w.Fail(http.StatusBadRequest, 1, "商品名称错误") + return + } + if price <= 0 { + w.Fail(http.StatusBadRequest, 2, "商品价格错误") + return + } + var goods Goods + err := DB.First(&goods, "id = ?", id).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + w.Fail(http.StatusBadRequest, 2, "商品不存在") + } else if err != nil { + LogError(err) + w.Fail(http.StatusBadRequest, 3, err.Error()) + } else { + // 商品名称不能重复 + err = DB.Where("id <> ?", id).First(&Goods{}, "name = ?", name).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + LogError(err) + w.Fail(http.StatusBadRequest, 4, err.Error()) + return + } + if goods.Name != name || goods.Price != price { + goods.Name = name + goods.Price = price + err = DB.Save(&goods).Error + if err != nil { + LogError(err) + w.Fail(http.StatusBadRequest, 5, err.Error()) + return + } + } + w.Ok(goods, "修改成功") + } +} + +// GetGoodsList 查询商品列表 +func GetGoodsList(w *ResponseWriter, r *Request) { + search := func(db *gorm.DB) *gorm.DB { + return db. + Model(&Goods{}). + Scopes(TimeRange(r, "created_at")). + Scopes(Paginate(r)). + Scopes(Search(r, "name", "name LIKE ?")) + } + var goodsList []Goods + var total int64 + var err error + if err = DB.Scopes(search).Count(&total).Error; err == nil { + err = DB.Scopes(search).Find(&goodsList).Error + } + if err != nil { + w.Fail(http.StatusInternalServerError, 1, err.Error()) + } else { + w.Ok(map[string]any{ + "list": goodsList, + "total": total, + }) + } +} + +func GetGoodsPrices(w *ResponseWriter, r *Request) { + id := uint(r.Uint64("id", 0)) + var goods Goods + if err := DB.Scopes(func(db *gorm.DB) *gorm.DB { + var queries []string + var args []any + if val, ok := r.Get("start_time"); ok { + queries = append(queries, "created_at >= ?") + args = append(args, val) + } + if val, ok := r.Get("end_time"); ok { + queries = append(queries, "created_at <= ?") + args = append(args, val) + } + if len(queries) == 0 { + return db.Preload("Prices") + } + args = append([]any{strings.Join(queries, " AND ")}, args...) + return db.Preload("Prices", args...) + }).First(&goods, "id = ?", id).Error; err != nil { + w.Fail(http.StatusInternalServerError, 1, err.Error()) + } else { + w.Ok(goods) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..225355f --- /dev/null +++ b/go.mod @@ -0,0 +1,15 @@ +module hupeh.vip/pricing + +go 1.19 + +require ( + github.com/go-chi/chi/v5 v5.0.8 + gorm.io/driver/sqlite v1.4.3 + gorm.io/gorm v1.24.2 +) + +require ( + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/mattn/go-sqlite3 v1.14.15 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..a693732 --- /dev/null +++ b/go.sum @@ -0,0 +1,14 @@ +github.com/go-chi/chi/v5 v5.0.8 h1:lD+NLqFcAi1ovnVZpsnObHGW4xb4J8lNmoYVfECH1Y0= +github.com/go-chi/chi/v5 v5.0.8/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= +github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= +gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= +gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= +gorm.io/gorm v1.24.2 h1:9wR6CFD+G8nOusLdvkZelOEhpJVwwHzpQOUM+REd6U0= +gorm.io/gorm v1.24.2/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= diff --git a/main.go b/main.go new file mode 100644 index 0000000..8c23778 --- /dev/null +++ b/main.go @@ -0,0 +1,24 @@ +package main + +import ( + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "hupeh.vip/pricing/app" + "log" + "net/http" +) +func main() { + app.ConfigLogger("debug.log", app.LogWhenMinute) + app.ConfigGormDB() + + r := chi.NewRouter() + r.Use(middleware.Logger) + r.Get("/goods", app.Handler(app.GetGoodsList)) + r.Post("/goods", app.Handler(app.CreateGoods)) + r.Get("/goods/:id/prices", app.Handler(app.GetGoodsPrices)) + r.Post("/goods/:id", app.Handler(app.UpdateGoods)) + + if err := http.ListenAndServe(":3000", r); err != nil { + log.Fatalln(err) + } +}