diff --git a/pkg/model/github_context.go b/pkg/model/github_context.go index e9918a8..86172df 100644 --- a/pkg/model/github_context.go +++ b/pkg/model/github_context.go @@ -95,7 +95,7 @@ func (ghc *GithubContext) SetRefAndSha(ctx context.Context, defaultBranch string // https://docs.github.com/en/developers/webhooks-and-events/webhooks/webhook-events-and-payloads switch ghc.EventName { case "pull_request_target": - ghc.Ref = ghc.BaseRef + ghc.Ref = fmt.Sprintf("refs/heads/%s", ghc.BaseRef) ghc.Sha = asString(nestedMapLookup(ghc.Event, "pull_request", "base", "sha")) case "pull_request", "pull_request_review", "pull_request_review_comment": ghc.Ref = fmt.Sprintf("refs/pull/%.0f/merge", ghc.Event["number"]) @@ -110,7 +110,10 @@ func (ghc *GithubContext) SetRefAndSha(ctx context.Context, defaultBranch string ghc.Sha = asString(ghc.Event["after"]) } default: - ghc.Ref = asString(nestedMapLookup(ghc.Event, "repository", "default_branch")) + defaultBranch := asString(nestedMapLookup(ghc.Event, "repository", "default_branch")) + if defaultBranch != "" { + ghc.Ref = fmt.Sprintf("refs/heads/%s", defaultBranch) + } } if ghc.Ref == "" { @@ -130,7 +133,7 @@ func (ghc *GithubContext) SetRefAndSha(ctx context.Context, defaultBranch string } if ghc.Ref == "" { - ghc.Ref = asString(nestedMapLookup(ghc.Event, "repository", "default_branch")) + ghc.Ref = fmt.Sprintf("refs/heads/%s", asString(nestedMapLookup(ghc.Event, "repository", "default_branch"))) } } diff --git a/pkg/model/github_context_test.go b/pkg/model/github_context_test.go index 5cffcde..a290094 100644 --- a/pkg/model/github_context_test.go +++ b/pkg/model/github_context_test.go @@ -40,7 +40,7 @@ func TestSetRefAndSha(t *testing.T) { }, }, }, - ref: "master", + ref: "refs/heads/master", sha: "pr-base-sha", }, { @@ -89,7 +89,7 @@ func TestSetRefAndSha(t *testing.T) { "default_branch": "main", }, }, - ref: "main", + ref: "refs/heads/main", sha: "1234fakesha", }, { @@ -127,7 +127,7 @@ func TestSetRefAndSha(t *testing.T) { ghc.SetRefAndSha(context.Background(), "", "/some/dir") - assert.Equal(t, "master", ghc.Ref) + assert.Equal(t, "refs/heads/master", ghc.Ref) assert.Equal(t, "1234fakesha", ghc.Sha) }) } diff --git a/pkg/runner/run_context.go b/pkg/runner/run_context.go index 89608ad..8c2de75 100644 --- a/pkg/runner/run_context.go +++ b/pkg/runner/run_context.go @@ -486,7 +486,7 @@ func (rc *RunContext) getGithubContext(ctx context.Context) *model.GithubContext } } - if ghc.EventName == "pull_request" { + if ghc.EventName == "pull_request" || ghc.EventName == "pull_request_target" { ghc.BaseRef = asString(nestedMapLookup(ghc.Event, "pull_request", "base", "ref")) ghc.HeadRef = asString(nestedMapLookup(ghc.Event, "pull_request", "head", "ref")) } diff --git a/pkg/runner/run_context_test.go b/pkg/runner/run_context_test.go index 5980a8c..cdafb8e 100644 --- a/pkg/runner/run_context_test.go +++ b/pkg/runner/run_context_test.go @@ -396,6 +396,48 @@ func TestGetGitHubContext(t *testing.T) { assert.Equal(t, ghc.Token, rc.Config.Secrets["GITHUB_TOKEN"]) } +func TestGetGithubContextRef(t *testing.T) { + table := []struct { + event string + json string + ref string + }{ + {event: "push", json: `{"ref":"0000000000000000000000000000000000000000"}`, ref: "0000000000000000000000000000000000000000"}, + {event: "create", json: `{"ref":"0000000000000000000000000000000000000000"}`, ref: "0000000000000000000000000000000000000000"}, + {event: "workflow_dispatch", json: `{"ref":"0000000000000000000000000000000000000000"}`, ref: "0000000000000000000000000000000000000000"}, + {event: "delete", json: `{"repository":{"default_branch": "main"}}`, ref: "refs/heads/main"}, + {event: "pull_request", json: `{"number":123}`, ref: "refs/pull/123/merge"}, + {event: "pull_request_review", json: `{"number":123}`, ref: "refs/pull/123/merge"}, + {event: "pull_request_review_comment", json: `{"number":123}`, ref: "refs/pull/123/merge"}, + {event: "pull_request_target", json: `{"pull_request":{"base":{"ref": "main"}}}`, ref: "refs/heads/main"}, + {event: "deployment", json: `{"deployment": {"ref": "tag-name"}}`, ref: "tag-name"}, + {event: "deployment_status", json: `{"deployment": {"ref": "tag-name"}}`, ref: "tag-name"}, + {event: "release", json: `{"release": {"tag_name": "tag-name"}}`, ref: "tag-name"}, + } + + for _, data := range table { + data := data + t.Run(data.event, func(t *testing.T) { + rc := &RunContext{ + EventJSON: data.json, + Config: &Config{ + EventName: data.event, + Workdir: "", + }, + Run: &model.Run{ + Workflow: &model.Workflow{ + Name: "GitHubContextTest", + }, + }, + } + + ghc := rc.getGithubContext(context.Background()) + + assert.Equal(t, data.ref, ghc.Ref) + }) + } +} + func createIfTestRunContext(jobs map[string]*model.Job) *RunContext { rc := &RunContext{ Config: &Config{