package db import ( "database/sql" "gorm.io/gorm" "math" ) // QueryBuilder 查询构造器 // TODO(hupeh):实现 joins 和表别名 type QueryBuilder[T any] struct { Expr db *gorm.DB selects []string omits []string orders []string limit int offset int distinct []any preloads []preload } func NewQueryBuilder[T any](db *gorm.DB) *QueryBuilder[T] { return &QueryBuilder[T]{Expr: Expr{}, db: db} } type preload struct { query string args []any } type Pager[T any] struct { Total int `json:"total" xml:"total"` // 数据总数 Page int `json:"page" xml:"page"` // 当前页码 Limit int `json:"limit" xml:"limit"` // 数据容量 Items []*T `json:"items" xml:"items"` // 数据列表 } func (q *QueryBuilder[T]) Select(columns ...string) *QueryBuilder[T] { q.selects = append(q.selects, columns...) return q } func (q *QueryBuilder[T]) Omit(columns ...string) *QueryBuilder[T] { q.omits = append(q.omits, columns...) return q } func (q *QueryBuilder[T]) DescentBy(columns ...string) *QueryBuilder[T] { for _, col := range columns { q.orders = append(q.orders, col+" DESC") } return q } func (q *QueryBuilder[T]) AscentBy(columns ...string) *QueryBuilder[T] { for _, col := range columns { q.orders = append(q.orders, col) } return q } func (q *QueryBuilder[T]) Limit(limit int) *QueryBuilder[T] { q.limit = limit return q } func (q *QueryBuilder[T]) Offset(offset int) *QueryBuilder[T] { q.offset = offset return q } func (q *QueryBuilder[T]) Distinct(columns ...any) *QueryBuilder[T] { q.distinct = append(q.distinct, columns...) return q } func (q *QueryBuilder[T]) Preload(query string, args ...any) *QueryBuilder[T] { q.preloads = append(q.preloads, preload{query, args}) return q } func (q *QueryBuilder[T]) Scopes(tx *gorm.DB) *gorm.DB { tx = q.scopesWithoutEffect(tx) if q.orders != nil { for _, order := range q.orders { tx = tx.Order(order) } } if q.limit > 0 { tx = tx.Limit(q.limit) } if q.offset > 0 { tx = tx.Offset(q.offset) } if q.preloads != nil { for _, pl := range q.preloads { tx = tx.Preload(pl.query, pl.args...) } } return tx } func (q *QueryBuilder[T]) scopesWithoutEffect(tx *gorm.DB) *gorm.DB { var entity T tx = tx.Model(&entity) if q.selects != nil { tx = tx.Select(q.selects) } if q.omits != nil { tx = tx.Omit(q.omits...) } if len(q.distinct) > 0 { tx = tx.Distinct(q.distinct...) } return q.Expr.Scopes(tx) } func (q *QueryBuilder[T]) Count() (int64, error) { var count int64 err := q.db.Scopes(q.scopesWithoutEffect).Count(&count).Error return count, err } func (q *QueryBuilder[T]) First(entity any) error { return q.db.Scopes(q.Scopes).First(entity).Error } func (q *QueryBuilder[T]) Take(entity any) error { return q.db.Scopes(q.Scopes).Take(entity).Error } func (q *QueryBuilder[T]) Last(entity any) error { return q.db.Scopes(q.Scopes).Last(entity).Error } func (q *QueryBuilder[T]) Find(entity any) error { return q.db.Scopes(q.Scopes).Find(entity).Error } func (q *QueryBuilder[T]) Paginate() (*Pager[T], error) { if q.limit <= 0 { q.limit = 30 } if q.offset < 0 { q.offset = 0 } count, err := q.Count() if err != nil { return nil, err } var items []*T err = q.Find(&items) if err != nil { return nil, err } return &Pager[T]{ Total: int(count), Page: int(math.Ceil(float64(q.offset)/float64(q.limit))) + 1, Limit: q.limit, Items: items, }, nil } // Rows 返回行数据迭代器 // // 使用示例: // // rows, err := q.Eq("name", "jack").Rows() // if err != nil { // panic(err) // } // defer rows.Close() // for rows.Next() { // var user User // db.ScanRows(rows, &user) // // do something // } func (q *QueryBuilder[T]) Rows() (*sql.Rows, error) { return q.db.Scopes(q.Scopes).Rows() } // Pluck 获取指定列的值 // // 示例: // // var names []string // q.Pluck("name", &names) func (q *QueryBuilder[T]) Pluck(column string, dest any) error { return q.db.Scopes(q.Scopes).Pluck(column, dest).Error }