SpecialProxy

This commit is contained in:
mmmdbybyd 2018-10-14 11:58:56 +08:00 committed by GitHub
parent 4c843a01a0
commit 80a25f7740
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 131 additions and 157 deletions

22
LICENSE
View File

@ -1,21 +1,7 @@
MIT License Copyright (c) <2017> mmmdbybyd
Copyright (c) 2017 mmmdbybyd Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
Permission is hereby granted, free of charge, to any person obtaining a copy The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -27,4 +27,12 @@ SpecialProxy
-w 工作进程数 -w 工作进程数
##### BUG ##### BUG
好像有些连接不关闭,需要定时重启代理 好像有些连接不关闭,需要定时重启代理
##### 编译:
~~~~~
Linux/Android:
make
Android-ndk:
ndk-build
~~~~~

44
dns.c
View File

@ -6,13 +6,13 @@ int dnsFd;
void read_dns_rsp() void read_dns_rsp()
{ {
static char rsp_data[512], *p, ips[16]; char rsp_data[512], *p, ip[16];
static unsigned char *_p; unsigned char *_p;
static struct dns *dns; struct dns *dns;
static conn_t *client; conn_t *client;
static int16_t len, dns_flag; int16_t len, dns_flag;
while ((len = read(dnsFd, rsp_data, BUFFER_SIZE)) > 11) while ((len = read(dnsFd, rsp_data, 512)) > 11)
{ {
memcpy(&dns_flag, rsp_data, 2); memcpy(&dns_flag, rsp_data, 2);
dns = dns_list + dns_flag; dns = dns_list + dns_flag;
@ -26,9 +26,9 @@ void read_dns_rsp()
continue; continue;
} }
/* get ips */ /* get domain ip */
p = rsp_data + dns->request_len + 11; p = rsp_data + dns->request_len + 11;
ips[0] = '\0'; ip[0] = 0;
while (p - rsp_data + 4 <= len) while (p - rsp_data + 4 <= len)
{ {
//type //type
@ -38,13 +38,12 @@ void read_dns_rsp()
continue; continue;
} }
_p = (unsigned char *)p + 1; _p = (unsigned char *)p + 1;
sprintf(ips, "%d.%d.%d.%d", _p[0], _p[1], _p[2], _p[3]); sprintf(ip, "%d.%d.%d.%d", _p[0], _p[1], _p[2], _p[3]);
break; break;
} }
if (ips[0]) if (ip[0])
{ {
//printf("ips %s\n", ips); if (connectionToServer(ip, client + 1) != 0)
if (connectionToServer(ips, client + 1) != 0)
{ {
close_connection(client); close_connection(client);
continue; continue;
@ -61,7 +60,7 @@ void read_dns_rsp()
/* 完全发送返回0发送部分返回1出错返回-1 */ /* 完全发送返回0发送部分返回1出错返回-1 */
static int8_t send_dns_req(struct dns *dns) static int8_t send_dns_req(struct dns *dns)
{ {
static int write_len; int write_len;
write_len = write(dnsFd, dns->request + dns->sent_len, dns->request_len - dns->sent_len); write_len = write(dnsFd, dns->request + dns->sent_len, dns->request_len - dns->sent_len);
if (write_len == dns->request_len - dns->sent_len) if (write_len == dns->request_len - dns->sent_len)
@ -82,7 +81,7 @@ static int8_t send_dns_req(struct dns *dns)
void dns_query() void dns_query()
{ {
static int16_t i, ret; int16_t i, ret;
for (i = 0; i < MAX_CONNECTION >> 1; i++) for (i = 0; i < MAX_CONNECTION >> 1; i++)
{ {
@ -95,18 +94,19 @@ void dns_query()
close_connection(cts + (i << 1)); close_connection(cts + (i << 1));
} }
} }
if (i == MAX_CONNECTION >> 1) //dnsFd的缓冲区以满
ev.events = EPOLLIN|EPOLLET; if (i < MAX_CONNECTION >> 1)
else {
ev.events = EPOLLIN|EPOLLOUT|EPOLLET; ev.events = EPOLLIN|EPOLLOUT|EPOLLET;
ev.data.fd = dnsFd; ev.data.fd = dnsFd;
epoll_ctl(efd, EPOLL_CTL_MOD, dnsFd, &ev); epoll_ctl(efd, EPOLL_CTL_MOD, dnsFd, &ev);
}
} }
int8_t build_dns_req(struct dns *dns, char *domain) int8_t build_dns_req(struct dns *dns, char *domain)
{ {
static char *p, *_p; char *p, *_p;
static int8_t domain_size; int8_t domain_size;
domain_size = strlen(domain); domain_size = strlen(domain);
p = dns->request + 12; p = dns->request + 12;
@ -128,8 +128,6 @@ int8_t build_dns_req(struct dns *dns, char *domain)
{ {
case 0: case 0:
ev.data.fd = dnsFd; ev.data.fd = dnsFd;
ev.events = EPOLLIN|EPOLLET;
epoll_ctl(efd, EPOLL_CTL_MOD, dnsFd, &ev);
return 0; return 0;
case 1: case 1:

177
http.c
View File

@ -10,61 +10,18 @@ char *local_header, *proxy_header, *ssl_proxy;
int lisFd, proxy_header_len, local_header_len; int lisFd, proxy_header_len, local_header_len;
uint8_t strict_spilce; uint8_t strict_spilce;
/*
glibc的memcpy不能src + len > dst
memcpy低
*/
#ifdef XMEMCPY
typedef struct byte256 {char data[256];} byte256_t;
typedef struct byte64 {char data[64];} byte64_t;
typedef struct byte16 {char data[16];} byte16_t;
static void xmemcpy(char *src, const char *dst, size_t len)
{
static byte256_t *to256, *from256;
static byte64_t *to64, *from64;
static byte16_t *to16, *from16;
to256 = (byte256_t *)src;
from256 = (byte256_t *)dst;
while (len >= sizeof(byte256_t))
{
*to256++ = *from256++;
len -= sizeof(byte256_t);
}
to64 = (byte64_t *)to256;
from64 = (byte64_t *)from256;
while (len >= sizeof(byte64_t))
{
*to64++ = *from64++;
len -= sizeof(byte64_t);
}
to16 = (byte16_t *)to64;
from16 = (byte16_t *)from64;
while (len >= sizeof(byte16_t))
{
*to16++ = *from16++;
len -= sizeof(byte16_t);
}
src = (char *)to16;
dst = (char *)from16;
while (len--)
*src++ = *dst++;
}
#else
#define xmemcpy memcpy
#endif
int8_t connectionToServer(char *ip, conn_t *server) int8_t connectionToServer(char *ip, conn_t *server)
{ {
server->fd = socket(AF_INET, SOCK_STREAM, 0); server->fd = socket(AF_INET, SOCK_STREAM, 0);
if (server->fd < 0) if (server->fd < 0)
return 1; return 1;
fcntl(server->fd, F_SETFL, fcntl(server->fd, F_GETFL)|O_NONBLOCK); fcntl(server->fd, F_SETFL, O_NONBLOCK);
addr.sin_addr.s_addr = inet_addr(ip); addr.sin_addr.s_addr = inet_addr(ip);
addr.sin_port = htons(server->destPort); addr.sin_port = htons(server->destPort);
connect(server->fd, (struct sockaddr *)&addr, sizeof(addr)); if (connect(server->fd, (struct sockaddr *)&addr, sizeof(addr)) != 0 && errno != EINPROGRESS)
return 1;
ev.data.ptr = server; ev.data.ptr = server;
ev.events = EPOLLIN|EPOLLOUT|EPOLLET; ev.events = EPOLLIN|EPOLLOUT|EPOLLERR|EPOLLHUP|EPOLLET;
epoll_ctl(efd, EPOLL_CTL_ADD, server->fd, &ev); epoll_ctl(efd, EPOLL_CTL_ADD, server->fd, &ev);
return 0; return 0;
@ -76,7 +33,7 @@ void close_connection(conn_t *conn)
close(conn->fd); close(conn->fd);
if ((conn - cts) & 1) if ((conn - cts) & 1)
{ {
static char *server_data; char *server_data;
server_data = conn->ready_data; server_data = conn->ready_data;
memset(conn, 0, sizeof(conn_t)); memset(conn, 0, sizeof(conn_t));
@ -85,7 +42,7 @@ void close_connection(conn_t *conn)
} }
else else
{ {
static struct dns *d; struct dns *d;
d = dns_list + ((conn - cts) >> 1); d = dns_list + ((conn - cts) >> 1);
d->request_len = d->sent_len = 0; d->request_len = d->sent_len = 0;
@ -122,8 +79,8 @@ static int8_t request_type(char *data)
static char *read_data(conn_t *in, char *data, int *data_len) static char *read_data(conn_t *in, char *data, int *data_len)
{ {
static char *new_data; char *new_data;
static int read_len; int read_len;
do { do {
new_data = (char *)realloc(data, *data_len + BUFFER_SIZE + 1); new_data = (char *)realloc(data, *data_len + BUFFER_SIZE + 1);
@ -153,12 +110,12 @@ static char *read_data(conn_t *in, char *data, int *data_len)
static char *get_host(char *data) static char *get_host(char *data)
{ {
static char *hostEnd, *host; char *hostEnd, *host;
host = strstr(data, local_header); host = strstr(data, local_header);
if (host != NULL) if (host != NULL)
{ {
static char *local_host; char *local_host;
host += local_header_len; host += local_header_len;
while (*host == ' ') while (*host == ' ')
@ -171,9 +128,8 @@ static char *get_host(char *data)
if (local_host == NULL) if (local_host == NULL)
return NULL; return NULL;
strcpy(local_host, "127.0.0.1:"); strcpy(local_host, "127.0.0.1:");
strncpy(local_host + 10, host, hostEnd - host); memcpy(local_host + 10, host, hostEnd - host);
local_host[10 + (hostEnd - host)] = '\0'; local_host[10 + (hostEnd - host)] = '\0';
puts(local_host);
return local_host; return local_host;
} }
host= strstr(data, proxy_header); host= strstr(data, proxy_header);
@ -192,8 +148,8 @@ static char *get_host(char *data)
/* 删除请求头中的头域 */ /* 删除请求头中的头域 */
static void del_hdr(char *header, int *header_len) static void del_hdr(char *header, int *header_len)
{ {
static char *key_end, *line_begin, *line_end; char *key_end, *line_begin, *line_end;
static int key_len; int key_len;
for (line_begin = strchr(header, '\n'); line_begin++ && *line_begin != '\r'; line_begin = line_end) for (line_begin = strchr(header, '\n'); line_begin++ && *line_begin != '\r'; line_begin = line_end)
{ {
@ -206,7 +162,7 @@ static void del_hdr(char *header, int *header_len)
{ {
if (line_end++) if (line_end++)
{ {
xmemcpy(line_begin, line_end, *header_len - (line_end - header) + 1); memmove(line_begin, line_end, *header_len - (line_end - header) + 1);
(*header_len) -= line_end - line_begin; (*header_len) -= line_end - line_begin;
line_end = line_begin - 1; //新行前一个字符 line_end = line_begin - 1; //新行前一个字符
} }
@ -223,8 +179,8 @@ static void del_hdr(char *header, int *header_len)
/* 构建新请求头 */ /* 构建新请求头 */
static char *build_request(char *client_data, int *data_len, char *host) static char *build_request(char *client_data, int *data_len, char *host)
{ {
static char *uri, *url, *p, *lf, *header, *new_data, *proxy_host; char *uri, *url, *p, *lf, *header, *new_data, *proxy_host;
static int len; int len;
header = client_data; header = client_data;
proxy_host = host; proxy_host = host;
@ -241,14 +197,14 @@ static char *build_request(char *client_data, int *data_len, char *host)
p = lf - 10; //指向HTTP版本前面的空格 p = lf - 10; //指向HTTP版本前面的空格
if (uri != NULL && uri < p) if (uri != NULL && uri < p)
{ {
xmemcpy(url, uri, *data_len - (uri - client_data) + 1); memmove(url, uri, *data_len - (uri - client_data) + 1);
*data_len -= uri - url; *data_len -= uri - url;
lf -= uri - url; lf -= uri - url;
} }
else else
{ {
*url++ = '/'; *url++ = '/';
xmemcpy(url, p, *data_len - (p - client_data) + 1); memmove(url, p, *data_len - (p - client_data) + 1);
*data_len -= p - url; *data_len -= p - url;
lf -= p - url; lf -= p - url;
} }
@ -294,7 +250,7 @@ static char *build_request(char *client_data, int *data_len, char *host)
/* 解析Host */ /* 解析Host */
int8_t parse_host(conn_t *server, char *host) int8_t parse_host(conn_t *server, char *host)
{ {
static char *port, *p; char *port, *p;
port = strchr(host, ':'); port = strchr(host, ':');
if (port) if (port)
@ -323,7 +279,7 @@ static int8_t copy_data(conn_t *ct)
{ {
if (ct->ready_data) if (ct->ready_data)
{ {
static char *new_data; char *new_data;
new_data = (char *)realloc(ct->ready_data, ct->ready_data_len + ct->incomplete_data_len); new_data = (char *)realloc(ct->ready_data, ct->ready_data_len + ct->incomplete_data_len);
if (new_data == NULL) if (new_data == NULL)
@ -344,11 +300,31 @@ static int8_t copy_data(conn_t *ct)
return 0; return 0;
} }
/* 判断请求是否为长连接 */
static int is_keepAlive(char *header)
{
char *ConnectionValue;
ConnectionValue = strstr(header, "\nConnection: ");
if (ConnectionValue)
{
ConnectionValue += 13;
if (*ConnectionValue == 'C' || *ConnectionValue == 'c')
return 0;
else
return 1;
}
if (strstr(header, "HTTP/1.1"))
return 1;
return 0;
}
static void serverToClient(conn_t *server) static void serverToClient(conn_t *server)
{ {
static conn_t *client; conn_t *client;
static int write_len; int write_len;
errno = 0;
client = server - 1; client = server - 1;
while ((server->ready_data_len = read(server->fd, server->ready_data, BUFFER_SIZE)) > 0) while ((server->ready_data_len = read(server->fd, server->ready_data, BUFFER_SIZE)) > 0)
{ {
@ -364,14 +340,23 @@ static void serverToClient(conn_t *server)
else if (write_len < server->ready_data_len) else if (write_len < server->ready_data_len)
{ {
server->sent_len = write_len; server->sent_len = write_len;
ev.events = EPOLLIN|EPOLLOUT|EPOLLET; ev.events = EPOLLIN|EPOLLOUT|EPOLLERR|EPOLLHUP|EPOLLET;
ev.data.ptr = client; ev.data.ptr = client;
epoll_ctl(efd, EPOLL_CTL_MOD, client->fd, &ev); epoll_ctl(efd, EPOLL_CTL_MOD, client->fd, &ev);
return; return;
} }
/* 判断服务端是否close */
if (client->request_type == HTTP_TYPE && client->is_ssl == 0)
{
server->ready_data[server->ready_data_len] = '\0';
if (strncmp(server->ready_data, "HTTP/1.", 7) == 0)
client->keep_alive = server->keep_alive = is_keepAlive(server->ready_data);
}
if (server->ready_data_len < BUFFER_SIZE)
break;
} }
//判断服务端是否关闭连接 //判断是否关闭连接
if (server->ready_data_len == 0 || errno != EAGAIN) if (server->ready_data_len == 0 || (errno != EAGAIN && errno != 0) || client->keep_alive == 0)
close_connection(server); close_connection(server);
else else
server->ready_data_len = server->sent_len = 0; server->ready_data_len = server->sent_len = 0;
@ -379,8 +364,8 @@ static void serverToClient(conn_t *server)
void tcp_out(conn_t *to) void tcp_out(conn_t *to)
{ {
static conn_t *from; conn_t *from;
static int write_len; int write_len;
if (to->fd == -1) if (to->fd == -1)
return; return;
@ -397,15 +382,14 @@ void tcp_out(conn_t *to)
serverToClient(from); serverToClient(from);
if (from->fd >= 0 && from->ready_data_len == 0) if (from->fd >= 0 && from->ready_data_len == 0)
{ {
ev.events = EPOLLIN|EPOLLET; ev.events = EPOLLIN|EPOLLERR|EPOLLHUP|EPOLLET;
ev.data.ptr = to; ev.data.ptr = to;
epoll_ctl(efd, EPOLL_CTL_MOD, to->fd, &ev); epoll_ctl(efd, EPOLL_CTL_MOD, to->fd, &ev);
} }
return;
} }
else else
{ {
ev.events = EPOLLIN|EPOLLET; ev.events = EPOLLIN|EPOLLERR|EPOLLHUP|EPOLLET;
ev.data.ptr = to; ev.data.ptr = to;
epoll_ctl(efd, EPOLL_CTL_MOD, to->fd, &ev); epoll_ctl(efd, EPOLL_CTL_MOD, to->fd, &ev);
free(from->ready_data); free(from->ready_data);
@ -416,7 +400,7 @@ void tcp_out(conn_t *to)
else if (write_len > 0) else if (write_len > 0)
{ {
from->sent_len += write_len; from->sent_len += write_len;
ev.events = EPOLLIN|EPOLLOUT|EPOLLET; ev.events = EPOLLIN|EPOLLOUT|EPOLLERR|EPOLLHUP|EPOLLET;
ev.data.ptr = to; ev.data.ptr = to;
epoll_ctl(efd, EPOLL_CTL_MOD, to->fd, &ev); epoll_ctl(efd, EPOLL_CTL_MOD, to->fd, &ev);
} }
@ -428,13 +412,11 @@ void tcp_out(conn_t *to)
void tcp_in(conn_t *in) void tcp_in(conn_t *in)
{ {
static int write_len; conn_t *server;
static conn_t *server; char *host, *headerEnd;
static char *host, *headerEnd;
if (in->fd < 0) if (in->fd < 0)
return; return;
//如果in - cts是奇数那么是服务端触发事件 //如果in - cts是奇数那么是服务端触发事件
if ((in - cts) & 1) if ((in - cts) & 1)
{ {
@ -450,7 +432,8 @@ void tcp_in(conn_t *in)
return; return;
} }
server = in + 1; server = in + 1;
if (request_type(in->incomplete_data) == OTHER_TYPE) server->request_type = in->request_type = request_type(in->incomplete_data);
if (in->request_type == OTHER_TYPE)
{ {
//如果是第一次读取数据并且不是HTTP请求的关闭连接。复制数据失败的也关闭连接 //如果是第一次读取数据并且不是HTTP请求的关闭连接。复制数据失败的也关闭连接
if (in->reread_data == 0 || copy_data(in) != 0) if (in->reread_data == 0 || copy_data(in) != 0)
@ -470,6 +453,8 @@ void tcp_in(conn_t *in)
close_connection(in); close_connection(in);
return; return;
} }
/* 判断是否长连接 */
server->keep_alive = in->keep_alive = is_keepAlive(in->incomplete_data);
/* 第一次读取数据 */ /* 第一次读取数据 */
if (in->reread_data == 0) if (in->reread_data == 0)
{ {
@ -482,16 +467,10 @@ void tcp_in(conn_t *in)
} }
if (strstr(in->incomplete_data, ssl_proxy)) if (strstr(in->incomplete_data, ssl_proxy))
{ {
write_len = write(in->fd, SSL_RSP, 39); server->keep_alive = in->keep_alive = 1;
if (write_len == 39) server->is_ssl = in->is_ssl = 1;
{ /* 这时候即使fd是非阻塞也只需要判断返回值是否小于0 */
; if (write(in->fd, SSL_RSP, 39) < 0)
}
else if (write_len > 0)
{
memcpy(server->ready_data, SSL_RSP + write_len, 39 - write_len);
}
else
{ {
free(host); free(host);
close_connection(in); close_connection(in);
@ -501,7 +480,7 @@ void tcp_in(conn_t *in)
if (headerEnd - in->incomplete_data < in->incomplete_data_len) if (headerEnd - in->incomplete_data < in->incomplete_data_len)
{ {
in->incomplete_data_len -= headerEnd - in->incomplete_data; in->incomplete_data_len -= headerEnd - in->incomplete_data;
xmemcpy(in->incomplete_data, headerEnd, in->incomplete_data_len + 1); memmove(in->incomplete_data, headerEnd, in->incomplete_data_len + 1);
if (request_type(in->incomplete_data) == OTHER_TYPE) if (request_type(in->incomplete_data) == OTHER_TYPE)
{ {
copy_data(in); copy_data(in);
@ -528,15 +507,17 @@ void tcp_in(conn_t *in)
} }
//数据处理完毕,可以发送 //数据处理完毕,可以发送
handle_data_complete: handle_data_complete:
//多次读取客户端数据,但是和服务端建立连接 //这个判断是防止 多次读取客户端数据,但是没有和服务端建立连接,导致报错
if (server->fd >= 0) if (server->fd >= 0)
tcp_out(server); tcp_out(server);
} }
void *accept_loop(void *ptr) void *accept_loop(void *ptr)
{ {
struct epoll_event epollEvent;
conn_t *client; conn_t *client;
epollEvent.events = EPOLLIN|EPOLLET;
while (1) while (1)
{ {
/* 偶数为客户端,奇数为服务端 */ /* 偶数为客户端,奇数为服务端 */
@ -548,14 +529,10 @@ void *accept_loop(void *ptr)
sleep(3); sleep(3);
continue; continue;
} }
client->fd = accept(lisFd, (struct sockaddr *)&addr, &addr_len); while ((client->fd = accept(lisFd, (struct sockaddr *)&addr, &addr_len)) < 0);
if (client->fd >= 0) fcntl(client->fd, F_SETFL, O_NONBLOCK);
{ epollEvent.data.ptr = client;
fcntl(client->fd, F_SETFL, fcntl(client->fd, F_GETFL)|O_NONBLOCK); epoll_ctl(efd, EPOLL_CTL_ADD, client->fd, &epollEvent);
ev.data.ptr = client;
ev.events = EPOLLIN|EPOLLET;
epoll_ctl(efd, EPOLL_CTL_ADD, client->fd, &ev);
}
} }
return NULL; return NULL;

3
http.h
View File

@ -8,6 +8,9 @@ typedef struct tcp_connection {
int fd, ready_data_len, incomplete_data_len, sent_len; int fd, ready_data_len, incomplete_data_len, sent_len;
uint16_t destPort; uint16_t destPort;
unsigned reread_data :1; unsigned reread_data :1;
unsigned request_type :1;
unsigned keep_alive :1;
unsigned is_ssl :1;
} conn_t; } conn_t;
extern void create_listen(char *ip, int port); extern void create_listen(char *ip, int port);

30
main.c
View File

@ -6,7 +6,7 @@
#define VERSION "0.1" #define VERSION "0.1"
#define DEFAULT_DNS_IP "114.114.114.114" #define DEFAULT_DNS_IP "114.114.114.114"
struct epoll_event evs[MAX_CONNECTION + 2], ev; struct epoll_event evs[MAX_CONNECTION + 1], ev;
struct sockaddr_in addr; struct sockaddr_in addr;
socklen_t addr_len; socklen_t addr_len;
int efd; int efd;
@ -37,7 +37,7 @@ static void server_loop()
epoll_ctl(efd, EPOLL_CTL_ADD, dnsFd, &ev); epoll_ctl(efd, EPOLL_CTL_ADD, dnsFd, &ev);
while (1) while (1)
{ {
n = epoll_wait(efd, evs, MAX_CONNECTION + 2, -1); n = epoll_wait(efd, evs, MAX_CONNECTION + 1, -1);
while (n-- > 0) while (n-- > 0)
{ {
if (evs[n].data.fd == dnsFd) if (evs[n].data.fd == dnsFd)
@ -49,6 +49,11 @@ static void server_loop()
} }
else else
{ {
if ((evs[n].events & EPOLLERR) || (evs[n].events & EPOLLHUP))
{
if (((conn_t *)evs[n].data.ptr)->fd >= 0)
close_connection((conn_t *)evs[n].data.ptr);
}
if (evs[n].events & EPOLLIN) if (evs[n].events & EPOLLIN)
tcp_in((conn_t *)evs[n].data.ptr); tcp_in((conn_t *)evs[n].data.ptr);
if (evs[n].events & EPOLLOUT) if (evs[n].events & EPOLLOUT)
@ -66,12 +71,6 @@ static void initializate(int argc, char **argv)
addr_len = sizeof(addr); addr_len = sizeof(addr);
lisFd = -1; lisFd = -1;
efd = epoll_create(MAX_CONNECTION + 2);
if (efd < 0)
{
perror("epoll_create");
exit(1);
}
dnsAddr.sin_family = addr.sin_family = AF_INET; dnsAddr.sin_family = addr.sin_family = AF_INET;
//默认dns地址 //默认dns地址
dnsAddr.sin_addr.s_addr = inet_addr(DEFAULT_DNS_IP); dnsAddr.sin_addr.s_addr = inet_addr(DEFAULT_DNS_IP);
@ -96,10 +95,6 @@ static void initializate(int argc, char **argv)
*p = '\0'; *p = '\0';
dnsAddr.sin_port = htons(atoi(p+1)); dnsAddr.sin_port = htons(atoi(p+1));
} }
else
{
dnsAddr.sin_port = htons(53);
}
dnsAddr.sin_addr.s_addr = inet_addr(optarg); dnsAddr.sin_addr.s_addr = inet_addr(optarg);
connect(dnsFd, (struct sockaddr *)&dnsAddr, sizeof(dnsAddr)); connect(dnsFd, (struct sockaddr *)&dnsAddr, sizeof(dnsAddr));
break; break;
@ -198,13 +193,20 @@ static void initializate(int argc, char **argv)
while (workers-- > 1 && fork() == 0) while (workers-- > 1 && fork() == 0)
//子进程中的dnsFd必须重新申请不然epoll监听可能读取到其他进程得到的数据 //子进程中的dnsFd必须重新申请不然epoll监听可能读取到其他进程得到的数据
dns_connect(&dnsAddr); dns_connect(&dnsAddr);
efd = epoll_create(MAX_CONNECTION + 1);
if (efd < 0)
{
perror("epoll_create");
exit(1);
}
} }
int main(int argc, char **argv) int main(int argc, char **argv)
{ {
initializate(argc, argv); initializate(argc, argv);
//if (daemon(1, 1)) if (daemon(1, 1))
if (daemon(1, 0)) //if (daemon(1, 0))
{ {
perror("daemon"); perror("daemon");
return 1; return 1;

2
main.h
View File

@ -19,7 +19,7 @@
#define BUFFER_SIZE 10240 #define BUFFER_SIZE 10240
#define MAX_CONNECTION 1020 #define MAX_CONNECTION 1020
extern struct epoll_event evs[MAX_CONNECTION + 2], ev; extern struct epoll_event evs[MAX_CONNECTION + 1], ev;
extern struct sockaddr_in addr; extern struct sockaddr_in addr;
extern socklen_t addr_len; extern socklen_t addr_len;
extern int efd; extern int efd;