From b1626cd1980f2febdd46228ee422c2a169b88266 Mon Sep 17 00:00:00 2001 From: Adora Laura Kalb Date: Thu, 11 Jul 2024 16:22:58 +0200 Subject: [PATCH] refactor a lot --- internal/certificates/certificates.go | 210 ++++++++++++------- internal/certificates/certificates_test.go | 1 - internal/configuration/models.go | 34 ++- internal/configuration/substitutions.go | 17 ++ internal/configuration/substitutions_test.go | 65 ++++++ internal/configuration/validation.go | 37 ++-- internal/constants/constants.go | 3 +- 7 files changed, 262 insertions(+), 105 deletions(-) delete mode 100644 internal/certificates/certificates_test.go create mode 100644 internal/configuration/substitutions.go create mode 100644 internal/configuration/substitutions_test.go diff --git a/internal/certificates/certificates.go b/internal/certificates/certificates.go index 4d29e07..66108ec 100644 --- a/internal/certificates/certificates.go +++ b/internal/certificates/certificates.go @@ -17,126 +17,187 @@ import ( "code.lila.network/adoralaura/certwarden-deploy/internal/configuration" "code.lila.network/adoralaura/certwarden-deploy/internal/constants" - "github.com/getsentry/sentry-go" ) func HandleCertificates(logger *slog.Logger, config *configuration.ConfigFileData) { for _, cert := range config.Certificates { - certBytes, err := getCertFromServer( - logger, - cert.Name, - cert.ApiKey, - config.BaseURL, - config.DisableCertificateValidation, - ) - if err != nil { - logger.Error("Failed to get certificate from server", "cert-id", cert.Name, "error", err) - return + certInfos := GenericCertificate{ + Name: cert.Name, + FilePath: cert.CertificatePath, + Secret: cert.CertificateSecret, + IsKey: false, } - certIsDifferent, err := checkCertIsDifferent(logger, cert.FilePath, certBytes) - if err != nil { - logger.Error("failed to handle certificate", "cert-id", cert.Name, "error", err) - return + keyInfos := GenericCertificate{ + Name: cert.Name, + FilePath: cert.KeyPath, + Secret: cert.KeySecret, + IsKey: true, } - if certIsDifferent || configuration.Force { - if configuration.Force { - logger.Info("Forcing file system change due to --force", "cert-id", cert.Name) - } + // Rollout Certificate + certOnDiskChanged, err := certInfos.Rollout(logger, config.BaseURL, config.DisableCertificateValidation) + if err != nil { + logger.Error( + "Failed to roll out Certificate", "path", + certInfos.FilePath, "name", cert.Name, "error", err, + ) + continue + } - err = updateCertOnFS(logger, cert.FilePath, certBytes) - if err != nil { - logger.Error("failed to handle certificate", "cert-id", cert.Name, "error", err) - return - } + // Rollout Key + keyOnDiskChanged, err := keyInfos.Rollout(logger, config.BaseURL, config.DisableCertificateValidation) + if err != nil { + logger.Error( + "Failed to roll out Key", "path", + keyInfos.FilePath, "name", cert.Name, "error", err, + ) + continue + } + + // if cert OR key changed OR --force + if (certOnDiskChanged || keyOnDiskChanged) || configuration.Force { if configuration.Force { - logger.Info("Forcing file system change due to --force", "cert-id", cert.Name) + logger.Info("Forcing file system change due to --force", "name", cert.Name) } - err = handleCertificateAction(cert) + err = handleCertificateAction(cert.Action) if err != nil { - logger.Error("post certificate change command failed", "cert-id", cert.Name, "error", err) + logger.Error("Failed to execute post-rollout action", "name", cert.Name, "error", err) } } - if certIsDifferent { - logger.Info("New certificate rolled out", "cert-id", cert.Name) - } else { - logger.Info("Certificate not changed, skipping...", "cert-id", cert.Name) - } - } } -func getCertFromFile(path string) ([]byte, error) { - filebytes, err := os.ReadFile(path) +// Rollout handles getting the certificate/key data from the +// server and writing it to disk if the data differs. +// +// Returns error on error, true if certificate action needs to be executed, false if not +func (c *GenericCertificate) Rollout(logger *slog.Logger, baseUrl string, skipInsecure bool) (bool, error) { + err := c.fetchFromServer( + logger, + baseUrl, + skipInsecure, + ) + if err != nil { + return false, fmt.Errorf("failed to get certificate from server: %w", err) + } + + fileNeedsRollout, err := c.needsRollout(logger) + if err != nil { + return false, fmt.Errorf("failed to check certificate on disk: %w", err) + } + + if fileNeedsRollout || configuration.Force { + if configuration.Force { + logger.Info("Forcing file system change due to --force", "name", c.Name) + } + + err = c.writeToDisk(logger) + if err != nil { + return false, fmt.Errorf("failed to handle certificate: %w", err) + } + + } + if fileNeedsRollout { + logger.Info("New file deployed", "path", c.FilePath) + return true, nil + } else { + logger.Info("File not changed, skipping...", "path", c.FilePath) + return false, nil + } +} + +// readFromDisk reads file data from disk and populates the data []byte field. +// +// Returns error or nil on success +func (c *GenericCertificate) readFromDisk() error { + filebytes, err := os.ReadFile(c.FilePath) if err != nil { if errors.Is(err, fs.ErrNotExist) { - return []byte{}, err + return err } else { - return []byte{}, fmt.Errorf("failed to read certificate file on disk: %w", err) + return fmt.Errorf("failed to read file from disk: %w", err) } } - return filebytes, nil + + c.diskBytes = filebytes + return nil } -func checkCertIsDifferent(logger *slog.Logger, path string, data []byte) (bool, error) { - filebytes, err := getCertFromFile(path) +// needsRollout checks the data []bytes against the data on disk. +// +// Returns true if file needs rollout, false if not +func (c *GenericCertificate) needsRollout(logger *slog.Logger) (bool, error) { + err := c.readFromDisk() if err != nil { if errors.Is(err, fs.ErrNotExist) { return true, nil } else { - return false, fmt.Errorf("failed to compare certificates: %w", err) + return false, fmt.Errorf("failed to compare data to file on disk: %w", err) } } - existingSha256 := sha256.Sum256(filebytes) - newSha256 := sha256.Sum256(data) + diskHash := sha256.Sum256(c.diskBytes) + serverHash := sha256.Sum256(c.serverBytes) - sumsAreDifferent := existingSha256 != newSha256 - if sumsAreDifferent { - logger.Debug("Certificate on file differs from the certificate on the server", "cert-path", path) + hashesAreDifferent := diskHash != serverHash + if hashesAreDifferent { + logger.Debug("File on disk differs from server source", "path", c.FilePath) } else { - logger.Debug("Certificate on file is identical to the certificate on the server", "cert-path", path) + logger.Debug("File on disk is identical to server source", "path", c.FilePath) } - return sumsAreDifferent, nil + return hashesAreDifferent, nil } -func updateCertOnFS(logger *slog.Logger, path string, data []byte) error { +// writeToDisk flushes the certificate data to disk. +// +// Returns error or nil on success. +func (c *GenericCertificate) writeToDisk(logger *slog.Logger) error { if configuration.DryRun { - logger.Debug("DRY-RUN: writing certificate data to file", "cert-path", path) + logger.Debug("DRY-RUN: writing data to file", "path", c.FilePath) return nil } - file, err := os.Create(path) + file, err := os.Create(c.FilePath) if err != nil { - return fmt.Errorf("failed to open certificate for writing: %w", err) + return fmt.Errorf("failed to open file for writing: %w", err) } defer func(l *slog.Logger) { if err := file.Close(); err != nil { - l.Error("failed to close file", "file-path", path, "error", err) + l.Error("failed to close file", "path", c.FilePath, "error", err) } }(logger) w := bufio.NewWriter(file) - if _, err := w.Write(data); err != nil { - return fmt.Errorf("failed to write certificate data to file: %w", err) + if _, err := w.Write(c.serverBytes); err != nil { + return fmt.Errorf("failed to write data to file: %w", err) } if err = w.Flush(); err != nil { - return fmt.Errorf("failed to flush certificate data to file: %w", err) + return fmt.Errorf("failed to flush data to file: %w", err) } - logger.Debug("wrote certificate to file", "file-path", path) + logger.Debug("Successfully wrote to file", "path", c.FilePath) return nil } -func getCertFromServer(logger *slog.Logger, certName string, certKey string, baseUrl string, skipInsecure bool) ([]byte, error) { - url := baseUrl + constants.CertificateApiPath + certName +// fetchFromServer fetches the cert/key data from the CertWarden server and +// fills the serverBytes field. +// +// Returns error or nil on success. +func (c *GenericCertificate) fetchFromServer(logger *slog.Logger, baseUrl string, skipInsecure bool) error { + var url string + if c.IsKey { + url = baseUrl + constants.CertificateApiPath + c.Name + } else { + url = baseUrl + constants.CertificateApiPath + c.Name + } + logger.Debug("Certificate request URL: " + url) var transport http.RoundTripper @@ -155,17 +216,15 @@ func getCertFromServer(logger *slog.Logger, certName string, certKey string, bas } req, err := http.NewRequest("GET", url, nil) if err != nil { - return []byte{}, fmt.Errorf("failed to prepare to request certificate from server: %w", err) + return fmt.Errorf("failed to prepare to request certificate from server: %w", err) } req.Header.Set("User-Agent", constants.UserAgent) - req.Header.Add(constants.ApiKeyHeaderName, certKey) + req.Header.Add(constants.ApiKeyHeaderName, c.Secret) res, err := client.Do(req) if err != nil { - e := fmt.Errorf("failed to request certificate from server: %w", err) - sentry.CaptureException(e) - return []byte{}, e + return fmt.Errorf("failed to request certificate from server: %w", err) } defer func(l *slog.Logger) { @@ -175,30 +234,27 @@ func getCertFromServer(logger *slog.Logger, certName string, certKey string, bas }(logger) if res.StatusCode == http.StatusUnauthorized { - logger.Error("API-Key for Certificate is invalid, skipping certificate!", "cert-id", certName) - return []byte{}, errors.New("API-Key invalid") + logger.Error("API-Key for Certificate is invalid, skipping certificate!", "name", c.Name) + return errors.New("API-Key invalid") } else if res.StatusCode != http.StatusOK { - logger.Error("failed to get certificate from server", "cert-id", certName, "http-response", res.Status) + logger.Error("failed to get certificate from server", "name", c.Name, "http-response", res.Status) } - body, err := io.ReadAll(res.Body) + bodyBytes, err := io.ReadAll(res.Body) if err != nil { - e := fmt.Errorf("failed to read certificate response from server: %w", err) - sentry.CaptureException(e) - return []byte{}, e + return fmt.Errorf("failed to read certificate response from server: %w", err) } - return body, nil + c.serverBytes = bodyBytes + return nil } -func handleCertificateAction(cert configuration.CertificateData) error { - if cert.Action == "" { +// handleCertificateAction executes the user-defined action after successful certificate deployment +func handleCertificateAction(action string) error { + if action == "" { return nil } - action := strings.ReplaceAll(cert.Action, "{name}", cert.Name) - action = strings.ReplaceAll(action, "{path}", cert.FilePath) - sargs := strings.Split(action, " ") cmd := exec.Command(sargs[0], sargs[1:]...) diff --git a/internal/certificates/certificates_test.go b/internal/certificates/certificates_test.go deleted file mode 100644 index 1f9bfc3..0000000 --- a/internal/certificates/certificates_test.go +++ /dev/null @@ -1 +0,0 @@ -package certificates diff --git a/internal/configuration/models.go b/internal/configuration/models.go index cbd5bae..beaf2a0 100644 --- a/internal/configuration/models.go +++ b/internal/configuration/models.go @@ -1,5 +1,7 @@ package configuration +import "log/slog" + // Config file gets read into here var Config *ConfigFileData @@ -28,12 +30,36 @@ type ConfigFileData struct { // Struct that holds the details of a single managed certificate type CertificateData struct { - Name string `yaml:"name"` - ApiKey string `yaml:"api_key"` - Action string `yaml:"action"` - FilePath string `yaml:"file_path"` + Name string `yaml:"name"` + CertificateSecret string `yaml:"cert_secret"` + CertificatePath string `yaml:"cert_path"` + KeySecret string `yaml:"key_secret"` + KeyPath string `yaml:"key_path"` + Action string `yaml:"action"` } type SentryData struct { DSN string `yaml:"dsn"` } + +type ConfigValidationError struct { + ErrorMessages []string +} + +func (e *ConfigValidationError) Error() string { + return "Configuration file has errors! Application cannot start unless the errors are corrected." +} + +func (e *ConfigValidationError) Add(msg string) { + e.ErrorMessages = append(e.ErrorMessages, msg) +} + +func (e *ConfigValidationError) HasMessages() bool { + return len(e.ErrorMessages) == 0 +} + +func (e *ConfigValidationError) Print(logger *slog.Logger) { + for _, line := range e.ErrorMessages { + logger.Error(line) + } +} diff --git a/internal/configuration/substitutions.go b/internal/configuration/substitutions.go new file mode 100644 index 0000000..29927c7 --- /dev/null +++ b/internal/configuration/substitutions.go @@ -0,0 +1,17 @@ +package configuration + +import ( + "log/slog" + "strings" +) + +func (c *ConfigFileData) SubstituteKeys(logger *slog.Logger) { + for index, cert := range c.Certificates { + c.Certificates[index].CertificatePath = strings.ReplaceAll(cert.CertificatePath, "{name}", c.Certificates[index].Name) + c.Certificates[index].KeyPath = strings.ReplaceAll(cert.KeyPath, "{name}", c.Certificates[index].Name) + + c.Certificates[index].Action = strings.ReplaceAll(cert.Action, "{name}", c.Certificates[index].Name) + c.Certificates[index].Action = strings.ReplaceAll(c.Certificates[index].Action, "{cert_path}", c.Certificates[index].CertificatePath) + c.Certificates[index].Action = strings.ReplaceAll(c.Certificates[index].Action, "{key_path}", c.Certificates[index].KeyPath) + } +} diff --git a/internal/configuration/substitutions_test.go b/internal/configuration/substitutions_test.go new file mode 100644 index 0000000..6fcdbcb --- /dev/null +++ b/internal/configuration/substitutions_test.go @@ -0,0 +1,65 @@ +package configuration + +import ( + "testing" +) + +// TestStringSubstitutionWithPlaceholders tests the string substitution feature. +// It ensures that {name}, {cert_path} and {key_path} get substituted correctly. +func TestStringSubstitutionWithPlaceholders(t *testing.T) { + cert := CertificateData{ + Name: "qwer", + CertificatePath: "/fake/path/{name}", + KeyPath: "/fake/path/{name}-key", + Action: "./fake action {cert_path} {key_path}", + } + + cfg := ConfigFileData{ + Certificates: []CertificateData{cert}, + } + + cfg.SubstituteKeys(nil) + + if cfg.Certificates[0].CertificatePath != "/fake/path/qwer" { + t.Fail() + t.Logf(`CertificatePath = %q, want "/fake/path/qwer"`, cfg.Certificates[0].CertificatePath) + } + if cfg.Certificates[0].KeyPath != "/fake/path/qwer-key" { + t.Fail() + t.Logf(`KeyPath = %q, want "/fake/path/qwer-key"`, cfg.Certificates[0].KeyPath) + } + if cfg.Certificates[0].Action != "./fake action /fake/path/qwer /fake/path/qwer-key" { + t.Fail() + t.Logf(`Action = %q, want "./fake action /fake/path/qwer /fake/path/qwer-key"`, cfg.Certificates[0].Action) + } +} + +// TestStringSubstitutionWithPlaceholders tests the string substitution feature. +// It ensures that if no substitutes are present, the config values are not changed. +func TestStringSubstitutionWithoutPlaceholders(t *testing.T) { + cert := CertificateData{ + Name: "qwer", + CertificatePath: "/fake/path/asd", + KeyPath: "/fake/path/asdf-key", + Action: "./fake action abcd efgh", + } + + cfg := ConfigFileData{ + Certificates: []CertificateData{cert}, + } + + cfg.SubstituteKeys(nil) + + if cfg.Certificates[0].CertificatePath != "/fake/path/asd" { + t.Fail() + t.Logf(`CertificatePath = %q, want "/fake/path/asd"`, cfg.Certificates[0].CertificatePath) + } + if cfg.Certificates[0].KeyPath != "/fake/path/asdf-key" { + t.Fail() + t.Logf(`KeyPath = %q, want "/fake/path/asdf-key"`, cfg.Certificates[0].KeyPath) + } + if cfg.Certificates[0].Action != "./fake action abcd efgh" { + t.Fail() + t.Logf(`Action = %q, want "./fake action abcd efgh"`, cfg.Certificates[0].Action) + } +} diff --git a/internal/configuration/validation.go b/internal/configuration/validation.go index f743f54..5c9e665 100644 --- a/internal/configuration/validation.go +++ b/internal/configuration/validation.go @@ -1,45 +1,38 @@ package configuration import ( - "log/slog" - "os" "regexp" ) -func ValidateConfig(logger *slog.Logger, config ConfigFileData) { - validationFailed := false +// IsValid tests if the config read from file has all required parameters set. +// +// Exits the app if errors are detected +func (c *ConfigFileData) IsValid() ConfigValidationError { + err := ConfigValidationError{} - if config.BaseURL == "" { - logger.Error(`Field 'base_url' in config file is required!`) - validationFailed = true + if c.BaseURL == "" { + err.Add(`Field 'base_url' in config file is required!`) } - for _, cert := range config.Certificates { + for _, cert := range c.Certificates { if cert.Name == "" { cert.Name = "unnamed_certificate" - logger.Error(`Field 'name' for certificates cannot be blank!`) - validationFailed = true + err.Add(`Field 'name' for certificates cannot be blank!`) } - if cert.ApiKey == "" { - logger.Error(`Field 'api_key' for certificate ` + cert.Name + " cannot be blank!") - validationFailed = true + if cert.CertificateSecret == "" { + err.Add(`Field 'cert_secret' for certificate ` + cert.Name + " cannot be blank!") } - if cert.FilePath == "" { - logger.Error(`Field 'file_path' for certificate ` + cert.Name + " cannot be blank!") - validationFailed = true + if cert.CertificatePath == "" { + err.Add(`Field 'cert_path' for certificate ` + cert.Name + " cannot be blank!") } re := regexp.MustCompile(`^[a-zA-Z0-9._-]+$`) if !re.MatchString(cert.Name) { - logger.Error(`Field 'name' for certificate may only contain -_. and alphanumeric characters!`) - validationFailed = true + err.Add(`Field 'name' for certificate may only contain -_. and alphanumeric characters!`) } } - if validationFailed { - logger.Error("Config file has errors! Please fix errors above! Exiting...", "config-path", ConfigFile) - os.Exit(1) - } + return err } diff --git a/internal/constants/constants.go b/internal/constants/constants.go index 204e740..d010d0c 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -1,6 +1,7 @@ package constants -const Version = "0.1.1" +const Version = "0.2.0" const CertificateApiPath = "/certwarden/api/v1/download/certificates/" +const KeyApiPath = "/certwarden/api/v1/download/privatekeys/" const ApiKeyHeaderName = "X-API-Key" const UserAgent = "certwarden-deploy/" + Version + " +https://code.lila.network/adoralaura/certwarden-deploy"