From b499527ee6af3c7e6f65586cf765161f8f0da537 Mon Sep 17 00:00:00 2001 From: zqj <582132116@qq.com> Date: Sat, 2 Apr 2022 11:39:05 +0800 Subject: [PATCH] feat: 消息包用aes加密 --- cmd/gameserver/game.go | 10 +++++++++- cmd/test/client.go | 12 ++++++------ common/components/aes.go | 108 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ common/components/icompontents.go | 8 +++++++- common/components/pbsplitter.go | 85 +++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------------------ common/components/timewheel_test.go | 18 +++++++++++++++--- common/conf.go | 105 +++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------------------------------------- conf/conf.yaml | 1 + 8 files changed, 254 insertions(+), 93 deletions(-) create mode 100644 common/components/aes.go diff --git a/cmd/gameserver/game.go b/cmd/gameserver/game.go index b106efa..fd9faaf 100644 --- a/cmd/gameserver/game.go +++ b/cmd/gameserver/game.go @@ -9,6 +9,7 @@ import ( "pro2d/common/components" "pro2d/common/db/mongoproxy" "pro2d/common/db/redisproxy" + "pro2d/common/logger" "pro2d/models" "time" @@ -25,12 +26,19 @@ func NewGameServer(sconf *common.SConf) (*GameServer, error) { options := []components.ServerOption{ components.WithPlugin(components.NewPlugin(sconf.PluginPath)), - components.WithSplitter(components.NewPBSplitter()), + components.WithConnCbk(s.OnConnection), components.WithMsgCbk(s.OnMessage), components.WithCloseCbk(s.OnClose), components.WithTimerCbk(s.OnTimer), } + //加密 + if sconf.Encipher { + options = append(options, components.WithSplitter(components.NewPBSplitter(components.NewAesEncipher()))) + logger.Debug("open encipher aes...") + } else { + options = append(options, components.WithSplitter(components.NewPBSplitter(nil))) + } iserver := components.NewServer(sconf.Port, options...) iserver.SetActions(action.GetActionMap()) diff --git a/cmd/test/client.go b/cmd/test/client.go index e9b7029..3226a33 100644 --- a/cmd/test/client.go +++ b/cmd/test/client.go @@ -19,14 +19,14 @@ func main() { } loginReq := &pb.LoginReq{ - Token: "141815055745814528", + Token: "141815055745814528", Device: "123123", } - l, _ :=proto.Marshal(loginReq) + l, _ := proto.Marshal(loginReq) options := []components.ConnectorOption{ components.WithCtorCount(common.GlobalConf.TestClient.Count), - components.WithCtorSplitter(components.NewPBSplitter()), + components.WithCtorSplitter(components.NewPBSplitter(nil)), } client := components.NewConnector(common.GlobalConf.TestClient.Ip, common.GlobalConf.TestClient.Port, options...) @@ -35,8 +35,8 @@ func main() { return } - for { + for { client.Send(head.Cmd, l) - time.Sleep(1*time.Second) + time.Sleep(1 * time.Second) } -} \ No newline at end of file +} diff --git a/common/components/aes.go b/common/components/aes.go new file mode 100644 index 0000000..2e8abc6 --- /dev/null +++ b/common/components/aes.go @@ -0,0 +1,108 @@ +package components + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "encoding/base64" + "errors" +) + +var ( + PwdKey = []byte("luluszhaozhaoAAA") //RC4 zhaolu2333 +) + +type AesEncipher struct { +} + +func NewAesEncipher() IEncipher { + return &AesEncipher{} +} + +//pkcs7Padding 填充 +func (a *AesEncipher) pkcs7Padding(data []byte, blockSize int) []byte { + //判断缺少几位长度。最少1,最多 blockSize + padding := blockSize - len(data)%blockSize + //补足位数。把切片[]byte{byte(padding)}复制padding个 + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(data, padText...) +} + +//pkcs7UnPadding 填充的反向操作 +func (a *AesEncipher) pkcs7UnPadding(data []byte) ([]byte, error) { + length := len(data) + if length == 0 { + return nil, errors.New("加密字符串错误!") + } + //获取填充的个数 + unPadding := int(data[length-1]) + return data[:(length - unPadding)], nil +} + +//AesEncrypt 加密 +func (a *AesEncipher) Encrypt(data []byte) ([]byte, error) { + //创建加密实例 + block, err := aes.NewCipher(PwdKey) + if err != nil { + return nil, err + } + //判断加密快的大小 + blockSize := block.BlockSize() + //填充 + encryptBytes := a.pkcs7Padding(data, blockSize) + if len(encryptBytes)%blockSize != 0 { + return nil, errors.New("crypto/cipher: input not full blocks") + } + //初始化加密数据接收切片 + crypted := make([]byte, len(encryptBytes)) + //使用cbc加密模式 + blockMode := cipher.NewCBCEncrypter(block, PwdKey[:blockSize]) + //执行加密 + blockMode.CryptBlocks(crypted, encryptBytes) + return crypted, nil +} + +//AesDecrypt 解密 +func (a *AesEncipher) Decrypt(data []byte) ([]byte, error) { + //创建实例 + block, err := aes.NewCipher(PwdKey) + if err != nil { + return nil, err + } + //获取块的大小 + blockSize := block.BlockSize() + //使用cbc + blockMode := cipher.NewCBCDecrypter(block, PwdKey[:blockSize]) + //初始化解密数据接收切片 + dataLen := len(data) + if dataLen%blockSize != 0 { + return nil, errors.New("crypto/cipher: input not full blocks") + } + crypted := make([]byte, dataLen) + //执行解密 + blockMode.CryptBlocks(crypted, data) + //去除填充 + crypted, err = a.pkcs7UnPadding(crypted) + if err != nil { + return nil, err + } + return crypted, nil +} + +//EncryptByAes Aes加密 后 base64 再加 +func (a *AesEncipher) EncryptByAes(data []byte) (string, error) { + res, err := a.Encrypt(data) + if err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(res), nil +} + +//DecryptByAes Aes 解密 +func (a *AesEncipher) DecryptByAes(data string) ([]byte, error) { + dataByte, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return nil, err + } + return a.Decrypt(dataByte) +} diff --git a/common/components/icompontents.go b/common/components/icompontents.go index f90bd2e..907464a 100644 --- a/common/components/icompontents.go +++ b/common/components/icompontents.go @@ -31,6 +31,12 @@ type ( ParseMsg(data []byte, atEOF bool) (advance int, token []byte, err error) GetHeadLen() uint32 } + //加解密 + IEncipher interface { + Encrypt([]byte) ([]byte, error) + Decrypt([]byte) ([]byte, error) + } + ConnectionCallback func(IConnection) CloseCallback func(IConnection) MessageCallback func(IMessage) @@ -127,7 +133,7 @@ type ( SetProperty(key string, val interface{}) SetProperties(properties map[string]interface{}) - ParseFields(message protoreflect.Message ,properties map[string]interface{}) []int32 + ParseFields(message protoreflect.Message, properties map[string]interface{}) []int32 } ) diff --git a/common/components/pbsplitter.go b/common/components/pbsplitter.go index 4f44377..d99c1f6 100644 --- a/common/components/pbsplitter.go +++ b/common/components/pbsplitter.go @@ -8,10 +8,10 @@ import ( ) type PBHead struct { - Length uint32 - Cmd uint32 - ErrCode int32 - PreField uint32 + Length uint32 + Cmd uint32 + ErrCode int32 + PreField uint32 } func (h *PBHead) GetDataLen() uint32 { @@ -38,12 +38,11 @@ type PBMessage struct { conn IConnection } - func (m *PBMessage) GetHeader() IHead { return m.Head } -func (m *PBMessage) SetHeader(header IHead) { +func (m *PBMessage) SetHeader(header IHead) { m.Head = header } func (m *PBMessage) GetData() []byte { @@ -62,31 +61,21 @@ func (m *PBMessage) GetSession() IConnection { return m.conn } +type PBSplitter struct { + encipher IEncipher +} -type PBSplitter struct {} - -func NewPBSplitter() *PBSplitter { - return &PBSplitter{} +func NewPBSplitter(encipher IEncipher) ISplitter { + return &PBSplitter{ + encipher, + } } func (m *PBSplitter) GetHeadLen() uint32 { return uint32(binary.Size(PBHead{})) } -func (m *PBSplitter) UnPack(data []byte) (IMessage,error) { - h := &PBHead{} - err := binary.Read(bytes.NewReader(data), binary.BigEndian, h) - if err != nil { - return nil, err - } - - return &PBMessage{ - Head: h, - Body: data[m.GetHeadLen():], - },nil -} - -func (m *PBSplitter) ParseMsg (data []byte, atEOF bool) (advance int, token []byte, err error) { +func (m *PBSplitter) ParseMsg(data []byte, atEOF bool) (advance int, token []byte, err error) { // 表示我们已经扫描到结尾了 if atEOF && len(data) == 0 { return 0, nil, nil @@ -102,9 +91,9 @@ func (m *PBSplitter) ParseMsg (data []byte, atEOF bool) (advance int, token []by return 0, nil, fmt.Errorf("length exceeds maximum length") } if int(length) <= len(data) { - return int(length) , data[:int(length)], nil + return int(length), data[:int(length)], nil } - return 0 , nil, nil + return 0, nil, nil } if atEOF { return len(data), data, nil @@ -115,20 +104,56 @@ func (m *PBSplitter) ParseMsg (data []byte, atEOF bool) (advance int, token []by func (m *PBSplitter) Pack(cmd uint32, data []byte, errcode int32, preserve uint32) ([]byte, error) { buf := &bytes.Buffer{} h := &PBHead{ - Length: m.GetHeadLen()+ uint32(len(data)), + Length: m.GetHeadLen(), Cmd: cmd, ErrCode: errcode, PreField: preserve, } - err := binary.Write(buf, binary.BigEndian, h) + var dataEn []byte + var err error + if m.encipher != nil { + dataEn, err = m.encipher.Encrypt(data) + if err != nil { + return nil, err + } + } else { + dataEn = data + } + + h.Length += uint32(len(dataEn)) + + err = binary.Write(buf, binary.BigEndian, h) if err != nil { return nil, err } - err = binary.Write(buf, binary.BigEndian, data) + err = binary.Write(buf, binary.BigEndian, dataEn) if err != nil { return nil, err } return buf.Bytes(), nil -} \ No newline at end of file +} + +func (m *PBSplitter) UnPack(data []byte) (IMessage, error) { + h := &PBHead{} + err := binary.Read(bytes.NewReader(data), binary.BigEndian, h) + if err != nil { + return nil, err + } + + var dataDe []byte + if m.encipher != nil { + dataDe, err = m.encipher.Decrypt(data[m.GetHeadLen():]) + if err != nil { + return nil, err + } + } else { + dataDe = data[m.GetHeadLen():] + } + + return &PBMessage{ + Head: h, + Body: dataDe, + }, nil +} diff --git a/common/components/timewheel_test.go b/common/components/timewheel_test.go index aafe050..9dc339e 100644 --- a/common/components/timewheel_test.go +++ b/common/components/timewheel_test.go @@ -6,13 +6,25 @@ import ( "time" ) -func PRINT() { +func PRINT() { fmt.Println("12312312312") } func TestTimeWheel_Start(t *testing.T) { - TimeOut(1 * time.Second, func() { + TimeOut(1*time.Second, func() { fmt.Println("12312313123") }) - select{} + select {} +} + +func TestAesEncipher_Decrypt(t *testing.T) { + aes := AesEncipher{} + encode, err := aes.Encrypt([]byte("123")) + if err != nil { + fmt.Println(err.Error()) + return + } + + dec, err := aes.Decrypt(encode) + fmt.Printf("%s\n", dec) } diff --git a/common/conf.go b/common/conf.go index 5a43c3a..bd37c8d 100644 --- a/common/conf.go +++ b/common/conf.go @@ -12,86 +12,87 @@ import ( type RedisConf struct { Address string `json:"address"` - Auth string `json:"auth"` - DB int `json:"db"` + Auth string `json:"auth"` + DB int `json:"db"` } type Etcd struct { - Endpoints []string `json:"endpoints"` - DialTimeout int `json:"dialtimeout"` + Endpoints []string `json:"endpoints"` + DialTimeout int `json:"dialtimeout"` } type MongoConf struct { - User string `yaml:"user"` - Password string `yaml:"password"` - Host string `yaml:"host"` - Port int `yaml:"port"` - TimeOut int `yaml:"timeout"` - MaxNum int `yaml:"maxnum"` - DBName string `yaml:"dbname"` + User string `yaml:"user"` + Password string `yaml:"password"` + Host string `yaml:"host"` + Port int `yaml:"port"` + TimeOut int `yaml:"timeout"` + MaxNum int `yaml:"maxnum"` + DBName string `yaml:"dbname"` } type SConf struct { - ID string `yaml:"id"` - Name string `yaml:"name"` - IP string `yaml:"ip"` - Port int `yaml:"port"` - DebugPort int `yaml:"debugport"` - MongoConf *MongoConf `yaml:"mongo"` - RedisConf *RedisConf `yaml:"redis"` - WorkerPoolSize int `yaml:"pool_size"` - PluginPath string `yaml:"plugin_path"` + ID string `yaml:"id"` + Name string `yaml:"name"` + IP string `yaml:"ip"` + Port int `yaml:"port"` + Encipher bool `yaml:"encipher"` + DebugPort int `yaml:"debugport"` + MongoConf *MongoConf `yaml:"mongo"` + RedisConf *RedisConf `yaml:"redis"` + WorkerPoolSize int `yaml:"pool_size"` + PluginPath string `yaml:"plugin_path"` } type LogConsole struct { - Level string `yaml:"level" json:"level"` - Color bool `yaml:"color" json:"color"` + Level string `yaml:"level" json:"level"` + Color bool `yaml:"color" json:"color"` } type LogFile struct { - Level string `yaml:"level" json:"level"` - Daily bool `yaml:"daily" json:"daily"` - Maxlines int `yaml:"maxlines" json:"maxlines"` - Maxsize int `yaml:"maxsize" json:"maxsize"` - Maxdays int `yaml:"maxdays" json:"maxdays"` - Append bool `yaml:"append" json:"append"` - Permit string `yaml:"permit" json:"permit"` + Level string `yaml:"level" json:"level"` + Daily bool `yaml:"daily" json:"daily"` + Maxlines int `yaml:"maxlines" json:"maxlines"` + Maxsize int `yaml:"maxsize" json:"maxsize"` + Maxdays int `yaml:"maxdays" json:"maxdays"` + Append bool `yaml:"append" json:"append"` + Permit string `yaml:"permit" json:"permit"` } type LogConn struct { - Net string `yaml:"net" json:"net"` - Addr string `yaml:"addr" json:"addr"` - Level string `yaml:"level" json:"level"` - Reconnect bool `yaml:"reconnect" json:"reconnect"` - ReconnectOnMsg bool `yaml:"reconnectOnMsg" json:"reconnectOnMsg"` + Net string `yaml:"net" json:"net"` + Addr string `yaml:"addr" json:"addr"` + Level string `yaml:"level" json:"level"` + Reconnect bool `yaml:"reconnect" json:"reconnect"` + ReconnectOnMsg bool `yaml:"reconnectOnMsg" json:"reconnectOnMsg"` } type LogConf struct { - TimeFormat string `yaml:"TimeFormat" json:"TimeFormat"` - LogConsole *LogConsole `yaml:"Console" json:"Console"` - LogFile *LogFile `yaml:"File" json:"File"` - LogConn *LogConn `yaml:"Conn" json:"Conn"` + TimeFormat string `yaml:"TimeFormat" json:"TimeFormat"` + LogConsole *LogConsole `yaml:"Console" json:"Console"` + LogFile *LogFile `yaml:"File" json:"File"` + LogConn *LogConn `yaml:"Conn" json:"Conn"` } type TestClient struct { - Ip string `yaml:"ip"` - Port int`yaml:"port"` - Count int `yaml:"count"` + Ip string `yaml:"ip"` + Port int `yaml:"port"` + Count int `yaml:"count"` } type ServerConf struct { - ID string `yaml:"id"` - Name string `yaml:"name"` - WorkerID int64 `yaml:"workerid"` - DatacenterID int64 `yaml:"datacenterid"` - AccountConf *SConf `yaml:"server_account"` - GameConf *SConf `yaml:"server_game"` - LogConf *LogConf `yaml:"logconf" json:"logconf"` - TestClient *TestClient `yaml:"test_client"` - Etcd *Etcd `yaml:"etcd"` + ID string `yaml:"id"` + Name string `yaml:"name"` + WorkerID int64 `yaml:"workerid"` + DatacenterID int64 `yaml:"datacenterid"` + AccountConf *SConf `yaml:"server_account"` + GameConf *SConf `yaml:"server_game"` + LogConf *LogConf `yaml:"logconf" json:"logconf"` + TestClient *TestClient `yaml:"test_client"` + Etcd *Etcd `yaml:"etcd"` } -var( +var ( GlobalConf ServerConf SnowFlack *snow.Snowflake ) @@ -122,4 +123,4 @@ func init() { //初始化雪花算法 SnowFlack = snow.NewSnowflake(GlobalConf.WorkerID, GlobalConf.DatacenterID) -} \ No newline at end of file +} diff --git a/conf/conf.yaml b/conf/conf.yaml index 38cbbdf..9fadefb 100644 --- a/conf/conf.yaml +++ b/conf/conf.yaml @@ -35,6 +35,7 @@ server_game: id: "2" name: "game" ip: "192.168.0.206" + encipher: true port: 8850 pool_size: 1 debugport: 6061 -- libgit2 0.21.2