From b1b9df5ef40e778f50a6f7cd08f4964c69258924 Mon Sep 17 00:00:00 2001 From: Earl Warren Date: Thu, 31 Oct 2024 15:58:39 +0100 Subject: [PATCH] fix: return an error when the argument count is wrong Closes forgejo/runner#307 --- .forgejo/workflows/test.yml | 2 +- pkg/exprparser/functions_test.go | 21 +++++++++++++++++++ pkg/exprparser/interpreter.go | 35 ++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) diff --git a/.forgejo/workflows/test.yml b/.forgejo/workflows/test.yml index 59516f3..a549105 100644 --- a/.forgejo/workflows/test.yml +++ b/.forgejo/workflows/test.yml @@ -44,5 +44,5 @@ jobs: - name: build without docker run: go build -tags WITHOUT_DOCKER -v ./... - name: test - run: go test -v ./pkg/jobparser ./pkg/model + run: go test -v ./pkg/jobparser ./pkg/model ./pkg/exprparser # TODO test more packages diff --git a/pkg/exprparser/functions_test.go b/pkg/exprparser/functions_test.go index ea51a2b..c90b326 100644 --- a/pkg/exprparser/functions_test.go +++ b/pkg/exprparser/functions_test.go @@ -43,6 +43,9 @@ func TestFunctionContains(t *testing.T) { assert.Equal(t, tt.expected, output) }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("contains('one')", DefaultStatusCheckNone) + assert.Error(t, err) } func TestFunctionStartsWith(t *testing.T) { @@ -72,6 +75,9 @@ func TestFunctionStartsWith(t *testing.T) { assert.Equal(t, tt.expected, output) }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("startsWith('one')", DefaultStatusCheckNone) + assert.Error(t, err) } func TestFunctionEndsWith(t *testing.T) { @@ -101,6 +107,9 @@ func TestFunctionEndsWith(t *testing.T) { assert.Equal(t, tt.expected, output) }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("endsWith('one')", DefaultStatusCheckNone) + assert.Error(t, err) } func TestFunctionJoin(t *testing.T) { @@ -128,6 +137,9 @@ func TestFunctionJoin(t *testing.T) { assert.Equal(t, tt.expected, output) }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("join()", DefaultStatusCheckNone) + assert.Error(t, err) } func TestFunctionToJSON(t *testing.T) { @@ -154,6 +166,9 @@ func TestFunctionToJSON(t *testing.T) { assert.Equal(t, tt.expected, output) }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("tojson()", DefaultStatusCheckNone) + assert.Error(t, err) } func TestFunctionFromJSON(t *testing.T) { @@ -177,6 +192,9 @@ func TestFunctionFromJSON(t *testing.T) { assert.Equal(t, tt.expected, output) }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("fromjson()", DefaultStatusCheckNone) + assert.Error(t, err) } func TestFunctionHashFiles(t *testing.T) { @@ -248,4 +266,7 @@ func TestFunctionFormat(t *testing.T) { } }) } + + _, err := NewInterpeter(env, Config{}).Evaluate("format()", DefaultStatusCheckNone) + assert.Error(t, err) } diff --git a/pkg/exprparser/interpreter.go b/pkg/exprparser/interpreter.go index 29c5686..021e5c9 100644 --- a/pkg/exprparser/interpreter.go +++ b/pkg/exprparser/interpreter.go @@ -593,23 +593,58 @@ func (impl *interperterImpl) evaluateFuncCall(funcCallNode *actionlint.FuncCallN args = append(args, reflect.ValueOf(value)) } + argCountCheck := func(argCount int) error { + if len(args) != argCount { + return fmt.Errorf("'%s' expected %d arguments but got %d instead", funcCallNode.Callee, argCount, len(args)) + } + return nil + } + + argAtLeastCheck := func(atLeast int) error { + if len(args) < atLeast { + return fmt.Errorf("'%s' expected at least %d arguments but got %d instead", funcCallNode.Callee, atLeast, len(args)) + } + return nil + } + switch strings.ToLower(funcCallNode.Callee) { case "contains": + if err := argCountCheck(2); err != nil { + return nil, err + } return impl.contains(args[0], args[1]) case "startswith": + if err := argCountCheck(2); err != nil { + return nil, err + } return impl.startsWith(args[0], args[1]) case "endswith": + if err := argCountCheck(2); err != nil { + return nil, err + } return impl.endsWith(args[0], args[1]) case "format": + if err := argAtLeastCheck(1); err != nil { + return nil, err + } return impl.format(args[0], args[1:]...) case "join": + if err := argAtLeastCheck(1); err != nil { + return nil, err + } if len(args) == 1 { return impl.join(args[0], reflect.ValueOf(",")) } return impl.join(args[0], args[1]) case "tojson": + if err := argCountCheck(1); err != nil { + return nil, err + } return impl.toJSON(args[0]) case "fromjson": + if err := argCountCheck(1); err != nil { + return nil, err + } return impl.fromJSON(args[0]) case "hashfiles": if impl.env.HashFiles != nil {