@@ -2,16 +2,27 @@ import { createImplementation } from "@webiny/feature/api";
22import { generateText } from "ai" ;
33import { streamText } from "ai" ;
44import { Ai as AiAbstraction } from "./abstractions.js" ;
5- import { AiGateway } from "./abstractions.js" ;
5+ import { AiSdkFactory } from "./abstractions.js" ;
6+ import { AiConnectionFactory } from "./abstractions.js" ;
67import type { AiGenerateTextParams } from "./abstractions.js" ;
78import type { AiStreamTextParams } from "./abstractions.js" ;
9+ import type { IAiSdk } from "./abstractions.js" ;
10+ import type { IAiConnection } from "./abstractions.js" ;
11+ import type { IAiConnectionInline } from "./abstractions.js" ;
12+ import type { LanguageModel } from "ai" ;
813
914class AiImpl implements AiAbstraction . Interface {
10- constructor ( private readonly aiGateway : AiGateway . Interface ) { }
15+ private sdkCache = new Map < string , IAiSdk > ( ) ;
16+ private resolvedConnections : IAiConnection [ ] | null = null ;
17+
18+ constructor (
19+ private readonly sdkFactories : AiSdkFactory . Interface [ ] ,
20+ private readonly connectionFactories : AiConnectionFactory . Interface [ ]
21+ ) { }
1122
1223 generateText ( params : AiGenerateTextParams ) : ReturnType < typeof generateText > {
13- const { model, ...rest } = params ;
14- return this . aiGateway . languageModel ( model ) . then ( resolvedModel => {
24+ const { model, connection , ...rest } = params ;
25+ return this . resolveLanguageModel ( model , connection ) . then ( resolvedModel => {
1526 // Cast required: spreading the discriminated Prompt union loses its narrowing.
1627 return generateText ( { model : resolvedModel , ...rest } as Parameters <
1728 typeof generateText
@@ -20,15 +31,118 @@ class AiImpl implements AiAbstraction.Interface {
2031 }
2132
2233 async streamText ( params : AiStreamTextParams ) : Promise < ReturnType < typeof streamText > > {
23- const { model, ...rest } = params ;
24- const resolvedModel = await this . aiGateway . languageModel ( model ) ;
34+ const { model, connection , ...rest } = params ;
35+ const resolvedModel = await this . resolveLanguageModel ( model , connection ) ;
2536 // Cast required: spreading the discriminated Prompt union loses its narrowing.
2637 return streamText ( { model : resolvedModel , ...rest } as Parameters < typeof streamText > [ 0 ] ) ;
2738 }
39+
40+ async listModels ( connection ?: string | IAiConnectionInline ) : Promise < string [ ] > {
41+ if ( connection !== undefined ) {
42+ const conn = await this . resolveConnection ( undefined , connection ) ;
43+ const sdk = await this . getSdk ( conn ) ;
44+ return sdk . listModels ( ) . map ( m => `${ conn . sdkName } /${ m } ` ) ;
45+ }
46+
47+ const connections = await this . getConnections ( ) ;
48+ const results = await Promise . all (
49+ connections . map ( async conn => {
50+ const sdk = await this . getSdk ( conn ) ;
51+ return sdk . listModels ( ) . map ( m => `${ conn . sdkName } /${ m } ` ) ;
52+ } )
53+ ) ;
54+ return results . flat ( ) ;
55+ }
56+
57+ private async resolveLanguageModel (
58+ modelId : string ,
59+ connection ?: string | IAiConnectionInline
60+ ) : Promise < LanguageModel > {
61+ const slashIndex = modelId . indexOf ( "/" ) ;
62+ if ( slashIndex === - 1 ) {
63+ throw new Error (
64+ `Invalid model ID "${ modelId } ". Expected format: "<sdkName>/<modelId>" (e.g. "openai/gpt-4o").`
65+ ) ;
66+ }
67+
68+ const sdkName = modelId . slice ( 0 , slashIndex ) ;
69+ const modelName = modelId . slice ( slashIndex + 1 ) ;
70+
71+ const conn = await this . resolveConnection ( sdkName , connection ) ;
72+ const sdk = await this . getSdk ( conn ) ;
73+ return sdk . languageModel ( modelName ) ;
74+ }
75+
76+ private async getConnections ( ) : Promise < IAiConnection [ ] > {
77+ if ( ! this . resolvedConnections ) {
78+ this . resolvedConnections = await Promise . all (
79+ this . connectionFactories . map ( f => f . execute ( ) )
80+ ) ;
81+ }
82+ return this . resolvedConnections ;
83+ }
84+
85+ private async resolveConnection (
86+ sdkName : string | undefined ,
87+ connection ?: string | IAiConnectionInline
88+ ) : Promise < IAiConnectionInline > {
89+ if ( typeof connection === "object" ) {
90+ return connection ;
91+ }
92+
93+ const connections = await this . getConnections ( ) ;
94+
95+ if ( typeof connection === "string" ) {
96+ const found = connections . find ( c => c . id === connection ) ;
97+ if ( ! found ) {
98+ const known = connections . map ( c => `"${ c . id } "` ) . join ( ", " ) ;
99+ throw new Error (
100+ `Unknown AI connection "${ connection } ". Registered connections: ${ known } .`
101+ ) ;
102+ }
103+ return found ;
104+ }
105+
106+ const found = connections . find ( c => c . sdkName === sdkName ) ;
107+ if ( ! found ) {
108+ const known = connections . map ( c => `"${ c . id } " (${ c . sdkName } )` ) . join ( ", " ) ;
109+ throw new Error (
110+ `No AI connection found for SDK "${ sdkName } ". Registered connections: ${ known } .`
111+ ) ;
112+ }
113+ return found ;
114+ }
115+
116+ private async getSdk ( connection : IAiConnectionInline ) : Promise < IAiSdk > {
117+ const cacheKey =
118+ "id" in connection
119+ ? ( connection as IAiConnection ) . id
120+ : `${ connection . sdkName } :${ connection . apiKey ?? "__env__" } ` ;
121+
122+ const cached = this . sdkCache . get ( cacheKey ) ;
123+ if ( cached ) {
124+ return cached ;
125+ }
126+
127+ const factory = this . sdkFactories . find ( f => f . name === connection . sdkName ) ;
128+ if ( ! factory ) {
129+ const known = this . sdkFactories . map ( f => `"${ f . name } "` ) . join ( ", " ) ;
130+ throw new Error (
131+ `No AI SDK factory found for "${ connection . sdkName } ". Registered factories: ${ known } .`
132+ ) ;
133+ }
134+
135+ const sdk = await factory . execute ( connection . apiKey ) ;
136+ this . sdkCache . set ( cacheKey , sdk ) ;
137+ return sdk ;
138+ }
28139}
29140
30141export const Ai = createImplementation ( {
31142 abstraction : AiAbstraction ,
32143 implementation : AiImpl ,
33- dependencies : [ AiGateway ]
144+ dependencies : [
145+ [ AiSdkFactory , { multiple : true } ] ,
146+ [ AiConnectionFactory , { multiple : true } ]
147+ ]
34148} ) ;
0 commit comments