add conditional compilation support
add multi core support
This commit is contained in:
Yuzuki616
2023-06-08 01:18:56 +08:00
parent 925542b515
commit d76c6a73eb
36 changed files with 1721 additions and 279 deletions

74
core/hy/acme.go Normal file
View File

@@ -0,0 +1,74 @@
package hy
import (
"context"
"crypto/tls"
"os"
"path/filepath"
"runtime"
"go.uber.org/zap"
"github.com/caddyserver/certmagic"
)
func acmeTLSConfig(domains []string, email string, disableHTTP bool, disableTLSALPN bool,
altHTTPPort int, altTLSALPNPort int,
) (*tls.Config, error) {
cfg := &certmagic.Config{
RenewalWindowRatio: certmagic.DefaultRenewalWindowRatio,
KeySource: certmagic.DefaultKeyGenerator,
Storage: &certmagic.FileStorage{Path: dataDir()},
Logger: zap.NewNop(),
}
issuer := certmagic.NewACMEIssuer(cfg, certmagic.ACMEIssuer{
CA: certmagic.LetsEncryptProductionCA,
TestCA: certmagic.LetsEncryptStagingCA,
Email: email,
Agreed: true,
DisableHTTPChallenge: disableHTTP,
DisableTLSALPNChallenge: disableTLSALPN,
AltHTTPPort: altHTTPPort,
AltTLSALPNPort: altTLSALPNPort,
Logger: zap.NewNop(),
})
cfg.Issuers = []certmagic.Issuer{issuer}
cache := certmagic.NewCache(certmagic.CacheOptions{
GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) {
return cfg, nil
},
Logger: zap.NewNop(),
})
cfg = certmagic.New(cache, *cfg)
err := cfg.ManageSync(context.Background(), domains)
if err != nil {
return nil, err
}
return cfg.TLSConfig(), nil
}
func homeDir() string {
home := os.Getenv("HOME")
if home == "" && runtime.GOOS == "windows" {
drive := os.Getenv("HOMEDRIVE")
path := os.Getenv("HOMEPATH")
home = drive + path
if drive == "" || path == "" {
home = os.Getenv("USERPROFILE")
}
}
if home == "" {
home = "."
}
return home
}
func dataDir() string {
baseDir := filepath.Join(homeDir(), ".local", "share")
if xdgData := os.Getenv("XDG_DATA_HOME"); xdgData != "" {
baseDir = xdgData
}
return filepath.Join(baseDir, "certmagic")
}

307
core/hy/config.go Normal file
View File

@@ -0,0 +1,307 @@
package hy
import (
"errors"
"fmt"
"github.com/yosuke-furukawa/json5/encoding/json5"
"regexp"
"strconv"
)
const (
mbpsToBps = 125000
minSpeedBPS = 16384
DefaultALPN = "hysteria"
DefaultStreamReceiveWindow = 16777216 // 16 MB
DefaultConnectionReceiveWindow = DefaultStreamReceiveWindow * 5 / 2 // 40 MB
DefaultMaxIncomingStreams = 1024
DefaultMMDBFilename = "GeoLite2-Country.mmdb"
ServerMaxIdleTimeoutSec = 60
DefaultClientIdleTimeoutSec = 20
DefaultClientHopIntervalSec = 10
)
var rateStringRegexp = regexp.MustCompile(`^(\d+)\s*([KMGT]?)([Bb])ps$`)
type serverConfig struct {
Listen string `json:"listen"`
Protocol string `json:"protocol"`
ACME struct {
Domains []string `json:"domains"`
Email string `json:"email"`
DisableHTTPChallenge bool `json:"disable_http"`
DisableTLSALPNChallenge bool `json:"disable_tlsalpn"`
AltHTTPPort int `json:"alt_http_port"`
AltTLSALPNPort int `json:"alt_tlsalpn_port"`
} `json:"acme"`
CertFile string `json:"cert"`
KeyFile string `json:"key"`
// Optional below
Up string `json:"up"`
UpMbps int `json:"up_mbps"`
Down string `json:"down"`
DownMbps int `json:"down_mbps"`
DisableUDP bool `json:"disable_udp"`
ACL string `json:"acl"`
MMDB string `json:"mmdb"`
Obfs string `json:"obfs"`
Auth struct {
Mode string `json:"mode"`
Config json5.RawMessage `json:"config"`
} `json:"auth"`
ALPN string `json:"alpn"`
PrometheusListen string `json:"prometheus_listen"`
ReceiveWindowConn uint64 `json:"recv_window_conn"`
ReceiveWindowClient uint64 `json:"recv_window_client"`
MaxConnClient int `json:"max_conn_client"`
DisableMTUDiscovery bool `json:"disable_mtu_discovery"`
Resolver string `json:"resolver"`
ResolvePreference string `json:"resolve_preference"`
SOCKS5Outbound struct {
Server string `json:"server"`
User string `json:"user"`
Password string `json:"password"`
} `json:"socks5_outbound"`
BindOutbound struct {
Address string `json:"address"`
Device string `json:"device"`
} `json:"bind_outbound"`
}
func (c *serverConfig) Speed() (uint64, uint64, error) {
var up, down uint64
if len(c.Up) > 0 {
up = stringToBps(c.Up)
if up == 0 {
return 0, 0, errors.New("invalid speed format")
}
} else {
up = uint64(c.UpMbps) * mbpsToBps
}
if len(c.Down) > 0 {
down = stringToBps(c.Down)
if down == 0 {
return 0, 0, errors.New("invalid speed format")
}
} else {
down = uint64(c.DownMbps) * mbpsToBps
}
return up, down, nil
}
func (c *serverConfig) Check() error {
if len(c.Listen) == 0 {
return errors.New("missing listen address")
}
if len(c.ACME.Domains) == 0 && (len(c.CertFile) == 0 || len(c.KeyFile) == 0) {
return errors.New("need either ACME info or cert/key files")
}
if len(c.ACME.Domains) > 0 && (len(c.CertFile) > 0 || len(c.KeyFile) > 0) {
return errors.New("cannot use both ACME and cert/key files, they are mutually exclusive")
}
if up, down, err := c.Speed(); err != nil || (up != 0 && up < minSpeedBPS) || (down != 0 && down < minSpeedBPS) {
return errors.New("invalid speed")
}
if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) ||
(c.ReceiveWindowClient != 0 && c.ReceiveWindowClient < 65536) {
return errors.New("invalid receive window size")
}
if c.MaxConnClient < 0 {
return errors.New("invalid max connections per client")
}
return nil
}
func (c *serverConfig) Fill() {
if len(c.ALPN) == 0 {
c.ALPN = DefaultALPN
}
if c.ReceiveWindowConn == 0 {
c.ReceiveWindowConn = DefaultStreamReceiveWindow
}
if c.ReceiveWindowClient == 0 {
c.ReceiveWindowClient = DefaultConnectionReceiveWindow
}
if c.MaxConnClient == 0 {
c.MaxConnClient = DefaultMaxIncomingStreams
}
if len(c.MMDB) == 0 {
c.MMDB = DefaultMMDBFilename
}
}
func (c *serverConfig) String() string {
return fmt.Sprintf("%+v", *c)
}
type Relay struct {
Listen string `json:"listen"`
Remote string `json:"remote"`
Timeout int `json:"timeout"`
}
func (r *Relay) Check() error {
if len(r.Listen) == 0 {
return errors.New("missing relay listen address")
}
if len(r.Remote) == 0 {
return errors.New("missing relay remote address")
}
if r.Timeout != 0 && r.Timeout < 4 {
return errors.New("invalid relay timeout")
}
return nil
}
type clientConfig struct {
Server string `json:"server"`
Protocol string `json:"protocol"`
Up string `json:"up"`
UpMbps int `json:"up_mbps"`
Down string `json:"down"`
DownMbps int `json:"down_mbps"`
// Optional below
Retry int `json:"retry"`
RetryInterval *int `json:"retry_interval"`
QuitOnDisconnect bool `json:"quit_on_disconnect"`
HandshakeTimeout int `json:"handshake_timeout"`
IdleTimeout int `json:"idle_timeout"`
HopInterval int `json:"hop_interval"`
SOCKS5 struct {
Listen string `json:"listen"`
Timeout int `json:"timeout"`
DisableUDP bool `json:"disable_udp"`
User string `json:"user"`
Password string `json:"password"`
} `json:"socks5"`
HTTP struct {
Listen string `json:"listen"`
Timeout int `json:"timeout"`
User string `json:"user"`
Password string `json:"password"`
Cert string `json:"cert"`
Key string `json:"key"`
} `json:"http"`
TUN struct {
Name string `json:"name"`
Timeout int `json:"timeout"`
MTU uint32 `json:"mtu"`
TCPSendBufferSize string `json:"tcp_sndbuf"`
TCPReceiveBufferSize string `json:"tcp_rcvbuf"`
TCPModerateReceiveBuffer bool `json:"tcp_autotuning"`
} `json:"tun"`
TCPRelays []Relay `json:"relay_tcps"`
TCPRelay Relay `json:"relay_tcp"` // deprecated, but we still support it for backward compatibility
UDPRelays []Relay `json:"relay_udps"`
UDPRelay Relay `json:"relay_udp"` // deprecated, but we still support it for backward compatibility
TCPTProxy struct {
Listen string `json:"listen"`
Timeout int `json:"timeout"`
} `json:"tproxy_tcp"`
UDPTProxy struct {
Listen string `json:"listen"`
Timeout int `json:"timeout"`
} `json:"tproxy_udp"`
TCPRedirect struct {
Listen string `json:"listen"`
Timeout int `json:"timeout"`
} `json:"redirect_tcp"`
ACL string `json:"acl"`
MMDB string `json:"mmdb"`
Obfs string `json:"obfs"`
Auth []byte `json:"auth"`
AuthString string `json:"auth_str"`
ALPN string `json:"alpn"`
ServerName string `json:"server_name"`
Insecure bool `json:"insecure"`
CustomCA string `json:"ca"`
ReceiveWindowConn uint64 `json:"recv_window_conn"`
ReceiveWindow uint64 `json:"recv_window"`
DisableMTUDiscovery bool `json:"disable_mtu_discovery"`
FastOpen bool `json:"fast_open"`
LazyStart bool `json:"lazy_start"`
Resolver string `json:"resolver"`
ResolvePreference string `json:"resolve_preference"`
}
func (c *clientConfig) Speed() (uint64, uint64, error) {
var up, down uint64
if len(c.Up) > 0 {
up = stringToBps(c.Up)
if up == 0 {
return 0, 0, errors.New("invalid speed format")
}
} else {
up = uint64(c.UpMbps) * mbpsToBps
}
if len(c.Down) > 0 {
down = stringToBps(c.Down)
if down == 0 {
return 0, 0, errors.New("invalid speed format")
}
} else {
down = uint64(c.DownMbps) * mbpsToBps
}
return up, down, nil
}
func (c *clientConfig) Fill() {
if len(c.ALPN) == 0 {
c.ALPN = DefaultALPN
}
if c.ReceiveWindowConn == 0 {
c.ReceiveWindowConn = DefaultStreamReceiveWindow
}
if c.ReceiveWindow == 0 {
c.ReceiveWindow = DefaultConnectionReceiveWindow
}
if len(c.MMDB) == 0 {
c.MMDB = DefaultMMDBFilename
}
if c.IdleTimeout == 0 {
c.IdleTimeout = DefaultClientIdleTimeoutSec
}
if c.HopInterval == 0 {
c.HopInterval = DefaultClientHopIntervalSec
}
}
func (c *clientConfig) String() string {
return fmt.Sprintf("%+v", *c)
}
func stringToBps(s string) uint64 {
if s == "" {
return 0
}
m := rateStringRegexp.FindStringSubmatch(s)
if m == nil {
return 0
}
var n uint64
switch m[2] {
case "K":
n = 1 << 10
case "M":
n = 1 << 20
case "G":
n = 1 << 30
case "T":
n = 1 << 40
default:
n = 1
}
v, _ := strconv.ParseUint(m[1], 10, 64)
n = v * n
if m[3] == "b" {
// Bits, need to convert to bytes
n = n >> 3
}
return n
}

306
core/hy/hy.go Normal file
View File

@@ -0,0 +1,306 @@
package hy
import (
"crypto/tls"
"io"
"net"
"net/http"
"time"
"github.com/oschwald/geoip2-golang"
"github.com/quic-go/quic-go"
"github.com/apernet/hysteria/app/auth"
"github.com/apernet/hysteria/core/acl"
"github.com/apernet/hysteria/core/cs"
"github.com/apernet/hysteria/core/pktconns"
"github.com/apernet/hysteria/core/pmtud"
"github.com/apernet/hysteria/core/sockopt"
"github.com/apernet/hysteria/core/transport"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/sirupsen/logrus"
"github.com/yosuke-furukawa/json5/encoding/json5"
)
var serverPacketConnFuncFactoryMap = map[string]pktconns.ServerPacketConnFuncFactory{
"": pktconns.NewServerUDPConnFunc,
"udp": pktconns.NewServerUDPConnFunc,
"wechat": pktconns.NewServerWeChatConnFunc,
"wechat-video": pktconns.NewServerWeChatConnFunc,
"faketcp": pktconns.NewServerFakeTCPConnFunc,
}
func server(config *serverConfig) {
logrus.WithField("config", config.String()).Info("Server configuration loaded")
config.Fill() // Fill default values
// Resolver
if len(config.Resolver) > 0 {
err := setResolver(config.Resolver)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
}).Fatal("Failed to set resolver")
}
}
// Load TLS config
var tlsConfig *tls.Config
if len(config.ACME.Domains) > 0 {
// ACME mode
tc, err := acmeTLSConfig(config.ACME.Domains, config.ACME.Email,
config.ACME.DisableHTTPChallenge, config.ACME.DisableTLSALPNChallenge,
config.ACME.AltHTTPPort, config.ACME.AltTLSALPNPort)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
}).Fatal("Failed to get a certificate with ACME")
}
tc.NextProtos = []string{config.ALPN}
tc.MinVersion = tls.VersionTLS13
tlsConfig = tc
} else {
// Local cert mode
kpl, err := newKeypairLoader(config.CertFile, config.KeyFile)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
"cert": config.CertFile,
"key": config.KeyFile,
}).Fatal("Failed to load the certificate")
}
tlsConfig = &tls.Config{
GetCertificate: kpl.GetCertificateFunc(),
NextProtos: []string{config.ALPN},
MinVersion: tls.VersionTLS13,
}
}
// QUIC config
quicConfig := &quic.Config{
InitialStreamReceiveWindow: config.ReceiveWindowConn,
MaxStreamReceiveWindow: config.ReceiveWindowConn,
InitialConnectionReceiveWindow: config.ReceiveWindowClient,
MaxConnectionReceiveWindow: config.ReceiveWindowClient,
MaxIncomingStreams: int64(config.MaxConnClient),
MaxIdleTimeout: ServerMaxIdleTimeoutSec * time.Second,
KeepAlivePeriod: 0, // Keep alive should solely be client's responsibility
DisablePathMTUDiscovery: config.DisableMTUDiscovery,
EnableDatagrams: true,
}
if !quicConfig.DisablePathMTUDiscovery && pmtud.DisablePathMTUDiscovery {
logrus.Info("Path MTU Discovery is not yet supported on this platform")
}
// Auth
var authFunc cs.ConnectFunc
var err error
switch authMode := config.Auth.Mode; authMode {
case "", "none":
if len(config.Obfs) == 0 {
logrus.Warn("Neither authentication nor obfuscation is turned on. " +
"Your server could be used by anyone! Are you sure this is what you want?")
}
authFunc = func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) {
return true, "Welcome"
}
case "password", "passwords":
authFunc, err = auth.PasswordAuthFunc(config.Auth.Config)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
}).Fatal("Failed to enable password authentication")
} else {
logrus.Info("Password authentication enabled")
}
case "external":
authFunc, err = auth.ExternalAuthFunc(config.Auth.Config)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
}).Fatal("Failed to enable external authentication")
} else {
logrus.Info("External authentication enabled")
}
default:
logrus.WithField("mode", config.Auth.Mode).Fatal("Unsupported authentication mode")
}
connectFunc := func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) {
ok, msg := authFunc(addr, auth, sSend, sRecv)
if !ok {
logrus.WithFields(logrus.Fields{
"src": defaultIPMasker.Mask(addr.String()),
"msg": msg,
}).Info("Authentication failed, client rejected")
} else {
logrus.WithFields(logrus.Fields{
"src": defaultIPMasker.Mask(addr.String()),
}).Info("Client connected")
}
return ok, msg
}
// Resolve preference
if len(config.ResolvePreference) > 0 {
pref, err := transport.ResolvePreferenceFromString(config.ResolvePreference)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
}).Fatal("Failed to parse the resolve preference")
}
transport.DefaultServerTransport.ResolvePreference = pref
}
// SOCKS5 outbound
if config.SOCKS5Outbound.Server != "" {
transport.DefaultServerTransport.SOCKS5Client = transport.NewSOCKS5Client(config.SOCKS5Outbound.Server,
config.SOCKS5Outbound.User, config.SOCKS5Outbound.Password)
}
// Bind outbound
if config.BindOutbound.Device != "" {
iface, err := net.InterfaceByName(config.BindOutbound.Device)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
}).Fatal("Failed to find the interface")
}
transport.DefaultServerTransport.LocalUDPIntf = iface
sockopt.BindDialer(transport.DefaultServerTransport.Dialer, iface)
}
if config.BindOutbound.Address != "" {
ip := net.ParseIP(config.BindOutbound.Address)
if ip == nil {
logrus.WithFields(logrus.Fields{
"error": err,
}).Fatal("Failed to parse the address")
}
transport.DefaultServerTransport.Dialer.LocalAddr = &net.TCPAddr{IP: ip}
transport.DefaultServerTransport.LocalUDPAddr = &net.UDPAddr{IP: ip}
}
// ACL
var aclEngine *acl.Engine
if len(config.ACL) > 0 {
aclEngine, err = acl.LoadFromFile(config.ACL, func(addr string) (*net.IPAddr, error) {
ipAddr, _, err := transport.DefaultServerTransport.ResolveIPAddr(addr)
return ipAddr, err
},
func() (*geoip2.Reader, error) {
return loadMMDBReader(config.MMDB)
})
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
"file": config.ACL,
}).Fatal("Failed to parse ACL")
}
aclEngine.DefaultAction = acl.ActionDirect
}
// Prometheus
var trafficCounter cs.TrafficCounter
if len(config.PrometheusListen) > 0 {
promReg := prometheus.NewRegistry()
trafficCounter = NewPrometheusTrafficCounter(promReg)
go func() {
http.Handle("/metrics", promhttp.HandlerFor(promReg, promhttp.HandlerOpts{}))
err := http.ListenAndServe(config.PrometheusListen, nil)
logrus.WithField("error", err).Fatal("Prometheus HTTP server error")
}()
}
// Packet conn
pktConnFuncFactory := serverPacketConnFuncFactoryMap[config.Protocol]
if pktConnFuncFactory == nil {
logrus.WithField("protocol", config.Protocol).Fatal("Unsupported protocol")
}
pktConnFunc := pktConnFuncFactory(config.Obfs)
pktConn, err := pktConnFunc(config.Listen)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
"addr": config.Listen,
}).Fatal("Failed to listen on the UDP address")
}
// Server
up, down, _ := config.Speed()
server, err := cs.NewServer(tlsConfig, quicConfig, pktConn,
transport.DefaultServerTransport, up, down, config.DisableUDP, aclEngine,
connectFunc, disconnectFunc, tcpRequestFunc, tcpErrorFunc, udpRequestFunc, udpErrorFunc, trafficCounter)
if err != nil {
logrus.WithField("error", err).Fatal("Failed to initialize server")
}
defer server.Close()
logrus.WithField("addr", config.Listen).Info("Server up and running")
err = server.Serve()
logrus.WithField("error", err).Fatal("Server shutdown")
}
func disconnectFunc(addr net.Addr, auth []byte, err error) {
logrus.WithFields(logrus.Fields{
"src": defaultIPMasker.Mask(addr.String()),
"error": err,
}).Info("Client disconnected")
}
func tcpRequestFunc(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string) {
logrus.WithFields(logrus.Fields{
"src": defaultIPMasker.Mask(addr.String()),
"dst": defaultIPMasker.Mask(reqAddr),
"action": actionToString(action, arg),
}).Debug("TCP request")
}
func tcpErrorFunc(addr net.Addr, auth []byte, reqAddr string, err error) {
if err != io.EOF {
logrus.WithFields(logrus.Fields{
"src": defaultIPMasker.Mask(addr.String()),
"dst": defaultIPMasker.Mask(reqAddr),
"error": err,
}).Info("TCP error")
} else {
logrus.WithFields(logrus.Fields{
"src": defaultIPMasker.Mask(addr.String()),
"dst": defaultIPMasker.Mask(reqAddr),
}).Debug("TCP EOF")
}
}
func udpRequestFunc(addr net.Addr, auth []byte, sessionID uint32) {
logrus.WithFields(logrus.Fields{
"src": defaultIPMasker.Mask(addr.String()),
"session": sessionID,
}).Debug("UDP request")
}
func udpErrorFunc(addr net.Addr, auth []byte, sessionID uint32, err error) {
if err != io.EOF {
logrus.WithFields(logrus.Fields{
"src": defaultIPMasker.Mask(addr.String()),
"session": sessionID,
"error": err,
}).Info("UDP error")
} else {
logrus.WithFields(logrus.Fields{
"src": defaultIPMasker.Mask(addr.String()),
"session": sessionID,
}).Debug("UDP EOF")
}
}
func actionToString(action acl.Action, arg string) string {
switch action {
case acl.ActionDirect:
return "Direct"
case acl.ActionProxy:
return "Proxy"
case acl.ActionBlock:
return "Block"
case acl.ActionHijack:
return "Hijack to " + arg
default:
return "Unknown"
}
}
func parseServerConfig(cb []byte) (*serverConfig, error) {
var c serverConfig
err := json5.Unmarshal(cb, &c)
if err != nil {
return nil, err
}
return &c, c.Check()
}

43
core/hy/ipmasker.go Normal file
View File

@@ -0,0 +1,43 @@
package hy
import (
"net"
)
type ipMasker struct {
IPv4Mask net.IPMask
IPv6Mask net.IPMask
}
// Mask masks an address with the configured CIDR.
// addr can be "host:port" or just host.
func (m *ipMasker) Mask(addr string) string {
if m.IPv4Mask == nil && m.IPv6Mask == nil {
return addr
}
host, port, err := net.SplitHostPort(addr)
if err != nil {
// just host
host, port = addr, ""
}
ip := net.ParseIP(host)
if ip == nil {
// not an IP address, return as is
return addr
}
if ip4 := ip.To4(); ip4 != nil && m.IPv4Mask != nil {
// IPv4
host = ip4.Mask(m.IPv4Mask).String()
} else if ip6 := ip.To16(); ip6 != nil && m.IPv6Mask != nil {
// IPv6
host = ip6.Mask(m.IPv6Mask).String()
}
if port != "" {
return net.JoinHostPort(host, port)
} else {
return host
}
}
var defaultIPMasker = &ipMasker{}

95
core/hy/kploader.go Normal file
View File

@@ -0,0 +1,95 @@
package hy
import (
"crypto/tls"
"sync"
"github.com/fsnotify/fsnotify"
"github.com/sirupsen/logrus"
)
type keypairLoader struct {
certMu sync.RWMutex
cert *tls.Certificate
certPath string
keyPath string
}
func newKeypairLoader(certPath, keyPath string) (*keypairLoader, error) {
loader := &keypairLoader{
certPath: certPath,
keyPath: keyPath,
}
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, err
}
loader.cert = &cert
watcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
}
go func() {
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
switch event.Op {
case fsnotify.Create, fsnotify.Write, fsnotify.Rename, fsnotify.Chmod:
logrus.WithFields(logrus.Fields{
"file": event.Name,
}).Info("Keypair change detected, reloading...")
if err := loader.load(); err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
}).Error("Failed to reload keypair")
} else {
logrus.Info("Keypair successfully reloaded")
}
case fsnotify.Remove:
_ = watcher.Add(event.Name) // Workaround for vim
// https://github.com/fsnotify/fsnotify/issues/92
}
case err, ok := <-watcher.Errors:
if !ok {
return
}
logrus.WithFields(logrus.Fields{
"error": err,
}).Error("Failed to watch keypair files for changes")
}
}
}()
err = watcher.Add(certPath)
if err != nil {
_ = watcher.Close()
return nil, err
}
err = watcher.Add(keyPath)
if err != nil {
_ = watcher.Close()
return nil, err
}
return loader, nil
}
func (kpr *keypairLoader) load() error {
cert, err := tls.LoadX509KeyPair(kpr.certPath, kpr.keyPath)
if err != nil {
return err
}
kpr.certMu.Lock()
kpr.cert = &cert
kpr.certMu.Unlock()
return nil
}
func (kpr *keypairLoader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
kpr.certMu.RLock()
defer kpr.certMu.RUnlock()
return kpr.cert, nil
}
}

49
core/hy/mmdb.go Normal file
View File

@@ -0,0 +1,49 @@
package hy
import (
"io"
"net/http"
"os"
"github.com/oschwald/geoip2-golang"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
)
func downloadMMDB(filename string) error {
resp, err := http.Get(viper.GetString("mmdb-url"))
if err != nil {
return err
}
defer resp.Body.Close()
file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()
_, err = io.Copy(file, resp.Body)
return err
}
func loadMMDBReader(filename string) (*geoip2.Reader, error) {
if _, err := os.Stat(filename); err != nil {
if os.IsNotExist(err) {
logrus.Info("GeoLite2 database not found, downloading...")
if err := downloadMMDB(filename); err != nil {
return nil, err
}
logrus.WithFields(logrus.Fields{
"file": filename,
}).Info("GeoLite2 database downloaded")
return geoip2.Open(filename)
} else {
// some other error
return nil, err
}
} else {
// file exists, just open it
return geoip2.Open(filename)
}
}

71
core/hy/prom.go Normal file
View File

@@ -0,0 +1,71 @@
package hy
import (
"github.com/apernet/hysteria/core/cs"
"github.com/prometheus/client_golang/prometheus"
)
type prometheusTrafficCounter struct {
reg *prometheus.Registry
upCounterVec *prometheus.CounterVec
downCounterVec *prometheus.CounterVec
connGaugeVec *prometheus.GaugeVec
counterMap map[string]counters
}
type counters struct {
UpCounter prometheus.Counter
DownCounter prometheus.Counter
ConnGauge prometheus.Gauge
}
func NewPrometheusTrafficCounter(reg *prometheus.Registry) cs.TrafficCounter {
c := &prometheusTrafficCounter{
reg: reg,
upCounterVec: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "hysteria_traffic_uplink_bytes_total",
}, []string{"auth"}),
downCounterVec: prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "hysteria_traffic_downlink_bytes_total",
}, []string{"auth"}),
connGaugeVec: prometheus.NewGaugeVec(prometheus.GaugeOpts{
Name: "hysteria_active_conn",
}, []string{"auth"}),
counterMap: make(map[string]counters),
}
reg.MustRegister(c.upCounterVec, c.downCounterVec, c.connGaugeVec)
return c
}
func (c *prometheusTrafficCounter) getCounters(auth string) counters {
cts, ok := c.counterMap[auth]
if !ok {
cts = counters{
UpCounter: c.upCounterVec.WithLabelValues(auth),
DownCounter: c.downCounterVec.WithLabelValues(auth),
ConnGauge: c.connGaugeVec.WithLabelValues(auth),
}
c.counterMap[auth] = cts
}
return cts
}
func (c *prometheusTrafficCounter) Rx(auth string, n int) {
cts := c.getCounters(auth)
cts.DownCounter.Add(float64(n))
}
func (c *prometheusTrafficCounter) Tx(auth string, n int) {
cts := c.getCounters(auth)
cts.UpCounter.Add(float64(n))
}
func (c *prometheusTrafficCounter) IncConn(auth string) {
cts := c.getCounters(auth)
cts.ConnGauge.Inc()
}
func (c *prometheusTrafficCounter) DecConn(auth string) {
cts := c.getCounters(auth)
cts.ConnGauge.Dec()
}

123
core/hy/resolver.go Normal file
View File

@@ -0,0 +1,123 @@
package hy
import (
"crypto/tls"
"errors"
"net"
"net/url"
"strings"
"github.com/apernet/hysteria/core/utils"
rdns "github.com/folbricht/routedns"
)
var errInvalidSyntax = errors.New("invalid syntax")
func setResolver(dns string) error {
if net.ParseIP(dns) != nil {
// Just an IP address, treat as UDP 53
dns = "udp://" + net.JoinHostPort(dns, "53")
}
var r rdns.Resolver
if strings.HasPrefix(dns, "udp://") {
// Standard UDP DNS resolver
dns = strings.TrimPrefix(dns, "udp://")
if dns == "" {
return errInvalidSyntax
}
if _, _, err := utils.SplitHostPort(dns); err != nil {
// Append the default DNS port
dns = net.JoinHostPort(dns, "53")
}
client, err := rdns.NewDNSClient("dns-udp", dns, "udp", rdns.DNSClientOptions{})
if err != nil {
return err
}
r = client
} else if strings.HasPrefix(dns, "tcp://") {
// Standard TCP DNS resolver
dns = strings.TrimPrefix(dns, "tcp://")
if dns == "" {
return errInvalidSyntax
}
if _, _, err := utils.SplitHostPort(dns); err != nil {
// Append the default DNS port
dns = net.JoinHostPort(dns, "53")
}
client, err := rdns.NewDNSClient("dns-tcp", dns, "tcp", rdns.DNSClientOptions{})
if err != nil {
return err
}
r = client
} else if strings.HasPrefix(dns, "https://") {
// DoH resolver
if dohURL, err := url.Parse(dns); err != nil {
return err
} else {
// Need to set bootstrap address to avoid loopback DNS lookup
dohIPAddr, err := net.ResolveIPAddr("ip", dohURL.Hostname())
if err != nil {
return err
}
client, err := rdns.NewDoHClient("doh", dns, rdns.DoHClientOptions{
BootstrapAddr: dohIPAddr.String(),
})
if err != nil {
return err
}
r = client
}
} else if strings.HasPrefix(dns, "tls://") {
// DoT resolver
dns = strings.TrimPrefix(dns, "tls://")
if dns == "" {
return errInvalidSyntax
}
dotHost, _, err := utils.SplitHostPort(dns)
if err != nil {
// Append the default DNS port
dns = net.JoinHostPort(dns, "853")
}
// Need to set bootstrap address to avoid loopback DNS lookup
dotIPAddr, err := net.ResolveIPAddr("ip", dotHost)
if err != nil {
return err
}
client, err := rdns.NewDoTClient("dot", dns, rdns.DoTClientOptions{
BootstrapAddr: dotIPAddr.String(),
TLSConfig: new(tls.Config),
})
if err != nil {
return err
}
r = client
} else if strings.HasPrefix(dns, "quic://") {
// DoQ resolver
dns = strings.TrimPrefix(dns, "quic://")
if dns == "" {
return errInvalidSyntax
}
doqHost, _, err := utils.SplitHostPort(dns)
if err != nil {
// Append the default DNS port
dns = net.JoinHostPort(dns, "853")
}
// Need to set bootstrap address to avoid loopback DNS lookup
doqIPAddr, err := net.ResolveIPAddr("ip", doqHost)
if err != nil {
return err
}
client, err := rdns.NewDoQClient("doq", dns, rdns.DoQClientOptions{
BootstrapAddr: doqIPAddr.String(),
})
if err != nil {
return err
}
r = client
} else {
return errInvalidSyntax
}
cache := rdns.NewCache("cache", r, rdns.CacheOptions{})
net.DefaultResolver = rdns.NewNetResolver(cache)
return nil
}