httpc.lua
3.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
local skynet = require "skynet"
local socket = require "http.sockethelper"
local url = require "http.url"
local internal = require "http.internal"
local dns = require "skynet.dns"
local string = string
local table = table
local httpc = {}
local async_dns
function httpc.dns(server,port)
async_dns = true
dns.server(server,port)
end
local function check_protocol(host)
local protocol = host:match("^[Hh][Tt][Tt][Pp][Ss]?://")
if protocol then
host = string.gsub(host, "^"..protocol, "")
protocol = string.lower(protocol)
if protocol == "https://" then
return "https", host
elseif protocol == "http://" then
return "http", host
else
error(string.format("Invalid protocol: %s", protocol))
end
else
return "http", host
end
end
local SSLCTX_CLIENT = nil
local function gen_interface(protocol, fd)
if protocol == "http" then
return {
init = nil,
close = nil,
read = socket.readfunc(fd),
write = socket.writefunc(fd),
readall = function ()
return socket.readall(fd)
end,
}
elseif protocol == "https" then
local tls = require "http.tlshelper"
SSLCTX_CLIENT = SSLCTX_CLIENT or tls.newctx()
local tls_ctx = tls.newtls("client", SSLCTX_CLIENT)
return {
init = tls.init_requestfunc(fd, tls_ctx),
close = tls.closefunc(tls_ctx),
read = tls.readfunc(fd, tls_ctx),
write = tls.writefunc(fd, tls_ctx),
readall = tls.readallfunc(fd, tls_ctx),
}
else
error(string.format("Invalid protocol: %s", protocol))
end
end
function httpc.request(method, host, url, recvheader, header, content)
local protocol
local timeout = httpc.timeout -- get httpc.timeout before any blocked api
protocol, host = check_protocol(host)
local hostname, port = host:match"([^:]+):?(%d*)$"
if port == "" then
port = protocol=="http" and 80 or protocol=="https" and 443
else
port = tonumber(port)
end
if async_dns and not hostname:match(".*%d+$") then
hostname = dns.resolve(hostname)
end
local fd = socket.connect(hostname, port, timeout)
if not fd then
error(string.format("%s connect error host:%s, port:%s, timeout:%s", protocol, hostname, port, timeout))
return
end
-- print("protocol hostname port", protocol, hostname, port)
local interface = gen_interface(protocol, fd)
local finish
if timeout then
skynet.timeout(timeout, function()
if not finish then
socket.shutdown(fd) -- shutdown the socket fd, need close later.
if interface.close then
interface.close()
end
end
end)
end
if interface.init then
interface.init()
end
local ok , statuscode, body = pcall(internal.request, interface, method, host, url, recvheader, header, content)
finish = true
socket.close(fd)
if interface.close then
interface.close()
end
if ok then
return statuscode, body
else
error(statuscode)
end
end
function httpc.get(...)
return httpc.request("GET", ...)
end
local function escape(s)
return (string.gsub(s, "([^A-Za-z0-9_])", function(c)
return string.format("%%%02X", string.byte(c))
end))
end
function httpc.post(host, url, form, recvheader)
local header = {
["content-type"] = "application/x-www-form-urlencoded"
}
local body = {}
for k,v in pairs(form) do
table.insert(body, string.format("%s=%s",escape(k),escape(v)))
end
return httpc.request("POST", host, url, recvheader, header, table.concat(body , "&"))
end
return httpc