From a6aa9f82b85a9c6418061d0319fe81ca55dc3999 Mon Sep 17 00:00:00 2001 From: Jonathan Boulle Date: Tue, 3 Jun 2014 16:49:26 -0700 Subject: [PATCH] fix(systemd): unmask runtime units when mask=False --- initialize/config.go | 9 +++- initialize/etcd_test.go | 4 +- system/systemd.go | 51 +++++++++++++++--- system/systemd_test.go | 113 ++++++++++++++++++++++++++++++++++++---- 4 files changed, 155 insertions(+), 22 deletions(-) diff --git a/initialize/config.go b/initialize/config.go index 76a326f..3f59e28 100644 --- a/initialize/config.go +++ b/initialize/config.go @@ -230,7 +230,7 @@ func Apply(cfg CloudConfig, env *Environment) error { commands := make(map[string]string, 0) reload := false for _, unit := range cfg.Coreos.Units { - dst := system.UnitDestination(&unit, env.Root()) + dst := unit.Destination(env.Root()) if unit.Content != "" { log.Printf("Writing unit %s to filesystem at path %s", unit.Name, dst) if err := system.PlaceUnit(&unit, dst); err != nil { @@ -242,7 +242,12 @@ func Apply(cfg CloudConfig, env *Environment) error { if unit.Mask { log.Printf("Masking unit file %s", unit.Name) - if err := system.MaskUnit(unit.Name, env.Root()); err != nil { + if err := system.MaskUnit(&unit, env.Root()); err != nil { + return err + } + } else if unit.Runtime { + log.Printf("Ensuring runtime unit file %s is unmasked", unit.Name) + if err := system.UnmaskUnit(&unit, env.Root()); err != nil { return err } } diff --git a/initialize/etcd_test.go b/initialize/etcd_test.go index 2282c9f..62fca6b 100644 --- a/initialize/etcd_test.go +++ b/initialize/etcd_test.go @@ -79,7 +79,7 @@ func TestEtcdEnvironmentWrittenToDisk(t *testing.T) { } u := uu[0] - dst := system.UnitDestination(&u, dir) + dst := u.Destination(dir) os.Stderr.WriteString("writing to " + dir + "\n") if err := system.PlaceUnit(&u, dst); err != nil { t.Fatalf("Writing of EtcdEnvironment failed: %v", err) @@ -134,7 +134,7 @@ func TestEtcdEnvironmentWrittenToDiskDefaultToMachineID(t *testing.T) { } u := uu[0] - dst := system.UnitDestination(&u, dir) + dst := u.Destination(dir) os.Stderr.WriteString("writing to " + dir + "\n") if err := system.PlaceUnit(&u, dst); err != nil { t.Fatalf("Writing of EtcdEnvironment failed: %v", err) diff --git a/system/systemd.go b/system/systemd.go index 0e81420..1b478fd 100644 --- a/system/systemd.go +++ b/system/systemd.go @@ -51,10 +51,10 @@ func (u *Unit) Group() (group string) { type Script []byte -// UnitDestination builds the appropriate absolute file path for -// the given Unit. The root argument indicates the effective base +// Destination builds the appropriate absolute file path for +// the Unit. The root argument indicates the effective base // directory of the system (similar to a chroot). -func UnitDestination(u *Unit, root string) string { +func (u *Unit) Destination(root string) string { dir := "etc" if u.Runtime { dir = "run" @@ -179,12 +179,12 @@ func MachineID(root string) string { return id } -// MaskUnit masks a Unit by the given name by symlinking its unit file (in -// /etc/systemd/system) to /dev/null, analogous to `systemctl mask` +// MaskUnit masks the given Unit by symlinking its unit file to +// /dev/null, analogous to `systemctl mask`. // N.B.: Unlike `systemctl mask`, this function will *remove any existing unit -// file* in /etc/systemd/system, to ensure that the mask will succeed. -func MaskUnit(unit string, root string) error { - masked := path.Join(root, "etc", "systemd", "system", unit) +// file at the location*, to ensure that the mask will succeed. +func MaskUnit(unit *Unit, root string) error { + masked := unit.Destination(root) if _, err := os.Stat(masked); os.IsNotExist(err) { if err := os.MkdirAll(path.Dir(masked), os.FileMode(0755)); err != nil { return err @@ -194,3 +194,38 @@ func MaskUnit(unit string, root string) error { } return os.Symlink("/dev/null", masked) } + +// UnmaskUnit is analogous to systemd's unit_file_unmask. If the file +// associated with the given Unit is empty or appears to be a symlink to +// /dev/null, it is removed. +func UnmaskUnit(unit *Unit, root string) error { + masked := unit.Destination(root) + ne, err := nullOrEmpty(masked) + if os.IsNotExist(err) { + return nil + } else if err != nil { + return err + } + if !ne { + log.Printf("%s is not null or empty, refusing to unmask", masked) + return nil + } + return os.Remove(masked) +} + +// nullOrEmpty checks whether a given path appears to be an empty regular file +// or a symlink to /dev/null +func nullOrEmpty(path string) (bool, error) { + fi, err := os.Stat(path) + if err != nil { + return false, err + } + m := fi.Mode() + if m.IsRegular() && fi.Size() <= 0 { + return true, nil + } + if m&os.ModeCharDevice > 0 { + return true, nil + } + return false, nil +} diff --git a/system/systemd_test.go b/system/systemd_test.go index 7bf3696..e715396 100644 --- a/system/systemd_test.go +++ b/system/systemd_test.go @@ -25,10 +25,10 @@ Address=10.209.171.177/19 } defer os.RemoveAll(dir) - dst := UnitDestination(&u, dir) + dst := u.Destination(dir) expectDst := path.Join(dir, "run", "systemd", "network", "50-eth0.network") if dst != expectDst { - t.Fatalf("UnitDestination returned %s, expected %s", dst, expectDst) + t.Fatalf("unit.Destination returned %s, expected %s", dst, expectDst) } if err := PlaceUnit(&u, dst); err != nil { @@ -69,18 +69,18 @@ func TestUnitDestination(t *testing.T) { DropIn: false, } - dst := UnitDestination(&u, dir) + dst := u.Destination(dir) expectDst := path.Join(dir, "etc", "systemd", "system", "foobar.service") if dst != expectDst { - t.Errorf("UnitDestination returned %s, expected %s", dst, expectDst) + t.Errorf("unit.Destination returned %s, expected %s", dst, expectDst) } u.DropIn = true - dst = UnitDestination(&u, dir) + dst = u.Destination(dir) expectDst = path.Join(dir, "etc", "systemd", "system", "foobar.service.d", cloudConfigDropIn) if dst != expectDst { - t.Errorf("UnitDestination returned %s, expected %s", dst, expectDst) + t.Errorf("unit.Destination returned %s, expected %s", dst, expectDst) } } @@ -100,10 +100,10 @@ Where=/media/state } defer os.RemoveAll(dir) - dst := UnitDestination(&u, dir) + dst := u.Destination(dir) expectDst := path.Join(dir, "etc", "systemd", "system", "media-state.mount") if dst != expectDst { - t.Fatalf("UnitDestination returned %s, expected %s", dst, expectDst) + t.Fatalf("unit.Destination returned %s, expected %s", dst, expectDst) } if err := PlaceUnit(&u, dst); err != nil { @@ -156,7 +156,8 @@ func TestMaskUnit(t *testing.T) { defer os.RemoveAll(dir) // Ensure mask works with units that do not currently exist - if err := MaskUnit("foo.service", dir); err != nil { + uf := &Unit{Name: "foo.service"} + if err := MaskUnit(uf, dir); err != nil { t.Fatalf("Unable to mask new unit: %v", err) } fooPath := path.Join(dir, "etc", "systemd", "system", "foo.service") @@ -169,11 +170,12 @@ func TestMaskUnit(t *testing.T) { } // Ensure mask works with unit files that already exist + ub := &Unit{Name: "bar.service"} barPath := path.Join(dir, "etc", "systemd", "system", "bar.service") if _, err := os.Create(barPath); err != nil { t.Fatalf("Error creating new unit file: %v", err) } - if err := MaskUnit("bar.service", dir); err != nil { + if err := MaskUnit(ub, dir); err != nil { t.Fatalf("Unable to mask existing unit: %v", err) } barTgt, err := os.Readlink(barPath) @@ -184,3 +186,94 @@ func TestMaskUnit(t *testing.T) { t.Fatalf("unit not masked, got unit target", barTgt) } } + +func TestUnmaskUnit(t *testing.T) { + dir, err := ioutil.TempDir(os.TempDir(), "coreos-cloudinit-") + if err != nil { + t.Fatalf("Unable to create tempdir: %v", err) + } + defer os.RemoveAll(dir) + + nilUnit := &Unit{Name: "null.service"} + if err := UnmaskUnit(nilUnit, dir); err != nil { + t.Errorf("unexpected error from unmasking nonexistent unit: %v", err) + } + + uf := &Unit{Name: "foo.service", Content: "[Service]\nExecStart=/bin/true"} + dst := uf.Destination(dir) + if err := os.MkdirAll(path.Dir(dst), os.FileMode(0755)); err != nil { + t.Fatalf("Unable to create unit directory: %v", err) + } + if _, err := os.Create(dst); err != nil { + t.Fatalf("Unable to write unit file: %v", err) + } + + if err := ioutil.WriteFile(dst, []byte(uf.Content), 700); err != nil { + t.Fatalf("Unable to write unit file: %v", err) + } + if err := UnmaskUnit(uf, dir); err != nil { + t.Errorf("unmask of non-empty unit returned unexpected error: %v", err) + } + got, _ := ioutil.ReadFile(dst) + if string(got) != uf.Content { + t.Errorf("unmask of non-empty unit mutated unit contents unexpectedly") + } + + ub := &Unit{Name: "bar.service"} + dst = ub.Destination(dir) + if err := os.Symlink("/dev/null", dst); err != nil { + t.Fatalf("Unable to create masked unit: %v", err) + } + if err := UnmaskUnit(ub, dir); err != nil { + t.Errorf("unmask of unit returned unexpected error: %v", err) + } + if _, err := os.Stat(dst); !os.IsNotExist(err) { + t.Errorf("expected %s to not exist after unmask, but got err: %s", err) + } +} + +func TestNullOrEmpty(t *testing.T) { + dir, err := ioutil.TempDir(os.TempDir(), "coreos-cloudinit-") + if err != nil { + t.Fatalf("Unable to create tempdir: %v", err) + } + defer os.RemoveAll(dir) + + non := path.Join(dir, "does_not_exist") + ne, err := nullOrEmpty(non) + if !os.IsNotExist(err) { + t.Errorf("nullOrEmpty on nonexistent file returned bad error: %v", err) + } + if ne { + t.Errorf("nullOrEmpty returned true unxpectedly") + } + + regEmpty := path.Join(dir, "regular_empty_file") + _, err = os.Create(regEmpty) + if err != nil { + t.Fatalf("Unable to create tempfile: %v", err) + } + gotNe, gotErr := nullOrEmpty(regEmpty) + if !gotNe || gotErr != nil { + t.Errorf("nullOrEmpty of regular empty file returned %t, %v - want true, nil", gotNe, gotErr) + } + + reg := path.Join(dir, "regular_file") + if err := ioutil.WriteFile(reg, []byte("asdf"), 700); err != nil { + t.Fatalf("Unable to create tempfile: %v", err) + } + gotNe, gotErr = nullOrEmpty(reg) + if gotNe || gotErr != nil { + t.Errorf("nullOrEmpty of regular file returned %t, %v - want false, nil", gotNe, gotErr) + } + + null := path.Join(dir, "null") + if err := os.Symlink(os.DevNull, null); err != nil { + t.Fatalf("Unable to create /dev/null link: %s", err) + } + gotNe, gotErr = nullOrEmpty(null) + if !gotNe || gotErr != nil { + t.Errorf("nullOrEmpty of null symlink returned %t, %v - want true, nil", gotNe, gotErr) + } + +}