Blame view

publish/skynet/lualib/skynet/socketchannel.lua 11.2 KB
4d6f285d   zhouhaihai   增加发布功能
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
  local skynet = require "skynet"
  local socket = require "skynet.socket"
  local socketdriver = require "skynet.socketdriver"
  
  -- channel support auto reconnect , and capture socket error in request/response transaction
  -- { host = "", port = , auth = function(so) , response = function(so) session, data }
  
  local socket_channel = {}
  local channel = {}
  local channel_socket = {}
  local channel_meta = { __index = channel }
  local channel_socket_meta = {
  	__index = channel_socket,
  	__gc = function(cs)
  		local fd = cs[1]
  		cs[1] = false
  		if fd then
  			socket.shutdown(fd)
  		end
  	end
  }
  
  local socket_error = setmetatable({}, {__tostring = function() return "[Error: socket]" end })	-- alias for error object
  socket_channel.error = socket_error
  
  function socket_channel.channel(desc)
  	local c = {
  		__host = assert(desc.host),
  		__port = assert(desc.port),
  		__backup = desc.backup,
  		__auth = desc.auth,
  		__response = desc.response,	-- It's for session mode
  		__request = {},	-- request seq { response func or session }	-- It's for order mode
  		__thread = {}, -- coroutine seq or session->coroutine map
  		__result = {}, -- response result { coroutine -> result }
  		__result_data = {},
  		__connecting = {},
  		__sock = false,
  		__closed = false,
  		__authcoroutine = false,
  		__nodelay = desc.nodelay,
  		__overload_notify = desc.overload,
  		__overload = false,
  	}
  
  	return setmetatable(c, channel_meta)
  end
  
  local function close_channel_socket(self)
  	if self.__sock then
  		local so = self.__sock
  		self.__sock = false
  		-- never raise error
  		pcall(socket.close,so[1])
  	end
  end
  
  local function wakeup_all(self, errmsg)
  	if self.__response then
  		for k,co in pairs(self.__thread) do
  			self.__thread[k] = nil
  			self.__result[co] = socket_error
  			self.__result_data[co] = errmsg
  			skynet.wakeup(co)
  		end
  	else
  		for i = 1, #self.__request do
  			self.__request[i] = nil
  		end
  		for i = 1, #self.__thread do
  			local co = self.__thread[i]
  			self.__thread[i] = nil
  			if co then	-- ignore the close signal
  				self.__result[co] = socket_error
  				self.__result_data[co] = errmsg
  				skynet.wakeup(co)
  			end
  		end
  	end
  end
  
  local function dispatch_by_session(self)
  	local response = self.__response
  	-- response() return session
  	while self.__sock do
  		local ok , session, result_ok, result_data, padding = pcall(response, self.__sock)
  		if ok and session then
  			local co = self.__thread[session]
  			if co then
  				if padding and result_ok then
  					-- If padding is true, append result_data to a table (self.__result_data[co])
  					local result = self.__result_data[co] or {}
  					self.__result_data[co] = result
  					table.insert(result, result_data)
  				else
  					self.__thread[session] = nil
  					self.__result[co] = result_ok
  					if result_ok and self.__result_data[co] then
  						table.insert(self.__result_data[co], result_data)
  					else
  						self.__result_data[co] = result_data
  					end
  					skynet.wakeup(co)
  				end
  			else
  				self.__thread[session] = nil
  				skynet.error("socket: unknown session :", session)
  			end
  		else
  			close_channel_socket(self)
  			local errormsg
  			if session ~= socket_error then
  				errormsg = session
  			end
  			wakeup_all(self, errormsg)
  		end
  	end
  end
  
  local function pop_response(self)
  	while true do
  		local func,co = table.remove(self.__request, 1), table.remove(self.__thread, 1)
  		if func then
  			return func, co
  		end
  		self.__wait_response = coroutine.running()
  		skynet.wait(self.__wait_response)
  	end
  end
  
  local function push_response(self, response, co)
  	if self.__response then
  		-- response is session
  		self.__thread[response] = co
  	else
  		-- response is a function, push it to __request
  		table.insert(self.__request, response)
  		table.insert(self.__thread, co)
  		if self.__wait_response then
  			skynet.wakeup(self.__wait_response)
  			self.__wait_response = nil
  		end
  	end
  end
  
  local function get_response(func, sock)
  	local result_ok, result_data, padding = func(sock)
  	if result_ok and padding then
  		local result = { result_data }
  		local index = 2
  		repeat
  			result_ok, result_data, padding = func(sock)
  			if not result_ok then
  				return result_ok, result_data
  			end
  			result[index] = result_data
  			index = index + 1
  		until not padding
  		return true, result
  	else
  		return result_ok, result_data
  	end
  end
  
  local function dispatch_by_order(self)
  	while self.__sock do
  		local func, co = pop_response(self)
  		if not co then
  			-- close signal
  			wakeup_all(self, "channel_closed")
  			break
  		end
  		local ok, result_ok, result_data = pcall(get_response, func, self.__sock)
  		if ok then
  			self.__result[co] = result_ok
  			if result_ok and self.__result_data[co] then
  				table.insert(self.__result_data[co], result_data)
  			else
  				self.__result_data[co] = result_data
  			end
  			skynet.wakeup(co)
  		else
  			close_channel_socket(self)
  			local errmsg
  			if result_ok ~= socket_error then
  				errmsg = result_ok
  			end
  			self.__result[co] = socket_error
  			self.__result_data[co] = errmsg
  			skynet.wakeup(co)
  			wakeup_all(self, errmsg)
  		end
  	end
  end
  
  local function dispatch_function(self)
  	if self.__response then
  		return dispatch_by_session
  	else
  		return dispatch_by_order
  	end
  end
  
  local function connect_backup(self)
  	if self.__backup then
  		for _, addr in ipairs(self.__backup) do
  			local host, port
  			if type(addr) == "table" then
  				host, port = addr.host, addr.port
  			else
  				host = addr
  				port = self.__port
  			end
  			skynet.error("socket: connect to backup host", host, port)
  			local fd = socket.open(host, port)
  			if fd then
  				self.__host = host
  				self.__port = port
  				return fd
  			end
  		end
  	end
  end
  
  local function term_dispatch_thread(self)
  	if not self.__response and self.__dispatch_thread then
  		-- dispatch by order, send close signal to dispatch thread
  		push_response(self, true, false)	-- (true, false) is close signal
  	end
  end
  
  local function connect_once(self)
  	if self.__closed then
  		return false
  	end
  	assert(not self.__sock and not self.__authcoroutine)
  	-- term current dispatch thread (send a signal)
  	term_dispatch_thread(self)
  
  	local fd,err = socket.open(self.__host, self.__port)
  	if not fd then
  		fd = connect_backup(self)
  		if not fd then
  			return false, err
  		end
  	end
  	if self.__nodelay then
  		socketdriver.nodelay(fd)
  	end
  
  	-- register overload warning
  
  	local overload = self.__overload_notify
  	if overload then
  		local function overload_trigger(id, size)
  			if id == self.__sock[1] then
  				if size == 0 then
  					if self.__overload then
  						self.__overload = false
  						overload(false)
  					end
  				else
  					if not self.__overload then
  						self.__overload = true
  						overload(true)
  					else
  						skynet.error(string.format("WARNING: %d K bytes need to send out (fd = %d %s:%s)", size, id, self.__host, self.__port))
  					end
  				end
  			end
  		end
  
  		skynet.fork(overload_trigger, fd, 0)
  		socket.warning(fd, overload_trigger)
  	end
  
  	while self.__dispatch_thread do
  		-- wait for dispatch thread exit
  		skynet.yield()
  	end
  
  	self.__sock = setmetatable( {fd} , channel_socket_meta )
  	self.__dispatch_thread = skynet.fork(function()
  		pcall(dispatch_function(self), self)
  		-- clear dispatch_thread
  		self.__dispatch_thread = nil
  	end)
  
  	if self.__auth then
  		self.__authcoroutine = coroutine.running()
  		local ok , message = pcall(self.__auth, self)
  		if not ok then
  			close_channel_socket(self)
  			if message ~= socket_error then
  				self.__authcoroutine = false
  				skynet.error("socket: auth failed", message)
  			end
  		end
  		self.__authcoroutine = false
  		if ok and not self.__sock then
  			-- auth may change host, so connect again
  			return connect_once(self)
  		end
  		return ok
  	end
  
  	return true
  end
  
  local function try_connect(self , once)
  	local t = 0
  	while not self.__closed do
  		local ok, err = connect_once(self)
  		if ok then
  			if not once then
  				skynet.error("socket: connect to", self.__host, self.__port)
  			end
  			return
  		elseif once then
  			return err
  		else
  			skynet.error("socket: connect", err)
  		end
  		if t > 1000 then
  			skynet.error("socket: try to reconnect", self.__host, self.__port)
  			skynet.sleep(t)
  			t = 0
  		else
  			skynet.sleep(t)
  		end
  		t = t + 100
  	end
  end
  
  local function check_connection(self)
  	if self.__sock then
  		if socket.disconnected(self.__sock[1]) then
  			-- closed by peer
  			skynet.error("socket: disconnect detected ", self.__host, self.__port)
  			close_channel_socket(self)
  			return
  		end
  		local authco = self.__authcoroutine
  		if not authco then
  			return true
  		end
  		if authco == coroutine.running() then
  			-- authing
  			return true
  		end
  	end
  	if self.__closed then
  		return false
  	end
  end
  
  local function block_connect(self, once)
  	local r = check_connection(self)
  	if r ~= nil then
  		return r
  	end
  	local err
  
  	if #self.__connecting > 0 then
  		-- connecting in other coroutine
  		local co = coroutine.running()
  		table.insert(self.__connecting, co)
  		skynet.wait(co)
  	else
  		self.__connecting[1] = true
  		err = try_connect(self, once)
  		self.__connecting[1] = nil
  		for i=2, #self.__connecting do
  			local co = self.__connecting[i]
  			self.__connecting[i] = nil
  			skynet.wakeup(co)
  		end
  	end
  
  	r = check_connection(self)
  	if r == nil then
  		skynet.error(string.format("Connect to %s:%d failed (%s)", self.__host, self.__port, err))
  		error(socket_error)
  	else
  		return r
  	end
  end
  
  function channel:connect(once)
  	self.__closed = false
  	return block_connect(self, once)
  end
  
  local function wait_for_response(self, response)
  	local co = coroutine.running()
  	push_response(self, response, co)
  	skynet.wait(co)
  
  	local result = self.__result[co]
  	self.__result[co] = nil
  	local result_data = self.__result_data[co]
  	self.__result_data[co] = nil
  
  	if result == socket_error then
  		if result_data then
  			error(result_data)
  		else
  			error(socket_error)
  		end
  	else
  		assert(result, result_data)
  		return result_data
  	end
  end
  
  local socket_write = socket.write
  local socket_lwrite = socket.lwrite
  
  local function sock_err(self)
  	close_channel_socket(self)
  	wakeup_all(self)
  	error(socket_error)
  end
  
  function channel:request(request, response, padding)
  	assert(block_connect(self, true))	-- connect once
  	local fd = self.__sock[1]
  
  	if padding then
  		-- padding may be a table, to support multi part request
  		-- multi part request use low priority socket write
  		-- now socket_lwrite returns as socket_write
  		if not socket_lwrite(fd , request) then
  			sock_err(self)
  		end
  		for _,v in ipairs(padding) do
  			if not socket_lwrite(fd, v) then
  				sock_err(self)
  			end
  		end
  	else
  		if not socket_write(fd , request) then
  			sock_err(self)
  		end
  	end
  
  	if response == nil then
  		-- no response
  		return
  	end
  
  	return wait_for_response(self, response)
  end
  
  function channel:response(response)
  	assert(block_connect(self))
  
  	return wait_for_response(self, response)
  end
  
  function channel:close()
  	if not self.__closed then
  		term_dispatch_thread(self)
  		self.__closed = true
  		close_channel_socket(self)
  	end
  end
  
  function channel:changehost(host, port)
  	self.__host = host
  	if port then
  		self.__port = port
  	end
  	if not self.__closed then
  		close_channel_socket(self)
  	end
  end
  
  function channel:changebackup(backup)
  	self.__backup = backup
  end
  
  channel_meta.__gc = channel.close
  
  local function wrapper_socket_function(f)
  	return function(self, ...)
  		local result = f(self[1], ...)
  		if not result then
  			error(socket_error)
  		else
  			return result
  		end
  	end
  end
  
  channel_socket.read = wrapper_socket_function(socket.read)
  channel_socket.readline = wrapper_socket_function(socket.readline)
  
  return socket_channel