Commit b499527ee6af3c7e6f65586cf765161f8f0da537

Authored by zhangqijia
1 parent d3faacd0

feat: 消息包用aes加密

cmd/gameserver/game.go
... ... @@ -9,6 +9,7 @@ import (
9 9 "pro2d/common/components"
10 10 "pro2d/common/db/mongoproxy"
11 11 "pro2d/common/db/redisproxy"
  12 + "pro2d/common/logger"
12 13 "pro2d/models"
13 14 "time"
14 15  
... ... @@ -25,12 +26,19 @@ func NewGameServer(sconf *common.SConf) (*GameServer, error) {
25 26  
26 27 options := []components.ServerOption{
27 28 components.WithPlugin(components.NewPlugin(sconf.PluginPath)),
28   - components.WithSplitter(components.NewPBSplitter()),
  29 +
29 30 components.WithConnCbk(s.OnConnection),
30 31 components.WithMsgCbk(s.OnMessage),
31 32 components.WithCloseCbk(s.OnClose),
32 33 components.WithTimerCbk(s.OnTimer),
33 34 }
  35 + //加密
  36 + if sconf.Encipher {
  37 + options = append(options, components.WithSplitter(components.NewPBSplitter(components.NewAesEncipher())))
  38 + logger.Debug("open encipher aes...")
  39 + } else {
  40 + options = append(options, components.WithSplitter(components.NewPBSplitter(nil)))
  41 + }
34 42  
35 43 iserver := components.NewServer(sconf.Port, options...)
36 44 iserver.SetActions(action.GetActionMap())
... ...
cmd/test/client.go
... ... @@ -19,14 +19,14 @@ func main() {
19 19 }
20 20  
21 21 loginReq := &pb.LoginReq{
22   - Token: "141815055745814528",
  22 + Token: "141815055745814528",
23 23 Device: "123123",
24 24 }
25   - l, _ :=proto.Marshal(loginReq)
  25 + l, _ := proto.Marshal(loginReq)
26 26  
27 27 options := []components.ConnectorOption{
28 28 components.WithCtorCount(common.GlobalConf.TestClient.Count),
29   - components.WithCtorSplitter(components.NewPBSplitter()),
  29 + components.WithCtorSplitter(components.NewPBSplitter(nil)),
30 30 }
31 31  
32 32 client := components.NewConnector(common.GlobalConf.TestClient.Ip, common.GlobalConf.TestClient.Port, options...)
... ... @@ -35,8 +35,8 @@ func main() {
35 35 return
36 36 }
37 37  
38   - for {
  38 + for {
39 39 client.Send(head.Cmd, l)
40   - time.Sleep(1*time.Second)
  40 + time.Sleep(1 * time.Second)
41 41 }
42   -}
43 42 \ No newline at end of file
  43 +}
... ...
common/components/aes.go 0 → 100644
... ... @@ -0,0 +1,108 @@
  1 +package components
  2 +
  3 +import (
  4 + "bytes"
  5 + "crypto/aes"
  6 + "crypto/cipher"
  7 + "encoding/base64"
  8 + "errors"
  9 +)
  10 +
  11 +var (
  12 + PwdKey = []byte("luluszhaozhaoAAA") //RC4 zhaolu2333
  13 +)
  14 +
  15 +type AesEncipher struct {
  16 +}
  17 +
  18 +func NewAesEncipher() IEncipher {
  19 + return &AesEncipher{}
  20 +}
  21 +
  22 +//pkcs7Padding 填充
  23 +func (a *AesEncipher) pkcs7Padding(data []byte, blockSize int) []byte {
  24 + //判断缺少几位长度。最少1,最多 blockSize
  25 + padding := blockSize - len(data)%blockSize
  26 + //补足位数。把切片[]byte{byte(padding)}复制padding个
  27 + padText := bytes.Repeat([]byte{byte(padding)}, padding)
  28 + return append(data, padText...)
  29 +}
  30 +
  31 +//pkcs7UnPadding 填充的反向操作
  32 +func (a *AesEncipher) pkcs7UnPadding(data []byte) ([]byte, error) {
  33 + length := len(data)
  34 + if length == 0 {
  35 + return nil, errors.New("加密字符串错误!")
  36 + }
  37 + //获取填充的个数
  38 + unPadding := int(data[length-1])
  39 + return data[:(length - unPadding)], nil
  40 +}
  41 +
  42 +//AesEncrypt 加密
  43 +func (a *AesEncipher) Encrypt(data []byte) ([]byte, error) {
  44 + //创建加密实例
  45 + block, err := aes.NewCipher(PwdKey)
  46 + if err != nil {
  47 + return nil, err
  48 + }
  49 + //判断加密快的大小
  50 + blockSize := block.BlockSize()
  51 + //填充
  52 + encryptBytes := a.pkcs7Padding(data, blockSize)
  53 + if len(encryptBytes)%blockSize != 0 {
  54 + return nil, errors.New("crypto/cipher: input not full blocks")
  55 + }
  56 + //初始化加密数据接收切片
  57 + crypted := make([]byte, len(encryptBytes))
  58 + //使用cbc加密模式
  59 + blockMode := cipher.NewCBCEncrypter(block, PwdKey[:blockSize])
  60 + //执行加密
  61 + blockMode.CryptBlocks(crypted, encryptBytes)
  62 + return crypted, nil
  63 +}
  64 +
  65 +//AesDecrypt 解密
  66 +func (a *AesEncipher) Decrypt(data []byte) ([]byte, error) {
  67 + //创建实例
  68 + block, err := aes.NewCipher(PwdKey)
  69 + if err != nil {
  70 + return nil, err
  71 + }
  72 + //获取块的大小
  73 + blockSize := block.BlockSize()
  74 + //使用cbc
  75 + blockMode := cipher.NewCBCDecrypter(block, PwdKey[:blockSize])
  76 + //初始化解密数据接收切片
  77 + dataLen := len(data)
  78 + if dataLen%blockSize != 0 {
  79 + return nil, errors.New("crypto/cipher: input not full blocks")
  80 + }
  81 + crypted := make([]byte, dataLen)
  82 + //执行解密
  83 + blockMode.CryptBlocks(crypted, data)
  84 + //去除填充
  85 + crypted, err = a.pkcs7UnPadding(crypted)
  86 + if err != nil {
  87 + return nil, err
  88 + }
  89 + return crypted, nil
  90 +}
  91 +
  92 +//EncryptByAes Aes加密 后 base64 再加
  93 +func (a *AesEncipher) EncryptByAes(data []byte) (string, error) {
  94 + res, err := a.Encrypt(data)
  95 + if err != nil {
  96 + return "", err
  97 + }
  98 + return base64.StdEncoding.EncodeToString(res), nil
  99 +}
  100 +
  101 +//DecryptByAes Aes 解密
  102 +func (a *AesEncipher) DecryptByAes(data string) ([]byte, error) {
  103 + dataByte, err := base64.StdEncoding.DecodeString(data)
  104 + if err != nil {
  105 + return nil, err
  106 + }
  107 + return a.Decrypt(dataByte)
  108 +}
... ...
common/components/icompontents.go
... ... @@ -31,6 +31,12 @@ type (
31 31 ParseMsg(data []byte, atEOF bool) (advance int, token []byte, err error)
32 32 GetHeadLen() uint32
33 33 }
  34 + //加解密
  35 + IEncipher interface {
  36 + Encrypt([]byte) ([]byte, error)
  37 + Decrypt([]byte) ([]byte, error)
  38 + }
  39 +
34 40 ConnectionCallback func(IConnection)
35 41 CloseCallback func(IConnection)
36 42 MessageCallback func(IMessage)
... ... @@ -127,7 +133,7 @@ type (
127 133  
128 134 SetProperty(key string, val interface{})
129 135 SetProperties(properties map[string]interface{})
130   - ParseFields(message protoreflect.Message ,properties map[string]interface{}) []int32
  136 + ParseFields(message protoreflect.Message, properties map[string]interface{}) []int32
131 137 }
132 138 )
133 139  
... ...
common/components/pbsplitter.go
... ... @@ -8,10 +8,10 @@ import (
8 8 )
9 9  
10 10 type PBHead struct {
11   - Length uint32
12   - Cmd uint32
13   - ErrCode int32
14   - PreField uint32
  11 + Length uint32
  12 + Cmd uint32
  13 + ErrCode int32
  14 + PreField uint32
15 15 }
16 16  
17 17 func (h *PBHead) GetDataLen() uint32 {
... ... @@ -38,12 +38,11 @@ type PBMessage struct {
38 38 conn IConnection
39 39 }
40 40  
41   -
42 41 func (m *PBMessage) GetHeader() IHead {
43 42 return m.Head
44 43 }
45 44  
46   -func (m *PBMessage) SetHeader(header IHead) {
  45 +func (m *PBMessage) SetHeader(header IHead) {
47 46 m.Head = header
48 47 }
49 48 func (m *PBMessage) GetData() []byte {
... ... @@ -62,31 +61,21 @@ func (m *PBMessage) GetSession() IConnection {
62 61 return m.conn
63 62 }
64 63  
  64 +type PBSplitter struct {
  65 + encipher IEncipher
  66 +}
65 67  
66   -type PBSplitter struct {}
67   -
68   -func NewPBSplitter() *PBSplitter {
69   - return &PBSplitter{}
  68 +func NewPBSplitter(encipher IEncipher) ISplitter {
  69 + return &PBSplitter{
  70 + encipher,
  71 + }
70 72 }
71 73  
72 74 func (m *PBSplitter) GetHeadLen() uint32 {
73 75 return uint32(binary.Size(PBHead{}))
74 76 }
75 77  
76   -func (m *PBSplitter) UnPack(data []byte) (IMessage,error) {
77   - h := &PBHead{}
78   - err := binary.Read(bytes.NewReader(data), binary.BigEndian, h)
79   - if err != nil {
80   - return nil, err
81   - }
82   -
83   - return &PBMessage{
84   - Head: h,
85   - Body: data[m.GetHeadLen():],
86   - },nil
87   -}
88   -
89   -func (m *PBSplitter) ParseMsg (data []byte, atEOF bool) (advance int, token []byte, err error) {
  78 +func (m *PBSplitter) ParseMsg(data []byte, atEOF bool) (advance int, token []byte, err error) {
90 79 // 表示我们已经扫描到结尾了
91 80 if atEOF && len(data) == 0 {
92 81 return 0, nil, nil
... ... @@ -102,9 +91,9 @@ func (m *PBSplitter) ParseMsg (data []byte, atEOF bool) (advance int, token []by
102 91 return 0, nil, fmt.Errorf("length exceeds maximum length")
103 92 }
104 93 if int(length) <= len(data) {
105   - return int(length) , data[:int(length)], nil
  94 + return int(length), data[:int(length)], nil
106 95 }
107   - return 0 , nil, nil
  96 + return 0, nil, nil
108 97 }
109 98 if atEOF {
110 99 return len(data), data, nil
... ... @@ -115,20 +104,56 @@ func (m *PBSplitter) ParseMsg (data []byte, atEOF bool) (advance int, token []by
115 104 func (m *PBSplitter) Pack(cmd uint32, data []byte, errcode int32, preserve uint32) ([]byte, error) {
116 105 buf := &bytes.Buffer{}
117 106 h := &PBHead{
118   - Length: m.GetHeadLen()+ uint32(len(data)),
  107 + Length: m.GetHeadLen(),
119 108 Cmd: cmd,
120 109 ErrCode: errcode,
121 110 PreField: preserve,
122 111 }
123   - err := binary.Write(buf, binary.BigEndian, h)
  112 + var dataEn []byte
  113 + var err error
  114 + if m.encipher != nil {
  115 + dataEn, err = m.encipher.Encrypt(data)
  116 + if err != nil {
  117 + return nil, err
  118 + }
  119 + } else {
  120 + dataEn = data
  121 + }
  122 +
  123 + h.Length += uint32(len(dataEn))
  124 +
  125 + err = binary.Write(buf, binary.BigEndian, h)
124 126 if err != nil {
125 127 return nil, err
126 128 }
127 129  
128   - err = binary.Write(buf, binary.BigEndian, data)
  130 + err = binary.Write(buf, binary.BigEndian, dataEn)
129 131 if err != nil {
130 132 return nil, err
131 133 }
132 134  
133 135 return buf.Bytes(), nil
134   -}
135 136 \ No newline at end of file
  137 +}
  138 +
  139 +func (m *PBSplitter) UnPack(data []byte) (IMessage, error) {
  140 + h := &PBHead{}
  141 + err := binary.Read(bytes.NewReader(data), binary.BigEndian, h)
  142 + if err != nil {
  143 + return nil, err
  144 + }
  145 +
  146 + var dataDe []byte
  147 + if m.encipher != nil {
  148 + dataDe, err = m.encipher.Decrypt(data[m.GetHeadLen():])
  149 + if err != nil {
  150 + return nil, err
  151 + }
  152 + } else {
  153 + dataDe = data[m.GetHeadLen():]
  154 + }
  155 +
  156 + return &PBMessage{
  157 + Head: h,
  158 + Body: dataDe,
  159 + }, nil
  160 +}
... ...
common/components/timewheel_test.go
... ... @@ -6,13 +6,25 @@ import (
6 6 "time"
7 7 )
8 8  
9   -func PRINT() {
  9 +func PRINT() {
10 10 fmt.Println("12312312312")
11 11 }
12 12  
13 13 func TestTimeWheel_Start(t *testing.T) {
14   - TimeOut(1 * time.Second, func() {
  14 + TimeOut(1*time.Second, func() {
15 15 fmt.Println("12312313123")
16 16 })
17   - select{}
  17 + select {}
  18 +}
  19 +
  20 +func TestAesEncipher_Decrypt(t *testing.T) {
  21 + aes := AesEncipher{}
  22 + encode, err := aes.Encrypt([]byte("123"))
  23 + if err != nil {
  24 + fmt.Println(err.Error())
  25 + return
  26 + }
  27 +
  28 + dec, err := aes.Decrypt(encode)
  29 + fmt.Printf("%s\n", dec)
18 30 }
... ...
common/conf.go
... ... @@ -12,86 +12,87 @@ import (
12 12  
13 13 type RedisConf struct {
14 14 Address string `json:"address"`
15   - Auth string `json:"auth"`
16   - DB int `json:"db"`
  15 + Auth string `json:"auth"`
  16 + DB int `json:"db"`
17 17 }
18 18  
19 19 type Etcd struct {
20   - Endpoints []string `json:"endpoints"`
21   - DialTimeout int `json:"dialtimeout"`
  20 + Endpoints []string `json:"endpoints"`
  21 + DialTimeout int `json:"dialtimeout"`
22 22 }
23 23  
24 24 type MongoConf struct {
25   - User string `yaml:"user"`
26   - Password string `yaml:"password"`
27   - Host string `yaml:"host"`
28   - Port int `yaml:"port"`
29   - TimeOut int `yaml:"timeout"`
30   - MaxNum int `yaml:"maxnum"`
31   - DBName string `yaml:"dbname"`
  25 + User string `yaml:"user"`
  26 + Password string `yaml:"password"`
  27 + Host string `yaml:"host"`
  28 + Port int `yaml:"port"`
  29 + TimeOut int `yaml:"timeout"`
  30 + MaxNum int `yaml:"maxnum"`
  31 + DBName string `yaml:"dbname"`
32 32 }
33 33  
34 34 type SConf struct {
35   - ID string `yaml:"id"`
36   - Name string `yaml:"name"`
37   - IP string `yaml:"ip"`
38   - Port int `yaml:"port"`
39   - DebugPort int `yaml:"debugport"`
40   - MongoConf *MongoConf `yaml:"mongo"`
41   - RedisConf *RedisConf `yaml:"redis"`
42   - WorkerPoolSize int `yaml:"pool_size"`
43   - PluginPath string `yaml:"plugin_path"`
  35 + ID string `yaml:"id"`
  36 + Name string `yaml:"name"`
  37 + IP string `yaml:"ip"`
  38 + Port int `yaml:"port"`
  39 + Encipher bool `yaml:"encipher"`
  40 + DebugPort int `yaml:"debugport"`
  41 + MongoConf *MongoConf `yaml:"mongo"`
  42 + RedisConf *RedisConf `yaml:"redis"`
  43 + WorkerPoolSize int `yaml:"pool_size"`
  44 + PluginPath string `yaml:"plugin_path"`
44 45 }
45 46  
46 47 type LogConsole struct {
47   - Level string `yaml:"level" json:"level"`
48   - Color bool `yaml:"color" json:"color"`
  48 + Level string `yaml:"level" json:"level"`
  49 + Color bool `yaml:"color" json:"color"`
49 50 }
50 51  
51 52 type LogFile struct {
52   - Level string `yaml:"level" json:"level"`
53   - Daily bool `yaml:"daily" json:"daily"`
54   - Maxlines int `yaml:"maxlines" json:"maxlines"`
55   - Maxsize int `yaml:"maxsize" json:"maxsize"`
56   - Maxdays int `yaml:"maxdays" json:"maxdays"`
57   - Append bool `yaml:"append" json:"append"`
58   - Permit string `yaml:"permit" json:"permit"`
  53 + Level string `yaml:"level" json:"level"`
  54 + Daily bool `yaml:"daily" json:"daily"`
  55 + Maxlines int `yaml:"maxlines" json:"maxlines"`
  56 + Maxsize int `yaml:"maxsize" json:"maxsize"`
  57 + Maxdays int `yaml:"maxdays" json:"maxdays"`
  58 + Append bool `yaml:"append" json:"append"`
  59 + Permit string `yaml:"permit" json:"permit"`
59 60 }
60 61  
61 62 type LogConn struct {
62   - Net string `yaml:"net" json:"net"`
63   - Addr string `yaml:"addr" json:"addr"`
64   - Level string `yaml:"level" json:"level"`
65   - Reconnect bool `yaml:"reconnect" json:"reconnect"`
66   - ReconnectOnMsg bool `yaml:"reconnectOnMsg" json:"reconnectOnMsg"`
  63 + Net string `yaml:"net" json:"net"`
  64 + Addr string `yaml:"addr" json:"addr"`
  65 + Level string `yaml:"level" json:"level"`
  66 + Reconnect bool `yaml:"reconnect" json:"reconnect"`
  67 + ReconnectOnMsg bool `yaml:"reconnectOnMsg" json:"reconnectOnMsg"`
67 68 }
68 69  
69 70 type LogConf struct {
70   - TimeFormat string `yaml:"TimeFormat" json:"TimeFormat"`
71   - LogConsole *LogConsole `yaml:"Console" json:"Console"`
72   - LogFile *LogFile `yaml:"File" json:"File"`
73   - LogConn *LogConn `yaml:"Conn" json:"Conn"`
  71 + TimeFormat string `yaml:"TimeFormat" json:"TimeFormat"`
  72 + LogConsole *LogConsole `yaml:"Console" json:"Console"`
  73 + LogFile *LogFile `yaml:"File" json:"File"`
  74 + LogConn *LogConn `yaml:"Conn" json:"Conn"`
74 75 }
75 76  
76 77 type TestClient struct {
77   - Ip string `yaml:"ip"`
78   - Port int`yaml:"port"`
79   - Count int `yaml:"count"`
  78 + Ip string `yaml:"ip"`
  79 + Port int `yaml:"port"`
  80 + Count int `yaml:"count"`
80 81 }
81 82  
82 83 type ServerConf struct {
83   - ID string `yaml:"id"`
84   - Name string `yaml:"name"`
85   - WorkerID int64 `yaml:"workerid"`
86   - DatacenterID int64 `yaml:"datacenterid"`
87   - AccountConf *SConf `yaml:"server_account"`
88   - GameConf *SConf `yaml:"server_game"`
89   - LogConf *LogConf `yaml:"logconf" json:"logconf"`
90   - TestClient *TestClient `yaml:"test_client"`
91   - Etcd *Etcd `yaml:"etcd"`
  84 + ID string `yaml:"id"`
  85 + Name string `yaml:"name"`
  86 + WorkerID int64 `yaml:"workerid"`
  87 + DatacenterID int64 `yaml:"datacenterid"`
  88 + AccountConf *SConf `yaml:"server_account"`
  89 + GameConf *SConf `yaml:"server_game"`
  90 + LogConf *LogConf `yaml:"logconf" json:"logconf"`
  91 + TestClient *TestClient `yaml:"test_client"`
  92 + Etcd *Etcd `yaml:"etcd"`
92 93 }
93 94  
94   -var(
  95 +var (
95 96 GlobalConf ServerConf
96 97 SnowFlack *snow.Snowflake
97 98 )
... ... @@ -122,4 +123,4 @@ func init() {
122 123  
123 124 //初始化雪花算法
124 125 SnowFlack = snow.NewSnowflake(GlobalConf.WorkerID, GlobalConf.DatacenterID)
125   -}
126 126 \ No newline at end of file
  127 +}
... ...
conf/conf.yaml
... ... @@ -35,6 +35,7 @@ server_game:
35 35 id: "2"
36 36 name: "game"
37 37 ip: "192.168.0.206"
  38 + encipher: true
38 39 port: 8850
39 40 pool_size: 1
40 41 debugport: 6061
... ...