Blame view

publish/skynet/lualib/snax/hotfix.lua 2.36 KB
4d6f285d   zhouhaihai   增加发布功能
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
  local si = require "snax.interface"
  
  local function envid(f)
  	local i = 1
  	while true do
  		local name, value = debug.getupvalue(f, i)
  		if name == nil then
  			return
  		end
  		if name == "_ENV" then
  			return debug.upvalueid(f, i)
  		end
  		i = i + 1
  	end
  end
  
  local function collect_uv(f , uv, env)
  	local i = 1
  	while true do
  		local name, value = debug.getupvalue(f, i)
  		if name == nil then
  			break
  		end
  		local id = debug.upvalueid(f, i)
  
  		if uv[name] then
  			assert(uv[name].id == id, string.format("ambiguity local value %s", name))
  		else
  			uv[name] = { func = f, index = i, id = id }
  
  			if type(value) == "function" then
  				if envid(value) == env then
  					collect_uv(value, uv, env)
  				end
  			end
  		end
  
  		i = i + 1
  	end
  end
  
  local function collect_all_uv(funcs)
  	local global = {}
  	for _, v in pairs(funcs) do
  		if v[4] then
  			collect_uv(v[4], global, envid(v[4]))
  		end
  	end
  	if not global["_ENV"] then
  		global["_ENV"] = {func = collect_uv, index = 1}
  	end
  	return global
  end
  
  local function loader(source)
  	return function (path, name, G)
  		return load(source, "=patch", "bt", G)
  	end
  end
  
  local function find_func(funcs, group , name)
  	for _, desc in pairs(funcs) do
  		local _, g, n = table.unpack(desc)
  		if group == g and name == n then
  			return desc
  		end
  	end
  end
  
  local dummy_env = {}
  for k,v in pairs(_ENV) do dummy_env[k] = v end
  
  local function _patch(global, f)
  	local i = 1
  	while true do
  		local name, value = debug.getupvalue(f, i)
  		if name == nil then
  			break
  		elseif value == nil or value == dummy_env then
  			local old_uv = global[name]
  			if old_uv then
  				debug.upvaluejoin(f, i, old_uv.func, old_uv.index)
  			end
  		else
  			if type(value) == "function" then
  				_patch(global, value)
  			end
  		end
  		i = i + 1
  	end
  end
  
  local function patch_func(funcs, global, group, name, f)
  	local desc = assert(find_func(funcs, group, name) , string.format("Patch mismatch %s.%s", group, name))
  	_patch(global, f)
  	desc[4] = f
  end
  
  local function inject(funcs, source, ...)
  	local patch = si("patch", dummy_env, loader(source))
  	local global = collect_all_uv(funcs)
  
  	for _, v in pairs(patch) do
  		local _, group, name, f = table.unpack(v)
  		if f then
  			patch_func(funcs, global, group, name, f)
  		end
  	end
  
  	local hf = find_func(patch, "system", "hotfix")
  	if hf and hf[4] then
  		return hf[4](...)
  	end
  end
  
  return function (funcs, source, ...)
  	return pcall(inject, funcs, source, ...)
  end