go项目脚手架
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
sorbet/pkg/db/query_builder.go

281 lines
5.9 KiB

package db
import (
"database/sql"
"gorm.io/gorm"
"math"
)
// QueryBuilder 查询构造器
// TODO(hupeh):实现 joins 和表别名
type QueryBuilder[T any] struct {
db *gorm.DB
selects []string
omits []string
expr *Expr
orders []string
limit int
offset int
distinct []any
preloads []preload
}
func NewQueryBuilder[T any](db *gorm.DB) *QueryBuilder[T] {
return &QueryBuilder[T]{expr: &Expr{}, db: db}
}
type preload struct {
query string
args []any
}
type Pager[T any] struct {
Total int `json:"total" xml:"total"` // 数据总数
Page int `json:"page" xml:"page"` // 当前页码
Limit int `json:"limit" xml:"limit"` // 数据容量
Items []*T `json:"items" xml:"items"` // 数据列表
}
func (q *QueryBuilder[T]) Select(columns ...string) *QueryBuilder[T] {
q.selects = append(q.selects, columns...)
return q
}
func (q *QueryBuilder[T]) Omit(columns ...string) *QueryBuilder[T] {
q.omits = append(q.omits, columns...)
return q
}
func (q *QueryBuilder[T]) Eq(col string, val any) *QueryBuilder[T] {
q.expr.Eq(col, val)
return q
}
func (q *QueryBuilder[T]) Neq(col string, val any) *QueryBuilder[T] {
q.expr.Neq(col, val)
return q
}
func (q *QueryBuilder[T]) Lt(col string, val any) *QueryBuilder[T] {
q.expr.Lt(col, val)
return q
}
func (q *QueryBuilder[T]) Lte(col string, val any) *QueryBuilder[T] {
q.expr.Lte(col, val)
return q
}
func (q *QueryBuilder[T]) Gt(col string, val any) *QueryBuilder[T] {
q.expr.Gt(col, val)
return q
}
func (q *QueryBuilder[T]) Gte(col string, val any) *QueryBuilder[T] {
q.expr.Gte(col, val)
return q
}
func (q *QueryBuilder[T]) Between(col string, less, more any) *QueryBuilder[T] {
q.expr.Between(col, less, more)
return q
}
func (q *QueryBuilder[T]) NotBetween(col string, less, more any) *QueryBuilder[T] {
q.expr.NotBetween(col, less, more)
return q
}
func (q *QueryBuilder[T]) IsNull(col string) *QueryBuilder[T] {
q.expr.IsNull(col)
return q
}
func (q *QueryBuilder[T]) NotNull(col string) *QueryBuilder[T] {
q.expr.NotNull(col)
return q
}
func (q *QueryBuilder[T]) Like(col, tpl string) *QueryBuilder[T] {
q.expr.Like(col, tpl)
return q
}
func (q *QueryBuilder[T]) NotLike(col, tpl string) *QueryBuilder[T] {
q.expr.NotLike(col, tpl)
return q
}
func (q *QueryBuilder[T]) In(col string, values ...any) *QueryBuilder[T] {
q.expr.In(col, values...)
return q
}
func (q *QueryBuilder[T]) NotIn(col string, values ...any) *QueryBuilder[T] {
q.expr.NotIn(col, values...)
return q
}
func (q *QueryBuilder[T]) When(condition bool, then func(ex *Expr), elses ...func(ex *Expr)) *QueryBuilder[T] {
q.expr.When(condition, then, elses...)
return q
}
func (q *QueryBuilder[T]) Or(or func(ex *Expr)) *QueryBuilder[T] {
q.expr.Or(or)
return q
}
func (q *QueryBuilder[T]) And(and func(ex *Expr)) *QueryBuilder[T] {
q.expr.And(and)
return q
}
func (q *QueryBuilder[T]) Not(not func(ex *Expr)) *QueryBuilder[T] {
q.expr.Not(not)
return q
}
func (q *QueryBuilder[T]) DescentBy(columns ...string) *QueryBuilder[T] {
for _, col := range columns {
q.orders = append(q.orders, col+" DESC")
}
return q
}
func (q *QueryBuilder[T]) AscentBy(columns ...string) *QueryBuilder[T] {
for _, col := range columns {
q.orders = append(q.orders, col)
}
return q
}
func (q *QueryBuilder[T]) Limit(limit int) *QueryBuilder[T] {
q.limit = limit
return q
}
func (q *QueryBuilder[T]) Offset(offset int) *QueryBuilder[T] {
q.offset = offset
return q
}
func (q *QueryBuilder[T]) Distinct(columns ...any) *QueryBuilder[T] {
q.distinct = append(q.distinct, columns...)
return q
}
func (q *QueryBuilder[T]) Preload(query string, args ...any) *QueryBuilder[T] {
q.preloads = append(q.preloads, preload{query, args})
return q
}
func (q *QueryBuilder[T]) Scopes(tx *gorm.DB) *gorm.DB {
tx = q.scopesWithoutEffect(tx)
if q.orders != nil {
for _, order := range q.orders {
tx = tx.Order(order)
}
}
if q.limit > 0 {
tx = tx.Limit(q.limit)
}
if q.offset > 0 {
tx = tx.Offset(q.offset)
}
if q.preloads != nil {
for _, pl := range q.preloads {
tx = tx.Preload(pl.query, pl.args...)
}
}
return tx
}
func (q *QueryBuilder[T]) scopesWithoutEffect(tx *gorm.DB) *gorm.DB {
var entity T
tx = tx.Model(&entity)
if q.selects != nil {
tx = tx.Select(q.selects)
}
if q.omits != nil {
tx = tx.Omit(q.omits...)
}
if len(q.distinct) > 0 {
tx = tx.Distinct(q.distinct...)
}
return q.expr.Scopes(tx)
}
func (q *QueryBuilder[T]) Count() (int64, error) {
var count int64
err := q.db.Scopes(q.scopesWithoutEffect).Count(&count).Error
return count, err
}
func (q *QueryBuilder[T]) First(entity *T) error {
return q.db.Scopes(q.Scopes).First(entity).Error
}
func (q *QueryBuilder[T]) Take(entity *T) error {
return q.db.Scopes(q.Scopes).Take(entity).Error
}
func (q *QueryBuilder[T]) Last(entity *T) error {
return q.db.Scopes(q.Scopes).Last(entity).Error
}
func (q *QueryBuilder[T]) Find(entities *[]*T) error {
return q.db.Scopes(q.Scopes).Find(entities).Error
}
func (q *QueryBuilder[T]) Paginate() (*Pager[T], error) {
if q.limit <= 0 {
q.limit = 30
}
if q.offset < 0 {
q.offset = 0
}
count, err := q.Count()
if err != nil {
return nil, err
}
var items []*T
err = q.Find(&items)
if err != nil {
return nil, err
}
return &Pager[T]{
Total: int(count),
Page: int(math.Ceil(float64(q.offset)/float64(q.limit))) + 1,
Limit: q.limit,
Items: items,
}, nil
}
// Rows 返回行数据迭代器
//
// 使用示例:
//
// rows, err := q.Eq("name", "jack").Rows()
// if err != nil {
// panic(err)
// }
// defer rows.Close()
// for rows.Next() {
// var user User
// db.ScanRows(rows, &user)
// // do something
// }
func (q *QueryBuilder[T]) Rows() (*sql.Rows, error) {
return q.db.Scopes(q.Scopes).Rows()
}
// Pluck 获取指定列的值
//
// 示例:
//
// var names []string
// q.Pluck("name", &names)
func (q *QueryBuilder[T]) Pluck(column string, dest any) error {
return q.db.Scopes(q.Scopes).Pluck(column, dest).Error
}