@@ -40,11 +40,11 @@ protocol ParamInfo: CustomStringConvertible {
40
40
var dependencies : [ LifetimeDependence ] { get set }
41
41
42
42
func getBoundsCheckedThunkBuilder(
43
- _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionDeclSyntax
43
+ _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionParts
44
44
) -> BoundsCheckedThunkBuilder
45
45
}
46
46
47
- func tryGetParamName( _ funcDecl: FunctionDeclSyntax , _ expr: SwiftifyExpr ) -> TokenSyntax ? {
47
+ func tryGetParamName( _ funcDecl: FunctionParts , _ expr: SwiftifyExpr ) -> TokenSyntax ? {
48
48
switch expr {
49
49
case . param( let i) :
50
50
let funcParam = getParam ( funcDecl, i - 1 )
@@ -55,7 +55,7 @@ func tryGetParamName(_ funcDecl: FunctionDeclSyntax, _ expr: SwiftifyExpr) -> To
55
55
}
56
56
}
57
57
58
- func getSwiftifyExprType( _ funcDecl: FunctionDeclSyntax , _ expr: SwiftifyExpr ) -> TypeSyntax {
58
+ func getSwiftifyExprType( _ funcDecl: FunctionParts , _ expr: SwiftifyExpr ) -> TypeSyntax {
59
59
switch expr {
60
60
case . param( let i) :
61
61
let funcParam = getParam ( funcDecl, i - 1 )
@@ -79,7 +79,7 @@ struct CxxSpan: ParamInfo {
79
79
}
80
80
81
81
func getBoundsCheckedThunkBuilder(
82
- _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionDeclSyntax
82
+ _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionParts
83
83
) -> BoundsCheckedThunkBuilder {
84
84
switch pointerIndex {
85
85
case . param( let i) :
@@ -115,7 +115,7 @@ struct CountedBy: ParamInfo {
115
115
}
116
116
117
117
func getBoundsCheckedThunkBuilder(
118
- _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionDeclSyntax
118
+ _ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionParts
119
119
) -> BoundsCheckedThunkBuilder {
120
120
switch pointerIndex {
121
121
case . param( let i) :
@@ -400,14 +400,14 @@ func getParam(_ signature: FunctionSignatureSyntax, _ paramIndex: Int) -> Functi
400
400
}
401
401
}
402
402
403
- func getParam( _ funcDecl: FunctionDeclSyntax , _ paramIndex: Int ) -> FunctionParameterSyntax {
403
+ func getParam( _ funcDecl: FunctionParts , _ paramIndex: Int ) -> FunctionParameterSyntax {
404
404
return getParam ( funcDecl. signature, paramIndex)
405
405
}
406
406
407
407
struct FunctionCallBuilder : BoundsCheckedThunkBuilder {
408
- let base : FunctionDeclSyntax
408
+ let base : FunctionParts
409
409
410
- init ( _ function: FunctionDeclSyntax ) {
410
+ init ( _ function: FunctionParts ) {
411
411
base = function
412
412
}
413
413
@@ -467,14 +467,18 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
467
467
FunctionCallExprSyntax (
468
468
calledExpression: functionRef, leftParen: . leftParenToken( ) ,
469
469
arguments: LabeledExprListSyntax ( labeledArgs) , rightParen: . rightParenToken( ) ) )
470
- return " unsafe \( call) "
470
+ if base. name. tokenKind == . keyword( . `init`) {
471
+ return " unsafe self. \( call) "
472
+ } else {
473
+ return " unsafe \( call) "
474
+ }
471
475
}
472
476
}
473
477
474
478
struct CxxSpanThunkBuilder : SpanBoundsThunkBuilder , ParamBoundsThunkBuilder {
475
479
public let base : BoundsCheckedThunkBuilder
476
480
public let index : Int
477
- public let funcDecl : FunctionDeclSyntax
481
+ public let funcDecl : FunctionParts
478
482
public let typeMappings : [ String : String ]
479
483
public let node : SyntaxProtocol
480
484
public let nonescaping : Bool
@@ -525,7 +529,7 @@ struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {
525
529
526
530
struct CxxSpanReturnThunkBuilder : SpanBoundsThunkBuilder {
527
531
public let base : BoundsCheckedThunkBuilder
528
- public let funcDecl : FunctionDeclSyntax
532
+ public let funcDecl : FunctionParts
529
533
public let typeMappings : [ String : String ]
530
534
public let node : SyntaxProtocol
531
535
let isParameter : Bool = false
@@ -564,7 +568,7 @@ struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder {
564
568
protocol BoundsThunkBuilder : BoundsCheckedThunkBuilder {
565
569
var oldType : TypeSyntax { get }
566
570
var newType : TypeSyntax { get throws }
567
- var funcDecl : FunctionDeclSyntax { get }
571
+ var funcDecl : FunctionParts { get }
568
572
}
569
573
570
574
extension BoundsThunkBuilder {
@@ -675,7 +679,7 @@ extension ParamBoundsThunkBuilder {
675
679
struct CountedOrSizedReturnPointerThunkBuilder : PointerBoundsThunkBuilder {
676
680
public let base : BoundsCheckedThunkBuilder
677
681
public let countExpr : ExprSyntax
678
- public let funcDecl : FunctionDeclSyntax
682
+ public let funcDecl : FunctionParts
679
683
public let nonescaping : Bool
680
684
public let isSizedBy : Bool
681
685
public let dependencies : [ LifetimeDependence ]
@@ -743,7 +747,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
743
747
public let base : BoundsCheckedThunkBuilder
744
748
public let index : Int
745
749
public let countExpr : ExprSyntax
746
- public let funcDecl : FunctionDeclSyntax
750
+ public let funcDecl : FunctionParts
747
751
public let nonescaping : Bool
748
752
public let isSizedBy : Bool
749
753
let isParameter : Bool = true
@@ -1237,22 +1241,22 @@ func parseMacroParam(
1237
1241
}
1238
1242
}
1239
1243
1240
- func checkArgs( _ args: [ ParamInfo ] , _ funcDecl : FunctionDeclSyntax ) throws {
1244
+ func checkArgs( _ args: [ ParamInfo ] , _ funcComponents : FunctionParts ) throws {
1241
1245
var argByIndex : [ Int : ParamInfo ] = [ : ]
1242
1246
var ret : ParamInfo ? = nil
1243
- let paramCount = funcDecl . signature. parameterClause. parameters. count
1247
+ let paramCount = funcComponents . signature. parameterClause. parameters. count
1244
1248
try args. forEach { pointerInfo in
1245
1249
switch pointerInfo. pointerIndex {
1246
1250
case . param( let i) :
1247
1251
if i < 1 || i > paramCount {
1248
1252
let noteMessage =
1249
1253
paramCount > 0
1250
- ? " function \( funcDecl . name) has parameter indices 1.. \( paramCount) "
1251
- : " function \( funcDecl . name) has no parameters "
1254
+ ? " function \( funcComponents . name) has parameter indices 1.. \( paramCount) "
1255
+ : " function \( funcComponents . name) has no parameters "
1252
1256
throw DiagnosticError (
1253
1257
" pointer index out of bounds " , node: pointerInfo. original,
1254
1258
notes: [
1255
- Note ( node: Syntax ( funcDecl . name) , message: MacroExpansionNoteMessage ( noteMessage) )
1259
+ Note ( node: Syntax ( funcComponents . name) , message: MacroExpansionNoteMessage ( noteMessage) )
1256
1260
] )
1257
1261
}
1258
1262
if argByIndex [ i] != nil {
@@ -1316,7 +1320,7 @@ func isInout(_ type: TypeSyntax) -> Bool {
1316
1320
}
1317
1321
1318
1322
func getReturnLifetimeAttribute(
1319
- _ funcDecl: FunctionDeclSyntax ,
1323
+ _ funcDecl: FunctionParts ,
1320
1324
_ dependencies: [ SwiftifyExpr : [ LifetimeDependence ] ]
1321
1325
) -> [ AttributeListSyntax . Element ] {
1322
1326
let returnDependencies = dependencies [ . `return`, default: [ ] ]
@@ -1473,9 +1477,9 @@ class CountExprRewriter: SyntaxRewriter {
1473
1477
}
1474
1478
}
1475
1479
1476
- func renameParameterNamesIfNeeded( _ funcDecl : FunctionDeclSyntax ) -> ( FunctionDeclSyntax , CountExprRewriter ) {
1477
- let params = funcDecl . signature. parameterClause. parameters
1478
- let funcName = funcDecl . name. withoutBackticks. trimmed. text
1480
+ func renameParameterNamesIfNeeded( _ funcComponents : FunctionParts ) -> ( FunctionParts , CountExprRewriter ) {
1481
+ let params = funcComponents . signature. parameterClause. parameters
1482
+ let funcName = funcComponents . name. withoutBackticks. trimmed. text
1479
1483
let shouldRename = params. contains ( where: { param in
1480
1484
let paramName = param. name. trimmed. text
1481
1485
return paramName == " _ " || paramName == funcName || " ` \( paramName) ` " == funcName
@@ -1499,13 +1503,32 @@ func renameParameterNamesIfNeeded(_ funcDecl: FunctionDeclSyntax) -> (FunctionDe
1499
1503
}
1500
1504
return newParam
1501
1505
}
1502
- let newDecl = if renamedParams. count > 0 {
1503
- funcDecl . with ( \. signature . parameterClause. parameters, FunctionParameterListSyntax ( newParams) )
1506
+ let newSig = if renamedParams. count > 0 {
1507
+ funcComponents . signature . with ( \. parameterClause. parameters, FunctionParameterListSyntax ( newParams) )
1504
1508
} else {
1505
1509
// Keeps source locations for diagnostics, in the common case where nothing was renamed
1506
- funcDecl
1510
+ funcComponents. signature
1511
+ }
1512
+ return ( FunctionParts ( signature: newSig, name: funcComponents. name, attributes: funcComponents. attributes) ,
1513
+ CountExprRewriter ( renamedParams) )
1514
+ }
1515
+
1516
+ struct FunctionParts {
1517
+ let signature : FunctionSignatureSyntax
1518
+ let name : TokenSyntax
1519
+ let attributes : AttributeListSyntax
1520
+ }
1521
+
1522
+ func deconstructFunction( _ declaration: some DeclSyntaxProtocol ) throws -> FunctionParts {
1523
+ if let origFuncDecl = declaration. as ( FunctionDeclSyntax . self) {
1524
+ return FunctionParts ( signature: origFuncDecl. signature, name: origFuncDecl. name,
1525
+ attributes: origFuncDecl. attributes)
1526
+ }
1527
+ if let origInitDecl = declaration. as ( InitializerDeclSyntax . self) {
1528
+ return FunctionParts ( signature: origInitDecl. signature, name: origInitDecl. initKeyword,
1529
+ attributes: origInitDecl. attributes)
1507
1530
}
1508
- return ( newDecl , CountExprRewriter ( renamedParams ) )
1531
+ throw DiagnosticError ( " @_SwiftifyImport only works on functions and initializers " , node : declaration )
1509
1532
}
1510
1533
1511
1534
/// A macro that adds safe(r) wrappers for functions with unsafe pointer types.
@@ -1521,10 +1544,8 @@ public struct SwiftifyImportMacro: PeerMacro {
1521
1544
in context: some MacroExpansionContext
1522
1545
) throws -> [ DeclSyntax ] {
1523
1546
do {
1524
- guard let origFuncDecl = declaration. as ( FunctionDeclSyntax . self) else {
1525
- throw DiagnosticError ( " @_SwiftifyImport only works on functions " , node: declaration)
1526
- }
1527
- let ( funcDecl, rewriter) = renameParameterNamesIfNeeded ( origFuncDecl)
1547
+ let origFuncComponents = try deconstructFunction ( declaration)
1548
+ let ( funcComponents, rewriter) = renameParameterNamesIfNeeded ( origFuncComponents)
1528
1549
1529
1550
let argumentList = node. arguments!. as ( LabeledExprListSyntax . self) !
1530
1551
var arguments = [ LabeledExprSyntax] ( argumentList)
@@ -1540,10 +1561,10 @@ public struct SwiftifyImportMacro: PeerMacro {
1540
1561
var lifetimeDependencies : [ SwiftifyExpr : [ LifetimeDependence ] ] = [ : ]
1541
1562
var parsedArgs = try arguments. compactMap {
1542
1563
try parseMacroParam (
1543
- $0, funcDecl . signature, rewriter, nonescapingPointers: & nonescapingPointers,
1564
+ $0, funcComponents . signature, rewriter, nonescapingPointers: & nonescapingPointers,
1544
1565
lifetimeDependencies: & lifetimeDependencies)
1545
1566
}
1546
- parsedArgs. append ( contentsOf: try parseCxxSpansInSignature ( funcDecl . signature, typeMappings) )
1567
+ parsedArgs. append ( contentsOf: try parseCxxSpansInSignature ( funcComponents . signature, typeMappings) )
1547
1568
setNonescapingPointers ( & parsedArgs, nonescapingPointers)
1548
1569
setLifetimeDependencies ( & parsedArgs, lifetimeDependencies)
1549
1570
// We only transform non-escaping spans.
@@ -1554,7 +1575,7 @@ public struct SwiftifyImportMacro: PeerMacro {
1554
1575
return true
1555
1576
}
1556
1577
}
1557
- try checkArgs ( parsedArgs, funcDecl )
1578
+ try checkArgs ( parsedArgs, funcComponents )
1558
1579
parsedArgs. sort { a, b in
1559
1580
// make sure return value cast to Span happens last so that withUnsafeBufferPointer
1560
1581
// doesn't return a ~Escapable type
@@ -1566,12 +1587,12 @@ public struct SwiftifyImportMacro: PeerMacro {
1566
1587
}
1567
1588
return paramOrReturnIndex ( a. pointerIndex) < paramOrReturnIndex ( b. pointerIndex)
1568
1589
}
1569
- let baseBuilder = FunctionCallBuilder ( funcDecl )
1590
+ let baseBuilder = FunctionCallBuilder ( funcComponents )
1570
1591
1571
1592
let builder : BoundsCheckedThunkBuilder = parsedArgs. reduce (
1572
1593
baseBuilder,
1573
1594
{ ( prev, parsedArg) in
1574
- parsedArg. getBoundsCheckedThunkBuilder ( prev, funcDecl )
1595
+ parsedArg. getBoundsCheckedThunkBuilder ( prev, funcComponents )
1575
1596
} )
1576
1597
let newSignature = try builder. buildFunctionSignature ( [ : ] , nil )
1577
1598
var eliminatedArgs = Set < Int > ( )
@@ -1580,15 +1601,22 @@ public struct SwiftifyImportMacro: PeerMacro {
1580
1601
let checks = ( basicChecks + compoundChecks) . map { e in
1581
1602
CodeBlockItemSyntax ( leadingTrivia: " \n " , item: e)
1582
1603
}
1583
- let call = CodeBlockItemSyntax (
1584
- item: CodeBlockItemSyntax . Item (
1585
- ReturnStmtSyntax (
1586
- returnKeyword: . keyword( . return, trailingTrivia: " " ) ,
1587
- expression: try builder. buildFunctionCall ( [ : ] ) ) ) )
1604
+ var call : CodeBlockItemSyntax
1605
+ if declaration. is ( InitializerDeclSyntax . self) {
1606
+ call = CodeBlockItemSyntax (
1607
+ item: CodeBlockItemSyntax . Item (
1608
+ try builder. buildFunctionCall ( [ : ] ) ) )
1609
+ } else {
1610
+ call = CodeBlockItemSyntax (
1611
+ item: CodeBlockItemSyntax . Item (
1612
+ ReturnStmtSyntax (
1613
+ returnKeyword: . keyword( . return, trailingTrivia: " " ) ,
1614
+ expression: try builder. buildFunctionCall ( [ : ] ) ) ) )
1615
+ }
1588
1616
let body = CodeBlockSyntax ( statements: CodeBlockItemListSyntax ( checks + [ call] ) )
1589
- let returnLifetimeAttribute = getReturnLifetimeAttribute ( funcDecl , lifetimeDependencies)
1617
+ let returnLifetimeAttribute = getReturnLifetimeAttribute ( funcComponents , lifetimeDependencies)
1590
1618
let lifetimeAttrs =
1591
- returnLifetimeAttribute + paramLifetimeAttributes( newSignature, funcDecl . attributes)
1619
+ returnLifetimeAttribute + paramLifetimeAttributes( newSignature, funcComponents . attributes)
1592
1620
let availabilityAttr = try getAvailability ( newSignature, spanAvailability)
1593
1621
let disfavoredOverload : [ AttributeListSyntax . Element ] =
1594
1622
[
@@ -1597,13 +1625,7 @@ public struct SwiftifyImportMacro: PeerMacro {
1597
1625
atSign: . atSignToken( ) ,
1598
1626
attributeName: IdentifierTypeSyntax ( name: " _disfavoredOverload " ) ) )
1599
1627
]
1600
- let newFunc =
1601
- funcDecl
1602
- . with ( \. signature, newSignature)
1603
- . with ( \. body, body)
1604
- . with (
1605
- \. attributes,
1606
- funcDecl. attributes. filter { e in
1628
+ let attributes = funcComponents. attributes. filter { e in
1607
1629
switch e {
1608
1630
case . attribute( let attr) :
1609
1631
// don't apply this macro recursively, and avoid dupe _alwaysEmitIntoClient
@@ -1619,9 +1641,23 @@ public struct SwiftifyImportMacro: PeerMacro {
1619
1641
]
1620
1642
+ availabilityAttr
1621
1643
+ lifetimeAttrs
1622
- + disfavoredOverload)
1623
- . with ( \. leadingTrivia, node. leadingTrivia + . docLineComment( " /// This is an auto-generated wrapper for safer interop \n " ) )
1624
- return [ DeclSyntax ( newFunc) ]
1644
+ + disfavoredOverload
1645
+ let trivia = node. leadingTrivia + . docLineComment( " /// This is an auto-generated wrapper for safer interop \n " )
1646
+ if let origFuncDecl = declaration. as ( FunctionDeclSyntax . self) {
1647
+ return [ DeclSyntax ( origFuncDecl
1648
+ . with ( \. signature, newSignature)
1649
+ . with ( \. body, body)
1650
+ . with ( \. attributes, AttributeListSyntax ( attributes) )
1651
+ . with ( \. leadingTrivia, trivia) ) ]
1652
+ }
1653
+ if let origInitDecl = declaration. as ( InitializerDeclSyntax . self) {
1654
+ return [ DeclSyntax ( origInitDecl
1655
+ . with ( \. signature, newSignature)
1656
+ . with ( \. body, body)
1657
+ . with ( \. attributes, AttributeListSyntax ( attributes) )
1658
+ . with ( \. leadingTrivia, trivia) ) ]
1659
+ }
1660
+ return [ ]
1625
1661
} catch let error as DiagnosticError {
1626
1662
context. diagnose (
1627
1663
Diagnostic (
@@ -1686,6 +1722,9 @@ extension FunctionParameterSyntax {
1686
1722
1687
1723
extension TokenSyntax {
1688
1724
public var withoutBackticks : TokenSyntax {
1725
+ if self . identifier == nil {
1726
+ return self
1727
+ }
1689
1728
return . identifier( self . identifier!. name)
1690
1729
}
1691
1730
0 commit comments