metal.go (15845B)
1 // Copyright 2020 The Ebiten Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package metal 16 17 import ( 18 "fmt" 19 "go/constant" 20 "go/token" 21 "regexp" 22 "strings" 23 24 "github.com/hajimehoshi/ebiten/v2/internal/shaderir" 25 ) 26 27 const ( 28 vertexOut = "varyings" 29 fragmentOut = "out" 30 ) 31 32 type compileContext struct { 33 structNames map[string]string 34 structTypes []shaderir.Type 35 } 36 37 func (c *compileContext) structName(p *shaderir.Program, t *shaderir.Type) string { 38 if t.Main != shaderir.Struct { 39 panic("metal: the given type at structName must be a struct") 40 } 41 s := t.String() 42 if n, ok := c.structNames[s]; ok { 43 return n 44 } 45 n := fmt.Sprintf("S%d", len(c.structNames)) 46 c.structNames[s] = n 47 c.structTypes = append(c.structTypes, *t) 48 return n 49 } 50 51 const Prelude = `#include <metal_stdlib> 52 53 using namespace metal; 54 55 constexpr sampler texture_sampler{filter::nearest};` 56 57 func Compile(p *shaderir.Program, vertex, fragment string) (shader string) { 58 c := &compileContext{ 59 structNames: map[string]string{}, 60 } 61 62 var lines []string 63 lines = append(lines, strings.Split(Prelude, "\n")...) 64 lines = append(lines, "", "{{.Structs}}") 65 66 if len(p.Attributes) > 0 { 67 lines = append(lines, "") 68 lines = append(lines, "struct Attributes {") 69 for i, a := range p.Attributes { 70 lines = append(lines, fmt.Sprintf("\t%s;", c.metalVarDecl(p, &a, fmt.Sprintf("M%d", i), true, false))) 71 } 72 lines = append(lines, "};") 73 } 74 75 if len(p.Varyings) > 0 { 76 lines = append(lines, "") 77 lines = append(lines, "struct Varyings {") 78 lines = append(lines, "\tfloat4 Position [[position]];") 79 for i, v := range p.Varyings { 80 lines = append(lines, fmt.Sprintf("\t%s;", c.metalVarDecl(p, &v, fmt.Sprintf("M%d", i), false, false))) 81 } 82 lines = append(lines, "};") 83 } 84 85 if len(p.Funcs) > 0 { 86 lines = append(lines, "") 87 for _, f := range p.Funcs { 88 lines = append(lines, c.metalFunc(p, &f, true)...) 89 } 90 for _, f := range p.Funcs { 91 if len(lines) > 0 && lines[len(lines)-1] != "" { 92 lines = append(lines, "") 93 } 94 lines = append(lines, c.metalFunc(p, &f, false)...) 95 } 96 } 97 98 if p.VertexFunc.Block != nil && len(p.VertexFunc.Block.Stmts) > 0 { 99 lines = append(lines, "") 100 lines = append(lines, 101 fmt.Sprintf("vertex Varyings %s(", vertex), 102 "\tuint vid [[vertex_id]],", 103 "\tconst device Attributes* attributes [[buffer(0)]]") 104 for i, u := range p.Uniforms { 105 lines[len(lines)-1] += "," 106 lines = append(lines, fmt.Sprintf("\tconstant %s [[buffer(%d)]]", c.metalVarDecl(p, &u, fmt.Sprintf("U%d", i), false, true), i+1)) 107 } 108 for i := 0; i < p.TextureNum; i++ { 109 lines[len(lines)-1] += "," 110 lines = append(lines, fmt.Sprintf("\ttexture2d<float> T%[1]d [[texture(%[1]d)]]", i)) 111 } 112 lines[len(lines)-1] += ") {" 113 lines = append(lines, fmt.Sprintf("\tVaryings %s = {};", vertexOut)) 114 lines = append(lines, c.metalBlock(p, p.VertexFunc.Block, p.VertexFunc.Block, 0)...) 115 if last := fmt.Sprintf("\treturn %s;", vertexOut); lines[len(lines)-1] != last { 116 lines = append(lines, last) 117 } 118 lines = append(lines, "}") 119 } 120 121 if p.FragmentFunc.Block != nil && len(p.FragmentFunc.Block.Stmts) > 0 { 122 lines = append(lines, "") 123 lines = append(lines, 124 fmt.Sprintf("fragment float4 %s(", fragment), 125 "\tVaryings varyings [[stage_in]]") 126 for i, u := range p.Uniforms { 127 lines[len(lines)-1] += "," 128 lines = append(lines, fmt.Sprintf("\tconstant %s [[buffer(%d)]]", c.metalVarDecl(p, &u, fmt.Sprintf("U%d", i), false, true), i+1)) 129 } 130 for i := 0; i < p.TextureNum; i++ { 131 lines[len(lines)-1] += "," 132 lines = append(lines, fmt.Sprintf("\ttexture2d<float> T%[1]d [[texture(%[1]d)]]", i)) 133 } 134 lines[len(lines)-1] += ") {" 135 lines = append(lines, fmt.Sprintf("\tfloat4 %s = float4(0);", fragmentOut)) 136 lines = append(lines, c.metalBlock(p, p.FragmentFunc.Block, p.FragmentFunc.Block, 0)...) 137 if last := fmt.Sprintf("\treturn %s;", fragmentOut); lines[len(lines)-1] != last { 138 lines = append(lines, last) 139 } 140 lines = append(lines, "}") 141 } 142 143 ls := strings.Join(lines, "\n") 144 145 // Struct types are determined after converting the program. 146 if len(c.structTypes) > 0 { 147 var stlines []string 148 for i, t := range c.structTypes { 149 stlines = append(stlines, fmt.Sprintf("struct S%d {", i)) 150 for j, st := range t.Sub { 151 stlines = append(stlines, fmt.Sprintf("\t%s;", c.metalVarDecl(p, &st, fmt.Sprintf("M%d", j), false, false))) 152 } 153 stlines = append(stlines, "};") 154 } 155 ls = strings.ReplaceAll(ls, "{{.Structs}}", strings.Join(stlines, "\n")) 156 } else { 157 ls = strings.ReplaceAll(ls, "{{.Structs}}", "") 158 } 159 160 nls := regexp.MustCompile(`\n\n+`) 161 ls = nls.ReplaceAllString(ls, "\n\n") 162 ls = strings.TrimSpace(ls) + "\n" 163 164 return ls 165 } 166 167 func (c *compileContext) metalType(p *shaderir.Program, t *shaderir.Type, packed bool, ref bool) string { 168 switch t.Main { 169 case shaderir.None: 170 return "void" 171 case shaderir.Struct: 172 return c.structName(p, t) 173 default: 174 return typeString(t, packed, ref) 175 } 176 } 177 178 func (c *compileContext) metalVarDecl(p *shaderir.Program, t *shaderir.Type, varname string, packed bool, ref bool) string { 179 switch t.Main { 180 case shaderir.None: 181 return "?(none)" 182 case shaderir.Struct: 183 s := c.structName(p, t) 184 if ref { 185 s += "&" 186 } 187 return fmt.Sprintf("%s %s", s, varname) 188 default: 189 t := typeString(t, packed, ref) 190 return fmt.Sprintf("%s %s", t, varname) 191 } 192 } 193 194 func (c *compileContext) metalVarInit(p *shaderir.Program, t *shaderir.Type) string { 195 switch t.Main { 196 case shaderir.None: 197 return "?(none)" 198 case shaderir.Array: 199 return "{}" 200 case shaderir.Struct: 201 return "{}" 202 case shaderir.Bool: 203 return "false" 204 case shaderir.Int: 205 return "0" 206 case shaderir.Float, shaderir.Vec2, shaderir.Vec3, shaderir.Vec4, shaderir.Mat2, shaderir.Mat3, shaderir.Mat4: 207 return fmt.Sprintf("%s(0)", basicTypeString(t.Main, false)) 208 default: 209 t := c.metalType(p, t, false, false) 210 panic(fmt.Sprintf("?(unexpected type: %s)", t)) 211 } 212 } 213 214 func (c *compileContext) metalFunc(p *shaderir.Program, f *shaderir.Func, prototype bool) []string { 215 var args []string 216 217 // Uniform variables and texture variables. In Metal, non-const global variables are not available. 218 for i, u := range p.Uniforms { 219 args = append(args, "constant "+c.metalVarDecl(p, &u, fmt.Sprintf("U%d", i), false, true)) 220 } 221 for i := 0; i < p.TextureNum; i++ { 222 args = append(args, fmt.Sprintf("texture2d<float> T%d", i)) 223 } 224 225 var idx int 226 for _, t := range f.InParams { 227 args = append(args, c.metalVarDecl(p, &t, fmt.Sprintf("l%d", idx), false, false)) 228 idx++ 229 } 230 for _, t := range f.OutParams { 231 args = append(args, "thread "+c.metalVarDecl(p, &t, fmt.Sprintf("l%d", idx), false, true)) 232 idx++ 233 } 234 argsstr := "void" 235 if len(args) > 0 { 236 argsstr = strings.Join(args, ", ") 237 } 238 239 t := c.metalType(p, &f.Return, false, false) 240 sig := fmt.Sprintf("%s F%d(%s)", t, f.Index, argsstr) 241 242 var lines []string 243 if prototype { 244 lines = append(lines, fmt.Sprintf("%s;", sig)) 245 return lines 246 } 247 lines = append(lines, fmt.Sprintf("%s {", sig)) 248 lines = append(lines, c.metalBlock(p, f.Block, f.Block, 0)...) 249 lines = append(lines, "}") 250 251 return lines 252 } 253 254 func constantToNumberLiteral(t shaderir.ConstType, v constant.Value) string { 255 switch t { 256 case shaderir.ConstTypeNone: 257 if v.Kind() == constant.Bool { 258 if constant.BoolVal(v) { 259 return "true" 260 } 261 return "false" 262 } 263 fallthrough 264 case shaderir.ConstTypeFloat: 265 if i := constant.ToInt(v); i.Kind() == constant.Int { 266 x, _ := constant.Int64Val(i) 267 return fmt.Sprintf("%d.0", x) 268 } 269 if i := constant.ToFloat(v); i.Kind() == constant.Float { 270 x, _ := constant.Float64Val(i) 271 return fmt.Sprintf("%.10e", x) 272 } 273 case shaderir.ConstTypeInt: 274 if i := constant.ToInt(v); i.Kind() == constant.Int { 275 x, _ := constant.Int64Val(i) 276 return fmt.Sprintf("%d", x) 277 } 278 } 279 return fmt.Sprintf("?(unexpected literal: %s)", v) 280 } 281 282 func localVariableName(p *shaderir.Program, topBlock *shaderir.Block, idx int) string { 283 switch topBlock { 284 case p.VertexFunc.Block: 285 na := len(p.Attributes) 286 nv := len(p.Varyings) 287 switch { 288 case idx < na: 289 return fmt.Sprintf("attributes[vid].M%d", idx) 290 case idx == na: 291 return fmt.Sprintf("%s.Position", vertexOut) 292 case idx < na+nv+1: 293 return fmt.Sprintf("%s.M%d", vertexOut, idx-na-1) 294 default: 295 return fmt.Sprintf("l%d", idx-(na+nv+1)) 296 } 297 case p.FragmentFunc.Block: 298 nv := len(p.Varyings) 299 switch { 300 case idx == 0: 301 return fmt.Sprintf("varyings.Position") 302 case idx < nv+1: 303 return fmt.Sprintf("varyings.M%d", idx-1) 304 case idx == nv+1: 305 return fragmentOut 306 default: 307 return fmt.Sprintf("l%d", idx-(nv+2)) 308 } 309 default: 310 return fmt.Sprintf("l%d", idx) 311 } 312 } 313 314 func (c *compileContext) initVariable(p *shaderir.Program, topBlock, block *shaderir.Block, index int, decl bool, level int) []string { 315 idt := strings.Repeat("\t", level+1) 316 name := localVariableName(p, topBlock, index) 317 t := p.LocalVariableType(topBlock, block, index) 318 319 var lines []string 320 if decl { 321 lines = append(lines, fmt.Sprintf("%s%s = %s;", idt, c.metalVarDecl(p, &t, name, false, false), c.metalVarInit(p, &t))) 322 } else { 323 lines = append(lines, fmt.Sprintf("%s%s = %s;", idt, name, c.metalVarInit(p, &t))) 324 } 325 return lines 326 } 327 328 func (c *compileContext) metalBlock(p *shaderir.Program, topBlock, block *shaderir.Block, level int) []string { 329 if block == nil { 330 return nil 331 } 332 333 idt := strings.Repeat("\t", level+1) 334 335 var lines []string 336 for i, t := range block.LocalVars { 337 // The type is None e.g., when the variable is a for-loop counter. 338 if t.Main != shaderir.None { 339 lines = append(lines, c.initVariable(p, topBlock, block, block.LocalVarIndexOffset+i, true, level)...) 340 } 341 } 342 343 var metalExpr func(e *shaderir.Expr) string 344 metalExpr = func(e *shaderir.Expr) string { 345 switch e.Type { 346 case shaderir.NumberExpr: 347 return constantToNumberLiteral(e.ConstType, e.Const) 348 case shaderir.UniformVariable: 349 return fmt.Sprintf("U%d", e.Index) 350 case shaderir.TextureVariable: 351 return fmt.Sprintf("T%d", e.Index) 352 case shaderir.LocalVariable: 353 return localVariableName(p, topBlock, e.Index) 354 case shaderir.StructMember: 355 return fmt.Sprintf("M%d", e.Index) 356 case shaderir.BuiltinFuncExpr: 357 return builtinFuncString(e.BuiltinFunc) 358 case shaderir.SwizzlingExpr: 359 if !shaderir.IsValidSwizzling(e.Swizzling) { 360 return fmt.Sprintf("?(unexpected swizzling: %s)", e.Swizzling) 361 } 362 return e.Swizzling 363 case shaderir.FunctionExpr: 364 return fmt.Sprintf("F%d", e.Index) 365 case shaderir.Unary: 366 var op string 367 switch e.Op { 368 case shaderir.Add, shaderir.Sub, shaderir.NotOp: 369 op = string(e.Op) 370 default: 371 op = fmt.Sprintf("?(unexpected op: %s)", string(e.Op)) 372 } 373 return fmt.Sprintf("%s(%s)", op, metalExpr(&e.Exprs[0])) 374 case shaderir.Binary: 375 return fmt.Sprintf("(%s) %s (%s)", metalExpr(&e.Exprs[0]), e.Op, metalExpr(&e.Exprs[1])) 376 case shaderir.Selection: 377 return fmt.Sprintf("(%s) ? (%s) : (%s)", metalExpr(&e.Exprs[0]), metalExpr(&e.Exprs[1]), metalExpr(&e.Exprs[2])) 378 case shaderir.Call: 379 callee := e.Exprs[0] 380 var args []string 381 if callee.Type != shaderir.BuiltinFuncExpr { 382 for i := range p.Uniforms { 383 args = append(args, fmt.Sprintf("U%d", i)) 384 } 385 for i := 0; i < p.TextureNum; i++ { 386 args = append(args, fmt.Sprintf("T%d", i)) 387 } 388 } 389 for _, exp := range e.Exprs[1:] { 390 args = append(args, metalExpr(&exp)) 391 } 392 if callee.Type == shaderir.BuiltinFuncExpr && callee.BuiltinFunc == shaderir.Texture2DF { 393 return fmt.Sprintf("%s.sample(texture_sampler, %s)", args[0], strings.Join(args[1:], ", ")) 394 } 395 return fmt.Sprintf("%s(%s)", metalExpr(&callee), strings.Join(args, ", ")) 396 case shaderir.FieldSelector: 397 return fmt.Sprintf("(%s).%s", metalExpr(&e.Exprs[0]), metalExpr(&e.Exprs[1])) 398 case shaderir.Index: 399 return fmt.Sprintf("(%s)[%s]", metalExpr(&e.Exprs[0]), metalExpr(&e.Exprs[1])) 400 default: 401 return fmt.Sprintf("?(unexpected expr: %d)", e.Type) 402 } 403 } 404 405 for _, s := range block.Stmts { 406 switch s.Type { 407 case shaderir.ExprStmt: 408 lines = append(lines, fmt.Sprintf("%s%s;", idt, metalExpr(&s.Exprs[0]))) 409 case shaderir.BlockStmt: 410 lines = append(lines, idt+"{") 411 lines = append(lines, c.metalBlock(p, topBlock, s.Blocks[0], level+1)...) 412 lines = append(lines, idt+"}") 413 case shaderir.Assign: 414 lines = append(lines, fmt.Sprintf("%s%s = %s;", idt, metalExpr(&s.Exprs[0]), metalExpr(&s.Exprs[1]))) 415 case shaderir.Init: 416 init := true 417 if topBlock == p.VertexFunc.Block { 418 // In the vertex function, varying values are the output parameters. 419 // These values are represented as a struct and not needed to be initialized. 420 na := len(p.Attributes) 421 nv := len(p.Varyings) 422 if s.InitIndex < na+nv+1 { 423 init = false 424 } 425 } 426 if init { 427 lines = append(lines, c.initVariable(p, topBlock, block, s.InitIndex, false, level)...) 428 } 429 case shaderir.If: 430 lines = append(lines, fmt.Sprintf("%sif (%s) {", idt, metalExpr(&s.Exprs[0]))) 431 lines = append(lines, c.metalBlock(p, topBlock, s.Blocks[0], level+1)...) 432 if len(s.Blocks) > 1 { 433 lines = append(lines, fmt.Sprintf("%s} else {", idt)) 434 lines = append(lines, c.metalBlock(p, topBlock, s.Blocks[1], level+1)...) 435 } 436 lines = append(lines, fmt.Sprintf("%s}", idt)) 437 case shaderir.For: 438 var ct shaderir.ConstType 439 switch s.ForVarType.Main { 440 case shaderir.Int: 441 ct = shaderir.ConstTypeInt 442 case shaderir.Float: 443 ct = shaderir.ConstTypeFloat 444 } 445 446 v := localVariableName(p, topBlock, s.ForVarIndex) 447 var delta string 448 switch val, _ := constant.Float64Val(s.ForDelta); val { 449 case 0: 450 delta = fmt.Sprintf("?(unexpected delta: %v)", s.ForDelta) 451 case 1: 452 delta = fmt.Sprintf("%s++", v) 453 case -1: 454 delta = fmt.Sprintf("%s--", v) 455 default: 456 d := s.ForDelta 457 if val > 0 { 458 delta = fmt.Sprintf("%s += %s", v, constantToNumberLiteral(ct, d)) 459 } else { 460 d = constant.UnaryOp(token.SUB, d, 0) 461 delta = fmt.Sprintf("%s -= %s", v, constantToNumberLiteral(ct, d)) 462 } 463 } 464 var op string 465 switch s.ForOp { 466 case shaderir.LessThanOp, shaderir.LessThanEqualOp, shaderir.GreaterThanOp, shaderir.GreaterThanEqualOp, shaderir.EqualOp, shaderir.NotEqualOp: 467 op = string(s.ForOp) 468 default: 469 op = fmt.Sprintf("?(unexpected op: %s)", string(s.ForOp)) 470 } 471 472 t := s.ForVarType 473 init := constantToNumberLiteral(ct, s.ForInit) 474 end := constantToNumberLiteral(ct, s.ForEnd) 475 ts := typeString(&t, false, false) 476 lines = append(lines, fmt.Sprintf("%sfor (%s %s = %s; %s %s %s; %s) {", idt, ts, v, init, v, op, end, delta)) 477 lines = append(lines, c.metalBlock(p, topBlock, s.Blocks[0], level+1)...) 478 lines = append(lines, fmt.Sprintf("%s}", idt)) 479 case shaderir.Continue: 480 lines = append(lines, idt+"continue;") 481 case shaderir.Break: 482 lines = append(lines, idt+"break;") 483 case shaderir.Return: 484 switch { 485 case topBlock == p.VertexFunc.Block: 486 lines = append(lines, fmt.Sprintf("%sreturn %s;", idt, vertexOut)) 487 case topBlock == p.FragmentFunc.Block: 488 lines = append(lines, fmt.Sprintf("%sreturn %s;", idt, fragmentOut)) 489 case len(s.Exprs) == 0: 490 lines = append(lines, idt+"return;") 491 default: 492 lines = append(lines, fmt.Sprintf("%sreturn %s;", idt, metalExpr(&s.Exprs[0]))) 493 } 494 case shaderir.Discard: 495 lines = append(lines, idt+"discard;") 496 default: 497 lines = append(lines, fmt.Sprintf("%s?(unexpected stmt: %d)", idt, s.Type)) 498 } 499 } 500 501 return lines 502 }