commit 8285fef11f1e4d6d8a7aaa15dc362bfd021d9d01
parent 140fb2edef2a829c6857b7b268a27ca1aca15701
Author: bsandro <email@bsandro.tech>
Date: Sun, 26 Jun 2022 12:23:23 +0300
functions and nested contexts
Diffstat:
4 files changed, 127 insertions(+), 3 deletions(-)
diff --git a/eval/eval.go b/eval/eval.go
@@ -75,6 +75,20 @@ func Eval(node ast.Node, ctx *object.Context) object.Object {
ctx.Set(node.Name.Value, value)
case *ast.Identifier:
return evalIdentifier(node, ctx)
+ case *ast.FunctionLiteral:
+ params := node.Parameters
+ body := node.Body
+ return &object.Function{Parameters: params, Ctx: ctx, Body: body}
+ case *ast.CallExpression:
+ fn := Eval(node.Function, ctx)
+ if isError(fn) {
+ return fn
+ }
+ args := evalExpressions(node.Arguments, ctx)
+ if len(args) == 1 && isError(args[0]) {
+ return args[0]
+ }
+ return applyFunction(fn, args)
}
return nil
}
@@ -223,3 +237,40 @@ func evalIdentifier(node *ast.Identifier, ctx *object.Context) object.Object {
}
return value
}
+
+func evalExpressions(expressions []ast.Expression, ctx *object.Context) []object.Object {
+ var result []object.Object
+ for _, e := range expressions {
+ evaluated := Eval(e, ctx)
+ if isError(evaluated) {
+ return []object.Object{evaluated}
+ }
+ result = append(result, evaluated)
+ }
+ return result
+}
+
+func applyFunction(fn object.Object, args []object.Object) object.Object {
+ function, ok := fn.(*object.Function)
+ if !ok {
+ return newError("not a function: %s", fn.Type())
+ }
+ childCtx := childFunctionContext(function, args)
+ evaluated := Eval(function.Body, childCtx)
+ return unwrapReturnValue(evaluated)
+}
+
+func childFunctionContext(fn *object.Function, args []object.Object) *object.Context {
+ ctx := object.NewChildContext(fn.Ctx)
+ for paramIdx, param := range fn.Parameters {
+ ctx.Set(param.Value, args[paramIdx])
+ }
+ return ctx
+}
+
+func unwrapReturnValue(obj object.Object) object.Object {
+ if returnValue, ok := obj.(*object.ReturnValue); ok {
+ return returnValue.Value
+ }
+ return obj
+}
diff --git a/eval/eval_test.go b/eval/eval_test.go
@@ -192,3 +192,39 @@ func TestLetStatements(t *testing.T) {
testIntegerObject(t, testEval(tt.input), tt.expected)
}
}
+
+func TestFunctionObject(t *testing.T) {
+ input := "fn(x) { x + 2; };"
+ evaluated := testEval(input)
+ fn, ok := evaluated.(*object.Function)
+ if !ok {
+ t.Fatalf("object is not a function")
+ }
+ if len(fn.Parameters) != 1 {
+ t.Fatalf("wrong number of parameters in function")
+ }
+ if fn.Parameters[0].String() != "x" {
+ t.Fatalf("wrong function parameter name")
+ }
+ expectedBody := "(x + 2)"
+ if fn.Body.String() != expectedBody {
+ t.Fatalf("function body is not '%s' but '%s'", expectedBody, fn.Body.String())
+ }
+}
+
+func TestFunctionApplication(t *testing.T) {
+ tests := []struct {
+ input string
+ expected int64
+ }{
+ {"let ident = fn(x) { x; }; ident(8);", 8},
+ {"let ident = fn(x) { return x; }; ident(8);", 8},
+ {"let double = fn(y) { y*2; }; double(8);", 16},
+ {"let sum = fn(a,b) { a+b;}; sum(10,20);", 30},
+ {"let sum = fn(a,b) { a+b;}; sum(10+20,sum(30,40));", 100},
+ {"fn(k){ k; }(32)", 32},
+ }
+ for _, tt := range tests {
+ testIntegerObject(t, testEval(tt.input), tt.expected)
+ }
+}
diff --git a/object/context.go b/object/context.go
@@ -1,16 +1,20 @@
package object
type Context struct {
- store map[string]Object
+ store map[string]Object
+ parent *Context
}
func NewContext() *Context {
s := make(map[string]Object)
- return &Context{store: s}
+ return &Context{store: s, parent: nil}
}
func (ctx *Context) Get(name string) (Object, bool) {
obj, ok := ctx.store[name]
+ if !ok && ctx.parent != nil {
+ obj, ok = ctx.parent.Get(name)
+ }
return obj, ok
}
@@ -18,3 +22,9 @@ func (ctx *Context) Set(name string, obj Object) Object {
ctx.store[name] = obj
return obj
}
+
+func NewChildContext(parent *Context) *Context {
+ ctx := NewContext()
+ ctx.parent = parent
+ return ctx
+}
diff --git a/object/object.go b/object/object.go
@@ -1,6 +1,11 @@
package object
-import "fmt"
+import (
+ "bytes"
+ "fmt"
+ "interp/ast"
+ "strings"
+)
type ObjectType string
@@ -10,6 +15,7 @@ const (
NULL_OBJ = "NULL"
RETURN_VALUE_OBJ = "RETURN_VALUE"
ERROR_OBJ = "ERROR"
+ FUNCTION_OBJ = "FUNCTION"
)
type Object interface {
@@ -49,3 +55,24 @@ type Error struct {
func (e *Error) Inspect() string { return "error: " + e.Message }
func (e *Error) Type() ObjectType { return ERROR_OBJ }
+
+type Function struct {
+ Parameters []*ast.Identifier
+ Body *ast.BlockStatement
+ Ctx *Context
+}
+
+func (f *Function) Type() ObjectType { return FUNCTION_OBJ }
+func (f *Function) Inspect() string {
+ var out bytes.Buffer
+ params := []string{}
+ for _, p := range f.Parameters {
+ params = append(params, p.String())
+ }
+ out.WriteString("fn(")
+ out.WriteString(strings.Join(params, ", "))
+ out.WriteString(") {\n")
+ out.WriteString(f.Body.String())
+ out.WriteString("\n}")
+ return out.String()
+}