Files
DenyIP-go/main.go
2025-12-23 13:14:53 +08:00

535 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"encoding/json"
"flag"
"fmt"
"log"
"net"
"os"
"os/exec"
"os/signal"
"runtime"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
)
var BuildDate = "unknown"
func init() {
// 强制使用 Go 的纯用户态 DNS 解析器
net.DefaultResolver = &net.Resolver{
PreferGo: true,
}
}
var (
daemon = flag.Bool("d", false, "守护进程模式")
child = flag.Bool("child", false, "子进程模式")
)
// 全局配置变量
var (
InterfacesList bool
InterfaceName *string
PcapFile *string
Protocol *string
IPSET_NUMBER int
MAX_IPSET_NAME = 100
IPSET_NAME string
// --- 优化部分开始 ---
// 替换 List 为带缓冲的 Channel
// 容量设为 5000足以应对大多数突发流量
IpChannel = make(chan string, 5000)
// 用于入队去重的 Map防止同一个 IP 在处理中时重复入队
// key: IP string, value: struct{}
PendingIPs sync.Map
PendingCount int64
// --- 优化部分结束 ---
ProcessedIPMap = map[string]time.Time{}
ProcessedMutex sync.Mutex
local_ipv4_addr string
)
// --- 核心逻辑优化:生产者 (入队) ---
// PushIPToQueue 安全地将 IP 放入队列
// 该函数应由 startPacketCapture 调用
func PushIPToQueue(ipStr string) {
// 0. 基础校验
if ipStr == "" {
return
}
// 1. 快速检查:如果已经在 ProcessedIPMap (已知的国内IP或白名单),直接忽略
ProcessedMutex.Lock()
_, processed := ProcessedIPMap[ipStr]
ProcessedMutex.Unlock()
if processed {
return
}
// 2. 去重检查:如果已经在 PendingIPs (队列中或正在处理),跳过
// LoadOrStore: 如果 key 存在,返回 true否则写入并返回 false
if _, loaded := PendingIPs.LoadOrStore(ipStr, struct{}{}); loaded {
return
}
// 【计数增加】只有确定要入队时才增加
atomic.AddInt64(&PendingCount, 1)
// 3. 非阻塞入队
select {
case IpChannel <- ipStr:
// 成功入队
default:
// 队列已满,丢弃该包
// 【回滚状态】从 Map 删除,并减少计数
PendingIPs.Delete(ipStr)
atomic.AddInt64(&PendingCount, -1)
// 可选log.Println("警告:处理队列已满,丢弃 IP:", ipStr)
}
}
// --- 核心逻辑优化:消费者 (处理 IP) ---
func processIP(ipStr string) {
// PushIPToQueue 已经拦截了空字符串,这里其实不需要再判断
// 但为了代码健壮性如果必须判断defer 必须在 return 之前定义
// 确保函数结束时从 Pending 状态移除并减少计数
defer func() {
PendingIPs.Delete(ipStr)
atomic.AddInt64(&PendingCount, -1)
}()
if ipStr == "" {
return
}
// 再次检查 ProcessedIPMap (防止排队期间被其他协程处理了)
ProcessedMutex.Lock()
_, processed := ProcessedIPMap[ipStr]
ProcessedMutex.Unlock()
if processed {
return
}
// 检查白名单
// --- 修改开始:使用读写锁检查白名单 ---
whiteListLock.RLock() // 加读锁
_, isWhitelisted := whiteList[ipStr] // 检查是否存在
whiteListLock.RUnlock() // 解读锁
if isWhitelisted {
log.Printf("\033[33m %s 在白名单中, 跳过 \033[0m\n", ipStr)
// 尝试从 ipset 移除
RemoveIPIfInSets("root", MAX_IPSET_NAME, ipStr)
return
}
// --- 修改结束 ---
REGION := "中国 内网"
// 如果 IP 已经在 ipset 中,通常无需处理
if Is_Ip_Ipset(ipStr) == 1 {
//log.Printf("\033[31m %s 已在 ipset 集合中 \033[0m\n", ipStr)
return
}
// 1. 离线库判断 (快速)
region, _ := ip2region(ipStr)
// 如果不包含 "中国" 也不包含 "内网",则判定为疑似国外
if !ContainsPart(region, REGION) {
log.Printf("\033[33m [%s %s] 离线库为国外, 进一步API判断\033[0m\n", ipStr, region)
// 2. 在线 API 判断 (慢速)
position, err := curl_(ipStr)
if err != nil {
log.Printf("获取IP地域出错: %v", err)
return // API 失败暂时跳过,等待下次重试
}
log.Printf("\033[31m [%s %s]\033[0m\n", ipStr, position)
if !ContainsPart(position, REGION) {
// --- 确认为国外 ---
AddIPSet(IPSET_NAME, ipStr)
log.Printf("\033[31m [封禁] 已添加国外 IP: %s \033[0m\n", ipStr)
} else {
// --- 确认为国内 ---
log.Printf("\033[32m %s API 修正为国内, 标记放行\033[0m\n", ipStr)
ProcessedMutex.Lock()
ProcessedIPMap[ipStr] = time.Now()
ProcessedMutex.Unlock()
}
} else {
// 离线库确认为国内,标记放行
ProcessedMutex.Lock()
ProcessedIPMap[ipStr] = time.Now()
ProcessedMutex.Unlock()
}
}
func RunMainProcess() {
log.Println(" 主进程启动...")
WriteLocalAddr()
cmd, err := StartChildProcess()
if err != nil {
log.Fatalf("子进程启动失败: %v", err)
}
IPSET_NUMBER = 0
IPSET_NAME = fmt.Sprintf("root%d", IPSET_NUMBER)
if Is_Name_Ipset(IPSET_NAME) == 0 { // 假设 0 表示不存在
createIPSet(IPSET_NAME)
}
// 1. 启动抓包 (生产者)
go startPacketCapture()
// 2. 启动 Worker Pool (消费者)
// 根据机器性能调整 worker 数量,建议 20-50
numWorkers := 40
log.Printf(" 启动 %d 个并发 Worker 处理 IP...", numWorkers)
for i := 0; i < numWorkers; i++ {
go func(id int) {
for ip := range IpChannel {
processIP(ip)
}
}(i)
}
// 定期保存 Map 数据 (替代原来在循环里保存)
go func() {
ticker := time.NewTicker(1 * time.Minute)
for range ticker.C {
if err := saveMapToFile("cn.json"); err != nil {
log.Printf(" 自动保存 Map 失败: %v", err)
}
}
}()
// 过期清理 ProcessedIPMap
go func() {
ticker := time.NewTicker(1 * time.Minute)
for range ticker.C {
now := time.Now()
ProcessedMutex.Lock()
count := 0
for ip, t := range ProcessedIPMap {
if t.Year() == 1971 {
continue
}
if now.Sub(t) > 30*time.Minute {
delete(ProcessedIPMap, ip)
count++
}
}
ProcessedMutex.Unlock()
if count > 0 {
log.Printf(" 已清理 %d 个过期 ProcessedIPMap 项", count)
}
}
}()
// 白名单刷新
go func() {
for {
time.Sleep(10 * time.Minute)
if err := LoadWhiteList("whitelist.txt"); err != nil {
log.Printf(" 刷新白名单失败: %v", err)
}
}
}()
// 防火墙扩容管理
// 防火墙扩容管理
go func() {
for {
time.Sleep(1 * time.Second)
// 1. 获取当前集合长度,必须处理错误
ipset_len, err := NumIPSet(IPSET_NAME)
if err != nil {
// 如果是因为集合不存在导致的错误,尝试创建它
if Is_Name_Ipset(IPSET_NAME) != 0 {
log.Printf("检测到集合 %s 不存在,正在初始化...", IPSET_NAME)
createIPSet(IPSET_NAME)
iptables_add(IPSET_NAME) // 重点:创建后必须同步添加 iptables 规则
}
continue
}
// 2. 检查是否需要扩容
// 注意ipset 默认 maxelem 是 65536达到这个值 add 操作就会失败
if ipset_len >= 65530 {
log.Printf("\033[31m ipset %s 列表已满 %d准备扩容... \033[0m\n", IPSET_NAME, ipset_len)
IPSET_NUMBER++
if IPSET_NUMBER >= MAX_IPSET_NAME {
log.Printf("\033[31m 警告:已达到最大集合数量限制!!! \033[0m\n")
// 这里可以根据需求决定是否 return 或采取其他措施
}
newSetName := fmt.Sprintf("root%d", IPSET_NUMBER)
// 3. 只有不存在时才创建 (注意判断逻辑: != 0 表示不存在)
if Is_Name_Ipset(newSetName) != 0 {
log.Printf("\033[32m 正在创建并应用新集合: %s \033[0m\n", newSetName)
if err := createIPSet(newSetName); err == nil {
// 4. 关键:创建新集合后,必须立刻将其加入 iptables 拦截规则
iptables_add(newSetName)
// 5. 切换全局变量
IPSET_NAME = newSetName
} else {
log.Printf("创建集合失败: %v", err)
}
} else {
// 如果集合已经存在(可能是上次运行留下的),直接切换过去
IPSET_NAME = newSetName
iptables_add(newSetName) // 确保规则存在
}
}
}
}()
// 打印日志
go func() {
for {
time.Sleep(7 * time.Second)
// 获取 ipset 数量
ipset_len, _ := NumIPSet(IPSET_NAME)
// 安全地获取 ProcessedIPMap 的长度(不要打印内容,不要在无锁状态下读取!)
ProcessedMutex.Lock()
processedLen := len(ProcessedIPMap)
ProcessedMutex.Unlock()
// 获取 Pending 数量
pendingCount := atomic.LoadInt64(&PendingCount)
log.Printf("\033[32m [状态监控] IPSet(%s): %d | 已处理缓存: %d | 待处理积压: %d \033[0m\n",
IPSET_NAME, ipset_len, processedLen, pendingCount)
}
}()
waitForSignalAndCleanUp(cmd)
}
func StartChildProcess() (*exec.Cmd, error) {
args := []string{}
for _, arg := range os.Args[1:] {
if !strings.HasPrefix(arg, "-child") {
args = append(args, arg)
}
}
args = append(args, "-child=true")
cmd := exec.Command(os.Args[0], args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("启动子进程失败: %w", err)
}
log.Printf(" 子进程已启动, PID: %d\n", cmd.Process.Pid)
return cmd, nil
}
func StopChildProcess(cmd *exec.Cmd) error {
if cmd == nil || cmd.Process == nil {
return fmt.Errorf("子进程无效")
}
if err := cmd.Process.Signal(syscall.SIGTERM); err != nil {
return fmt.Errorf("SIGTERM 失败: %w", err)
}
done := make(chan error, 1)
go func() { done <- cmd.Wait() }()
select {
case err := <-done:
return err
case <-time.After(2 * time.Second):
_ = cmd.Process.Kill()
}
return nil
}
func waitForSignalAndCleanUp(cmd *exec.Cmd) {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
sig := <-sigChan
fmt.Printf("主进程收到信号: %v\n", sig)
if cmd != nil && cmd.Process != nil {
_ = cmd.Process.Signal(sig)
}
StopChildProcess(cmd)
saveMapToFile("cn.json")
}
func StartDaemon() {
args := []string{}
for _, arg := range os.Args[1:] {
if !strings.HasPrefix(arg, "-d") && !strings.HasPrefix(arg, "-child") {
args = append(args, arg)
}
}
args = append(args, "-d=false", "-child=false")
cmd := exec.Command(os.Args[0], args...)
cmd.Start()
os.Exit(0)
}
func RunChildProcess() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM)
<-sigChan
}
func saveMapToFile(filePath string) error {
ProcessedMutex.Lock()
defer ProcessedMutex.Unlock()
file, err := os.Create(filePath)
if err != nil {
return err
}
defer file.Close()
return json.NewEncoder(file).Encode(ProcessedIPMap)
}
func loadFromFile(filePath string) error {
ProcessedMutex.Lock()
defer ProcessedMutex.Unlock()
file, err := os.Open(filePath)
if err != nil {
return nil
}
defer file.Close()
return json.NewDecoder(file).Decode(&ProcessedIPMap)
}
func InitMap() {
loadFromFile("cn.json")
}
func WriteLocalAddr() {
local_ipv4_addr = GetLocalIpv4Addr()
if local_ipv4_addr != "NULL" {
ProcessedMutex.Lock()
ProcessedIPMap[local_ipv4_addr] = time.Date(1971, 1, 1, 0, 0, 0, 0, time.UTC)
ProcessedMutex.Unlock()
}
}
func HandleCmd() {
// 定义命令行标志
var instruction string
var help bool
InterfaceName = flag.String("i", "", "指定要使用的网络接口")
flag.BoolVar(&InterfacesList, "l", false, "列出可用的网络接口")
Protocol = flag.String("f", "'tcp' or 'udp' or 'tcp or udp'", "指定 BPF 过滤器")
PcapFile = flag.String("o", "", "保存捕获数据的输出文件(可选)")
flag.StringVar(&instruction, "s", "",
"-s start 启动 Iptables 规则\n"+
"-s stop 停止 Iptables 规则\n"+
"-s list 打印 Iptables 规则\n"+
"-s reload 重启 Iptables 规则")
flag.BoolVar(&help, "h", false, "")
flag.BoolVar(&help, "help", false, "帮助信息")
flag.Parse()
if help {
fmt.Printf(
"\t\tDenyip firewall\n"+
"\tVersion 0.2\n"+
"\tE-mail: aixiao@aixiao.me\n"+
"\tBuild Date: %s\n", BuildDate)
flag.Usage()
fmt.Printf("\n")
os.Exit(0)
}
if instruction != "" {
switch instruction {
case "start":
for i := 0; i < MAX_IPSET_NAME; i++ {
_name := fmt.Sprintf("root%d", i)
iptables_add(_name)
}
os.Exit(0)
case "stop":
for i := 0; i < MAX_IPSET_NAME; i++ {
_name := fmt.Sprintf("root%d", i)
iptables_del(_name)
}
os.Exit(0)
case "r":
fallthrough
case "restart":
fallthrough
case "reload":
for i := 0; i < MAX_IPSET_NAME; i++ {
_name := fmt.Sprintf("root%d", i)
iptables_del(_name)
}
for i := 0; i < MAX_IPSET_NAME; i++ {
_name := fmt.Sprintf("root%d", i)
iptables_add(_name)
}
os.Exit(0)
case "l":
fallthrough
case "list":
_ = iptables_list()
os.Exit(0)
default:
log.Fatalf("未知的操作: %s. 请使用 'start' 或 'stop'.", instruction)
}
}
if InterfacesList {
printAvailableInterfaces()
os.Exit(0)
}
if *InterfaceName == "" {
log.Fatal("请使用 -i 标志指定网络接口,或者使用 -l 列出接口。")
}
}
func main() {
runtime.GOMAXPROCS(runtime.NumCPU())
HandleCmd()
CheckCommandExists("iptables")
CheckCommandExists("ipset")
embed_ip2region()
if *daemon {
StartDaemon()
}
if *child {
RunChildProcess()
return
}
InitMap()
RunMainProcess()
}