From 6403ab90042f78e2e7968a7e4494eca16dc6af0e Mon Sep 17 00:00:00 2001 From: aixiao Date: Wed, 12 Jan 2022 09:37:15 +0800 Subject: [PATCH] Fix int extract_host(const char *header) function --- Makefile | 2 +- README.md | 3 +- ais.c | 196 +++++++++++++++++++++++++++--------------------- ais.conf | 1 + ais.h | 18 +++++ conf.c | 17 ++++- conf.h | 0 stript/start.sh | 0 stript/stop.sh | 0 9 files changed, 148 insertions(+), 89 deletions(-) mode change 100644 => 100755 Makefile mode change 100644 => 100755 README.md mode change 100644 => 100755 ais.c mode change 100644 => 100755 ais.conf mode change 100644 => 100755 ais.h mode change 100644 => 100755 conf.c mode change 100644 => 100755 conf.h mode change 100644 => 100755 stript/start.sh mode change 100644 => 100755 stript/stop.sh diff --git a/Makefile b/Makefile old mode 100644 new mode 100755 index c669027..1192da3 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -CROSS_COMPILE ?= +CROSS_COMPILE ?= CC := $(CROSS_COMPILE)gcc STRIP := $(CROSS_COMPILE)strip CFLAGS += -g -O2 -Wall diff --git a/README.md b/README.md old mode 100644 new mode 100755 index b5a351b..663a2c3 --- a/README.md +++ b/README.md @@ -2,8 +2,7 @@ 修改自mproxy(https://github.com/examplecode/mproxy), 作为CProxy服务端. 仅代理TCP. 支持客户端IP白名单. - 支持IPV4. - 支持IPV6. + 支持IPV4/IPV6. ## 参数 diff --git a/ais.c b/ais.c old mode 100644 new mode 100755 index 72b1913..93308dd --- a/ais.c +++ b/ais.c @@ -70,7 +70,9 @@ enum { static int io_flag; /* 网络io的一些标志位 */ static int m_pid; /* 保存主进程id */ -void server_loop(int signal, char *conffile); +conf *configure = NULL; + +void server_loop(int signal, conf *p); void stop_server(); void handle_client(int client_sock, struct sockaddr_in client_addr); void forward_header(int destination_sock); @@ -84,6 +86,7 @@ const char *get_work_mode(); int create_connection6(char *remote_host, int remote_port); int _main(int argc, char *argv[]); + ssize_t readLine(int fd, void *buffer, size_t n) { ssize_t numRead; @@ -155,6 +158,7 @@ int read_header(int fd, void *buffer) } } + return 0; } @@ -188,7 +192,7 @@ int isChar(char *string) // string字符串里有字符时返回1 int extract_host(const char *header) { - char *_p = strstr(header, "CONNECT"); // 在 CONNECT 方法中解析 隧道主机名称及端口号 + char *_p = strstr(header, "CONNECT"); // 在 CONNECT 方法中解析 隧道主机名称及端口号 if (_p) { if (strchr(header, '[') || strchr(header, ']')) { char *_p1 = strchr(header, '['); @@ -222,9 +226,11 @@ int extract_host(const char *header) return 0; } else { // 在非 CONNECT 方法中解析主机名称及端口号 char *p = strstr(header, "Host:"); - if (!p) { + char *p0 = strstr(header, "host:"); + if (!p && !p0) { return -1; } + char *p1 = strchr(p, '\n'); if (!p1) { return -1; @@ -240,60 +246,61 @@ int extract_host(const char *header) char *p4 = NULL; if (p3) p4 = strchr(p3 + 1, ':'); + { // IPV6 + if (p4 != NULL) { + char *p5 = NULL; + char *p6 = NULL; + p5 = strchr(header, ' '); + if (p5) + p6 = strchr(p5 + 1, ' '); - if (p4 != NULL) { // IPV6 - char *p5 = NULL; - char *p6 = NULL; - p5 = strchr(header, ' '); - if (p5) - p6 = strchr(p5 + 1, ' '); - - char url[p6 - p5 - 1]; - memset(url, 0, p6 - p5 - 1); - strncpy(url, p5 + 1, p6 - p5 - 1); - url[p6 - p5 - 1] = '\0'; - if (strstr(url, "http") != NULL) { - memcpy(url, url + 7, strlen(url) - 7); // 去除 'http://' - url[strlen(url) - 7] = '\0'; - char *p7 = strchr(url, '/'); - if (p7) { // 去除 uri - url[p7 - url] = '\0'; - } - printf("url: %s\n", url); - char *p8 = strchr(url, ']'); - if (p8) { - remote_port = atoi(p8 + 2); - strncpy(remote_host, url + 1, strlen(url) - strlen(p8) - 1); - - if (strlen(p8) < 3) { // 如果p8为 ']' 时, 长度为1 - remote_port = 80; + char url[p6 - p5 - 1]; + memset(url, 0, p6 - p5 - 1); + strncpy(url, p5 + 1, p6 - p5 - 1); + url[p6 - p5 - 1] = '\0'; + if (strstr(url, "http") != NULL) { + memcpy(url, url + 7, strlen(url) - 7); // 去除 'http://' + url[strlen(url) - 7] = '\0'; + char *p7 = strchr(url, '/'); + if (p7) { // 去除 uri + url[p7 - url] = '\0'; + } + printf("url: %s\n", url); + char *p8 = strchr(url, ']'); + if (p8) { + remote_port = atoi(p8 + 2); strncpy(remote_host, url + 1, strlen(url) - strlen(p8) - 1); - } - } else { // 不包含'['、']'时 - remote_port = 80; - strcpy(remote_host, url); - } - return 0; - } else { // 头为不规范的url时处理Host - char *_p1 = strchr(s_host, '['); - char *_p2 = strchr(_p1+1, ']'); - if (_p1 && _p2) { - strncpy(remote_host, _p1+1, _p2 - _p1 -1); - remote_port = atoi(_p2+2); - if (strlen(_p2) < 3) { + if (strlen(p8) < 3) { // 如果p8为 ']' 时, 长度为1 + remote_port = 80; + strncpy(remote_host, url + 1, strlen(url) - strlen(p8) - 1); + } + } else { // 不包含'['、']'时 remote_port = 80; - strncpy(remote_host, s_host+1, strlen(s_host)-strlen(_p2)-1); + strcpy(remote_host, url); } - } - - return 0; - } - return -1; + return 0; + } else { // 头为不规范的url时处理Host + char *_p1 = strchr(s_host, '['); + char *_p2 = strchr(_p1+1, ']'); + if (_p1 && _p2) { + strncpy(remote_host, _p1+1, _p2 - _p1 -1); + remote_port = atoi(_p2+2); + if (strlen(_p2) < 3) { + remote_port = 80; + strncpy(remote_host, s_host+1, strlen(s_host)-strlen(_p2)-1); + } + } + + return 0; + } + + return -1; + } } - if (p2 && p2 < p1) { + if (p2 && p2 < p1) { // http请求头带端口 int p_len = (int)(p1 - p2 - 1); char s_port[p_len]; strncpy(s_port, p2 + 1, p_len); @@ -303,8 +310,8 @@ int extract_host(const char *header) int h_len = (int)(p2 - p - 5 - 1); strncpy(remote_host, p + 5 + 1, h_len); remote_host[h_len] = '\0'; - } else { - int h_len = (int)(p1 - p - 5 - 1 - 1); + } else { // http请求头不带端口 + int h_len = (int)(p1 - p - 5 - 1); strncpy(remote_host, p + 5 + 1, h_len); remote_host[h_len] = '\0'; remote_port = 80; @@ -405,32 +412,32 @@ void handle_client(int client_sock, struct sockaddr_in client_addr) int is_http_tunnel = 0; if (strlen(remote_host) == 0) { /* 未指定远端主机名称从http 请求 HOST 字段中获取 */ #ifdef DEBUG - LOG(" ============ handle new client ============\n"); - LOG(">>>Header:%s\n", header_buffer); + LOG(RED " ============ handle new client ============\n" NONE); + LOG(RED ">>>Header:%s\n" NONE, header_buffer); #endif if (read_header(client_sock, header_buffer) < 0) { - LOG("Read Http header failed\n"); + LOG(RED "Read Http header failed\n" NONE); return; } else { char *p = strstr(header_buffer, "CONNECT"); /* 判断是否是http 隧道请求 */ if (p) { - LOG("receive CONNECT request\n"); + LOG(RED "receive CONNECT request\n" NONE); is_http_tunnel = 1; } if (strstr(header_buffer, "GET /AIS") > 0) { - LOG("====== hand AIS info request ===="); + LOG(RED "====== hand AIS info request ====" NONE); hand_mproxy_info_req(client_sock, header_buffer); return; } if (extract_host(header_buffer) < 0) { - LOG("Cannot extract host field,bad http protrotol"); + LOG(RED "Cannot extract host field,bad http protrotol" NONE); return; } - LOG("Host:%s port: %d io_flag:%d\n", remote_host, remote_port, io_flag); + LOG(RED "Host:%s port: %d io_flag:%d\n" NONE, remote_host, remote_port, io_flag); } } @@ -440,7 +447,7 @@ void handle_client(int client_sock, struct sockaddr_in client_addr) //printf("%d\n", remote_port); if ((remote_sock = create_connection6(remote_host, remote_port)) < 0) { - LOG("Cannot connect to host [%s:%d]\n", remote_host, remote_port); + LOG(RED "Cannot connect to host [%s:%d]\n" NONE, remote_host, remote_port); return; } @@ -477,8 +484,8 @@ void forward_header(int destination_sock) { rewrite_header(); #ifdef DEBUG - LOG("================ The Forward HEAD ================="); - LOG("%s\n", header_buffer); + LOG(RED "================ The Forward HEAD =================" NONE); + LOG(RED "%s\n" NONE, header_buffer); #endif int len = strlen(header_buffer); @@ -629,7 +636,7 @@ int whitelist(char *client_ip, char (*whitelist_ip)[WHITELIST_IP_NUM]) return 0; } -void server_loop(int signal, char *conffile) +void server_loop(int signals, conf *configure) { int i; char ipstr[WHITELIST_IP_NUM]; @@ -640,27 +647,25 @@ void server_loop(int signal, char *conffile) socklen_t addrlen6 = sizeof(client_addr6); char whitelist_ip[WHITELIST_IP_NUM][WHITELIST_IP_NUM] = { { 0 }, { 0 } }; - conf *configure = (struct CONF *)malloc(sizeof(struct CONF)); - read_conf(conffile, configure); - printf("%s\n", configure->IP_SEGMENT); + //printf("%s\n", configure->IP_SEGMENT); split_string(configure->IP_SEGMENT, " ", whitelist_ip); for (i = 1; i <= WHITELIST_IP_NUM - 1; i++) { - if (*whitelist_ip[i] != '\0') - printf("%s\n", whitelist_ip[i]); + if (*whitelist_ip[i] != '\0') ; + //printf("%s\n", whitelist_ip[i]); } while (1) { - if (signal == 4) { + if (signals == 4) { client_sock = accept(server_sock, (struct sockaddr *)&client_addr, &addrlen); if (client_sock > 0) { - LOG("Client Ip %s Client Port %d\n", inet_ntop(AF_INET, &client_addr.sin_addr.s_addr, ipstr, sizeof(ipstr)), ntohs(client_addr.sin_port)); + LOG(RED "Client Ip %s Client Port %d\n" NONE, inet_ntop(AF_INET, &client_addr.sin_addr.s_addr, ipstr, sizeof(ipstr)), ntohs(client_addr.sin_port)); strcpy(client_ip, inet_ntop(AF_INET, &client_addr.sin_addr.s_addr, ipstr, sizeof(ipstr))); // 复制客户端IP到client_ip if (configure->IP_RESTRICTION == 1) { if (whitelist(client_ip, whitelist_ip) == 0) { - LOG("非法IPV4客户端, 拒绝连接\n"); + LOG(RED "非法IPV4客户端, 拒绝连接\n" NONE); continue; } } @@ -674,15 +679,15 @@ void server_loop(int signal, char *conffile) close(client_sock); // 关闭父进程 client_sock } - if (signal == 6) { + if (signals == 6) { client_sock6 = accept(server_sock6, (struct sockaddr *)&client_addr6, &addrlen6); if (client_sock6 > 0) { - LOG("Client Ip %s Client Port %d\n", inet_ntop(AF_INET6, &client_addr6.sin6_addr, ipstr, sizeof(ipstr)), ntohs(client_addr6.sin6_port)); + LOG(RED "Client Ip %s Client Port %d\n" NONE, inet_ntop(AF_INET6, &client_addr6.sin6_addr, ipstr, sizeof(ipstr)), ntohs(client_addr6.sin6_port)); strcpy(client_ip, inet_ntop(AF_INET6, &client_addr6.sin6_addr, ipstr, sizeof(ipstr))); // 复制客户端IP到client_ip if (configure->IP_RESTRICTION == 1) { if (whitelist(client_ip, whitelist_ip) == 0) { - LOG("非法IPV6客户端, 拒绝连接\n"); + LOG(RED "非法IPV6客户端, 拒绝连接\n" NONE); continue; } } @@ -735,8 +740,7 @@ int create_server_socket(int port) server_addr.sin_family = AF_INET; server_addr.sin_port = htons(port); server_addr.sin_addr.s_addr = INADDR_ANY; - if (bind(server_sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) - != 0) { + if (bind(server_sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) != 0) { return SERVER_BIND_ERROR; } @@ -784,11 +788,12 @@ int create_server_socket6(int port) return server_sock; } -void start_server(int SIGNAL, char *conffile) +void start_server(int SIGNAL, conf *p) { //初始化全局变量 header_buffer = (char *)malloc(MAX_HEADER_SIZE); + // ipv4 if (SIGNAL == 4) { if ((server_sock = create_server_socket(local_port)) < 0) { // start server LOG("Cannot run server on %d\n", local_port); @@ -796,6 +801,7 @@ void start_server(int SIGNAL, char *conffile) } } + // ipv6 if (SIGNAL == 6) { if ((server_sock6 = create_server_socket6(local_port)) < 0) { // start server LOG("Cannot run server on %d\n", local_port); @@ -803,7 +809,16 @@ void start_server(int SIGNAL, char *conffile) } } - server_loop(SIGNAL, conffile); + server_loop(SIGNAL, p); +} + +void signal_free(int signal) +{ + printf("PID:%d, signal = %d\n", getpid(), signal); + free(header_buffer); + + kill(getpid(), SIGKILL); + exit(0); } int _main(int argc, char *argv[]) @@ -820,8 +835,10 @@ int _main(int argc, char *argv[]) char *p = NULL; char *conffile = "./ais.conf"; - conf *configure = (struct CONF *)malloc(sizeof(struct CONF)); + //conf *configure = (struct CONF *)malloc(sizeof(struct CONF)); + configure = (struct CONF *)malloc(sizeof(struct CONF)); read_conf(conffile, configure); + printf("%d\n", configure->local_port); printf("%s\n", configure->io_flag); printf("%d\n", configure->encode); @@ -855,6 +872,8 @@ int _main(int argc, char *argv[]) break; case 'c': conffile = optarg; + free_conf(configure); + read_conf(conffile, configure); break; case 'E': io_flag = W_S_ENC; @@ -884,7 +903,7 @@ int _main(int argc, char *argv[]) io_flag = R_C_DEC; } sslEncodeCode = configure->encode; - + if (DAEMON == 1) { // 守护进程 if (daemon(1, 1)) { perror("daemon"); @@ -892,23 +911,32 @@ int _main(int argc, char *argv[]) } } - printf("sslEncodeCode: %d\n", sslEncodeCode); + //printf("sslEncodeCode: %d\n", sslEncodeCode); get_info(info_buf); - LOG("%s\n", info_buf); + LOG(RED "%s\n" NONE, info_buf); + + { + signal(SIGTERM, signal_free); + signal(SIGHUP, signal_free); + signal(SIGINT, signal_free); + signal(SIGABRT, signal_free); + signal(SIGILL, signal_free); + signal(SIGSEGV, signal_free); + } if (fork() == 0) { // IPV4 进程 - start_server(4, conffile); + start_server(4, configure); } if (fork() == 0) { // IPV6 进程 - start_server(6, conffile); + start_server(6, configure); } free_conf(configure); return 0; } -int main(int argc, char *argv[]) +int main(int argc, char *argv[], char **envp) { return _main(argc, argv); } diff --git a/ais.conf b/ais.conf old mode 100644 new mode 100755 index 1c0c73e..14797ab --- a/ais.conf +++ b/ais.conf @@ -5,3 +5,4 @@ global { IP_RESTRICTION = 0; IP_SEGMENT= 127.0.0.1; } + diff --git a/ais.h b/ais.h old mode 100644 new mode 100755 index f9f9a45..061eeaf --- a/ais.h +++ b/ais.h @@ -3,4 +3,22 @@ #define WHITELIST_IP_NUM 2700 +// 字体颜色 +#define NONE "\033[m" +#define RED "\033[0;32;31m" +#define LIGHT_RED "\033[1;31m" +#define GREEN "\033[0;32;32m" +#define LIGHT_GREEN "\033[1;32m" +#define BLUE "\033[0;32;34m" +#define LIGHT_BLUE "\033[1;34m" +#define DARY_GRAY "\033[1;30m" +#define CYAN "\033[0;36m" +#define LIGHT_CYAN "\033[1;36m" +#define PURPLE "\033[0;35m" +#define LIGHT_PURPLE "\033[1;35m" +#define BROWN "\033[0;33m" +#define YELLOW "\033[1;33m" +#define LIGHT_GRAY "\033[0;37m" +#define WHITE "\033[1;37m" + #endif diff --git a/conf.c b/conf.c old mode 100644 new mode 100755 index 9d4fd0d..52e62a2 --- a/conf.c +++ b/conf.c @@ -1,5 +1,16 @@ #include "conf.h" +int8_t copy_new_mem(char *src, int src_len, char **dest) +{ + *dest = (char *)malloc(src_len + 1); + if (*dest == NULL) + return 1; + memcpy(*dest, src, src_len); + *((*dest) + src_len) = '\0'; + + return 0; +} + /* 在content中,设置变量(var)的首地址,值(val)的位置首地址和末地址,返回下一行指针 */ static char *set_var_val_lineEnd(char *content, char **var, char **val_begin, char **val_end) { @@ -74,7 +85,8 @@ static void parse_global_module(char *content, conf * p) if (strcasecmp(var, "io_flag") == 0) { val_begin_len = strlen(val_begin) + 1; - p->io_flag = (char *)malloc(val_begin_len); + //val_begin_len = val_end - val_begin; + p->io_flag = (char *)malloc(val_begin_len + 1); memset(p->io_flag, 0, val_begin_len); memcpy(p->io_flag, val_begin, val_begin_len); } @@ -85,7 +97,8 @@ static void parse_global_module(char *content, conf * p) if (strcasecmp(var, "IP_SEGMENT") == 0) { val_begin_len = strlen(val_begin) + 1; - p->IP_SEGMENT = (char *)malloc(val_begin_len); + //val_begin_len = val_end - val_begin; + p->IP_SEGMENT = (char *)malloc(val_begin_len + 1); memset(p->IP_SEGMENT, 0, val_begin_len); memcpy(p->IP_SEGMENT, val_begin, val_begin_len); } diff --git a/conf.h b/conf.h old mode 100644 new mode 100755 diff --git a/stript/start.sh b/stript/start.sh old mode 100644 new mode 100755 diff --git a/stript/stop.sh b/stript/stop.sh old mode 100644 new mode 100755