Fix int extract_host(const char *header) function

This commit is contained in:
aixiao 2022-01-12 09:37:15 +08:00
parent 9a28c71d6f
commit 6403ab9004
9 changed files with 148 additions and 89 deletions

0
Makefile Normal file → Executable file
View File

3
README.md Normal file → Executable file
View File

@ -2,8 +2,7 @@
修改自mproxy(https://github.com/examplecode/mproxy), 作为CProxy服务端. 仅代理TCP. 修改自mproxy(https://github.com/examplecode/mproxy), 作为CProxy服务端. 仅代理TCP.
支持客户端IP白名单. 支持客户端IP白名单.
支持IPV4. 支持IPV4/IPV6.
支持IPV6.
## 参数 ## 参数

106
ais.c Normal file → Executable file
View File

@ -70,7 +70,9 @@ enum {
static int io_flag; /* 网络io的一些标志位 */ static int io_flag; /* 网络io的一些标志位 */
static int m_pid; /* 保存主进程id */ static int m_pid; /* 保存主进程id */
void server_loop(int signal, char *conffile); conf *configure = NULL;
void server_loop(int signal, conf *p);
void stop_server(); void stop_server();
void handle_client(int client_sock, struct sockaddr_in client_addr); void handle_client(int client_sock, struct sockaddr_in client_addr);
void forward_header(int destination_sock); void forward_header(int destination_sock);
@ -84,6 +86,7 @@ const char *get_work_mode();
int create_connection6(char *remote_host, int remote_port); int create_connection6(char *remote_host, int remote_port);
int _main(int argc, char *argv[]); int _main(int argc, char *argv[]);
ssize_t readLine(int fd, void *buffer, size_t n) ssize_t readLine(int fd, void *buffer, size_t n)
{ {
ssize_t numRead; ssize_t numRead;
@ -155,6 +158,7 @@ int read_header(int fd, void *buffer)
} }
} }
return 0; return 0;
} }
@ -222,9 +226,11 @@ int extract_host(const char *header)
return 0; return 0;
} else { // 在非 CONNECT 方法中解析主机名称及端口号 } else { // 在非 CONNECT 方法中解析主机名称及端口号
char *p = strstr(header, "Host:"); char *p = strstr(header, "Host:");
if (!p) { char *p0 = strstr(header, "host:");
if (!p && !p0) {
return -1; return -1;
} }
char *p1 = strchr(p, '\n'); char *p1 = strchr(p, '\n');
if (!p1) { if (!p1) {
return -1; return -1;
@ -240,8 +246,8 @@ int extract_host(const char *header)
char *p4 = NULL; char *p4 = NULL;
if (p3) if (p3)
p4 = strchr(p3 + 1, ':'); p4 = strchr(p3 + 1, ':');
{ // IPV6
if (p4 != NULL) { // IPV6 if (p4 != NULL) {
char *p5 = NULL; char *p5 = NULL;
char *p6 = NULL; char *p6 = NULL;
p5 = strchr(header, ' '); p5 = strchr(header, ' ');
@ -292,8 +298,9 @@ int extract_host(const char *header)
return -1; return -1;
} }
}
if (p2 && p2 < p1) { if (p2 && p2 < p1) { // http请求头带端口
int p_len = (int)(p1 - p2 - 1); int p_len = (int)(p1 - p2 - 1);
char s_port[p_len]; char s_port[p_len];
strncpy(s_port, p2 + 1, p_len); strncpy(s_port, p2 + 1, p_len);
@ -303,8 +310,8 @@ int extract_host(const char *header)
int h_len = (int)(p2 - p - 5 - 1); int h_len = (int)(p2 - p - 5 - 1);
strncpy(remote_host, p + 5 + 1, h_len); strncpy(remote_host, p + 5 + 1, h_len);
remote_host[h_len] = '\0'; remote_host[h_len] = '\0';
} else { } else { // http请求头不带端口
int h_len = (int)(p1 - p - 5 - 1 - 1); int h_len = (int)(p1 - p - 5 - 1);
strncpy(remote_host, p + 5 + 1, h_len); strncpy(remote_host, p + 5 + 1, h_len);
remote_host[h_len] = '\0'; remote_host[h_len] = '\0';
remote_port = 80; remote_port = 80;
@ -405,32 +412,32 @@ void handle_client(int client_sock, struct sockaddr_in client_addr)
int is_http_tunnel = 0; int is_http_tunnel = 0;
if (strlen(remote_host) == 0) { /* 未指定远端主机名称从http 请求 HOST 字段中获取 */ if (strlen(remote_host) == 0) { /* 未指定远端主机名称从http 请求 HOST 字段中获取 */
#ifdef DEBUG #ifdef DEBUG
LOG(" ============ handle new client ============\n"); LOG(RED " ============ handle new client ============\n" NONE);
LOG(">>>Header:%s\n", header_buffer); LOG(RED ">>>Header:%s\n" NONE, header_buffer);
#endif #endif
if (read_header(client_sock, header_buffer) < 0) { if (read_header(client_sock, header_buffer) < 0) {
LOG("Read Http header failed\n"); LOG(RED "Read Http header failed\n" NONE);
return; return;
} else { } else {
char *p = strstr(header_buffer, "CONNECT"); /* 判断是否是http 隧道请求 */ char *p = strstr(header_buffer, "CONNECT"); /* 判断是否是http 隧道请求 */
if (p) { if (p) {
LOG("receive CONNECT request\n"); LOG(RED "receive CONNECT request\n" NONE);
is_http_tunnel = 1; is_http_tunnel = 1;
} }
if (strstr(header_buffer, "GET /AIS") > 0) { if (strstr(header_buffer, "GET /AIS") > 0) {
LOG("====== hand AIS info request ===="); LOG(RED "====== hand AIS info request ====" NONE);
hand_mproxy_info_req(client_sock, header_buffer); hand_mproxy_info_req(client_sock, header_buffer);
return; return;
} }
if (extract_host(header_buffer) < 0) { if (extract_host(header_buffer) < 0) {
LOG("Cannot extract host field,bad http protrotol"); LOG(RED "Cannot extract host field,bad http protrotol" NONE);
return; return;
} }
LOG("Host:%s port: %d io_flag:%d\n", remote_host, remote_port, io_flag); LOG(RED "Host:%s port: %d io_flag:%d\n" NONE, remote_host, remote_port, io_flag);
} }
} }
@ -440,7 +447,7 @@ void handle_client(int client_sock, struct sockaddr_in client_addr)
//printf("%d\n", remote_port); //printf("%d\n", remote_port);
if ((remote_sock = create_connection6(remote_host, remote_port)) < 0) { if ((remote_sock = create_connection6(remote_host, remote_port)) < 0) {
LOG("Cannot connect to host [%s:%d]\n", remote_host, remote_port); LOG(RED "Cannot connect to host [%s:%d]\n" NONE, remote_host, remote_port);
return; return;
} }
@ -477,8 +484,8 @@ void forward_header(int destination_sock)
{ {
rewrite_header(); rewrite_header();
#ifdef DEBUG #ifdef DEBUG
LOG("================ The Forward HEAD ================="); LOG(RED "================ The Forward HEAD =================" NONE);
LOG("%s\n", header_buffer); LOG(RED "%s\n" NONE, header_buffer);
#endif #endif
int len = strlen(header_buffer); int len = strlen(header_buffer);
@ -629,7 +636,7 @@ int whitelist(char *client_ip, char (*whitelist_ip)[WHITELIST_IP_NUM])
return 0; return 0;
} }
void server_loop(int signal, char *conffile) void server_loop(int signals, conf *configure)
{ {
int i; int i;
char ipstr[WHITELIST_IP_NUM]; char ipstr[WHITELIST_IP_NUM];
@ -640,27 +647,25 @@ void server_loop(int signal, char *conffile)
socklen_t addrlen6 = sizeof(client_addr6); socklen_t addrlen6 = sizeof(client_addr6);
char whitelist_ip[WHITELIST_IP_NUM][WHITELIST_IP_NUM] = { { 0 }, { 0 } }; char whitelist_ip[WHITELIST_IP_NUM][WHITELIST_IP_NUM] = { { 0 }, { 0 } };
conf *configure = (struct CONF *)malloc(sizeof(struct CONF)); //printf("%s\n", configure->IP_SEGMENT);
read_conf(conffile, configure);
printf("%s\n", configure->IP_SEGMENT);
split_string(configure->IP_SEGMENT, " ", whitelist_ip); split_string(configure->IP_SEGMENT, " ", whitelist_ip);
for (i = 1; i <= WHITELIST_IP_NUM - 1; i++) { for (i = 1; i <= WHITELIST_IP_NUM - 1; i++) {
if (*whitelist_ip[i] != '\0') if (*whitelist_ip[i] != '\0') ;
printf("%s\n", whitelist_ip[i]); //printf("%s\n", whitelist_ip[i]);
} }
while (1) { while (1) {
if (signal == 4) { if (signals == 4) {
client_sock = accept(server_sock, (struct sockaddr *)&client_addr, &addrlen); client_sock = accept(server_sock, (struct sockaddr *)&client_addr, &addrlen);
if (client_sock > 0) { if (client_sock > 0) {
LOG("Client Ip %s Client Port %d\n", inet_ntop(AF_INET, &client_addr.sin_addr.s_addr, ipstr, sizeof(ipstr)), ntohs(client_addr.sin_port)); LOG(RED "Client Ip %s Client Port %d\n" NONE, inet_ntop(AF_INET, &client_addr.sin_addr.s_addr, ipstr, sizeof(ipstr)), ntohs(client_addr.sin_port));
strcpy(client_ip, inet_ntop(AF_INET, &client_addr.sin_addr.s_addr, ipstr, sizeof(ipstr))); // 复制客户端IP到client_ip strcpy(client_ip, inet_ntop(AF_INET, &client_addr.sin_addr.s_addr, ipstr, sizeof(ipstr))); // 复制客户端IP到client_ip
if (configure->IP_RESTRICTION == 1) { if (configure->IP_RESTRICTION == 1) {
if (whitelist(client_ip, whitelist_ip) == 0) { if (whitelist(client_ip, whitelist_ip) == 0) {
LOG("非法IPV4客户端, 拒绝连接\n"); LOG(RED "非法IPV4客户端, 拒绝连接\n" NONE);
continue; continue;
} }
} }
@ -674,15 +679,15 @@ void server_loop(int signal, char *conffile)
close(client_sock); // 关闭父进程 client_sock close(client_sock); // 关闭父进程 client_sock
} }
if (signal == 6) { if (signals == 6) {
client_sock6 = accept(server_sock6, (struct sockaddr *)&client_addr6, &addrlen6); client_sock6 = accept(server_sock6, (struct sockaddr *)&client_addr6, &addrlen6);
if (client_sock6 > 0) { if (client_sock6 > 0) {
LOG("Client Ip %s Client Port %d\n", inet_ntop(AF_INET6, &client_addr6.sin6_addr, ipstr, sizeof(ipstr)), ntohs(client_addr6.sin6_port)); LOG(RED "Client Ip %s Client Port %d\n" NONE, inet_ntop(AF_INET6, &client_addr6.sin6_addr, ipstr, sizeof(ipstr)), ntohs(client_addr6.sin6_port));
strcpy(client_ip, inet_ntop(AF_INET6, &client_addr6.sin6_addr, ipstr, sizeof(ipstr))); // 复制客户端IP到client_ip strcpy(client_ip, inet_ntop(AF_INET6, &client_addr6.sin6_addr, ipstr, sizeof(ipstr))); // 复制客户端IP到client_ip
if (configure->IP_RESTRICTION == 1) { if (configure->IP_RESTRICTION == 1) {
if (whitelist(client_ip, whitelist_ip) == 0) { if (whitelist(client_ip, whitelist_ip) == 0) {
LOG("非法IPV6客户端, 拒绝连接\n"); LOG(RED "非法IPV6客户端, 拒绝连接\n" NONE);
continue; continue;
} }
} }
@ -735,8 +740,7 @@ int create_server_socket(int port)
server_addr.sin_family = AF_INET; server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(port); server_addr.sin_port = htons(port);
server_addr.sin_addr.s_addr = INADDR_ANY; server_addr.sin_addr.s_addr = INADDR_ANY;
if (bind(server_sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) if (bind(server_sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) != 0) {
!= 0) {
return SERVER_BIND_ERROR; return SERVER_BIND_ERROR;
} }
@ -784,11 +788,12 @@ int create_server_socket6(int port)
return server_sock; return server_sock;
} }
void start_server(int SIGNAL, char *conffile) void start_server(int SIGNAL, conf *p)
{ {
//初始化全局变量 //初始化全局变量
header_buffer = (char *)malloc(MAX_HEADER_SIZE); header_buffer = (char *)malloc(MAX_HEADER_SIZE);
// ipv4
if (SIGNAL == 4) { if (SIGNAL == 4) {
if ((server_sock = create_server_socket(local_port)) < 0) { // start server if ((server_sock = create_server_socket(local_port)) < 0) { // start server
LOG("Cannot run server on %d\n", local_port); LOG("Cannot run server on %d\n", local_port);
@ -796,6 +801,7 @@ void start_server(int SIGNAL, char *conffile)
} }
} }
// ipv6
if (SIGNAL == 6) { if (SIGNAL == 6) {
if ((server_sock6 = create_server_socket6(local_port)) < 0) { // start server if ((server_sock6 = create_server_socket6(local_port)) < 0) { // start server
LOG("Cannot run server on %d\n", local_port); LOG("Cannot run server on %d\n", local_port);
@ -803,7 +809,16 @@ void start_server(int SIGNAL, char *conffile)
} }
} }
server_loop(SIGNAL, conffile); server_loop(SIGNAL, p);
}
void signal_free(int signal)
{
printf("PID:%d, signal = %d\n", getpid(), signal);
free(header_buffer);
kill(getpid(), SIGKILL);
exit(0);
} }
int _main(int argc, char *argv[]) int _main(int argc, char *argv[])
@ -820,8 +835,10 @@ int _main(int argc, char *argv[])
char *p = NULL; char *p = NULL;
char *conffile = "./ais.conf"; char *conffile = "./ais.conf";
conf *configure = (struct CONF *)malloc(sizeof(struct CONF)); //conf *configure = (struct CONF *)malloc(sizeof(struct CONF));
configure = (struct CONF *)malloc(sizeof(struct CONF));
read_conf(conffile, configure); read_conf(conffile, configure);
printf("%d\n", configure->local_port); printf("%d\n", configure->local_port);
printf("%s\n", configure->io_flag); printf("%s\n", configure->io_flag);
printf("%d\n", configure->encode); printf("%d\n", configure->encode);
@ -855,6 +872,8 @@ int _main(int argc, char *argv[])
break; break;
case 'c': case 'c':
conffile = optarg; conffile = optarg;
free_conf(configure);
read_conf(conffile, configure);
break; break;
case 'E': case 'E':
io_flag = W_S_ENC; io_flag = W_S_ENC;
@ -892,23 +911,32 @@ int _main(int argc, char *argv[])
} }
} }
printf("sslEncodeCode: %d\n", sslEncodeCode); //printf("sslEncodeCode: %d\n", sslEncodeCode);
get_info(info_buf); get_info(info_buf);
LOG("%s\n", info_buf); LOG(RED "%s\n" NONE, info_buf);
{
signal(SIGTERM, signal_free);
signal(SIGHUP, signal_free);
signal(SIGINT, signal_free);
signal(SIGABRT, signal_free);
signal(SIGILL, signal_free);
signal(SIGSEGV, signal_free);
}
if (fork() == 0) { // IPV4 进程 if (fork() == 0) { // IPV4 进程
start_server(4, conffile); start_server(4, configure);
} }
if (fork() == 0) { // IPV6 进程 if (fork() == 0) { // IPV6 进程
start_server(6, conffile); start_server(6, configure);
} }
free_conf(configure); free_conf(configure);
return 0; return 0;
} }
int main(int argc, char *argv[]) int main(int argc, char *argv[], char **envp)
{ {
return _main(argc, argv); return _main(argc, argv);
} }

1
ais.conf Normal file → Executable file
View File

@ -5,3 +5,4 @@ global {
IP_RESTRICTION = 0; IP_RESTRICTION = 0;
IP_SEGMENT= 127.0.0.1; IP_SEGMENT= 127.0.0.1;
} }

18
ais.h Normal file → Executable file
View File

@ -3,4 +3,22 @@
#define WHITELIST_IP_NUM 2700 #define WHITELIST_IP_NUM 2700
// 字体颜色
#define NONE "\033[m"
#define RED "\033[0;32;31m"
#define LIGHT_RED "\033[1;31m"
#define GREEN "\033[0;32;32m"
#define LIGHT_GREEN "\033[1;32m"
#define BLUE "\033[0;32;34m"
#define LIGHT_BLUE "\033[1;34m"
#define DARY_GRAY "\033[1;30m"
#define CYAN "\033[0;36m"
#define LIGHT_CYAN "\033[1;36m"
#define PURPLE "\033[0;35m"
#define LIGHT_PURPLE "\033[1;35m"
#define BROWN "\033[0;33m"
#define YELLOW "\033[1;33m"
#define LIGHT_GRAY "\033[0;37m"
#define WHITE "\033[1;37m"
#endif #endif

17
conf.c Normal file → Executable file
View File

@ -1,5 +1,16 @@
#include "conf.h" #include "conf.h"
int8_t copy_new_mem(char *src, int src_len, char **dest)
{
*dest = (char *)malloc(src_len + 1);
if (*dest == NULL)
return 1;
memcpy(*dest, src, src_len);
*((*dest) + src_len) = '\0';
return 0;
}
/* 在content中设置变量(var)的首地址,值(val)的位置首地址和末地址,返回下一行指针 */ /* 在content中设置变量(var)的首地址,值(val)的位置首地址和末地址,返回下一行指针 */
static char *set_var_val_lineEnd(char *content, char **var, char **val_begin, char **val_end) static char *set_var_val_lineEnd(char *content, char **var, char **val_begin, char **val_end)
{ {
@ -74,7 +85,8 @@ static void parse_global_module(char *content, conf * p)
if (strcasecmp(var, "io_flag") == 0) { if (strcasecmp(var, "io_flag") == 0) {
val_begin_len = strlen(val_begin) + 1; val_begin_len = strlen(val_begin) + 1;
p->io_flag = (char *)malloc(val_begin_len); //val_begin_len = val_end - val_begin;
p->io_flag = (char *)malloc(val_begin_len + 1);
memset(p->io_flag, 0, val_begin_len); memset(p->io_flag, 0, val_begin_len);
memcpy(p->io_flag, val_begin, val_begin_len); memcpy(p->io_flag, val_begin, val_begin_len);
} }
@ -85,7 +97,8 @@ static void parse_global_module(char *content, conf * p)
if (strcasecmp(var, "IP_SEGMENT") == 0) { if (strcasecmp(var, "IP_SEGMENT") == 0) {
val_begin_len = strlen(val_begin) + 1; val_begin_len = strlen(val_begin) + 1;
p->IP_SEGMENT = (char *)malloc(val_begin_len); //val_begin_len = val_end - val_begin;
p->IP_SEGMENT = (char *)malloc(val_begin_len + 1);
memset(p->IP_SEGMENT, 0, val_begin_len); memset(p->IP_SEGMENT, 0, val_begin_len);
memcpy(p->IP_SEGMENT, val_begin, val_begin_len); memcpy(p->IP_SEGMENT, val_begin, val_begin_len);
} }

0
conf.h Normal file → Executable file
View File

0
stript/start.sh Normal file → Executable file
View File

0
stript/stop.sh Normal file → Executable file
View File