v0.9.21 WebSSH

This commit is contained in:
naiba
2021-08-18 11:56:54 +08:00
parent 960266bf71
commit 9bf536b68a
13 changed files with 597 additions and 99 deletions

View File

@@ -3,20 +3,27 @@ package main
import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"os"
"os/exec"
"runtime"
"syscall"
"time"
"unsafe"
"github.com/blang/semver"
"github.com/genkiroid/cert"
"github.com/go-ping/ping"
"github.com/gorilla/websocket"
"github.com/kr/pty"
"github.com/p14yground/go-github-selfupdate/selfupdate"
"google.golang.org/grpc"
@@ -58,6 +65,13 @@ const (
networkTimeOut = time.Second * 5 // 普通网络超时
)
type windowSize struct {
Rows uint16 `json:"rows"`
Cols uint16 `json:"cols"`
X uint16
Y uint16
}
func main() {
// 来自于 GoReleaser 的版本号
monitor.Version = version
@@ -154,7 +168,14 @@ func receiveTasks(tasks pb.NezhaService_RequestTaskClient) error {
if err != nil {
return err
}
go doTask(task)
go func() {
defer func() {
if recover() != nil {
println("task panic", task)
}
}()
doTask(task)
}()
}
}
@@ -163,100 +184,16 @@ func doTask(task *pb.Task) {
result.Id = task.GetId()
result.Type = task.GetType()
switch task.GetType() {
case model.TaskTypeTerminal:
handleTerminalTask(task)
case model.TaskTypeHTTPGET:
start := time.Now()
resp, err := httpClient.Get(task.GetData())
if err == nil {
// 检查 HTTP Response 状态
result.Delay = float32(time.Since(start).Microseconds()) / 1000.0
if resp.StatusCode > 399 || resp.StatusCode < 200 {
err = errors.New("\n应用错误" + resp.Status)
}
}
if err == nil {
// 检查 SSL 证书信息
serviceUrl, err := url.Parse(task.GetData())
if err == nil {
if serviceUrl.Scheme == "https" {
c := cert.NewCert(serviceUrl.Host)
if c.Error != "" {
result.Data = "SSL证书错误" + c.Error
} else {
result.Data = c.Issuer + "|" + c.NotAfter
result.Successful = true
}
} else {
result.Successful = true
}
} else {
result.Data = "URL解析错误" + err.Error()
}
} else {
// HTTP 请求失败
result.Data = err.Error()
}
handleHttpGetTask(task, &result)
case model.TaskTypeICMPPing:
pinger, err := ping.NewPinger(task.GetData())
if err == nil {
pinger.SetPrivileged(true)
pinger.Count = 5
pinger.Timeout = time.Second * 20
err = pinger.Run() // Blocks until finished.
}
if err == nil {
result.Delay = float32(pinger.Statistics().AvgRtt.Microseconds()) / 1000.0
result.Successful = true
} else {
result.Data = err.Error()
}
handleIcmpPingTask(task, &result)
case model.TaskTypeTCPPing:
start := time.Now()
conn, err := net.DialTimeout("tcp", task.GetData(), time.Second*10)
if err == nil {
conn.Write([]byte("ping\n"))
conn.Close()
result.Delay = float32(time.Since(start).Microseconds()) / 1000.0
result.Successful = true
} else {
result.Data = err.Error()
}
handleTcpPingTask(task, &result)
case model.TaskTypeCommand:
startedAt := time.Now()
var cmd *exec.Cmd
var endCh = make(chan struct{})
pg, err := utils.NewProcessExitGroup()
if err != nil {
// 进程组创建失败,直接退出
result.Data = err.Error()
client.ReportTask(context.Background(), &result)
return
}
timeout := time.NewTimer(time.Hour * 2)
if utils.IsWindows() {
cmd = exec.Command("cmd", "/c", task.GetData())
} else {
cmd = exec.Command("sh", "-c", task.GetData())
}
pg.AddProcess(cmd)
go func() {
select {
case <-timeout.C:
result.Data = "任务执行超时\n"
close(endCh)
pg.Dispose()
case <-endCh:
timeout.Stop()
}
}()
output, err := cmd.Output()
if err != nil {
result.Data += fmt.Sprintf("%s\n%s", string(output), err.Error())
} else {
close(endCh)
result.Data = string(output)
result.Successful = true
}
result.Delay = float32(time.Since(startedAt).Seconds())
handleCommandTask(task, &result)
default:
println("Unknown action: ", task)
}
@@ -307,6 +244,211 @@ func doSelfUpdate() {
}
}
func handleTcpPingTask(task *pb.Task, result *pb.TaskResult) {
start := time.Now()
conn, err := net.DialTimeout("tcp", task.GetData(), time.Second*10)
if err == nil {
conn.Write([]byte("ping\n"))
conn.Close()
result.Delay = float32(time.Since(start).Microseconds()) / 1000.0
result.Successful = true
} else {
result.Data = err.Error()
}
}
func handleIcmpPingTask(task *pb.Task, result *pb.TaskResult) {
pinger, err := ping.NewPinger(task.GetData())
if err == nil {
pinger.SetPrivileged(true)
pinger.Count = 5
pinger.Timeout = time.Second * 20
err = pinger.Run() // Blocks until finished.
}
if err == nil {
result.Delay = float32(pinger.Statistics().AvgRtt.Microseconds()) / 1000.0
result.Successful = true
} else {
result.Data = err.Error()
}
}
func handleHttpGetTask(task *pb.Task, result *pb.TaskResult) {
start := time.Now()
resp, err := httpClient.Get(task.GetData())
if err == nil {
// 检查 HTTP Response 状态
result.Delay = float32(time.Since(start).Microseconds()) / 1000.0
if resp.StatusCode > 399 || resp.StatusCode < 200 {
err = errors.New("\n应用错误" + resp.Status)
}
}
if err == nil {
// 检查 SSL 证书信息
serviceUrl, err := url.Parse(task.GetData())
if err == nil {
if serviceUrl.Scheme == "https" {
c := cert.NewCert(serviceUrl.Host)
if c.Error != "" {
result.Data = "SSL证书错误" + c.Error
} else {
result.Data = c.Issuer + "|" + c.NotAfter
result.Successful = true
}
} else {
result.Successful = true
}
} else {
result.Data = "URL解析错误" + err.Error()
}
} else {
// HTTP 请求失败
result.Data = err.Error()
}
}
func handleCommandTask(task *pb.Task, result *pb.TaskResult) {
startedAt := time.Now()
var cmd *exec.Cmd
var endCh = make(chan struct{})
pg, err := utils.NewProcessExitGroup()
if err != nil {
// 进程组创建失败,直接退出
result.Data = err.Error()
return
}
timeout := time.NewTimer(time.Hour * 2)
if utils.IsWindows() {
cmd = exec.Command("cmd", "/c", task.GetData())
} else {
cmd = exec.Command("sh", "-c", task.GetData())
}
pg.AddProcess(cmd)
go func() {
select {
case <-timeout.C:
result.Data = "任务执行超时\n"
close(endCh)
pg.Dispose()
case <-endCh:
timeout.Stop()
}
}()
output, err := cmd.Output()
if err != nil {
result.Data += fmt.Sprintf("%s\n%s", string(output), err.Error())
} else {
close(endCh)
result.Data = string(output)
result.Successful = true
}
result.Delay = float32(time.Since(startedAt).Seconds())
}
func handleTerminalTask(task *pb.Task) {
var terminal model.TerminalTask
err := json.Unmarshal([]byte(task.GetData()), &terminal)
if err != nil {
println("Terminal 任务解析错误:", err)
return
}
protocol := "ws"
if terminal.UseSSL {
protocol += "s"
}
header := http.Header{}
header.Add("Secret", clientSecret)
conn, _, err := websocket.DefaultDialer.Dial(fmt.Sprintf("%s://%s/terminal/%s", protocol, terminal.Host, terminal.Session), header)
if err != nil {
println("Terminal 连接失败:", err)
return
}
defer conn.Close()
var cmd *exec.Cmd
var shellPath string
if runtime.GOOS == "windows" {
shellPath, err = exec.LookPath("powershell.exe")
if err != nil || shellPath == "" {
shellPath = "cmd.exe"
}
} else {
shellPath = os.Getenv("SHELL")
if shellPath == "" {
shellPath = "sh"
}
}
cmd = exec.Command(shellPath)
cmd.Env = append(os.Environ(), "TERM=xterm")
tty, err := pty.Start(cmd)
if err != nil {
println("Terminal pty.Start失败", err)
return
}
defer func() {
cmd.Process.Kill()
cmd.Process.Wait()
tty.Close()
conn.Close()
println("terminal exit", terminal.Session)
}()
println("terminal init", terminal.Session, shellPath)
go func() {
for {
buf := make([]byte, 1024)
read, err := tty.Read(buf)
if err != nil {
conn.WriteMessage(websocket.TextMessage, []byte(err.Error()))
return
}
conn.WriteMessage(websocket.BinaryMessage, buf[:read])
}
}()
for {
messageType, reader, err := conn.NextReader()
if err != nil {
return
}
if messageType == websocket.TextMessage {
continue
}
dataTypeBuf := make([]byte, 1)
read, err := reader.Read(dataTypeBuf)
if err != nil {
conn.WriteMessage(websocket.TextMessage, []byte("Unable to read message type from reader"))
return
}
if read != 1 {
return
}
switch dataTypeBuf[0] {
case 0:
io.Copy(tty, reader)
case 1:
decoder := json.NewDecoder(reader)
resizeMessage := windowSize{}
err := decoder.Decode(&resizeMessage)
if err != nil {
continue
}
syscall.Syscall(
syscall.SYS_IOCTL,
tty.Fd(),
syscall.TIOCSWINSZ,
uintptr(unsafe.Pointer(&resizeMessage)),
)
}
}
}
func println(v ...interface{}) {
if debug {
log.Println(v...)