Commit 58e37bfe35c35abf16c41fe07802ba09c33b0b1a

Authored by zhangqijia
1 parent 29b6d86f

add sync.Pool to conn, agent

cmd/gameserver/agent.go
... ... @@ -8,6 +8,7 @@ import (
8 8 "pro2d/common/logger"
9 9 "pro2d/models"
10 10 "pro2d/utils"
  11 + "sync"
11 12 "sync/atomic"
12 13 )
13 14  
... ... @@ -16,25 +17,21 @@ type Agent struct {
16 17 Server components.IServer
17 18  
18 19 Role *models.RoleModel
19   -
20   -
21   - Quit chan *Agent
22   -
23 20 nextCheckTime int64 //下一次检查的时间
24 21 lastHeartCheckTime int64
25 22 heartTimeoutCount int //超时次数
26 23 }
27 24  
28   -func NewAgent(s components.IServer) *Agent {
29   - return &Agent{
30   - Server: s,
  25 +var agentPool = sync.Pool{New: func() interface{} { return new(Agent)}}
31 26  
32   - Quit: make(chan *Agent),
  27 +func NewAgent(s components.IServer) *Agent {
  28 + a := agentPool.Get().(*Agent)
  29 + a.Server = s
33 30  
34   - nextCheckTime: 0,
35   - lastHeartCheckTime: utils.Timex(),
36   - heartTimeoutCount: 0,
37   - }
  31 + a.nextCheckTime = 0
  32 + a.lastHeartCheckTime = utils.Timex()
  33 + a.heartTimeoutCount= 0
  34 + return a
38 35 }
39 36  
40 37 func (c *Agent) OnConnection(conn components.IConnection) {
... ... @@ -45,16 +42,19 @@ func (c *Agent) OnMessage(msg components.IMessage) {
45 42 atomic.StoreInt64(&c.lastHeartCheckTime, utils.Timex())
46 43 md := c.Server.GetAction(msg.GetHeader().GetMsgID())
47 44 if md == nil {
48   - logger.Debug("cmd: %d, md is nil", msg.GetHeader().GetMsgID())
  45 + logger.Debug("cmd: %d, handler is nil", msg.GetHeader().GetMsgID())
49 46 return
50 47 }
51   -
52   - logger.Debug("protocode handler: %d", msg.GetHeader().GetMsgID())
53   - //fmt.Printf("errCode: %d, protomsg:%v\n", errCode, protomsg)
  48 + logger.Debug("protocolID: %d", msg.GetHeader().GetMsgID())
  49 + //fmt.Printf("errCode: %d, protoMsg:%v\n", errCode, protoMsg)
54 50  
55 51 f := md.(func (msg components.IMessage) (int32, interface{}))
56   - errCode, protomsg := f(msg)
57   - rsp, err := proto.Marshal(protomsg.(proto.Message))
  52 + errCode, protoMsg := f(msg)
  53 + if protoMsg == nil {
  54 + return
  55 + }
  56 +
  57 + rsp, err := proto.Marshal(protoMsg.(proto.Message))
58 58 if err != nil {
59 59 conn := msg.GetSession()
60 60 if conn != nil {
... ... @@ -65,9 +65,9 @@ func (c *Agent) OnMessage(msg components.IMessage) {
65 65 conn := msg.GetSession()
66 66 if conn != nil {
67 67 conn.Send(errCode, msg.GetHeader().GetMsgID(), rsp)
  68 + return
68 69 }
69   - return
70   - logger.Error("protocode not handler: %d", msg.GetHeader().GetMsgID())
  70 + logger.Error("protocol not handler: %d", msg.GetHeader().GetMsgID())
71 71 }
72 72  
73 73 func (c *Agent) OnTimer() {
... ... @@ -91,6 +91,8 @@ func (c *Agent) OnClose() {
91 91 }
92 92  
93 93 func (c *Agent) Close() {
  94 + agentPool.Put(c)
  95 +
94 96 if c.Role == nil {
95 97 return
96 98 }
... ... @@ -100,7 +102,7 @@ func (c *Agent) Close() {
100 102  
101 103 func (c *Agent) checkHeartBeat(now int64) {
102 104 lastHeartCheckTime := atomic.LoadInt64(&c.lastHeartCheckTime)
103   - logger.Debug("checkHeartBeat ID: %d, last: %d, now: %d", c.GetID(), lastHeartCheckTime, now)
  105 + //logger.Debug("checkHeartBeat ID: %d, last: %d, now: %d", c.GetID(), lastHeartCheckTime, now)
104 106 if math.Abs(float64(lastHeartCheckTime - now)) > common.HeartTimerInterval {
105 107 c.heartTimeoutCount++
106 108 if c.heartTimeoutCount >= common.HeartTimeoutCountMax {
... ...
cmd/gameserver/main.go
... ... @@ -20,7 +20,7 @@ func main() {
20 20  
21 21 s,err1 := NewGameServer(common.GlobalConf.GameConf)
22 22 if err1 != nil {
23   - fmt.Errorf(err1.Error())
  23 + logger.Error(err1)
24 24 return
25 25 }
26 26 go func() {
... ...
cmd/gameserver/plugin/plugin.go
... ... @@ -15,7 +15,7 @@ func init() {
15 15 func GetActionMap() map[interface{}]interface{} {
16 16 logger.Debug("init protocode...")
17 17 am := make(map[interface{}]interface{})
18   - am[uint32(pb.ProtoCode_LoginReq)] = "LoginRpc"
  18 + am[uint32(pb.ProtoCode_LoginReq)] = LoginRpc
19 19  
20 20 return am
21 21 }
... ...
common/components/conn.go
... ... @@ -6,6 +6,7 @@ import (
6 6 "net"
7 7 "pro2d/common"
8 8 "pro2d/common/logger"
  9 + "sync"
9 10 "sync/atomic"
10 11 "time"
11 12 )
... ... @@ -14,7 +15,7 @@ type Connection struct {
14 15 IConnection
15 16 net.Conn
16 17 Server IServer
17   - Id int
  18 + Id uint32
18 19  
19 20 scanner *bufio.Scanner
20 21 writer *bufio.Writer
... ... @@ -31,30 +32,49 @@ type Connection struct {
31 32 Status uint32
32 33 }
33 34  
  35 +var connectionPool = &sync.Pool{
  36 + New: func() interface{} { return new(Connection)},
  37 +}
  38 +
34 39 func NewConn(id int, conn net.Conn, s IServer) *Connection {
35   - c := &Connection{
36   - Id: id,
37   - Conn: conn,
38   - Server: s,
39   -
40   - scanner: bufio.NewScanner(conn),
41   - writer: bufio.NewWriter(conn),
42   - WBuffer: make(chan []byte, common.MaxMsgChan),
43   - Quit: make(chan *Connection),
44   - readFunc: make(chan func(), 10),
45   - timerFunc: make(chan func(), 10),
46   -
47   - Status: 0,
  40 + c := connectionPool.Get().(*Connection)
  41 + closed := atomic.LoadUint32(&c.Status)
  42 + if closed != 0 {
  43 + connectionPool.Put(c)
  44 + c = new(Connection)
48 45 }
49   - c.connectionCallback = c.defaultConnectionCallback
50   - c.messageCallback = c.defaultMessageCallback
51   - c.closeCallback = c.defaultCloseCallback
52   - c.timerCallback = c.defaultTimerCallback
  46 +
  47 + atomic.StoreUint32(&c.Id, uint32(id))
  48 + c.Conn = conn
  49 + c.Server = s
  50 +
  51 + c.scanner = bufio.NewScanner(conn)
  52 + c.writer = bufio.NewWriter(conn)
  53 +
  54 + c.reset()
  55 +
53 56 return c
54 57 }
55 58  
56   -func (c *Connection) GetID() int {
57   - return c.Id
  59 +func (c *Connection) reset() {
  60 + c.WBuffer = make(chan []byte, common.MaxMsgChan)
  61 + c.Quit = make(chan *Connection)
  62 +
  63 + if c.readFunc == nil {
  64 + c.readFunc = make(chan func(), 10)
  65 + }
  66 + if c.timerFunc == nil {
  67 + c.timerFunc = make(chan func(), 10)
  68 + }
  69 +
  70 + //c.connectionCallback = c.defaultConnectionCallback
  71 + //c.messageCallback = c.defaultMessageCallback
  72 + //c.closeCallback = c.defaultCloseCallback
  73 + //c.timerCallback = c.defaultTimerCallback
  74 +}
  75 +
  76 +func (c *Connection) GetID() uint32 {
  77 + return atomic.LoadUint32(&c.Id)
58 78 }
59 79  
60 80 func (c *Connection) SetConnectionCallback(cb ConnectionCallback) {
... ... @@ -78,12 +98,14 @@ func (c *Connection) Start() {
78 98 go c.read()
79 99 go c.listen()
80 100  
81   - c.Status = 1
  101 + atomic.StoreUint32(&c.Status, 1)
82 102 c.connectionCallback(c)
83 103 c.handleTimeOut()
84 104 }
85 105  
86 106 func (c *Connection) Stop() {
  107 + if atomic.LoadUint32(&c.Status) == 0 { return }
  108 +
87 109 sendTimeout := time.NewTimer(5 * time.Millisecond)
88 110 defer sendTimeout.Stop()
89 111 // 发送超时
... ... @@ -125,7 +147,10 @@ func (c *Connection) defaultTimerCallback(conn IConnection) {
125 147 }
126 148  
127 149 func (c *Connection) write() {
128   - defer c.quitting()
  150 + defer func() {
  151 + logger.Debug("write close")
  152 + c.Stop()
  153 + }()
129 154  
130 155 for msg := range c.WBuffer {
131 156 n, err := c.writer.Write(msg)
... ... @@ -141,7 +166,11 @@ func (c *Connection) write() {
141 166 }
142 167  
143 168 func (c *Connection) read() {
144   - defer c.quitting()
  169 + defer func() {
  170 + logger.Debug("read close")
  171 + c.Stop()
  172 + }()
  173 +
145 174 c.scanner.Split(c.Server.GetSplitter().ParseMsg)
146 175  
147 176 for c.scanner.Scan() {
... ... @@ -164,7 +193,10 @@ func (c *Connection) read() {
164 193  
165 194 //此设计目的是为了让网络数据与定时器处理都在一条协程里处理。不想加锁。。。
166 195 func (c *Connection) listen() {
167   - defer c.quitting()
  196 + defer func() {
  197 + logger.Debug("listen close")
  198 + c.quitting()
  199 + }()
168 200  
169 201 for {
170 202 select {
... ... @@ -179,6 +211,8 @@ func (c *Connection) listen() {
179 211 }
180 212  
181 213 func (c *Connection) handleTimeOut() {
  214 + if atomic.LoadUint32(&c.Status) == 0 { return }
  215 +
182 216 c.timerFunc <- func() {
183 217 c.timerCallback(c)
184 218 }
... ... @@ -186,16 +220,16 @@ func (c *Connection) handleTimeOut() {
186 220 }
187 221  
188 222 func (c *Connection) quitting() {
189   - closed := atomic.LoadUint32(&c.Status)
190   - if closed == 0 {
191   - return
192   - }
  223 + if atomic.LoadUint32(&c.Status) == 0 { return }
193 224 atomic.StoreUint32(&c.Status, 0)
194 225  
195 226 logger.Debug("ID: %d close", c.Id)
196 227 close(c.WBuffer)
197 228 close(c.Quit)
198   - close(c.readFunc)
  229 +
199 230 c.Conn.Close()
200 231 c.closeCallback(c)
  232 +
  233 + //放回到对象池
  234 + connectionPool.Put(c)
201 235 }
... ...
common/components/connmanage.go
... ... @@ -4,29 +4,29 @@ import &quot;sync&quot;
4 4  
5 5 type ConnManage struct {
6 6 mu sync.RWMutex
7   - conns map[int]IConnection
  7 + conns map[uint32]IConnection
8 8 }
9 9  
10 10 func NewConnManage() *ConnManage {
11 11 return &ConnManage{
12 12 mu: sync.RWMutex{},
13   - conns: make(map[int]IConnection),
  13 + conns: make(map[uint32]IConnection),
14 14 }
15 15 }
16 16  
17   -func (c *ConnManage) AddConn(id int, connection IConnection) {
  17 +func (c *ConnManage) AddConn(id uint32, connection IConnection) {
18 18 c.mu.Lock()
19 19 defer c.mu.Unlock()
20 20 c.conns[id] = connection
21 21 }
22 22  
23   -func (c *ConnManage) GetConn(id int) IConnection {
  23 +func (c *ConnManage) GetConn(id uint32) IConnection {
24 24 c.mu.RLock()
25 25 defer c.mu.RUnlock()
26 26 return c.conns[id]
27 27 }
28 28  
29   -func (c *ConnManage) DelConn(id int) IConnection {
  29 +func (c *ConnManage) DelConn(id uint32) IConnection {
30 30 c.mu.Lock()
31 31 defer c.mu.Unlock()
32 32 conn := c.conns[id]
... ... @@ -35,14 +35,12 @@ func (c *ConnManage) DelConn(id int) IConnection {
35 35 }
36 36  
37 37 func (c *ConnManage) Range(f func(key interface{}, value interface{}) bool) {
  38 + c.mu.Lock()
  39 + defer c.mu.Unlock()
38 40 for k, v := range c.conns {
39   - c.mu.Lock()
40 41 if ok := f(k, v); !ok {
41   - c.mu.Unlock()
42 42 return
43 43 }
44   - c.mu.Unlock()
45   -
46 44 }
47 45 }
48 46  
... ...
common/components/icompontents.go
... ... @@ -36,7 +36,7 @@ type (
36 36 TimerCallback func(IConnection)
37 37 //链接
38 38 IConnection interface {
39   - GetID() int
  39 + GetID() uint32
40 40 Start()
41 41 Stop()
42 42 Send(code int32, cmd uint32, b []byte) error
... ... @@ -48,9 +48,9 @@ type (
48 48 }
49 49 //connManage
50 50 IConnManage interface {
51   - AddConn(id int, connection IConnection)
52   - GetConn(id int) IConnection
53   - DelConn(id int) IConnection
  51 + AddConn(id uint32, connection IConnection)
  52 + GetConn(id uint32) IConnection
  53 + DelConn(id uint32) IConnection
54 54 Range(f func(key interface{}, value interface{}) bool)
55 55 StopAllConns()
56 56 }
... ...
common/components/server.go
... ... @@ -137,7 +137,7 @@ func (s *Server) Start() error {
137 137  
138 138 func (s *Server) Stop() {
139 139 StopTimer()
140   -
  140 + s.connManage.StopAllConns()
141 141 }
142 142  
143 143 func (s *Server) newConnection(conn IConnection) {
... ... @@ -146,7 +146,7 @@ func (s *Server) newConnection(conn IConnection) {
146 146 conn.SetMessageCallback(s.messageCallback)
147 147 conn.SetTimerCallback(s.timerCallback)
148 148  
149   - go conn.Start()
  149 + conn.Start()
150 150 }
151 151  
152 152 func (s *Server) removeConnection(conn IConnection) {
... ...
common/components/timerwheel.go
... ... @@ -2,7 +2,6 @@ package components
2 2  
3 3 import (
4 4 "container/list"
5   - "pro2d/common"
6 5 "sync"
7 6 "sync/atomic"
8 7 "time"
... ... @@ -16,6 +15,10 @@ const (
16 15 TimeLevel = 1 << TimeLevelShift
17 16 TimeNearMask = TimeNear - 1
18 17 TimeLevelMask = TimeLevel - 1
  18 +
  19 + //协程池 大小
  20 + WorkerPoolSize = 10
  21 + MaxTaskPerWorker = 20
19 22 )
20 23  
21 24 type bucket struct {
... ... @@ -74,14 +77,16 @@ type TimeWheel struct {
74 77  
75 78 WorkPool *WorkPool
76 79 exit chan struct{}
  80 + exitFlag uint32
77 81 }
78 82  
79 83 func NewTimeWheel() *TimeWheel {
80 84 tw := &TimeWheel{
81 85 tick: 10*time.Millisecond,
82 86 time: 0,
83   - WorkPool: NewWorkPool(common.WorkerPoolSize, common.MaxTaskPerWorker),
  87 + WorkPool: NewWorkPool(WorkerPoolSize, MaxTaskPerWorker),
84 88 exit: make(chan struct{}),
  89 + exitFlag: 0,
85 90 }
86 91 for i :=0; i < TimeNear; i++ {
87 92 tw.near[i] = newBucket()
... ... @@ -183,6 +188,12 @@ func (tw *TimeWheel) Start() {
183 188 }
184 189  
185 190 func (tw *TimeWheel) Stop() {
  191 + flag := atomic.LoadUint32(&tw.exitFlag)
  192 + if flag != 0 {
  193 + return
  194 + }
  195 +
  196 + atomic.StoreUint32(&tw.exitFlag, 1)
186 197 close(tw.exit)
187 198 }
188 199  
... ...
common/const.go
1 1 package common
2 2  
3 3 const (
4   - //协程池 大小
5   - WorkerPoolSize = 10
6   - MaxTaskPerWorker = 100
7   -
8 4 //最大包大
9 5 MaxPacketLength = 10 * 1024 * 1024
10 6 MaxMsgChan = 100
11 7  
12   - //jwt
13   - Pro2DTokenSignedString = "Pro2DSecret"
14   -
15   - //定时器
16   - TickMS = 10
17   - WheelSize = 3600
18   -
19 8 //心跳
20 9 HeartTimerInterval = 5 //s
21 10 HeartTimeoutCountMax = 20 //最大超时次数
... ...
doc/plugin.md
... ... @@ -33,6 +33,7 @@ am[uint32(1)] = HotRpc
33 33 2. 增加函数`HotRpc`
34 34 ```
35 35 func HotRpc(msg components.IMessage) (int32, interface{}) {
  36 + return 0, nil
36 37 }
37 38 ```
38 39  
... ...