Skip to content

Commit c1e50d5

Browse files
committed
Fix dia2 auto-install: clone from GitHub + editable pip install
dia2 is not on PyPI and has a packaging bug (missing subpackages). Auto-installer now clones the repo and uses pip install -e --no-deps. Also added version-agnostic TemplateResponse wrapper for Starlette 0.x/1.x compatibility (dia2 deps can downgrade starlette).
1 parent 252db2d commit c1e50d5

File tree

3 files changed

+56
-16
lines changed

3 files changed

+56
-16
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ env/
3636
outputs/
3737
reference_audio/
3838
model_cache/ # Also good practice to ignore the model cache
39+
dia2_src/ # Auto-cloned dia2 source (installed via editable pip)
3940

4041
# Ignore test reports/coverage
4142
.coverage

engine.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -382,19 +382,51 @@ def get_compute_dtype(device: torch.device, weights_filename: str) -> str:
382382

383383

384384
def _auto_install_dia2():
385-
"""Attempts to install the dia2 package via pip at runtime."""
385+
"""
386+
Attempts to install the dia2 package from GitHub at runtime.
387+
dia2 is NOT on PyPI — must be cloned and installed in editable mode
388+
because the pyproject.toml has a packaging bug (missing subpackages).
389+
"""
386390
global DIA2_AVAILABLE, Dia2, GenerationConfig, SamplingConfig, PrefixConfig, GenerationResult
387391
import subprocess
388392
import sys
389393

394+
DIA2_REPO_URL = "https://github.com/nari-labs/dia2.git"
395+
dia2_src_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "dia2_src")
396+
390397
try:
391398
_check_cancelled()
399+
400+
# Step 1: Clone the dia2 repo if not already present
401+
if not os.path.isdir(dia2_src_dir):
402+
logger.info(f"Cloning dia2 repository from {DIA2_REPO_URL}...")
403+
_update_download_status("downloading", "Cloning dia2 repository from GitHub...", 10)
404+
result = subprocess.run(
405+
["git", "clone", DIA2_REPO_URL, dia2_src_dir],
406+
capture_output=True, text=True, timeout=300,
407+
)
408+
if result.returncode != 0:
409+
raise RuntimeError(f"git clone failed:\n{result.stderr}")
410+
logger.info("dia2 repository cloned successfully.")
411+
else:
412+
logger.info(f"dia2 source already exists at {dia2_src_dir}, pulling latest...")
413+
_update_download_status("downloading", "Updating dia2 repository...", 10)
414+
subprocess.run(
415+
["git", "-C", dia2_src_dir, "pull", "--ff-only"],
416+
capture_output=True, text=True, timeout=60,
417+
)
418+
419+
_check_cancelled()
420+
421+
# Step 2: Install in editable mode with --no-deps to avoid breaking other packages
422+
logger.info("Installing dia2 package (editable mode, no-deps)...")
423+
_update_download_status("installing", "Installing dia2 package...", 25)
392424
result = subprocess.run(
393-
[sys.executable, "-m", "pip", "install", "dia2"],
394-
capture_output=True, text=True, timeout=600,
425+
[sys.executable, "-m", "pip", "install", "-e", dia2_src_dir, "--no-deps"],
426+
capture_output=True, text=True, timeout=300,
395427
)
396428
if result.returncode != 0:
397-
raise RuntimeError(f"pip install dia2 failed:\n{result.stderr}")
429+
raise RuntimeError(f"pip install -e dia2 failed:\n{result.stderr}")
398430

399431
logger.info("dia2 package installed successfully. Importing...")
400432
from dia2 import Dia2 as _Dia2, GenerationConfig as _GC, SamplingConfig as _SC, PrefixConfig as _PC, GenerationResult as _GR
@@ -409,7 +441,7 @@ def _auto_install_dia2():
409441
logger.error(f"Failed to auto-install dia2: {e}", exc_info=True)
410442
raise ImportError(
411443
f"dia2 package could not be installed automatically: {e}. "
412-
"Please install it manually: pip install dia2"
444+
"Please install manually: git clone https://github.com/nari-labs/dia2.git dia2_src && pip install -e dia2_src --no-deps"
413445
)
414446

415447

server.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,20 @@ def _delayed_browser_open(host, port):
221221
templates = Jinja2Templates(directory="ui")
222222

223223

224+
def _render_template(request, name, context, status_code=200):
225+
"""Version-agnostic TemplateResponse wrapper for Starlette 0.x and 1.x."""
226+
import starlette
227+
major = int(starlette.__version__.split(".")[0])
228+
if major >= 1:
229+
# Starlette 1.0+: request is a separate positional arg
230+
return templates.TemplateResponse(request, name, context=context, status_code=status_code)
231+
else:
232+
# Starlette 0.x: request goes inside the context dict
233+
ctx = dict(context) if context else {}
234+
ctx["request"] = request
235+
return templates.TemplateResponse(name, ctx, status_code=status_code)
236+
237+
224238
# --- Configuration Routes (New YAML-based) ---
225239
@app.post(
226240
"/save_settings",
@@ -831,7 +845,7 @@ async def get_web_ui(request: Request):
831845
"success": None,
832846
}
833847

834-
return templates.TemplateResponse(request, "index.html", context=template_context)
848+
return _render_template(request, "index.html", template_context)
835849

836850
except Exception as e:
837851
logger.error(f"Error rendering Web UI: {e}", exc_info=True)
@@ -909,7 +923,7 @@ async def handle_web_ui_generate(
909923
"output_file_url": None,
910924
"generation_time": None,
911925
}
912-
return templates.TemplateResponse(request, "index.html", context=error_context, status_code=503)
926+
return _render_template(request, "index.html", error_context, status_code=503)
913927

914928
# --- Start processing the valid request ---
915929
logger.info(
@@ -1052,9 +1066,7 @@ async def handle_web_ui_generate(
10521066
"output_file_url": None,
10531067
"generation_time": None,
10541068
}
1055-
return templates.TemplateResponse(
1056-
request, "index.html", context=error_context, status_code=400
1057-
) # Bad Request
1069+
return _render_template(request, "index.html", error_context, status_code=400)
10581070

10591071
# --- Generation ---
10601072
try:
@@ -1218,12 +1230,7 @@ async def handle_web_ui_generate(
12181230
}
12191231

12201232
# Render and return the HTML response
1221-
return templates.TemplateResponse(
1222-
request,
1223-
"index.html",
1224-
context=template_context,
1225-
status_code=status_code,
1226-
)
1233+
return _render_template(request, "index.html", template_context, status_code=status_code)
12271234

12281235

12291236
# --- Reference Audio Upload Endpoint ---

0 commit comments

Comments
 (0)