diff --git a/pkg/ddns/ddns.go b/pkg/ddns/ddns.go index 430faf4..3df190e 100644 --- a/pkg/ddns/ddns.go +++ b/pkg/ddns/ddns.go @@ -20,7 +20,6 @@ const ( ) type Provider struct { - ctx context.Context ipAddr string recordType string prefix string @@ -36,11 +35,10 @@ func (provider *Provider) GetProfileID() uint64 { } func (provider *Provider) UpdateDomain(ctx context.Context, overrideDomains ...string) { - provider.ctx = ctx for _, domain := range utils.IfOr(len(overrideDomains) > 0, overrideDomains, provider.DDNSProfile.Domains) { for retries := 0; retries < int(provider.DDNSProfile.MaxRetries); retries++ { log.Printf("NEZHA>> Updating DNS Record of domain %s: %d/%d", domain, retries+1, provider.DDNSProfile.MaxRetries) - if err := provider.updateDomain(domain); err != nil { + if err := provider.updateDomain(ctx, domain); err != nil { log.Printf("NEZHA>> Failed to update DNS record of domain %s: %v", domain, err) } else { log.Printf("NEZHA>> Update DNS record of domain %s succeeded", domain) @@ -50,9 +48,9 @@ func (provider *Provider) UpdateDomain(ctx context.Context, overrideDomains ...s } } -func (provider *Provider) updateDomain(domain string) error { +func (provider *Provider) updateDomain(ctx context.Context, domain string) error { var err error - provider.prefix, provider.zone, err = provider.splitDomainSOA(domain) + provider.prefix, provider.zone, err = provider.splitDomainSOA(ctx, domain) if err != nil { return err } @@ -61,7 +59,7 @@ func (provider *Provider) updateDomain(domain string) error { if *provider.DDNSProfile.EnableIPv4 { provider.recordType = getRecordString(true) provider.ipAddr = provider.IPAddrs.IPv4Addr - if err = provider.addDomainRecord(); err != nil { + if err = provider.addDomainRecord(ctx); err != nil { return err } } @@ -69,7 +67,7 @@ func (provider *Provider) updateDomain(domain string) error { if *provider.DDNSProfile.EnableIPv6 { provider.recordType = getRecordString(false) provider.ipAddr = provider.IPAddrs.IPv6Addr - if err = provider.addDomainRecord(); err != nil { + if err = provider.addDomainRecord(ctx); err != nil { return err } } @@ -77,8 +75,8 @@ func (provider *Provider) updateDomain(domain string) error { return nil } -func (provider *Provider) addDomainRecord() error { - _, err := provider.Setter.SetRecords(provider.ctx, provider.zone, +func (provider *Provider) addDomainRecord(ctx context.Context) error { + _, err := provider.Setter.SetRecords(ctx, provider.zone, []libdns.Record{ { Type: provider.recordType, @@ -90,7 +88,7 @@ func (provider *Provider) addDomainRecord() error { return err } -func (provider *Provider) splitDomainSOA(domain string) (prefix string, zone string, err error) { +func (provider *Provider) splitDomainSOA(ctx context.Context, domain string) (prefix string, zone string, err error) { c := &dns.Client{Timeout: dnsTimeOut} domain += "." @@ -98,26 +96,26 @@ func (provider *Provider) splitDomainSOA(domain string) (prefix string, zone str servers := utils.DNSServers - customDNSServers, _ := provider.ctx.Value(DNSServerKey{}).([]string) + customDNSServers, _ := ctx.Value(DNSServerKey{}).([]string) if len(customDNSServers) > 0 { servers = customDNSServers } - var r *dns.Msg - for _, idx := range indexes { - var m dns.Msg - m.SetQuestion(domain[idx:], dns.TypeSOA) + for _, server := range servers { + for _, idx := range indexes { + var m dns.Msg + m.SetQuestion(domain[idx:], dns.TypeSOA) - for _, server := range servers { - r, _, err = c.Exchange(&m, server) + r, _, err := c.Exchange(&m, server) if err != nil { - return + continue } + if len(r.Answer) > 0 { if soa, ok := r.Answer[0].(*dns.SOA); ok { - zone = soa.Hdr.Name - prefix = libdns.RelativeName(domain, zone) - return + zone := soa.Hdr.Name + prefix := libdns.RelativeName(domain, zone) + return prefix, zone, nil } } } diff --git a/pkg/ddns/ddns_test.go b/pkg/ddns/ddns_test.go index cd846f3..c802d2b 100644 --- a/pkg/ddns/ddns_test.go +++ b/pkg/ddns/ddns_test.go @@ -35,9 +35,10 @@ func TestSplitDomainSOA(t *testing.T) { }, } - provider := &Provider{ctx: context.WithValue(context.Background(), DNSServerKey{}, []string{"1.1.1.1:53"})} + ctx := context.WithValue(context.Background(), DNSServerKey{}, []string{"1.1.1.1:53"}) + provider := &Provider{} for _, c := range cases { - prefix, zone, err := provider.splitDomainSOA(c.domain) + prefix, zone, err := provider.splitDomainSOA(ctx, c.domain) if err != nil { t.Fatalf("Error: %s", err) } diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 3a285a1..8a3d922 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -18,7 +18,9 @@ import ( ) var ( - DNSServers = []string{"8.8.8.8:53", "8.8.4.4:53", "1.1.1.1:53", "1.0.0.1:53"} + DNSServersV4 = []string{"8.8.8.8:53", "8.8.4.4:53", "1.1.1.1:53", "1.0.0.1:53"} + DNSServersV6 = []string{"[2001:4860:4860::8888]:53", "[2001:4860:4860::8844]:53", "[2606:4700:4700::1111]:53", "[2606:4700:4700::1001]:53"} + DNSServers = append(DNSServersV4, DNSServersV6...) ipv4Re = regexp.MustCompile(`(\d*\.).*(\.\d*)`) ipv6Re = regexp.MustCompile(`(\w*:\w*:).*(:\w*:\w*)`) diff --git a/service/rpc/nezha.go b/service/rpc/nezha.go index 6a4a346..69c1c26 100644 --- a/service/rpc/nezha.go +++ b/service/rpc/nezha.go @@ -6,15 +6,12 @@ import ( "fmt" "log" "net" - "strings" "sync" "time" "github.com/jinzhu/copier" - "github.com/nezhahq/nezha/pkg/ddns" geoipx "github.com/nezhahq/nezha/pkg/geoip" "github.com/nezhahq/nezha/pkg/grpcx" - "github.com/nezhahq/nezha/pkg/utils" "github.com/nezhahq/nezha/model" pb "github.com/nezhahq/nezha/proto" @@ -239,19 +236,8 @@ func (s *NezhaHandler) ReportGeoIP(c context.Context, r *pb.GeoIP) (*pb.GeoIP, e ipv4 := geoip.IP.IPv4Addr ipv6 := geoip.IP.IPv6Addr - dnsServers := strings.Split(singleton.Conf.DNSServers, ",") - ctx := context.WithValue(context.Background(), ddns.DNSServerKey{}, utils.IfOr(dnsServers[0] != "", dnsServers, utils.DNSServers)) - - providers, err := singleton.DDNSShared.GetDDNSProvidersFromProfiles(server.DDNSProfiles, &model.IP{IPv4Addr: ipv4, IPv6Addr: ipv6}) - if err == nil { - for _, provider := range providers { - domains := server.OverrideDDNSDomains[provider.GetProfileID()] - go func(provider *ddns.Provider) { - provider.UpdateDomain(ctx, domains...) - }(provider) - } - } else { - log.Printf("NEZHA>> Failed to retrieve DDNS configuration: %v", err) + if err := singleton.ServerShared.UpdateDDNS(server, &model.IP{IPv4Addr: ipv4, IPv6Addr: ipv6}); err != nil { + log.Printf("NEZHA>> Failed to update DDNS for server %d: %v", err, server.ID) } } diff --git a/service/singleton/ddns.go b/service/singleton/ddns.go index 863ccaf..93619cd 100644 --- a/service/singleton/ddns.go +++ b/service/singleton/ddns.go @@ -65,7 +65,7 @@ func (c *DDNSClass) GetDDNSProvidersFromProfiles(profileId []uint64, ip *model.I profiles = append(profiles, profile) } else { c.listMu.RUnlock() - return nil, fmt.Errorf("无法找到DDNS配置 ID %d", id) + return nil, fmt.Errorf("cannot find DDNS profile %d", id) } } c.listMu.RUnlock() @@ -90,7 +90,7 @@ func (c *DDNSClass) GetDDNSProvidersFromProfiles(profileId []uint64, ip *model.I provider.Setter = &he.Provider{APIKey: profile.AccessSecret} providers = append(providers, provider) default: - return nil, fmt.Errorf("无法找到配置的DDNS提供者 %s", profile.Provider) + return nil, fmt.Errorf("cannot find DDNS provider %s", profile.Provider) } } return providers, nil diff --git a/service/singleton/server.go b/service/singleton/server.go index e4e9e8c..332b54b 100644 --- a/service/singleton/server.go +++ b/service/singleton/server.go @@ -2,9 +2,13 @@ package singleton import ( "cmp" + "context" + "log" "slices" + "strings" "github.com/nezhahq/nezha/model" + "github.com/nezhahq/nezha/pkg/ddns" "github.com/nezhahq/nezha/pkg/utils" ) @@ -14,14 +18,19 @@ type ServerClass struct { uuidToID map[string]uint64 sortedListForGuest []*model.Server + + conf *ConfigClass + dc *DDNSClass } -func NewServerClass() *ServerClass { +func NewServerClass(conf *ConfigClass, dc *DDNSClass) *ServerClass { sc := &ServerClass{ class: class[uint64, *model.Server]{ list: make(map[uint64]*model.Server), }, uuidToID: make(map[string]uint64), + conf: conf, + dc: dc, } var servers []model.Server @@ -47,6 +56,12 @@ func (c *ServerClass) Update(s *model.Server, uuid string) { c.listMu.Unlock() + if s.EnableDDNS { + if err := c.UpdateDDNS(s, nil); err != nil { + log.Printf("NEZHA>> Failed to update DDNS for server %d: %v", err, s.ID) + } + } + c.sortList() } @@ -79,6 +94,25 @@ func (c *ServerClass) UUIDToID(uuid string) (id uint64, ok bool) { return } +func (c *ServerClass) UpdateDDNS(server *model.Server, ip *model.IP) error { + confServers := strings.Split(c.conf.DNSServers, ",") + ctx := context.WithValue(context.Background(), ddns.DNSServerKey{}, utils.IfOr(confServers[0] != "", confServers, utils.DNSServers)) + + providers, err := c.dc.GetDDNSProvidersFromProfiles(server.DDNSProfiles, utils.IfOr(ip != nil, ip, &server.GeoIP.IP)) + if err != nil { + return err + } + + for _, provider := range providers { + domains := server.OverrideDDNSDomains[provider.GetProfileID()] + go func(provider *ddns.Provider) { + provider.UpdateDomain(ctx, domains...) + }(provider) + } + + return nil +} + func (c *ServerClass) sortList() { c.listMu.RLock() defer c.listMu.RUnlock() diff --git a/service/singleton/singleton.go b/service/singleton/singleton.go index 743a1f5..57fd779 100644 --- a/service/singleton/singleton.go +++ b/service/singleton/singleton.go @@ -52,13 +52,13 @@ func InitTimezoneAndCache() error { // LoadSingleton 加载子服务并执行 func LoadSingleton(bus chan<- *model.Service) (err error) { - initUser() // 加载用户ID绑定表 - initI18n() // 加载本地化服务 - NotificationShared = NewNotificationClass() // 加载通知服务 - ServerShared = NewServerClass() // 加载服务器列表 - CronShared = NewCronClass() // 加载定时任务 + initUser() // 加载用户ID绑定表 + initI18n() // 加载本地化服务 NATShared = NewNATClass() DDNSShared = NewDDNSClass() + NotificationShared = NewNotificationClass() // 加载通知服务 + ServerShared = NewServerClass(Conf, DDNSShared) // 加载服务器列表 + CronShared = NewCronClass() // 加载定时任务 ServiceSentinelShared, err = NewServiceSentinel(bus, ServerShared, NotificationShared, CronShared) return }