package dns import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "net/url" "time" ) const cfBaseURL = "https://api.cloudflare.com/client/v4" // Cloudflare implements the Provider interface using the Cloudflare API v4. type Cloudflare struct { token string zoneID string client *http.Client } // NewCloudflare creates a new Cloudflare DNS provider. // token is required. zoneID can be empty for ListZones/TestConnection calls. func NewCloudflare(token, zoneID string) (*Cloudflare, error) { if token == "" { return nil, fmt.Errorf("cloudflare API token is required") } return &Cloudflare{ token: token, zoneID: zoneID, client: &http.Client{Timeout: 30 * time.Second}, }, nil } // --- Provider interface --- // EnsureRecord creates or updates an A record for the given FQDN. func (c *Cloudflare) EnsureRecord(ctx context.Context, fqdn, ip string) (string, error) { if c.zoneID == "" { return "", fmt.Errorf("zone ID is required for DNS operations") } // Check if a record already exists. existing, err := c.findRecord(ctx, fqdn) if err != nil { return "", fmt.Errorf("find existing record: %w", err) } if existing != nil { // Record exists — update if IP differs. if existing.Content == ip { return existing.ID, nil // already correct, no-op } updated, err := c.updateRecord(ctx, existing.ID, fqdn, ip) if err != nil { return "", fmt.Errorf("update record: %w", err) } return updated.ID, nil } // Record doesn't exist — create it. created, err := c.createRecord(ctx, fqdn, ip) if err != nil { return "", fmt.Errorf("create record: %w", err) } return created.ID, nil } // DeleteRecord removes an A record by FQDN. Returns nil if not found. func (c *Cloudflare) DeleteRecord(ctx context.Context, fqdn string) error { if c.zoneID == "" { return fmt.Errorf("zone ID is required for DNS operations") } existing, err := c.findRecord(ctx, fqdn) if err != nil { return fmt.Errorf("find record: %w", err) } if existing == nil { return nil // doesn't exist, nothing to delete } endpoint := fmt.Sprintf("%s/zones/%s/dns_records/%s", cfBaseURL, c.zoneID, existing.ID) if _, err := c.doRequest(ctx, http.MethodDelete, endpoint, nil); err != nil { return fmt.Errorf("delete record: %w", err) } return nil } // ListRecords returns all A records in the zone. func (c *Cloudflare) ListRecords(ctx context.Context) ([]Record, error) { if c.zoneID == "" { return nil, fmt.Errorf("zone ID is required for DNS operations") } var allRecords []Record page := 1 for { endpoint := fmt.Sprintf("%s/zones/%s/dns_records?type=A&page=%d&per_page=100", cfBaseURL, c.zoneID, page) body, err := c.doRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, fmt.Errorf("list records: %w", err) } var resp cfListResponse if err := json.Unmarshal(body, &resp); err != nil { return nil, fmt.Errorf("decode list response: %w", err) } for _, r := range resp.Result { allRecords = append(allRecords, Record{ ID: r.ID, FQDN: r.Name, Type: r.Type, Content: r.Content, TTL: r.TTL, Proxied: r.Proxied, }) } if page >= resp.ResultInfo.TotalPages { break } page++ } return allRecords, nil } // TestConnection verifies the API token is valid. func (c *Cloudflare) TestConnection(ctx context.Context) error { endpoint := cfBaseURL + "/user/tokens/verify" body, err := c.doRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return fmt.Errorf("verify token: %w", err) } var resp cfBaseResponse if err := json.Unmarshal(body, &resp); err != nil { return fmt.Errorf("decode verify response: %w", err) } if !resp.Success { return fmt.Errorf("token verification failed: %s", formatErrors(resp.Errors)) } return nil } // --- Additional methods (not part of Provider interface) --- // ListZones returns all zones accessible by the token. func (c *Cloudflare) ListZones(ctx context.Context) ([]Zone, error) { var allZones []Zone page := 1 for { endpoint := fmt.Sprintf("%s/zones?page=%d&per_page=50&status=active", cfBaseURL, page) body, err := c.doRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, fmt.Errorf("list zones: %w", err) } var resp cfZonesResponse if err := json.Unmarshal(body, &resp); err != nil { return nil, fmt.Errorf("decode zones response: %w", err) } for _, z := range resp.Result { allZones = append(allZones, Zone{ ID: z.ID, Name: z.Name, }) } if page >= resp.ResultInfo.TotalPages { break } page++ } return allZones, nil } // --- Internal helpers --- func (c *Cloudflare) findRecord(ctx context.Context, fqdn string) (*cfDNSRecord, error) { endpoint := fmt.Sprintf("%s/zones/%s/dns_records?type=A&name=%s", cfBaseURL, c.zoneID, url.QueryEscape(fqdn)) body, err := c.doRequest(ctx, http.MethodGet, endpoint, nil) if err != nil { return nil, err } var resp cfListResponse if err := json.Unmarshal(body, &resp); err != nil { return nil, fmt.Errorf("decode find response: %w", err) } if len(resp.Result) == 0 { return nil, nil } return &resp.Result[0], nil } func (c *Cloudflare) createRecord(ctx context.Context, fqdn, ip string) (*cfDNSRecord, error) { payload := cfDNSRecordRequest{ Type: "A", Name: fqdn, Content: ip, TTL: 1, // auto Proxied: false, } data, err := json.Marshal(payload) if err != nil { return nil, fmt.Errorf("marshal create payload: %w", err) } endpoint := fmt.Sprintf("%s/zones/%s/dns_records", cfBaseURL, c.zoneID) body, err := c.doRequest(ctx, http.MethodPost, endpoint, data) if err != nil { return nil, err } var resp cfSingleResponse if err := json.Unmarshal(body, &resp); err != nil { return nil, fmt.Errorf("decode create response: %w", err) } if !resp.Success { return nil, fmt.Errorf("create failed: %s", formatErrors(resp.Errors)) } return &resp.Result, nil } func (c *Cloudflare) updateRecord(ctx context.Context, recordID, fqdn, ip string) (*cfDNSRecord, error) { payload := cfDNSRecordRequest{ Type: "A", Name: fqdn, Content: ip, TTL: 1, Proxied: false, } data, err := json.Marshal(payload) if err != nil { return nil, fmt.Errorf("marshal update payload: %w", err) } endpoint := fmt.Sprintf("%s/zones/%s/dns_records/%s", cfBaseURL, c.zoneID, recordID) body, err := c.doRequest(ctx, http.MethodPut, endpoint, data) if err != nil { return nil, err } var resp cfSingleResponse if err := json.Unmarshal(body, &resp); err != nil { return nil, fmt.Errorf("decode update response: %w", err) } if !resp.Success { return nil, fmt.Errorf("update failed: %s", formatErrors(resp.Errors)) } return &resp.Result, nil } func (c *Cloudflare) doRequest(ctx context.Context, method, endpoint string, payload []byte) ([]byte, error) { var bodyReader io.Reader if payload != nil { bodyReader = bytes.NewReader(payload) } req, err := http.NewRequestWithContext(ctx, method, endpoint, bodyReader) if err != nil { return nil, fmt.Errorf("create request: %w", err) } req.Header.Set("Authorization", "Bearer "+c.token) req.Header.Set("Content-Type", "application/json") resp, err := c.client.Do(req) if err != nil { return nil, fmt.Errorf("http request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("read response: %w", err) } if resp.StatusCode >= 400 { var errResp cfBaseResponse if json.Unmarshal(body, &errResp) == nil && len(errResp.Errors) > 0 { return nil, fmt.Errorf("cloudflare API error (%d): %s", resp.StatusCode, formatErrors(errResp.Errors)) } return nil, fmt.Errorf("cloudflare API error (%d): %s", resp.StatusCode, string(body)) } return body, nil } // --- Cloudflare API response types --- type cfBaseResponse struct { Success bool `json:"success"` Errors []cfError `json:"errors"` } type cfError struct { Code int `json:"code"` Message string `json:"message"` } type cfDNSRecord struct { ID string `json:"id"` Type string `json:"type"` Name string `json:"name"` Content string `json:"content"` TTL int `json:"ttl"` Proxied bool `json:"proxied"` } type cfDNSRecordRequest struct { Type string `json:"type"` Name string `json:"name"` Content string `json:"content"` TTL int `json:"ttl"` Proxied bool `json:"proxied"` } type cfResultInfo struct { TotalPages int `json:"total_pages"` } type cfListResponse struct { cfBaseResponse Result []cfDNSRecord `json:"result"` ResultInfo cfResultInfo `json:"result_info"` } type cfSingleResponse struct { cfBaseResponse Result cfDNSRecord `json:"result"` } type cfZone struct { ID string `json:"id"` Name string `json:"name"` } type cfZonesResponse struct { cfBaseResponse Result []cfZone `json:"result"` ResultInfo cfResultInfo `json:"result_info"` } func formatErrors(errs []cfError) string { if len(errs) == 0 { return "unknown error" } msg := errs[0].Message for _, e := range errs[1:] { msg += "; " + e.Message } return msg }