diff --git a/initialize/config.go b/initialize/config.go index 890fd5d..76a326f 100644 --- a/initialize/config.go +++ b/initialize/config.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "log" - "path" "github.com/coreos/coreos-cloudinit/third_party/launchpad.net/goyaml" @@ -221,11 +220,11 @@ func Apply(cfg CloudConfig, env *Environment) error { } for _, file := range cfg.WriteFiles { - file.Path = path.Join(env.Root(), file.Path) - if err := system.WriteFile(&file); err != nil { + path, err := system.WriteFile(&file, env.Root()) + if err != nil { return err } - log.Printf("Wrote file %s to filesystem", file.Path) + log.Printf("Wrote file %s to filesystem", path) } commands := make(map[string]string, 0) diff --git a/initialize/manage_etc_hosts_test.go b/initialize/manage_etc_hosts_test.go index 1fbac07..3e23b10 100644 --- a/initialize/manage_etc_hosts_test.go +++ b/initialize/manage_etc_hosts_test.go @@ -50,9 +50,7 @@ func TestEtcHostsWrittenToDisk(t *testing.T) { t.Fatalf("manageEtcHosts returned nil file unexpectedly") } - f.Path = path.Join(dir, f.Path) - - if err := system.WriteFile(f); err != nil { + if _, err := system.WriteFile(f, dir); err != nil { t.Fatalf("Error writing EtcHosts: %v", err) } diff --git a/initialize/oem_test.go b/initialize/oem_test.go index 9f46215..a2eae46 100644 --- a/initialize/oem_test.go +++ b/initialize/oem_test.go @@ -31,8 +31,7 @@ func TestOEMReleaseWrittenToDisk(t *testing.T) { t.Fatalf("OEMRelease returned nil file unexpectedly") } - f.Path = path.Join(dir, f.Path) - if err := system.WriteFile(f); err != nil { + if _, err := system.WriteFile(f, dir); err != nil { t.Fatalf("Writing of OEMRelease failed: %v", err) } diff --git a/initialize/update_test.go b/initialize/update_test.go index 1d2106d..ae72eaf 100644 --- a/initialize/update_test.go +++ b/initialize/update_test.go @@ -205,8 +205,7 @@ func TestUpdateConfWrittenToDisk(t *testing.T) { t.Fatal("Unexpectedly got nil updateconfig file") } - f.Path = path.Join(dir, f.Path) - if err := system.WriteFile(f); err != nil { + if _, err := system.WriteFile(f, dir); err != nil { t.Fatalf("Error writing update config: %v", err) } diff --git a/initialize/workspace.go b/initialize/workspace.go index 3f21d5a..ed1cf93 100644 --- a/initialize/workspace.go +++ b/initialize/workspace.go @@ -3,6 +3,7 @@ package initialize import ( "io/ioutil" "path" + "strings" "github.com/coreos/coreos-cloudinit/system" ) @@ -28,21 +29,23 @@ func PersistScriptInWorkspace(script system.Script, workspace string) (string, e } tmp.Close() + relpath := strings.TrimPrefix(tmp.Name(), workspace) + file := system.File{ - Path: tmp.Name(), + Path: relpath, RawFilePermissions: "0744", - Content: string(script), + Content: string(script), } - err = system.WriteFile(&file) - return file.Path, err + return system.WriteFile(&file, workspace) } func PersistUnitNameInWorkspace(name string, workspace string) error { file := system.File{ - Path: path.Join(workspace, "scripts", "unit-name"), + Path: path.Join("scripts", "unit-name"), RawFilePermissions: "0644", - Content: name, + Content: name, } - return system.WriteFile(&file) + _, err := system.WriteFile(&file, workspace) + return err } diff --git a/system/file.go b/system/file.go index d8d224f..e9f40b2 100644 --- a/system/file.go +++ b/system/file.go @@ -31,33 +31,55 @@ func (f *File) Permissions() (os.FileMode, error) { return os.FileMode(perm), nil } -func WriteFile(f *File) error { +func WriteFile(f *File, root string) (string, error) { if f.Encoding != "" { - return fmt.Errorf("Unable to write file with encoding %s", f.Encoding) + return "", fmt.Errorf("Unable to write file with encoding %s", f.Encoding) } - if err := os.MkdirAll(path.Dir(f.Path), os.FileMode(0755)); err != nil { - return err + fullpath := path.Join(root, f.Path) + dir := path.Dir(fullpath) + + if err := EnsureDirectoryExists(dir); err != nil { + return "", err } perm, err := f.Permissions() if err != nil { - return err + return "", err } - if err := ioutil.WriteFile(f.Path, []byte(f.Content), perm); err != nil { - return err + var tmp *os.File + // Create a temporary file in the same directory to ensure it's on the same filesystem + if tmp, err = ioutil.TempFile(dir, "cloudinit-temp"); err != nil { + return "", err + } + + if err := ioutil.WriteFile(tmp.Name(), []byte(f.Content), perm); err != nil { + return "", err + } + + if err := tmp.Close(); err != nil { + return "", err + } + + // Ensure the permissions are as requested (since WriteFile can be affected by sticky bit) + if err := os.Chmod(tmp.Name(), perm); err != nil { + return "", err } if f.Owner != "" { // We shell out since we don't have a way to look up unix groups natively - cmd := exec.Command("chown", f.Owner, f.Path) + cmd := exec.Command("chown", f.Owner, tmp.Name()) if err := cmd.Run(); err != nil { - return err + return "", err } } - return nil + if err := os.Rename(tmp.Name(), fullpath); err != nil { + return "", err + } + + return fullpath, nil } func EnsureDirectoryExists(dir string) error { diff --git a/system/file_test.go b/system/file_test.go index c8cd3d0..949ac2a 100644 --- a/system/file_test.go +++ b/system/file_test.go @@ -4,7 +4,6 @@ import ( "io/ioutil" "os" "path" - "syscall" "testing" ) @@ -13,18 +12,22 @@ func TestWriteFileUnencodedContent(t *testing.T) { if err != nil { t.Fatalf("Unable to create tempdir: %v", err) } - defer syscall.Rmdir(dir) + defer os.RemoveAll(dir) - fullPath := path.Join(dir, "tmp", "foo") + fn := "foo" + fullPath := path.Join(dir, fn) wf := File{ - Path: fullPath, - Content: "bar", + Path: fn, + Content: "bar", RawFilePermissions: "0644", } - if err := WriteFile(&wf); err != nil { + path, err := WriteFile(&wf, dir) + if err != nil { t.Fatalf("Processing of WriteFile failed: %v", err) + } else if path != fullPath { + t.Fatalf("WriteFile returned bad path: want %s, got %s", fullPath, path) } fi, err := os.Stat(fullPath) @@ -51,15 +54,15 @@ func TestWriteFileInvalidPermission(t *testing.T) { if err != nil { t.Fatalf("Unable to create tempdir: %v", err) } - defer syscall.Rmdir(dir) + defer os.RemoveAll(dir) wf := File{ - Path: path.Join(dir, "tmp", "foo"), - Content: "bar", + Path: path.Join(dir, "tmp", "foo"), + Content: "bar", RawFilePermissions: "pants", } - if err := WriteFile(&wf); err == nil { + if _, err := WriteFile(&wf, dir); err == nil { t.Fatalf("Expected error to be raised when writing file with invalid permission") } } @@ -69,17 +72,21 @@ func TestWriteFilePermissions(t *testing.T) { if err != nil { t.Fatalf("Unable to create tempdir: %v", err) } - defer syscall.Rmdir(dir) + defer os.RemoveAll(dir) - fullPath := path.Join(dir, "tmp", "foo") + fn := "foo" + fullPath := path.Join(dir, fn) wf := File{ - Path: fullPath, + Path: fn, RawFilePermissions: "0755", } - if err := WriteFile(&wf); err != nil { + path, err := WriteFile(&wf, dir) + if err != nil { t.Fatalf("Processing of WriteFile failed: %v", err) + } else if path != fullPath { + t.Fatalf("WriteFile returned bad path: want %s, got %s", fullPath, path) } fi, err := os.Stat(fullPath) @@ -97,15 +104,15 @@ func TestWriteFileEncodedContent(t *testing.T) { if err != nil { t.Fatalf("Unable to create tempdir: %v", err) } - defer syscall.Rmdir(dir) + defer os.RemoveAll(dir) wf := File{ - Path: path.Join(dir, "tmp", "foo"), - Content: "", + Path: path.Join(dir, "tmp", "foo"), + Content: "", Encoding: "base64", } - if err := WriteFile(&wf); err == nil { + if _, err := WriteFile(&wf, dir); err == nil { t.Fatalf("Expected error to be raised when writing file with encoding") } } diff --git a/system/systemd.go b/system/systemd.go index a867219..0e81420 100644 --- a/system/systemd.go +++ b/system/systemd.go @@ -78,12 +78,12 @@ func PlaceUnit(u *Unit, dst string) error { } file := File{ - Path: dst, + Path: filepath.Base(dst), Content: u.Content, RawFilePermissions: "0644", } - err := WriteFile(&file) + _, err := WriteFile(&file, dir) if err != nil { return err }