Loading ast/ast.go +11 −6 Original line number Diff line number Diff line Loading @@ -2,6 +2,7 @@ package ast import ( "bytes" "fmt" "monkey/token" "strings" ) Loading Loading @@ -177,15 +178,15 @@ type InfixExpression struct { Right Expression } func (ie *InfixExpression) expressionNode() {} func (ie *InfixExpression) TokenLiteral() string { return ie.Token.Literal } func (ie *InfixExpression) String() string { func (oe *InfixExpression) expressionNode() {} func (oe *InfixExpression) TokenLiteral() string { return oe.Token.Literal } func (oe *InfixExpression) String() string { var out bytes.Buffer out.WriteString("(") out.WriteString(ie.Left.String()) out.WriteString(" " + ie.Operator + " ") out.WriteString(ie.Right.String()) out.WriteString(oe.Left.String()) out.WriteString(" " + oe.Operator + " ") out.WriteString(oe.Right.String()) out.WriteString(")") return out.String() Loading Loading @@ -220,6 +221,7 @@ type FunctionLiteral struct { Token token.Token // The 'fn' token Parameters []*Identifier Body *BlockStatement Name string } func (fl *FunctionLiteral) expressionNode() {} Loading @@ -233,6 +235,9 @@ func (fl *FunctionLiteral) String() string { } out.WriteString(fl.TokenLiteral()) if fl.Name != "" { out.WriteString(fmt.Sprintf("<%s>", fl.Name)) } out.WriteString("(") out.WriteString(strings.Join(params, ", ")) out.WriteString(") ") Loading benchmark/main.go 0 → 100644 +75 −0 Original line number Diff line number Diff line package main import ( "flag" "fmt" "time" "monkey/compiler" "monkey/evaluator" "monkey/lexer" "monkey/object" "monkey/parser" "monkey/vm" ) var engine = flag.String("engine", "vm", "use 'vm' or 'eval'") var input = ` let fibonacci = fn(x) { if (x == 0) { 0 } else { if (x == 1) { return 1; } else { fibonacci(x - 1) + fibonacci(x - 2); } } }; fibonacci(35); ` func main() { flag.Parse() var duration time.Duration var result object.Object l := lexer.New(input) p := parser.New(l) program := p.ParseProgram() if *engine == "vm" { comp := compiler.New() err := comp.Compile(program) if err != nil { fmt.Printf("compiler error: %s", err) return } machine := vm.New(comp.Bytecode()) start := time.Now() err = machine.Run() if err != nil { fmt.Printf("vm error: %s", err) return } duration = time.Since(start) result = machine.LastPoppedStackElem() } else { env := object.NewEnvironment() start := time.Now() result = evaluator.Eval(program, env) duration = time.Since(start) } fmt.Printf( "engine=%s, result=%s, duration=%s\n", *engine, result.Inspect(), duration) } code/code.go 0 → 100644 +219 −0 Original line number Diff line number Diff line package code import ( "bytes" "encoding/binary" "fmt" ) type Instructions []byte func (ins Instructions) String() string { var out bytes.Buffer i := 0 for i < len(ins) { def, err := Lookup(ins[i]) if err != nil { fmt.Fprintf(&out, "ERROR: %s\n", err) continue } operands, read := ReadOperands(def, ins[i+1:]) fmt.Fprintf(&out, "%04d %s\n", i, ins.fmtInstruction(def, operands)) i += 1 + read } return out.String() } func (ins Instructions) fmtInstruction(def *Definition, operands []int) string { operandCount := len(def.OperandWidths) if len(operands) != operandCount { return fmt.Sprintf("ERROR: operand len %d does not match defined %d\n", len(operands), operandCount) } switch operandCount { case 0: return def.Name case 1: return fmt.Sprintf("%s %d", def.Name, operands[0]) case 2: return fmt.Sprintf("%s %d %d", def.Name, operands[0], operands[1]) } return fmt.Sprintf("ERROR: unhandled operandCount for %s\n", def.Name) } type Opcode byte const ( OpConstant Opcode = iota OpAdd OpPop OpSub OpMul OpDiv OpTrue OpFalse OpEqual OpNotEqual OpGreaterThan OpMinus OpBang OpJumpNotTruthy OpJump OpNull OpGetGlobal OpSetGlobal OpArray OpHash OpIndex OpCall OpReturnValue OpReturn OpGetLocal OpSetLocal OpGetBuiltin OpClosure OpGetFree OpCurrentClosure ) type Definition struct { Name string OperandWidths []int } var definitions = map[Opcode]*Definition{ OpConstant: {"OpConstant", []int{2}}, OpAdd: {"OpAdd", []int{}}, OpPop: {"OpPop", []int{}}, OpSub: {"OpSub", []int{}}, OpMul: {"OpMul", []int{}}, OpDiv: {"OpDiv", []int{}}, OpTrue: {"OpTrue", []int{}}, OpFalse: {"OpFalse", []int{}}, OpEqual: {"OpEqual", []int{}}, OpNotEqual: {"OpNotEqual", []int{}}, OpGreaterThan: {"OpGreaterThan", []int{}}, OpMinus: {"OpMinus", []int{}}, OpBang: {"OpBang", []int{}}, OpJumpNotTruthy: {"OpJumpNotTruthy", []int{2}}, OpJump: {"OpJump", []int{2}}, OpNull: {"OpNull", []int{}}, OpGetGlobal: {"OpGetGlobal", []int{2}}, OpSetGlobal: {"OpSetGlobal", []int{2}}, OpArray: {"OpArray", []int{2}}, OpHash: {"OpHash", []int{2}}, OpIndex: {"OpIndex", []int{}}, OpCall: {"OpCall", []int{1}}, OpReturnValue: {"OpReturnValue", []int{}}, OpReturn: {"OpReturn", []int{}}, OpGetLocal: {"OpGetLocal", []int{1}}, OpSetLocal: {"OpSetLocal", []int{1}}, OpGetBuiltin: {"OpGetBuiltin", []int{1}}, OpClosure: {"OpClosure", []int{2, 1}}, OpGetFree: {"OpGetFree", []int{1}}, OpCurrentClosure: {"OpCurrentClosure", []int{}}, } func Lookup(op byte) (*Definition, error) { def, ok := definitions[Opcode(op)] if !ok { return nil, fmt.Errorf("opcode %d undefined", op) } return def, nil } func Make(op Opcode, operands ...int) []byte { def, ok := definitions[op] if !ok { return []byte{} } instructionLen := 1 for _, w := range def.OperandWidths { instructionLen += w } instruction := make([]byte, instructionLen) instruction[0] = byte(op) offset := 1 for i, o := range operands { width := def.OperandWidths[i] switch width { case 2: binary.BigEndian.PutUint16(instruction[offset:], uint16(o)) case 1: instruction[offset] = byte(o) } offset += width } return instruction } func ReadOperands(def *Definition, ins Instructions) ([]int, int) { operands := make([]int, len(def.OperandWidths)) offset := 0 for i, width := range def.OperandWidths { switch width { case 2: operands[i] = int(ReadUint16(ins[offset:])) case 1: operands[i] = int(ReadUint8(ins[offset:])) } offset += width } return operands, offset } func ReadUint8(ins Instructions) uint8 { return uint8(ins[0]) } func ReadUint16(ins Instructions) uint16 { return binary.BigEndian.Uint16(ins) } code/code_test.go 0 → 100644 +91 −0 Original line number Diff line number Diff line package code import "testing" func TestMake(t *testing.T) { tests := []struct { op Opcode operands []int expected []byte }{ {OpConstant, []int{65534}, []byte{byte(OpConstant), 255, 254}}, {OpAdd, []int{}, []byte{byte(OpAdd)}}, {OpGetLocal, []int{255}, []byte{byte(OpGetLocal), 255}}, {OpClosure, []int{65534, 255}, []byte{byte(OpClosure), 255, 254, 255}}, } for _, tt := range tests { instruction := Make(tt.op, tt.operands...) if len(instruction) != len(tt.expected) { t.Errorf("instruction has wrong length. want=%d, got=%d", len(tt.expected), len(instruction)) } for i, b := range tt.expected { if instruction[i] != tt.expected[i] { t.Errorf("wrong byte at pos %d. want=%d, got=%d", i, b, instruction[i]) } } } } func TestInstructionsString(t *testing.T) { instructions := []Instructions{ Make(OpAdd), Make(OpGetLocal, 1), Make(OpConstant, 2), Make(OpConstant, 65535), Make(OpClosure, 65535, 255), } expected := `0000 OpAdd 0001 OpGetLocal 1 0003 OpConstant 2 0006 OpConstant 65535 0009 OpClosure 65535 255 ` concatted := Instructions{} for _, ins := range instructions { concatted = append(concatted, ins...) } if concatted.String() != expected { t.Errorf("instructions wrongly formatted.\nwant=%q\ngot=%q", expected, concatted.String()) } } func TestReadOperands(t *testing.T) { tests := []struct { op Opcode operands []int bytesRead int }{ {OpConstant, []int{65535}, 2}, {OpGetLocal, []int{255}, 1}, {OpClosure, []int{65535, 255}, 3}, } for _, tt := range tests { instruction := Make(tt.op, tt.operands...) def, err := Lookup(byte(tt.op)) if err != nil { t.Fatalf("definition not found: %q\n", err) } operandsRead, n := ReadOperands(def, instruction[1:]) if n != tt.bytesRead { t.Fatalf("n wrong. want=%d, got=%d", tt.bytesRead, n) } for i, want := range tt.operands { if operandsRead[i] != want { t.Errorf("operand wrong. want=%d, got=%d", want, operandsRead[i]) } } } } compiler/compiler.go 0 → 100644 +456 −0 Original line number Diff line number Diff line package compiler import ( "fmt" "monkey/ast" "monkey/code" "monkey/object" "sort" ) type Compiler struct { constants []object.Object symbolTable *SymbolTable scopes []CompilationScope scopeIndex int } func New() *Compiler { mainScope := CompilationScope{ instructions: code.Instructions{}, lastInstruction: EmittedInstruction{}, previousInstruction: EmittedInstruction{}, } symbolTable := NewSymbolTable() for i, v := range object.Builtins { symbolTable.DefineBuiltin(i, v.Name) } return &Compiler{ constants: []object.Object{}, symbolTable: symbolTable, scopes: []CompilationScope{mainScope}, scopeIndex: 0, } } func NewWithState(s *SymbolTable, constants []object.Object) *Compiler { compiler := New() compiler.symbolTable = s compiler.constants = constants return compiler } func (c *Compiler) Compile(node ast.Node) error { switch node := node.(type) { case *ast.Program: for _, s := range node.Statements { err := c.Compile(s) if err != nil { return err } } case *ast.ExpressionStatement: err := c.Compile(node.Expression) if err != nil { return err } c.emit(code.OpPop) case *ast.InfixExpression: if node.Operator == "<" { err := c.Compile(node.Right) if err != nil { return err } err = c.Compile(node.Left) if err != nil { return err } c.emit(code.OpGreaterThan) return nil } err := c.Compile(node.Left) if err != nil { return err } err = c.Compile(node.Right) if err != nil { return err } switch node.Operator { case "+": c.emit(code.OpAdd) case "-": c.emit(code.OpSub) case "*": c.emit(code.OpMul) case "/": c.emit(code.OpDiv) case ">": c.emit(code.OpGreaterThan) case "==": c.emit(code.OpEqual) case "!=": c.emit(code.OpNotEqual) default: return fmt.Errorf("unknown operator %s", node.Operator) } case *ast.IntegerLiteral: integer := &object.Integer{Value: node.Value} c.emit(code.OpConstant, c.addConstant(integer)) case *ast.Boolean: if node.Value { c.emit(code.OpTrue) } else { c.emit(code.OpFalse) } case *ast.PrefixExpression: err := c.Compile(node.Right) if err != nil { return err } switch node.Operator { case "!": c.emit(code.OpBang) case "-": c.emit(code.OpMinus) default: return fmt.Errorf("unknown operator %s", node.Operator) } case *ast.IfExpression: err := c.Compile(node.Condition) if err != nil { return err } // Emit an `OpJumpNotTruthy` with a bogus value jumpNotTruthyPos := c.emit(code.OpJumpNotTruthy, 9999) err = c.Compile(node.Consequence) if err != nil { return err } if c.lastInstructionIs(code.OpPop) { c.removeLastPop() } // Emit an `OpJump` with a bogus value jumpPos := c.emit(code.OpJump, 9999) afterConsequencePos := len(c.currentInstructions()) c.changeOperand(jumpNotTruthyPos, afterConsequencePos) if node.Alternative == nil { c.emit(code.OpNull) } else { err := c.Compile(node.Alternative) if err != nil { return err } if c.lastInstructionIs(code.OpPop) { c.removeLastPop() } } afterAlternativePos := len(c.currentInstructions()) c.changeOperand(jumpPos, afterAlternativePos) case *ast.BlockStatement: for _, s := range node.Statements { err := c.Compile(s) if err != nil { return err } } case *ast.LetStatement: symbol := c.symbolTable.Define(node.Name.Value) err := c.Compile(node.Value) if err != nil { return err } if symbol.Scope == GlobalScope { c.emit(code.OpSetGlobal, symbol.Index) } else { c.emit(code.OpSetLocal, symbol.Index) } case *ast.Identifier: symbol, ok := c.symbolTable.Resolve(node.Value) if !ok { return fmt.Errorf("undefined variable %s", node.Value) } c.loadSymbol(symbol) case *ast.StringLiteral: str := &object.String{Value: node.Value} c.emit(code.OpConstant, c.addConstant(str)) case *ast.ArrayLiteral: for _, el := range node.Elements { err := c.Compile(el) if err != nil { return err } } c.emit(code.OpArray, len(node.Elements)) case *ast.HashLiteral: keys := []ast.Expression{} for k := range node.Pairs { keys = append(keys, k) } sort.Slice(keys, func(i, j int) bool { return keys[i].String() < keys[j].String() }) for _, k := range keys { err := c.Compile(k) if err != nil { return err } err = c.Compile(node.Pairs[k]) if err != nil { return err } } c.emit(code.OpHash, len(node.Pairs)*2) case *ast.IndexExpression: err := c.Compile(node.Left) if err != nil { return err } err = c.Compile(node.Index) if err != nil { return err } c.emit(code.OpIndex) case *ast.FunctionLiteral: c.enterScope() if node.Name != "" { c.symbolTable.DefineFunctionName(node.Name) } for _, p := range node.Parameters { c.symbolTable.Define(p.Value) } err := c.Compile(node.Body) if err != nil { return err } if c.lastInstructionIs(code.OpPop) { c.replaceLastPopWithReturn() } if !c.lastInstructionIs(code.OpReturnValue) { c.emit(code.OpReturn) } freeSymbols := c.symbolTable.FreeSymbols numLocals := c.symbolTable.numDefinitions instructions := c.leaveScope() for _, s := range freeSymbols { c.loadSymbol(s) } compiledFn := &object.CompiledFunction{ Instructions: instructions, NumLocals: numLocals, NumParameters: len(node.Parameters), } fnIndex := c.addConstant(compiledFn) c.emit(code.OpClosure, fnIndex, len(freeSymbols)) case *ast.ReturnStatement: err := c.Compile(node.ReturnValue) if err != nil { return err } c.emit(code.OpReturnValue) case *ast.CallExpression: err := c.Compile(node.Function) if err != nil { return err } for _, a := range node.Arguments { err := c.Compile(a) if err != nil { return err } } c.emit(code.OpCall, len(node.Arguments)) } return nil } func (c *Compiler) Bytecode() *Bytecode { return &Bytecode{ Instructions: c.currentInstructions(), Constants: c.constants, } } func (c *Compiler) addConstant(obj object.Object) int { c.constants = append(c.constants, obj) return len(c.constants) - 1 } func (c *Compiler) emit(op code.Opcode, operands ...int) int { ins := code.Make(op, operands...) pos := c.addInstruction(ins) c.setLastInstruction(op, pos) return pos } func (c *Compiler) addInstruction(ins []byte) int { posNewInstruction := len(c.currentInstructions()) updatedInstructions := append(c.currentInstructions(), ins...) c.scopes[c.scopeIndex].instructions = updatedInstructions return posNewInstruction } func (c *Compiler) setLastInstruction(op code.Opcode, pos int) { previous := c.scopes[c.scopeIndex].lastInstruction last := EmittedInstruction{Opcode: op, Position: pos} c.scopes[c.scopeIndex].previousInstruction = previous c.scopes[c.scopeIndex].lastInstruction = last } func (c *Compiler) lastInstructionIs(op code.Opcode) bool { if len(c.currentInstructions()) == 0 { return false } return c.scopes[c.scopeIndex].lastInstruction.Opcode == op } func (c *Compiler) removeLastPop() { last := c.scopes[c.scopeIndex].lastInstruction previous := c.scopes[c.scopeIndex].previousInstruction old := c.currentInstructions() new := old[:last.Position] c.scopes[c.scopeIndex].instructions = new c.scopes[c.scopeIndex].lastInstruction = previous } func (c *Compiler) replaceInstruction(pos int, newInstruction []byte) { ins := c.currentInstructions() for i := 0; i < len(newInstruction); i++ { ins[pos+i] = newInstruction[i] } } func (c *Compiler) changeOperand(opPos int, operand int) { op := code.Opcode(c.currentInstructions()[opPos]) newInstruction := code.Make(op, operand) c.replaceInstruction(opPos, newInstruction) } func (c *Compiler) currentInstructions() code.Instructions { return c.scopes[c.scopeIndex].instructions } func (c *Compiler) enterScope() { scope := CompilationScope{ instructions: code.Instructions{}, lastInstruction: EmittedInstruction{}, previousInstruction: EmittedInstruction{}, } c.scopes = append(c.scopes, scope) c.scopeIndex++ c.symbolTable = NewEnclosedSymbolTable(c.symbolTable) } func (c *Compiler) leaveScope() code.Instructions { instructions := c.currentInstructions() c.scopes = c.scopes[:len(c.scopes)-1] c.scopeIndex-- c.symbolTable = c.symbolTable.Outer return instructions } func (c *Compiler) replaceLastPopWithReturn() { lastPos := c.scopes[c.scopeIndex].lastInstruction.Position c.replaceInstruction(lastPos, code.Make(code.OpReturnValue)) c.scopes[c.scopeIndex].lastInstruction.Opcode = code.OpReturnValue } func (c *Compiler) loadSymbol(s Symbol) { switch s.Scope { case GlobalScope: c.emit(code.OpGetGlobal, s.Index) case LocalScope: c.emit(code.OpGetLocal, s.Index) case BuiltinScope: c.emit(code.OpGetBuiltin, s.Index) case FreeScope: c.emit(code.OpGetFree, s.Index) case FunctionScope: c.emit(code.OpCurrentClosure) } } type Bytecode struct { Instructions code.Instructions Constants []object.Object } type EmittedInstruction struct { Opcode code.Opcode Position int } type CompilationScope struct { instructions code.Instructions lastInstruction EmittedInstruction previousInstruction EmittedInstruction } Loading
ast/ast.go +11 −6 Original line number Diff line number Diff line Loading @@ -2,6 +2,7 @@ package ast import ( "bytes" "fmt" "monkey/token" "strings" ) Loading Loading @@ -177,15 +178,15 @@ type InfixExpression struct { Right Expression } func (ie *InfixExpression) expressionNode() {} func (ie *InfixExpression) TokenLiteral() string { return ie.Token.Literal } func (ie *InfixExpression) String() string { func (oe *InfixExpression) expressionNode() {} func (oe *InfixExpression) TokenLiteral() string { return oe.Token.Literal } func (oe *InfixExpression) String() string { var out bytes.Buffer out.WriteString("(") out.WriteString(ie.Left.String()) out.WriteString(" " + ie.Operator + " ") out.WriteString(ie.Right.String()) out.WriteString(oe.Left.String()) out.WriteString(" " + oe.Operator + " ") out.WriteString(oe.Right.String()) out.WriteString(")") return out.String() Loading Loading @@ -220,6 +221,7 @@ type FunctionLiteral struct { Token token.Token // The 'fn' token Parameters []*Identifier Body *BlockStatement Name string } func (fl *FunctionLiteral) expressionNode() {} Loading @@ -233,6 +235,9 @@ func (fl *FunctionLiteral) String() string { } out.WriteString(fl.TokenLiteral()) if fl.Name != "" { out.WriteString(fmt.Sprintf("<%s>", fl.Name)) } out.WriteString("(") out.WriteString(strings.Join(params, ", ")) out.WriteString(") ") Loading
benchmark/main.go 0 → 100644 +75 −0 Original line number Diff line number Diff line package main import ( "flag" "fmt" "time" "monkey/compiler" "monkey/evaluator" "monkey/lexer" "monkey/object" "monkey/parser" "monkey/vm" ) var engine = flag.String("engine", "vm", "use 'vm' or 'eval'") var input = ` let fibonacci = fn(x) { if (x == 0) { 0 } else { if (x == 1) { return 1; } else { fibonacci(x - 1) + fibonacci(x - 2); } } }; fibonacci(35); ` func main() { flag.Parse() var duration time.Duration var result object.Object l := lexer.New(input) p := parser.New(l) program := p.ParseProgram() if *engine == "vm" { comp := compiler.New() err := comp.Compile(program) if err != nil { fmt.Printf("compiler error: %s", err) return } machine := vm.New(comp.Bytecode()) start := time.Now() err = machine.Run() if err != nil { fmt.Printf("vm error: %s", err) return } duration = time.Since(start) result = machine.LastPoppedStackElem() } else { env := object.NewEnvironment() start := time.Now() result = evaluator.Eval(program, env) duration = time.Since(start) } fmt.Printf( "engine=%s, result=%s, duration=%s\n", *engine, result.Inspect(), duration) }
code/code.go 0 → 100644 +219 −0 Original line number Diff line number Diff line package code import ( "bytes" "encoding/binary" "fmt" ) type Instructions []byte func (ins Instructions) String() string { var out bytes.Buffer i := 0 for i < len(ins) { def, err := Lookup(ins[i]) if err != nil { fmt.Fprintf(&out, "ERROR: %s\n", err) continue } operands, read := ReadOperands(def, ins[i+1:]) fmt.Fprintf(&out, "%04d %s\n", i, ins.fmtInstruction(def, operands)) i += 1 + read } return out.String() } func (ins Instructions) fmtInstruction(def *Definition, operands []int) string { operandCount := len(def.OperandWidths) if len(operands) != operandCount { return fmt.Sprintf("ERROR: operand len %d does not match defined %d\n", len(operands), operandCount) } switch operandCount { case 0: return def.Name case 1: return fmt.Sprintf("%s %d", def.Name, operands[0]) case 2: return fmt.Sprintf("%s %d %d", def.Name, operands[0], operands[1]) } return fmt.Sprintf("ERROR: unhandled operandCount for %s\n", def.Name) } type Opcode byte const ( OpConstant Opcode = iota OpAdd OpPop OpSub OpMul OpDiv OpTrue OpFalse OpEqual OpNotEqual OpGreaterThan OpMinus OpBang OpJumpNotTruthy OpJump OpNull OpGetGlobal OpSetGlobal OpArray OpHash OpIndex OpCall OpReturnValue OpReturn OpGetLocal OpSetLocal OpGetBuiltin OpClosure OpGetFree OpCurrentClosure ) type Definition struct { Name string OperandWidths []int } var definitions = map[Opcode]*Definition{ OpConstant: {"OpConstant", []int{2}}, OpAdd: {"OpAdd", []int{}}, OpPop: {"OpPop", []int{}}, OpSub: {"OpSub", []int{}}, OpMul: {"OpMul", []int{}}, OpDiv: {"OpDiv", []int{}}, OpTrue: {"OpTrue", []int{}}, OpFalse: {"OpFalse", []int{}}, OpEqual: {"OpEqual", []int{}}, OpNotEqual: {"OpNotEqual", []int{}}, OpGreaterThan: {"OpGreaterThan", []int{}}, OpMinus: {"OpMinus", []int{}}, OpBang: {"OpBang", []int{}}, OpJumpNotTruthy: {"OpJumpNotTruthy", []int{2}}, OpJump: {"OpJump", []int{2}}, OpNull: {"OpNull", []int{}}, OpGetGlobal: {"OpGetGlobal", []int{2}}, OpSetGlobal: {"OpSetGlobal", []int{2}}, OpArray: {"OpArray", []int{2}}, OpHash: {"OpHash", []int{2}}, OpIndex: {"OpIndex", []int{}}, OpCall: {"OpCall", []int{1}}, OpReturnValue: {"OpReturnValue", []int{}}, OpReturn: {"OpReturn", []int{}}, OpGetLocal: {"OpGetLocal", []int{1}}, OpSetLocal: {"OpSetLocal", []int{1}}, OpGetBuiltin: {"OpGetBuiltin", []int{1}}, OpClosure: {"OpClosure", []int{2, 1}}, OpGetFree: {"OpGetFree", []int{1}}, OpCurrentClosure: {"OpCurrentClosure", []int{}}, } func Lookup(op byte) (*Definition, error) { def, ok := definitions[Opcode(op)] if !ok { return nil, fmt.Errorf("opcode %d undefined", op) } return def, nil } func Make(op Opcode, operands ...int) []byte { def, ok := definitions[op] if !ok { return []byte{} } instructionLen := 1 for _, w := range def.OperandWidths { instructionLen += w } instruction := make([]byte, instructionLen) instruction[0] = byte(op) offset := 1 for i, o := range operands { width := def.OperandWidths[i] switch width { case 2: binary.BigEndian.PutUint16(instruction[offset:], uint16(o)) case 1: instruction[offset] = byte(o) } offset += width } return instruction } func ReadOperands(def *Definition, ins Instructions) ([]int, int) { operands := make([]int, len(def.OperandWidths)) offset := 0 for i, width := range def.OperandWidths { switch width { case 2: operands[i] = int(ReadUint16(ins[offset:])) case 1: operands[i] = int(ReadUint8(ins[offset:])) } offset += width } return operands, offset } func ReadUint8(ins Instructions) uint8 { return uint8(ins[0]) } func ReadUint16(ins Instructions) uint16 { return binary.BigEndian.Uint16(ins) }
code/code_test.go 0 → 100644 +91 −0 Original line number Diff line number Diff line package code import "testing" func TestMake(t *testing.T) { tests := []struct { op Opcode operands []int expected []byte }{ {OpConstant, []int{65534}, []byte{byte(OpConstant), 255, 254}}, {OpAdd, []int{}, []byte{byte(OpAdd)}}, {OpGetLocal, []int{255}, []byte{byte(OpGetLocal), 255}}, {OpClosure, []int{65534, 255}, []byte{byte(OpClosure), 255, 254, 255}}, } for _, tt := range tests { instruction := Make(tt.op, tt.operands...) if len(instruction) != len(tt.expected) { t.Errorf("instruction has wrong length. want=%d, got=%d", len(tt.expected), len(instruction)) } for i, b := range tt.expected { if instruction[i] != tt.expected[i] { t.Errorf("wrong byte at pos %d. want=%d, got=%d", i, b, instruction[i]) } } } } func TestInstructionsString(t *testing.T) { instructions := []Instructions{ Make(OpAdd), Make(OpGetLocal, 1), Make(OpConstant, 2), Make(OpConstant, 65535), Make(OpClosure, 65535, 255), } expected := `0000 OpAdd 0001 OpGetLocal 1 0003 OpConstant 2 0006 OpConstant 65535 0009 OpClosure 65535 255 ` concatted := Instructions{} for _, ins := range instructions { concatted = append(concatted, ins...) } if concatted.String() != expected { t.Errorf("instructions wrongly formatted.\nwant=%q\ngot=%q", expected, concatted.String()) } } func TestReadOperands(t *testing.T) { tests := []struct { op Opcode operands []int bytesRead int }{ {OpConstant, []int{65535}, 2}, {OpGetLocal, []int{255}, 1}, {OpClosure, []int{65535, 255}, 3}, } for _, tt := range tests { instruction := Make(tt.op, tt.operands...) def, err := Lookup(byte(tt.op)) if err != nil { t.Fatalf("definition not found: %q\n", err) } operandsRead, n := ReadOperands(def, instruction[1:]) if n != tt.bytesRead { t.Fatalf("n wrong. want=%d, got=%d", tt.bytesRead, n) } for i, want := range tt.operands { if operandsRead[i] != want { t.Errorf("operand wrong. want=%d, got=%d", want, operandsRead[i]) } } } }
compiler/compiler.go 0 → 100644 +456 −0 Original line number Diff line number Diff line package compiler import ( "fmt" "monkey/ast" "monkey/code" "monkey/object" "sort" ) type Compiler struct { constants []object.Object symbolTable *SymbolTable scopes []CompilationScope scopeIndex int } func New() *Compiler { mainScope := CompilationScope{ instructions: code.Instructions{}, lastInstruction: EmittedInstruction{}, previousInstruction: EmittedInstruction{}, } symbolTable := NewSymbolTable() for i, v := range object.Builtins { symbolTable.DefineBuiltin(i, v.Name) } return &Compiler{ constants: []object.Object{}, symbolTable: symbolTable, scopes: []CompilationScope{mainScope}, scopeIndex: 0, } } func NewWithState(s *SymbolTable, constants []object.Object) *Compiler { compiler := New() compiler.symbolTable = s compiler.constants = constants return compiler } func (c *Compiler) Compile(node ast.Node) error { switch node := node.(type) { case *ast.Program: for _, s := range node.Statements { err := c.Compile(s) if err != nil { return err } } case *ast.ExpressionStatement: err := c.Compile(node.Expression) if err != nil { return err } c.emit(code.OpPop) case *ast.InfixExpression: if node.Operator == "<" { err := c.Compile(node.Right) if err != nil { return err } err = c.Compile(node.Left) if err != nil { return err } c.emit(code.OpGreaterThan) return nil } err := c.Compile(node.Left) if err != nil { return err } err = c.Compile(node.Right) if err != nil { return err } switch node.Operator { case "+": c.emit(code.OpAdd) case "-": c.emit(code.OpSub) case "*": c.emit(code.OpMul) case "/": c.emit(code.OpDiv) case ">": c.emit(code.OpGreaterThan) case "==": c.emit(code.OpEqual) case "!=": c.emit(code.OpNotEqual) default: return fmt.Errorf("unknown operator %s", node.Operator) } case *ast.IntegerLiteral: integer := &object.Integer{Value: node.Value} c.emit(code.OpConstant, c.addConstant(integer)) case *ast.Boolean: if node.Value { c.emit(code.OpTrue) } else { c.emit(code.OpFalse) } case *ast.PrefixExpression: err := c.Compile(node.Right) if err != nil { return err } switch node.Operator { case "!": c.emit(code.OpBang) case "-": c.emit(code.OpMinus) default: return fmt.Errorf("unknown operator %s", node.Operator) } case *ast.IfExpression: err := c.Compile(node.Condition) if err != nil { return err } // Emit an `OpJumpNotTruthy` with a bogus value jumpNotTruthyPos := c.emit(code.OpJumpNotTruthy, 9999) err = c.Compile(node.Consequence) if err != nil { return err } if c.lastInstructionIs(code.OpPop) { c.removeLastPop() } // Emit an `OpJump` with a bogus value jumpPos := c.emit(code.OpJump, 9999) afterConsequencePos := len(c.currentInstructions()) c.changeOperand(jumpNotTruthyPos, afterConsequencePos) if node.Alternative == nil { c.emit(code.OpNull) } else { err := c.Compile(node.Alternative) if err != nil { return err } if c.lastInstructionIs(code.OpPop) { c.removeLastPop() } } afterAlternativePos := len(c.currentInstructions()) c.changeOperand(jumpPos, afterAlternativePos) case *ast.BlockStatement: for _, s := range node.Statements { err := c.Compile(s) if err != nil { return err } } case *ast.LetStatement: symbol := c.symbolTable.Define(node.Name.Value) err := c.Compile(node.Value) if err != nil { return err } if symbol.Scope == GlobalScope { c.emit(code.OpSetGlobal, symbol.Index) } else { c.emit(code.OpSetLocal, symbol.Index) } case *ast.Identifier: symbol, ok := c.symbolTable.Resolve(node.Value) if !ok { return fmt.Errorf("undefined variable %s", node.Value) } c.loadSymbol(symbol) case *ast.StringLiteral: str := &object.String{Value: node.Value} c.emit(code.OpConstant, c.addConstant(str)) case *ast.ArrayLiteral: for _, el := range node.Elements { err := c.Compile(el) if err != nil { return err } } c.emit(code.OpArray, len(node.Elements)) case *ast.HashLiteral: keys := []ast.Expression{} for k := range node.Pairs { keys = append(keys, k) } sort.Slice(keys, func(i, j int) bool { return keys[i].String() < keys[j].String() }) for _, k := range keys { err := c.Compile(k) if err != nil { return err } err = c.Compile(node.Pairs[k]) if err != nil { return err } } c.emit(code.OpHash, len(node.Pairs)*2) case *ast.IndexExpression: err := c.Compile(node.Left) if err != nil { return err } err = c.Compile(node.Index) if err != nil { return err } c.emit(code.OpIndex) case *ast.FunctionLiteral: c.enterScope() if node.Name != "" { c.symbolTable.DefineFunctionName(node.Name) } for _, p := range node.Parameters { c.symbolTable.Define(p.Value) } err := c.Compile(node.Body) if err != nil { return err } if c.lastInstructionIs(code.OpPop) { c.replaceLastPopWithReturn() } if !c.lastInstructionIs(code.OpReturnValue) { c.emit(code.OpReturn) } freeSymbols := c.symbolTable.FreeSymbols numLocals := c.symbolTable.numDefinitions instructions := c.leaveScope() for _, s := range freeSymbols { c.loadSymbol(s) } compiledFn := &object.CompiledFunction{ Instructions: instructions, NumLocals: numLocals, NumParameters: len(node.Parameters), } fnIndex := c.addConstant(compiledFn) c.emit(code.OpClosure, fnIndex, len(freeSymbols)) case *ast.ReturnStatement: err := c.Compile(node.ReturnValue) if err != nil { return err } c.emit(code.OpReturnValue) case *ast.CallExpression: err := c.Compile(node.Function) if err != nil { return err } for _, a := range node.Arguments { err := c.Compile(a) if err != nil { return err } } c.emit(code.OpCall, len(node.Arguments)) } return nil } func (c *Compiler) Bytecode() *Bytecode { return &Bytecode{ Instructions: c.currentInstructions(), Constants: c.constants, } } func (c *Compiler) addConstant(obj object.Object) int { c.constants = append(c.constants, obj) return len(c.constants) - 1 } func (c *Compiler) emit(op code.Opcode, operands ...int) int { ins := code.Make(op, operands...) pos := c.addInstruction(ins) c.setLastInstruction(op, pos) return pos } func (c *Compiler) addInstruction(ins []byte) int { posNewInstruction := len(c.currentInstructions()) updatedInstructions := append(c.currentInstructions(), ins...) c.scopes[c.scopeIndex].instructions = updatedInstructions return posNewInstruction } func (c *Compiler) setLastInstruction(op code.Opcode, pos int) { previous := c.scopes[c.scopeIndex].lastInstruction last := EmittedInstruction{Opcode: op, Position: pos} c.scopes[c.scopeIndex].previousInstruction = previous c.scopes[c.scopeIndex].lastInstruction = last } func (c *Compiler) lastInstructionIs(op code.Opcode) bool { if len(c.currentInstructions()) == 0 { return false } return c.scopes[c.scopeIndex].lastInstruction.Opcode == op } func (c *Compiler) removeLastPop() { last := c.scopes[c.scopeIndex].lastInstruction previous := c.scopes[c.scopeIndex].previousInstruction old := c.currentInstructions() new := old[:last.Position] c.scopes[c.scopeIndex].instructions = new c.scopes[c.scopeIndex].lastInstruction = previous } func (c *Compiler) replaceInstruction(pos int, newInstruction []byte) { ins := c.currentInstructions() for i := 0; i < len(newInstruction); i++ { ins[pos+i] = newInstruction[i] } } func (c *Compiler) changeOperand(opPos int, operand int) { op := code.Opcode(c.currentInstructions()[opPos]) newInstruction := code.Make(op, operand) c.replaceInstruction(opPos, newInstruction) } func (c *Compiler) currentInstructions() code.Instructions { return c.scopes[c.scopeIndex].instructions } func (c *Compiler) enterScope() { scope := CompilationScope{ instructions: code.Instructions{}, lastInstruction: EmittedInstruction{}, previousInstruction: EmittedInstruction{}, } c.scopes = append(c.scopes, scope) c.scopeIndex++ c.symbolTable = NewEnclosedSymbolTable(c.symbolTable) } func (c *Compiler) leaveScope() code.Instructions { instructions := c.currentInstructions() c.scopes = c.scopes[:len(c.scopes)-1] c.scopeIndex-- c.symbolTable = c.symbolTable.Outer return instructions } func (c *Compiler) replaceLastPopWithReturn() { lastPos := c.scopes[c.scopeIndex].lastInstruction.Position c.replaceInstruction(lastPos, code.Make(code.OpReturnValue)) c.scopes[c.scopeIndex].lastInstruction.Opcode = code.OpReturnValue } func (c *Compiler) loadSymbol(s Symbol) { switch s.Scope { case GlobalScope: c.emit(code.OpGetGlobal, s.Index) case LocalScope: c.emit(code.OpGetLocal, s.Index) case BuiltinScope: c.emit(code.OpGetBuiltin, s.Index) case FreeScope: c.emit(code.OpGetFree, s.Index) case FunctionScope: c.emit(code.OpCurrentClosure) } } type Bytecode struct { Instructions code.Instructions Constants []object.Object } type EmittedInstruction struct { Opcode code.Opcode Position int } type CompilationScope struct { instructions code.Instructions lastInstruction EmittedInstruction previousInstruction EmittedInstruction }