agent.go 3.12 KB
package service

import (
	"github.com/golang/protobuf/proto"
	"math"
	"pro2d/common"
	"pro2d/common/components"
	"pro2d/common/logger"
	"pro2d/models"
	"pro2d/pb"
	"sync"
	"sync/atomic"
)

type Agent struct {
	components.IConnection
	Server components.IServer
	components.IAgent

	Role               *models.RoleModel
	nextCheckTime      int64 //下一次检查的时间
	lastHeartCheckTime int64
	heartTimeoutCount  int //超时次数
}

var agentPool = sync.Pool{New: func() interface{} { return new(Agent) }}

func NewAgent(s components.IServer) *Agent {
	a := agentPool.Get().(*Agent)
	a.Server = s

	a.nextCheckTime = 0
	a.lastHeartCheckTime = common.Timex()
	a.heartTimeoutCount = 0
	return a
}

func (c *Agent) SetSchema(schema components.ISchema) {
	c.Role = schema.(*models.RoleModel)

	c.Server.GetConnManage().AddRID(c.Role.Role.Id, c.IConnection.GetID())
}

func (c *Agent) GetSchema() components.ISchema {
	return c.Role
}

func (c *Agent) SetServer(server components.IServer) {
	c.Server = server
}

func (c *Agent) GetServer() components.IServer {
	return c.Server
}

func (c *Agent) OnConnection(conn components.IConnection) {
	c.IConnection = conn
}

func (c *Agent) OnMessage(msg components.IMessage) {
	atomic.StoreInt64(&c.lastHeartCheckTime, common.Timex())
	if msg.GetHeader().GetMsgID() == uint32(pb.ProtoCode_HeartReq) {
		return
	}

	md := c.Server.GetAction(msg.GetHeader().GetMsgID())
	if md == nil {
		logger.Debug("cmd: %d, handler is nil", msg.GetHeader().GetMsgID())
		return
	}
	logger.Debug("protocolID: %d", msg.GetHeader().GetMsgID())
	//fmt.Printf("errCode: %d, protoMsg:%v\n", errCode, protoMsg)

	f := md.(func(agent components.IAgent, msg components.IMessage) (int32, interface{}))
	errCode, protoMsg := f(c, msg)

	if protoMsg == nil {
		c.Send(errCode, msg.GetHeader().GetMsgID(), nil)
		return
	}

	if errCode != 0 {
		logger.Error("errCode %d, msg: %v", errCode, protoMsg)
		c.Send(errCode, msg.GetHeader().GetMsgID(), nil)
		return
	}

	rsp, err := proto.Marshal(protoMsg.(proto.Message))
	if err != nil {
		c.Send(-100, msg.GetHeader().GetMsgID(), nil)
		return
	}
	c.Send(errCode, msg.GetHeader().GetMsgID(), rsp)
}

func (c *Agent) OnTimer() {
	nextCheckTime := atomic.LoadInt64(&c.nextCheckTime)
	now := common.Timex()
	if now >= nextCheckTime {
		//检查心跳
		c.checkHeartBeat(now)
		nextCheckTime = now + common.HeartTimerInterval
		atomic.StoreInt64(&c.nextCheckTime, nextCheckTime)
	}

	if c.Role != nil {
		//role 恢复数据
		c.Role.OnRecoverTimer(now)
	}
}

func (c *Agent) OnClose() {
	c.IConnection = nil
	c.Role = nil
	agentPool.Put(c)

	if c.Role == nil {
		return
	}

	c.Server.GetConnManage().DelRID(c.Role.Role.Id)
	c.Role.OnOfflineEvent()
}

func (c *Agent) checkHeartBeat(now int64) {
	lastHeartCheckTime := atomic.LoadInt64(&c.lastHeartCheckTime)
	//logger.Debug("checkHeartBeat ID: %d, last: %d, now: %d", c.GetID(), lastHeartCheckTime, now)
	if math.Abs(float64(lastHeartCheckTime-now)) > common.HeartTimerInterval {
		c.heartTimeoutCount++
		if c.heartTimeoutCount >= common.HeartTimeoutCountMax {
			c.Stop()
			return
		}
		logger.Debug("timeout count: %d", c.heartTimeoutCount)
	} else {
		c.heartTimeoutCount = 0
	}
}