You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
85 lines
2.2 KiB
85 lines
2.2 KiB
2 months ago
|
package db
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"log/slog"
|
||
|
"time"
|
||
|
|
||
|
"gorm.io/gorm/logger"
|
||
|
"zestack.dev/slim"
|
||
|
)
|
||
|
|
||
|
// TODO 支持租户隔离
|
||
|
func getLogger(ctx context.Context) *slog.Logger {
|
||
|
if ctx != nil {
|
||
|
l, ok := ctx.Value("db:logger").(*slim.Logger)
|
||
|
if ok && l != nil {
|
||
|
return l.Logger
|
||
|
}
|
||
|
k, ok := ctx.Value("db:logger").(*slog.Logger)
|
||
|
if ok && k != nil {
|
||
|
return k
|
||
|
}
|
||
|
}
|
||
|
return slog.Default()
|
||
|
}
|
||
|
|
||
|
type dbLogger struct {
|
||
|
SlowThreshold time.Duration
|
||
|
}
|
||
|
|
||
|
// LogMode log mode
|
||
|
func (l *dbLogger) LogMode(logger.LogLevel) logger.Interface {
|
||
|
return l
|
||
|
}
|
||
|
|
||
|
// Info print info
|
||
|
func (l *dbLogger) Info(ctx context.Context, msg string, data ...any) {
|
||
|
getLogger(ctx).Info(fmt.Sprintf(msg, data...))
|
||
|
}
|
||
|
|
||
|
// Warn print warn messages
|
||
|
func (l *dbLogger) Warn(ctx context.Context, msg string, data ...any) {
|
||
|
getLogger(ctx).Warn(fmt.Sprintf(msg, data...))
|
||
|
}
|
||
|
|
||
|
// Error print error messages
|
||
|
func (l *dbLogger) Error(ctx context.Context, msg string, data ...any) {
|
||
|
getLogger(ctx).Error(fmt.Sprintf(msg, data...))
|
||
|
}
|
||
|
|
||
|
// Trace print sql message
|
||
|
func (l *dbLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
||
|
elapsed := time.Since(begin)
|
||
|
switch {
|
||
|
case err != nil && !errors.Is(err, logger.ErrRecordNotFound):
|
||
|
sql, rows := fc()
|
||
|
if rows == -1 {
|
||
|
l.Error(ctx, "%s [rows:%v] %s [%.3fms]", err, "-", sql, float64(elapsed.Nanoseconds())/1e6)
|
||
|
} else {
|
||
|
l.Error(ctx, "%s [rows:%v] %s [%.3fms]", err, rows, sql, float64(elapsed.Nanoseconds())/1e6)
|
||
|
}
|
||
|
case elapsed > l.SlowThreshold && l.SlowThreshold != 0:
|
||
|
sql, rows := fc()
|
||
|
slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold)
|
||
|
if rows == -1 {
|
||
|
l.Warn(ctx, "%s [rows:%v] %s [%.3fms]", slowLog, "-", sql, float64(elapsed.Nanoseconds())/1e6)
|
||
|
} else {
|
||
|
l.Warn(ctx, "%s [rows:%v] %s [%.3fms]", slowLog, rows, sql, float64(elapsed.Nanoseconds())/1e6)
|
||
|
}
|
||
|
default:
|
||
|
sql, rows := fc()
|
||
|
if rows == -1 {
|
||
|
l.Info(ctx, "[rows:%v] %s [%.3fms]", "-", sql, float64(elapsed.Nanoseconds())/1e6)
|
||
|
} else {
|
||
|
l.Info(ctx, "[rows:%v] %s [%.3fms]", rows, sql, float64(elapsed.Nanoseconds())/1e6)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (l *dbLogger) ParamsFilter(_ context.Context, sql string, params ...any) (string, []any) {
|
||
|
return sql, params
|
||
|
}
|