package main import ( "encoding/json" "flag" "fmt" "log" "net" "os" "os/exec" "os/signal" "runtime" "strings" "sync" "sync/atomic" "syscall" "time" ) var BuildDate = "unknown" func init() { // 强制使用 Go 的纯用户态 DNS 解析器 net.DefaultResolver = &net.Resolver{ PreferGo: true, } } var ( daemon = flag.Bool("d", false, "守护进程模式") child = flag.Bool("child", false, "子进程模式") ) // 全局配置变量 var ( InterfacesList bool InterfaceName *string PcapFile *string Protocol *string IPSET_NUMBER int MAX_IPSET_NAME = 100 IPSET_NAME string // --- 优化部分开始 --- // 替换 List 为带缓冲的 Channel // 容量设为 5000,足以应对大多数突发流量 IpChannel = make(chan string, 5000) // 用于入队去重的 Map,防止同一个 IP 在处理中时重复入队 // key: IP string, value: struct{} PendingIPs sync.Map PendingCount int64 // --- 优化部分结束 --- ProcessedIPMap = map[string]time.Time{} ProcessedMutex sync.Mutex local_ipv4_addr string ) // --- 核心逻辑优化:生产者 (入队) --- // PushIPToQueue 安全地将 IP 放入队列 // 该函数应由 startPacketCapture 调用 func PushIPToQueue(ipStr string) { // 0. 基础校验 if ipStr == "" { return } // 1. 快速检查:如果已经在 ProcessedIPMap (已知的国内IP或白名单),直接忽略 ProcessedMutex.Lock() _, processed := ProcessedIPMap[ipStr] ProcessedMutex.Unlock() if processed { return } // 2. 去重检查:如果已经在 PendingIPs (队列中或正在处理),跳过 // LoadOrStore: 如果 key 存在,返回 true;否则写入并返回 false if _, loaded := PendingIPs.LoadOrStore(ipStr, struct{}{}); loaded { return } // 【计数增加】只有确定要入队时才增加 atomic.AddInt64(&PendingCount, 1) // 3. 非阻塞入队 select { case IpChannel <- ipStr: // 成功入队 default: // 队列已满,丢弃该包 // 【回滚状态】从 Map 删除,并减少计数 PendingIPs.Delete(ipStr) atomic.AddInt64(&PendingCount, -1) // 可选:log.Println("警告:处理队列已满,丢弃 IP:", ipStr) } } // --- 核心逻辑优化:消费者 (处理 IP) --- func processIP(ipStr string) { // PushIPToQueue 已经拦截了空字符串,这里其实不需要再判断 // 但为了代码健壮性,如果必须判断,defer 必须在 return 之前定义 // 确保函数结束时从 Pending 状态移除并减少计数 defer func() { PendingIPs.Delete(ipStr) atomic.AddInt64(&PendingCount, -1) }() if ipStr == "" { return } // 再次检查 ProcessedIPMap (防止排队期间被其他协程处理了) ProcessedMutex.Lock() _, processed := ProcessedIPMap[ipStr] ProcessedMutex.Unlock() if processed { return } // 检查白名单 // --- 修改开始:使用读写锁检查白名单 --- whiteListLock.RLock() // 加读锁 _, isWhitelisted := whiteList[ipStr] // 检查是否存在 whiteListLock.RUnlock() // 解读锁 if isWhitelisted { log.Printf("\033[33m %s 在白名单中, 跳过 \033[0m\n", ipStr) // 尝试从 ipset 移除 RemoveIPIfInSets("root", MAX_IPSET_NAME, ipStr) return } // --- 修改结束 --- REGION := "中国 内网" // 如果 IP 已经在 ipset 中,通常无需处理 if Is_Ip_Ipset(ipStr) == 1 { //log.Printf("\033[31m %s 已在 ipset 集合中 \033[0m\n", ipStr) return } // 1. 离线库判断 (快速) region, _ := ip2region(ipStr) // 如果不包含 "中国" 也不包含 "内网",则判定为疑似国外 if !ContainsPart(region, REGION) { log.Printf("\033[33m [%s %s] 离线库为国外, 进一步API判断\033[0m\n", ipStr, region) // 2. 在线 API 判断 (慢速) position, err := curl_(ipStr) if err != nil { log.Printf("获取IP地域出错: %v", err) return // API 失败暂时跳过,等待下次重试 } log.Printf("\033[31m [%s %s]\033[0m\n", ipStr, position) if !ContainsPart(position, REGION) { // --- 确认为国外 --- AddIPSet(IPSET_NAME, ipStr) log.Printf("\033[31m [封禁] 已添加国外 IP: %s \033[0m\n", ipStr) } else { // --- 确认为国内 --- log.Printf("\033[32m %s API 修正为国内, 标记放行\033[0m\n", ipStr) ProcessedMutex.Lock() ProcessedIPMap[ipStr] = time.Now() ProcessedMutex.Unlock() } } else { // 离线库确认为国内,标记放行 ProcessedMutex.Lock() ProcessedIPMap[ipStr] = time.Now() ProcessedMutex.Unlock() } } func RunMainProcess() { log.Println(" 主进程启动...") WriteLocalAddr() cmd, err := StartChildProcess() if err != nil { log.Fatalf("子进程启动失败: %v", err) } IPSET_NUMBER = 0 IPSET_NAME = fmt.Sprintf("root%d", IPSET_NUMBER) if Is_Name_Ipset(IPSET_NAME) == 0 { // 假设 0 表示不存在 createIPSet(IPSET_NAME) } // 1. 启动抓包 (生产者) go startPacketCapture() // 2. 启动 Worker Pool (消费者) // 根据机器性能调整 worker 数量,建议 20-50 numWorkers := 40 log.Printf(" 启动 %d 个并发 Worker 处理 IP...", numWorkers) for i := 0; i < numWorkers; i++ { go func(id int) { for ip := range IpChannel { processIP(ip) } }(i) } // 定期保存 Map 数据 (替代原来在循环里保存) go func() { ticker := time.NewTicker(1 * time.Minute) for range ticker.C { if err := saveMapToFile("cn.json"); err != nil { log.Printf(" 自动保存 Map 失败: %v", err) } } }() // 过期清理 ProcessedIPMap go func() { ticker := time.NewTicker(1 * time.Minute) for range ticker.C { now := time.Now() ProcessedMutex.Lock() count := 0 for ip, t := range ProcessedIPMap { if t.Year() == 1971 { continue } if now.Sub(t) > 30*time.Minute { delete(ProcessedIPMap, ip) count++ } } ProcessedMutex.Unlock() if count > 0 { log.Printf(" 已清理 %d 个过期 ProcessedIPMap 项", count) } } }() // 白名单刷新 go func() { for { time.Sleep(10 * time.Minute) if err := LoadWhiteList("whitelist.txt"); err != nil { log.Printf(" 刷新白名单失败: %v", err) } } }() // 防火墙扩容管理 // 防火墙扩容管理 go func() { for { time.Sleep(1 * time.Second) // 1. 获取当前集合长度,必须处理错误 ipset_len, err := NumIPSet(IPSET_NAME) if err != nil { // 如果是因为集合不存在导致的错误,尝试创建它 if Is_Name_Ipset(IPSET_NAME) != 0 { log.Printf("检测到集合 %s 不存在,正在初始化...", IPSET_NAME) createIPSet(IPSET_NAME) iptables_add(IPSET_NAME) // 重点:创建后必须同步添加 iptables 规则 } continue } // 2. 检查是否需要扩容 // 注意:ipset 默认 maxelem 是 65536,达到这个值 add 操作就会失败 if ipset_len >= 65530 { log.Printf("\033[31m ipset %s 列表已满 %d,准备扩容... \033[0m\n", IPSET_NAME, ipset_len) IPSET_NUMBER++ if IPSET_NUMBER >= MAX_IPSET_NAME { log.Printf("\033[31m 警告:已达到最大集合数量限制!!! \033[0m\n") // 这里可以根据需求决定是否 return 或采取其他措施 } newSetName := fmt.Sprintf("root%d", IPSET_NUMBER) // 3. 只有不存在时才创建 (注意判断逻辑: != 0 表示不存在) if Is_Name_Ipset(newSetName) != 0 { log.Printf("\033[32m 正在创建并应用新集合: %s \033[0m\n", newSetName) if err := createIPSet(newSetName); err == nil { // 4. 关键:创建新集合后,必须立刻将其加入 iptables 拦截规则 iptables_add(newSetName) // 5. 切换全局变量 IPSET_NAME = newSetName } else { log.Printf("创建集合失败: %v", err) } } else { // 如果集合已经存在(可能是上次运行留下的),直接切换过去 IPSET_NAME = newSetName iptables_add(newSetName) // 确保规则存在 } } } }() // 打印日志 go func() { for { time.Sleep(7 * time.Second) // 获取 ipset 数量 ipset_len, _ := NumIPSet(IPSET_NAME) // 安全地获取 ProcessedIPMap 的长度(不要打印内容,不要在无锁状态下读取!) ProcessedMutex.Lock() processedLen := len(ProcessedIPMap) ProcessedMutex.Unlock() // 获取 Pending 数量 pendingCount := atomic.LoadInt64(&PendingCount) log.Printf("\033[32m [状态监控] IPSet(%s): %d | 已处理缓存: %d | 待处理积压: %d \033[0m\n", IPSET_NAME, ipset_len, processedLen, pendingCount) } }() waitForSignalAndCleanUp(cmd) } func StartChildProcess() (*exec.Cmd, error) { args := []string{} for _, arg := range os.Args[1:] { if !strings.HasPrefix(arg, "-child") { args = append(args, arg) } } args = append(args, "-child=true") cmd := exec.Command(os.Args[0], args...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { return nil, fmt.Errorf("启动子进程失败: %w", err) } log.Printf(" 子进程已启动, PID: %d\n", cmd.Process.Pid) return cmd, nil } func StopChildProcess(cmd *exec.Cmd) error { if cmd == nil || cmd.Process == nil { return fmt.Errorf("子进程无效") } if err := cmd.Process.Signal(syscall.SIGTERM); err != nil { return fmt.Errorf("SIGTERM 失败: %w", err) } done := make(chan error, 1) go func() { done <- cmd.Wait() }() select { case err := <-done: return err case <-time.After(2 * time.Second): _ = cmd.Process.Kill() } return nil } func waitForSignalAndCleanUp(cmd *exec.Cmd) { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) sig := <-sigChan fmt.Printf("主进程收到信号: %v\n", sig) if cmd != nil && cmd.Process != nil { _ = cmd.Process.Signal(sig) } StopChildProcess(cmd) saveMapToFile("cn.json") } func StartDaemon() { args := []string{} for _, arg := range os.Args[1:] { if !strings.HasPrefix(arg, "-d") && !strings.HasPrefix(arg, "-child") { args = append(args, arg) } } args = append(args, "-d=false", "-child=false") cmd := exec.Command(os.Args[0], args...) cmd.Start() os.Exit(0) } func RunChildProcess() { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM) <-sigChan } func saveMapToFile(filePath string) error { ProcessedMutex.Lock() defer ProcessedMutex.Unlock() file, err := os.Create(filePath) if err != nil { return err } defer file.Close() return json.NewEncoder(file).Encode(ProcessedIPMap) } func loadFromFile(filePath string) error { ProcessedMutex.Lock() defer ProcessedMutex.Unlock() file, err := os.Open(filePath) if err != nil { return nil } defer file.Close() return json.NewDecoder(file).Decode(&ProcessedIPMap) } func InitMap() { loadFromFile("cn.json") } func WriteLocalAddr() { local_ipv4_addr = GetLocalIpv4Addr() if local_ipv4_addr != "NULL" { ProcessedMutex.Lock() ProcessedIPMap[local_ipv4_addr] = time.Date(1971, 1, 1, 0, 0, 0, 0, time.UTC) ProcessedMutex.Unlock() } } func HandleCmd() { // 定义命令行标志 var instruction string var help bool InterfaceName = flag.String("i", "", "指定要使用的网络接口") flag.BoolVar(&InterfacesList, "l", false, "列出可用的网络接口") Protocol = flag.String("f", "'tcp' or 'udp' or 'tcp or udp'", "指定 BPF 过滤器") PcapFile = flag.String("o", "", "保存捕获数据的输出文件(可选)") flag.StringVar(&instruction, "s", "", "-s start 启动 Iptables 规则\n"+ "-s stop 停止 Iptables 规则\n"+ "-s list 打印 Iptables 规则\n"+ "-s reload 重启 Iptables 规则") flag.BoolVar(&help, "h", false, "") flag.BoolVar(&help, "help", false, "帮助信息") flag.Parse() if help { fmt.Printf( "\t\tDenyip firewall\n"+ "\tVersion 0.2\n"+ "\tE-mail: aixiao@aixiao.me\n"+ "\tBuild Date: %s\n", BuildDate) flag.Usage() fmt.Printf("\n") os.Exit(0) } if instruction != "" { switch instruction { case "start": for i := 0; i < MAX_IPSET_NAME; i++ { _name := fmt.Sprintf("root%d", i) iptables_add(_name) } os.Exit(0) case "stop": for i := 0; i < MAX_IPSET_NAME; i++ { _name := fmt.Sprintf("root%d", i) iptables_del(_name) } os.Exit(0) case "r": fallthrough case "restart": fallthrough case "reload": for i := 0; i < MAX_IPSET_NAME; i++ { _name := fmt.Sprintf("root%d", i) iptables_del(_name) } for i := 0; i < MAX_IPSET_NAME; i++ { _name := fmt.Sprintf("root%d", i) iptables_add(_name) } os.Exit(0) case "l": fallthrough case "list": _ = iptables_list() os.Exit(0) default: log.Fatalf("未知的操作: %s. 请使用 'start' 或 'stop'.", instruction) } } if InterfacesList { printAvailableInterfaces() os.Exit(0) } if *InterfaceName == "" { log.Fatal("请使用 -i 标志指定网络接口,或者使用 -l 列出接口。") } } func main() { runtime.GOMAXPROCS(runtime.NumCPU()) HandleCmd() CheckCommandExists("iptables") CheckCommandExists("ipset") embed_ip2region() if *daemon { StartDaemon() } if *child { RunChildProcess() return } InitMap() RunMainProcess() }