Skip to content

Commit bb05a2e

Browse files
committed
Add support for official models (server side)
1 parent 3155811 commit bb05a2e

2 files changed

Lines changed: 107 additions & 5 deletions

File tree

mod.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
"huggingface-api-key": { "type": "string", "default": "" },
4444
"use-platinum": { "type": "bool", "default": false },
4545
"ollama-model": { "type": "string", "default": "entity12208/editorai:deepseek" },
46+
"local-url": { "type": "string", "default": "http://localhost:11435" },
47+
"local-model": { "type": "string", "default": "editorai" },
4648
"ollama-timeout": { "type": "int", "default": 600, "min": 60, "max": 1800 },
4749
"enable-rate-limiting": { "type": "bool", "default": true },
4850
"rate-limit-seconds": { "type": "int", "default": 3, "min": 1, "max": 60 },

src/main.cpp

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ static std::string getProviderApiKey(const std::string& provider) {
265265
if (provider == "openai") return Mod::get()->getSettingValue<std::string>("openai-api-key");
266266
if (provider == "ministral") return Mod::get()->getSettingValue<std::string>("ministral-api-key");
267267
if (provider == "huggingface") return Mod::get()->getSettingValue<std::string>("huggingface-api-key");
268-
return ""; // ollama — no key needed
268+
return ""; // ollama / local — no key needed
269269
}
270270

271271
static std::string getProviderModel(const std::string& provider) {
@@ -275,6 +275,7 @@ static std::string getProviderModel(const std::string& provider) {
275275
if (provider == "ministral") return Mod::get()->getSettingValue<std::string>("ministral-model");
276276
if (provider == "huggingface") return Mod::get()->getSettingValue<std::string>("huggingface-model");
277277
if (provider == "ollama") return Mod::get()->getSettingValue<std::string>("ollama-model");
278+
if (provider == "local") return Mod::get()->getSettingValue<std::string>("local-model");
278279
return "unknown";
279280
}
280281

@@ -285,6 +286,10 @@ static std::string getOllamaUrl() {
285286
: "http://localhost:11434";
286287
}
287288

289+
static std::string getLocalUrl() {
290+
return Mod::get()->getSettingValue<std::string>("local-url");
291+
}
292+
288293
// ─── Deferred object struct ───────────────────────────────────────────────────
289294

290295
struct DeferredObject {
@@ -799,7 +804,9 @@ class AISettingsPopup : public Popup {
799804
std::vector<IntInputMeta> m_intInputs;
800805
std::vector<CCMenuItemSpriteExtra*> m_tabBtns;
801806
async::TaskHolder<web::WebResponse> m_ollamaListener;
807+
async::TaskHolder<web::WebResponse> m_localListener;
802808
std::vector<std::string> m_ollamaModels; // auto-detected
809+
std::vector<std::string> m_localModels; // auto-detected from local server
803810

804811
float m_rowY = 0;
805812
float m_labelX = 0;
@@ -1023,7 +1030,7 @@ class AISettingsPopup : public Popup {
10231030

10241031
void buildGeneralTab() {
10251032
addCycler("Provider:", "ai-provider",
1026-
{"gemini", "claude", "openai", "ministral", "huggingface", "ollama"});
1033+
{"gemini", "claude", "openai", "ministral", "huggingface", "ollama", "local"});
10271034
addCycler("Difficulty:", "difficulty",
10281035
{"easy", "medium", "hard", "extreme"});
10291036
addCycler("Style:", "style",
@@ -1074,10 +1081,30 @@ class AISettingsPopup : public Popup {
10741081
}
10751082

10761083
addIntRow("Timeout (s):", "ollama-timeout", 60, 1800, 600);
1084+
} else if (provider == "local") {
1085+
addTextRow("Server URL:", "local-url", "http://localhost:11435", 200);
1086+
1087+
// Auto-detect models from local server
1088+
if (m_localModels.empty()) {
1089+
addInfoRow("Model:", "Detecting...", {200, 200, 100});
1090+
fetchLocalModels();
1091+
} else if (m_localModels.size() == 1 && m_localModels[0].rfind("(", 0) == 0) {
1092+
addInfoRow("Model:", m_localModels[0], {255, 100, 100});
1093+
} else {
1094+
addCycler("Model:", "local-model", m_localModels);
1095+
}
1096+
1097+
auto hint = CCLabelBMFont::create(
1098+
"Your own trained model. Run: python training/server/serve.py",
1099+
"bigFont.fnt");
1100+
hint->setScale(0.18f);
1101+
hint->setColor({150, 150, 150});
1102+
hint->setPosition({this->m_size.width / 2, m_rowY - 8});
1103+
m_contentLayer->addChild(hint);
10771104
}
10781105

10791106
// Hint about key security
1080-
if (provider != "ollama") {
1107+
if (provider != "ollama" && provider != "local") {
10811108
auto hint = CCLabelBMFont::create("Keys stored locally in Geode save data.", "bigFont.fnt");
10821109
hint->setScale(0.2f);
10831110
hint->setColor({150, 150, 150});
@@ -1162,6 +1189,56 @@ class AISettingsPopup : public Popup {
11621189
);
11631190
}
11641191

1192+
void fetchLocalModels() {
1193+
// Save text inputs first so local-url is committed
1194+
saveTextInputs();
1195+
std::string localUrl = Mod::get()->getSettingValue<std::string>("local-url");
1196+
log::info("Fetching models from Local AI server: {}/api/tags", localUrl);
1197+
1198+
auto request = web::WebRequest();
1199+
request.timeout(std::chrono::seconds(5));
1200+
1201+
m_localListener.spawn(
1202+
request.get(localUrl + "/api/tags"),
1203+
[this](web::WebResponse response) {
1204+
if (!response.ok()) {
1205+
int code = response.code();
1206+
log::warn("Local AI server not reachable: HTTP {}", code);
1207+
m_localModels = code == 0
1208+
? std::vector<std::string>{"(server not running)"}
1209+
: std::vector<std::string>{fmt::format("(error: HTTP {})", code)};
1210+
buildTab(SettingsTab::Provider);
1211+
return;
1212+
}
1213+
1214+
auto jsonRes = response.json();
1215+
if (!jsonRes) {
1216+
m_localModels = {"(invalid response)"};
1217+
buildTab(SettingsTab::Provider);
1218+
return;
1219+
}
1220+
1221+
auto json = jsonRes.unwrap();
1222+
m_localModels.clear();
1223+
1224+
if (json.contains("models") && json["models"].isArray()) {
1225+
for (size_t i = 0; i < json["models"].size(); ++i) {
1226+
auto nameResult = json["models"][i]["name"].asString();
1227+
if (nameResult)
1228+
m_localModels.push_back(nameResult.unwrap());
1229+
}
1230+
}
1231+
1232+
if (m_localModels.empty()) {
1233+
m_localModels = {"(no models loaded)"};
1234+
}
1235+
1236+
log::info("Detected {} models from Local AI server", m_localModels.size());
1237+
buildTab(SettingsTab::Provider);
1238+
}
1239+
);
1240+
}
1241+
11651242
// ── Cycler callbacks ────────────────────────────────────────────────
11661243

11671244
void onCycleSetting(CCObject* sender) {
@@ -1196,6 +1273,7 @@ class AISettingsPopup : public Popup {
11961273
// If provider changed, rebuild Provider tab data
11971274
if (info.settingId == "ai-provider" && m_currentTab == SettingsTab::Provider) {
11981275
m_ollamaModels.clear();
1276+
m_localModels.clear();
11991277
buildTab(SettingsTab::Provider);
12001278
}
12011279

@@ -1447,6 +1525,8 @@ class AIGeneratorPopup : public Popup {
14471525
if (provider == "ollama") {
14481526
bool usePlatinum = Mod::get()->getSettingValue<bool>("use-platinum");
14491527
keyStatus = usePlatinum ? "<cg>Platinum cloud</c>" : "<cg>Local — no key needed</c>";
1528+
} else if (provider == "local") {
1529+
keyStatus = "<cg>Local AI — no key needed</c>";
14501530
} else {
14511531
keyStatus = apiKey.empty()
14521532
? "<cr>Not set — go to mod settings</c>"
@@ -2570,6 +2650,24 @@ class AIGeneratorPopup : public Popup {
25702650
requestBody["options"] = options;
25712651

25722652
url = ollamaUrl + "/api/generate";
2653+
2654+
// ── Local AI (EditorAI trained model) ─────────────────────────────────
2655+
// Uses the same NDJSON streaming API as Ollama, served by training/server/serve.py
2656+
} else if (provider == "local") {
2657+
std::string localUrl = getLocalUrl();
2658+
log::info("Using Local AI at: {}", localUrl + "/api/generate");
2659+
2660+
auto options = matjson::Value::object();
2661+
options["temperature"] = 0.7;
2662+
2663+
requestBody = matjson::Value::object();
2664+
requestBody["model"] = model;
2665+
requestBody["prompt"] = systemPrompt + "\n\n" + fullPrompt;
2666+
requestBody["stream"] = true;
2667+
requestBody["format"] = "json";
2668+
requestBody["options"] = options;
2669+
2670+
url = localUrl + "/api/generate";
25732671
}
25742672

25752673
std::string jsonBody = requestBody.dump();
@@ -2590,6 +2688,8 @@ class AIGeneratorPopup : public Popup {
25902688
// Ollama can be very slow on large models with partial GPU offload.
25912689
int timeoutSec = (int)Mod::get()->getSettingValue<int64_t>("ollama-timeout");
25922690
request.timeout(std::chrono::seconds(timeoutSec));
2691+
} else if (provider == "local") {
2692+
request.timeout(std::chrono::seconds(300));
25932693
}
25942694

25952695
request.bodyString(jsonBody);
@@ -2647,7 +2747,7 @@ class AIGeneratorPopup : public Popup {
26472747
std::string provider = Mod::get()->getSettingValue<std::string>("ai-provider");
26482748
std::string apiKey = getProviderApiKey(provider);
26492749

2650-
if (apiKey.empty() && provider != "ollama") {
2750+
if (apiKey.empty() && provider != "ollama" && provider != "local") {
26512751
FLAlertLayer::create("API Key Required",
26522752
gd::string(fmt::format(
26532753
"Please open mod settings and enter your API key under the {} section.",
@@ -2712,7 +2812,7 @@ class AIGeneratorPopup : public Popup {
27122812
// and always fails on streaming output. We must parse line by line,
27132813
// accumulate all "response" fields, and verify "done":true on the
27142814
// final line.
2715-
if (provider == "ollama") {
2815+
if (provider == "ollama" || provider == "local") {
27162816
auto rawResult = response.string();
27172817
if (!rawResult) {
27182818
onError("Invalid Response", "The API returned invalid data.");

0 commit comments

Comments
 (0)