diff --git a/system/ssh_key.go b/system/ssh_key.go index a75896c..5a76076 100644 --- a/system/ssh_key.go +++ b/system/ssh_key.go @@ -15,11 +15,65 @@ package system import ( + "bufio" "fmt" "os" "strings" ) +func diffLines(src, dst []string) []string { + var tgt []string + + mb := map[string]bool{} + + for _, x := range src { + mb[x] = true + } + + for _, x := range dst { + if _, ok := mb[x]; !ok { + mb[x] = true + } + } + + for k, _ := range mb { + tgt = append(tgt, k) + } + + return tgt +} + +func readLines(path string) ([]string, error) { + var lines []string + + file, err := os.Open(path) + if err != nil { + return lines, err + } + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + return lines, scanner.Err() +} + +// writeLines writes the lines to the given file. +func writeLines(lines []string, path string) error { + file, err := os.Create(path) + if err != nil { + return err + } + defer file.Close() + + w := bufio.NewWriter(file) + for _, line := range lines { + fmt.Fprintln(w, line) + } + return w.Flush() +} + // Add the provide SSH public key to the core user's list of // authorized keys func AuthorizeSSHKeys(user string, keysName string, keys []string) error { @@ -29,7 +83,7 @@ func AuthorizeSSHKeys(user string, keysName string, keys []string) error { // join all keys with newlines, ensuring the resulting string // also ends with a newline - joined := fmt.Sprintf("%s\n", strings.Join(keys, "\n")) + // joined := fmt.Sprintf("%s\n", strings.Join(keys, "\n")) home, err := UserHome(user) if err != nil { @@ -43,12 +97,12 @@ func AuthorizeSSHKeys(user string, keysName string, keys []string) error { } authorized_file := fmt.Sprintf("%s/.ssh/authorized_keys", home) - f, err := os.OpenFile(authorized_file, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) - if err != nil { - return err + var newkeys []string + for _, x := range keys { + newkeys = append(newkeys, strings.Split(x, "\n")...) } - defer f.Close() - _, err = f.WriteString(joined) + oldkeys, _ := readLines(authorized_file) - return err + diffkeys := diffLines(oldkeys, newkeys) + return writeLines(diffkeys, authorized_file) }