diff --git a/internal/init.go b/internal/init.go index 15cbccb..0cd1d11 100644 --- a/internal/init.go +++ b/internal/init.go @@ -7,6 +7,7 @@ import ( "sorbet/internal/entities" "sorbet/internal/middleware" "sorbet/internal/util" + "sorbet/pkg/app" "sorbet/pkg/db" "sorbet/pkg/env" "sorbet/pkg/log" @@ -23,6 +24,7 @@ func Init() error { log.Error("同步数据表结构需要开启 [DB_CODE_FIRST],在生产模式下请务必关闭。") } } + app.OnStart(startServer) return nil } @@ -31,16 +33,16 @@ func syncEntities() error { &entities.Company{}, &entities.CompanyDepartment{}, &entities.CompanyStaff{}, - &entities.ConfigGroup{}, &entities.Config{}, + &entities.ConfigGroup{}, &entities.Feature{}, &entities.FeatureCategory{}, &entities.FeatureConfig{}, &entities.FeatureContent{}, &entities.FeatureContentChapter{}, &entities.FeatureContentDetail{}, - &entities.ResourceCategory{}, &entities.Resource{}, + &entities.ResourceCategory{}, &entities.SystemLog{}, &entities.SystemMenu{}, &entities.SystemPermission{}, @@ -50,7 +52,7 @@ func syncEntities() error { ) } -func Start() error { +func startServer() { e := echo.New() e.HideBanner = true e.HidePort = true @@ -59,13 +61,12 @@ func Start() error { e.Use(middleware.Recover()) e.Use(middleware.CORS()) e.Use(middleware.Logger) - return run(e) -} -func run(e *echo.Echo) error { port := env.String("SERVER_PORT", "1324") addr := fmt.Sprintf(":%s", port) // TODO(hupeh): 验证 addr 是否合法 - e.Logger.Fatal(e.Start(addr)) - return nil + e.Logger.Error(e.Start(addr)) + + // 通知应用退出 + app.Stop() } diff --git a/main.go b/main.go index 571bc08..b8ac967 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ package main import ( "log" "sorbet/internal" + "sorbet/pkg/app" "sorbet/pkg/env" ) @@ -13,7 +14,10 @@ func main() { if err := internal.Init(); err != nil { panic(err) } - if err := internal.Start(); err != nil { + if err := app.Init("--udp", "false"); err != nil { + panic(err) + } + if err := app.Loop(); err != nil { log.Panicln(err) } } diff --git a/pkg/app/app.go b/pkg/app/app.go new file mode 100644 index 0000000..852e4f3 --- /dev/null +++ b/pkg/app/app.go @@ -0,0 +1,124 @@ +package app + +import ( + "context" + "errors" + "os" + "sorbet/pkg/app/appkit" +) + +func Init(args ...string) (err error) { + if len(args) == 0 { + args = os.Args[1:] + } + + switch Status() { + case Initialized: + return nil + case Starting: + return errors.New("starting") + case Running: + return errors.New("running") + case Stopping: + return errors.New("stopping") + case Idle, Stopped: + // continue + } + + initLifecycle() + initStats(args) + initUdp(args) + initBus() + initBuiltins() + + setStatus(Initialized) + + return +} + +func initBuiltins() { + // 订阅应用启动 + var sub appkit.Subscriber + sub = Sub("start", func([]byte) []byte { + sub.Cancel() + return nil + }) + + // 订阅应用停止 + Sub("stop", func(bytes []byte) []byte { + quit := exit + if quit != nil { + exit = nil + quit() + } + return nil + }) +} + +func Loop() error { + switch Status() { + case Starting, Running: + return nil + case Stopping: + return errors.New("stopping") + case Idle: + return errors.New("idle, you maybe forgot call app.Init()") + case Stopped: + return errors.New("stopped, you maybe forgot call app.Init()") + case Initialized: + // nothing + } + + // 释放内存 + defer free() + + wg.Add(1) + setStatus(Starting) + + // In a goroutine, we wait on for all goroutines to complete (for example + // timers). We use this to signal to the main thread to exit. + // wg.Add(1) basically translates to uv_ref, if this was Node. + // wg.Done() basically translates to uv_unref + go func() { + wg.Wait() + setStatus(Stopping) + stop() + }() + + // 启动应用指令 + Pub("start", nil) + + for { + select { + case msg := <-pubs: + err := dispatch(msg) + exitOnError(err) + wg.Done() // Corresponds to the wg.Add(1) in Pub(). + + case <-ctx.Done(): + // 只有等待所有协程全部完成后, + // 然后才可以退出主循环 + checkPubsEmpty() + setStatus(Stopped) + return ctx.Err() + } + + // We don't want to exit until we've received at least one message. + // This is so the program doesn't exit after sending the "start" + // message. + if Status() == Starting { + wg.Done() + start() + setStatus(Running) + } + } +} + +func Stop() { + Pub("stop", nil) +} + +func Go(fn func(context.Context)) { + assert(fn != nil, "invalid fn") + async(fn) +} diff --git a/pkg/app/appkit/bus.go b/pkg/app/appkit/bus.go new file mode 100644 index 0000000..03dd01a --- /dev/null +++ b/pkg/app/appkit/bus.go @@ -0,0 +1,5 @@ +package appkit + +type Subscriber interface { + Cancel() +} diff --git a/pkg/app/appkit/codec.go b/pkg/app/appkit/codec.go new file mode 100644 index 0000000..261ba23 --- /dev/null +++ b/pkg/app/appkit/codec.go @@ -0,0 +1,81 @@ +package appkit + +import ( + "bytes" + "encoding/binary" + "errors" + "math" +) + +// 数据协议如下: +// +// +-----------+--------+--------------+-------+------+ +// | 0x01,0x02 | length | topic-length | topic | data | +// +-----------+--------+--------------+-------+------+ +// 2 2 2 n n +// +// * 头部有两个模数 0x01、0x02; +// * length 表示消息的总长度; +// * 然后分别是 topic、data,包含其长度和数据; +// * 长度使用 uint16 类型,所以占用 2 个字节; +// +// todo 实现数据校验码,防止篡改 + +// Encode 数据编码 +func Encode(topic, data []byte) ([]byte, error) { + topicLen := len(topic) + dataLen := len(data) + size := 2 + topicLen + dataLen // 不包括模数和数据长度 + if topicLen == 0 { + return nil, errors.New("topic too small") + } + if size+4 > math.MaxUint16 { + return nil, errors.New("data too big") + } + buf := &bytes.Buffer{} + // 由于 模数、topic 和 data 都是基本类型,所以当 + // 执行 binary.Write() 时是不会出错的 + _ = binary.Write(buf, binary.BigEndian, []byte{0x01, 0x02}) + _ = binary.Write(buf, binary.BigEndian, uint16(size)) + _ = binary.Write(buf, binary.BigEndian, uint16(topicLen)) + _ = binary.Write(buf, binary.BigEndian, topic) + _ = binary.Write(buf, binary.BigEndian, data) // 如果没有数据会不会报错 + return buf.Bytes(), nil +} + +// Decode 数据解码 +func Decode(buf []byte) (topic, data []byte, err error) { + r := bytes.NewReader(buf) + // 读取模数 + var v [2]byte + err = binary.Read(r, binary.BigEndian, &v) + if err != nil { + return + } + if v[0] != 0x01 || v[1] != 0x02 { + err = errors.New("protocol error") + return + } + // 读取数据长度 + var size uint16 + err = binary.Read(r, binary.BigEndian, &size) + if err != nil { + return + } + // 读取主题长度 + var topicLen uint16 + err = binary.Read(r, binary.BigEndian, &topicLen) + if err != nil { + return + } + // 读取主题 + topic = make([]byte, topicLen) + err = binary.Read(r, binary.BigEndian, &topic) + if err != nil { + return + } + // 读取数据 + data = make([]byte, size-2-topicLen) + err = binary.Read(r, binary.BigEndian, &data) + return +} diff --git a/pkg/app/bus.go b/pkg/app/bus.go new file mode 100644 index 0000000..fa0225f --- /dev/null +++ b/pkg/app/bus.go @@ -0,0 +1,215 @@ +package app + +import ( + "errors" + "net" + "slices" + "sorbet/pkg/app/appkit" + "sync" + "sync/atomic" +) + +var ( + // 操作锁 + // + // * 添加订阅器 + lock sync.RWMutex + + // 用于控制主体程序并发,涉及以下几个方面: + // + // * 任务调度 + // * 定时器 + // * 发布订阅 + // * 主程序轮询 + wg sync.WaitGroup + + // 消息发布通道 + pubs chan Msg + + // 消息订阅器 + subs map[string][]*subscriber + + nextSubId int32 +) + +func initBus() { + if pubs == nil { + pubs = make(chan Msg, 256) + } + if subs == nil { + subs = make(map[string][]*subscriber) + } + OnFree(freeBus) +} + +func freeBus() { + clear(subs) + + select { + case _, ok := <-pubs: + if !ok { + // 通道已经关闭了,需要重新构建 + pubs = make(chan Msg, 256) + return + } + // 清空 pubs 缓存的数据,复用它 + l := len(pubs) + for i := 0; i < l; i++ { + <-pubs + } + default: + // 通道里面没有值,可以复用 + } +} + +func Sub(topic string, handle func([]byte) []byte) appkit.Subscriber { + lock.Lock() + defer lock.Unlock() + sub := &subscriber{ + id: atomic.AddInt32(&nextSubId, 1), + topic: topic, + handle: handle, + } + if _, ok := subs[topic]; !ok { + subs[topic] = make([]*subscriber, 0) + } + subs[topic] = append(subs[topic], sub) + return sub +} + +func Pub(topic string, data []byte) { + pubMsg(Msg{Data: data, topic: topic}) +} + +func pubMsg(msg Msg) { + wg.Add(1) + pubs <- msg +} + +type Msg struct { + Data []byte + + topic string + addr net.Addr + pc net.PacketConn +} + +// Size 消息长度 +func (m Msg) Size() int { + return 6 + len([]byte(m.topic)) + len(m.Data) +} + +type subscriber struct { + id int32 + topic string + handle func([]byte) []byte +} + +func (s *subscriber) Active() bool { + return atomic.LoadInt32(&s.id) > 0 +} + +func (s *subscriber) Cancel() { + id := atomic.SwapInt32(&s.id, 0) + if id == 0 { + return + } + lock.Lock() + defer lock.Unlock() + subscribers, ok := subs[s.topic] + if !ok { + return + } + if len(subscribers) > 0 { + subscribers = slices.DeleteFunc(subscribers, func(sub *subscriber) bool { + return sub.id == id + }) + } + if len(subscribers) > 0 { + subs[s.topic] = subscribers + } else { + delete(subs, s.topic) + } +} + +func (s *subscriber) invoke(data []byte) []byte { + if s.Active() { + return s.handle(data) + } + return nil +} + +func dispatch(msg Msg) error { + lock.RLock() + var cbs []*subscriber + if subscribers, ok := subs[msg.topic]; ok { + for _, s := range subscribers { + if s.Active() { + cbs = append(cbs, s) + } + } + } + lock.RUnlock() + + if len(cbs) == 0 { + return errors.New("no subscribers for topic " + msg.topic) + } + + var response []byte + for _, cb := range cbs { + res := cb.invoke(msg.Data) + if res != nil { + response = res + } + } + + if len(response) > 0 && msg.pc != nil { + stats.v8workerRespond++ + stats.v8workerBytesSent += len(response) + + n, ex := msg.pc.WriteTo(response, msg.addr) + if n > 0 { + stats.v8workerSend++ + stats.v8workerBytesSent += n + } + if ex != nil { + // todo handle the error + } + } + + return nil +} + +func recv(pc net.PacketConn, addr net.Addr, buf []byte) { + stats.v8workerRecv++ + stats.v8workerBytesRecv += len(buf) + + topic, data, err := appkit.Decode(buf) + if err != nil { + // todo handle the error + // return errors.New("invalid payload") + return + } + + pubMsg(Msg{ + Data: data, + topic: string(topic), + addr: addr, + pc: pc, + }) +} + +func checkPubsEmpty() { + // We've received a done event. As a sanity check, make sure that resChan is + // empty. + select { + case _, ok := <-pubs: + if ok { + panic("Read a message from pubs after context deadlined.") + } else { + panic("pubs closed. Unexpected.") + } + default: + // No value ready, moving on. + } +} diff --git a/pkg/app/lifecycle.go b/pkg/app/lifecycle.go new file mode 100644 index 0000000..ae1f8e8 --- /dev/null +++ b/pkg/app/lifecycle.go @@ -0,0 +1,83 @@ +package app + +import ( + "context" + "sync/atomic" +) + +const ( + Idle int32 = iota + Initialized + Starting + Running + Stopping + Stopped +) + +var ( + // 当前状态 + status int32 + + // 应用上下文 + ctx context.Context + exit context.CancelFunc + + start func() // 应用启动钩子 + stop func() // 应用停止钩子 + free func() // 内存释放钩子 +) + +func initLifecycle() { + ctx, exit = context.WithCancel(context.Background()) + + start = func() { + // nothing + } + + stop = func() { + if exit != nil { + exit() + } + } + + free = func() { + ctx = nil + start = nil + stop = nil + exit = nil + } +} + +// setStatus 设置应用状态 +func setStatus(newStatus int32) { + atomic.StoreInt32(&status, newStatus) +} + +// Status 返回当前状态 +func Status() int32 { + return atomic.LoadInt32(&status) +} + +func OnStart(fn func()) { + oldStart := start + start = func() { + oldStart() + fn() + } +} + +func OnStop(fn func()) { + oldStop := stop + stop = func() { + fn() + oldStop() + } +} + +func OnFree(fn func()) { + oldFree := free + free = func() { + fn() + oldFree() + } +} diff --git a/pkg/app/pubsub.go b/pkg/app/pubsub.go deleted file mode 100644 index 2e8b6fb..0000000 --- a/pkg/app/pubsub.go +++ /dev/null @@ -1,179 +0,0 @@ -package app - -import ( - "bytes" - "context" - "encoding/binary" - "fmt" - "io" - "os" - "sync" -) - -// Callback 订阅回调 -type Callback func(payload []byte) - -// Runnable 子协程运行 -type Runnable func(ctx context.Context) - -var wg sync.WaitGroup -var resChan chan []byte -var doneChan chan struct{} -var channels = make(map[string][]Callback) -var runtime = context.Background() - -var stats struct { - send int - respond int - receive int -} - -func Async(callback func()) { - wg.Add(1) - go func() { - defer wg.Done() - callback() - }() -} - -func process(buf []byte) (response []byte, err error) { - stats.receive++ - var n int - var channel []byte - n, channel, err = decode(buf) - if err != nil { - return - } - var payload []byte - _, payload, err = decode(buf[n:]) - if err != nil { - return - } - subscribers, ok := channels[string(channel)] - if !ok { - err = fmt.Errorf("no subscribers for channel %s", string(channel)) - return - } - for _, subscriber := range subscribers[:] { - // fixme: 是否需要使用 payload 的副本? - recovery(payload, subscriber) - } - if response != nil { - stats.respond++ - } - return -} - -func recovery(payload []byte, callback Callback) { - defer func() { - err := recover() - if err != nil { - fmt.Fprintf(os.Stderr, "%v", err) - } - }() - callback(payload) -} - -func Sub(channel string, callback Callback) { - subscribers, ok := channels[channel] - if !ok { - subscribers = make([]Callback, 0) - } - subscribers = append(subscribers, callback) - channels[channel] = subscribers -} - -// 定义写数据的格式 -func encode(data []byte) []byte { - // 4字节头部 + 可变体的长度 - buf := make([]byte, 4+len(data)) - // 写入头部,记录数据长度 - binary.BigEndian.PutUint32(buf[:4], uint32(len(data))) - // 将整个数据,放到4后边 - copy(buf[4:], data) - // 返回结果 - return buf -} - -func decode(buf []byte) (int, []byte, error) { - r := bytes.NewBuffer(buf) - // 读取头部记录的长度 - header := make([]byte, 4) - // 按长度读取消息 - _, err := io.ReadFull(r, header) - if err != nil { - return 0, nil, err - } - // 读取数据 - dataLen := binary.BigEndian.Uint32(header) - data := make([]byte, dataLen) - _, err = io.ReadFull(r, data) - if err != nil { - return 0, nil, err - } - return 4 + int(dataLen), data, nil -} - -func Pub(channel string, payload []byte) { - // fixme: 在时间和空间上,是否有比 append 性能更好的操作 - buf := append(encode([]byte(channel)), encode(payload)...) - wg.Add(1) - resChan <- buf -} - -func Run(runner Runnable) { - Async(func() { - runner(runtime) - }) -} - -func Loop() { - wg.Add(1) - first := true - - var cancel context.CancelFunc - runtime, cancel = context.WithCancel(context.Background()) - - go func() { - wg.Wait() - cancel() - doneChan <- struct{}{} - }() - - for { - select { - case msg := <-resChan: - process(msg) - wg.Done() // Corresponds to the wg.Add(1) in Pub(). - case <-doneChan: - // All goroutines have completed. Now we can exit main(). - checkChanEmpty() - return - } - - // We don't want to exit until we've received at least one message. - // This is so the program doesn't exit after sending the "start" - // message. - if first { - wg.Done() - first = false - } - } - - // todo 清理工作 -} - -func checkChanEmpty() { - // We've received a done event. As a sanity check, make sure that resChan is - // empty. - select { - case _, ok := <-resChan: - if ok { - panic("Read a message from resChan after doneChan closed.") - } else { - panic("resChan closed. Unexpected.") - } - default: - // No value ready, moving on. - } -} diff --git a/pkg/app/stats.go b/pkg/app/stats.go new file mode 100644 index 0000000..fdffc7f --- /dev/null +++ b/pkg/app/stats.go @@ -0,0 +1,44 @@ +package app + +import ( + "os" + "reflect" + "runtime/pprof" +) + +var stats struct { + v8workerSend int + v8workerRespond int + v8workerRecv int + v8workerBytesSent int + v8workerBytesRecv int +} + +func initStats(args []string) { + initCpuProfile(strArg(args, "--cpuprof")) + initMemoryProfile(strArg(args, "--memprof")) + OnFree(freeStats) +} + +func initCpuProfile(cpuprof string) { + if cpuprof != "" { + cpuProfile, err := os.Create(cpuprof) + check(err) + check(pprof.StartCPUProfile(cpuProfile)) + OnStop(pprof.StopCPUProfile) + } +} + +func initMemoryProfile(memprof string) { + if memprof != "" { + memProfile, err := os.Create(memprof) + check(err) + check(pprof.WriteHeapProfile(memProfile)) + OnStop(func() { check(memProfile.Close()) }) + } +} + +func freeStats() { + newStats := reflect.New(reflect.TypeOf(stats)) + reflect.ValueOf(stats).Set(newStats.Elem()) +} diff --git a/pkg/app/udp.go b/pkg/app/udp.go new file mode 100644 index 0000000..d09acd3 --- /dev/null +++ b/pkg/app/udp.go @@ -0,0 +1,63 @@ +package app + +import ( + "context" + "errors" + "io" + "net" +) + +// 通过 UDP 协议实现外部通信 +var pc net.PacketConn + +func initUdp(args []string) { + // 未开启 udp + if !boolArg(args, "--udp") { + return + } + + // 这一步属于初始化,切不可放到 OnStart 里面 + sizeLimit := intArg(args, "--payload", 1024<<2) + addr := strArg(args, "--addr", ":1324") + + // 当程序启动时开启 udp 服务器 + OnStart(func() { + listenUDP(addr, sizeLimit) + }) + + OnStop(freeUdp) +} + +func freeUdp() { + if pc != nil { + err := pc.Close() + pc = nil + check(err) + } +} + +func listenUDP(address string, maxSize int) { + var err error + pc, err = net.ListenPacket("udp", address) + check(err) + + async(func(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + default: + } + + buf := make([]byte, maxSize) + n, addr, err := pc.ReadFrom(buf) + if err != nil { + if errors.Is(err, io.EOF) { + return + } + continue + } + recv(pc, addr, buf[:n]) + } + }) +} diff --git a/pkg/app/util.go b/pkg/app/util.go new file mode 100644 index 0000000..39cf83e --- /dev/null +++ b/pkg/app/util.go @@ -0,0 +1,101 @@ +package app + +import ( + "context" + "errors" + "os" + "slices" + "strconv" + "strings" +) + +func assert(cond bool, msg string) { + if !cond { + panic(msg) + } +} + +func check(e error) { + if e != nil { + panic(e) + } +} + +func exitOnError(err error) { + if err != nil { + os.Stderr.WriteString(err.Error()) + os.Exit(1) + } +} + +func async(cb func(ctx context.Context)) { + wg.Add(1) + go func() { + defer wg.Done() + select { + case <-ctx.Done(): + return + default: + cb(ctx) + } + }() +} + +func getArg[T any](args []string, name string, f func(string, bool) T) T { + i := slices.Index(args, name) + if i == -1 { + return f("", false) + } + if i == len(args)-1 { + panic(errors.New("invalid " + name + " option")) + } + s := args[i+1] + if strings.HasPrefix(s, "--") { + return f("", true) + } + return f(s, true) +} + +func strArg(args []string, name string, def ...string) string { + return getArg(args, name, func(s string, b bool) string { + if b { + return s + } + for _, v := range def { + return v + } + return "" + }) +} + +func intArg(args []string, name string, def ...int) int { + return getArg(args, name, func(s string, b bool) int { + if b { + size, err := strconv.Atoi(s) + check(err) + assert(size > 0, "invalid "+name+" option") + return size + } + for _, v := range def { + return v + } + return 0 + }) +} + +func boolArg(args []string, name string, def ...bool) bool { + return getArg(args, name, func(s string, b bool) bool { + if b { + if s == "" { + return true + } + v, err := strconv.ParseBool(s) + check(err) + return v + } + for _, v := range def { + return v + } + return false + }) +}