mirror of
https://github.com/Buriburizaem0n/nezha_domains.git
synced 2026-02-04 04:30:05 +00:00
[agent] splitting the agent into separate repositories
This commit is contained in:
@@ -1,589 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AlecAivazis/survey/v2"
|
||||
bpc "github.com/DaRealFreak/cloudflare-bp-go"
|
||||
"github.com/blang/semver"
|
||||
"github.com/go-ping/ping"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/nezhahq/go-github-selfupdate/selfupdate"
|
||||
"github.com/shirou/gopsutil/v3/disk"
|
||||
"github.com/shirou/gopsutil/v3/host"
|
||||
psnet "github.com/shirou/gopsutil/v3/net"
|
||||
flag "github.com/spf13/pflag"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/naiba/nezha/cmd/agent/monitor"
|
||||
"github.com/naiba/nezha/cmd/agent/processgroup"
|
||||
"github.com/naiba/nezha/cmd/agent/pty"
|
||||
"github.com/naiba/nezha/model"
|
||||
"github.com/naiba/nezha/pkg/utils"
|
||||
pb "github.com/naiba/nezha/proto"
|
||||
"github.com/naiba/nezha/service/rpc"
|
||||
)
|
||||
|
||||
type AgentCliParam struct {
|
||||
SkipConnectionCount bool // 跳过连接数检查
|
||||
SkipProcsCount bool // 跳过进程数量检查
|
||||
DisableAutoUpdate bool // 关闭自动更新
|
||||
DisableForceUpdate bool // 关闭强制更新
|
||||
DisableCommandExecute bool // 关闭命令执行
|
||||
Debug bool // debug模式
|
||||
Server string // 服务器地址
|
||||
ClientSecret string // 客户端密钥
|
||||
ReportDelay int // 报告间隔
|
||||
TLS bool // 是否使用TLS加密传输至服务端
|
||||
}
|
||||
|
||||
var (
|
||||
version string
|
||||
arch string
|
||||
client pb.NezhaServiceClient
|
||||
inited bool
|
||||
)
|
||||
|
||||
var (
|
||||
agentCliParam AgentCliParam
|
||||
agentConfig model.AgentConfig
|
||||
httpClient = &http.Client{
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
Timeout: time.Second * 30,
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
delayWhenError = time.Second * 10 // Agent 重连间隔
|
||||
networkTimeOut = time.Second * 5 // 普通网络超时
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.CommandLine.ParseErrorsWhitelist.UnknownFlags = true
|
||||
|
||||
http.DefaultClient.Timeout = time.Second * 5
|
||||
httpClient.Transport = bpc.AddCloudFlareByPass(httpClient.Transport)
|
||||
|
||||
ex, err := os.Executable()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
agentConfig.Read(filepath.Dir(ex) + "/config.yml")
|
||||
}
|
||||
|
||||
func main() {
|
||||
// windows环境处理
|
||||
if runtime.GOOS == "windows" {
|
||||
hostArch, err := host.KernelArch()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if hostArch == "i386" {
|
||||
hostArch = "386"
|
||||
}
|
||||
if hostArch == "i686" || hostArch == "ia64" || hostArch == "x86_64" {
|
||||
hostArch = "amd64"
|
||||
}
|
||||
if hostArch == "aarch64" {
|
||||
hostArch = "arm64"
|
||||
}
|
||||
if arch != hostArch {
|
||||
panic(fmt.Sprintf("与当前系统不匹配,当前运行 %s_%s, 需要下载 %s_%s", runtime.GOOS, arch, runtime.GOOS, hostArch))
|
||||
}
|
||||
}
|
||||
|
||||
// 来自于 GoReleaser 的版本号
|
||||
monitor.Version = version
|
||||
|
||||
// 初始化运行参数
|
||||
var isEditAgentConfig bool
|
||||
flag.BoolVarP(&agentCliParam.Debug, "debug", "d", false, "开启调试信息")
|
||||
flag.BoolVarP(&isEditAgentConfig, "edit-agent-config", "", false, "修改要监控的网卡/分区白名单")
|
||||
flag.StringVarP(&agentCliParam.Server, "server", "s", "localhost:5555", "管理面板RPC端口")
|
||||
flag.StringVarP(&agentCliParam.ClientSecret, "password", "p", "", "Agent连接Secret")
|
||||
flag.IntVar(&agentCliParam.ReportDelay, "report-delay", 1, "系统状态上报间隔")
|
||||
flag.BoolVar(&agentCliParam.SkipConnectionCount, "skip-conn", false, "不监控连接数")
|
||||
flag.BoolVar(&agentCliParam.SkipProcsCount, "skip-procs", false, "不监控进程数")
|
||||
flag.BoolVar(&agentCliParam.DisableCommandExecute, "disable-command-execute", false, "禁止在此机器上执行命令")
|
||||
flag.BoolVar(&agentCliParam.DisableAutoUpdate, "disable-auto-update", false, "禁用自动升级")
|
||||
flag.BoolVar(&agentCliParam.DisableForceUpdate, "disable-force-update", false, "禁用强制升级")
|
||||
flag.BoolVar(&agentCliParam.TLS, "tls", false, "启用SSL/TLS加密")
|
||||
flag.Parse()
|
||||
|
||||
if isEditAgentConfig {
|
||||
editAgentConfig()
|
||||
return
|
||||
}
|
||||
|
||||
if agentCliParam.ClientSecret == "" {
|
||||
flag.Usage()
|
||||
return
|
||||
}
|
||||
|
||||
if agentCliParam.ReportDelay < 1 || agentCliParam.ReportDelay > 4 {
|
||||
println("report-delay 的区间为 1-4")
|
||||
return
|
||||
}
|
||||
|
||||
run()
|
||||
}
|
||||
|
||||
func run() {
|
||||
auth := rpc.AuthHandler{
|
||||
ClientSecret: agentCliParam.ClientSecret,
|
||||
}
|
||||
|
||||
// 下载远程命令执行需要的终端
|
||||
if !agentCliParam.DisableCommandExecute {
|
||||
go pty.DownloadDependency()
|
||||
}
|
||||
// 上报服务器信息
|
||||
go reportState()
|
||||
// 更新IP信息
|
||||
go monitor.UpdateIP()
|
||||
|
||||
// 定时检查更新
|
||||
if _, err := semver.Parse(version); err == nil && !agentCliParam.DisableAutoUpdate {
|
||||
doSelfUpdate(true)
|
||||
go func() {
|
||||
for range time.Tick(20 * time.Minute) {
|
||||
doSelfUpdate(true)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
var err error
|
||||
var conn *grpc.ClientConn
|
||||
|
||||
retry := func() {
|
||||
inited = false
|
||||
println("Error to close connection ...")
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
time.Sleep(delayWhenError)
|
||||
println("Try to reconnect ...")
|
||||
}
|
||||
|
||||
for {
|
||||
timeOutCtx, cancel := context.WithTimeout(context.Background(), networkTimeOut)
|
||||
var securityOption grpc.DialOption
|
||||
if agentCliParam.TLS {
|
||||
securityOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{MinVersion: tls.VersionTLS12}))
|
||||
} else {
|
||||
securityOption = grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
}
|
||||
conn, err = grpc.DialContext(timeOutCtx, agentCliParam.Server, securityOption, grpc.WithPerRPCCredentials(&auth))
|
||||
if err != nil {
|
||||
println("与面板建立连接失败:", err)
|
||||
cancel()
|
||||
retry()
|
||||
continue
|
||||
}
|
||||
cancel()
|
||||
client = pb.NewNezhaServiceClient(conn)
|
||||
// 第一步注册
|
||||
timeOutCtx, cancel = context.WithTimeout(context.Background(), networkTimeOut)
|
||||
_, err = client.ReportSystemInfo(timeOutCtx, monitor.GetHost(&agentConfig).PB())
|
||||
if err != nil {
|
||||
println("上报系统信息失败:", err)
|
||||
cancel()
|
||||
retry()
|
||||
continue
|
||||
}
|
||||
cancel()
|
||||
inited = true
|
||||
// 执行 Task
|
||||
tasks, err := client.RequestTask(context.Background(), monitor.GetHost(&agentConfig).PB())
|
||||
if err != nil {
|
||||
println("请求任务失败:", err)
|
||||
retry()
|
||||
continue
|
||||
}
|
||||
err = receiveTasks(tasks)
|
||||
println("receiveTasks exit to main:", err)
|
||||
retry()
|
||||
}
|
||||
}
|
||||
|
||||
func receiveTasks(tasks pb.NezhaService_RequestTaskClient) error {
|
||||
var err error
|
||||
defer println("receiveTasks exit", time.Now(), "=>", err)
|
||||
for {
|
||||
var task *pb.Task
|
||||
task, err = tasks.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
println("task panic", task, err)
|
||||
}
|
||||
}()
|
||||
doTask(task)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func doTask(task *pb.Task) {
|
||||
var result pb.TaskResult
|
||||
result.Id = task.GetId()
|
||||
result.Type = task.GetType()
|
||||
switch task.GetType() {
|
||||
case model.TaskTypeTerminal:
|
||||
handleTerminalTask(task)
|
||||
case model.TaskTypeHTTPGET:
|
||||
handleHttpGetTask(task, &result)
|
||||
case model.TaskTypeICMPPing:
|
||||
handleIcmpPingTask(task, &result)
|
||||
case model.TaskTypeTCPPing:
|
||||
handleTcpPingTask(task, &result)
|
||||
case model.TaskTypeCommand:
|
||||
handleCommandTask(task, &result)
|
||||
case model.TaskTypeUpgrade:
|
||||
handleUpgradeTask(task, &result)
|
||||
case model.TaskTypeKeepalive:
|
||||
return
|
||||
default:
|
||||
println("不支持的任务:", task)
|
||||
}
|
||||
client.ReportTask(context.Background(), &result)
|
||||
}
|
||||
|
||||
// reportState 向server上报状态信息
|
||||
func reportState() {
|
||||
var lastReportHostInfo time.Time
|
||||
var err error
|
||||
defer println("reportState exit", time.Now(), "=>", err)
|
||||
for {
|
||||
// 为了更准确的记录时段流量,inited 后再上传状态信息
|
||||
if client != nil && inited {
|
||||
monitor.TrackNetworkSpeed(&agentConfig)
|
||||
timeOutCtx, cancel := context.WithTimeout(context.Background(), networkTimeOut)
|
||||
_, err = client.ReportSystemState(timeOutCtx, monitor.GetState(&agentConfig, agentCliParam.SkipConnectionCount, agentCliParam.SkipProcsCount).PB())
|
||||
cancel()
|
||||
if err != nil {
|
||||
println("reportState error", err)
|
||||
time.Sleep(delayWhenError)
|
||||
}
|
||||
// 每10分钟重新获取一次硬件信息
|
||||
if lastReportHostInfo.Before(time.Now().Add(-10 * time.Minute)) {
|
||||
lastReportHostInfo = time.Now()
|
||||
client.ReportSystemInfo(context.Background(), monitor.GetHost(&agentConfig).PB())
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Second * time.Duration(agentCliParam.ReportDelay))
|
||||
}
|
||||
}
|
||||
|
||||
// doSelfUpdate 执行更新检查 如果更新成功则会结束进程
|
||||
func doSelfUpdate(useLocalVersion bool) {
|
||||
v := semver.MustParse("0.1.0")
|
||||
if useLocalVersion {
|
||||
v = semver.MustParse(version)
|
||||
}
|
||||
println("检查更新:", v)
|
||||
latest, err := selfupdate.UpdateSelf(v, "nezhahq/agent")
|
||||
if err != nil {
|
||||
println("更新失败:", err)
|
||||
return
|
||||
}
|
||||
if !latest.Version.Equals(v) {
|
||||
println("已经更新至:", latest.Version, " 正在结束进程")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func handleUpgradeTask(task *pb.Task, result *pb.TaskResult) {
|
||||
if agentCliParam.DisableForceUpdate {
|
||||
return
|
||||
}
|
||||
doSelfUpdate(false)
|
||||
}
|
||||
|
||||
func handleTcpPingTask(task *pb.Task, result *pb.TaskResult) {
|
||||
start := time.Now()
|
||||
conn, err := net.DialTimeout("tcp", task.GetData(), time.Second*10)
|
||||
if err == nil {
|
||||
conn.Write([]byte("ping\n"))
|
||||
conn.Close()
|
||||
result.Delay = float32(time.Since(start).Microseconds()) / 1000.0
|
||||
result.Successful = true
|
||||
} else {
|
||||
result.Data = err.Error()
|
||||
}
|
||||
}
|
||||
|
||||
func handleIcmpPingTask(task *pb.Task, result *pb.TaskResult) {
|
||||
pinger, err := ping.NewPinger(task.GetData())
|
||||
if err == nil {
|
||||
pinger.SetPrivileged(true)
|
||||
pinger.Count = 5
|
||||
pinger.Timeout = time.Second * 20
|
||||
err = pinger.Run() // Blocks until finished.
|
||||
}
|
||||
if err == nil {
|
||||
stat := pinger.Statistics()
|
||||
if stat.PacketsRecv == 0 {
|
||||
result.Data = "pockets recv 0"
|
||||
return
|
||||
}
|
||||
result.Delay = float32(stat.AvgRtt.Microseconds()) / 1000.0
|
||||
result.Successful = true
|
||||
} else {
|
||||
result.Data = err.Error()
|
||||
}
|
||||
}
|
||||
|
||||
func handleHttpGetTask(task *pb.Task, result *pb.TaskResult) {
|
||||
start := time.Now()
|
||||
resp, err := httpClient.Get(task.GetData())
|
||||
if err == nil {
|
||||
// 检查 HTTP Response 状态
|
||||
result.Delay = float32(time.Since(start).Microseconds()) / 1000.0
|
||||
if resp.StatusCode > 399 || resp.StatusCode < 200 {
|
||||
err = errors.New("\n应用错误:" + resp.Status)
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
// 检查 SSL 证书信息
|
||||
if resp.TLS != nil && len(resp.TLS.PeerCertificates) > 0 {
|
||||
c := resp.TLS.PeerCertificates[0]
|
||||
result.Data = c.Issuer.CommonName + "|" + c.NotAfter.String()
|
||||
}
|
||||
result.Successful = true
|
||||
} else {
|
||||
// HTTP 请求失败
|
||||
result.Data = err.Error()
|
||||
}
|
||||
}
|
||||
|
||||
func handleCommandTask(task *pb.Task, result *pb.TaskResult) {
|
||||
if agentCliParam.DisableCommandExecute {
|
||||
result.Data = "此 Agent 已禁止命令执行"
|
||||
return
|
||||
}
|
||||
startedAt := time.Now()
|
||||
var cmd *exec.Cmd
|
||||
var endCh = make(chan struct{})
|
||||
pg, err := processgroup.NewProcessExitGroup()
|
||||
if err != nil {
|
||||
// 进程组创建失败,直接退出
|
||||
result.Data = err.Error()
|
||||
return
|
||||
}
|
||||
timeout := time.NewTimer(time.Hour * 2)
|
||||
if utils.IsWindows() {
|
||||
cmd = exec.Command("cmd", "/c", task.GetData()) // #nosec
|
||||
} else {
|
||||
cmd = exec.Command("sh", "-c", task.GetData()) // #nosec
|
||||
}
|
||||
cmd.Env = os.Environ()
|
||||
pg.AddProcess(cmd)
|
||||
go func() {
|
||||
select {
|
||||
case <-timeout.C:
|
||||
result.Data = "任务执行超时\n"
|
||||
close(endCh)
|
||||
pg.Dispose()
|
||||
case <-endCh:
|
||||
timeout.Stop()
|
||||
}
|
||||
}()
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
result.Data += fmt.Sprintf("%s\n%s", string(output), err.Error())
|
||||
} else {
|
||||
close(endCh)
|
||||
result.Data = string(output)
|
||||
result.Successful = true
|
||||
}
|
||||
pg.Dispose()
|
||||
result.Delay = float32(time.Since(startedAt).Seconds())
|
||||
}
|
||||
|
||||
type WindowSize struct {
|
||||
Cols uint32
|
||||
Rows uint32
|
||||
}
|
||||
|
||||
func handleTerminalTask(task *pb.Task) {
|
||||
if agentCliParam.DisableCommandExecute {
|
||||
println("此 Agent 已禁止命令执行")
|
||||
return
|
||||
}
|
||||
var terminal model.TerminalTask
|
||||
err := utils.Json.Unmarshal([]byte(task.GetData()), &terminal)
|
||||
if err != nil {
|
||||
println("Terminal 任务解析错误:", err)
|
||||
return
|
||||
}
|
||||
protocol := "ws"
|
||||
if terminal.UseSSL {
|
||||
protocol += "s"
|
||||
}
|
||||
header := http.Header{}
|
||||
header.Add("Secret", agentCliParam.ClientSecret)
|
||||
// 目前只兼容Cloudflare验证
|
||||
// 后续可能需要兼容更多的Cookie验证情况
|
||||
if terminal.Cookie != "" {
|
||||
cfCookie := fmt.Sprintf("CF_Authorization=%s", terminal.Cookie)
|
||||
header.Add("Cookie", cfCookie)
|
||||
}
|
||||
conn, _, err := websocket.DefaultDialer.Dial(fmt.Sprintf("%s://%s/terminal/%s", protocol, terminal.Host, terminal.Session), header)
|
||||
if err != nil {
|
||||
println("Terminal 连接失败:", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
tty, err := pty.Start()
|
||||
if err != nil {
|
||||
println("Terminal pty.Start失败:", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := tty.Close()
|
||||
conn.Close()
|
||||
println("terminal exit", terminal.Session, err)
|
||||
}()
|
||||
println("terminal init", terminal.Session)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
buf := make([]byte, 1024)
|
||||
read, err := tty.Read(buf)
|
||||
if err != nil {
|
||||
conn.WriteMessage(websocket.TextMessage, []byte(err.Error()))
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
conn.WriteMessage(websocket.BinaryMessage, buf[:read])
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
messageType, reader, err := conn.NextReader()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if messageType == websocket.TextMessage {
|
||||
continue
|
||||
}
|
||||
|
||||
dataTypeBuf := make([]byte, 1)
|
||||
read, err := reader.Read(dataTypeBuf)
|
||||
if err != nil {
|
||||
conn.WriteMessage(websocket.TextMessage, []byte("Unable to read message type from reader"))
|
||||
return
|
||||
}
|
||||
|
||||
if read != 1 {
|
||||
return
|
||||
}
|
||||
|
||||
switch dataTypeBuf[0] {
|
||||
case 0:
|
||||
io.Copy(tty, reader)
|
||||
case 1:
|
||||
decoder := utils.Json.NewDecoder(reader)
|
||||
var resizeMessage WindowSize
|
||||
err := decoder.Decode(&resizeMessage)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
tty.Setsize(resizeMessage.Cols, resizeMessage.Rows)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 修改Agent要监控的网卡与硬盘分区
|
||||
func editAgentConfig() {
|
||||
nc, err := psnet.IOCounters(true)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
var nicAllowlistOptions []string
|
||||
for _, v := range nc {
|
||||
nicAllowlistOptions = append(nicAllowlistOptions, v.Name)
|
||||
}
|
||||
|
||||
var diskAllowlistOptions []string
|
||||
diskList, err := disk.Partitions(false)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for _, p := range diskList {
|
||||
diskAllowlistOptions = append(diskAllowlistOptions, fmt.Sprintf("%s\t%s\t%s", p.Mountpoint, p.Fstype, p.Device))
|
||||
}
|
||||
|
||||
var qs = []*survey.Question{
|
||||
{
|
||||
Name: "nic",
|
||||
Prompt: &survey.MultiSelect{
|
||||
Message: "选择要监控的网卡",
|
||||
Options: nicAllowlistOptions,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "disk",
|
||||
Prompt: &survey.MultiSelect{
|
||||
Message: "选择要监控的硬盘分区",
|
||||
Options: diskAllowlistOptions,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
answers := struct {
|
||||
Nic []string
|
||||
Disk []string
|
||||
}{}
|
||||
|
||||
err = survey.Ask(qs, &answers, survey.WithValidator(survey.Required))
|
||||
if err != nil {
|
||||
fmt.Println("选择错误", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
agentConfig.HardDrivePartitionAllowlist = []string{}
|
||||
for _, v := range answers.Disk {
|
||||
agentConfig.HardDrivePartitionAllowlist = append(agentConfig.HardDrivePartitionAllowlist, strings.Split(v, "\t")[0])
|
||||
}
|
||||
|
||||
agentConfig.NICAllowlist = make(map[string]bool)
|
||||
for _, v := range answers.Nic {
|
||||
agentConfig.NICAllowlist[v] = true
|
||||
}
|
||||
|
||||
if err = agentConfig.Save(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fmt.Println("修改自定义配置成功,重启 Agent 后生效")
|
||||
}
|
||||
|
||||
func println(v ...interface{}) {
|
||||
if agentCliParam.Debug {
|
||||
fmt.Printf("NEZHA@%s>> ", time.Now().Format("2006-01-02 15:04:05"))
|
||||
fmt.Println(v...)
|
||||
}
|
||||
}
|
||||
@@ -1,290 +0,0 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/dean2021/goss"
|
||||
"github.com/shirou/gopsutil/v3/cpu"
|
||||
"github.com/shirou/gopsutil/v3/disk"
|
||||
"github.com/shirou/gopsutil/v3/host"
|
||||
"github.com/shirou/gopsutil/v3/load"
|
||||
"github.com/shirou/gopsutil/v3/mem"
|
||||
"github.com/shirou/gopsutil/v3/net"
|
||||
"github.com/shirou/gopsutil/v3/process"
|
||||
|
||||
"github.com/naiba/nezha/model"
|
||||
)
|
||||
|
||||
var (
|
||||
Version string
|
||||
expectDiskFsTypes = []string{
|
||||
"apfs", "ext4", "ext3", "ext2", "f2fs", "reiserfs", "jfs", "btrfs",
|
||||
"fuseblk", "zfs", "simfs", "ntfs", "fat32", "exfat", "xfs", "fuse.rclone",
|
||||
}
|
||||
excludeNetInterfaces = []string{
|
||||
"lo", "tun", "docker", "veth", "br-", "vmbr", "vnet", "kube",
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
netInSpeed, netOutSpeed, netInTransfer, netOutTransfer, lastUpdateNetStats uint64
|
||||
cachedBootTime time.Time
|
||||
)
|
||||
|
||||
// GetHost 获取主机硬件信息
|
||||
func GetHost(agentConfig *model.AgentConfig) *model.Host {
|
||||
var ret model.Host
|
||||
|
||||
var cpuType string
|
||||
hi, err := host.Info()
|
||||
if err != nil {
|
||||
println("host.Info error:", err)
|
||||
} else {
|
||||
if hi.VirtualizationSystem != "" {
|
||||
cpuType = "Virtual"
|
||||
} else {
|
||||
cpuType = "Physical"
|
||||
}
|
||||
ret.Platform = hi.Platform
|
||||
ret.PlatformVersion = hi.PlatformVersion
|
||||
ret.Arch = hi.KernelArch
|
||||
ret.Virtualization = hi.VirtualizationSystem
|
||||
ret.BootTime = hi.BootTime
|
||||
}
|
||||
|
||||
cpuModelCount := make(map[string]int)
|
||||
ci, err := cpu.Info()
|
||||
if err != nil {
|
||||
println("cpu.Info error:", err)
|
||||
} else {
|
||||
for i := 0; i < len(ci); i++ {
|
||||
cpuModelCount[ci[i].ModelName]++
|
||||
}
|
||||
for model, count := range cpuModelCount {
|
||||
ret.CPU = append(ret.CPU, fmt.Sprintf("%s %d %s Core", model, count, cpuType))
|
||||
}
|
||||
}
|
||||
|
||||
ret.DiskTotal, _ = getDiskTotalAndUsed(agentConfig)
|
||||
|
||||
mv, err := mem.VirtualMemory()
|
||||
if err != nil {
|
||||
println("mem.VirtualMemory error:", err)
|
||||
} else {
|
||||
ret.MemTotal = mv.Total
|
||||
if runtime.GOOS != "windows" {
|
||||
ret.SwapTotal = mv.SwapTotal
|
||||
}
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
ms, err := mem.SwapMemory()
|
||||
if err != nil {
|
||||
println("mem.SwapMemory error:", err)
|
||||
} else {
|
||||
ret.SwapTotal = ms.Total
|
||||
}
|
||||
}
|
||||
|
||||
cachedBootTime = time.Unix(int64(hi.BootTime), 0)
|
||||
|
||||
ret.IP = CachedIP
|
||||
ret.CountryCode = strings.ToLower(cachedCountry)
|
||||
ret.Version = Version
|
||||
|
||||
return &ret
|
||||
}
|
||||
|
||||
func GetState(agentConfig *model.AgentConfig, skipConnectionCount bool, skipProcsCount bool) *model.HostState {
|
||||
var ret model.HostState
|
||||
|
||||
cp, err := cpu.Percent(0, false)
|
||||
if err != nil {
|
||||
println("cpu.Percent error:", err)
|
||||
} else {
|
||||
ret.CPU = cp[0]
|
||||
}
|
||||
|
||||
vm, err := mem.VirtualMemory()
|
||||
if err != nil {
|
||||
println("mem.VirtualMemory error:", err)
|
||||
} else {
|
||||
ret.MemUsed = vm.Total - vm.Available
|
||||
if runtime.GOOS != "windows" {
|
||||
ret.SwapUsed = vm.SwapTotal - vm.SwapFree
|
||||
}
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
// gopsutil 在 Windows 下不能正确取 swap
|
||||
ms, err := mem.SwapMemory()
|
||||
if err != nil {
|
||||
println("mem.SwapMemory error:", err)
|
||||
} else {
|
||||
ret.SwapUsed = ms.Used
|
||||
}
|
||||
}
|
||||
|
||||
_, ret.DiskUsed = getDiskTotalAndUsed(agentConfig)
|
||||
|
||||
loadStat, err := load.Avg()
|
||||
if err != nil {
|
||||
println("load.Avg error:", err)
|
||||
} else {
|
||||
ret.Load1 = loadStat.Load1
|
||||
ret.Load5 = loadStat.Load5
|
||||
ret.Load15 = loadStat.Load15
|
||||
}
|
||||
|
||||
var procs []int32
|
||||
if !skipProcsCount {
|
||||
procs, err = process.Pids()
|
||||
if err != nil {
|
||||
println("process.Pids error:", err)
|
||||
} else {
|
||||
ret.ProcessCount = uint64(len(procs))
|
||||
}
|
||||
}
|
||||
|
||||
var tcpConnCount, udpConnCount uint64
|
||||
if !skipConnectionCount {
|
||||
ss_err := true
|
||||
if runtime.GOOS == "linux" {
|
||||
tcpStat, err_tcp := goss.ConnectionsWithProtocol(goss.AF_INET, syscall.IPPROTO_TCP)
|
||||
udpStat, err_udp := goss.ConnectionsWithProtocol(goss.AF_INET, syscall.IPPROTO_UDP)
|
||||
if err_tcp == nil && err_udp == nil {
|
||||
ss_err = false
|
||||
tcpConnCount = uint64(len(tcpStat))
|
||||
udpConnCount = uint64(len(udpStat))
|
||||
}
|
||||
if strings.Contains(CachedIP, ":") {
|
||||
tcpStat6, err_tcp := goss.ConnectionsWithProtocol(goss.AF_INET6, syscall.IPPROTO_TCP)
|
||||
udpStat6, err_udp := goss.ConnectionsWithProtocol(goss.AF_INET6, syscall.IPPROTO_UDP)
|
||||
if err_tcp == nil && err_udp == nil {
|
||||
ss_err = false
|
||||
tcpConnCount += uint64(len(tcpStat6))
|
||||
udpConnCount += uint64(len(udpStat6))
|
||||
}
|
||||
}
|
||||
}
|
||||
if ss_err {
|
||||
conns, _ := net.Connections("all")
|
||||
for i := 0; i < len(conns); i++ {
|
||||
switch conns[i].Type {
|
||||
case syscall.SOCK_STREAM:
|
||||
tcpConnCount++
|
||||
case syscall.SOCK_DGRAM:
|
||||
udpConnCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ret.NetInTransfer, ret.NetOutTransfer = netInTransfer, netOutTransfer
|
||||
ret.NetInSpeed, ret.NetOutSpeed = netInSpeed, netOutSpeed
|
||||
ret.Uptime = uint64(time.Since(cachedBootTime).Seconds())
|
||||
ret.TcpConnCount, ret.UdpConnCount = tcpConnCount, udpConnCount
|
||||
|
||||
return &ret
|
||||
}
|
||||
|
||||
// TrackNetworkSpeed NIC监控,统计流量与速度
|
||||
func TrackNetworkSpeed(agentConfig *model.AgentConfig) {
|
||||
var innerNetInTransfer, innerNetOutTransfer uint64
|
||||
nc, err := net.IOCounters(true)
|
||||
if err == nil {
|
||||
for _, v := range nc {
|
||||
if len(agentConfig.NICAllowlist) > 0 {
|
||||
if !agentConfig.NICAllowlist[v.Name] {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if isListContainsStr(excludeNetInterfaces, v.Name) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
innerNetInTransfer += v.BytesRecv
|
||||
innerNetOutTransfer += v.BytesSent
|
||||
}
|
||||
now := uint64(time.Now().Unix())
|
||||
diff := now - lastUpdateNetStats
|
||||
if diff > 0 {
|
||||
netInSpeed = (innerNetInTransfer - netInTransfer) / diff
|
||||
netOutSpeed = (innerNetOutTransfer - netOutTransfer) / diff
|
||||
}
|
||||
netInTransfer = innerNetInTransfer
|
||||
netOutTransfer = innerNetOutTransfer
|
||||
lastUpdateNetStats = now
|
||||
}
|
||||
}
|
||||
|
||||
func getDiskTotalAndUsed(agentConfig *model.AgentConfig) (total uint64, used uint64) {
|
||||
devices := make(map[string]string)
|
||||
|
||||
if len(agentConfig.HardDrivePartitionAllowlist) > 0 {
|
||||
// 如果配置了白名单,使用白名单的列表
|
||||
for i, v := range agentConfig.HardDrivePartitionAllowlist {
|
||||
devices[strconv.Itoa(i)] = v
|
||||
}
|
||||
} else {
|
||||
// 否则使用默认过滤规则
|
||||
diskList, _ := disk.Partitions(false)
|
||||
for _, d := range diskList {
|
||||
fsType := strings.ToLower(d.Fstype)
|
||||
// 不统计 K8s 的虚拟挂载点:https://github.com/shirou/gopsutil/issues/1007
|
||||
if devices[d.Device] == "" && isListContainsStr(expectDiskFsTypes, fsType) && !strings.Contains(d.Mountpoint, "/var/lib/kubelet") {
|
||||
devices[d.Device] = d.Mountpoint
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, mountPath := range devices {
|
||||
diskUsageOf, err := disk.Usage(mountPath)
|
||||
if err == nil {
|
||||
total += diskUsageOf.Total
|
||||
used += diskUsageOf.Used
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback 到这个方法,仅统计根路径,适用于OpenVZ之类的.
|
||||
if runtime.GOOS == "linux" && total == 0 && used == 0 {
|
||||
cmd := exec.Command("df")
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err == nil {
|
||||
s := strings.Split(string(out), "\n")
|
||||
for _, c := range s {
|
||||
info := strings.Fields(c)
|
||||
if len(info) == 6 {
|
||||
if info[5] == "/" {
|
||||
total, _ = strconv.ParseUint(info[1], 0, 64)
|
||||
used, _ = strconv.ParseUint(info[2], 0, 64)
|
||||
// 默认获取的是1K块为单位的.
|
||||
total = total * 1024
|
||||
used = used * 1024
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func isListContainsStr(list []string, str string) bool {
|
||||
for i := 0; i < len(list); i++ {
|
||||
if strings.Contains(str, list[i]) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func println(v ...interface{}) {
|
||||
fmt.Printf("NEZHA@%s>> ", time.Now().Format("2006-01-02 15:04:05"))
|
||||
fmt.Println(v...)
|
||||
}
|
||||
@@ -1,125 +0,0 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
fakeUa "github.com/EDDYCJY/fake-useragent"
|
||||
|
||||
"github.com/naiba/nezha/pkg/utils"
|
||||
)
|
||||
|
||||
type geoIP struct {
|
||||
CountryCode string `json:"country_code,omitempty"`
|
||||
CountryCode2 string `json:"countryCode,omitempty"`
|
||||
IP string `json:"ip,omitempty"`
|
||||
Query string `json:"query,omitempty"`
|
||||
Location struct {
|
||||
CountryCode string `json:"country_code,omitempty"`
|
||||
} `json:"location,omitempty"`
|
||||
}
|
||||
|
||||
func (ip *geoIP) Unmarshal(body []byte) error {
|
||||
if err := utils.Json.Unmarshal(body, ip); err != nil {
|
||||
return err
|
||||
}
|
||||
if ip.IP == "" && ip.Query != "" {
|
||||
ip.IP = ip.Query
|
||||
}
|
||||
if ip.CountryCode == "" && ip.CountryCode2 != "" {
|
||||
ip.CountryCode = ip.CountryCode2
|
||||
}
|
||||
if ip.CountryCode == "" && ip.Location.CountryCode != "" {
|
||||
ip.CountryCode = ip.Location.CountryCode
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
geoIPApiList = []string{
|
||||
"http://api.myip.la/en?json",
|
||||
"https://api.ip.sb/geoip",
|
||||
"https://ipapi.co/json",
|
||||
"http://ip-api.com/json/",
|
||||
// "https://extreme-ip-lookup.com/json/", // 不精确
|
||||
// "https://ip.seeip.org/geoip", // 不精确
|
||||
// "https://freegeoip.app/json/", // 需要 Key
|
||||
}
|
||||
CachedIP, cachedCountry string
|
||||
httpClientV4 = utils.NewSingleStackHTTPClient(time.Second*20, time.Second*5, time.Second*10, false)
|
||||
httpClientV6 = utils.NewSingleStackHTTPClient(time.Second*20, time.Second*5, time.Second*10, true)
|
||||
)
|
||||
|
||||
// UpdateIP 每30分钟更新一次IP地址与国家码的缓存
|
||||
func UpdateIP() {
|
||||
for {
|
||||
ipv4 := fetchGeoIP(geoIPApiList, false)
|
||||
ipv6 := fetchGeoIP(geoIPApiList, true)
|
||||
if ipv4.IP == "" && ipv6.IP == "" {
|
||||
time.Sleep(time.Minute)
|
||||
continue
|
||||
}
|
||||
if ipv4.IP == "" || ipv6.IP == "" {
|
||||
CachedIP = fmt.Sprintf("%s%s", ipv4.IP, ipv6.IP)
|
||||
} else {
|
||||
CachedIP = fmt.Sprintf("%s/%s", ipv4.IP, ipv6.IP)
|
||||
}
|
||||
if ipv4.CountryCode != "" {
|
||||
cachedCountry = ipv4.CountryCode
|
||||
} else if ipv6.CountryCode != "" {
|
||||
cachedCountry = ipv6.CountryCode
|
||||
}
|
||||
time.Sleep(time.Minute * 30)
|
||||
}
|
||||
}
|
||||
|
||||
func fetchGeoIP(servers []string, isV6 bool) geoIP {
|
||||
var ip geoIP
|
||||
var resp *http.Response
|
||||
var err error
|
||||
|
||||
// 双栈支持参差不齐,不能随机请求,有些 IPv6 取不到 IP
|
||||
for i := 0; i < len(servers); i++ {
|
||||
if isV6 {
|
||||
resp, err = httpGetWithUA(httpClientV6, servers[i])
|
||||
} else {
|
||||
resp, err = httpGetWithUA(httpClientV4, servers[i])
|
||||
}
|
||||
if err == nil {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
if err := ip.Unmarshal(body); err != nil {
|
||||
continue
|
||||
}
|
||||
// 没取到 v6 IP
|
||||
if isV6 && !strings.Contains(ip.IP, ":") {
|
||||
continue
|
||||
}
|
||||
// 没取到 v4 IP
|
||||
if !isV6 && !strings.Contains(ip.IP, ".") {
|
||||
continue
|
||||
}
|
||||
// 未获取到国家码
|
||||
if ip.CountryCode == "" {
|
||||
continue
|
||||
}
|
||||
return ip
|
||||
}
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
func httpGetWithUA(client *http.Client, url string) (*http.Response, error) {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("User-Agent", fakeUa.Random())
|
||||
return client.Do(req)
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/naiba/nezha/pkg/utils"
|
||||
)
|
||||
|
||||
func TestGeoIPApi(t *testing.T) {
|
||||
for i := 0; i < len(geoIPApiList); i++ {
|
||||
resp, err := httpGetWithUA(httpClientV4, geoIPApiList[i])
|
||||
assert.Nil(t, err)
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
assert.Nil(t, err)
|
||||
resp.Body.Close()
|
||||
var ip geoIP
|
||||
err = ip.Unmarshal(body)
|
||||
assert.Nil(t, err)
|
||||
t.Logf("%s %s %s", geoIPApiList[i], ip.CountryCode, utils.IPDesensitize(ip.IP))
|
||||
assert.True(t, ip.IP != "")
|
||||
assert.True(t, ip.CountryCode != "")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchGeoIP(t *testing.T) {
|
||||
ip := fetchGeoIP(geoIPApiList, false)
|
||||
assert.NotEmpty(t, ip.IP)
|
||||
assert.NotEmpty(t, ip.CountryCode)
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package processgroup
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"sync"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
type ProcessExitGroup struct {
|
||||
cmds []*exec.Cmd
|
||||
}
|
||||
|
||||
func NewProcessExitGroup() (ProcessExitGroup, error) {
|
||||
return ProcessExitGroup{}, nil
|
||||
}
|
||||
|
||||
func (g *ProcessExitGroup) killChildProcess(c *exec.Cmd) error {
|
||||
pgid, err := syscall.Getpgid(c.Process.Pid)
|
||||
if err != nil {
|
||||
// Fall-back on error. Kill the main process only.
|
||||
c.Process.Kill()
|
||||
}
|
||||
// Kill the whole process group.
|
||||
syscall.Kill(-pgid, syscall.SIGTERM)
|
||||
return c.Wait()
|
||||
}
|
||||
|
||||
func (g *ProcessExitGroup) Dispose() []error {
|
||||
var errors []error
|
||||
mutex := new(sync.Mutex)
|
||||
wg := new(sync.WaitGroup)
|
||||
wg.Add(len(g.cmds))
|
||||
for _, c := range g.cmds {
|
||||
go func(c *exec.Cmd) {
|
||||
defer wg.Done()
|
||||
if err := g.killChildProcess(c); err != nil {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
errors = append(errors, err)
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
wg.Wait()
|
||||
return errors
|
||||
}
|
||||
|
||||
func (g *ProcessExitGroup) AddProcess(cmd *exec.Cmd) error {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
|
||||
g.cmds = append(g.cmds, cmd)
|
||||
return nil
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package processgroup
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
type ProcessExitGroup struct {
|
||||
cmds []*exec.Cmd
|
||||
}
|
||||
|
||||
func NewProcessExitGroup() (ProcessExitGroup, error) {
|
||||
return ProcessExitGroup{}, nil
|
||||
}
|
||||
|
||||
func (g *ProcessExitGroup) Dispose() error {
|
||||
for _, c := range g.cmds {
|
||||
if err := exec.Command("taskkill", "/F", "/T", "/PID", fmt.Sprint(c.Process.Pid)).Run(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *ProcessExitGroup) AddProcess(cmd *exec.Cmd) error {
|
||||
g.cmds = append(g.cmds, cmd)
|
||||
return nil
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package pty
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
|
||||
opty "github.com/creack/pty"
|
||||
)
|
||||
|
||||
var defaultShells = []string{"zsh", "fish", "bash", "sh"}
|
||||
|
||||
type Pty struct {
|
||||
tty *os.File
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
func DownloadDependency() {
|
||||
}
|
||||
|
||||
func Start() (*Pty, error) {
|
||||
var shellPath string
|
||||
for i := 0; i < len(defaultShells); i++ {
|
||||
shellPath, _ = exec.LookPath(defaultShells[i])
|
||||
if shellPath != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
if shellPath == "" {
|
||||
return nil, errors.New("没有可用终端")
|
||||
}
|
||||
cmd := exec.Command(shellPath) // #nosec
|
||||
cmd.Env = append(os.Environ(), "TERM=xterm")
|
||||
tty, err := opty.Start(cmd)
|
||||
return &Pty{tty: tty, cmd: cmd}, err
|
||||
}
|
||||
|
||||
func (pty *Pty) Write(p []byte) (n int, err error) {
|
||||
return pty.tty.Write(p)
|
||||
}
|
||||
|
||||
func (pty *Pty) Read(p []byte) (n int, err error) {
|
||||
return pty.tty.Read(p)
|
||||
}
|
||||
|
||||
func (pty *Pty) Setsize(cols, rows uint32) error {
|
||||
return opty.Setsize(pty.tty, &opty.Winsize{
|
||||
Cols: uint16(cols),
|
||||
Rows: uint16(rows),
|
||||
})
|
||||
}
|
||||
|
||||
func (pty *Pty) killChildProcess(c *exec.Cmd) error {
|
||||
pgid, err := syscall.Getpgid(c.Process.Pid)
|
||||
if err != nil {
|
||||
// Fall-back on error. Kill the main process only.
|
||||
c.Process.Kill()
|
||||
}
|
||||
// Kill the whole process group.
|
||||
syscall.Kill(-pgid, syscall.SIGKILL) // SIGKILL 直接杀掉 SIGTERM 发送信号,等待进程自己退出
|
||||
return c.Wait()
|
||||
}
|
||||
|
||||
func (pty *Pty) Close() error {
|
||||
if err := pty.tty.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return pty.killChildProcess(pty.cmd)
|
||||
}
|
||||
@@ -1,106 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package pty
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/artdarek/go-unzip"
|
||||
"github.com/iamacarpet/go-winpty"
|
||||
)
|
||||
|
||||
type Pty struct {
|
||||
tty *winpty.WinPTY
|
||||
}
|
||||
|
||||
func DownloadDependency() {
|
||||
executablePath, err := getExecutableFilePath()
|
||||
if err != nil {
|
||||
fmt.Println("NEZHA>> wintty 获取文件路径失败", err)
|
||||
return
|
||||
}
|
||||
|
||||
winptyAgentExe := filepath.Join(executablePath, "winpty-agent.exe")
|
||||
winptyAgentDll := filepath.Join(executablePath, "winpty.dll")
|
||||
|
||||
fe, errFe := os.Stat(winptyAgentExe)
|
||||
fd, errFd := os.Stat(winptyAgentDll)
|
||||
if errFe == nil && fe.Size() > 300000 && errFd == nil && fd.Size() > 300000 {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := http.Get("https://dn-dao-github-mirror.daocloud.io/rprichard/winpty/releases/download/0.4.3/winpty-0.4.3-msvc2015.zip")
|
||||
if err != nil {
|
||||
log.Println("NEZHA>> wintty 下载失败", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
content, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Println("NEZHA>> wintty 下载失败", err)
|
||||
return
|
||||
}
|
||||
if err := os.WriteFile("./wintty.zip", content, os.FileMode(0777)); err != nil {
|
||||
log.Println("NEZHA>> wintty 写入失败", err)
|
||||
return
|
||||
}
|
||||
if err := unzip.New("./wintty.zip", "./wintty").Extract(); err != nil {
|
||||
fmt.Println("NEZHA>> wintty 解压失败", err)
|
||||
return
|
||||
}
|
||||
arch := "x64"
|
||||
if runtime.GOARCH != "amd64" {
|
||||
arch = "ia32"
|
||||
}
|
||||
|
||||
os.Rename("./wintty/"+arch+"/bin/winpty-agent.exe", winptyAgentExe)
|
||||
os.Rename("./wintty/"+arch+"/bin/winpty.dll", winptyAgentDll)
|
||||
os.RemoveAll("./wintty")
|
||||
os.RemoveAll("./wintty.zip")
|
||||
}
|
||||
|
||||
func getExecutableFilePath() (string, error) {
|
||||
ex, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Dir(ex), nil
|
||||
}
|
||||
|
||||
func Start() (*Pty, error) {
|
||||
shellPath, err := exec.LookPath("powershell.exe")
|
||||
if err != nil || shellPath == "" {
|
||||
shellPath = "cmd.exe"
|
||||
}
|
||||
path, err := getExecutableFilePath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tty, err := winpty.OpenDefault(path, shellPath)
|
||||
return &Pty{tty: tty}, err
|
||||
}
|
||||
|
||||
func (pty *Pty) Write(p []byte) (n int, err error) {
|
||||
return pty.tty.StdIn.Write(p)
|
||||
}
|
||||
|
||||
func (pty *Pty) Read(p []byte) (n int, err error) {
|
||||
return pty.tty.StdOut.Read(p)
|
||||
}
|
||||
|
||||
func (pty *Pty) Setsize(cols, rows uint32) error {
|
||||
pty.tty.SetSize(cols, rows)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pty *Pty) Close() error {
|
||||
pty.tty.Close()
|
||||
return nil
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/host"
|
||||
)
|
||||
|
||||
func main() {
|
||||
info, err := host.Info()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
log.Printf("%#v", info)
|
||||
}
|
||||
Reference in New Issue
Block a user