Skip to content

Commit 9765fdf

Browse files
committed
Replaced brittle prompt with pydantic schema validation
1 parent 838375f commit 9765fdf

File tree

2 files changed

+31
-31
lines changed

2 files changed

+31
-31
lines changed

GSoC25/NEF/NEF.py

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from urllib.parse import quote
1010
from getpass import getpass
1111

12+
from output_schema import DisambiguationResult
13+
1214
# =============== Gemini client bootstrap ===============
1315
try:
1416
from google import genai
@@ -280,13 +282,6 @@ def _safe_idx(lst: List[Any], value: Any) -> Optional[int]:
280282
i = 0
281283
return max(0, min(i, len(lst) - 1))
282284

283-
@staticmethod
284-
def _get_json(text: Optional[str]) -> Dict[str, Any]:
285-
try:
286-
return json.loads((text or "").strip() or "{}")
287-
except Exception:
288-
return {}
289-
290285
# --------------------- main entrypoint ---------------------
291286

292287
def disambiguate_triple(
@@ -332,46 +327,45 @@ def disambiguate_triple(
332327
subj_list_text = self._fmt_indexed(subject_candidates)
333328
obj_list_text = self._fmt_indexed(object_candidates)
334329

335-
prompt = f"""Pick the best RDF triple using ONLY these options.
330+
prompt = f"""
331+
Analyze the context and disambiguate the triple by selecting the correct Subject, Predicate, and Object.
336332
337-
Allowed predicate URIs:
338-
{pred_list_text}
333+
Context: {context}
339334
340-
Subject candidates (choose by INDEX):
341-
{subj_list_text}
335+
Allowed Predicates (Select one URI):
336+
{pred_list_text}
342337
343-
Object candidates (choose by INDEX):
344-
{obj_list_text}
338+
Subject Candidates (Select by Index):
339+
{subj_list_text}
345340
346-
Context (helps decide, but does NOT add new options):
347-
{context}
341+
Object Candidates (Select by Index):
342+
{obj_list_text}
343+
"""
348344

349-
Return ONLY strict JSON on one line (no prose):
350-
{{"subject_index": 0, "predicate": "URI", "object_index": 0}}
351-
Rules:
352-
- "predicate" MUST be exactly one URI from Allowed predicate URIs.
353-
- "subject_index" MUST be an integer index from Subject candidates.
354-
- "object_index" MUST be an integer index from Object candidates.
355-
- Do not invent or modify URIs. Do not swap roles.
356-
"""
345+
generation_config = {
346+
"response_mime_type": "application/json",
347+
"response_schema": DisambiguationResult # <--- The Pydantic class output
348+
}
357349

358350
# Call the model (Gemini client style)
359351
resp = self.client.models.generate_content(
360352
model=self.model_name,
361353
contents=prompt,
362-
config={"response_mime_type": "application/json"},
354+
config=generation_config
363355
)
364356

365-
data = self._get_json(getattr(resp, "text", None))
357+
data = DisambiguationResult.model_validate_json(resp.text)
358+
pred_uri = data.predicate_uri
359+
s_idx = data.subject_index
360+
o_idx = data.object_index
366361

367362
# Validate predicate
368-
pred_uri = data.get("predicate", "")
369363
if pred_uri not in allowed:
370-
pred_uri = allowed[0]
364+
pred_uri = None ## if none
371365

372366
# Clamp indices and map to URIs
373-
si = self._safe_idx(subject_candidates, data.get("subject_index", 0))
374-
oi = self._safe_idx(object_candidates, data.get("object_index", 0))
367+
si = self._safe_idx(subject_candidates, s_idx)
368+
oi = self._safe_idx(object_candidates, o_idx)
375369

376370
s_uri = subject_candidates[si][0] if (subject_candidates and si is not None) else ""
377371
o_uri = object_candidates[oi][0] if (object_candidates and oi is not None) else ""
@@ -392,7 +386,7 @@ def disambiguate_triple(
392386
# Build meta (compatible with your previous code)
393387
chosen_sim, rank0 = sim_map.get(pred_uri, (None, None))
394388
meta = {
395-
"label": "candidate",
389+
"label": "candidate" if pred_uri else "hallucinaiton_rejected",
396390
"chosen_similarity": float(chosen_sim) if chosen_sim is not None else None,
397391
"rank_in_topk": (rank0 + 1) if rank0 is not None else None,
398392
"topk": total_k,

GSoC25/NEF/output_schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from pydantic import BaseModel, Field
2+
3+
class DisambiguationResult(BaseModel):
4+
subject_index: int = Field(description="The integer index of the correct subject from the provided list.")
5+
predicate_uri: str = Field(description="The exact URI of the predicate selected from the allowed list.")
6+
object_index: int = Field(description="The integer index of the correct object from the provided list.")

0 commit comments

Comments
 (0)