feat(db): 通过传递上下文区分不同请求的日志

main
熊二 1 year ago
parent 815fbd568f
commit 59ae7e72f0
  1. 2
      internal/middleware/logger.go
  2. 4
      main.go
  3. 57
      pkg/db/log.go
  4. 39
      pkg/db/logger.go
  5. 1
      pkg/db/repo.go
  6. 52
      pkg/db/repository.go
  7. 15
      pkg/ioc/ioc.go

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"context"
"fmt" "fmt"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/rs/xid" "github.com/rs/xid"
@ -34,6 +35,7 @@ func Logger(next echo.HandlerFunc) echo.HandlerFunc {
req.Method, req.RequestURI, c.RealIP(), req.Method, req.RequestURI, c.RealIP(),
log.RawTime(start), log.RawTime(start),
) )
c.SetRequest(c.Request().WithContext(context.WithValue(c.Request().Context(), "logger", l)))
c.SetLogger(util.NewCustomLogger(l)) c.SetLogger(util.NewCustomLogger(l))
if err = next(c); err != nil { if err = next(c); err != nil {
c.Error(err) c.Error(err)

@ -71,8 +71,8 @@ func main() {
//db := ioc.MustGet[gorm.DB]().WithContext(c.Request().Context()) //db := ioc.MustGet[gorm.DB]().WithContext(c.Request().Context())
//ioc.Fork().Bind(db) //ioc.Fork().Bind(db)
//repo := ioc.MustGet[repositories.CompanyRepository]() //repo := ioc.MustGet[repositories.CompanyRepository]()
repo.Create(&entities.Company{Name: "海苔一诺"}) repo.Create(c.Request().Context(), &entities.Company{Name: "海苔一诺"})
pager, err := repo.Paginate() pager, err := repo.Paginate(c.Request().Context())
if err != nil { if err != nil {
return err return err
} }

@ -1,57 +0,0 @@
package db
//
//import (
// "context"
// "gorm.io/gorm/dbLogger"
// "io"
// "log"
// "os"
// "time"
//)
//
//type dbLogger struct {
// console dbLogger.Interface
// persist dbLogger.Interface
//}
//
//func NewLogger(persistWriter io.Writer) dbLogger.Interface {
// return &dbLogger{
// console: dbLogger.New(log.New(os.Stdout, "", log.Ltime|log.Lmicroseconds), dbLogger.Config{
// SlowThreshold: 200 * time.Millisecond,
// Colorful: true,
// LogLevel: dbLogger.Info,
// }),
// persist: dbLogger.New(log.New(persistWriter, "\r\n", log.LstdFlags), dbLogger.Config{
// SlowThreshold: 200 * time.Millisecond,
// LogLevel: dbLogger.Info,
// }),
// }
//}
//
//func (l *dbLogger) LogMode(level dbLogger.LogLevel) dbLogger.Interface {
// c := *l
// c.console = c.console.LogMode(level)
// c.persist = c.persist.LogMode(level)
// return &c
//}
//
//func (l *dbLogger) Info(ctx context.Context, s string, i ...interface{}) {
// l.console.Info(ctx, s, i...)
// l.persist.Info(ctx, s, i...)
//}
//
//func (l *dbLogger) Warn(ctx context.Context, s string, i ...interface{}) {
// l.console.Warn(ctx, s, i...)
// l.persist.Warn(ctx, s, i...)
//}
//
//func (l *dbLogger) Error(ctx context.Context, s string, i ...interface{}) {
// l.console.Error(ctx, s, i...)
// l.persist.Error(ctx, s, i...)
//}
//
//func (l *dbLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
// l.console.Trace(ctx, begin, fc, err)
// l.persist.Trace(ctx, begin, fc, err)
//}

@ -9,6 +9,17 @@ import (
"time" "time"
) )
func getLogger(ctx context.Context) log.Logger {
if ctx == nil {
return log.Default()
}
l, ok := ctx.Value("logger").(log.Logger)
if ok && l != nil {
return l
}
return log.Default()
}
type dbLogger struct { type dbLogger struct {
SlowThreshold time.Duration SlowThreshold time.Duration
} }
@ -19,49 +30,49 @@ func (l *dbLogger) LogMode(level glog.LogLevel) glog.Interface {
} }
// Info print info // Info print info
func (l dbLogger) Info(ctx context.Context, msg string, data ...any) { func (l *dbLogger) Info(ctx context.Context, msg string, data ...any) {
log.Info(msg, data...) getLogger(ctx).Info(msg, data...)
} }
// Warn print warn messages // Warn print warn messages
func (l dbLogger) Warn(ctx context.Context, msg string, data ...any) { func (l *dbLogger) Warn(ctx context.Context, msg string, data ...any) {
log.Warn(msg, data...) getLogger(ctx).Warn(msg, data...)
} }
// Error print error messages // Error print error messages
func (l dbLogger) Error(ctx context.Context, msg string, data ...any) { func (l *dbLogger) Error(ctx context.Context, msg string, data ...any) {
log.Error(msg, data...) getLogger(ctx).Error(msg, data...)
} }
// Trace print sql message // Trace print sql message
func (l dbLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { func (l *dbLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
elapsed := time.Since(begin) elapsed := time.Since(begin)
switch { switch {
case err != nil && !errors.Is(err, glog.ErrRecordNotFound): case err != nil && !errors.Is(err, glog.ErrRecordNotFound):
sql, rows := fc() sql, rows := fc()
if rows == -1 { if rows == -1 {
log.Error("%s [rows:%v] %s [%.3fms]", err, "-", sql, float64(elapsed.Nanoseconds())/1e6) l.Error(ctx, "%s [rows:%v] %s [%.3fms]", err, "-", sql, float64(elapsed.Nanoseconds())/1e6)
} else { } else {
log.Error("%s [rows:%v] %s [%.3fms]", err, rows, sql, float64(elapsed.Nanoseconds())/1e6) l.Error(ctx, "%s [rows:%v] %s [%.3fms]", err, rows, sql, float64(elapsed.Nanoseconds())/1e6)
} }
case elapsed > l.SlowThreshold && l.SlowThreshold != 0: case elapsed > l.SlowThreshold && l.SlowThreshold != 0:
sql, rows := fc() sql, rows := fc()
slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold) slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)
if rows == -1 { if rows == -1 {
log.Warn("%s [rows:%v] %s [%.3fms]", slowLog, "-", sql, float64(elapsed.Nanoseconds())/1e6) l.Warn(ctx, "%s [rows:%v] %s [%.3fms]", slowLog, "-", sql, float64(elapsed.Nanoseconds())/1e6)
} else { } else {
log.Warn("%s [rows:%v] %s [%.3fms]", slowLog, rows, sql, float64(elapsed.Nanoseconds())/1e6) l.Warn(ctx, "%s [rows:%v] %s [%.3fms]", slowLog, rows, sql, float64(elapsed.Nanoseconds())/1e6)
} }
default: default:
sql, rows := fc() sql, rows := fc()
if rows == -1 { if rows == -1 {
log.Trace("[rows:%v] %s [%.3fms]", "-", sql, float64(elapsed.Nanoseconds())/1e6, log.RawLevel("GORM")) l.Info(ctx, "[rows:%v] %s [%.3fms]", "-", sql, float64(elapsed.Nanoseconds())/1e6, log.RawLevel("GORM"))
} else { } else {
log.Trace("[rows:%v] %s [%.3fms]", rows, sql, float64(elapsed.Nanoseconds())/1e6, log.RawLevel("GORM")) l.Info(ctx, "[rows:%v] %s [%.3fms]", rows, sql, float64(elapsed.Nanoseconds())/1e6, log.RawLevel("GORM"))
} }
} }
} }
func (l dbLogger) ParamsFilter(ctx context.Context, sql string, params ...any) (string, []any) { func (l *dbLogger) ParamsFilter(ctx context.Context, sql string, params ...any) (string, []any) {
return sql, params return sql, params
} }

@ -1 +0,0 @@
package db

@ -1,6 +1,7 @@
package db package db
import ( import (
"context"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -27,45 +28,52 @@ func NewRepositoryWith[T any](db *gorm.DB, pk ...string) *Repository[T] {
return r return r
} }
func (r *Repository[T]) DB(ctx context.Context) *gorm.DB {
if ctx != nil {
return r.db.WithContext(ctx)
}
return r.db
}
// Create 创建数据 // Create 创建数据
func (r *Repository[T]) Create(entity *T) error { func (r *Repository[T]) Create(ctx context.Context, entity *T) error {
return r.db.Model(&entity).Create(&entity).Error return r.DB(ctx).Model(&entity).Create(&entity).Error
} }
func (r *Repository[T]) Delete(expr *Expr) (int64, error) { func (r *Repository[T]) Delete(ctx context.Context, expr *Expr) (int64, error) {
var entity T var entity T
res := r.db.Model(&entity).Scopes(expr.Scopes).Delete(&entity) res := r.DB(ctx).Model(&entity).Scopes(expr.Scopes).Delete(&entity)
return res.RowsAffected, res.Error return res.RowsAffected, res.Error
} }
func (r *Repository[T]) DeleteByID(id any) error { func (r *Repository[T]) DeleteByID(ctx context.Context, id any) error {
var entity T var entity T
return r.db.Delete(&entity, r.pk, id).Error return r.DB(ctx).Delete(&entity, r.pk, id).Error
} }
func (r *Repository[T]) Update(expr *Expr, values map[string]any) (int64, error) { func (r *Repository[T]) Update(ctx context.Context, expr *Expr, values map[string]any) (int64, error) {
res := r.db.Scopes(expr.Scopes).Updates(values) res := r.DB(ctx).Scopes(expr.Scopes).Updates(values)
return res.RowsAffected, res.Error return res.RowsAffected, res.Error
} }
func (r *Repository[T]) UpdateByID(id any, values map[string]any) error { func (r *Repository[T]) UpdateByID(ctx context.Context, id any, values map[string]any) error {
var entity T var entity T
return r.db.Model(&entity).Where(r.pk, id).Updates(values).Error return r.DB(ctx).Model(&entity).Where(r.pk, id).Updates(values).Error
} }
func (r *Repository[T]) GetByID(id any) (*T, error) { func (r *Repository[T]) GetByID(ctx context.Context, id any) (*T, error) {
var entity T var entity T
err := r.db.Model(&entity).Where(r.pk, id).First(&entity).Error err := r.DB(ctx).Model(&entity).Where(r.pk, id).First(&entity).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &entity, nil return &entity, nil
} }
func (r *Repository[T]) Find(expr ...*Expr) ([]*T, error) { func (r *Repository[T]) Find(ctx context.Context, expr ...*Expr) ([]*T, error) {
var entity T var entity T
var items []*T var items []*T
err := r.db.Model(&entity).Scopes(func(tx *gorm.DB) *gorm.DB { err := r.DB(ctx).Model(&entity).Scopes(func(tx *gorm.DB) *gorm.DB {
for _, e := range expr { for _, e := range expr {
tx = e.Scopes(tx) tx = e.Scopes(tx)
} }
@ -77,22 +85,22 @@ func (r *Repository[T]) Find(expr ...*Expr) ([]*T, error) {
return items, nil return items, nil
} }
func (r *Repository[T]) Paginate(expr ...*Expr) (*Pager[T], error) { func (r *Repository[T]) Paginate(ctx context.Context, expr ...*Expr) (*Pager[T], error) {
qb := NewQueryBuilder[T](r.db) qb := NewQueryBuilder[T](r.DB(ctx))
for _, e := range expr { for _, e := range expr {
qb.Expr = *e qb.Expr = *e
} }
return qb.Paginate() return qb.Paginate()
} }
func (r *Repository[T]) NewDeleteBuilder() *DeleteBuilder[T] { func (r *Repository[T]) NewDeleteBuilder(ctx context.Context) *DeleteBuilder[T] {
return NewDeleteBuilder[T](r.db) return NewDeleteBuilder[T](r.DB(ctx))
} }
func (r *Repository[T]) NewUpdateBuilder() *UpdateBuilder[T] { func (r *Repository[T]) NewUpdateBuilder(ctx context.Context) *UpdateBuilder[T] {
return NewUpdateBuilder[T](r.db) return NewUpdateBuilder[T](r.DB(ctx))
} }
func (r *Repository[T]) NewQueryBuilder() *QueryBuilder[T] { func (r *Repository[T]) NewQueryBuilder(ctx context.Context) *QueryBuilder[T] {
return NewQueryBuilder[T](r.db) return NewQueryBuilder[T](r.DB(ctx))
} }

@ -1,7 +1,6 @@
package ioc package ioc
import ( import (
"context"
"errors" "errors"
"reflect" "reflect"
) )
@ -61,16 +60,16 @@ func Resolve(i any) error {
} }
// Get 获取指定类型的值 // Get 获取指定类型的值
func Get[T any](ctx context.Context) (*T, error) { func Get[T any]() (*T, error) {
return NamedGet[T](ctx, "") return NamedGet[T]("")
} }
func MustGet[T any](ctx context.Context) *T { func MustGet[T any]() *T {
return MustNamedGet[T](ctx, "") return MustNamedGet[T]("")
} }
// NamedGet 通过注入的名称获取指定类型的值 // NamedGet 通过注入的名称获取指定类型的值
func NamedGet[T any](ctx context.Context, name string) (*T, error) { func NamedGet[T any](name string) (*T, error) {
var abs T var abs T
t := reflect.TypeOf(&abs) t := reflect.TypeOf(&abs)
v := global.NamedGet(name, t) v := global.NamedGet(name, t)
@ -83,8 +82,8 @@ func NamedGet[T any](ctx context.Context, name string) (*T, error) {
return nil, ErrValueNotFound return nil, ErrValueNotFound
} }
func MustNamedGet[T any](ctx context.Context, name string) *T { func MustNamedGet[T any](name string) *T {
v, err := NamedGet[T](ctx, name) v, err := NamedGet[T](name)
if err != nil { if err != nil {
panic(err) panic(err)
} }

Loading…
Cancel
Save