package db import ( "context" "gorm.io/gorm" ) type Repository[T any] struct { db *gorm.DB pk string // 默认 id } func NewRepository[T any](db ...*gorm.DB) *Repository[T] { for _, d := range db { return NewRepositoryWith[T](d) } return NewRepositoryWith[T](DB()) } func NewRepositoryWith[T any](db *gorm.DB, pk ...string) *Repository[T] { r := &Repository[T]{db: db, pk: "id"} for _, s := range pk { if s != "" { r.pk = s return r } } return r } func (r *Repository[T]) DB(ctx context.Context) *gorm.DB { if ctx != nil { return r.db.WithContext(ctx) } return r.db } // Create 创建数据 func (r *Repository[T]) Create(ctx context.Context, entity *T) error { return r.DB(ctx).Model(&entity).Create(&entity).Error } func (r *Repository[T]) Delete(ctx context.Context, expr *Expr) (int64, error) { var entity T res := r.DB(ctx).Model(&entity).Scopes(expr.Scopes).Delete(&entity) return res.RowsAffected, res.Error } func (r *Repository[T]) DeleteByID(ctx context.Context, id any) error { var entity T return r.DB(ctx).Delete(&entity, r.pk, id).Error } func (r *Repository[T]) Update(ctx context.Context, expr *Expr, values map[string]any) (int64, error) { res := r.DB(ctx).Scopes(expr.Scopes).Updates(values) return res.RowsAffected, res.Error } func (r *Repository[T]) UpdateByID(ctx context.Context, id any, values map[string]any) error { var entity T return r.DB(ctx).Model(&entity).Where(r.pk, id).Updates(values).Error } func (r *Repository[T]) GetByID(ctx context.Context, id any) (*T, error) { var entity T err := r.DB(ctx).Model(&entity).Where(r.pk, id).First(&entity).Error if err != nil { return nil, err } return &entity, nil } func (r *Repository[T]) GetBy(ctx context.Context, expr ...*Expr) (*T, error) { var entity T err := r.DB(ctx).Model(&entity).Scopes(func(tx *gorm.DB) *gorm.DB { for _, e := range expr { tx = e.Scopes(tx) } return tx }).First(&entity).Error if err != nil { return nil, err } return &entity, nil } func (r *Repository[T]) Find(ctx context.Context, expr ...*Expr) ([]*T, error) { var entity T var items []*T err := r.DB(ctx).Model(&entity).Scopes(func(tx *gorm.DB) *gorm.DB { for _, e := range expr { tx = e.Scopes(tx) } return tx }).Find(&items).Error if err != nil { return nil, err } return items, nil } func (r *Repository[T]) Paginate(ctx context.Context, expr ...*Expr) (*Pager[T], error) { qb := NewQueryBuilder[T](r.DB(ctx)) for _, e := range expr { qb.expr = e } return qb.Paginate() } func (r *Repository[T]) NewDeleteBuilder(ctx context.Context) *DeleteBuilder[T] { return NewDeleteBuilder[T](r.DB(ctx)) } func (r *Repository[T]) NewUpdateBuilder(ctx context.Context) *UpdateBuilder[T] { return NewUpdateBuilder[T](r.DB(ctx)) } func (r *Repository[T]) NewQueryBuilder(ctx context.Context) *QueryBuilder[T] { return NewQueryBuilder[T](r.DB(ctx)) }