Skip to content

Commit ea527ae

Browse files
committed
Replaced brittle prompt with pydantic schema validation
1 parent 838375f commit ea527ae

File tree

3 files changed

+48
-33
lines changed

3 files changed

+48
-33
lines changed

GSoC25/NEF/NEF.py

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
from urllib.parse import quote
1010
from getpass import getpass
1111

12+
from pydantic import ValidationError
13+
try:
14+
from .output_schema import DisambiguationResult
15+
except ImportError:
16+
from output_schema import DisambiguationResult
17+
1218
# =============== Gemini client bootstrap ===============
1319
try:
1420
from google import genai
@@ -280,13 +286,6 @@ def _safe_idx(lst: List[Any], value: Any) -> Optional[int]:
280286
i = 0
281287
return max(0, min(i, len(lst) - 1))
282288

283-
@staticmethod
284-
def _get_json(text: Optional[str]) -> Dict[str, Any]:
285-
try:
286-
return json.loads((text or "").strip() or "{}")
287-
except Exception:
288-
return {}
289-
290289
# --------------------- main entrypoint ---------------------
291290

292291
def disambiguate_triple(
@@ -332,46 +331,56 @@ def disambiguate_triple(
332331
subj_list_text = self._fmt_indexed(subject_candidates)
333332
obj_list_text = self._fmt_indexed(object_candidates)
334333

335-
prompt = f"""Pick the best RDF triple using ONLY these options.
334+
prompt = f"""
335+
Analyze the context and disambiguate the triple by selecting the correct Subject, Predicate, and Object.
336336
337-
Allowed predicate URIs:
338-
{pred_list_text}
337+
Context: {context}
339338
340-
Subject candidates (choose by INDEX):
341-
{subj_list_text}
339+
Allowed Predicates (Select one URI):
340+
{pred_list_text}
342341
343-
Object candidates (choose by INDEX):
344-
{obj_list_text}
342+
Subject Candidates (Select by Index):
343+
{subj_list_text}
345344
346-
Context (helps decide, but does NOT add new options):
347-
{context}
345+
Object Candidates (Select by Index):
346+
{obj_list_text}
347+
"""
348348

349-
Return ONLY strict JSON on one line (no prose):
350-
{{"subject_index": 0, "predicate": "URI", "object_index": 0}}
351-
Rules:
352-
- "predicate" MUST be exactly one URI from Allowed predicate URIs.
353-
- "subject_index" MUST be an integer index from Subject candidates.
354-
- "object_index" MUST be an integer index from Object candidates.
355-
- Do not invent or modify URIs. Do not swap roles.
356-
"""
349+
generation_config = types.GenerateContentConfig(
350+
response_mime_type="application/json",
351+
response_schema=DisambiguationResult # <--- The Pydantic class output
352+
)
353+
357354

358355
# Call the model (Gemini client style)
359356
resp = self.client.models.generate_content(
360357
model=self.model_name,
361358
contents=prompt,
362-
config={"response_mime_type": "application/json"},
359+
config=generation_config
363360
)
364361

365-
data = self._get_json(getattr(resp, "text", None))
362+
try:
363+
data = DisambiguationResult.model_validate_json(resp.text)
364+
pred_uri = data.predicate_uri
365+
s_idx = data.subject_index
366+
o_idx = data.object_index
367+
368+
# Validate predicate
369+
if pred_uri not in allowed:
370+
if self.verbose:
371+
print(f"Hallucination detected {pred_uri}")
372+
pred_uri = None ## if none
366373

367-
# Validate predicate
368-
pred_uri = data.get("predicate", "")
369-
if pred_uri not in allowed:
370-
pred_uri = allowed[0]
374+
except (ValidationError, ValueError) as e:
375+
if self.verbose:
376+
print(f"Disambiguation validation failed: {e}")
377+
pred_uri=None
378+
s_idx=0
379+
o_idx=0
371380

372381
# Clamp indices and map to URIs
373-
si = self._safe_idx(subject_candidates, data.get("subject_index", 0))
374-
oi = self._safe_idx(object_candidates, data.get("object_index", 0))
382+
si = self._safe_idx(subject_candidates, s_idx)
383+
oi = self._safe_idx(object_candidates, o_idx)
375384

376385
s_uri = subject_candidates[si][0] if (subject_candidates and si is not None) else ""
377386
o_uri = object_candidates[oi][0] if (object_candidates and oi is not None) else ""
@@ -392,7 +401,7 @@ def disambiguate_triple(
392401
# Build meta (compatible with your previous code)
393402
chosen_sim, rank0 = sim_map.get(pred_uri, (None, None))
394403
meta = {
395-
"label": "candidate",
404+
"label": "candidate" if pred_uri else "hallucination_rejected",
396405
"chosen_similarity": float(chosen_sim) if chosen_sim is not None else None,
397406
"rank_in_topk": (rank0 + 1) if rank0 is not None else None,
398407
"topk": total_k,

GSoC25/NEF/__init__.py

Whitespace-only changes.

GSoC25/NEF/output_schema.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from pydantic import BaseModel, Field
2+
3+
class DisambiguationResult(BaseModel):
4+
subject_index: int = Field(description="The integer index of the correct subject from the provided list.")
5+
predicate_uri: str = Field(description="The exact URI of the predicate selected from the allowed list.")
6+
object_index: int = Field(description="The integer index of the correct object from the provided list.")

0 commit comments

Comments
 (0)