535 lines
13 KiB
Go
535 lines
13 KiB
Go
package main
|
||
|
||
import (
|
||
"encoding/json"
|
||
"flag"
|
||
"fmt"
|
||
"log"
|
||
"net"
|
||
"os"
|
||
"os/exec"
|
||
"os/signal"
|
||
"runtime"
|
||
"strings"
|
||
"sync"
|
||
"sync/atomic"
|
||
"syscall"
|
||
"time"
|
||
)
|
||
|
||
var BuildDate = "unknown"
|
||
|
||
func init() {
|
||
// 强制使用 Go 的纯用户态 DNS 解析器
|
||
net.DefaultResolver = &net.Resolver{
|
||
PreferGo: true,
|
||
}
|
||
}
|
||
|
||
var (
|
||
daemon = flag.Bool("d", false, "守护进程模式")
|
||
child = flag.Bool("child", false, "子进程模式")
|
||
)
|
||
|
||
// 全局配置变量
|
||
var (
|
||
InterfacesList bool
|
||
InterfaceName *string
|
||
PcapFile *string
|
||
Protocol *string
|
||
|
||
IPSET_NUMBER int
|
||
MAX_IPSET_NAME = 100
|
||
IPSET_NAME string
|
||
|
||
// --- 优化部分开始 ---
|
||
|
||
// 替换 List 为带缓冲的 Channel
|
||
// 容量设为 5000,足以应对大多数突发流量
|
||
IpChannel = make(chan string, 5000)
|
||
|
||
// 用于入队去重的 Map,防止同一个 IP 在处理中时重复入队
|
||
// key: IP string, value: struct{}
|
||
PendingIPs sync.Map
|
||
PendingCount int64
|
||
|
||
// --- 优化部分结束 ---
|
||
|
||
ProcessedIPMap = map[string]time.Time{}
|
||
ProcessedMutex sync.Mutex
|
||
|
||
local_ipv4_addr string
|
||
)
|
||
|
||
// --- 核心逻辑优化:生产者 (入队) ---
|
||
// PushIPToQueue 安全地将 IP 放入队列
|
||
// 该函数应由 startPacketCapture 调用
|
||
func PushIPToQueue(ipStr string) {
|
||
// 0. 基础校验
|
||
if ipStr == "" {
|
||
return
|
||
}
|
||
|
||
// 1. 快速检查:如果已经在 ProcessedIPMap (已知的国内IP或白名单),直接忽略
|
||
ProcessedMutex.Lock()
|
||
_, processed := ProcessedIPMap[ipStr]
|
||
ProcessedMutex.Unlock()
|
||
if processed {
|
||
return
|
||
}
|
||
|
||
// 2. 去重检查:如果已经在 PendingIPs (队列中或正在处理),跳过
|
||
// LoadOrStore: 如果 key 存在,返回 true;否则写入并返回 false
|
||
if _, loaded := PendingIPs.LoadOrStore(ipStr, struct{}{}); loaded {
|
||
return
|
||
}
|
||
|
||
// 【计数增加】只有确定要入队时才增加
|
||
atomic.AddInt64(&PendingCount, 1)
|
||
|
||
// 3. 非阻塞入队
|
||
select {
|
||
case IpChannel <- ipStr:
|
||
// 成功入队
|
||
default:
|
||
// 队列已满,丢弃该包
|
||
// 【回滚状态】从 Map 删除,并减少计数
|
||
PendingIPs.Delete(ipStr)
|
||
atomic.AddInt64(&PendingCount, -1)
|
||
// 可选:log.Println("警告:处理队列已满,丢弃 IP:", ipStr)
|
||
}
|
||
}
|
||
|
||
// --- 核心逻辑优化:消费者 (处理 IP) ---
|
||
func processIP(ipStr string) {
|
||
// PushIPToQueue 已经拦截了空字符串,这里其实不需要再判断
|
||
// 但为了代码健壮性,如果必须判断,defer 必须在 return 之前定义
|
||
|
||
// 确保函数结束时从 Pending 状态移除并减少计数
|
||
defer func() {
|
||
PendingIPs.Delete(ipStr)
|
||
atomic.AddInt64(&PendingCount, -1)
|
||
}()
|
||
|
||
if ipStr == "" {
|
||
return
|
||
}
|
||
|
||
// 再次检查 ProcessedIPMap (防止排队期间被其他协程处理了)
|
||
ProcessedMutex.Lock()
|
||
_, processed := ProcessedIPMap[ipStr]
|
||
ProcessedMutex.Unlock()
|
||
if processed {
|
||
return
|
||
}
|
||
|
||
// 检查白名单
|
||
// --- 修改开始:使用读写锁检查白名单 ---
|
||
whiteListLock.RLock() // 加读锁
|
||
_, isWhitelisted := whiteList[ipStr] // 检查是否存在
|
||
whiteListLock.RUnlock() // 解读锁
|
||
|
||
if isWhitelisted {
|
||
log.Printf("\033[33m %s 在白名单中, 跳过 \033[0m\n", ipStr)
|
||
// 尝试从 ipset 移除
|
||
RemoveIPIfInSets("root", MAX_IPSET_NAME, ipStr)
|
||
return
|
||
}
|
||
// --- 修改结束 ---
|
||
|
||
REGION := "中国 内网"
|
||
|
||
// 如果 IP 已经在 ipset 中,通常无需处理
|
||
if Is_Ip_Ipset(ipStr) == 1 {
|
||
//log.Printf("\033[31m %s 已在 ipset 集合中 \033[0m\n", ipStr)
|
||
return
|
||
}
|
||
|
||
// 1. 离线库判断 (快速)
|
||
region, _ := ip2region(ipStr)
|
||
|
||
// 如果不包含 "中国" 也不包含 "内网",则判定为疑似国外
|
||
if !ContainsPart(region, REGION) {
|
||
log.Printf("\033[33m [%s %s] 离线库为国外, 进一步API判断\033[0m\n", ipStr, region)
|
||
|
||
// 2. 在线 API 判断 (慢速)
|
||
position, err := curl_(ipStr)
|
||
if err != nil {
|
||
log.Printf("获取IP地域出错: %v", err)
|
||
return // API 失败暂时跳过,等待下次重试
|
||
}
|
||
|
||
log.Printf("\033[31m [%s %s]\033[0m\n", ipStr, position)
|
||
|
||
if !ContainsPart(position, REGION) {
|
||
// --- 确认为国外 ---
|
||
AddIPSet(IPSET_NAME, ipStr)
|
||
log.Printf("\033[31m [封禁] 已添加国外 IP: %s \033[0m\n", ipStr)
|
||
} else {
|
||
// --- 确认为国内 ---
|
||
log.Printf("\033[32m %s API 修正为国内, 标记放行\033[0m\n", ipStr)
|
||
ProcessedMutex.Lock()
|
||
ProcessedIPMap[ipStr] = time.Now()
|
||
ProcessedMutex.Unlock()
|
||
}
|
||
} else {
|
||
// 离线库确认为国内,标记放行
|
||
ProcessedMutex.Lock()
|
||
ProcessedIPMap[ipStr] = time.Now()
|
||
ProcessedMutex.Unlock()
|
||
}
|
||
}
|
||
|
||
func RunMainProcess() {
|
||
log.Println(" 主进程启动...")
|
||
|
||
WriteLocalAddr()
|
||
|
||
cmd, err := StartChildProcess()
|
||
if err != nil {
|
||
log.Fatalf("子进程启动失败: %v", err)
|
||
}
|
||
|
||
IPSET_NUMBER = 0
|
||
IPSET_NAME = fmt.Sprintf("root%d", IPSET_NUMBER)
|
||
if Is_Name_Ipset(IPSET_NAME) == 0 { // 假设 0 表示不存在
|
||
createIPSet(IPSET_NAME)
|
||
}
|
||
|
||
// 1. 启动抓包 (生产者)
|
||
go startPacketCapture()
|
||
|
||
// 2. 启动 Worker Pool (消费者)
|
||
// 根据机器性能调整 worker 数量,建议 20-50
|
||
numWorkers := 40
|
||
log.Printf(" 启动 %d 个并发 Worker 处理 IP...", numWorkers)
|
||
for i := 0; i < numWorkers; i++ {
|
||
go func(id int) {
|
||
for ip := range IpChannel {
|
||
processIP(ip)
|
||
}
|
||
}(i)
|
||
}
|
||
|
||
// 定期保存 Map 数据 (替代原来在循环里保存)
|
||
go func() {
|
||
ticker := time.NewTicker(1 * time.Minute)
|
||
for range ticker.C {
|
||
if err := saveMapToFile("cn.json"); err != nil {
|
||
log.Printf(" 自动保存 Map 失败: %v", err)
|
||
}
|
||
}
|
||
}()
|
||
|
||
// 过期清理 ProcessedIPMap
|
||
go func() {
|
||
ticker := time.NewTicker(1 * time.Minute)
|
||
for range ticker.C {
|
||
now := time.Now()
|
||
ProcessedMutex.Lock()
|
||
count := 0
|
||
for ip, t := range ProcessedIPMap {
|
||
if t.Year() == 1971 {
|
||
continue
|
||
}
|
||
if now.Sub(t) > 30*time.Minute {
|
||
delete(ProcessedIPMap, ip)
|
||
count++
|
||
}
|
||
}
|
||
ProcessedMutex.Unlock()
|
||
if count > 0 {
|
||
log.Printf(" 已清理 %d 个过期 ProcessedIPMap 项", count)
|
||
}
|
||
}
|
||
}()
|
||
|
||
// 白名单刷新
|
||
go func() {
|
||
for {
|
||
time.Sleep(10 * time.Minute)
|
||
if err := LoadWhiteList("whitelist.txt"); err != nil {
|
||
log.Printf(" 刷新白名单失败: %v", err)
|
||
}
|
||
}
|
||
}()
|
||
|
||
// 防火墙扩容管理
|
||
// 防火墙扩容管理
|
||
go func() {
|
||
for {
|
||
time.Sleep(1 * time.Second)
|
||
|
||
// 1. 获取当前集合长度,必须处理错误
|
||
ipset_len, err := NumIPSet(IPSET_NAME)
|
||
if err != nil {
|
||
// 如果是因为集合不存在导致的错误,尝试创建它
|
||
if Is_Name_Ipset(IPSET_NAME) != 0 {
|
||
log.Printf("检测到集合 %s 不存在,正在初始化...", IPSET_NAME)
|
||
createIPSet(IPSET_NAME)
|
||
iptables_add(IPSET_NAME) // 重点:创建后必须同步添加 iptables 规则
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 2. 检查是否需要扩容
|
||
// 注意:ipset 默认 maxelem 是 65536,达到这个值 add 操作就会失败
|
||
if ipset_len >= 65530 {
|
||
log.Printf("\033[31m ipset %s 列表已满 %d,准备扩容... \033[0m\n", IPSET_NAME, ipset_len)
|
||
|
||
IPSET_NUMBER++
|
||
if IPSET_NUMBER >= MAX_IPSET_NAME {
|
||
log.Printf("\033[31m 警告:已达到最大集合数量限制!!! \033[0m\n")
|
||
// 这里可以根据需求决定是否 return 或采取其他措施
|
||
}
|
||
|
||
newSetName := fmt.Sprintf("root%d", IPSET_NUMBER)
|
||
|
||
// 3. 只有不存在时才创建 (注意判断逻辑: != 0 表示不存在)
|
||
if Is_Name_Ipset(newSetName) != 0 {
|
||
log.Printf("\033[32m 正在创建并应用新集合: %s \033[0m\n", newSetName)
|
||
if err := createIPSet(newSetName); err == nil {
|
||
// 4. 关键:创建新集合后,必须立刻将其加入 iptables 拦截规则
|
||
iptables_add(newSetName)
|
||
// 5. 切换全局变量
|
||
IPSET_NAME = newSetName
|
||
} else {
|
||
log.Printf("创建集合失败: %v", err)
|
||
}
|
||
} else {
|
||
// 如果集合已经存在(可能是上次运行留下的),直接切换过去
|
||
IPSET_NAME = newSetName
|
||
iptables_add(newSetName) // 确保规则存在
|
||
}
|
||
}
|
||
}
|
||
}()
|
||
|
||
// 打印日志
|
||
go func() {
|
||
for {
|
||
time.Sleep(7 * time.Second)
|
||
|
||
// 获取 ipset 数量
|
||
ipset_len, _ := NumIPSet(IPSET_NAME)
|
||
|
||
// 安全地获取 ProcessedIPMap 的长度(不要打印内容,不要在无锁状态下读取!)
|
||
ProcessedMutex.Lock()
|
||
processedLen := len(ProcessedIPMap)
|
||
ProcessedMutex.Unlock()
|
||
|
||
// 获取 Pending 数量
|
||
pendingCount := atomic.LoadInt64(&PendingCount)
|
||
|
||
log.Printf("\033[32m [状态监控] IPSet(%s): %d | 已处理缓存: %d | 待处理积压: %d \033[0m\n",
|
||
IPSET_NAME, ipset_len, processedLen, pendingCount)
|
||
}
|
||
}()
|
||
|
||
waitForSignalAndCleanUp(cmd)
|
||
}
|
||
|
||
func StartChildProcess() (*exec.Cmd, error) {
|
||
args := []string{}
|
||
for _, arg := range os.Args[1:] {
|
||
if !strings.HasPrefix(arg, "-child") {
|
||
args = append(args, arg)
|
||
}
|
||
}
|
||
args = append(args, "-child=true")
|
||
cmd := exec.Command(os.Args[0], args...)
|
||
cmd.Stdout = os.Stdout
|
||
cmd.Stderr = os.Stderr
|
||
|
||
if err := cmd.Start(); err != nil {
|
||
return nil, fmt.Errorf("启动子进程失败: %w", err)
|
||
}
|
||
log.Printf(" 子进程已启动, PID: %d\n", cmd.Process.Pid)
|
||
return cmd, nil
|
||
}
|
||
|
||
func StopChildProcess(cmd *exec.Cmd) error {
|
||
if cmd == nil || cmd.Process == nil {
|
||
return fmt.Errorf("子进程无效")
|
||
}
|
||
if err := cmd.Process.Signal(syscall.SIGTERM); err != nil {
|
||
return fmt.Errorf("SIGTERM 失败: %w", err)
|
||
}
|
||
done := make(chan error, 1)
|
||
go func() { done <- cmd.Wait() }()
|
||
select {
|
||
case err := <-done:
|
||
return err
|
||
case <-time.After(2 * time.Second):
|
||
_ = cmd.Process.Kill()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func waitForSignalAndCleanUp(cmd *exec.Cmd) {
|
||
sigChan := make(chan os.Signal, 1)
|
||
signal.Notify(sigChan, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
|
||
sig := <-sigChan
|
||
fmt.Printf("主进程收到信号: %v\n", sig)
|
||
if cmd != nil && cmd.Process != nil {
|
||
_ = cmd.Process.Signal(sig)
|
||
}
|
||
StopChildProcess(cmd)
|
||
saveMapToFile("cn.json")
|
||
}
|
||
|
||
func StartDaemon() {
|
||
args := []string{}
|
||
for _, arg := range os.Args[1:] {
|
||
if !strings.HasPrefix(arg, "-d") && !strings.HasPrefix(arg, "-child") {
|
||
args = append(args, arg)
|
||
}
|
||
}
|
||
args = append(args, "-d=false", "-child=false")
|
||
cmd := exec.Command(os.Args[0], args...)
|
||
cmd.Start()
|
||
os.Exit(0)
|
||
}
|
||
|
||
func RunChildProcess() {
|
||
sigChan := make(chan os.Signal, 1)
|
||
signal.Notify(sigChan, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM)
|
||
<-sigChan
|
||
}
|
||
|
||
func saveMapToFile(filePath string) error {
|
||
ProcessedMutex.Lock()
|
||
defer ProcessedMutex.Unlock()
|
||
file, err := os.Create(filePath)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer file.Close()
|
||
return json.NewEncoder(file).Encode(ProcessedIPMap)
|
||
}
|
||
|
||
func loadFromFile(filePath string) error {
|
||
ProcessedMutex.Lock()
|
||
defer ProcessedMutex.Unlock()
|
||
file, err := os.Open(filePath)
|
||
if err != nil {
|
||
return nil
|
||
}
|
||
defer file.Close()
|
||
return json.NewDecoder(file).Decode(&ProcessedIPMap)
|
||
}
|
||
|
||
func InitMap() {
|
||
loadFromFile("cn.json")
|
||
}
|
||
|
||
func WriteLocalAddr() {
|
||
local_ipv4_addr = GetLocalIpv4Addr()
|
||
if local_ipv4_addr != "NULL" {
|
||
ProcessedMutex.Lock()
|
||
ProcessedIPMap[local_ipv4_addr] = time.Date(1971, 1, 1, 0, 0, 0, 0, time.UTC)
|
||
ProcessedMutex.Unlock()
|
||
}
|
||
}
|
||
|
||
func HandleCmd() {
|
||
// 定义命令行标志
|
||
var instruction string
|
||
var help bool
|
||
InterfaceName = flag.String("i", "", "指定要使用的网络接口")
|
||
flag.BoolVar(&InterfacesList, "l", false, "列出可用的网络接口")
|
||
Protocol = flag.String("f", "'tcp' or 'udp' or 'tcp or udp'", "指定 BPF 过滤器")
|
||
PcapFile = flag.String("o", "", "保存捕获数据的输出文件(可选)")
|
||
flag.StringVar(&instruction, "s", "",
|
||
"-s start 启动 Iptables 规则\n"+
|
||
"-s stop 停止 Iptables 规则\n"+
|
||
"-s list 打印 Iptables 规则\n"+
|
||
"-s reload 重启 Iptables 规则")
|
||
flag.BoolVar(&help, "h", false, "")
|
||
flag.BoolVar(&help, "help", false, "帮助信息")
|
||
flag.Parse()
|
||
|
||
if help {
|
||
fmt.Printf(
|
||
"\t\tDenyip firewall\n"+
|
||
"\tVersion 0.2\n"+
|
||
"\tE-mail: aixiao@aixiao.me\n"+
|
||
"\tBuild Date: %s\n", BuildDate)
|
||
|
||
flag.Usage()
|
||
fmt.Printf("\n")
|
||
os.Exit(0)
|
||
}
|
||
|
||
if instruction != "" {
|
||
switch instruction {
|
||
case "start":
|
||
for i := 0; i < MAX_IPSET_NAME; i++ {
|
||
_name := fmt.Sprintf("root%d", i)
|
||
iptables_add(_name)
|
||
}
|
||
|
||
os.Exit(0)
|
||
case "stop":
|
||
for i := 0; i < MAX_IPSET_NAME; i++ {
|
||
_name := fmt.Sprintf("root%d", i)
|
||
iptables_del(_name)
|
||
}
|
||
|
||
os.Exit(0)
|
||
case "r":
|
||
fallthrough
|
||
case "restart":
|
||
fallthrough
|
||
case "reload":
|
||
for i := 0; i < MAX_IPSET_NAME; i++ {
|
||
_name := fmt.Sprintf("root%d", i)
|
||
iptables_del(_name)
|
||
}
|
||
|
||
for i := 0; i < MAX_IPSET_NAME; i++ {
|
||
_name := fmt.Sprintf("root%d", i)
|
||
iptables_add(_name)
|
||
}
|
||
|
||
os.Exit(0)
|
||
case "l":
|
||
fallthrough
|
||
case "list":
|
||
_ = iptables_list()
|
||
os.Exit(0)
|
||
default:
|
||
log.Fatalf("未知的操作: %s. 请使用 'start' 或 'stop'.", instruction)
|
||
}
|
||
}
|
||
|
||
if InterfacesList {
|
||
printAvailableInterfaces()
|
||
os.Exit(0)
|
||
}
|
||
|
||
if *InterfaceName == "" {
|
||
log.Fatal("请使用 -i 标志指定网络接口,或者使用 -l 列出接口。")
|
||
}
|
||
}
|
||
|
||
func main() {
|
||
runtime.GOMAXPROCS(runtime.NumCPU())
|
||
HandleCmd()
|
||
|
||
CheckCommandExists("iptables")
|
||
CheckCommandExists("ipset")
|
||
embed_ip2region()
|
||
|
||
if *daemon {
|
||
StartDaemon()
|
||
}
|
||
if *child {
|
||
RunChildProcess()
|
||
return
|
||
}
|
||
|
||
InitMap()
|
||
RunMainProcess()
|
||
}
|