mirror of
https://github.com/Buriburizaem0n/nezha_domains.git
synced 2026-02-04 04:30:05 +00:00
ddns: store configuation in database (#435)
* ddns: store configuation in database Co-authored-by: nap0o <144927971+nap0o@users.noreply.github.com> * feat: split domain with soa lookup * switch to libdns interface * ddns: add unit test * ddns: skip TestSplitDomainSOA on ci network is not steady * fix error handling * fix error handling --------- Co-authored-by: nap0o <144927971+nap0o@users.noreply.github.com>
This commit is contained in:
@@ -1,190 +0,0 @@
|
||||
package ddns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/naiba/nezha/pkg/utils"
|
||||
)
|
||||
|
||||
const baseEndpoint = "https://api.cloudflare.com/client/v4/zones"
|
||||
|
||||
type ProviderCloudflare struct {
|
||||
isIpv4 bool
|
||||
domainConfig *DomainConfig
|
||||
secret string
|
||||
zoneId string
|
||||
ipAddr string
|
||||
recordId string
|
||||
recordType string
|
||||
}
|
||||
|
||||
type cfReq struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Content string `json:"content"`
|
||||
TTL uint32 `json:"ttl"`
|
||||
Proxied bool `json:"proxied"`
|
||||
}
|
||||
|
||||
func NewProviderCloudflare(s string) *ProviderCloudflare {
|
||||
return &ProviderCloudflare{
|
||||
secret: s,
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *ProviderCloudflare) UpdateDomain(domainConfig *DomainConfig) error {
|
||||
if domainConfig == nil {
|
||||
return fmt.Errorf("获取 DDNS 配置失败")
|
||||
}
|
||||
provider.domainConfig = domainConfig
|
||||
|
||||
err := provider.getZoneID()
|
||||
if err != nil {
|
||||
return fmt.Errorf("无法获取 zone ID: %s", err)
|
||||
}
|
||||
|
||||
// 当IPv4和IPv6同时成功才算作成功
|
||||
if provider.domainConfig.EnableIPv4 {
|
||||
provider.isIpv4 = true
|
||||
provider.recordType = getRecordString(provider.isIpv4)
|
||||
provider.ipAddr = provider.domainConfig.Ipv4Addr
|
||||
if err = provider.addDomainRecord(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if provider.domainConfig.EnableIpv6 {
|
||||
provider.isIpv4 = false
|
||||
provider.recordType = getRecordString(provider.isIpv4)
|
||||
provider.ipAddr = provider.domainConfig.Ipv6Addr
|
||||
if err = provider.addDomainRecord(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (provider *ProviderCloudflare) addDomainRecord() error {
|
||||
err := provider.findDNSRecord()
|
||||
if err != nil {
|
||||
if errors.Is(err, utils.ErrGjsonNotFound) {
|
||||
// 添加 DNS 记录
|
||||
return provider.createDNSRecord()
|
||||
}
|
||||
return fmt.Errorf("查找 DNS 记录时出错: %s", err)
|
||||
}
|
||||
|
||||
// 更新 DNS 记录
|
||||
return provider.updateDNSRecord()
|
||||
}
|
||||
|
||||
func (provider *ProviderCloudflare) getZoneID() error {
|
||||
_, realDomain := splitDomain(provider.domainConfig.FullDomain)
|
||||
zu, _ := url.Parse(baseEndpoint)
|
||||
|
||||
q := zu.Query()
|
||||
q.Set("name", realDomain)
|
||||
zu.RawQuery = q.Encode()
|
||||
|
||||
body, err := provider.sendRequest("GET", zu.String(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := utils.GjsonGet(body, "result.0.id")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
provider.zoneId = result.String()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (provider *ProviderCloudflare) findDNSRecord() error {
|
||||
de, _ := url.JoinPath(baseEndpoint, provider.zoneId, "dns_records")
|
||||
du, _ := url.Parse(de)
|
||||
|
||||
q := du.Query()
|
||||
q.Set("name", provider.domainConfig.FullDomain)
|
||||
q.Set("type", provider.recordType)
|
||||
du.RawQuery = q.Encode()
|
||||
|
||||
body, err := provider.sendRequest("GET", du.String(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := utils.GjsonGet(body, "result.0.id")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
provider.recordId = result.String()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (provider *ProviderCloudflare) createDNSRecord() error {
|
||||
de, _ := url.JoinPath(baseEndpoint, provider.zoneId, "dns_records")
|
||||
data := &cfReq{
|
||||
Name: provider.domainConfig.FullDomain,
|
||||
Type: provider.recordType,
|
||||
Content: provider.ipAddr,
|
||||
TTL: 60,
|
||||
Proxied: false,
|
||||
}
|
||||
|
||||
jsonData, _ := utils.Json.Marshal(data)
|
||||
_, err := provider.sendRequest("POST", de, jsonData)
|
||||
return err
|
||||
}
|
||||
|
||||
func (provider *ProviderCloudflare) updateDNSRecord() error {
|
||||
de, _ := url.JoinPath(baseEndpoint, provider.zoneId, "dns_records", provider.recordId)
|
||||
data := &cfReq{
|
||||
Name: provider.domainConfig.FullDomain,
|
||||
Type: provider.recordType,
|
||||
Content: provider.ipAddr,
|
||||
TTL: 60,
|
||||
Proxied: false,
|
||||
}
|
||||
|
||||
jsonData, _ := utils.Json.Marshal(data)
|
||||
_, err := provider.sendRequest("PATCH", de, jsonData)
|
||||
return err
|
||||
}
|
||||
|
||||
// 以下为辅助方法,如发送 HTTP 请求等
|
||||
func (provider *ProviderCloudflare) sendRequest(method string, url string, data []byte) ([]byte, error) {
|
||||
req, err := http.NewRequest(method, url, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", provider.secret))
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
|
||||
resp, err := utils.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
err := Body.Close()
|
||||
if err != nil {
|
||||
log.Printf("NEZHA>> 无法关闭HTTP响应体流: %s", err.Error())
|
||||
}
|
||||
}(resp.Body)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
125
pkg/ddns/ddns.go
125
pkg/ddns/ddns.go
@@ -1,24 +1,121 @@
|
||||
package ddns
|
||||
|
||||
import "golang.org/x/net/publicsuffix"
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
type DomainConfig struct {
|
||||
EnableIPv4 bool
|
||||
EnableIpv6 bool
|
||||
FullDomain string
|
||||
Ipv4Addr string
|
||||
Ipv6Addr string
|
||||
"github.com/libdns/libdns"
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"github.com/naiba/nezha/model"
|
||||
"github.com/naiba/nezha/pkg/utils"
|
||||
)
|
||||
|
||||
var dnsTimeOut = 10 * time.Second
|
||||
|
||||
type IP struct {
|
||||
Ipv4Addr string
|
||||
Ipv6Addr string
|
||||
}
|
||||
|
||||
type Provider interface {
|
||||
// UpdateDomain Return is updated
|
||||
UpdateDomain(*DomainConfig) error
|
||||
type Provider struct {
|
||||
ctx context.Context
|
||||
ipAddr string
|
||||
recordType string
|
||||
domain string
|
||||
prefix string
|
||||
zone string
|
||||
|
||||
DDNSProfile *model.DDNSProfile
|
||||
IPAddrs *IP
|
||||
Setter libdns.RecordSetter
|
||||
}
|
||||
|
||||
func splitDomain(domain string) (prefix string, realDomain string) {
|
||||
realDomain, _ = publicsuffix.EffectiveTLDPlusOne(domain)
|
||||
prefix = domain[:len(domain)-len(realDomain)-1]
|
||||
return prefix, realDomain
|
||||
func (provider *Provider) UpdateDomain(ctx context.Context) {
|
||||
provider.ctx = ctx
|
||||
for _, domain := range provider.DDNSProfile.Domains {
|
||||
for retries := 0; retries < int(provider.DDNSProfile.MaxRetries); retries++ {
|
||||
provider.domain = domain
|
||||
log.Printf("NEZHA>> 正在尝试更新域名(%s)DDNS(%d/%d)", provider.domain, retries+1, provider.DDNSProfile.MaxRetries)
|
||||
if err := provider.updateDomain(); err != nil {
|
||||
log.Printf("NEZHA>> 尝试更新域名(%s)DDNS失败: %v", provider.domain, err)
|
||||
} else {
|
||||
log.Printf("NEZHA>> 尝试更新域名(%s)DDNS成功", provider.domain)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *Provider) updateDomain() error {
|
||||
var err error
|
||||
provider.prefix, provider.zone, err = splitDomainSOA(provider.domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 当IPv4和IPv6同时成功才算作成功
|
||||
if *provider.DDNSProfile.EnableIPv4 {
|
||||
provider.recordType = getRecordString(true)
|
||||
provider.ipAddr = provider.IPAddrs.Ipv4Addr
|
||||
if err = provider.addDomainRecord(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if *provider.DDNSProfile.EnableIPv6 {
|
||||
provider.recordType = getRecordString(false)
|
||||
provider.ipAddr = provider.IPAddrs.Ipv6Addr
|
||||
if err = provider.addDomainRecord(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (provider *Provider) addDomainRecord() error {
|
||||
_, err := provider.Setter.SetRecords(provider.ctx, provider.zone,
|
||||
[]libdns.Record{
|
||||
{
|
||||
Type: provider.recordType,
|
||||
Name: provider.prefix,
|
||||
Value: provider.ipAddr,
|
||||
TTL: time.Minute,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func splitDomainSOA(domain string) (prefix string, zone string, err error) {
|
||||
c := &dns.Client{Timeout: dnsTimeOut}
|
||||
|
||||
domain += "."
|
||||
indexes := dns.Split(domain)
|
||||
|
||||
var r *dns.Msg
|
||||
for _, idx := range indexes {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(domain[idx:], dns.TypeSOA)
|
||||
|
||||
for _, server := range utils.DNSServers {
|
||||
r, _, err = c.Exchange(m, server)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(r.Answer) > 0 {
|
||||
if soa, ok := r.Answer[0].(*dns.SOA); ok {
|
||||
zone = soa.Hdr.Name
|
||||
prefix = domain[:len(domain)-len(zone)-1]
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", "", fmt.Errorf("SOA record not found for domain: %s", domain)
|
||||
}
|
||||
|
||||
func getRecordString(isIpv4 bool) string {
|
||||
|
||||
44
pkg/ddns/ddns_test.go
Normal file
44
pkg/ddns/ddns_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package ddns
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testSt struct {
|
||||
domain string
|
||||
zone string
|
||||
prefix string
|
||||
}
|
||||
|
||||
func TestSplitDomainSOA(t *testing.T) {
|
||||
if ci := os.Getenv("CI"); ci != "" { // skip if test on CI
|
||||
return
|
||||
}
|
||||
|
||||
cases := []testSt{
|
||||
{
|
||||
domain: "www.example.co.uk",
|
||||
zone: "example.co.uk.",
|
||||
prefix: "www",
|
||||
},
|
||||
{
|
||||
domain: "abc.example.com",
|
||||
zone: "example.com.",
|
||||
prefix: "abc",
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
prefix, zone, err := splitDomainSOA(c.domain)
|
||||
if err != nil {
|
||||
t.Fatalf("Error: %s", err)
|
||||
}
|
||||
if prefix != c.prefix {
|
||||
t.Fatalf("Expected prefix %s, but got %s", c.prefix, prefix)
|
||||
}
|
||||
if zone != c.zone {
|
||||
t.Fatalf("Expected zone %s, but got %s", c.zone, zone)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package ddns
|
||||
|
||||
type ProviderDummy struct{}
|
||||
|
||||
func (provider *ProviderDummy) UpdateDomain(domainConfig *DomainConfig) error {
|
||||
return nil
|
||||
}
|
||||
16
pkg/ddns/dummy/dummy.go
Normal file
16
pkg/ddns/dummy/dummy.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package dummy
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/libdns/libdns"
|
||||
)
|
||||
|
||||
// Internal use
|
||||
type Provider struct {
|
||||
}
|
||||
|
||||
func (provider *Provider) SetRecords(ctx context.Context, zone string,
|
||||
recs []libdns.Record) ([]libdns.Record, error) {
|
||||
return recs, nil
|
||||
}
|
||||
@@ -1,243 +0,0 @@
|
||||
package ddns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/naiba/nezha/pkg/utils"
|
||||
)
|
||||
|
||||
const te = "https://dnspod.tencentcloudapi.com"
|
||||
|
||||
type ProviderTencentCloud struct {
|
||||
isIpv4 bool
|
||||
domainConfig *DomainConfig
|
||||
recordID uint64
|
||||
recordType string
|
||||
secretID string
|
||||
secretKey string
|
||||
errCode string
|
||||
ipAddr string
|
||||
}
|
||||
|
||||
type tcReq struct {
|
||||
RecordType string `json:"RecordType"`
|
||||
Domain string `json:"Domain"`
|
||||
RecordLine string `json:"RecordLine"`
|
||||
Subdomain string `json:"Subdomain,omitempty"`
|
||||
SubDomain string `json:"SubDomain,omitempty"` // As is
|
||||
Value string `json:"Value,omitempty"`
|
||||
TTL uint32 `json:"TTL,omitempty"`
|
||||
RecordId uint64 `json:"RecordId,omitempty"`
|
||||
}
|
||||
|
||||
func NewProviderTencentCloud(id, key string) *ProviderTencentCloud {
|
||||
return &ProviderTencentCloud{
|
||||
secretID: id,
|
||||
secretKey: key,
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *ProviderTencentCloud) UpdateDomain(domainConfig *DomainConfig) error {
|
||||
if domainConfig == nil {
|
||||
return fmt.Errorf("获取 DDNS 配置失败")
|
||||
}
|
||||
provider.domainConfig = domainConfig
|
||||
|
||||
// 当IPv4和IPv6同时成功才算作成功
|
||||
var err error
|
||||
if provider.domainConfig.EnableIPv4 {
|
||||
provider.isIpv4 = true
|
||||
provider.recordType = getRecordString(provider.isIpv4)
|
||||
provider.ipAddr = provider.domainConfig.Ipv4Addr
|
||||
if err = provider.addDomainRecord(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if provider.domainConfig.EnableIpv6 {
|
||||
provider.isIpv4 = false
|
||||
provider.recordType = getRecordString(provider.isIpv4)
|
||||
provider.ipAddr = provider.domainConfig.Ipv6Addr
|
||||
if err = provider.addDomainRecord(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (provider *ProviderTencentCloud) addDomainRecord() error {
|
||||
err := provider.findDNSRecord()
|
||||
if err != nil {
|
||||
return fmt.Errorf("查找 DNS 记录时出错: %s", err)
|
||||
}
|
||||
|
||||
if provider.errCode == "ResourceNotFound.NoDataOfRecord" { // 没有找到 DNS 记录
|
||||
return provider.createDNSRecord()
|
||||
} else if provider.errCode != "" {
|
||||
return fmt.Errorf("查询 DNS 记录时出错,错误代码为: %s", provider.errCode)
|
||||
}
|
||||
|
||||
// 默认情况下更新 DNS 记录
|
||||
return provider.updateDNSRecord()
|
||||
}
|
||||
|
||||
func (provider *ProviderTencentCloud) findDNSRecord() error {
|
||||
prefix, realDomain := splitDomain(provider.domainConfig.FullDomain)
|
||||
data := &tcReq{
|
||||
RecordType: provider.recordType,
|
||||
Domain: realDomain,
|
||||
RecordLine: "默认",
|
||||
Subdomain: prefix,
|
||||
}
|
||||
|
||||
jsonData, _ := utils.Json.Marshal(data)
|
||||
body, err := provider.sendRequest("DescribeRecordList", jsonData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := utils.GjsonGet(body, "Response.RecordList.0.RecordId")
|
||||
if err != nil {
|
||||
if errors.Is(err, utils.ErrGjsonNotFound) {
|
||||
if errCode, err := utils.GjsonGet(body, "Response.Error.Code"); err == nil {
|
||||
provider.errCode = errCode.String()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
provider.recordID = result.Uint()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (provider *ProviderTencentCloud) createDNSRecord() error {
|
||||
prefix, realDomain := splitDomain(provider.domainConfig.FullDomain)
|
||||
data := &tcReq{
|
||||
RecordType: provider.recordType,
|
||||
RecordLine: "默认",
|
||||
Domain: realDomain,
|
||||
SubDomain: prefix,
|
||||
Value: provider.ipAddr,
|
||||
TTL: 600,
|
||||
}
|
||||
|
||||
jsonData, _ := utils.Json.Marshal(data)
|
||||
_, err := provider.sendRequest("CreateRecord", jsonData)
|
||||
return err
|
||||
}
|
||||
|
||||
func (provider *ProviderTencentCloud) updateDNSRecord() error {
|
||||
prefix, realDomain := splitDomain(provider.domainConfig.FullDomain)
|
||||
data := &tcReq{
|
||||
RecordType: provider.recordType,
|
||||
RecordLine: "默认",
|
||||
Domain: realDomain,
|
||||
SubDomain: prefix,
|
||||
Value: provider.ipAddr,
|
||||
TTL: 600,
|
||||
RecordId: provider.recordID,
|
||||
}
|
||||
|
||||
jsonData, _ := utils.Json.Marshal(data)
|
||||
_, err := provider.sendRequest("ModifyRecord", jsonData)
|
||||
return err
|
||||
}
|
||||
|
||||
// 以下为辅助方法,如发送 HTTP 请求等
|
||||
func (provider *ProviderTencentCloud) sendRequest(action string, data []byte) ([]byte, error) {
|
||||
req, err := http.NewRequest("POST", te, bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-TC-Version", "2021-03-23")
|
||||
|
||||
provider.signRequest(provider.secretID, provider.secretKey, req, action, string(data))
|
||||
resp, err := utils.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
err := Body.Close()
|
||||
if err != nil {
|
||||
log.Printf("NEZHA>> 无法关闭HTTP响应体流: %s\n", err.Error())
|
||||
}
|
||||
}(resp.Body)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// https://github.com/jeessy2/ddns-go/blob/master/util/tencent_cloud_signer.go
|
||||
|
||||
func (provider *ProviderTencentCloud) sha256hex(s string) string {
|
||||
b := sha256.Sum256([]byte(s))
|
||||
return hex.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
func (provider *ProviderTencentCloud) hmacsha256(s, key string) string {
|
||||
hashed := hmac.New(sha256.New, []byte(key))
|
||||
hashed.Write([]byte(s))
|
||||
return string(hashed.Sum(nil))
|
||||
}
|
||||
|
||||
func (provider *ProviderTencentCloud) WriteString(strs ...string) string {
|
||||
var b strings.Builder
|
||||
for _, str := range strs {
|
||||
b.WriteString(str)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (provider *ProviderTencentCloud) signRequest(secretId string, secretKey string, r *http.Request, action string, payload string) {
|
||||
algorithm := "TC3-HMAC-SHA256"
|
||||
service := "dnspod"
|
||||
host := provider.WriteString(service, ".tencentcloudapi.com")
|
||||
timestamp := time.Now().Unix()
|
||||
timestampStr := strconv.FormatInt(timestamp, 10)
|
||||
|
||||
// 步骤 1:拼接规范请求串
|
||||
canonicalHeaders := provider.WriteString("content-type:application/json\nhost:", host, "\nx-tc-action:", strings.ToLower(action), "\n")
|
||||
signedHeaders := "content-type;host;x-tc-action"
|
||||
hashedRequestPayload := provider.sha256hex(payload)
|
||||
canonicalRequest := provider.WriteString("POST\n/\n\n", canonicalHeaders, "\n", signedHeaders, "\n", hashedRequestPayload)
|
||||
|
||||
// 步骤 2:拼接待签名字符串
|
||||
date := time.Unix(timestamp, 0).UTC().Format("2006-01-02")
|
||||
credentialScope := provider.WriteString(date, "/", service, "/tc3_request")
|
||||
hashedCanonicalRequest := provider.sha256hex(canonicalRequest)
|
||||
string2sign := provider.WriteString(algorithm, "\n", timestampStr, "\n", credentialScope, "\n", hashedCanonicalRequest)
|
||||
|
||||
// 步骤 3:计算签名
|
||||
secretDate := provider.hmacsha256(date, provider.WriteString("TC3", secretKey))
|
||||
secretService := provider.hmacsha256(service, secretDate)
|
||||
secretSigning := provider.hmacsha256("tc3_request", secretService)
|
||||
signature := hex.EncodeToString([]byte(provider.hmacsha256(string2sign, secretSigning)))
|
||||
|
||||
// 步骤 4:拼接 Authorization
|
||||
authorization := provider.WriteString(algorithm, " Credential=", secretId, "/", credentialScope, ", SignedHeaders=", signedHeaders, ", Signature=", signature)
|
||||
|
||||
r.Header.Add("Authorization", authorization)
|
||||
r.Header.Set("Host", host)
|
||||
r.Header.Set("X-TC-Action", action)
|
||||
r.Header.Add("X-TC-Timestamp", timestampStr)
|
||||
}
|
||||
@@ -1,110 +0,0 @@
|
||||
package ddns
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/naiba/nezha/pkg/utils"
|
||||
)
|
||||
|
||||
type ProviderWebHook struct {
|
||||
url string
|
||||
requestMethod string
|
||||
requestBody string
|
||||
requestHeader string
|
||||
domainConfig *DomainConfig
|
||||
}
|
||||
|
||||
func NewProviderWebHook(s, rm, rb, rh string) *ProviderWebHook {
|
||||
return &ProviderWebHook{
|
||||
url: s,
|
||||
requestMethod: rm,
|
||||
requestBody: rb,
|
||||
requestHeader: rh,
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *ProviderWebHook) UpdateDomain(domainConfig *DomainConfig) error {
|
||||
if domainConfig == nil {
|
||||
return fmt.Errorf("获取 DDNS 配置失败")
|
||||
}
|
||||
provider.domainConfig = domainConfig
|
||||
|
||||
if provider.domainConfig.FullDomain == "" {
|
||||
return fmt.Errorf("failed to update an empty domain")
|
||||
}
|
||||
|
||||
if provider.domainConfig.EnableIPv4 && provider.domainConfig.Ipv4Addr != "" {
|
||||
req, err := provider.prepareRequest(true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update a domain: %s. Cause by: %v", provider.domainConfig.FullDomain, err)
|
||||
}
|
||||
if _, err := utils.HttpClient.Do(req); err != nil {
|
||||
return fmt.Errorf("failed to update a domain: %s. Cause by: %v", provider.domainConfig.FullDomain, err)
|
||||
}
|
||||
}
|
||||
|
||||
if provider.domainConfig.EnableIpv6 && provider.domainConfig.Ipv6Addr != "" {
|
||||
req, err := provider.prepareRequest(false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update a domain: %s. Cause by: %v", provider.domainConfig.FullDomain, err)
|
||||
}
|
||||
if _, err := utils.HttpClient.Do(req); err != nil {
|
||||
return fmt.Errorf("failed to update a domain: %s. Cause by: %v", provider.domainConfig.FullDomain, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (provider *ProviderWebHook) prepareRequest(isIPv4 bool) (*http.Request, error) {
|
||||
u, err := url.Parse(provider.url)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed parsing url: %v", err)
|
||||
}
|
||||
|
||||
// Only handle queries here
|
||||
q := u.Query()
|
||||
for p, vals := range q {
|
||||
for n, v := range vals {
|
||||
vals[n] = provider.formatWebhookString(v, isIPv4)
|
||||
}
|
||||
q[p] = vals
|
||||
}
|
||||
|
||||
u.RawQuery = q.Encode()
|
||||
body := provider.formatWebhookString(provider.requestBody, isIPv4)
|
||||
header := provider.formatWebhookString(provider.requestHeader, isIPv4)
|
||||
headers := strings.Split(header, "\n")
|
||||
|
||||
req, err := http.NewRequest(provider.requestMethod, u.String(), bytes.NewBufferString(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed creating new request: %v", err)
|
||||
}
|
||||
|
||||
utils.SetStringHeadersToRequest(req, headers)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (provider *ProviderWebHook) formatWebhookString(s string, isIPv4 bool) string {
|
||||
var ipAddr, ipType string
|
||||
if isIPv4 {
|
||||
ipAddr = provider.domainConfig.Ipv4Addr
|
||||
ipType = "ipv4"
|
||||
} else {
|
||||
ipAddr = provider.domainConfig.Ipv6Addr
|
||||
ipType = "ipv6"
|
||||
}
|
||||
|
||||
r := strings.NewReplacer(
|
||||
"{ip}", ipAddr,
|
||||
"{domain}", provider.domainConfig.FullDomain,
|
||||
"{type}", ipType,
|
||||
"\r", "",
|
||||
)
|
||||
|
||||
result := r.Replace(strings.TrimSpace(s))
|
||||
return result
|
||||
}
|
||||
178
pkg/ddns/webhook/webhook.go
Normal file
178
pkg/ddns/webhook/webhook.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package webhook
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/libdns/libdns"
|
||||
"github.com/naiba/nezha/model"
|
||||
"github.com/naiba/nezha/pkg/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
_ = iota
|
||||
methodGET
|
||||
methodPOST
|
||||
methodPATCH
|
||||
methodDELETE
|
||||
methodPUT
|
||||
)
|
||||
|
||||
const (
|
||||
_ = iota
|
||||
requestTypeJSON
|
||||
requestTypeForm
|
||||
)
|
||||
|
||||
var requestTypes = map[uint8]string{
|
||||
methodGET: "GET",
|
||||
methodPOST: "POST",
|
||||
methodPATCH: "PATCH",
|
||||
methodDELETE: "DELETE",
|
||||
methodPUT: "PUT",
|
||||
}
|
||||
|
||||
// Internal use
|
||||
type Provider struct {
|
||||
ipAddr string
|
||||
ipType string
|
||||
recordType string
|
||||
domain string
|
||||
|
||||
DDNSProfile *model.DDNSProfile
|
||||
}
|
||||
|
||||
func (provider *Provider) SetRecords(ctx context.Context, zone string,
|
||||
recs []libdns.Record) ([]libdns.Record, error) {
|
||||
for _, rec := range recs {
|
||||
provider.recordType = rec.Type
|
||||
provider.ipType = recordToIPType(provider.recordType)
|
||||
provider.ipAddr = rec.Value
|
||||
provider.domain = fmt.Sprintf("%s.%s", rec.Name, strings.TrimSuffix(zone, "."))
|
||||
|
||||
req, err := provider.prepareRequest(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update a domain: %s. Cause by: %v", provider.domain, err)
|
||||
}
|
||||
if _, err := utils.HttpClient.Do(req); err != nil {
|
||||
return nil, fmt.Errorf("failed to update a domain: %s. Cause by: %v", provider.domain, err)
|
||||
}
|
||||
}
|
||||
|
||||
return recs, nil
|
||||
}
|
||||
|
||||
func (provider *Provider) prepareRequest(ctx context.Context) (*http.Request, error) {
|
||||
u, err := provider.reqUrl()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err := provider.reqBody()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
headers, err := utils.GjsonParseStringMap(
|
||||
provider.formatWebhookString(provider.DDNSProfile.WebhookHeaders))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, requestTypes[provider.DDNSProfile.WebhookMethod], u.String(), strings.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
provider.setContentType(req)
|
||||
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (provider *Provider) setContentType(req *http.Request) {
|
||||
if provider.DDNSProfile.WebhookMethod == methodGET {
|
||||
return
|
||||
}
|
||||
if provider.DDNSProfile.WebhookRequestType == requestTypeForm {
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
} else {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *Provider) reqUrl() (*url.URL, error) {
|
||||
formattedUrl := strings.ReplaceAll(provider.DDNSProfile.WebhookURL, "#", "%23")
|
||||
|
||||
u, err := url.Parse(formattedUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Only handle queries here
|
||||
q := u.Query()
|
||||
for p, vals := range q {
|
||||
for n, v := range vals {
|
||||
vals[n] = provider.formatWebhookString(v)
|
||||
}
|
||||
q[p] = vals
|
||||
}
|
||||
|
||||
u.RawQuery = q.Encode()
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (provider *Provider) reqBody() (string, error) {
|
||||
if provider.DDNSProfile.WebhookMethod == methodGET ||
|
||||
provider.DDNSProfile.WebhookMethod == methodDELETE {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
switch provider.DDNSProfile.WebhookRequestType {
|
||||
case requestTypeJSON:
|
||||
return provider.formatWebhookString(provider.DDNSProfile.WebhookRequestBody), nil
|
||||
case requestTypeForm:
|
||||
data, err := utils.GjsonParseStringMap(provider.DDNSProfile.WebhookRequestBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
params := url.Values{}
|
||||
for k, v := range data {
|
||||
params.Add(k, provider.formatWebhookString(v))
|
||||
}
|
||||
return params.Encode(), nil
|
||||
default:
|
||||
return "", errors.New("request type not supported")
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *Provider) formatWebhookString(s string) string {
|
||||
r := strings.NewReplacer(
|
||||
"#ip#", provider.ipAddr,
|
||||
"#domain#", provider.domain,
|
||||
"#type#", provider.ipType,
|
||||
"#record#", provider.recordType,
|
||||
"\r", "",
|
||||
)
|
||||
|
||||
result := r.Replace(strings.TrimSpace(s))
|
||||
return result
|
||||
}
|
||||
|
||||
func recordToIPType(record string) string {
|
||||
switch record {
|
||||
case "A":
|
||||
return "ipv4"
|
||||
case "AAAA":
|
||||
return "ipv6"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
116
pkg/ddns/webhook/webhook_test.go
Normal file
116
pkg/ddns/webhook/webhook_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package webhook
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/naiba/nezha/model"
|
||||
)
|
||||
|
||||
var (
|
||||
reqTypeForm = "application/x-www-form-urlencoded"
|
||||
reqTypeJSON = "application/json"
|
||||
)
|
||||
|
||||
type testSt struct {
|
||||
profile model.DDNSProfile
|
||||
expectURL string
|
||||
expectBody string
|
||||
expectContentType string
|
||||
expectHeader map[string]string
|
||||
}
|
||||
|
||||
func execCase(t *testing.T, item testSt) {
|
||||
pw := Provider{DDNSProfile: &item.profile}
|
||||
pw.ipAddr = "1.1.1.1"
|
||||
pw.domain = item.profile.Domains[0]
|
||||
pw.ipType = "ipv4"
|
||||
pw.recordType = "A"
|
||||
pw.DDNSProfile = &item.profile
|
||||
|
||||
reqUrl, err := pw.reqUrl()
|
||||
if err != nil {
|
||||
t.Fatalf("Error: %s", err)
|
||||
}
|
||||
if item.expectURL != reqUrl.String() {
|
||||
t.Fatalf("Expected %s, but got %s", item.expectURL, reqUrl.String())
|
||||
}
|
||||
|
||||
reqBody, err := pw.reqBody()
|
||||
if err != nil {
|
||||
t.Fatalf("Error: %s", err)
|
||||
}
|
||||
if item.expectBody != reqBody {
|
||||
t.Fatalf("Expected %s, but got %s", item.expectBody, reqBody)
|
||||
}
|
||||
|
||||
req, err := pw.prepareRequest(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Error: %s", err)
|
||||
}
|
||||
|
||||
if item.expectContentType != req.Header.Get("Content-Type") {
|
||||
t.Fatalf("Expected %s, but got %s", item.expectContentType, req.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
for k, v := range item.expectHeader {
|
||||
if v != req.Header.Get(k) {
|
||||
t.Fatalf("Expected %s, but got %s", v, req.Header.Get(k))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhookRequest(t *testing.T) {
|
||||
ipv4 := true
|
||||
|
||||
cases := []testSt{
|
||||
{
|
||||
profile: model.DDNSProfile{
|
||||
Domains: []string{"www.example.com"},
|
||||
MaxRetries: 1,
|
||||
EnableIPv4: &ipv4,
|
||||
WebhookURL: "http://ddns.example.com/?ip=#ip#",
|
||||
WebhookMethod: methodGET,
|
||||
WebhookHeaders: `{"ip":"#ip#","record":"#record#"}`,
|
||||
},
|
||||
expectURL: "http://ddns.example.com/?ip=1.1.1.1",
|
||||
expectContentType: "",
|
||||
expectHeader: map[string]string{
|
||||
"ip": "1.1.1.1",
|
||||
"record": "A",
|
||||
},
|
||||
},
|
||||
{
|
||||
profile: model.DDNSProfile{
|
||||
Domains: []string{"www.example.com"},
|
||||
MaxRetries: 1,
|
||||
EnableIPv4: &ipv4,
|
||||
WebhookURL: "http://ddns.example.com/api",
|
||||
WebhookMethod: methodPOST,
|
||||
WebhookRequestType: requestTypeJSON,
|
||||
WebhookRequestBody: `{"ip":"#ip#","record":"#record#"}`,
|
||||
},
|
||||
expectURL: "http://ddns.example.com/api",
|
||||
expectContentType: reqTypeJSON,
|
||||
expectBody: `{"ip":"1.1.1.1","record":"A"}`,
|
||||
},
|
||||
{
|
||||
profile: model.DDNSProfile{
|
||||
Domains: []string{"www.example.com"},
|
||||
MaxRetries: 1,
|
||||
EnableIPv4: &ipv4,
|
||||
WebhookURL: "http://ddns.example.com/api",
|
||||
WebhookMethod: methodPOST,
|
||||
WebhookRequestType: requestTypeForm,
|
||||
WebhookRequestBody: `{"ip":"#ip#","record":"#record#"}`,
|
||||
},
|
||||
expectURL: "http://ddns.example.com/api",
|
||||
expectContentType: reqTypeForm,
|
||||
expectBody: "ip=1.1.1.1&record=A",
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
execCase(t, c)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user