shader.go (18021B)
1 // Copyright 2020 The Ebiten Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package shader 16 17 import ( 18 "fmt" 19 "go/ast" 20 "go/token" 21 "strings" 22 23 "github.com/hajimehoshi/ebiten/v2/internal/shaderir" 24 ) 25 26 type variable struct { 27 name string 28 typ shaderir.Type 29 forLoopCounter bool 30 } 31 32 type constant struct { 33 name string 34 typ shaderir.Type 35 init ast.Expr 36 } 37 38 type function struct { 39 name string 40 block *block 41 42 ir shaderir.Func 43 } 44 45 type compileState struct { 46 fs *token.FileSet 47 48 vertexEntry string 49 fragmentEntry string 50 51 ir shaderir.Program 52 53 funcs []function 54 55 global block 56 57 varyingParsed bool 58 59 errs []string 60 } 61 62 func (cs *compileState) findFunction(name string) (int, bool) { 63 for i, f := range cs.funcs { 64 if f.name == name { 65 return i, true 66 } 67 } 68 return 0, false 69 } 70 71 func (cs *compileState) findUniformVariable(name string) (int, bool) { 72 for i, u := range cs.ir.UniformNames { 73 if u == name { 74 return i, true 75 } 76 } 77 return 0, false 78 } 79 80 type typ struct { 81 name string 82 ir shaderir.Type 83 } 84 85 type block struct { 86 types []typ 87 vars []variable 88 unusedVars map[int]token.Pos 89 consts []constant 90 pos token.Pos 91 outer *block 92 93 ir *shaderir.Block 94 } 95 96 func (b *block) totalLocalVariableNum() int { 97 c := len(b.vars) 98 if b.outer != nil { 99 c += b.outer.totalLocalVariableNum() 100 } 101 return c 102 } 103 104 func (b *block) addNamedLocalVariable(name string, typ shaderir.Type, pos token.Pos) { 105 b.vars = append(b.vars, variable{ 106 name: name, 107 typ: typ, 108 }) 109 if name == "_" { 110 return 111 } 112 idx := len(b.vars) - 1 113 if b.unusedVars == nil { 114 b.unusedVars = map[int]token.Pos{} 115 } 116 b.unusedVars[idx] = pos 117 } 118 119 func (b *block) findLocalVariable(name string, markLocalVariableUsed bool) (int, shaderir.Type, bool) { 120 if name == "" || name == "_" { 121 panic("shader: variable name must be non-empty and non-underscore") 122 } 123 124 idx := 0 125 for outer := b.outer; outer != nil; outer = outer.outer { 126 idx += len(outer.vars) 127 } 128 for i, v := range b.vars { 129 if v.name == name { 130 if markLocalVariableUsed { 131 delete(b.unusedVars, i) 132 } 133 return idx + i, v.typ, true 134 } 135 } 136 if b.outer != nil { 137 return b.outer.findLocalVariable(name, markLocalVariableUsed) 138 } 139 return 0, shaderir.Type{}, false 140 } 141 142 func (b *block) findLocalVariableByIndex(idx int) (shaderir.Type, bool) { 143 bs := []*block{b} 144 for outer := b.outer; outer != nil; outer = outer.outer { 145 bs = append(bs, outer) 146 } 147 for i := len(bs) - 1; i >= 0; i-- { 148 if len(bs[i].vars) <= idx { 149 idx -= len(bs[i].vars) 150 continue 151 } 152 return bs[i].vars[idx].typ, true 153 } 154 return shaderir.Type{}, false 155 } 156 157 type ParseError struct { 158 errs []string 159 } 160 161 func (p *ParseError) Error() string { 162 return strings.Join(p.errs, "\n") 163 } 164 165 func Compile(fs *token.FileSet, f *ast.File, vertexEntry, fragmentEntry string, textureNum int) (*shaderir.Program, error) { 166 s := &compileState{ 167 fs: fs, 168 vertexEntry: vertexEntry, 169 fragmentEntry: fragmentEntry, 170 } 171 s.global.ir = &shaderir.Block{} 172 s.parse(f) 173 174 if len(s.errs) > 0 { 175 return nil, &ParseError{s.errs} 176 } 177 178 // TODO: Resolve identifiers? 179 // TODO: Resolve constants 180 181 // TODO: Make a call graph and reorder the elements. 182 183 s.ir.TextureNum = textureNum 184 return &s.ir, nil 185 } 186 187 func (s *compileState) addError(pos token.Pos, str string) { 188 p := s.fs.Position(pos) 189 s.errs = append(s.errs, fmt.Sprintf("%s: %s", p, str)) 190 } 191 192 func (cs *compileState) parse(f *ast.File) { 193 // Parse GenDecl for global variables, and then parse functions. 194 for _, d := range f.Decls { 195 if _, ok := d.(*ast.FuncDecl); !ok { 196 ss, ok := cs.parseDecl(&cs.global, d) 197 if !ok { 198 return 199 } 200 cs.global.ir.Stmts = append(cs.global.ir.Stmts, ss...) 201 } 202 } 203 204 // Sort the uniform variable so that special variable starting with __ should come first. 205 var unames []string 206 var utypes []shaderir.Type 207 for i, u := range cs.ir.UniformNames { 208 if strings.HasPrefix(u, "__") { 209 unames = append(unames, u) 210 utypes = append(utypes, cs.ir.Uniforms[i]) 211 } 212 } 213 // TODO: Check len(unames) == graphics.PreservedUniformVariablesNum. Unfortunately this is not true on tests. 214 for i, u := range cs.ir.UniformNames { 215 if !strings.HasPrefix(u, "__") { 216 unames = append(unames, u) 217 utypes = append(utypes, cs.ir.Uniforms[i]) 218 } 219 } 220 cs.ir.UniformNames = unames 221 cs.ir.Uniforms = utypes 222 223 // Parse function names so that any other function call the others. 224 // The function data is provisional and will be updated soon. 225 for _, d := range f.Decls { 226 fd, ok := d.(*ast.FuncDecl) 227 if !ok { 228 continue 229 } 230 n := fd.Name.Name 231 if n == cs.vertexEntry { 232 continue 233 } 234 if n == cs.fragmentEntry { 235 continue 236 } 237 238 inParams, outParams := cs.parseFuncParams(&cs.global, fd) 239 var inT, outT []shaderir.Type 240 for _, v := range inParams { 241 inT = append(inT, v.typ) 242 } 243 for _, v := range outParams { 244 outT = append(outT, v.typ) 245 } 246 247 cs.funcs = append(cs.funcs, function{ 248 name: n, 249 ir: shaderir.Func{ 250 Index: len(cs.funcs), 251 InParams: inT, 252 OutParams: outT, 253 Block: &shaderir.Block{}, 254 }, 255 }) 256 } 257 258 // Parse functions. 259 for _, d := range f.Decls { 260 if _, ok := d.(*ast.FuncDecl); ok { 261 ss, ok := cs.parseDecl(&cs.global, d) 262 if !ok { 263 return 264 } 265 cs.global.ir.Stmts = append(cs.global.ir.Stmts, ss...) 266 } 267 } 268 269 if len(cs.errs) > 0 { 270 return 271 } 272 273 for _, f := range cs.funcs { 274 cs.ir.Funcs = append(cs.ir.Funcs, f.ir) 275 } 276 } 277 278 func (cs *compileState) parseDecl(b *block, d ast.Decl) ([]shaderir.Stmt, bool) { 279 var stmts []shaderir.Stmt 280 281 switch d := d.(type) { 282 case *ast.GenDecl: 283 switch d.Tok { 284 case token.TYPE: 285 // TODO: Parse other types 286 for _, s := range d.Specs { 287 s := s.(*ast.TypeSpec) 288 t, ok := cs.parseType(b, s.Type) 289 if !ok { 290 return nil, false 291 } 292 b.types = append(b.types, typ{ 293 name: s.Name.Name, 294 ir: t, 295 }) 296 } 297 case token.CONST: 298 for _, s := range d.Specs { 299 s := s.(*ast.ValueSpec) 300 cs := cs.parseConstant(b, s) 301 b.consts = append(b.consts, cs...) 302 } 303 case token.VAR: 304 for _, s := range d.Specs { 305 s := s.(*ast.ValueSpec) 306 vs, inits, ss, ok := cs.parseVariable(b, s) 307 if !ok { 308 return nil, false 309 } 310 stmts = append(stmts, ss...) 311 if b == &cs.global { 312 // TODO: Should rhs be ignored? 313 for i, v := range vs { 314 if !strings.HasPrefix(v.name, "__") { 315 if v.name[0] < 'A' || 'Z' < v.name[0] { 316 cs.addError(s.Names[i].Pos(), fmt.Sprintf("global variables must be exposed: %s", v.name)) 317 } 318 } 319 cs.ir.UniformNames = append(cs.ir.UniformNames, v.name) 320 cs.ir.Uniforms = append(cs.ir.Uniforms, v.typ) 321 } 322 continue 323 } 324 325 // base must be obtained before adding the variables. 326 base := b.totalLocalVariableNum() 327 for _, v := range vs { 328 b.addNamedLocalVariable(v.name, v.typ, d.Pos()) 329 } 330 331 if len(inits) > 0 { 332 for i := range vs { 333 stmts = append(stmts, shaderir.Stmt{ 334 Type: shaderir.Assign, 335 Exprs: []shaderir.Expr{ 336 { 337 Type: shaderir.LocalVariable, 338 Index: base + i, 339 }, 340 inits[i], 341 }, 342 }) 343 } 344 } 345 } 346 case token.IMPORT: 347 cs.addError(d.Pos(), "import is forbidden") 348 default: 349 cs.addError(d.Pos(), "unexpected token") 350 } 351 case *ast.FuncDecl: 352 f, ok := cs.parseFunc(b, d) 353 if !ok { 354 return nil, false 355 } 356 if b != &cs.global { 357 cs.addError(d.Pos(), "non-global function is not implemented") 358 return nil, false 359 } 360 switch d.Name.Name { 361 case cs.vertexEntry: 362 cs.ir.VertexFunc.Block = f.ir.Block 363 case cs.fragmentEntry: 364 cs.ir.FragmentFunc.Block = f.ir.Block 365 default: 366 // The function is already registered for their names. 367 for i := range cs.funcs { 368 if cs.funcs[i].name == d.Name.Name { 369 // Index is already determined by the provisional parsing. 370 f.ir.Index = cs.funcs[i].ir.Index 371 cs.funcs[i] = f 372 break 373 } 374 } 375 } 376 default: 377 cs.addError(d.Pos(), "unexpected decl") 378 return nil, false 379 } 380 381 return stmts, true 382 } 383 384 // functionReturnTypes returns the original returning value types, if the given expression is call. 385 // 386 // Note that parseExpr returns the returning types for IR, not the original function. 387 func (cs *compileState) functionReturnTypes(block *block, expr ast.Expr) ([]shaderir.Type, bool) { 388 call, ok := expr.(*ast.CallExpr) 389 if !ok { 390 return nil, false 391 } 392 393 ident, ok := call.Fun.(*ast.Ident) 394 if !ok { 395 return nil, false 396 } 397 398 for _, f := range cs.funcs { 399 if f.name == ident.Name { 400 // TODO: Is it correct to combine out-params and return param? 401 ts := f.ir.OutParams 402 if f.ir.Return.Main != shaderir.None { 403 ts = append(ts, f.ir.Return) 404 } 405 return ts, true 406 } 407 } 408 return nil, false 409 } 410 411 func (s *compileState) parseVariable(block *block, vs *ast.ValueSpec) ([]variable, []shaderir.Expr, []shaderir.Stmt, bool) { 412 if len(vs.Names) != len(vs.Values) && len(vs.Values) != 1 && len(vs.Values) != 0 { 413 s.addError(vs.Pos(), fmt.Sprintf("the numbers of lhs and rhs don't match")) 414 return nil, nil, nil, false 415 } 416 417 var declt shaderir.Type 418 if vs.Type != nil { 419 var ok bool 420 declt, ok = s.parseType(block, vs.Type) 421 if !ok { 422 return nil, nil, nil, false 423 } 424 } 425 426 var ( 427 vars []variable 428 inits []shaderir.Expr 429 stmts []shaderir.Stmt 430 ) 431 432 // These variables are used only in multiple-value context. 433 var inittypes []shaderir.Type 434 var initexprs []shaderir.Expr 435 436 for i, n := range vs.Names { 437 t := declt 438 switch { 439 case len(vs.Values) == 0: 440 // No initialization 441 442 case len(vs.Names) == len(vs.Values): 443 // Single-value context 444 445 init := vs.Values[i] 446 447 es, origts, ss, ok := s.parseExpr(block, init, true) 448 if !ok { 449 return nil, nil, nil, false 450 } 451 452 if t.Main == shaderir.None { 453 ts, ok := s.functionReturnTypes(block, init) 454 if !ok { 455 ts = origts 456 } 457 if len(ts) > 1 { 458 s.addError(vs.Pos(), fmt.Sprintf("the numbers of lhs and rhs don't match")) 459 } 460 t = ts[0] 461 } 462 463 if es[0].Type == shaderir.NumberExpr { 464 switch t.Main { 465 case shaderir.Int: 466 es[0].ConstType = shaderir.ConstTypeInt 467 case shaderir.Float: 468 es[0].ConstType = shaderir.ConstTypeFloat 469 } 470 } 471 472 inits = append(inits, es...) 473 stmts = append(stmts, ss...) 474 475 default: 476 // Multiple-value context 477 478 if i == 0 { 479 init := vs.Values[0] 480 481 var ss []shaderir.Stmt 482 var ok bool 483 initexprs, inittypes, ss, ok = s.parseExpr(block, init, true) 484 if !ok { 485 return nil, nil, nil, false 486 } 487 stmts = append(stmts, ss...) 488 489 if t.Main == shaderir.None { 490 ts, ok := s.functionReturnTypes(block, init) 491 if ok { 492 inittypes = ts 493 } 494 if len(ts) != len(vs.Names) { 495 s.addError(vs.Pos(), fmt.Sprintf("the numbers of lhs and rhs don't match")) 496 continue 497 } 498 } 499 } 500 if len(inittypes) > 0 { 501 t = inittypes[i] 502 } 503 504 // Add the same initexprs for each variable. 505 inits = append(inits, initexprs...) 506 } 507 508 name := n.Name 509 for _, v := range append(block.vars, vars...) { 510 if v.name == name { 511 s.addError(vs.Pos(), fmt.Sprintf("duplicated local variable name: %s", name)) 512 return nil, nil, nil, false 513 } 514 } 515 vars = append(vars, variable{ 516 name: name, 517 typ: t, 518 }) 519 } 520 521 return vars, inits, stmts, true 522 } 523 524 func (s *compileState) parseConstant(block *block, vs *ast.ValueSpec) []constant { 525 var t shaderir.Type 526 if vs.Type != nil { 527 var ok bool 528 t, ok = s.parseType(block, vs.Type) 529 if !ok { 530 return nil 531 } 532 } 533 534 var cs []constant 535 for i, n := range vs.Names { 536 cs = append(cs, constant{ 537 name: n.Name, 538 typ: t, 539 init: vs.Values[i], 540 }) 541 } 542 return cs 543 } 544 545 func (cs *compileState) parseFuncParams(block *block, d *ast.FuncDecl) (in, out []variable) { 546 for _, f := range d.Type.Params.List { 547 t, ok := cs.parseType(block, f.Type) 548 if !ok { 549 return 550 } 551 for _, n := range f.Names { 552 in = append(in, variable{ 553 name: n.Name, 554 typ: t, 555 }) 556 } 557 } 558 559 if d.Type.Results == nil { 560 return 561 } 562 563 for _, f := range d.Type.Results.List { 564 t, ok := cs.parseType(block, f.Type) 565 if !ok { 566 return 567 } 568 if len(f.Names) == 0 { 569 out = append(out, variable{ 570 name: "", 571 typ: t, 572 }) 573 } else { 574 for _, n := range f.Names { 575 out = append(out, variable{ 576 name: n.Name, 577 typ: t, 578 }) 579 } 580 } 581 } 582 return 583 } 584 585 func (cs *compileState) parseFunc(block *block, d *ast.FuncDecl) (function, bool) { 586 if d.Name == nil { 587 cs.addError(d.Pos(), "function must have a name") 588 return function{}, false 589 } 590 if d.Name.Name == "init" { 591 cs.addError(d.Pos(), "init function is not implemented") 592 return function{}, false 593 } 594 if d.Body == nil { 595 cs.addError(d.Pos(), "function must have a body") 596 return function{}, false 597 } 598 599 inParams, outParams := cs.parseFuncParams(block, d) 600 601 checkVaryings := func(vs []variable) { 602 if len(cs.ir.Varyings) != len(vs) { 603 cs.addError(d.Pos(), fmt.Sprintf("the number of vertex entry point's returning values and the number of framgent entry point's params must be the same")) 604 return 605 } 606 for i, t := range cs.ir.Varyings { 607 if t.Main != vs[i].typ.Main { 608 cs.addError(d.Pos(), fmt.Sprintf("vertex entry point's returning value types and framgent entry point's param types must match")) 609 } 610 } 611 } 612 613 if block == &cs.global { 614 switch d.Name.Name { 615 case cs.vertexEntry: 616 for _, v := range inParams { 617 cs.ir.Attributes = append(cs.ir.Attributes, v.typ) 618 } 619 620 // The first out-param is treated as gl_Position in GLSL. 621 if len(outParams) == 0 { 622 cs.addError(d.Pos(), fmt.Sprintf("vertex entry point must have at least one returning vec4 value for a position")) 623 return function{}, false 624 } 625 if outParams[0].typ.Main != shaderir.Vec4 { 626 cs.addError(d.Pos(), fmt.Sprintf("vertex entry point must have at least one returning vec4 value for a position")) 627 return function{}, false 628 } 629 630 if cs.varyingParsed { 631 checkVaryings(outParams[1:]) 632 } else { 633 for _, v := range outParams[1:] { 634 // TODO: Check that these params are not arrays or structs 635 cs.ir.Varyings = append(cs.ir.Varyings, v.typ) 636 } 637 } 638 cs.varyingParsed = true 639 case cs.fragmentEntry: 640 if len(inParams) == 0 { 641 cs.addError(d.Pos(), fmt.Sprintf("fragment entry point must have at least one vec4 parameter for a position")) 642 return function{}, false 643 } 644 if inParams[0].typ.Main != shaderir.Vec4 { 645 cs.addError(d.Pos(), fmt.Sprintf("fragment entry point must have at least one vec4 parameter for a position")) 646 return function{}, false 647 } 648 649 if len(outParams) != 1 { 650 cs.addError(d.Pos(), fmt.Sprintf("fragment entry point must have one returning vec4 value for a color")) 651 return function{}, false 652 } 653 if outParams[0].typ.Main != shaderir.Vec4 { 654 cs.addError(d.Pos(), fmt.Sprintf("fragment entry point must have one returning vec4 value for a color")) 655 return function{}, false 656 } 657 658 if cs.varyingParsed { 659 checkVaryings(inParams[1:]) 660 } else { 661 for _, v := range inParams[1:] { 662 cs.ir.Varyings = append(cs.ir.Varyings, v.typ) 663 } 664 } 665 cs.varyingParsed = true 666 } 667 } 668 669 b, ok := cs.parseBlock(block, d.Name.Name, d.Body.List, inParams, outParams, true) 670 if !ok { 671 return function{}, false 672 } 673 674 if len(outParams) > 0 { 675 var hasReturn func(stmts []shaderir.Stmt) bool 676 hasReturn = func(stmts []shaderir.Stmt) bool { 677 for _, stmt := range stmts { 678 if stmt.Type == shaderir.Return { 679 return true 680 } 681 for _, b := range stmt.Blocks { 682 if hasReturn(b.Stmts) { 683 return true 684 } 685 } 686 } 687 return false 688 } 689 690 if !hasReturn(b.ir.Stmts) { 691 cs.addError(d.Pos(), fmt.Sprintf("function %s must have a return statement but not", d.Name)) 692 return function{}, false 693 } 694 } 695 696 var inT, outT []shaderir.Type 697 for _, v := range inParams { 698 inT = append(inT, v.typ) 699 } 700 for _, v := range outParams { 701 outT = append(outT, v.typ) 702 } 703 704 return function{ 705 name: d.Name.Name, 706 block: b, 707 ir: shaderir.Func{ 708 InParams: inT, 709 OutParams: outT, 710 Block: b.ir, 711 }, 712 }, true 713 } 714 715 func (cs *compileState) parseBlock(outer *block, fname string, stmts []ast.Stmt, inParams, outParams []variable, checkLocalVariableUsage bool) (*block, bool) { 716 var vars []variable 717 if outer == &cs.global { 718 vars = make([]variable, 0, len(inParams)+len(outParams)) 719 vars = append(vars, inParams...) 720 vars = append(vars, outParams...) 721 } 722 723 var offset int 724 for b := outer; b != nil; b = b.outer { 725 offset += len(b.vars) 726 } 727 if outer == &cs.global { 728 offset += len(inParams) + len(outParams) 729 } 730 731 block := &block{ 732 vars: vars, 733 outer: outer, 734 ir: &shaderir.Block{ 735 LocalVarIndexOffset: offset, 736 }, 737 } 738 739 defer func() { 740 var offset int 741 if outer == &cs.global { 742 offset = len(inParams) + len(outParams) 743 } 744 for _, v := range block.vars[offset:] { 745 if v.forLoopCounter { 746 block.ir.LocalVars = append(block.ir.LocalVars, shaderir.Type{}) 747 continue 748 } 749 block.ir.LocalVars = append(block.ir.LocalVars, v.typ) 750 } 751 }() 752 753 if outer.outer == nil && len(outParams) > 0 && outParams[0].name != "" { 754 for i := range outParams { 755 block.ir.Stmts = append(block.ir.Stmts, shaderir.Stmt{ 756 Type: shaderir.Init, 757 InitIndex: len(inParams) + i, 758 }) 759 } 760 } 761 762 for _, stmt := range stmts { 763 ss, ok := cs.parseStmt(block, fname, stmt, inParams, outParams) 764 if !ok { 765 return nil, false 766 } 767 block.ir.Stmts = append(block.ir.Stmts, ss...) 768 } 769 770 if checkLocalVariableUsage && len(block.unusedVars) > 0 { 771 for idx, pos := range block.unusedVars { 772 cs.addError(pos, fmt.Sprintf("local variable %s is not used", block.vars[idx].name)) 773 } 774 return nil, false 775 } 776 777 return block, true 778 }