zorldo

Goofing around with Ebiten
git clone git://bsandro.tech/zorldo
Log | Files | Refs | README

metal.go (15845B)


      1 // Copyright 2020 The Ebiten Authors
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 package metal
     16 
     17 import (
     18 	"fmt"
     19 	"go/constant"
     20 	"go/token"
     21 	"regexp"
     22 	"strings"
     23 
     24 	"github.com/hajimehoshi/ebiten/v2/internal/shaderir"
     25 )
     26 
     27 const (
     28 	vertexOut   = "varyings"
     29 	fragmentOut = "out"
     30 )
     31 
     32 type compileContext struct {
     33 	structNames map[string]string
     34 	structTypes []shaderir.Type
     35 }
     36 
     37 func (c *compileContext) structName(p *shaderir.Program, t *shaderir.Type) string {
     38 	if t.Main != shaderir.Struct {
     39 		panic("metal: the given type at structName must be a struct")
     40 	}
     41 	s := t.String()
     42 	if n, ok := c.structNames[s]; ok {
     43 		return n
     44 	}
     45 	n := fmt.Sprintf("S%d", len(c.structNames))
     46 	c.structNames[s] = n
     47 	c.structTypes = append(c.structTypes, *t)
     48 	return n
     49 }
     50 
     51 const Prelude = `#include <metal_stdlib>
     52 
     53 using namespace metal;
     54 
     55 constexpr sampler texture_sampler{filter::nearest};`
     56 
     57 func Compile(p *shaderir.Program, vertex, fragment string) (shader string) {
     58 	c := &compileContext{
     59 		structNames: map[string]string{},
     60 	}
     61 
     62 	var lines []string
     63 	lines = append(lines, strings.Split(Prelude, "\n")...)
     64 	lines = append(lines, "", "{{.Structs}}")
     65 
     66 	if len(p.Attributes) > 0 {
     67 		lines = append(lines, "")
     68 		lines = append(lines, "struct Attributes {")
     69 		for i, a := range p.Attributes {
     70 			lines = append(lines, fmt.Sprintf("\t%s;", c.metalVarDecl(p, &a, fmt.Sprintf("M%d", i), true, false)))
     71 		}
     72 		lines = append(lines, "};")
     73 	}
     74 
     75 	if len(p.Varyings) > 0 {
     76 		lines = append(lines, "")
     77 		lines = append(lines, "struct Varyings {")
     78 		lines = append(lines, "\tfloat4 Position [[position]];")
     79 		for i, v := range p.Varyings {
     80 			lines = append(lines, fmt.Sprintf("\t%s;", c.metalVarDecl(p, &v, fmt.Sprintf("M%d", i), false, false)))
     81 		}
     82 		lines = append(lines, "};")
     83 	}
     84 
     85 	if len(p.Funcs) > 0 {
     86 		lines = append(lines, "")
     87 		for _, f := range p.Funcs {
     88 			lines = append(lines, c.metalFunc(p, &f, true)...)
     89 		}
     90 		for _, f := range p.Funcs {
     91 			if len(lines) > 0 && lines[len(lines)-1] != "" {
     92 				lines = append(lines, "")
     93 			}
     94 			lines = append(lines, c.metalFunc(p, &f, false)...)
     95 		}
     96 	}
     97 
     98 	if p.VertexFunc.Block != nil && len(p.VertexFunc.Block.Stmts) > 0 {
     99 		lines = append(lines, "")
    100 		lines = append(lines,
    101 			fmt.Sprintf("vertex Varyings %s(", vertex),
    102 			"\tuint vid [[vertex_id]],",
    103 			"\tconst device Attributes* attributes [[buffer(0)]]")
    104 		for i, u := range p.Uniforms {
    105 			lines[len(lines)-1] += ","
    106 			lines = append(lines, fmt.Sprintf("\tconstant %s [[buffer(%d)]]", c.metalVarDecl(p, &u, fmt.Sprintf("U%d", i), false, true), i+1))
    107 		}
    108 		for i := 0; i < p.TextureNum; i++ {
    109 			lines[len(lines)-1] += ","
    110 			lines = append(lines, fmt.Sprintf("\ttexture2d<float> T%[1]d [[texture(%[1]d)]]", i))
    111 		}
    112 		lines[len(lines)-1] += ") {"
    113 		lines = append(lines, fmt.Sprintf("\tVaryings %s = {};", vertexOut))
    114 		lines = append(lines, c.metalBlock(p, p.VertexFunc.Block, p.VertexFunc.Block, 0)...)
    115 		if last := fmt.Sprintf("\treturn %s;", vertexOut); lines[len(lines)-1] != last {
    116 			lines = append(lines, last)
    117 		}
    118 		lines = append(lines, "}")
    119 	}
    120 
    121 	if p.FragmentFunc.Block != nil && len(p.FragmentFunc.Block.Stmts) > 0 {
    122 		lines = append(lines, "")
    123 		lines = append(lines,
    124 			fmt.Sprintf("fragment float4 %s(", fragment),
    125 			"\tVaryings varyings [[stage_in]]")
    126 		for i, u := range p.Uniforms {
    127 			lines[len(lines)-1] += ","
    128 			lines = append(lines, fmt.Sprintf("\tconstant %s [[buffer(%d)]]", c.metalVarDecl(p, &u, fmt.Sprintf("U%d", i), false, true), i+1))
    129 		}
    130 		for i := 0; i < p.TextureNum; i++ {
    131 			lines[len(lines)-1] += ","
    132 			lines = append(lines, fmt.Sprintf("\ttexture2d<float> T%[1]d [[texture(%[1]d)]]", i))
    133 		}
    134 		lines[len(lines)-1] += ") {"
    135 		lines = append(lines, fmt.Sprintf("\tfloat4 %s = float4(0);", fragmentOut))
    136 		lines = append(lines, c.metalBlock(p, p.FragmentFunc.Block, p.FragmentFunc.Block, 0)...)
    137 		if last := fmt.Sprintf("\treturn %s;", fragmentOut); lines[len(lines)-1] != last {
    138 			lines = append(lines, last)
    139 		}
    140 		lines = append(lines, "}")
    141 	}
    142 
    143 	ls := strings.Join(lines, "\n")
    144 
    145 	// Struct types are determined after converting the program.
    146 	if len(c.structTypes) > 0 {
    147 		var stlines []string
    148 		for i, t := range c.structTypes {
    149 			stlines = append(stlines, fmt.Sprintf("struct S%d {", i))
    150 			for j, st := range t.Sub {
    151 				stlines = append(stlines, fmt.Sprintf("\t%s;", c.metalVarDecl(p, &st, fmt.Sprintf("M%d", j), false, false)))
    152 			}
    153 			stlines = append(stlines, "};")
    154 		}
    155 		ls = strings.ReplaceAll(ls, "{{.Structs}}", strings.Join(stlines, "\n"))
    156 	} else {
    157 		ls = strings.ReplaceAll(ls, "{{.Structs}}", "")
    158 	}
    159 
    160 	nls := regexp.MustCompile(`\n\n+`)
    161 	ls = nls.ReplaceAllString(ls, "\n\n")
    162 	ls = strings.TrimSpace(ls) + "\n"
    163 
    164 	return ls
    165 }
    166 
    167 func (c *compileContext) metalType(p *shaderir.Program, t *shaderir.Type, packed bool, ref bool) string {
    168 	switch t.Main {
    169 	case shaderir.None:
    170 		return "void"
    171 	case shaderir.Struct:
    172 		return c.structName(p, t)
    173 	default:
    174 		return typeString(t, packed, ref)
    175 	}
    176 }
    177 
    178 func (c *compileContext) metalVarDecl(p *shaderir.Program, t *shaderir.Type, varname string, packed bool, ref bool) string {
    179 	switch t.Main {
    180 	case shaderir.None:
    181 		return "?(none)"
    182 	case shaderir.Struct:
    183 		s := c.structName(p, t)
    184 		if ref {
    185 			s += "&"
    186 		}
    187 		return fmt.Sprintf("%s %s", s, varname)
    188 	default:
    189 		t := typeString(t, packed, ref)
    190 		return fmt.Sprintf("%s %s", t, varname)
    191 	}
    192 }
    193 
    194 func (c *compileContext) metalVarInit(p *shaderir.Program, t *shaderir.Type) string {
    195 	switch t.Main {
    196 	case shaderir.None:
    197 		return "?(none)"
    198 	case shaderir.Array:
    199 		return "{}"
    200 	case shaderir.Struct:
    201 		return "{}"
    202 	case shaderir.Bool:
    203 		return "false"
    204 	case shaderir.Int:
    205 		return "0"
    206 	case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4:
    207 		return fmt.Sprintf("%s(0)", basicTypeString(t.Main, false))
    208 	default:
    209 		t := c.metalType(p, t, false, false)
    210 		panic(fmt.Sprintf("?(unexpected type: %s)", t))
    211 	}
    212 }
    213 
    214 func (c *compileContext) metalFunc(p *shaderir.Program, f *shaderir.Func, prototype bool) []string {
    215 	var args []string
    216 
    217 	// Uniform variables and texture variables. In Metal, non-const global variables are not available.
    218 	for i, u := range p.Uniforms {
    219 		args = append(args, "constant "+c.metalVarDecl(p, &u, fmt.Sprintf("U%d", i), false, true))
    220 	}
    221 	for i := 0; i < p.TextureNum; i++ {
    222 		args = append(args, fmt.Sprintf("texture2d<float> T%d", i))
    223 	}
    224 
    225 	var idx int
    226 	for _, t := range f.InParams {
    227 		args = append(args, c.metalVarDecl(p, &t, fmt.Sprintf("l%d", idx), false, false))
    228 		idx++
    229 	}
    230 	for _, t := range f.OutParams {
    231 		args = append(args, "thread "+c.metalVarDecl(p, &t, fmt.Sprintf("l%d", idx), false, true))
    232 		idx++
    233 	}
    234 	argsstr := "void"
    235 	if len(args) > 0 {
    236 		argsstr = strings.Join(args, ", ")
    237 	}
    238 
    239 	t := c.metalType(p, &f.Return, false, false)
    240 	sig := fmt.Sprintf("%s F%d(%s)", t, f.Index, argsstr)
    241 
    242 	var lines []string
    243 	if prototype {
    244 		lines = append(lines, fmt.Sprintf("%s;", sig))
    245 		return lines
    246 	}
    247 	lines = append(lines, fmt.Sprintf("%s {", sig))
    248 	lines = append(lines, c.metalBlock(p, f.Block, f.Block, 0)...)
    249 	lines = append(lines, "}")
    250 
    251 	return lines
    252 }
    253 
    254 func constantToNumberLiteral(t shaderir.ConstType, v constant.Value) string {
    255 	switch t {
    256 	case shaderir.ConstTypeNone:
    257 		if v.Kind() == constant.Bool {
    258 			if constant.BoolVal(v) {
    259 				return "true"
    260 			}
    261 			return "false"
    262 		}
    263 		fallthrough
    264 	case shaderir.ConstTypeFloat:
    265 		if i := constant.ToInt(v); i.Kind() == constant.Int {
    266 			x, _ := constant.Int64Val(i)
    267 			return fmt.Sprintf("%d.0", x)
    268 		}
    269 		if i := constant.ToFloat(v); i.Kind() == constant.Float {
    270 			x, _ := constant.Float64Val(i)
    271 			return fmt.Sprintf("%.10e", x)
    272 		}
    273 	case shaderir.ConstTypeInt:
    274 		if i := constant.ToInt(v); i.Kind() == constant.Int {
    275 			x, _ := constant.Int64Val(i)
    276 			return fmt.Sprintf("%d", x)
    277 		}
    278 	}
    279 	return fmt.Sprintf("?(unexpected literal: %s)", v)
    280 }
    281 
    282 func localVariableName(p *shaderir.Program, topBlock *shaderir.Block, idx int) string {
    283 	switch topBlock {
    284 	case p.VertexFunc.Block:
    285 		na := len(p.Attributes)
    286 		nv := len(p.Varyings)
    287 		switch {
    288 		case idx < na:
    289 			return fmt.Sprintf("attributes[vid].M%d", idx)
    290 		case idx == na:
    291 			return fmt.Sprintf("%s.Position", vertexOut)
    292 		case idx < na+nv+1:
    293 			return fmt.Sprintf("%s.M%d", vertexOut, idx-na-1)
    294 		default:
    295 			return fmt.Sprintf("l%d", idx-(na+nv+1))
    296 		}
    297 	case p.FragmentFunc.Block:
    298 		nv := len(p.Varyings)
    299 		switch {
    300 		case idx == 0:
    301 			return fmt.Sprintf("varyings.Position")
    302 		case idx < nv+1:
    303 			return fmt.Sprintf("varyings.M%d", idx-1)
    304 		case idx == nv+1:
    305 			return fragmentOut
    306 		default:
    307 			return fmt.Sprintf("l%d", idx-(nv+2))
    308 		}
    309 	default:
    310 		return fmt.Sprintf("l%d", idx)
    311 	}
    312 }
    313 
    314 func (c *compileContext) initVariable(p *shaderir.Program, topBlock, block *shaderir.Block, index int, decl bool, level int) []string {
    315 	idt := strings.Repeat("\t", level+1)
    316 	name := localVariableName(p, topBlock, index)
    317 	t := p.LocalVariableType(topBlock, block, index)
    318 
    319 	var lines []string
    320 	if decl {
    321 		lines = append(lines, fmt.Sprintf("%s%s = %s;", idt, c.metalVarDecl(p, &t, name, false, false), c.metalVarInit(p, &t)))
    322 	} else {
    323 		lines = append(lines, fmt.Sprintf("%s%s = %s;", idt, name, c.metalVarInit(p, &t)))
    324 	}
    325 	return lines
    326 }
    327 
    328 func (c *compileContext) metalBlock(p *shaderir.Program, topBlock, block *shaderir.Block, level int) []string {
    329 	if block == nil {
    330 		return nil
    331 	}
    332 
    333 	idt := strings.Repeat("\t", level+1)
    334 
    335 	var lines []string
    336 	for i, t := range block.LocalVars {
    337 		// The type is None e.g., when the variable is a for-loop counter.
    338 		if t.Main != shaderir.None {
    339 			lines = append(lines, c.initVariable(p, topBlock, block, block.LocalVarIndexOffset+i, true, level)...)
    340 		}
    341 	}
    342 
    343 	var metalExpr func(e *shaderir.Expr) string
    344 	metalExpr = func(e *shaderir.Expr) string {
    345 		switch e.Type {
    346 		case shaderir.NumberExpr:
    347 			return constantToNumberLiteral(e.ConstType, e.Const)
    348 		case shaderir.UniformVariable:
    349 			return fmt.Sprintf("U%d", e.Index)
    350 		case shaderir.TextureVariable:
    351 			return fmt.Sprintf("T%d", e.Index)
    352 		case shaderir.LocalVariable:
    353 			return localVariableName(p, topBlock, e.Index)
    354 		case shaderir.StructMember:
    355 			return fmt.Sprintf("M%d", e.Index)
    356 		case shaderir.BuiltinFuncExpr:
    357 			return builtinFuncString(e.BuiltinFunc)
    358 		case shaderir.SwizzlingExpr:
    359 			if !shaderir.IsValidSwizzling(e.Swizzling) {
    360 				return fmt.Sprintf("?(unexpected swizzling: %s)", e.Swizzling)
    361 			}
    362 			return e.Swizzling
    363 		case shaderir.FunctionExpr:
    364 			return fmt.Sprintf("F%d", e.Index)
    365 		case shaderir.Unary:
    366 			var op string
    367 			switch e.Op {
    368 			case shaderir.Add, shaderir.Sub, shaderir.NotOp:
    369 				op = string(e.Op)
    370 			default:
    371 				op = fmt.Sprintf("?(unexpected op: %s)", string(e.Op))
    372 			}
    373 			return fmt.Sprintf("%s(%s)", op, metalExpr(&e.Exprs[0]))
    374 		case shaderir.Binary:
    375 			return fmt.Sprintf("(%s) %s (%s)", metalExpr(&e.Exprs[0]), e.Op, metalExpr(&e.Exprs[1]))
    376 		case shaderir.Selection:
    377 			return fmt.Sprintf("(%s) ? (%s) : (%s)", metalExpr(&e.Exprs[0]), metalExpr(&e.Exprs[1]), metalExpr(&e.Exprs[2]))
    378 		case shaderir.Call:
    379 			callee := e.Exprs[0]
    380 			var args []string
    381 			if callee.Type != shaderir.BuiltinFuncExpr {
    382 				for i := range p.Uniforms {
    383 					args = append(args, fmt.Sprintf("U%d", i))
    384 				}
    385 				for i := 0; i < p.TextureNum; i++ {
    386 					args = append(args, fmt.Sprintf("T%d", i))
    387 				}
    388 			}
    389 			for _, exp := range e.Exprs[1:] {
    390 				args = append(args, metalExpr(&exp))
    391 			}
    392 			if callee.Type == shaderir.BuiltinFuncExpr && callee.BuiltinFunc == shaderir.Texture2DF {
    393 				return fmt.Sprintf("%s.sample(texture_sampler, %s)", args[0], strings.Join(args[1:], ", "))
    394 			}
    395 			return fmt.Sprintf("%s(%s)", metalExpr(&callee), strings.Join(args, ", "))
    396 		case shaderir.FieldSelector:
    397 			return fmt.Sprintf("(%s).%s", metalExpr(&e.Exprs[0]), metalExpr(&e.Exprs[1]))
    398 		case shaderir.Index:
    399 			return fmt.Sprintf("(%s)[%s]", metalExpr(&e.Exprs[0]), metalExpr(&e.Exprs[1]))
    400 		default:
    401 			return fmt.Sprintf("?(unexpected expr: %d)", e.Type)
    402 		}
    403 	}
    404 
    405 	for _, s := range block.Stmts {
    406 		switch s.Type {
    407 		case shaderir.ExprStmt:
    408 			lines = append(lines, fmt.Sprintf("%s%s;", idt, metalExpr(&s.Exprs[0])))
    409 		case shaderir.BlockStmt:
    410 			lines = append(lines, idt+"{")
    411 			lines = append(lines, c.metalBlock(p, topBlock, s.Blocks[0], level+1)...)
    412 			lines = append(lines, idt+"}")
    413 		case shaderir.Assign:
    414 			lines = append(lines, fmt.Sprintf("%s%s = %s;", idt, metalExpr(&s.Exprs[0]), metalExpr(&s.Exprs[1])))
    415 		case shaderir.Init:
    416 			init := true
    417 			if topBlock == p.VertexFunc.Block {
    418 				// In the vertex function, varying values are the output parameters.
    419 				// These values are represented as a struct and not needed to be initialized.
    420 				na := len(p.Attributes)
    421 				nv := len(p.Varyings)
    422 				if s.InitIndex < na+nv+1 {
    423 					init = false
    424 				}
    425 			}
    426 			if init {
    427 				lines = append(lines, c.initVariable(p, topBlock, block, s.InitIndex, false, level)...)
    428 			}
    429 		case shaderir.If:
    430 			lines = append(lines, fmt.Sprintf("%sif (%s) {", idt, metalExpr(&s.Exprs[0])))
    431 			lines = append(lines, c.metalBlock(p, topBlock, s.Blocks[0], level+1)...)
    432 			if len(s.Blocks) > 1 {
    433 				lines = append(lines, fmt.Sprintf("%s} else {", idt))
    434 				lines = append(lines, c.metalBlock(p, topBlock, s.Blocks[1], level+1)...)
    435 			}
    436 			lines = append(lines, fmt.Sprintf("%s}", idt))
    437 		case shaderir.For:
    438 			var ct shaderir.ConstType
    439 			switch s.ForVarType.Main {
    440 			case shaderir.Int:
    441 				ct = shaderir.ConstTypeInt
    442 			case shaderir.Float:
    443 				ct = shaderir.ConstTypeFloat
    444 			}
    445 
    446 			v := localVariableName(p, topBlock, s.ForVarIndex)
    447 			var delta string
    448 			switch val, _ := constant.Float64Val(s.ForDelta); val {
    449 			case 0:
    450 				delta = fmt.Sprintf("?(unexpected delta: %v)", s.ForDelta)
    451 			case 1:
    452 				delta = fmt.Sprintf("%s++", v)
    453 			case -1:
    454 				delta = fmt.Sprintf("%s--", v)
    455 			default:
    456 				d := s.ForDelta
    457 				if val > 0 {
    458 					delta = fmt.Sprintf("%s += %s", v, constantToNumberLiteral(ct, d))
    459 				} else {
    460 					d = constant.UnaryOp(token.SUB, d, 0)
    461 					delta = fmt.Sprintf("%s -= %s", v, constantToNumberLiteral(ct, d))
    462 				}
    463 			}
    464 			var op string
    465 			switch s.ForOp {
    466 			case shaderir.LessThanOp, shaderir.LessThanEqualOp, shaderir.GreaterThanOp, shaderir.GreaterThanEqualOp, shaderir.EqualOp, shaderir.NotEqualOp:
    467 				op = string(s.ForOp)
    468 			default:
    469 				op = fmt.Sprintf("?(unexpected op: %s)", string(s.ForOp))
    470 			}
    471 
    472 			t := s.ForVarType
    473 			init := constantToNumberLiteral(ct, s.ForInit)
    474 			end := constantToNumberLiteral(ct, s.ForEnd)
    475 			ts := typeString(&t, false, false)
    476 			lines = append(lines, fmt.Sprintf("%sfor (%s %s = %s; %s %s %s; %s) {", idt, ts, v, init, v, op, end, delta))
    477 			lines = append(lines, c.metalBlock(p, topBlock, s.Blocks[0], level+1)...)
    478 			lines = append(lines, fmt.Sprintf("%s}", idt))
    479 		case shaderir.Continue:
    480 			lines = append(lines, idt+"continue;")
    481 		case shaderir.Break:
    482 			lines = append(lines, idt+"break;")
    483 		case shaderir.Return:
    484 			switch {
    485 			case topBlock == p.VertexFunc.Block:
    486 				lines = append(lines, fmt.Sprintf("%sreturn %s;", idt, vertexOut))
    487 			case topBlock == p.FragmentFunc.Block:
    488 				lines = append(lines, fmt.Sprintf("%sreturn %s;", idt, fragmentOut))
    489 			case len(s.Exprs) == 0:
    490 				lines = append(lines, idt+"return;")
    491 			default:
    492 				lines = append(lines, fmt.Sprintf("%sreturn %s;", idt, metalExpr(&s.Exprs[0])))
    493 			}
    494 		case shaderir.Discard:
    495 			lines = append(lines, idt+"discard;")
    496 		default:
    497 			lines = append(lines, fmt.Sprintf("%s?(unexpected stmt: %d)", idt, s.Type))
    498 		}
    499 	}
    500 
    501 	return lines
    502 }