update ddns on server update (#1050)

This commit is contained in:
UUBulb
2025-03-31 19:41:04 +08:00
committed by GitHub
parent 62ea87da74
commit 67c129635e
7 changed files with 69 additions and 48 deletions

View File

@@ -20,7 +20,6 @@ const (
) )
type Provider struct { type Provider struct {
ctx context.Context
ipAddr string ipAddr string
recordType string recordType string
prefix string prefix string
@@ -36,11 +35,10 @@ func (provider *Provider) GetProfileID() uint64 {
} }
func (provider *Provider) UpdateDomain(ctx context.Context, overrideDomains ...string) { func (provider *Provider) UpdateDomain(ctx context.Context, overrideDomains ...string) {
provider.ctx = ctx
for _, domain := range utils.IfOr(len(overrideDomains) > 0, overrideDomains, provider.DDNSProfile.Domains) { for _, domain := range utils.IfOr(len(overrideDomains) > 0, overrideDomains, provider.DDNSProfile.Domains) {
for retries := 0; retries < int(provider.DDNSProfile.MaxRetries); retries++ { for retries := 0; retries < int(provider.DDNSProfile.MaxRetries); retries++ {
log.Printf("NEZHA>> Updating DNS Record of domain %s: %d/%d", domain, retries+1, provider.DDNSProfile.MaxRetries) log.Printf("NEZHA>> Updating DNS Record of domain %s: %d/%d", domain, retries+1, provider.DDNSProfile.MaxRetries)
if err := provider.updateDomain(domain); err != nil { if err := provider.updateDomain(ctx, domain); err != nil {
log.Printf("NEZHA>> Failed to update DNS record of domain %s: %v", domain, err) log.Printf("NEZHA>> Failed to update DNS record of domain %s: %v", domain, err)
} else { } else {
log.Printf("NEZHA>> Update DNS record of domain %s succeeded", domain) log.Printf("NEZHA>> Update DNS record of domain %s succeeded", domain)
@@ -50,9 +48,9 @@ func (provider *Provider) UpdateDomain(ctx context.Context, overrideDomains ...s
} }
} }
func (provider *Provider) updateDomain(domain string) error { func (provider *Provider) updateDomain(ctx context.Context, domain string) error {
var err error var err error
provider.prefix, provider.zone, err = provider.splitDomainSOA(domain) provider.prefix, provider.zone, err = provider.splitDomainSOA(ctx, domain)
if err != nil { if err != nil {
return err return err
} }
@@ -61,7 +59,7 @@ func (provider *Provider) updateDomain(domain string) error {
if *provider.DDNSProfile.EnableIPv4 { if *provider.DDNSProfile.EnableIPv4 {
provider.recordType = getRecordString(true) provider.recordType = getRecordString(true)
provider.ipAddr = provider.IPAddrs.IPv4Addr provider.ipAddr = provider.IPAddrs.IPv4Addr
if err = provider.addDomainRecord(); err != nil { if err = provider.addDomainRecord(ctx); err != nil {
return err return err
} }
} }
@@ -69,7 +67,7 @@ func (provider *Provider) updateDomain(domain string) error {
if *provider.DDNSProfile.EnableIPv6 { if *provider.DDNSProfile.EnableIPv6 {
provider.recordType = getRecordString(false) provider.recordType = getRecordString(false)
provider.ipAddr = provider.IPAddrs.IPv6Addr provider.ipAddr = provider.IPAddrs.IPv6Addr
if err = provider.addDomainRecord(); err != nil { if err = provider.addDomainRecord(ctx); err != nil {
return err return err
} }
} }
@@ -77,8 +75,8 @@ func (provider *Provider) updateDomain(domain string) error {
return nil return nil
} }
func (provider *Provider) addDomainRecord() error { func (provider *Provider) addDomainRecord(ctx context.Context) error {
_, err := provider.Setter.SetRecords(provider.ctx, provider.zone, _, err := provider.Setter.SetRecords(ctx, provider.zone,
[]libdns.Record{ []libdns.Record{
{ {
Type: provider.recordType, Type: provider.recordType,
@@ -90,7 +88,7 @@ func (provider *Provider) addDomainRecord() error {
return err return err
} }
func (provider *Provider) splitDomainSOA(domain string) (prefix string, zone string, err error) { func (provider *Provider) splitDomainSOA(ctx context.Context, domain string) (prefix string, zone string, err error) {
c := &dns.Client{Timeout: dnsTimeOut} c := &dns.Client{Timeout: dnsTimeOut}
domain += "." domain += "."
@@ -98,26 +96,26 @@ func (provider *Provider) splitDomainSOA(domain string) (prefix string, zone str
servers := utils.DNSServers servers := utils.DNSServers
customDNSServers, _ := provider.ctx.Value(DNSServerKey{}).([]string) customDNSServers, _ := ctx.Value(DNSServerKey{}).([]string)
if len(customDNSServers) > 0 { if len(customDNSServers) > 0 {
servers = customDNSServers servers = customDNSServers
} }
var r *dns.Msg for _, server := range servers {
for _, idx := range indexes { for _, idx := range indexes {
var m dns.Msg var m dns.Msg
m.SetQuestion(domain[idx:], dns.TypeSOA) m.SetQuestion(domain[idx:], dns.TypeSOA)
for _, server := range servers { r, _, err := c.Exchange(&m, server)
r, _, err = c.Exchange(&m, server)
if err != nil { if err != nil {
return continue
} }
if len(r.Answer) > 0 { if len(r.Answer) > 0 {
if soa, ok := r.Answer[0].(*dns.SOA); ok { if soa, ok := r.Answer[0].(*dns.SOA); ok {
zone = soa.Hdr.Name zone := soa.Hdr.Name
prefix = libdns.RelativeName(domain, zone) prefix := libdns.RelativeName(domain, zone)
return return prefix, zone, nil
} }
} }
} }

View File

@@ -35,9 +35,10 @@ func TestSplitDomainSOA(t *testing.T) {
}, },
} }
provider := &Provider{ctx: context.WithValue(context.Background(), DNSServerKey{}, []string{"1.1.1.1:53"})} ctx := context.WithValue(context.Background(), DNSServerKey{}, []string{"1.1.1.1:53"})
provider := &Provider{}
for _, c := range cases { for _, c := range cases {
prefix, zone, err := provider.splitDomainSOA(c.domain) prefix, zone, err := provider.splitDomainSOA(ctx, c.domain)
if err != nil { if err != nil {
t.Fatalf("Error: %s", err) t.Fatalf("Error: %s", err)
} }

View File

@@ -18,7 +18,9 @@ import (
) )
var ( var (
DNSServers = []string{"8.8.8.8:53", "8.8.4.4:53", "1.1.1.1:53", "1.0.0.1:53"} DNSServersV4 = []string{"8.8.8.8:53", "8.8.4.4:53", "1.1.1.1:53", "1.0.0.1:53"}
DNSServersV6 = []string{"[2001:4860:4860::8888]:53", "[2001:4860:4860::8844]:53", "[2606:4700:4700::1111]:53", "[2606:4700:4700::1001]:53"}
DNSServers = append(DNSServersV4, DNSServersV6...)
ipv4Re = regexp.MustCompile(`(\d*\.).*(\.\d*)`) ipv4Re = regexp.MustCompile(`(\d*\.).*(\.\d*)`)
ipv6Re = regexp.MustCompile(`(\w*:\w*:).*(:\w*:\w*)`) ipv6Re = regexp.MustCompile(`(\w*:\w*:).*(:\w*:\w*)`)

View File

@@ -6,15 +6,12 @@ import (
"fmt" "fmt"
"log" "log"
"net" "net"
"strings"
"sync" "sync"
"time" "time"
"github.com/jinzhu/copier" "github.com/jinzhu/copier"
"github.com/nezhahq/nezha/pkg/ddns"
geoipx "github.com/nezhahq/nezha/pkg/geoip" geoipx "github.com/nezhahq/nezha/pkg/geoip"
"github.com/nezhahq/nezha/pkg/grpcx" "github.com/nezhahq/nezha/pkg/grpcx"
"github.com/nezhahq/nezha/pkg/utils"
"github.com/nezhahq/nezha/model" "github.com/nezhahq/nezha/model"
pb "github.com/nezhahq/nezha/proto" pb "github.com/nezhahq/nezha/proto"
@@ -239,19 +236,8 @@ func (s *NezhaHandler) ReportGeoIP(c context.Context, r *pb.GeoIP) (*pb.GeoIP, e
ipv4 := geoip.IP.IPv4Addr ipv4 := geoip.IP.IPv4Addr
ipv6 := geoip.IP.IPv6Addr ipv6 := geoip.IP.IPv6Addr
dnsServers := strings.Split(singleton.Conf.DNSServers, ",") if err := singleton.ServerShared.UpdateDDNS(server, &model.IP{IPv4Addr: ipv4, IPv6Addr: ipv6}); err != nil {
ctx := context.WithValue(context.Background(), ddns.DNSServerKey{}, utils.IfOr(dnsServers[0] != "", dnsServers, utils.DNSServers)) log.Printf("NEZHA>> Failed to update DDNS for server %d: %v", err, server.ID)
providers, err := singleton.DDNSShared.GetDDNSProvidersFromProfiles(server.DDNSProfiles, &model.IP{IPv4Addr: ipv4, IPv6Addr: ipv6})
if err == nil {
for _, provider := range providers {
domains := server.OverrideDDNSDomains[provider.GetProfileID()]
go func(provider *ddns.Provider) {
provider.UpdateDomain(ctx, domains...)
}(provider)
}
} else {
log.Printf("NEZHA>> Failed to retrieve DDNS configuration: %v", err)
} }
} }

View File

@@ -65,7 +65,7 @@ func (c *DDNSClass) GetDDNSProvidersFromProfiles(profileId []uint64, ip *model.I
profiles = append(profiles, profile) profiles = append(profiles, profile)
} else { } else {
c.listMu.RUnlock() c.listMu.RUnlock()
return nil, fmt.Errorf("无法找到DDNS配置 ID %d", id) return nil, fmt.Errorf("cannot find DDNS profile %d", id)
} }
} }
c.listMu.RUnlock() c.listMu.RUnlock()
@@ -90,7 +90,7 @@ func (c *DDNSClass) GetDDNSProvidersFromProfiles(profileId []uint64, ip *model.I
provider.Setter = &he.Provider{APIKey: profile.AccessSecret} provider.Setter = &he.Provider{APIKey: profile.AccessSecret}
providers = append(providers, provider) providers = append(providers, provider)
default: default:
return nil, fmt.Errorf("无法找到配置的DDNS提供者 %s", profile.Provider) return nil, fmt.Errorf("cannot find DDNS provider %s", profile.Provider)
} }
} }
return providers, nil return providers, nil

View File

@@ -2,9 +2,13 @@ package singleton
import ( import (
"cmp" "cmp"
"context"
"log"
"slices" "slices"
"strings"
"github.com/nezhahq/nezha/model" "github.com/nezhahq/nezha/model"
"github.com/nezhahq/nezha/pkg/ddns"
"github.com/nezhahq/nezha/pkg/utils" "github.com/nezhahq/nezha/pkg/utils"
) )
@@ -14,14 +18,19 @@ type ServerClass struct {
uuidToID map[string]uint64 uuidToID map[string]uint64
sortedListForGuest []*model.Server sortedListForGuest []*model.Server
conf *ConfigClass
dc *DDNSClass
} }
func NewServerClass() *ServerClass { func NewServerClass(conf *ConfigClass, dc *DDNSClass) *ServerClass {
sc := &ServerClass{ sc := &ServerClass{
class: class[uint64, *model.Server]{ class: class[uint64, *model.Server]{
list: make(map[uint64]*model.Server), list: make(map[uint64]*model.Server),
}, },
uuidToID: make(map[string]uint64), uuidToID: make(map[string]uint64),
conf: conf,
dc: dc,
} }
var servers []model.Server var servers []model.Server
@@ -47,6 +56,12 @@ func (c *ServerClass) Update(s *model.Server, uuid string) {
c.listMu.Unlock() c.listMu.Unlock()
if s.EnableDDNS {
if err := c.UpdateDDNS(s, nil); err != nil {
log.Printf("NEZHA>> Failed to update DDNS for server %d: %v", err, s.ID)
}
}
c.sortList() c.sortList()
} }
@@ -79,6 +94,25 @@ func (c *ServerClass) UUIDToID(uuid string) (id uint64, ok bool) {
return return
} }
func (c *ServerClass) UpdateDDNS(server *model.Server, ip *model.IP) error {
confServers := strings.Split(c.conf.DNSServers, ",")
ctx := context.WithValue(context.Background(), ddns.DNSServerKey{}, utils.IfOr(confServers[0] != "", confServers, utils.DNSServers))
providers, err := c.dc.GetDDNSProvidersFromProfiles(server.DDNSProfiles, utils.IfOr(ip != nil, ip, &server.GeoIP.IP))
if err != nil {
return err
}
for _, provider := range providers {
domains := server.OverrideDDNSDomains[provider.GetProfileID()]
go func(provider *ddns.Provider) {
provider.UpdateDomain(ctx, domains...)
}(provider)
}
return nil
}
func (c *ServerClass) sortList() { func (c *ServerClass) sortList() {
c.listMu.RLock() c.listMu.RLock()
defer c.listMu.RUnlock() defer c.listMu.RUnlock()

View File

@@ -52,13 +52,13 @@ func InitTimezoneAndCache() error {
// LoadSingleton 加载子服务并执行 // LoadSingleton 加载子服务并执行
func LoadSingleton(bus chan<- *model.Service) (err error) { func LoadSingleton(bus chan<- *model.Service) (err error) {
initUser() // 加载用户ID绑定表 initUser() // 加载用户ID绑定表
initI18n() // 加载本地化服务 initI18n() // 加载本地化服务
NotificationShared = NewNotificationClass() // 加载通知服务
ServerShared = NewServerClass() // 加载服务器列表
CronShared = NewCronClass() // 加载定时任务
NATShared = NewNATClass() NATShared = NewNATClass()
DDNSShared = NewDDNSClass() DDNSShared = NewDDNSClass()
NotificationShared = NewNotificationClass() // 加载通知服务
ServerShared = NewServerClass(Conf, DDNSShared) // 加载服务器列表
CronShared = NewCronClass() // 加载定时任务
ServiceSentinelShared, err = NewServiceSentinel(bus, ServerShared, NotificationShared, CronShared) ServiceSentinelShared, err = NewServiceSentinel(bus, ServerShared, NotificationShared, CronShared)
return return
} }