From 59ae7e72f0e4295af763633e78bd1321f7bbd0c1 Mon Sep 17 00:00:00 2001 From: hupeh Date: Fri, 15 Sep 2023 15:44:34 +0800 Subject: [PATCH] =?UTF-8?q?feat(db):=20:sparkles:=20=E9=80=9A=E8=BF=87?= =?UTF-8?q?=E4=BC=A0=E9=80=92=E4=B8=8A=E4=B8=8B=E6=96=87=E5=8C=BA=E5=88=86?= =?UTF-8?q?=E4=B8=8D=E5=90=8C=E8=AF=B7=E6=B1=82=E7=9A=84=E6=97=A5=E5=BF=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/middleware/logger.go | 2 ++ main.go | 4 +-- pkg/db/log.go | 57 ----------------------------------- pkg/db/logger.go | 39 +++++++++++++++--------- pkg/db/repo.go | 1 - pkg/db/repository.go | 52 ++++++++++++++++++-------------- pkg/ioc/ioc.go | 15 +++++---- 7 files changed, 66 insertions(+), 104 deletions(-) delete mode 100644 pkg/db/log.go delete mode 100644 pkg/db/repo.go diff --git a/internal/middleware/logger.go b/internal/middleware/logger.go index 6e25d64..6355907 100644 --- a/internal/middleware/logger.go +++ b/internal/middleware/logger.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "fmt" "github.com/labstack/echo/v4" "github.com/rs/xid" @@ -34,6 +35,7 @@ func Logger(next echo.HandlerFunc) echo.HandlerFunc { req.Method, req.RequestURI, c.RealIP(), log.RawTime(start), ) + c.SetRequest(c.Request().WithContext(context.WithValue(c.Request().Context(), "logger", l))) c.SetLogger(util.NewCustomLogger(l)) if err = next(c); err != nil { c.Error(err) diff --git a/main.go b/main.go index 34396bf..52c6deb 100644 --- a/main.go +++ b/main.go @@ -71,8 +71,8 @@ func main() { //db := ioc.MustGet[gorm.DB]().WithContext(c.Request().Context()) //ioc.Fork().Bind(db) //repo := ioc.MustGet[repositories.CompanyRepository]() - repo.Create(&entities.Company{Name: "海苔一诺"}) - pager, err := repo.Paginate() + repo.Create(c.Request().Context(), &entities.Company{Name: "海苔一诺"}) + pager, err := repo.Paginate(c.Request().Context()) if err != nil { return err } diff --git a/pkg/db/log.go b/pkg/db/log.go deleted file mode 100644 index ca9dc18..0000000 --- a/pkg/db/log.go +++ /dev/null @@ -1,57 +0,0 @@ -package db - -// -//import ( -// "context" -// "gorm.io/gorm/dbLogger" -// "io" -// "log" -// "os" -// "time" -//) -// -//type dbLogger struct { -// console dbLogger.Interface -// persist dbLogger.Interface -//} -// -//func NewLogger(persistWriter io.Writer) dbLogger.Interface { -// return &dbLogger{ -// console: dbLogger.New(log.New(os.Stdout, "", log.Ltime|log.Lmicroseconds), dbLogger.Config{ -// SlowThreshold: 200 * time.Millisecond, -// Colorful: true, -// LogLevel: dbLogger.Info, -// }), -// persist: dbLogger.New(log.New(persistWriter, "\r\n", log.LstdFlags), dbLogger.Config{ -// SlowThreshold: 200 * time.Millisecond, -// LogLevel: dbLogger.Info, -// }), -// } -//} -// -//func (l *dbLogger) LogMode(level dbLogger.LogLevel) dbLogger.Interface { -// c := *l -// c.console = c.console.LogMode(level) -// c.persist = c.persist.LogMode(level) -// return &c -//} -// -//func (l *dbLogger) Info(ctx context.Context, s string, i ...interface{}) { -// l.console.Info(ctx, s, i...) -// l.persist.Info(ctx, s, i...) -//} -// -//func (l *dbLogger) Warn(ctx context.Context, s string, i ...interface{}) { -// l.console.Warn(ctx, s, i...) -// l.persist.Warn(ctx, s, i...) -//} -// -//func (l *dbLogger) Error(ctx context.Context, s string, i ...interface{}) { -// l.console.Error(ctx, s, i...) -// l.persist.Error(ctx, s, i...) -//} -// -//func (l *dbLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { -// l.console.Trace(ctx, begin, fc, err) -// l.persist.Trace(ctx, begin, fc, err) -//} diff --git a/pkg/db/logger.go b/pkg/db/logger.go index 6d31b8b..14281d9 100644 --- a/pkg/db/logger.go +++ b/pkg/db/logger.go @@ -9,6 +9,17 @@ import ( "time" ) +func getLogger(ctx context.Context) log.Logger { + if ctx == nil { + return log.Default() + } + l, ok := ctx.Value("logger").(log.Logger) + if ok && l != nil { + return l + } + return log.Default() +} + type dbLogger struct { SlowThreshold time.Duration } @@ -19,49 +30,49 @@ func (l *dbLogger) LogMode(level glog.LogLevel) glog.Interface { } // Info print info -func (l dbLogger) Info(ctx context.Context, msg string, data ...any) { - log.Info(msg, data...) +func (l *dbLogger) Info(ctx context.Context, msg string, data ...any) { + getLogger(ctx).Info(msg, data...) } // Warn print warn messages -func (l dbLogger) Warn(ctx context.Context, msg string, data ...any) { - log.Warn(msg, data...) +func (l *dbLogger) Warn(ctx context.Context, msg string, data ...any) { + getLogger(ctx).Warn(msg, data...) } // Error print error messages -func (l dbLogger) Error(ctx context.Context, msg string, data ...any) { - log.Error(msg, data...) +func (l *dbLogger) Error(ctx context.Context, msg string, data ...any) { + getLogger(ctx).Error(msg, data...) } // Trace print sql message -func (l dbLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { +func (l *dbLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { elapsed := time.Since(begin) switch { case err != nil && !errors.Is(err, glog.ErrRecordNotFound): sql, rows := fc() if rows == -1 { - log.Error("%s [rows:%v] %s [%.3fms]", err, "-", sql, float64(elapsed.Nanoseconds())/1e6) + l.Error(ctx, "%s [rows:%v] %s [%.3fms]", err, "-", sql, float64(elapsed.Nanoseconds())/1e6) } else { - log.Error("%s [rows:%v] %s [%.3fms]", err, rows, sql, float64(elapsed.Nanoseconds())/1e6) + l.Error(ctx, "%s [rows:%v] %s [%.3fms]", err, rows, sql, float64(elapsed.Nanoseconds())/1e6) } case elapsed > l.SlowThreshold && l.SlowThreshold != 0: sql, rows := fc() slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold) if rows == -1 { - log.Warn("%s [rows:%v] %s [%.3fms]", slowLog, "-", sql, float64(elapsed.Nanoseconds())/1e6) + l.Warn(ctx, "%s [rows:%v] %s [%.3fms]", slowLog, "-", sql, float64(elapsed.Nanoseconds())/1e6) } else { - log.Warn("%s [rows:%v] %s [%.3fms]", slowLog, rows, sql, float64(elapsed.Nanoseconds())/1e6) + l.Warn(ctx, "%s [rows:%v] %s [%.3fms]", slowLog, rows, sql, float64(elapsed.Nanoseconds())/1e6) } default: sql, rows := fc() if rows == -1 { - log.Trace("[rows:%v] %s [%.3fms]", "-", sql, float64(elapsed.Nanoseconds())/1e6, log.RawLevel("GORM")) + l.Info(ctx, "[rows:%v] %s [%.3fms]", "-", sql, float64(elapsed.Nanoseconds())/1e6, log.RawLevel("GORM")) } else { - log.Trace("[rows:%v] %s [%.3fms]", rows, sql, float64(elapsed.Nanoseconds())/1e6, log.RawLevel("GORM")) + l.Info(ctx, "[rows:%v] %s [%.3fms]", rows, sql, float64(elapsed.Nanoseconds())/1e6, log.RawLevel("GORM")) } } } -func (l dbLogger) ParamsFilter(ctx context.Context, sql string, params ...any) (string, []any) { +func (l *dbLogger) ParamsFilter(ctx context.Context, sql string, params ...any) (string, []any) { return sql, params } diff --git a/pkg/db/repo.go b/pkg/db/repo.go deleted file mode 100644 index 3a49c63..0000000 --- a/pkg/db/repo.go +++ /dev/null @@ -1 +0,0 @@ -package db diff --git a/pkg/db/repository.go b/pkg/db/repository.go index 458dda2..bff39f1 100644 --- a/pkg/db/repository.go +++ b/pkg/db/repository.go @@ -1,6 +1,7 @@ package db import ( + "context" "gorm.io/gorm" ) @@ -27,45 +28,52 @@ func NewRepositoryWith[T any](db *gorm.DB, pk ...string) *Repository[T] { return r } +func (r *Repository[T]) DB(ctx context.Context) *gorm.DB { + if ctx != nil { + return r.db.WithContext(ctx) + } + return r.db +} + // Create 创建数据 -func (r *Repository[T]) Create(entity *T) error { - return r.db.Model(&entity).Create(&entity).Error +func (r *Repository[T]) Create(ctx context.Context, entity *T) error { + return r.DB(ctx).Model(&entity).Create(&entity).Error } -func (r *Repository[T]) Delete(expr *Expr) (int64, error) { +func (r *Repository[T]) Delete(ctx context.Context, expr *Expr) (int64, error) { var entity T - res := r.db.Model(&entity).Scopes(expr.Scopes).Delete(&entity) + res := r.DB(ctx).Model(&entity).Scopes(expr.Scopes).Delete(&entity) return res.RowsAffected, res.Error } -func (r *Repository[T]) DeleteByID(id any) error { +func (r *Repository[T]) DeleteByID(ctx context.Context, id any) error { var entity T - return r.db.Delete(&entity, r.pk, id).Error + return r.DB(ctx).Delete(&entity, r.pk, id).Error } -func (r *Repository[T]) Update(expr *Expr, values map[string]any) (int64, error) { - res := r.db.Scopes(expr.Scopes).Updates(values) +func (r *Repository[T]) Update(ctx context.Context, expr *Expr, values map[string]any) (int64, error) { + res := r.DB(ctx).Scopes(expr.Scopes).Updates(values) return res.RowsAffected, res.Error } -func (r *Repository[T]) UpdateByID(id any, values map[string]any) error { +func (r *Repository[T]) UpdateByID(ctx context.Context, id any, values map[string]any) error { var entity T - return r.db.Model(&entity).Where(r.pk, id).Updates(values).Error + return r.DB(ctx).Model(&entity).Where(r.pk, id).Updates(values).Error } -func (r *Repository[T]) GetByID(id any) (*T, error) { +func (r *Repository[T]) GetByID(ctx context.Context, id any) (*T, error) { var entity T - err := r.db.Model(&entity).Where(r.pk, id).First(&entity).Error + err := r.DB(ctx).Model(&entity).Where(r.pk, id).First(&entity).Error if err != nil { return nil, err } return &entity, nil } -func (r *Repository[T]) Find(expr ...*Expr) ([]*T, error) { +func (r *Repository[T]) Find(ctx context.Context, expr ...*Expr) ([]*T, error) { var entity T var items []*T - err := r.db.Model(&entity).Scopes(func(tx *gorm.DB) *gorm.DB { + err := r.DB(ctx).Model(&entity).Scopes(func(tx *gorm.DB) *gorm.DB { for _, e := range expr { tx = e.Scopes(tx) } @@ -77,22 +85,22 @@ func (r *Repository[T]) Find(expr ...*Expr) ([]*T, error) { return items, nil } -func (r *Repository[T]) Paginate(expr ...*Expr) (*Pager[T], error) { - qb := NewQueryBuilder[T](r.db) +func (r *Repository[T]) Paginate(ctx context.Context, expr ...*Expr) (*Pager[T], error) { + qb := NewQueryBuilder[T](r.DB(ctx)) for _, e := range expr { qb.Expr = *e } return qb.Paginate() } -func (r *Repository[T]) NewDeleteBuilder() *DeleteBuilder[T] { - return NewDeleteBuilder[T](r.db) +func (r *Repository[T]) NewDeleteBuilder(ctx context.Context) *DeleteBuilder[T] { + return NewDeleteBuilder[T](r.DB(ctx)) } -func (r *Repository[T]) NewUpdateBuilder() *UpdateBuilder[T] { - return NewUpdateBuilder[T](r.db) +func (r *Repository[T]) NewUpdateBuilder(ctx context.Context) *UpdateBuilder[T] { + return NewUpdateBuilder[T](r.DB(ctx)) } -func (r *Repository[T]) NewQueryBuilder() *QueryBuilder[T] { - return NewQueryBuilder[T](r.db) +func (r *Repository[T]) NewQueryBuilder(ctx context.Context) *QueryBuilder[T] { + return NewQueryBuilder[T](r.DB(ctx)) } diff --git a/pkg/ioc/ioc.go b/pkg/ioc/ioc.go index 9320821..8b71253 100644 --- a/pkg/ioc/ioc.go +++ b/pkg/ioc/ioc.go @@ -1,7 +1,6 @@ package ioc import ( - "context" "errors" "reflect" ) @@ -61,16 +60,16 @@ func Resolve(i any) error { } // Get 获取指定类型的值 -func Get[T any](ctx context.Context) (*T, error) { - return NamedGet[T](ctx, "") +func Get[T any]() (*T, error) { + return NamedGet[T]("") } -func MustGet[T any](ctx context.Context) *T { - return MustNamedGet[T](ctx, "") +func MustGet[T any]() *T { + return MustNamedGet[T]("") } // NamedGet 通过注入的名称获取指定类型的值 -func NamedGet[T any](ctx context.Context, name string) (*T, error) { +func NamedGet[T any](name string) (*T, error) { var abs T t := reflect.TypeOf(&abs) v := global.NamedGet(name, t) @@ -83,8 +82,8 @@ func NamedGet[T any](ctx context.Context, name string) (*T, error) { return nil, ErrValueNotFound } -func MustNamedGet[T any](ctx context.Context, name string) *T { - v, err := NamedGet[T](ctx, name) +func MustNamedGet[T any](name string) *T { + v, err := NamedGet[T](name) if err != nil { panic(err) }