diff --git a/pkg/app/pubsub.go b/pkg/app/pubsub.go new file mode 100644 index 0000000..2e8b6fb --- /dev/null +++ b/pkg/app/pubsub.go @@ -0,0 +1,179 @@ +package app + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "io" + "os" + "sync" +) + +// Callback 订阅回调 +type Callback func(payload []byte) + +// Runnable 子协程运行 +type Runnable func(ctx context.Context) + +var wg sync.WaitGroup +var resChan chan []byte +var doneChan chan struct{} +var channels = make(map[string][]Callback) +var runtime = context.Background() + +var stats struct { + send int + respond int + receive int +} + +func Async(callback func()) { + wg.Add(1) + go func() { + defer wg.Done() + callback() + }() +} + +func process(buf []byte) (response []byte, err error) { + stats.receive++ + var n int + var channel []byte + n, channel, err = decode(buf) + if err != nil { + return + } + var payload []byte + _, payload, err = decode(buf[n:]) + if err != nil { + return + } + subscribers, ok := channels[string(channel)] + if !ok { + err = fmt.Errorf("no subscribers for channel %s", string(channel)) + return + } + for _, subscriber := range subscribers[:] { + // fixme: 是否需要使用 payload 的副本? + recovery(payload, subscriber) + } + if response != nil { + stats.respond++ + } + return +} + +func recovery(payload []byte, callback Callback) { + defer func() { + err := recover() + if err != nil { + fmt.Fprintf(os.Stderr, "%v", err) + } + }() + callback(payload) +} + +func Sub(channel string, callback Callback) { + subscribers, ok := channels[channel] + if !ok { + subscribers = make([]Callback, 0) + } + subscribers = append(subscribers, callback) + channels[channel] = subscribers +} + +// 定义写数据的格式 +func encode(data []byte) []byte { + // 4字节头部 + 可变体的长度 + buf := make([]byte, 4+len(data)) + // 写入头部,记录数据长度 + binary.BigEndian.PutUint32(buf[:4], uint32(len(data))) + // 将整个数据,放到4后边 + copy(buf[4:], data) + // 返回结果 + return buf +} + +func decode(buf []byte) (int, []byte, error) { + r := bytes.NewBuffer(buf) + // 读取头部记录的长度 + header := make([]byte, 4) + // 按长度读取消息 + _, err := io.ReadFull(r, header) + if err != nil { + return 0, nil, err + } + // 读取数据 + dataLen := binary.BigEndian.Uint32(header) + data := make([]byte, dataLen) + _, err = io.ReadFull(r, data) + if err != nil { + return 0, nil, err + } + return 4 + int(dataLen), data, nil +} + +func Pub(channel string, payload []byte) { + // fixme: 在时间和空间上,是否有比 append 性能更好的操作 + buf := append(encode([]byte(channel)), encode(payload)...) + wg.Add(1) + resChan <- buf +} + +func Run(runner Runnable) { + Async(func() { + runner(runtime) + }) +} + +func Loop() { + wg.Add(1) + first := true + + var cancel context.CancelFunc + runtime, cancel = context.WithCancel(context.Background()) + + go func() { + wg.Wait() + cancel() + doneChan <- struct{}{} + }() + + for { + select { + case msg := <-resChan: + process(msg) + wg.Done() // Corresponds to the wg.Add(1) in Pub(). + case <-doneChan: + // All goroutines have completed. Now we can exit main(). + checkChanEmpty() + return + } + + // We don't want to exit until we've received at least one message. + // This is so the program doesn't exit after sending the "start" + // message. + if first { + wg.Done() + first = false + } + } + + // todo 清理工作 +} + +func checkChanEmpty() { + // We've received a done event. As a sanity check, make sure that resChan is + // empty. + select { + case _, ok := <-resChan: + if ok { + panic("Read a message from resChan after doneChan closed.") + } else { + panic("resChan closed. Unexpected.") + } + default: + // No value ready, moving on. + } +}