package db import ( "database/sql" "gorm.io/gorm" "math" ) // QueryBuilder 查询构造器 // TODO(hupeh):实现 joins 和表别名 type QueryBuilder[T any] struct { db *gorm.DB selects []string omits []string expr *Expr orders []string limit int offset int distinct []any preloads []preload } func NewQueryBuilder[T any](db *gorm.DB) *QueryBuilder[T] { return &QueryBuilder[T]{expr: &Expr{}, db: db} } type preload struct { query string args []any } type Pager[T any] struct { Total int `json:"total" xml:"total"` // 数据总数 Page int `json:"page" xml:"page"` // 当前页码 Limit int `json:"limit" xml:"limit"` // 数据容量 Items []*T `json:"items" xml:"items"` // 数据列表 } func (q *QueryBuilder[T]) Select(columns ...string) *QueryBuilder[T] { q.selects = append(q.selects, columns...) return q } func (q *QueryBuilder[T]) Omit(columns ...string) *QueryBuilder[T] { q.omits = append(q.omits, columns...) return q } func (q *QueryBuilder[T]) Eq(col string, val any) *QueryBuilder[T] { q.expr.Eq(col, val) return q } func (q *QueryBuilder[T]) Neq(col string, val any) *QueryBuilder[T] { q.expr.Neq(col, val) return q } func (q *QueryBuilder[T]) Lt(col string, val any) *QueryBuilder[T] { q.expr.Lt(col, val) return q } func (q *QueryBuilder[T]) Lte(col string, val any) *QueryBuilder[T] { q.expr.Lte(col, val) return q } func (q *QueryBuilder[T]) Gt(col string, val any) *QueryBuilder[T] { q.expr.Gt(col, val) return q } func (q *QueryBuilder[T]) Gte(col string, val any) *QueryBuilder[T] { q.expr.Gte(col, val) return q } func (q *QueryBuilder[T]) Between(col string, less, more any) *QueryBuilder[T] { q.expr.Between(col, less, more) return q } func (q *QueryBuilder[T]) NotBetween(col string, less, more any) *QueryBuilder[T] { q.expr.NotBetween(col, less, more) return q } func (q *QueryBuilder[T]) IsNull(col string) *QueryBuilder[T] { q.expr.IsNull(col) return q } func (q *QueryBuilder[T]) NotNull(col string) *QueryBuilder[T] { q.expr.NotNull(col) return q } func (q *QueryBuilder[T]) Like(col, tpl string) *QueryBuilder[T] { q.expr.Like(col, tpl) return q } func (q *QueryBuilder[T]) NotLike(col, tpl string) *QueryBuilder[T] { q.expr.NotLike(col, tpl) return q } func (q *QueryBuilder[T]) In(col string, values ...any) *QueryBuilder[T] { q.expr.In(col, values...) return q } func (q *QueryBuilder[T]) NotIn(col string, values ...any) *QueryBuilder[T] { q.expr.NotIn(col, values...) return q } func (q *QueryBuilder[T]) When(condition bool, then func(ex *Expr), elses ...func(ex *Expr)) *QueryBuilder[T] { q.expr.When(condition, then, elses...) return q } func (q *QueryBuilder[T]) Or(or func(ex *Expr)) *QueryBuilder[T] { q.expr.Or(or) return q } func (q *QueryBuilder[T]) And(and func(ex *Expr)) *QueryBuilder[T] { q.expr.And(and) return q } func (q *QueryBuilder[T]) Not(not func(ex *Expr)) *QueryBuilder[T] { q.expr.Not(not) return q } func (q *QueryBuilder[T]) DescentBy(columns ...string) *QueryBuilder[T] { for _, col := range columns { q.orders = append(q.orders, col+" DESC") } return q } func (q *QueryBuilder[T]) AscentBy(columns ...string) *QueryBuilder[T] { for _, col := range columns { q.orders = append(q.orders, col) } return q } func (q *QueryBuilder[T]) Limit(limit int) *QueryBuilder[T] { q.limit = limit return q } func (q *QueryBuilder[T]) Offset(offset int) *QueryBuilder[T] { q.offset = offset return q } func (q *QueryBuilder[T]) Distinct(columns ...any) *QueryBuilder[T] { q.distinct = append(q.distinct, columns...) return q } func (q *QueryBuilder[T]) Preload(query string, args ...any) *QueryBuilder[T] { q.preloads = append(q.preloads, preload{query, args}) return q } func (q *QueryBuilder[T]) Scopes(tx *gorm.DB) *gorm.DB { tx = q.scopesWithoutEffect(tx) if q.orders != nil { for _, order := range q.orders { tx = tx.Order(order) } } if q.limit > 0 { tx = tx.Limit(q.limit) } if q.offset > 0 { tx = tx.Offset(q.offset) } if q.preloads != nil { for _, pl := range q.preloads { tx = tx.Preload(pl.query, pl.args...) } } return tx } func (q *QueryBuilder[T]) scopesWithoutEffect(tx *gorm.DB) *gorm.DB { var entity T tx = tx.Model(&entity) if q.selects != nil { tx = tx.Select(q.selects) } if q.omits != nil { tx = tx.Omit(q.omits...) } if len(q.distinct) > 0 { tx = tx.Distinct(q.distinct...) } return q.expr.Scopes(tx) } func (q *QueryBuilder[T]) Count() (int64, error) { var count int64 err := q.db.Scopes(q.scopesWithoutEffect).Count(&count).Error return count, err } func (q *QueryBuilder[T]) First(entity *T) error { return q.db.Scopes(q.Scopes).First(entity).Error } func (q *QueryBuilder[T]) Take(entity *T) error { return q.db.Scopes(q.Scopes).Take(entity).Error } func (q *QueryBuilder[T]) Last(entity *T) error { return q.db.Scopes(q.Scopes).Last(entity).Error } func (q *QueryBuilder[T]) Find(entities *[]*T) error { return q.db.Scopes(q.Scopes).Find(entities).Error } func (q *QueryBuilder[T]) Paginate() (*Pager[T], error) { if q.limit <= 0 { q.limit = 30 } if q.offset < 0 { q.offset = 0 } count, err := q.Count() if err != nil { return nil, err } var items []*T err = q.Find(&items) if err != nil { return nil, err } return &Pager[T]{ Total: int(count), Page: int(math.Ceil(float64(q.offset)/float64(q.limit))) + 1, Limit: q.limit, Items: items, }, nil } // Rows 返回行数据迭代器 // // 使用示例: // // rows, err := q.Eq("name", "jack").Rows() // if err != nil { // panic(err) // } // defer rows.Close() // for rows.Next() { // var user User // db.ScanRows(rows, &user) // // do something // } func (q *QueryBuilder[T]) Rows() (*sql.Rows, error) { return q.db.Scopes(q.Scopes).Rows() } // Pluck 获取指定列的值 // // 示例: // // var names []string // q.Pluck("name", &names) func (q *QueryBuilder[T]) Pluck(column string, dest any) error { return q.db.Scopes(q.Scopes).Pluck(column, dest).Error }