@@ -158,11 +158,40 @@ partial void OnReceiveException(
158158 private static string GenerateWebSocketAuthorizationConstructors ( WebSocketClient wsClient )
159159 {
160160 var result = new System . Text . StringBuilder ( ) ;
161+ var hasStoredHeaderAuth = wsClient . Authorizations . Any ( static auth =>
162+ auth . Type == SecuritySchemeType . Http &&
163+ ! string . Equals ( auth . Scheme , "basic" , StringComparison . OrdinalIgnoreCase ) ||
164+ auth . Type == SecuritySchemeType . ApiKey &&
165+ auth . In == ParameterLocation . Header ) ;
166+ var hasSubprotocolAuth = wsClient . Authorizations . Any ( static auth => ! auth . WebSocketSubProtocols . IsEmpty ) ;
167+ var bridgeApiKeyToSubprotocol = hasSubprotocolAuth &&
168+ wsClient . Authorizations . Any ( static auth =>
169+ ! auth . WebSocketSubProtocols . IsEmpty &&
170+ auth . Parameters . Contains ( "apiKey" ) ) ;
171+
172+ if ( hasStoredHeaderAuth )
173+ {
174+ result . AppendLine ( @"
175+ private string? _storedAuthorizationHeaderName;
176+ private string? _storedAuthorizationHeaderScheme;
177+ private string? _storedAuthorizationApiKey;" ) ;
178+ }
179+
180+ if ( hasSubprotocolAuth )
181+ {
182+ result . AppendLine ( @"
183+ private readonly global::System.Collections.Generic.Dictionary<string, string> _subprotocolAuthorizationValues = new global::System.Collections.Generic.Dictionary<string, string>(global::System.StringComparer.Ordinal);
184+ private bool _preferSubprotocolAuth;" ) ;
185+ }
161186
162187 foreach ( var auth in wsClient . Authorizations )
163188 {
164- if ( auth . Type == SecuritySchemeType . Http &&
165- ! string . Equals ( auth . Scheme , "basic" , StringComparison . OrdinalIgnoreCase ) )
189+ if ( ! auth . WebSocketSubProtocols . IsEmpty )
190+ {
191+ result . AppendLine ( GenerateSubprotocolAuthorizationMethod ( auth ) ) ;
192+ }
193+ else if ( auth . Type == SecuritySchemeType . Http &&
194+ ! string . Equals ( auth . Scheme , "basic" , StringComparison . OrdinalIgnoreCase ) )
166195 {
167196 var friendlyName = string . IsNullOrWhiteSpace ( auth . FriendlyName )
168197 ? auth . Scheme . ToPropertyName ( )
@@ -181,8 +210,10 @@ private static string GenerateWebSocketAuthorizationConstructors(WebSocketClient
181210 {{
182211 apiKey = apiKey ?? throw new global::System.ArgumentNullException(nameof(apiKey));
183212
184- _clientWebSocket.Options.SetRequestHeader(""Authorization"", $""{ schemeName } {{apiKey}}"");
185- }}
213+ _storedAuthorizationApiKey = apiKey;
214+ _storedAuthorizationHeaderName = ""Authorization"";
215+ _storedAuthorizationHeaderScheme = ""{ schemeName } "";
216+ { ( hasSubprotocolAuth ? " _preferSubprotocolAuth = false;\n " : string . Empty ) } { ( bridgeApiKeyToSubprotocol ? " _subprotocolAuthorizationValues[\" apiKey\" ] = apiKey;\n " : string . Empty ) } }}
186217
187218 /// <summary>
188219 /// Creates a new instance with { schemeName } token authentication.
@@ -214,8 +245,10 @@ public void AuthorizeUsingApiKeyInHeader(
214245 {{
215246 apiKey = apiKey ?? throw new global::System.ArgumentNullException(nameof(apiKey));
216247
217- _clientWebSocket.Options.SetRequestHeader(""{ auth . Name } "", apiKey);
218- }}
248+ _storedAuthorizationApiKey = apiKey;
249+ _storedAuthorizationHeaderName = ""{ auth . Name } "";
250+ _storedAuthorizationHeaderScheme = null;
251+ { ( hasSubprotocolAuth ? " _preferSubprotocolAuth = false;\n " : string . Empty ) } { ( bridgeApiKeyToSubprotocol ? " _subprotocolAuthorizationValues[\" apiKey\" ] = apiKey;\n " : string . Empty ) } }}
219252
220253 /// <summary>
221254 /// Creates a new instance with API key header authentication.
@@ -251,7 +284,7 @@ public void AuthorizeUsingApiKeyInQuery(
251284 // Query parameter auth is handled during ConnectAsync
252285 _queryApiKey = apiKey;
253286 _queryApiKeyName = ""{ auth . Name } "";
254- }}
287+ { ( bridgeApiKeyToSubprotocol ? " _subprotocolAuthorizationValues[ \" apiKey \" ] = apiKey; \n " : string . Empty ) } }}
255288
256289 private string? _queryApiKey;
257290 private string? _queryApiKeyName;" ;
@@ -262,13 +295,49 @@ public void AuthorizeUsingApiKeyInQuery(
262295 return result . ToString ( ) ;
263296 }
264297
298+ private static string GenerateSubprotocolAuthorizationMethod ( Authorization auth )
299+ {
300+ var xmlParams = auth . Parameters . Select ( parameter =>
301+ $@ " /// <param name=""{ parameter . ToParameterName ( ) } ""></param>") ;
302+ var parameterDeclarations = auth . Parameters . Select ( parameter =>
303+ $@ " string { parameter . ToParameterName ( ) } ") ;
304+ var assignments = auth . Parameters . Select ( parameter =>
305+ {
306+ var parameterName = parameter . ToParameterName ( ) ;
307+ return $@ " var { parameterName } Value = { parameterName } ?? throw new global::System.ArgumentNullException(nameof({ parameterName } ));
308+ _subprotocolAuthorizationValues[""{ parameter } ""] = { parameterName } Value;" ;
309+ } ) ;
310+ var signature = auth . Parameters . Length == 0
311+ ? string . Empty
312+ : $@ "
313+ { parameterDeclarations . Inject ( ) . TrimEnd ( ',' ) }
314+ " ;
315+
316+ return $@ "
317+ /// <summary>
318+ /// Authorize using WebSocket subprotocol authentication.
319+ /// </summary>
320+ { xmlParams . Inject ( emptyValue : string . Empty ) }
321+ public void { auth . MethodName } ({ signature } )
322+ {{
323+ { assignments . Inject ( emptyValue : string . Empty ) }
324+ _preferSubprotocolAuth = true;
325+ }}" ;
326+ }
327+
265328 private static string GenerateConnectAsync ( WebSocketClient wsClient )
266329 {
267330 var serverVariables = wsClient . ServerVariables . ToArray ( ) ;
268331 var signatureParams = wsClient . QueryParameters . ToArray ( ) ;
269332 var serializedParams = wsClient . SerializedQueryParameters . ToArray ( ) ;
270333 var hasTypedParameters = serverVariables . Length > 0 || signatureParams . Length > 0 ;
271334 var hasRequiredParameters = serverVariables . Any ( static x => x . IsRequired ) || signatureParams . Any ( static x => x . IsRequired ) ;
335+ var hasSubprotocolAuth = wsClient . Authorizations . Any ( static x => ! x . WebSocketSubProtocols . IsEmpty ) ;
336+ var hasStoredHeaderAuth = wsClient . Authorizations . Any ( static x =>
337+ x . Type == SecuritySchemeType . Http &&
338+ ! string . Equals ( x . Scheme , "basic" , StringComparison . OrdinalIgnoreCase ) ||
339+ x . Type == SecuritySchemeType . ApiKey &&
340+ x . In == ParameterLocation . Header ) ;
272341 var supportsQueryApiKeyAuth = wsClient . Authorizations . Any (
273342 static x => x . Type == SecuritySchemeType . ApiKey &&
274343 x . In == ParameterLocation . Query ) ;
@@ -279,18 +348,30 @@ private static string GenerateConnectAsync(WebSocketClient wsClient)
279348 __pathBuilder.AddRequiredParameter(_queryApiKeyName, _queryApiKey);
280349 }"
281350 : string . Empty ;
351+ var applyStoredAuthorization = GenerateStoredWebSocketAuthorization (
352+ wsClient ,
353+ hasStoredHeaderAuth ,
354+ hasSubprotocolAuth ) ;
355+ var connectionOptionsExtraParameter = hasSubprotocolAuth
356+ ? ",\n bool useSubprotocolAuth"
357+ : string . Empty ;
358+ var applyStoredAuthorizationCall = hasStoredHeaderAuth || hasSubprotocolAuth
359+ ? $ " ApplyStoredAuthorization({ ( hasSubprotocolAuth ? "useSubprotocolAuth" : "false" ) } );\n "
360+ : string . Empty ;
282361
283362 var connectionHelpers = @"
363+ " + applyStoredAuthorization + @"
284364 private void ApplyConnectionOptions(
285365 global::System.Collections.Generic.IDictionary<string, string>? additionalHeaders,
286366 global::System.Collections.Generic.IEnumerable<string>? additionalSubProtocols,
287- global::System.TimeSpan? keepAliveInterval)
367+ global::System.TimeSpan? keepAliveInterval" + connectionOptionsExtraParameter + @" )
288368 {
289369 if (keepAliveInterval is not null)
290370 {
291371 _clientWebSocket.Options.KeepAliveInterval = keepAliveInterval.Value;
292372 }
293373
374+ " + applyStoredAuthorizationCall + @"
294375 if (additionalHeaders is not null)
295376 {
296377 foreach (var header in additionalHeaders)
@@ -345,7 +426,7 @@ private void ApplyConnectionOptions(
345426 global::System.Collections.Generic.IEnumerable<string>? additionalSubProtocols = null,
346427 global::System.TimeSpan? keepAliveInterval = null,
347428 global::System.TimeSpan? connectTimeout = null,
348- global::System.Threading.CancellationToken cancellationToken = default)
429+ { ( hasSubprotocolAuth ? "bool useSubprotocolAuth = false, \n " : string . Empty ) } global::System.Threading.CancellationToken cancellationToken = default)
349430 {{
350431 global::System.Uri __uri;
351432 if (uri is not null)
@@ -360,7 +441,7 @@ private void ApplyConnectionOptions(
360441 __uri = new global::System.Uri(__pathBuilder.ToString());
361442 }}
362443
363- ApplyConnectionOptions(additionalHeaders, additionalSubProtocols, keepAliveInterval);
444+ ApplyConnectionOptions(additionalHeaders, additionalSubProtocols, keepAliveInterval{ ( hasSubprotocolAuth ? ", useSubprotocolAuth" : string . Empty ) } );
364445 await ConnectAsyncCore(__uri, connectTimeout, cancellationToken).ConfigureAwait(false);
365446 }}" ;
366447 }
@@ -449,6 +530,10 @@ void AppendParameter(MethodParameter param, bool preferNonNullableDefaultType)
449530 xmlDoc . AppendLine ( " /// <param name=\" additionalSubProtocols\" >Additional WebSocket subprotocols applied before connecting.</param>" ) ;
450531 xmlDoc . AppendLine ( " /// <param name=\" keepAliveInterval\" >Optional keep-alive interval.</param>" ) ;
451532 xmlDoc . AppendLine ( " /// <param name=\" connectTimeout\" >Optional connect timeout.</param>" ) ;
533+ if ( hasSubprotocolAuth )
534+ {
535+ xmlDoc . AppendLine ( " /// <param name=\" useSubprotocolAuth\" >When true, applies stored subprotocol authentication instead of header authentication.</param>" ) ;
536+ }
452537 xmlDoc . AppendLine ( " /// <param name=\" cancellationToken\" >A cancellation token.</param>" ) ;
453538
454539 foreach ( var param in serializedParams )
@@ -492,7 +577,7 @@ void AppendParameter(MethodParameter param, bool preferNonNullableDefaultType)
492577 global::System.Collections.Generic.IEnumerable<string>? additionalSubProtocols = null,
493578 global::System.TimeSpan? keepAliveInterval = null,
494579 global::System.TimeSpan? connectTimeout = null,
495- global::System.Threading.CancellationToken cancellationToken = default)
580+ { ( hasSubprotocolAuth ? "bool useSubprotocolAuth = false, \n " : string . Empty ) } global::System.Threading.CancellationToken cancellationToken = default)
496581 {{
497582 global::System.Uri __uri;
498583 if (uri is not null)
@@ -507,7 +592,7 @@ void AppendParameter(MethodParameter param, bool preferNonNullableDefaultType)
507592 __uri = new global::System.Uri(__pathBuilder.ToString());
508593 }}
509594
510- ApplyConnectionOptions(additionalHeaders, additionalSubProtocols, keepAliveInterval);
595+ ApplyConnectionOptions(additionalHeaders, additionalSubProtocols, keepAliveInterval{ ( hasSubprotocolAuth ? ", useSubprotocolAuth" : string . Empty ) } );
511596 await ConnectAsyncCore(__uri, connectTimeout, cancellationToken).ConfigureAwait(false);
512597 }}" : string . Empty ;
513598
@@ -526,7 +611,7 @@ void AppendParameter(MethodParameter param, bool preferNonNullableDefaultType)
526611 global::System.Collections.Generic.IEnumerable<string>? additionalSubProtocols = null,
527612 global::System.TimeSpan? keepAliveInterval = null,
528613 global::System.TimeSpan? connectTimeout = null,
529- global::System.Threading.CancellationToken cancellationToken = default)
614+ { ( hasSubprotocolAuth ? "bool useSubprotocolAuth = false, \n " : string . Empty ) } global::System.Threading.CancellationToken cancellationToken = default)
530615 {{
531616 global::System.Uri __uri;
532617 if (uri is not null)
@@ -544,8 +629,78 @@ void AppendParameter(MethodParameter param, bool preferNonNullableDefaultType)
544629 __uri = new global::System.Uri(__pathBuilder.ToString());
545630 }}
546631
547- ApplyConnectionOptions(additionalHeaders, additionalSubProtocols, keepAliveInterval);
632+ ApplyConnectionOptions(additionalHeaders, additionalSubProtocols, keepAliveInterval{ ( hasSubprotocolAuth ? ", useSubprotocolAuth" : string . Empty ) } );
548633 await ConnectAsyncCore(__uri, connectTimeout, cancellationToken).ConfigureAwait(false);
549634 }}" ;
550635 }
636+
637+ private static string GenerateStoredWebSocketAuthorization (
638+ WebSocketClient wsClient ,
639+ bool hasStoredHeaderAuth ,
640+ bool hasSubprotocolAuth )
641+ {
642+ if ( ! hasStoredHeaderAuth && ! hasSubprotocolAuth )
643+ {
644+ return string . Empty ;
645+ }
646+
647+ var subprotocolBlocks = wsClient . Authorizations
648+ . Where ( static auth => ! auth . WebSocketSubProtocols . IsEmpty )
649+ . Select ( auth =>
650+ {
651+ var condition = auth . Parameters . Length == 0
652+ ? "true"
653+ : string . Join ( " && " , auth . Parameters . Select ( parameter =>
654+ $@ "_subprotocolAuthorizationValues.ContainsKey(""{ parameter } "")") ) ;
655+ var parameterReads = auth . Parameters . Select ( parameter =>
656+ {
657+ var parameterName = parameter . ToParameterName ( ) ;
658+ return $@ " var __{ parameterName } = _subprotocolAuthorizationValues[""{ parameter } ""];
659+ " ;
660+ } ) . Inject ( emptyValue : string . Empty ) ;
661+ var addProtocols = auth . WebSocketSubProtocols
662+ . Distinct ( StringComparer . Ordinal )
663+ . Select ( template =>
664+ {
665+ var replacements = auth . Parameters . Select ( parameter =>
666+ {
667+ var parameterName = parameter . ToParameterName ( ) ;
668+ return $@ " __subProtocol = __subProtocol.Replace(""{{{parameter}}}"", __{ parameterName } );
669+ " ;
670+ } ) . Inject ( emptyValue : string . Empty ) ;
671+ return $@ " var __subProtocol = { template . ToCSharpStringLiteral ( ) } ;
672+ { replacements } _clientWebSocket.Options.AddSubProtocol(__subProtocol);
673+ " ;
674+ } ) . Inject ( emptyValue : string . Empty ) ;
675+
676+ return $@ " if ({ condition } )
677+ {{
678+ { parameterReads } { addProtocols } return;
679+ }}
680+ " ;
681+ } ) . Inject ( emptyValue : string . Empty ) ;
682+
683+ var headerBlock = hasStoredHeaderAuth
684+ ? @"
685+ if (_storedAuthorizationApiKey is not null &&
686+ _storedAuthorizationHeaderName is not null)
687+ {
688+ var __authorizationValue = _storedAuthorizationHeaderScheme is not null
689+ ? $""{_storedAuthorizationHeaderScheme} {_storedAuthorizationApiKey}""
690+ : _storedAuthorizationApiKey;
691+ _clientWebSocket.Options.SetRequestHeader(_storedAuthorizationHeaderName, __authorizationValue);
692+ }"
693+ : string . Empty ;
694+
695+ return $@ "
696+ private void ApplyStoredAuthorization(
697+ bool useSubprotocolAuth)
698+ {{
699+ { ( hasSubprotocolAuth ? $@ " if (useSubprotocolAuth || _preferSubprotocolAuth)
700+ {{
701+ { subprotocolBlocks } return;
702+ }}
703+ " : string . Empty ) } { headerBlock }
704+ }}" ;
705+ }
551706}
0 commit comments