hotfix.lua 2.36 KB
local si = require "snax.interface"

local function envid(f)
	local i = 1
	while true do
		local name, value = debug.getupvalue(f, i)
		if name == nil then
			return
		end
		if name == "_ENV" then
			return debug.upvalueid(f, i)
		end
		i = i + 1
	end
end

local function collect_uv(f , uv, env)
	local i = 1
	while true do
		local name, value = debug.getupvalue(f, i)
		if name == nil then
			break
		end
		local id = debug.upvalueid(f, i)

		if uv[name] then
			assert(uv[name].id == id, string.format("ambiguity local value %s", name))
		else
			uv[name] = { func = f, index = i, id = id }

			if type(value) == "function" then
				if envid(value) == env then
					collect_uv(value, uv, env)
				end
			end
		end

		i = i + 1
	end
end

local function collect_all_uv(funcs)
	local global = {}
	for _, v in pairs(funcs) do
		if v[4] then
			collect_uv(v[4], global, envid(v[4]))
		end
	end
	if not global["_ENV"] then
		global["_ENV"] = {func = collect_uv, index = 1}
	end
	return global
end

local function loader(source)
	return function (path, name, G)
		return load(source, "=patch", "bt", G)
	end
end

local function find_func(funcs, group , name)
	for _, desc in pairs(funcs) do
		local _, g, n = table.unpack(desc)
		if group == g and name == n then
			return desc
		end
	end
end

local dummy_env = {}
for k,v in pairs(_ENV) do dummy_env[k] = v end

local function _patch(global, f)
	local i = 1
	while true do
		local name, value = debug.getupvalue(f, i)
		if name == nil then
			break
		elseif value == nil or value == dummy_env then
			local old_uv = global[name]
			if old_uv then
				debug.upvaluejoin(f, i, old_uv.func, old_uv.index)
			end
		else
			if type(value) == "function" then
				_patch(global, value)
			end
		end
		i = i + 1
	end
end

local function patch_func(funcs, global, group, name, f)
	local desc = assert(find_func(funcs, group, name) , string.format("Patch mismatch %s.%s", group, name))
	_patch(global, f)
	desc[4] = f
end

local function inject(funcs, source, ...)
	local patch = si("patch", dummy_env, loader(source))
	local global = collect_all_uv(funcs)

	for _, v in pairs(patch) do
		local _, group, name, f = table.unpack(v)
		if f then
			patch_func(funcs, global, group, name, f)
		end
	end

	local hf = find_func(patch, "system", "hotfix")
	if hf and hf[4] then
		return hf[4](...)
	end
end

return function (funcs, source, ...)
	return pcall(inject, funcs, source, ...)
end