Skip to content

Commit f1908da

Browse files
committed
Replaced brittle prompt with pydantic schema validation
1 parent 838375f commit f1908da

File tree

3 files changed

+44
-33
lines changed

3 files changed

+44
-33
lines changed

GSoC25/NEF/NEF.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
from urllib.parse import quote
1010
from getpass import getpass
1111

12+
try:
13+
from .output_schema import DisambiguationResult
14+
except ImportError:
15+
from output_schema import DisambiguationResult
16+
1217
# =============== Gemini client bootstrap ===============
1318
try:
1419
from google import genai
@@ -280,13 +285,6 @@ def _safe_idx(lst: List[Any], value: Any) -> Optional[int]:
280285
i = 0
281286
return max(0, min(i, len(lst) - 1))
282287

283-
@staticmethod
284-
def _get_json(text: Optional[str]) -> Dict[str, Any]:
285-
try:
286-
return json.loads((text or "").strip() or "{}")
287-
except Exception:
288-
return {}
289-
290288
# --------------------- main entrypoint ---------------------
291289

292290
def disambiguate_triple(
@@ -332,46 +330,53 @@ def disambiguate_triple(
332330
subj_list_text = self._fmt_indexed(subject_candidates)
333331
obj_list_text = self._fmt_indexed(object_candidates)
334332

335-
prompt = f"""Pick the best RDF triple using ONLY these options.
333+
prompt = f"""
334+
Analyze the context and disambiguate the triple by selecting the correct Subject, Predicate, and Object.
336335
337-
Allowed predicate URIs:
338-
{pred_list_text}
336+
Context: {context}
339337
340-
Subject candidates (choose by INDEX):
341-
{subj_list_text}
338+
Allowed Predicates (Select one URI):
339+
{pred_list_text}
342340
343-
Object candidates (choose by INDEX):
344-
{obj_list_text}
341+
Subject Candidates (Select by Index):
342+
{subj_list_text}
345343
346-
Context (helps decide, but does NOT add new options):
347-
{context}
344+
Object Candidates (Select by Index):
345+
{obj_list_text}
346+
"""
348347

349-
Return ONLY strict JSON on one line (no prose):
350-
{{"subject_index": 0, "predicate": "URI", "object_index": 0}}
351-
Rules:
352-
- "predicate" MUST be exactly one URI from Allowed predicate URIs.
353-
- "subject_index" MUST be an integer index from Subject candidates.
354-
- "object_index" MUST be an integer index from Object candidates.
355-
- Do not invent or modify URIs. Do not swap roles.
356-
"""
348+
generation_config = types.GenerateContentConfig(
349+
response_mime_type = "application/json",
350+
response_scheme = DisambiguationResult # <--- The Pydantic class output
351+
)
352+
357353

358354
# Call the model (Gemini client style)
359355
resp = self.client.models.generate_content(
360356
model=self.model_name,
361357
contents=prompt,
362-
config={"response_mime_type": "application/json"},
358+
config=generation_config
363359
)
364360

365-
data = self._get_json(getattr(resp, "text", None))
361+
try:
362+
data = DisambiguationResult.model_validate_json(resp.text)
363+
pred_uri = data.predicate_uri
364+
s_idx = data.subject_index
365+
o_idx = data.object_index
366+
367+
# Validate predicate
368+
if pred_uri not in allowed:
369+
pred_uri = None ## if none
366370

367-
# Validate predicate
368-
pred_uri = data.get("predicate", "")
369-
if pred_uri not in allowed:
370-
pred_uri = allowed[0]
371+
except Exception as e:
372+
print(f"validation error{e}")
373+
pred_uri=None
374+
s_idx=0
375+
o_idx=0
371376

372377
# Clamp indices and map to URIs
373-
si = self._safe_idx(subject_candidates, data.get("subject_index", 0))
374-
oi = self._safe_idx(object_candidates, data.get("object_index", 0))
378+
si = self._safe_idx(subject_candidates, s_idx)
379+
oi = self._safe_idx(object_candidates, o_idx)
375380

376381
s_uri = subject_candidates[si][0] if (subject_candidates and si is not None) else ""
377382
o_uri = object_candidates[oi][0] if (object_candidates and oi is not None) else ""
@@ -392,7 +397,7 @@ def disambiguate_triple(
392397
# Build meta (compatible with your previous code)
393398
chosen_sim, rank0 = sim_map.get(pred_uri, (None, None))
394399
meta = {
395-
"label": "candidate",
400+
"label": "candidate" if pred_uri else "hallucination_rejected",
396401
"chosen_similarity": float(chosen_sim) if chosen_sim is not None else None,
397402
"rank_in_topk": (rank0 + 1) if rank0 is not None else None,
398403
"topk": total_k,

GSoC25/NEF/__init__.py

Whitespace-only changes.

GSoC25/NEF/output_schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from pydantic import BaseModel, Field
2+
3+
class DisambiguationResult(BaseModel):
4+
subject_index: int = Field(description="The integer index of the correct subject from the provided list.")
5+
predicate_uri: str = Field(description="The exact URI of the predicate selected from the allowed list.")
6+
object_index: int = Field(description="The integer index of the correct object from the provided list.")

0 commit comments

Comments
 (0)