You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
281 lines
5.9 KiB
281 lines
5.9 KiB
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"gorm.io/gorm"
|
|
"math"
|
|
)
|
|
|
|
// QueryBuilder 查询构造器
|
|
// TODO(hupeh):实现 joins 和表别名
|
|
type QueryBuilder[T any] struct {
|
|
db *gorm.DB
|
|
selects []string
|
|
omits []string
|
|
expr *Expr
|
|
orders []string
|
|
limit int
|
|
offset int
|
|
distinct []any
|
|
preloads []preload
|
|
}
|
|
|
|
func NewQueryBuilder[T any](db *gorm.DB) *QueryBuilder[T] {
|
|
return &QueryBuilder[T]{expr: &Expr{}, db: db}
|
|
}
|
|
|
|
type preload struct {
|
|
query string
|
|
args []any
|
|
}
|
|
|
|
type Pager[T any] struct {
|
|
Total int `json:"total" xml:"total"` // 数据总数
|
|
Page int `json:"page" xml:"page"` // 当前页码
|
|
Limit int `json:"limit" xml:"limit"` // 数据容量
|
|
Items []*T `json:"items" xml:"items"` // 数据列表
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Select(columns ...string) *QueryBuilder[T] {
|
|
q.selects = append(q.selects, columns...)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Omit(columns ...string) *QueryBuilder[T] {
|
|
q.omits = append(q.omits, columns...)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Eq(col string, val any) *QueryBuilder[T] {
|
|
q.expr.Eq(col, val)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Neq(col string, val any) *QueryBuilder[T] {
|
|
q.expr.Neq(col, val)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Lt(col string, val any) *QueryBuilder[T] {
|
|
q.expr.Lt(col, val)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Lte(col string, val any) *QueryBuilder[T] {
|
|
q.expr.Lte(col, val)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Gt(col string, val any) *QueryBuilder[T] {
|
|
q.expr.Gt(col, val)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Gte(col string, val any) *QueryBuilder[T] {
|
|
q.expr.Gte(col, val)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Between(col string, less, more any) *QueryBuilder[T] {
|
|
q.expr.Between(col, less, more)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) NotBetween(col string, less, more any) *QueryBuilder[T] {
|
|
q.expr.NotBetween(col, less, more)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) IsNull(col string) *QueryBuilder[T] {
|
|
q.expr.IsNull(col)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) NotNull(col string) *QueryBuilder[T] {
|
|
q.expr.NotNull(col)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Like(col, tpl string) *QueryBuilder[T] {
|
|
q.expr.Like(col, tpl)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) NotLike(col, tpl string) *QueryBuilder[T] {
|
|
q.expr.NotLike(col, tpl)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) In(col string, values ...any) *QueryBuilder[T] {
|
|
q.expr.In(col, values...)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) NotIn(col string, values ...any) *QueryBuilder[T] {
|
|
q.expr.NotIn(col, values...)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) When(condition bool, then func(ex *Expr), elses ...func(ex *Expr)) *QueryBuilder[T] {
|
|
q.expr.When(condition, then, elses...)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Or(or func(ex *Expr)) *QueryBuilder[T] {
|
|
q.expr.Or(or)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) And(and func(ex *Expr)) *QueryBuilder[T] {
|
|
q.expr.And(and)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Not(not func(ex *Expr)) *QueryBuilder[T] {
|
|
q.expr.Not(not)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) DescentBy(columns ...string) *QueryBuilder[T] {
|
|
for _, col := range columns {
|
|
q.orders = append(q.orders, col+" DESC")
|
|
}
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) AscentBy(columns ...string) *QueryBuilder[T] {
|
|
for _, col := range columns {
|
|
q.orders = append(q.orders, col)
|
|
}
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Limit(limit int) *QueryBuilder[T] {
|
|
q.limit = limit
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Offset(offset int) *QueryBuilder[T] {
|
|
q.offset = offset
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Distinct(columns ...any) *QueryBuilder[T] {
|
|
q.distinct = append(q.distinct, columns...)
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Preload(query string, args ...any) *QueryBuilder[T] {
|
|
q.preloads = append(q.preloads, preload{query, args})
|
|
return q
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Scopes(tx *gorm.DB) *gorm.DB {
|
|
tx = q.scopesWithoutEffect(tx)
|
|
if q.orders != nil {
|
|
for _, order := range q.orders {
|
|
tx = tx.Order(order)
|
|
}
|
|
}
|
|
if q.limit > 0 {
|
|
tx = tx.Limit(q.limit)
|
|
}
|
|
if q.offset > 0 {
|
|
tx = tx.Offset(q.offset)
|
|
}
|
|
if q.preloads != nil {
|
|
for _, pl := range q.preloads {
|
|
tx = tx.Preload(pl.query, pl.args...)
|
|
}
|
|
}
|
|
return tx
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) scopesWithoutEffect(tx *gorm.DB) *gorm.DB {
|
|
var entity T
|
|
tx = tx.Model(&entity)
|
|
if q.selects != nil {
|
|
tx = tx.Select(q.selects)
|
|
}
|
|
if q.omits != nil {
|
|
tx = tx.Omit(q.omits...)
|
|
}
|
|
if len(q.distinct) > 0 {
|
|
tx = tx.Distinct(q.distinct...)
|
|
}
|
|
return q.expr.Scopes(tx)
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Count() (int64, error) {
|
|
var count int64
|
|
err := q.db.Scopes(q.scopesWithoutEffect).Count(&count).Error
|
|
return count, err
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) First(entity *T) error {
|
|
return q.db.Scopes(q.Scopes).First(entity).Error
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Take(entity *T) error {
|
|
return q.db.Scopes(q.Scopes).Take(entity).Error
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Last(entity *T) error {
|
|
return q.db.Scopes(q.Scopes).Last(entity).Error
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Find(entities *[]*T) error {
|
|
return q.db.Scopes(q.Scopes).Find(entities).Error
|
|
}
|
|
|
|
func (q *QueryBuilder[T]) Paginate() (*Pager[T], error) {
|
|
if q.limit <= 0 {
|
|
q.limit = 30
|
|
}
|
|
if q.offset < 0 {
|
|
q.offset = 0
|
|
}
|
|
count, err := q.Count()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var items []*T
|
|
err = q.Find(&items)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Pager[T]{
|
|
Total: int(count),
|
|
Page: int(math.Ceil(float64(q.offset)/float64(q.limit))) + 1,
|
|
Limit: q.limit,
|
|
Items: items,
|
|
}, nil
|
|
}
|
|
|
|
// Rows 返回行数据迭代器
|
|
//
|
|
// 使用示例:
|
|
//
|
|
// rows, err := q.Eq("name", "jack").Rows()
|
|
// if err != nil {
|
|
// panic(err)
|
|
// }
|
|
// defer rows.Close()
|
|
// for rows.Next() {
|
|
// var user User
|
|
// db.ScanRows(rows, &user)
|
|
// // do something
|
|
// }
|
|
func (q *QueryBuilder[T]) Rows() (*sql.Rows, error) {
|
|
return q.db.Scopes(q.Scopes).Rows()
|
|
}
|
|
|
|
// Pluck 获取指定列的值
|
|
//
|
|
// 示例:
|
|
//
|
|
// var names []string
|
|
// q.Pluck("name", &names)
|
|
func (q *QueryBuilder[T]) Pluck(column string, dest any) error {
|
|
return q.db.Scopes(q.Scopes).Pluck(column, dest).Error
|
|
}
|
|
|