Skip to content

Commit ed541e6

Browse files
committed
fix: add asyncapi websocket subprotocol auth
1 parent 6ff9217 commit ed541e6

File tree

26 files changed

+821
-37
lines changed

26 files changed

+821
-37
lines changed

src/libs/AutoSDK.CSharp/Pipeline/AsyncApiData.cs

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -956,8 +956,10 @@ private static Authorization[] BuildAuthorizations(
956956
var secScheme = kvp.Value;
957957
var (friendlyName, parameters, paramLocation, securitySchemeType) =
958958
MapSecurityScheme(secScheme);
959-
960-
authorizations.Add(CSharpAuthorizationFactory.Create(
959+
var subProtocolTemplates = secScheme.SubProtocols
960+
.Where(static x => !string.IsNullOrWhiteSpace(x))
961+
.ToImmutableArray();
962+
var authorization = CSharpAuthorizationFactory.Create(
961963
friendlyName: friendlyName,
962964
schemeId: kvp.Key,
963965
type: securitySchemeType,
@@ -970,7 +972,14 @@ private static Authorization[] BuildAuthorizations(
970972
flows: ImmutableArray<OAuthFlow>.Empty.AsEquatableArray(),
971973
openIdConnectUrl: string.Empty,
972974
oAuth2MetadataUrl: string.Empty,
973-
isDeprecated: false));
975+
isDeprecated: false);
976+
977+
if (!subProtocolTemplates.IsEmpty)
978+
{
979+
authorization.WebSocketSubProtocols = subProtocolTemplates.AsEquatableArray();
980+
}
981+
982+
authorizations.Add(authorization);
974983
}
975984

976985
return authorizations.ToArray();
@@ -979,6 +988,15 @@ private static Authorization[] BuildAuthorizations(
979988
private static (string FriendlyName, string[] Parameters, ParameterLocation? Location,
980989
SecuritySchemeType? SchemeType) MapSecurityScheme(AsyncApiSecurityScheme scheme)
981990
{
991+
if (scheme.SubProtocols.Count > 0)
992+
{
993+
return (
994+
"Subprotocol",
995+
ExtractSubProtocolParameters(scheme.SubProtocols).ToArray(),
996+
null,
997+
SecuritySchemeType.ApiKey);
998+
}
999+
9821000
return (scheme.Type.ToLowerInvariant(), scheme.Scheme?.ToUpperInvariant(), scheme.In?.ToLowerInvariant()) switch
9831001
{
9841002
("http", "BEARER", _) =>
@@ -996,6 +1014,45 @@ private static (string FriendlyName, string[] Parameters, ParameterLocation? Loc
9961014
};
9971015
}
9981016

1017+
private static IEnumerable<string> ExtractSubProtocolParameters(
1018+
IEnumerable<string> templates)
1019+
{
1020+
var seen = new HashSet<string>(StringComparer.Ordinal);
1021+
1022+
foreach (var template in templates)
1023+
{
1024+
if (string.IsNullOrWhiteSpace(template))
1025+
{
1026+
continue;
1027+
}
1028+
1029+
var startIndex = 0;
1030+
while (startIndex < template.Length)
1031+
{
1032+
var openBrace = template.IndexOf('{', startIndex);
1033+
if (openBrace < 0)
1034+
{
1035+
break;
1036+
}
1037+
1038+
var closeBrace = template.IndexOf('}', openBrace + 1);
1039+
if (closeBrace < 0)
1040+
{
1041+
break;
1042+
}
1043+
1044+
var parameterName = template.Substring(openBrace + 1, closeBrace - openBrace - 1).Trim();
1045+
if (!string.IsNullOrWhiteSpace(parameterName) &&
1046+
seen.Add(parameterName))
1047+
{
1048+
yield return parameterName;
1049+
}
1050+
1051+
startIndex = closeBrace + 1;
1052+
}
1053+
}
1054+
}
1055+
9991056
/// <summary>
10001057
/// When multiple receive message types exist, injects synthetic oneOf wrapper schemas
10011058
/// into the AsyncAPI components for discriminated union deserialization.

src/libs/AutoSDK.CSharp/Sources/Sources.WebSocketClient.cs

Lines changed: 169 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,40 @@ partial void OnReceiveException(
158158
private static string GenerateWebSocketAuthorizationConstructors(WebSocketClient wsClient)
159159
{
160160
var result = new System.Text.StringBuilder();
161+
var hasStoredHeaderAuth = wsClient.Authorizations.Any(static auth =>
162+
auth.Type == SecuritySchemeType.Http &&
163+
!string.Equals(auth.Scheme, "basic", StringComparison.OrdinalIgnoreCase) ||
164+
auth.Type == SecuritySchemeType.ApiKey &&
165+
auth.In == ParameterLocation.Header);
166+
var hasSubprotocolAuth = wsClient.Authorizations.Any(static auth => !auth.WebSocketSubProtocols.IsEmpty);
167+
var bridgeApiKeyToSubprotocol = hasSubprotocolAuth &&
168+
wsClient.Authorizations.Any(static auth =>
169+
!auth.WebSocketSubProtocols.IsEmpty &&
170+
auth.Parameters.Contains("apiKey"));
171+
172+
if (hasStoredHeaderAuth)
173+
{
174+
result.AppendLine(@"
175+
private string? _storedAuthorizationHeaderName;
176+
private string? _storedAuthorizationHeaderScheme;
177+
private string? _storedAuthorizationApiKey;");
178+
}
179+
180+
if (hasSubprotocolAuth)
181+
{
182+
result.AppendLine(@"
183+
private readonly global::System.Collections.Generic.Dictionary<string, string> _subprotocolAuthorizationValues = new global::System.Collections.Generic.Dictionary<string, string>(global::System.StringComparer.Ordinal);
184+
private bool _preferSubprotocolAuth;");
185+
}
161186

162187
foreach (var auth in wsClient.Authorizations)
163188
{
164-
if (auth.Type == SecuritySchemeType.Http &&
165-
!string.Equals(auth.Scheme, "basic", StringComparison.OrdinalIgnoreCase))
189+
if (!auth.WebSocketSubProtocols.IsEmpty)
190+
{
191+
result.AppendLine(GenerateSubprotocolAuthorizationMethod(auth));
192+
}
193+
else if (auth.Type == SecuritySchemeType.Http &&
194+
!string.Equals(auth.Scheme, "basic", StringComparison.OrdinalIgnoreCase))
166195
{
167196
var friendlyName = string.IsNullOrWhiteSpace(auth.FriendlyName)
168197
? auth.Scheme.ToPropertyName()
@@ -181,8 +210,10 @@ private static string GenerateWebSocketAuthorizationConstructors(WebSocketClient
181210
{{
182211
apiKey = apiKey ?? throw new global::System.ArgumentNullException(nameof(apiKey));
183212
184-
_clientWebSocket.Options.SetRequestHeader(""Authorization"", $""{schemeName} {{apiKey}}"");
185-
}}
213+
_storedAuthorizationApiKey = apiKey;
214+
_storedAuthorizationHeaderName = ""Authorization"";
215+
_storedAuthorizationHeaderScheme = ""{schemeName}"";
216+
{(hasSubprotocolAuth ? " _preferSubprotocolAuth = false;\n" : string.Empty)}{(bridgeApiKeyToSubprotocol ? " _subprotocolAuthorizationValues[\"apiKey\"] = apiKey;\n" : string.Empty)} }}
186217
187218
/// <summary>
188219
/// Creates a new instance with {schemeName} token authentication.
@@ -214,8 +245,10 @@ public void AuthorizeUsingApiKeyInHeader(
214245
{{
215246
apiKey = apiKey ?? throw new global::System.ArgumentNullException(nameof(apiKey));
216247
217-
_clientWebSocket.Options.SetRequestHeader(""{auth.Name}"", apiKey);
218-
}}
248+
_storedAuthorizationApiKey = apiKey;
249+
_storedAuthorizationHeaderName = ""{auth.Name}"";
250+
_storedAuthorizationHeaderScheme = null;
251+
{(hasSubprotocolAuth ? " _preferSubprotocolAuth = false;\n" : string.Empty)}{(bridgeApiKeyToSubprotocol ? " _subprotocolAuthorizationValues[\"apiKey\"] = apiKey;\n" : string.Empty)} }}
219252
220253
/// <summary>
221254
/// Creates a new instance with API key header authentication.
@@ -251,7 +284,7 @@ public void AuthorizeUsingApiKeyInQuery(
251284
// Query parameter auth is handled during ConnectAsync
252285
_queryApiKey = apiKey;
253286
_queryApiKeyName = ""{auth.Name}"";
254-
}}
287+
{(bridgeApiKeyToSubprotocol ? " _subprotocolAuthorizationValues[\"apiKey\"] = apiKey;\n" : string.Empty)} }}
255288
256289
private string? _queryApiKey;
257290
private string? _queryApiKeyName;";
@@ -262,13 +295,49 @@ public void AuthorizeUsingApiKeyInQuery(
262295
return result.ToString();
263296
}
264297

298+
private static string GenerateSubprotocolAuthorizationMethod(Authorization auth)
299+
{
300+
var xmlParams = auth.Parameters.Select(parameter =>
301+
$@" /// <param name=""{parameter.ToParameterName()}""></param>");
302+
var parameterDeclarations = auth.Parameters.Select(parameter =>
303+
$@" string {parameter.ToParameterName()}");
304+
var assignments = auth.Parameters.Select(parameter =>
305+
{
306+
var parameterName = parameter.ToParameterName();
307+
return $@" var {parameterName}Value = {parameterName} ?? throw new global::System.ArgumentNullException(nameof({parameterName}));
308+
_subprotocolAuthorizationValues[""{parameter}""] = {parameterName}Value;";
309+
});
310+
var signature = auth.Parameters.Length == 0
311+
? string.Empty
312+
: $@"
313+
{parameterDeclarations.Inject().TrimEnd(',')}
314+
";
315+
316+
return $@"
317+
/// <summary>
318+
/// Authorize using WebSocket subprotocol authentication.
319+
/// </summary>
320+
{xmlParams.Inject(emptyValue: string.Empty)}
321+
public void {auth.MethodName}({signature} )
322+
{{
323+
{assignments.Inject(emptyValue: string.Empty)}
324+
_preferSubprotocolAuth = true;
325+
}}";
326+
}
327+
265328
private static string GenerateConnectAsync(WebSocketClient wsClient)
266329
{
267330
var serverVariables = wsClient.ServerVariables.ToArray();
268331
var signatureParams = wsClient.QueryParameters.ToArray();
269332
var serializedParams = wsClient.SerializedQueryParameters.ToArray();
270333
var hasTypedParameters = serverVariables.Length > 0 || signatureParams.Length > 0;
271334
var hasRequiredParameters = serverVariables.Any(static x => x.IsRequired) || signatureParams.Any(static x => x.IsRequired);
335+
var hasSubprotocolAuth = wsClient.Authorizations.Any(static x => !x.WebSocketSubProtocols.IsEmpty);
336+
var hasStoredHeaderAuth = wsClient.Authorizations.Any(static x =>
337+
x.Type == SecuritySchemeType.Http &&
338+
!string.Equals(x.Scheme, "basic", StringComparison.OrdinalIgnoreCase) ||
339+
x.Type == SecuritySchemeType.ApiKey &&
340+
x.In == ParameterLocation.Header);
272341
var supportsQueryApiKeyAuth = wsClient.Authorizations.Any(
273342
static x => x.Type == SecuritySchemeType.ApiKey &&
274343
x.In == ParameterLocation.Query);
@@ -279,18 +348,30 @@ private static string GenerateConnectAsync(WebSocketClient wsClient)
279348
__pathBuilder.AddRequiredParameter(_queryApiKeyName, _queryApiKey);
280349
}"
281350
: string.Empty;
351+
var applyStoredAuthorization = GenerateStoredWebSocketAuthorization(
352+
wsClient,
353+
hasStoredHeaderAuth,
354+
hasSubprotocolAuth);
355+
var connectionOptionsExtraParameter = hasSubprotocolAuth
356+
? ",\n bool useSubprotocolAuth"
357+
: string.Empty;
358+
var applyStoredAuthorizationCall = hasStoredHeaderAuth || hasSubprotocolAuth
359+
? $" ApplyStoredAuthorization({(hasSubprotocolAuth ? "useSubprotocolAuth" : "false")});\n"
360+
: string.Empty;
282361

283362
var connectionHelpers = @"
363+
" + applyStoredAuthorization + @"
284364
private void ApplyConnectionOptions(
285365
global::System.Collections.Generic.IDictionary<string, string>? additionalHeaders,
286366
global::System.Collections.Generic.IEnumerable<string>? additionalSubProtocols,
287-
global::System.TimeSpan? keepAliveInterval)
367+
global::System.TimeSpan? keepAliveInterval" + connectionOptionsExtraParameter + @")
288368
{
289369
if (keepAliveInterval is not null)
290370
{
291371
_clientWebSocket.Options.KeepAliveInterval = keepAliveInterval.Value;
292372
}
293373
374+
" + applyStoredAuthorizationCall + @"
294375
if (additionalHeaders is not null)
295376
{
296377
foreach (var header in additionalHeaders)
@@ -345,7 +426,7 @@ private void ApplyConnectionOptions(
345426
global::System.Collections.Generic.IEnumerable<string>? additionalSubProtocols = null,
346427
global::System.TimeSpan? keepAliveInterval = null,
347428
global::System.TimeSpan? connectTimeout = null,
348-
global::System.Threading.CancellationToken cancellationToken = default)
429+
{(hasSubprotocolAuth ? "bool useSubprotocolAuth = false,\n " : string.Empty)}global::System.Threading.CancellationToken cancellationToken = default)
349430
{{
350431
global::System.Uri __uri;
351432
if (uri is not null)
@@ -360,7 +441,7 @@ private void ApplyConnectionOptions(
360441
__uri = new global::System.Uri(__pathBuilder.ToString());
361442
}}
362443
363-
ApplyConnectionOptions(additionalHeaders, additionalSubProtocols, keepAliveInterval);
444+
ApplyConnectionOptions(additionalHeaders, additionalSubProtocols, keepAliveInterval{(hasSubprotocolAuth ? ", useSubprotocolAuth" : string.Empty)});
364445
await ConnectAsyncCore(__uri, connectTimeout, cancellationToken).ConfigureAwait(false);
365446
}}";
366447
}
@@ -449,6 +530,10 @@ void AppendParameter(MethodParameter param, bool preferNonNullableDefaultType)
449530
xmlDoc.AppendLine(" /// <param name=\"additionalSubProtocols\">Additional WebSocket subprotocols applied before connecting.</param>");
450531
xmlDoc.AppendLine(" /// <param name=\"keepAliveInterval\">Optional keep-alive interval.</param>");
451532
xmlDoc.AppendLine(" /// <param name=\"connectTimeout\">Optional connect timeout.</param>");
533+
if (hasSubprotocolAuth)
534+
{
535+
xmlDoc.AppendLine(" /// <param name=\"useSubprotocolAuth\">When true, applies stored subprotocol authentication instead of header authentication.</param>");
536+
}
452537
xmlDoc.AppendLine(" /// <param name=\"cancellationToken\">A cancellation token.</param>");
453538

454539
foreach (var param in serializedParams)
@@ -492,7 +577,7 @@ void AppendParameter(MethodParameter param, bool preferNonNullableDefaultType)
492577
global::System.Collections.Generic.IEnumerable<string>? additionalSubProtocols = null,
493578
global::System.TimeSpan? keepAliveInterval = null,
494579
global::System.TimeSpan? connectTimeout = null,
495-
global::System.Threading.CancellationToken cancellationToken = default)
580+
{(hasSubprotocolAuth ? "bool useSubprotocolAuth = false,\n " : string.Empty)}global::System.Threading.CancellationToken cancellationToken = default)
496581
{{
497582
global::System.Uri __uri;
498583
if (uri is not null)
@@ -507,7 +592,7 @@ void AppendParameter(MethodParameter param, bool preferNonNullableDefaultType)
507592
__uri = new global::System.Uri(__pathBuilder.ToString());
508593
}}
509594
510-
ApplyConnectionOptions(additionalHeaders, additionalSubProtocols, keepAliveInterval);
595+
ApplyConnectionOptions(additionalHeaders, additionalSubProtocols, keepAliveInterval{(hasSubprotocolAuth ? ", useSubprotocolAuth" : string.Empty)});
511596
await ConnectAsyncCore(__uri, connectTimeout, cancellationToken).ConfigureAwait(false);
512597
}}" : string.Empty;
513598

@@ -526,7 +611,7 @@ void AppendParameter(MethodParameter param, bool preferNonNullableDefaultType)
526611
global::System.Collections.Generic.IEnumerable<string>? additionalSubProtocols = null,
527612
global::System.TimeSpan? keepAliveInterval = null,
528613
global::System.TimeSpan? connectTimeout = null,
529-
global::System.Threading.CancellationToken cancellationToken = default)
614+
{(hasSubprotocolAuth ? "bool useSubprotocolAuth = false,\n " : string.Empty)}global::System.Threading.CancellationToken cancellationToken = default)
530615
{{
531616
global::System.Uri __uri;
532617
if (uri is not null)
@@ -544,8 +629,78 @@ void AppendParameter(MethodParameter param, bool preferNonNullableDefaultType)
544629
__uri = new global::System.Uri(__pathBuilder.ToString());
545630
}}
546631
547-
ApplyConnectionOptions(additionalHeaders, additionalSubProtocols, keepAliveInterval);
632+
ApplyConnectionOptions(additionalHeaders, additionalSubProtocols, keepAliveInterval{(hasSubprotocolAuth ? ", useSubprotocolAuth" : string.Empty)});
548633
await ConnectAsyncCore(__uri, connectTimeout, cancellationToken).ConfigureAwait(false);
549634
}}";
550635
}
636+
637+
private static string GenerateStoredWebSocketAuthorization(
638+
WebSocketClient wsClient,
639+
bool hasStoredHeaderAuth,
640+
bool hasSubprotocolAuth)
641+
{
642+
if (!hasStoredHeaderAuth && !hasSubprotocolAuth)
643+
{
644+
return string.Empty;
645+
}
646+
647+
var subprotocolBlocks = wsClient.Authorizations
648+
.Where(static auth => !auth.WebSocketSubProtocols.IsEmpty)
649+
.Select(auth =>
650+
{
651+
var condition = auth.Parameters.Length == 0
652+
? "true"
653+
: string.Join(" && ", auth.Parameters.Select(parameter =>
654+
$@"_subprotocolAuthorizationValues.ContainsKey(""{parameter}"")"));
655+
var parameterReads = auth.Parameters.Select(parameter =>
656+
{
657+
var parameterName = parameter.ToParameterName();
658+
return $@" var __{parameterName} = _subprotocolAuthorizationValues[""{parameter}""];
659+
";
660+
}).Inject(emptyValue: string.Empty);
661+
var addProtocols = auth.WebSocketSubProtocols
662+
.Distinct(StringComparer.Ordinal)
663+
.Select(template =>
664+
{
665+
var replacements = auth.Parameters.Select(parameter =>
666+
{
667+
var parameterName = parameter.ToParameterName();
668+
return $@" __subProtocol = __subProtocol.Replace(""{{{parameter}}}"", __{parameterName});
669+
";
670+
}).Inject(emptyValue: string.Empty);
671+
return $@" var __subProtocol = {template.ToCSharpStringLiteral()};
672+
{replacements} _clientWebSocket.Options.AddSubProtocol(__subProtocol);
673+
";
674+
}).Inject(emptyValue: string.Empty);
675+
676+
return $@" if ({condition})
677+
{{
678+
{parameterReads}{addProtocols} return;
679+
}}
680+
";
681+
}).Inject(emptyValue: string.Empty);
682+
683+
var headerBlock = hasStoredHeaderAuth
684+
? @"
685+
if (_storedAuthorizationApiKey is not null &&
686+
_storedAuthorizationHeaderName is not null)
687+
{
688+
var __authorizationValue = _storedAuthorizationHeaderScheme is not null
689+
? $""{_storedAuthorizationHeaderScheme} {_storedAuthorizationApiKey}""
690+
: _storedAuthorizationApiKey;
691+
_clientWebSocket.Options.SetRequestHeader(_storedAuthorizationHeaderName, __authorizationValue);
692+
}"
693+
: string.Empty;
694+
695+
return $@"
696+
private void ApplyStoredAuthorization(
697+
bool useSubprotocolAuth)
698+
{{
699+
{(hasSubprotocolAuth ? $@" if (useSubprotocolAuth || _preferSubprotocolAuth)
700+
{{
701+
{subprotocolBlocks} return;
702+
}}
703+
" : string.Empty)}{headerBlock}
704+
}}";
705+
}
551706
}

0 commit comments

Comments
 (0)