99from urllib .parse import quote
1010from getpass import getpass
1111
12+ try :
13+ from .output_schema import DisambiguationResult
14+ except ImportError :
15+ from output_schema import DisambiguationResult
16+
1217# =============== Gemini client bootstrap ===============
1318try :
1419 from google import genai
@@ -280,13 +285,6 @@ def _safe_idx(lst: List[Any], value: Any) -> Optional[int]:
280285 i = 0
281286 return max (0 , min (i , len (lst ) - 1 ))
282287
283- @staticmethod
284- def _get_json (text : Optional [str ]) -> Dict [str , Any ]:
285- try :
286- return json .loads ((text or "" ).strip () or "{}" )
287- except Exception :
288- return {}
289-
290288 # --------------------- main entrypoint ---------------------
291289
292290 def disambiguate_triple (
@@ -332,46 +330,53 @@ def disambiguate_triple(
332330 subj_list_text = self ._fmt_indexed (subject_candidates )
333331 obj_list_text = self ._fmt_indexed (object_candidates )
334332
335- prompt = f"""Pick the best RDF triple using ONLY these options.
333+ prompt = f"""
334+ Analyze the context and disambiguate the triple by selecting the correct Subject, Predicate, and Object.
336335
337- Allowed predicate URIs:
338- { pred_list_text }
336+ Context: { context }
339337
340- Subject candidates (choose by INDEX ):
341- { subj_list_text }
338+ Allowed Predicates (Select one URI ):
339+ { pred_list_text }
342340
343- Object candidates (choose by INDEX ):
344- { obj_list_text }
341+ Subject Candidates (Select by Index ):
342+ { subj_list_text }
345343
346- Context (helps decide, but does NOT add new options):
347- { context }
344+ Object Candidates (Select by Index):
345+ { obj_list_text }
346+ """
348347
349- Return ONLY strict JSON on one line (no prose):
350- {{"subject_index": 0, "predicate": "URI", "object_index": 0}}
351- Rules:
352- - "predicate" MUST be exactly one URI from Allowed predicate URIs.
353- - "subject_index" MUST be an integer index from Subject candidates.
354- - "object_index" MUST be an integer index from Object candidates.
355- - Do not invent or modify URIs. Do not swap roles.
356- """
348+ generation_config = types .GenerateContentConfig (
349+ response_mime_type = "application/json" ,
350+ response_scheme = DisambiguationResult # <--- The Pydantic class output
351+ )
352+
357353
358354 # Call the model (Gemini client style)
359355 resp = self .client .models .generate_content (
360356 model = self .model_name ,
361357 contents = prompt ,
362- config = { "response_mime_type" : "application/json" },
358+ config = generation_config
363359 )
364360
365- data = self ._get_json (getattr (resp , "text" , None ))
361+ try :
362+ data = DisambiguationResult .model_validate_json (resp .text )
363+ pred_uri = data .predicate_uri
364+ s_idx = data .subject_index
365+ o_idx = data .object_index
366+
367+ # Validate predicate
368+ if pred_uri not in allowed :
369+ pred_uri = None ## if none
366370
367- # Validate predicate
368- pred_uri = data .get ("predicate" , "" )
369- if pred_uri not in allowed :
370- pred_uri = allowed [0 ]
371+ except Exception as e :
372+ print (f"validation error{ e } " )
373+ pred_uri = None
374+ s_idx = 0
375+ o_idx = 0
371376
372377 # Clamp indices and map to URIs
373- si = self ._safe_idx (subject_candidates , data . get ( "subject_index" , 0 ) )
374- oi = self ._safe_idx (object_candidates , data . get ( "object_index" , 0 ) )
378+ si = self ._safe_idx (subject_candidates , s_idx )
379+ oi = self ._safe_idx (object_candidates , o_idx )
375380
376381 s_uri = subject_candidates [si ][0 ] if (subject_candidates and si is not None ) else ""
377382 o_uri = object_candidates [oi ][0 ] if (object_candidates and oi is not None ) else ""
@@ -392,7 +397,7 @@ def disambiguate_triple(
392397 # Build meta (compatible with your previous code)
393398 chosen_sim , rank0 = sim_map .get (pred_uri , (None , None ))
394399 meta = {
395- "label" : "candidate" ,
400+ "label" : "candidate" if pred_uri else "hallucination_rejected" ,
396401 "chosen_similarity" : float (chosen_sim ) if chosen_sim is not None else None ,
397402 "rank_in_topk" : (rank0 + 1 ) if rank0 is not None else None ,
398403 "topk" : total_k ,
0 commit comments