You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
179 lines
3.4 KiB
179 lines
3.4 KiB
package app
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"sync"
|
|
)
|
|
|
|
// Callback 订阅回调
|
|
type Callback func(payload []byte)
|
|
|
|
// Runnable 子协程运行
|
|
type Runnable func(ctx context.Context)
|
|
|
|
var wg sync.WaitGroup
|
|
var resChan chan []byte
|
|
var doneChan chan struct{}
|
|
var channels = make(map[string][]Callback)
|
|
var runtime = context.Background()
|
|
|
|
var stats struct {
|
|
send int
|
|
respond int
|
|
receive int
|
|
}
|
|
|
|
func Async(callback func()) {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
callback()
|
|
}()
|
|
}
|
|
|
|
func process(buf []byte) (response []byte, err error) {
|
|
stats.receive++
|
|
var n int
|
|
var channel []byte
|
|
n, channel, err = decode(buf)
|
|
if err != nil {
|
|
return
|
|
}
|
|
var payload []byte
|
|
_, payload, err = decode(buf[n:])
|
|
if err != nil {
|
|
return
|
|
}
|
|
subscribers, ok := channels[string(channel)]
|
|
if !ok {
|
|
err = fmt.Errorf("no subscribers for channel %s", string(channel))
|
|
return
|
|
}
|
|
for _, subscriber := range subscribers[:] {
|
|
// fixme: 是否需要使用 payload 的副本?
|
|
recovery(payload, subscriber)
|
|
}
|
|
if response != nil {
|
|
stats.respond++
|
|
}
|
|
return
|
|
}
|
|
|
|
func recovery(payload []byte, callback Callback) {
|
|
defer func() {
|
|
err := recover()
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "%v", err)
|
|
}
|
|
}()
|
|
callback(payload)
|
|
}
|
|
|
|
func Sub(channel string, callback Callback) {
|
|
subscribers, ok := channels[channel]
|
|
if !ok {
|
|
subscribers = make([]Callback, 0)
|
|
}
|
|
subscribers = append(subscribers, callback)
|
|
channels[channel] = subscribers
|
|
}
|
|
|
|
// 定义写数据的格式
|
|
func encode(data []byte) []byte {
|
|
// 4字节头部 + 可变体的长度
|
|
buf := make([]byte, 4+len(data))
|
|
// 写入头部,记录数据长度
|
|
binary.BigEndian.PutUint32(buf[:4], uint32(len(data)))
|
|
// 将整个数据,放到4后边
|
|
copy(buf[4:], data)
|
|
// 返回结果
|
|
return buf
|
|
}
|
|
|
|
func decode(buf []byte) (int, []byte, error) {
|
|
r := bytes.NewBuffer(buf)
|
|
// 读取头部记录的长度
|
|
header := make([]byte, 4)
|
|
// 按长度读取消息
|
|
_, err := io.ReadFull(r, header)
|
|
if err != nil {
|
|
return 0, nil, err
|
|
}
|
|
// 读取数据
|
|
dataLen := binary.BigEndian.Uint32(header)
|
|
data := make([]byte, dataLen)
|
|
_, err = io.ReadFull(r, data)
|
|
if err != nil {
|
|
return 0, nil, err
|
|
}
|
|
return 4 + int(dataLen), data, nil
|
|
}
|
|
|
|
func Pub(channel string, payload []byte) {
|
|
// fixme: 在时间和空间上,是否有比 append 性能更好的操作
|
|
buf := append(encode([]byte(channel)), encode(payload)...)
|
|
wg.Add(1)
|
|
resChan <- buf
|
|
}
|
|
|
|
func Run(runner Runnable) {
|
|
Async(func() {
|
|
runner(runtime)
|
|
})
|
|
}
|
|
|
|
func Loop() {
|
|
wg.Add(1)
|
|
first := true
|
|
|
|
var cancel context.CancelFunc
|
|
runtime, cancel = context.WithCancel(context.Background())
|
|
|
|
go func() {
|
|
wg.Wait()
|
|
cancel()
|
|
doneChan <- struct{}{}
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case msg := <-resChan:
|
|
process(msg)
|
|
wg.Done() // Corresponds to the wg.Add(1) in Pub().
|
|
case <-doneChan:
|
|
// All goroutines have completed. Now we can exit main().
|
|
checkChanEmpty()
|
|
return
|
|
}
|
|
|
|
// We don't want to exit until we've received at least one message.
|
|
// This is so the program doesn't exit after sending the "start"
|
|
// message.
|
|
if first {
|
|
wg.Done()
|
|
first = false
|
|
}
|
|
}
|
|
|
|
// todo 清理工作
|
|
}
|
|
|
|
func checkChanEmpty() {
|
|
// We've received a done event. As a sanity check, make sure that resChan is
|
|
// empty.
|
|
select {
|
|
case _, ok := <-resChan:
|
|
if ok {
|
|
panic("Read a message from resChan after doneChan closed.")
|
|
} else {
|
|
panic("resChan closed. Unexpected.")
|
|
}
|
|
default:
|
|
// No value ready, moving on.
|
|
}
|
|
}
|
|
|