ddns: retreive dns servers from context (#1034)

This commit is contained in:
UUBulb
2025-03-17 23:11:40 +08:00
committed by GitHub
parent c3ec52e392
commit 38c2374bad
5 changed files with 17 additions and 21 deletions

View File

@@ -109,7 +109,6 @@ func updateConfig(c *gin.Context) (any, error) {
return nil, newGormError("%v", err) return nil, newGormError("%v", err)
} }
singleton.OnNameserverUpdate()
singleton.OnUpdateLang(singleton.Conf.Language) singleton.OnUpdateLang(singleton.Conf.Language)
return nil, nil return nil, nil
} }

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"strings"
"time" "time"
"github.com/libdns/libdns" "github.com/libdns/libdns"
@@ -14,9 +13,10 @@ import (
"github.com/nezhahq/nezha/pkg/utils" "github.com/nezhahq/nezha/pkg/utils"
) )
var ( type DNSServerKey struct{}
const (
dnsTimeOut = 10 * time.Second dnsTimeOut = 10 * time.Second
customDNSServers []string
) )
type Provider struct { type Provider struct {
@@ -31,12 +31,6 @@ type Provider struct {
Setter libdns.RecordSetter Setter libdns.RecordSetter
} }
func InitDNSServers(s string) {
if s != "" {
customDNSServers = strings.Split(s, ",")
}
}
func (provider *Provider) GetProfileID() uint64 { func (provider *Provider) GetProfileID() uint64 {
return provider.DDNSProfile.ID return provider.DDNSProfile.ID
} }
@@ -58,7 +52,7 @@ func (provider *Provider) UpdateDomain(ctx context.Context, overrideDomains ...s
func (provider *Provider) updateDomain(domain string) error { func (provider *Provider) updateDomain(domain string) error {
var err error var err error
provider.prefix, provider.zone, err = splitDomainSOA(domain) provider.prefix, provider.zone, err = provider.splitDomainSOA(domain)
if err != nil { if err != nil {
return err return err
} }
@@ -96,13 +90,15 @@ func (provider *Provider) addDomainRecord() error {
return err return err
} }
func splitDomainSOA(domain string) (prefix string, zone string, err error) { func (provider *Provider) splitDomainSOA(domain string) (prefix string, zone string, err error) {
c := &dns.Client{Timeout: dnsTimeOut} c := &dns.Client{Timeout: dnsTimeOut}
domain += "." domain += "."
indexes := dns.Split(domain) indexes := dns.Split(domain)
servers := utils.DNSServers servers := utils.DNSServers
customDNSServers, _ := provider.ctx.Value(DNSServerKey{}).([]string)
if len(customDNSServers) > 0 { if len(customDNSServers) > 0 {
servers = customDNSServers servers = customDNSServers
} }

View File

@@ -1,6 +1,7 @@
package ddns package ddns
import ( import (
"context"
"os" "os"
"testing" "testing"
) )
@@ -34,8 +35,9 @@ func TestSplitDomainSOA(t *testing.T) {
}, },
} }
provider := &Provider{ctx: context.WithValue(context.Background(), DNSServerKey{}, []string{"1.1.1.1:53"})}
for _, c := range cases { for _, c := range cases {
prefix, zone, err := splitDomainSOA(c.domain) prefix, zone, err := provider.splitDomainSOA(c.domain)
if err != nil { if err != nil {
t.Fatalf("Error: %s", err) t.Fatalf("Error: %s", err)
} }

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"log" "log"
"net" "net"
"strings"
"sync" "sync"
"time" "time"
@@ -13,6 +14,7 @@ import (
"github.com/nezhahq/nezha/pkg/ddns" "github.com/nezhahq/nezha/pkg/ddns"
geoipx "github.com/nezhahq/nezha/pkg/geoip" geoipx "github.com/nezhahq/nezha/pkg/geoip"
"github.com/nezhahq/nezha/pkg/grpcx" "github.com/nezhahq/nezha/pkg/grpcx"
"github.com/nezhahq/nezha/pkg/utils"
"github.com/nezhahq/nezha/model" "github.com/nezhahq/nezha/model"
pb "github.com/nezhahq/nezha/proto" pb "github.com/nezhahq/nezha/proto"
@@ -236,12 +238,15 @@ func (s *NezhaHandler) ReportGeoIP(c context.Context, r *pb.GeoIP) (*pb.GeoIP, e
ipv4 := geoip.IP.IPv4Addr ipv4 := geoip.IP.IPv4Addr
ipv6 := geoip.IP.IPv6Addr ipv6 := geoip.IP.IPv6Addr
dnsServers := strings.Split(singleton.Conf.DNSServers, ",")
ctx := context.WithValue(context.Background(), ddns.DNSServerKey{}, utils.IfOr(len(dnsServers) > 0, dnsServers, utils.DNSServers))
providers, err := singleton.DDNSShared.GetDDNSProvidersFromProfiles(server.DDNSProfiles, &model.IP{IPv4Addr: ipv4, IPv6Addr: ipv6}) providers, err := singleton.DDNSShared.GetDDNSProvidersFromProfiles(server.DDNSProfiles, &model.IP{IPv4Addr: ipv4, IPv6Addr: ipv6})
if err == nil { if err == nil {
for _, provider := range providers { for _, provider := range providers {
domains := server.OverrideDDNSDomains[provider.GetProfileID()] domains := server.OverrideDDNSDomains[provider.GetProfileID()]
go func(provider *ddns.Provider) { go func(provider *ddns.Provider) {
provider.UpdateDomain(context.Background(), domains...) provider.UpdateDomain(ctx, domains...)
}(provider) }(provider)
} }
} else { } else {

View File

@@ -34,8 +34,6 @@ func NewDDNSClass() *DDNSClass {
sortedList: sortedList, sortedList: sortedList,
}, },
} }
OnNameserverUpdate()
return dc return dc
} }
@@ -107,7 +105,3 @@ func (c *DDNSClass) sortList() {
defer c.sortedListMu.Unlock() defer c.sortedListMu.Unlock()
c.sortedList = sortedList c.sortedList = sortedList
} }
func OnNameserverUpdate() {
ddns2.InitDNSServers(Conf.DNSServers)
}