diff --git a/cmd/dir.go b/cmd/dir.go new file mode 100644 index 0000000..e1d24e9 --- /dev/null +++ b/cmd/dir.go @@ -0,0 +1,27 @@ +package cmd + +import ( + "os" + "path/filepath" + + log "github.com/sirupsen/logrus" +) + +var ( + UserHomeDir string + CacheHomeDir string +) + +func init() { + home, err := os.UserHomeDir() + if err != nil { + log.Fatal(err) + } + UserHomeDir = home + + if v := os.Getenv("XDG_CACHE_HOME"); v != "" { + CacheHomeDir = v + } else { + CacheHomeDir = filepath.Join(UserHomeDir, ".cache") + } +} diff --git a/cmd/input.go b/cmd/input.go index 9327de2..ed9655c 100644 --- a/cmd/input.go +++ b/cmd/input.go @@ -42,6 +42,10 @@ type Input struct { artifactServerPath string artifactServerAddr string artifactServerPort string + noCacheServer bool + cacheServerPath string + cacheServerAddr string + cacheServerPort uint16 jsonLogger bool noSkipCheckout bool remoteName string diff --git a/cmd/notices.go b/cmd/notices.go index 9ddcf6f..a912bd9 100644 --- a/cmd/notices.go +++ b/cmd/notices.go @@ -132,16 +132,7 @@ func saveNoticesEtag(etag string) { } func etagPath() string { - var xdgCache string - var ok bool - if xdgCache, ok = os.LookupEnv("XDG_CACHE_HOME"); !ok || xdgCache == "" { - if home, err := os.UserHomeDir(); err == nil { - xdgCache = filepath.Join(home, ".cache") - } else if xdgCache, err = filepath.Abs("."); err != nil { - log.Fatal(err) - } - } - dir := filepath.Join(xdgCache, "act") + dir := filepath.Join(CacheHomeDir, "act") if err := os.MkdirAll(dir, 0o777); err != nil { log.Fatal(err) } diff --git a/cmd/root.go b/cmd/root.go index 548d90c..d5b8c39 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -20,6 +20,7 @@ import ( "github.com/spf13/cobra" "gopkg.in/yaml.v3" + "github.com/nektos/act/pkg/artifactcache" "github.com/nektos/act/pkg/artifacts" "github.com/nektos/act/pkg/common" "github.com/nektos/act/pkg/container" @@ -87,6 +88,10 @@ func Execute(ctx context.Context, version string) { rootCmd.PersistentFlags().StringVarP(&input.artifactServerAddr, "artifact-server-addr", "", common.GetOutboundIP().String(), "Defines the address to which the artifact server binds.") rootCmd.PersistentFlags().StringVarP(&input.artifactServerPort, "artifact-server-port", "", "34567", "Defines the port where the artifact server listens.") rootCmd.PersistentFlags().BoolVarP(&input.noSkipCheckout, "no-skip-checkout", "", false, "Do not skip actions/checkout") + rootCmd.PersistentFlags().BoolVarP(&input.noCacheServer, "no-cache-server", "", false, "Disable cache server") + rootCmd.PersistentFlags().StringVarP(&input.cacheServerPath, "cache-server-path", "", filepath.Join(CacheHomeDir, "actcache"), "Defines the path where the cache server stores caches.") + rootCmd.PersistentFlags().StringVarP(&input.cacheServerAddr, "cache-server-addr", "", common.GetOutboundIP().String(), "Defines the address to which the cache server binds.") + rootCmd.PersistentFlags().Uint16VarP(&input.cacheServerPort, "cache-server-port", "", 0, "Defines the port where the artifact server listens. 0 means a randomly available port.") rootCmd.SetArgs(args()) if err := rootCmd.Execute(); err != nil { @@ -95,11 +100,6 @@ func Execute(ctx context.Context, version string) { } func configLocations() []string { - home, err := os.UserHomeDir() - if err != nil { - log.Fatal(err) - } - configFileName := ".actrc" // reference: https://specifications.freedesktop.org/basedir-spec/latest/ar01s03.html @@ -112,7 +112,7 @@ func configLocations() []string { } return []string{ - filepath.Join(home, configFileName), + filepath.Join(UserHomeDir, configFileName), actrcXdg, filepath.Join(".", configFileName), } @@ -609,6 +609,17 @@ func newRunCommand(ctx context.Context, input *Input) func(*cobra.Command, []str cancel := artifacts.Serve(ctx, input.artifactServerPath, input.artifactServerAddr, input.artifactServerPort) + const cacheURLKey = "ACTIONS_CACHE_URL" + var cacheHandler *artifactcache.Handler + if !input.noCacheServer && envs[cacheURLKey] == "" { + var err error + cacheHandler, err = artifactcache.StartHandler(input.cacheServerPath, input.cacheServerAddr, input.cacheServerPort, common.Logger(ctx)) + if err != nil { + return err + } + envs[cacheURLKey] = cacheHandler.ExternalURL() + "/" + } + ctx = common.WithDryrun(ctx, input.dryrun) if watch, err := cmd.Flags().GetBool("watch"); err != nil { return err @@ -622,6 +633,7 @@ func newRunCommand(ctx context.Context, input *Input) func(*cobra.Command, []str executor := r.NewPlanExecutor(plan).Finally(func(ctx context.Context) error { cancel() + _ = cacheHandler.Close() return nil }) err = executor(ctx) diff --git a/go.mod b/go.mod index 6073aab..4a2a62f 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,8 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.2 + github.com/timshannon/bolthold v0.0.0-20210913165410-232392fc8a6a + go.etcd.io/bbolt v1.3.7 golang.org/x/term v0.7.0 gopkg.in/yaml.v3 v3.0.1 gotest.tools/v3 v3.4.0 diff --git a/go.sum b/go.sum index 66cdab5..7db8dc1 100644 --- a/go.sum +++ b/go.sum @@ -202,6 +202,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= +github.com/timshannon/bolthold v0.0.0-20210913165410-232392fc8a6a h1:oIi7H/bwFUYKYhzKbHc+3MvHRWqhQwXVB4LweLMiVy0= +github.com/timshannon/bolthold v0.0.0-20210913165410-232392fc8a6a/go.mod h1:iSvujNDmpZ6eQX+bg/0X3lF7LEmZ8N77g2a/J/+Zt2U= github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE= github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df/go.mod h1:JP3t17pCcGlemwknint6hfoeCVQrEMVwxRLRjXpq+BU= @@ -216,6 +218,9 @@ github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.etcd.io/bbolt v1.3.6/go.mod h1:qXsaaIqmgQH0T+OPdb99Bf+PKfBBQVAdyD6TY9G8XM4= +go.etcd.io/bbolt v1.3.7 h1:j+zJOnnEjF/kyHlDDgGnVL/AIqIJPq8UoB2GSNfkUfQ= +go.etcd.io/bbolt v1.3.7/go.mod h1:N9Mkw9X8x5fupy0IKsmuqVtoGDyxsaDlbk4Rd05IAQw= golang.org/x/arch v0.1.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -259,6 +264,7 @@ golang.org/x/sys v0.0.0-20190606203320-7fc4e5ec1444/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191115151921-52ab43148777/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/pkg/artifactcache/doc.go b/pkg/artifactcache/doc.go new file mode 100644 index 0000000..13d2644 --- /dev/null +++ b/pkg/artifactcache/doc.go @@ -0,0 +1,8 @@ +// Package artifactcache provides a cache handler for the runner. +// +// Inspired by https://github.com/sp-ricard-valverde/github-act-cache-server +// +// TODO: Authorization +// TODO: Restrictions for accessing a cache, see https://docs.github.com/en/actions/using-workflows/caching-dependencies-to-speed-up-workflows#restrictions-for-accessing-a-cache +// TODO: Force deleting cache entries, see https://docs.github.com/en/actions/using-workflows/caching-dependencies-to-speed-up-workflows#force-deleting-cache-entries +package artifactcache diff --git a/pkg/artifactcache/handler.go b/pkg/artifactcache/handler.go new file mode 100644 index 0000000..f11def6 --- /dev/null +++ b/pkg/artifactcache/handler.go @@ -0,0 +1,488 @@ +package artifactcache + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/julienschmidt/httprouter" + "github.com/sirupsen/logrus" + "github.com/timshannon/bolthold" + "go.etcd.io/bbolt" + + "github.com/nektos/act/pkg/common" +) + +const ( + urlBase = "/_apis/artifactcache" +) + +type Handler struct { + db *bolthold.Store + storage *Storage + router *httprouter.Router + listener net.Listener + server *http.Server + logger logrus.FieldLogger + + gcing int32 // TODO: use atomic.Bool when we can use Go 1.19 + gcAt time.Time + + outboundIP string +} + +func StartHandler(dir, outboundIP string, port uint16, logger logrus.FieldLogger) (*Handler, error) { + h := &Handler{} + + if logger == nil { + discard := logrus.New() + discard.Out = io.Discard + logger = discard + } + logger = logger.WithField("module", "artifactcache") + h.logger = logger + + if dir == "" { + home, err := os.UserHomeDir() + if err != nil { + return nil, err + } + dir = filepath.Join(home, ".cache", "actcache") + } + if err := os.MkdirAll(dir, 0o755); err != nil { + return nil, err + } + + db, err := bolthold.Open(filepath.Join(dir, "bolt.db"), 0o644, &bolthold.Options{ + Encoder: json.Marshal, + Decoder: json.Unmarshal, + Options: &bbolt.Options{ + Timeout: 5 * time.Second, + NoGrowSync: bbolt.DefaultOptions.NoGrowSync, + FreelistType: bbolt.DefaultOptions.FreelistType, + }, + }) + if err != nil { + return nil, err + } + h.db = db + + storage, err := NewStorage(filepath.Join(dir, "cache")) + if err != nil { + return nil, err + } + h.storage = storage + + if outboundIP != "" { + h.outboundIP = outboundIP + } else if ip := common.GetOutboundIP(); ip == nil { + return nil, fmt.Errorf("unable to determine outbound IP address") + } else { + h.outboundIP = ip.String() + } + + router := httprouter.New() + router.GET(urlBase+"/cache", h.middleware(h.find)) + router.POST(urlBase+"/caches", h.middleware(h.reserve)) + router.PATCH(urlBase+"/caches/:id", h.middleware(h.upload)) + router.POST(urlBase+"/caches/:id", h.middleware(h.commit)) + router.GET(urlBase+"/artifacts/:id", h.middleware(h.get)) + router.POST(urlBase+"/clean", h.middleware(h.clean)) + + h.router = router + + h.gcCache() + + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) // listen on all interfaces + if err != nil { + return nil, err + } + server := &http.Server{ + ReadHeaderTimeout: 2 * time.Second, + Handler: router, + } + go func() { + if err := server.Serve(listener); err != nil && errors.Is(err, net.ErrClosed) { + logger.Errorf("http serve: %v", err) + } + }() + h.listener = listener + h.server = server + + return h, nil +} + +func (h *Handler) ExternalURL() string { + // TODO: make the external url configurable if necessary + return fmt.Sprintf("http://%s:%d", + h.outboundIP, + h.listener.Addr().(*net.TCPAddr).Port) +} + +func (h *Handler) Close() error { + if h == nil { + return nil + } + var retErr error + if h.server != nil { + err := h.server.Close() + if err != nil { + retErr = err + } + h.server = nil + } + if h.listener != nil { + err := h.listener.Close() + if errors.Is(err, net.ErrClosed) { + err = nil + } + if err != nil { + retErr = err + } + h.listener = nil + } + if h.db != nil { + err := h.db.Close() + if err != nil { + retErr = err + } + h.db = nil + } + return retErr +} + +// GET /_apis/artifactcache/cache +func (h *Handler) find(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + keys := strings.Split(r.URL.Query().Get("keys"), ",") + // cache keys are case insensitive + for i, key := range keys { + keys[i] = strings.ToLower(key) + } + version := r.URL.Query().Get("version") + + cache, err := h.findCache(keys, version) + if err != nil { + h.responseJSON(w, r, 500, err) + return + } + if cache == nil { + h.responseJSON(w, r, 204) + return + } + + if ok, err := h.storage.Exist(cache.ID); err != nil { + h.responseJSON(w, r, 500, err) + return + } else if !ok { + _ = h.db.Delete(cache.ID, cache) + h.responseJSON(w, r, 204) + return + } + h.responseJSON(w, r, 200, map[string]any{ + "result": "hit", + "archiveLocation": fmt.Sprintf("%s%s/artifacts/%d", h.ExternalURL(), urlBase, cache.ID), + "cacheKey": cache.Key, + }) +} + +// POST /_apis/artifactcache/caches +func (h *Handler) reserve(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + api := &Request{} + if err := json.NewDecoder(r.Body).Decode(api); err != nil { + h.responseJSON(w, r, 400, err) + return + } + // cache keys are case insensitive + api.Key = strings.ToLower(api.Key) + + cache := api.ToCache() + cache.FillKeyVersionHash() + if err := h.db.FindOne(cache, bolthold.Where("KeyVersionHash").Eq(cache.KeyVersionHash)); err != nil { + if !errors.Is(err, bolthold.ErrNotFound) { + h.responseJSON(w, r, 500, err) + return + } + } else { + h.responseJSON(w, r, 400, fmt.Errorf("already exist")) + return + } + + now := time.Now().Unix() + cache.CreatedAt = now + cache.UsedAt = now + if err := h.db.Insert(bolthold.NextSequence(), cache); err != nil { + h.responseJSON(w, r, 500, err) + return + } + // write back id to db + if err := h.db.Update(cache.ID, cache); err != nil { + h.responseJSON(w, r, 500, err) + return + } + h.responseJSON(w, r, 200, map[string]any{ + "cacheId": cache.ID, + }) +} + +// PATCH /_apis/artifactcache/caches/:id +func (h *Handler) upload(w http.ResponseWriter, r *http.Request, params httprouter.Params) { + id, err := strconv.ParseInt(params.ByName("id"), 10, 64) + if err != nil { + h.responseJSON(w, r, 400, err) + return + } + + cache := &Cache{} + if err := h.db.Get(id, cache); err != nil { + if errors.Is(err, bolthold.ErrNotFound) { + h.responseJSON(w, r, 400, fmt.Errorf("cache %d: not reserved", id)) + return + } + h.responseJSON(w, r, 500, err) + return + } + + if cache.Complete { + h.responseJSON(w, r, 400, fmt.Errorf("cache %v %q: already complete", cache.ID, cache.Key)) + return + } + start, _, err := parseContentRange(r.Header.Get("Content-Range")) + if err != nil { + h.responseJSON(w, r, 400, err) + return + } + if err := h.storage.Write(cache.ID, start, r.Body); err != nil { + h.responseJSON(w, r, 500, err) + } + h.useCache(id) + h.responseJSON(w, r, 200) +} + +// POST /_apis/artifactcache/caches/:id +func (h *Handler) commit(w http.ResponseWriter, r *http.Request, params httprouter.Params) { + id, err := strconv.ParseInt(params.ByName("id"), 10, 64) + if err != nil { + h.responseJSON(w, r, 400, err) + return + } + + cache := &Cache{} + if err := h.db.Get(id, cache); err != nil { + if errors.Is(err, bolthold.ErrNotFound) { + h.responseJSON(w, r, 400, fmt.Errorf("cache %d: not reserved", id)) + return + } + h.responseJSON(w, r, 500, err) + return + } + + if cache.Complete { + h.responseJSON(w, r, 400, fmt.Errorf("cache %v %q: already complete", cache.ID, cache.Key)) + return + } + + if err := h.storage.Commit(cache.ID, cache.Size); err != nil { + h.responseJSON(w, r, 500, err) + return + } + + cache.Complete = true + if err := h.db.Update(cache.ID, cache); err != nil { + h.responseJSON(w, r, 500, err) + return + } + + h.responseJSON(w, r, 200) +} + +// GET /_apis/artifactcache/artifacts/:id +func (h *Handler) get(w http.ResponseWriter, r *http.Request, params httprouter.Params) { + id, err := strconv.ParseInt(params.ByName("id"), 10, 64) + if err != nil { + h.responseJSON(w, r, 400, err) + return + } + h.useCache(id) + h.storage.Serve(w, r, uint64(id)) +} + +// POST /_apis/artifactcache/clean +func (h *Handler) clean(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { + // TODO: don't support force deleting cache entries + // see: https://docs.github.com/en/actions/using-workflows/caching-dependencies-to-speed-up-workflows#force-deleting-cache-entries + + h.responseJSON(w, r, 200) +} + +func (h *Handler) middleware(handler httprouter.Handle) httprouter.Handle { + return func(w http.ResponseWriter, r *http.Request, params httprouter.Params) { + h.logger.Debugf("%s %s", r.Method, r.RequestURI) + handler(w, r, params) + go h.gcCache() + } +} + +// if not found, return (nil, nil) instead of an error. +func (h *Handler) findCache(keys []string, version string) (*Cache, error) { + if len(keys) == 0 { + return nil, nil + } + key := keys[0] // the first key is for exact match. + + cache := &Cache{ + Key: key, + Version: version, + } + cache.FillKeyVersionHash() + + if err := h.db.FindOne(cache, bolthold.Where("KeyVersionHash").Eq(cache.KeyVersionHash)); err != nil { + if !errors.Is(err, bolthold.ErrNotFound) { + return nil, err + } + } else if cache.Complete { + return cache, nil + } + stop := fmt.Errorf("stop") + + for _, prefix := range keys[1:] { + found := false + if err := h.db.ForEach(bolthold.Where("Key").Ge(prefix).And("Version").Eq(version).SortBy("Key"), func(v *Cache) error { + if !strings.HasPrefix(v.Key, prefix) { + return stop + } + if v.Complete { + cache = v + found = true + return stop + } + return nil + }); err != nil { + if !errors.Is(err, stop) { + return nil, err + } + } + if found { + return cache, nil + } + } + return nil, nil +} + +func (h *Handler) useCache(id int64) { + cache := &Cache{} + if err := h.db.Get(id, cache); err != nil { + return + } + cache.UsedAt = time.Now().Unix() + _ = h.db.Update(cache.ID, cache) +} + +func (h *Handler) gcCache() { + if atomic.LoadInt32(&h.gcing) != 0 { + return + } + if !atomic.CompareAndSwapInt32(&h.gcing, 0, 1) { + return + } + defer atomic.StoreInt32(&h.gcing, 0) + + if time.Since(h.gcAt) < time.Hour { + h.logger.Debugf("skip gc: %v", h.gcAt.String()) + return + } + h.gcAt = time.Now() + h.logger.Debugf("gc: %v", h.gcAt.String()) + + const ( + keepUsed = 30 * 24 * time.Hour + keepUnused = 7 * 24 * time.Hour + keepTemp = 5 * time.Minute + ) + + var caches []*Cache + if err := h.db.Find(&caches, bolthold.Where("UsedAt").Lt(time.Now().Add(-keepTemp).Unix())); err != nil { + h.logger.Warnf("find caches: %v", err) + } else { + for _, cache := range caches { + if cache.Complete { + continue + } + h.storage.Remove(cache.ID) + if err := h.db.Delete(cache.ID, cache); err != nil { + h.logger.Warnf("delete cache: %v", err) + continue + } + h.logger.Infof("deleted cache: %+v", cache) + } + } + + caches = caches[:0] + if err := h.db.Find(&caches, bolthold.Where("UsedAt").Lt(time.Now().Add(-keepUnused).Unix())); err != nil { + h.logger.Warnf("find caches: %v", err) + } else { + for _, cache := range caches { + h.storage.Remove(cache.ID) + if err := h.db.Delete(cache.ID, cache); err != nil { + h.logger.Warnf("delete cache: %v", err) + continue + } + h.logger.Infof("deleted cache: %+v", cache) + } + } + + caches = caches[:0] + if err := h.db.Find(&caches, bolthold.Where("CreatedAt").Lt(time.Now().Add(-keepUsed).Unix())); err != nil { + h.logger.Warnf("find caches: %v", err) + } else { + for _, cache := range caches { + h.storage.Remove(cache.ID) + if err := h.db.Delete(cache.ID, cache); err != nil { + h.logger.Warnf("delete cache: %v", err) + continue + } + h.logger.Infof("deleted cache: %+v", cache) + } + } +} + +func (h *Handler) responseJSON(w http.ResponseWriter, r *http.Request, code int, v ...any) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + var data []byte + if len(v) == 0 || v[0] == nil { + data, _ = json.Marshal(struct{}{}) + } else if err, ok := v[0].(error); ok { + h.logger.Errorf("%v %v: %v", r.Method, r.RequestURI, err) + data, _ = json.Marshal(map[string]any{ + "error": err.Error(), + }) + } else { + data, _ = json.Marshal(v[0]) + } + w.WriteHeader(code) + _, _ = w.Write(data) +} + +func parseContentRange(s string) (int64, int64, error) { + // support the format like "bytes 11-22/*" only + s, _, _ = strings.Cut(strings.TrimPrefix(s, "bytes "), "/") + s1, s2, _ := strings.Cut(s, "-") + + start, err := strconv.ParseInt(s1, 10, 64) + if err != nil { + return 0, 0, fmt.Errorf("parse %q: %w", s, err) + } + stop, err := strconv.ParseInt(s2, 10, 64) + if err != nil { + return 0, 0, fmt.Errorf("parse %q: %w", s, err) + } + return start, stop, nil +} diff --git a/pkg/artifactcache/handler_test.go b/pkg/artifactcache/handler_test.go new file mode 100644 index 0000000..7c6840a --- /dev/null +++ b/pkg/artifactcache/handler_test.go @@ -0,0 +1,469 @@ +package artifactcache + +import ( + "bytes" + "crypto/rand" + "encoding/json" + "fmt" + "io" + "net/http" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.etcd.io/bbolt" +) + +func TestHandler(t *testing.T) { + dir := filepath.Join(t.TempDir(), "artifactcache") + handler, err := StartHandler(dir, "", 0, nil) + require.NoError(t, err) + + base := fmt.Sprintf("%s%s", handler.ExternalURL(), urlBase) + + defer func() { + t.Run("inpect db", func(t *testing.T) { + require.NoError(t, handler.db.Bolt().View(func(tx *bbolt.Tx) error { + return tx.Bucket([]byte("Cache")).ForEach(func(k, v []byte) error { + t.Logf("%s: %s", k, v) + return nil + }) + })) + }) + t.Run("close", func(t *testing.T) { + require.NoError(t, handler.Close()) + assert.Nil(t, handler.server) + assert.Nil(t, handler.listener) + assert.Nil(t, handler.db) + _, err := http.Post(fmt.Sprintf("%s/caches/%d", base, 1), "", nil) + assert.Error(t, err) + }) + }() + + t.Run("get not exist", func(t *testing.T) { + key := strings.ToLower(t.Name()) + version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" + resp, err := http.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 204, resp.StatusCode) + }) + + t.Run("reserve and upload", func(t *testing.T) { + key := strings.ToLower(t.Name()) + version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" + content := make([]byte, 100) + _, err := rand.Read(content) + require.NoError(t, err) + uploadCacheNormally(t, base, key, version, content) + }) + + t.Run("clean", func(t *testing.T) { + resp, err := http.Post(fmt.Sprintf("%s/clean", base), "", nil) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + }) + + t.Run("reserve with bad request", func(t *testing.T) { + body := []byte(`invalid json`) + require.NoError(t, err) + resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) + require.NoError(t, err) + assert.Equal(t, 400, resp.StatusCode) + }) + + t.Run("duplicate reserve", func(t *testing.T) { + key := strings.ToLower(t.Name()) + version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" + { + body, err := json.Marshal(&Request{ + Key: key, + Version: version, + Size: 100, + }) + require.NoError(t, err) + resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + + got := struct { + CacheID uint64 `json:"cacheId"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + } + { + body, err := json.Marshal(&Request{ + Key: key, + Version: version, + Size: 100, + }) + require.NoError(t, err) + resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) + require.NoError(t, err) + assert.Equal(t, 400, resp.StatusCode) + } + }) + + t.Run("upload with bad id", func(t *testing.T) { + req, err := http.NewRequest(http.MethodPatch, + fmt.Sprintf("%s/caches/invalid_id", base), bytes.NewReader(nil)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Content-Range", "bytes 0-99/*") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.Equal(t, 400, resp.StatusCode) + }) + + t.Run("upload without reserve", func(t *testing.T) { + req, err := http.NewRequest(http.MethodPatch, + fmt.Sprintf("%s/caches/%d", base, 1000), bytes.NewReader(nil)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Content-Range", "bytes 0-99/*") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.Equal(t, 400, resp.StatusCode) + }) + + t.Run("upload with complete", func(t *testing.T) { + key := strings.ToLower(t.Name()) + version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" + var id uint64 + content := make([]byte, 100) + _, err := rand.Read(content) + require.NoError(t, err) + { + body, err := json.Marshal(&Request{ + Key: key, + Version: version, + Size: 100, + }) + require.NoError(t, err) + resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + + got := struct { + CacheID uint64 `json:"cacheId"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + id = got.CacheID + } + { + req, err := http.NewRequest(http.MethodPatch, + fmt.Sprintf("%s/caches/%d", base, id), bytes.NewReader(content)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Content-Range", "bytes 0-99/*") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + } + { + resp, err := http.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + } + { + req, err := http.NewRequest(http.MethodPatch, + fmt.Sprintf("%s/caches/%d", base, id), bytes.NewReader(content)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Content-Range", "bytes 0-99/*") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.Equal(t, 400, resp.StatusCode) + } + }) + + t.Run("upload with invalid range", func(t *testing.T) { + key := strings.ToLower(t.Name()) + version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" + var id uint64 + content := make([]byte, 100) + _, err := rand.Read(content) + require.NoError(t, err) + { + body, err := json.Marshal(&Request{ + Key: key, + Version: version, + Size: 100, + }) + require.NoError(t, err) + resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + + got := struct { + CacheID uint64 `json:"cacheId"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + id = got.CacheID + } + { + req, err := http.NewRequest(http.MethodPatch, + fmt.Sprintf("%s/caches/%d", base, id), bytes.NewReader(content)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Content-Range", "bytes xx-99/*") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.Equal(t, 400, resp.StatusCode) + } + }) + + t.Run("commit with bad id", func(t *testing.T) { + { + resp, err := http.Post(fmt.Sprintf("%s/caches/invalid_id", base), "", nil) + require.NoError(t, err) + assert.Equal(t, 400, resp.StatusCode) + } + }) + + t.Run("commit with not exist id", func(t *testing.T) { + { + resp, err := http.Post(fmt.Sprintf("%s/caches/%d", base, 100), "", nil) + require.NoError(t, err) + assert.Equal(t, 400, resp.StatusCode) + } + }) + + t.Run("duplicate commit", func(t *testing.T) { + key := strings.ToLower(t.Name()) + version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" + var id uint64 + content := make([]byte, 100) + _, err := rand.Read(content) + require.NoError(t, err) + { + body, err := json.Marshal(&Request{ + Key: key, + Version: version, + Size: 100, + }) + require.NoError(t, err) + resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + + got := struct { + CacheID uint64 `json:"cacheId"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + id = got.CacheID + } + { + req, err := http.NewRequest(http.MethodPatch, + fmt.Sprintf("%s/caches/%d", base, id), bytes.NewReader(content)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Content-Range", "bytes 0-99/*") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + } + { + resp, err := http.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + } + { + resp, err := http.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil) + require.NoError(t, err) + assert.Equal(t, 400, resp.StatusCode) + } + }) + + t.Run("commit early", func(t *testing.T) { + key := strings.ToLower(t.Name()) + version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" + var id uint64 + content := make([]byte, 100) + _, err := rand.Read(content) + require.NoError(t, err) + { + body, err := json.Marshal(&Request{ + Key: key, + Version: version, + Size: 100, + }) + require.NoError(t, err) + resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + + got := struct { + CacheID uint64 `json:"cacheId"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + id = got.CacheID + } + { + req, err := http.NewRequest(http.MethodPatch, + fmt.Sprintf("%s/caches/%d", base, id), bytes.NewReader(content[:50])) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Content-Range", "bytes 0-59/*") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + } + { + resp, err := http.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil) + require.NoError(t, err) + assert.Equal(t, 500, resp.StatusCode) + } + }) + + t.Run("get with bad id", func(t *testing.T) { + resp, err := http.Get(fmt.Sprintf("%s/artifacts/invalid_id", base)) + require.NoError(t, err) + require.Equal(t, 400, resp.StatusCode) + }) + + t.Run("get with not exist id", func(t *testing.T) { + resp, err := http.Get(fmt.Sprintf("%s/artifacts/%d", base, 100)) + require.NoError(t, err) + require.Equal(t, 404, resp.StatusCode) + }) + + t.Run("get with not exist id", func(t *testing.T) { + resp, err := http.Get(fmt.Sprintf("%s/artifacts/%d", base, 100)) + require.NoError(t, err) + require.Equal(t, 404, resp.StatusCode) + }) + + t.Run("get with multiple keys", func(t *testing.T) { + version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" + key := strings.ToLower(t.Name()) + keys := [3]string{ + key + "_a", + key + "_a_b", + key + "_a_b_c", + } + contents := [3][]byte{ + make([]byte, 100), + make([]byte, 200), + make([]byte, 300), + } + for i := range contents { + _, err := rand.Read(contents[i]) + require.NoError(t, err) + uploadCacheNormally(t, base, keys[i], version, contents[i]) + } + + reqKeys := strings.Join([]string{ + key + "_a_b_x", + key + "_a_b", + key + "_a", + }, ",") + var archiveLocation string + { + resp, err := http.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, reqKeys, version)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + got := struct { + Result string `json:"result"` + ArchiveLocation string `json:"archiveLocation"` + CacheKey string `json:"cacheKey"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + assert.Equal(t, "hit", got.Result) + assert.Equal(t, keys[1], got.CacheKey) + archiveLocation = got.ArchiveLocation + } + { + resp, err := http.Get(archiveLocation) //nolint:gosec + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + got, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, contents[1], got) + } + }) + + t.Run("case insensitive", func(t *testing.T) { + version := "c19da02a2bd7e77277f1ac29ab45c09b7d46a4ee758284e26bb3045ad11d9d20" + key := strings.ToLower(t.Name()) + content := make([]byte, 100) + _, err := rand.Read(content) + require.NoError(t, err) + uploadCacheNormally(t, base, key+"_ABC", version, content) + + { + reqKey := key + "_aBc" + resp, err := http.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, reqKey, version)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + got := struct { + Result string `json:"result"` + ArchiveLocation string `json:"archiveLocation"` + CacheKey string `json:"cacheKey"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + assert.Equal(t, "hit", got.Result) + assert.Equal(t, key+"_abc", got.CacheKey) + } + }) +} + +func uploadCacheNormally(t *testing.T, base, key, version string, content []byte) { + var id uint64 + { + body, err := json.Marshal(&Request{ + Key: key, + Version: version, + Size: int64(len(content)), + }) + require.NoError(t, err) + resp, err := http.Post(fmt.Sprintf("%s/caches", base), "application/json", bytes.NewReader(body)) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + + got := struct { + CacheID uint64 `json:"cacheId"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + id = got.CacheID + } + { + req, err := http.NewRequest(http.MethodPatch, + fmt.Sprintf("%s/caches/%d", base, id), bytes.NewReader(content)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Content-Range", "bytes 0-99/*") + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + } + { + resp, err := http.Post(fmt.Sprintf("%s/caches/%d", base, id), "", nil) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + } + var archiveLocation string + { + resp, err := http.Get(fmt.Sprintf("%s/cache?keys=%s&version=%s", base, key, version)) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + got := struct { + Result string `json:"result"` + ArchiveLocation string `json:"archiveLocation"` + CacheKey string `json:"cacheKey"` + }{} + require.NoError(t, json.NewDecoder(resp.Body).Decode(&got)) + assert.Equal(t, "hit", got.Result) + assert.Equal(t, strings.ToLower(key), got.CacheKey) + archiveLocation = got.ArchiveLocation + } + { + resp, err := http.Get(archiveLocation) //nolint:gosec + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) + got, err := io.ReadAll(resp.Body) + require.NoError(t, err) + assert.Equal(t, content, got) + } +} diff --git a/pkg/artifactcache/model.go b/pkg/artifactcache/model.go new file mode 100644 index 0000000..5c28899 --- /dev/null +++ b/pkg/artifactcache/model.go @@ -0,0 +1,38 @@ +package artifactcache + +import ( + "crypto/sha256" + "fmt" +) + +type Request struct { + Key string `json:"key" ` + Version string `json:"version"` + Size int64 `json:"cacheSize"` +} + +func (c *Request) ToCache() *Cache { + if c == nil { + return nil + } + return &Cache{ + Key: c.Key, + Version: c.Version, + Size: c.Size, + } +} + +type Cache struct { + ID uint64 `json:"id" boltholdKey:"ID"` + Key string `json:"key" boltholdIndex:"Key"` + Version string `json:"version" boltholdIndex:"Version"` + KeyVersionHash string `json:"keyVersionHash" boltholdUnique:"KeyVersionHash"` + Size int64 `json:"cacheSize"` + Complete bool `json:"complete"` + UsedAt int64 `json:"usedAt" boltholdIndex:"UsedAt"` + CreatedAt int64 `json:"createdAt" boltholdIndex:"CreatedAt"` +} + +func (c *Cache) FillKeyVersionHash() { + c.KeyVersionHash = fmt.Sprintf("%x", sha256.Sum256([]byte(fmt.Sprintf("%s:%s", c.Key, c.Version)))) +} diff --git a/pkg/artifactcache/storage.go b/pkg/artifactcache/storage.go new file mode 100644 index 0000000..a49c94e --- /dev/null +++ b/pkg/artifactcache/storage.go @@ -0,0 +1,126 @@ +package artifactcache + +import ( + "fmt" + "io" + "net/http" + "os" + "path/filepath" +) + +type Storage struct { + rootDir string +} + +func NewStorage(rootDir string) (*Storage, error) { + if err := os.MkdirAll(rootDir, 0o755); err != nil { + return nil, err + } + return &Storage{ + rootDir: rootDir, + }, nil +} + +func (s *Storage) Exist(id uint64) (bool, error) { + name := s.filename(id) + if _, err := os.Stat(name); os.IsNotExist(err) { + return false, nil + } else if err != nil { + return false, err + } + return true, nil +} + +func (s *Storage) Write(id uint64, offset int64, reader io.Reader) error { + name := s.tempName(id, offset) + if err := os.MkdirAll(filepath.Dir(name), 0o755); err != nil { + return err + } + file, err := os.Create(name) + if err != nil { + return err + } + defer file.Close() + + _, err = io.Copy(file, reader) + return err +} + +func (s *Storage) Commit(id uint64, size int64) error { + defer func() { + _ = os.RemoveAll(s.tempDir(id)) + }() + + name := s.filename(id) + tempNames, err := s.tempNames(id) + if err != nil { + return err + } + + if err := os.MkdirAll(filepath.Dir(name), 0o755); err != nil { + return err + } + file, err := os.Create(name) + if err != nil { + return err + } + defer file.Close() + + var written int64 + for _, v := range tempNames { + f, err := os.Open(v) + if err != nil { + return err + } + n, err := io.Copy(file, f) + _ = f.Close() + if err != nil { + return err + } + written += n + } + + if written != size { + _ = file.Close() + _ = os.Remove(name) + return fmt.Errorf("broken file: %v != %v", written, size) + } + return nil +} + +func (s *Storage) Serve(w http.ResponseWriter, r *http.Request, id uint64) { + name := s.filename(id) + http.ServeFile(w, r, name) +} + +func (s *Storage) Remove(id uint64) { + _ = os.Remove(s.filename(id)) + _ = os.RemoveAll(s.tempDir(id)) +} + +func (s *Storage) filename(id uint64) string { + return filepath.Join(s.rootDir, fmt.Sprintf("%02x", id%0xff), fmt.Sprint(id)) +} + +func (s *Storage) tempDir(id uint64) string { + return filepath.Join(s.rootDir, "tmp", fmt.Sprint(id)) +} + +func (s *Storage) tempName(id uint64, offset int64) string { + return filepath.Join(s.tempDir(id), fmt.Sprintf("%016x", offset)) +} + +func (s *Storage) tempNames(id uint64) ([]string, error) { + dir := s.tempDir(id) + files, err := os.ReadDir(dir) + if err != nil { + return nil, err + } + var names []string + for _, v := range files { + if !v.IsDir() { + names = append(names, filepath.Join(dir, v.Name())) + } + } + return names, nil +} diff --git a/pkg/artifactcache/testdata/example/example.yaml b/pkg/artifactcache/testdata/example/example.yaml new file mode 100644 index 0000000..5332e72 --- /dev/null +++ b/pkg/artifactcache/testdata/example/example.yaml @@ -0,0 +1,30 @@ +# Copied from https://github.com/actions/cache#example-cache-workflow +name: Caching Primes + +on: push + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - run: env + + - uses: actions/checkout@v3 + + - name: Cache Primes + id: cache-primes + uses: actions/cache@v3 + with: + path: prime-numbers + key: ${{ runner.os }}-primes-${{ github.run_id }} + restore-keys: | + ${{ runner.os }}-primes + ${{ runner.os }} + + - name: Generate Prime Numbers + if: steps.cache-primes.outputs.cache-hit != 'true' + run: cat /proc/sys/kernel/random/uuid > prime-numbers + + - name: Use Prime Numbers + run: cat prime-numbers