99from urllib .parse import quote
1010from getpass import getpass
1111
12+ from pydantic import ValidationError
13+ try :
14+ from .output_schema import DisambiguationResult
15+ except ImportError :
16+ from output_schema import DisambiguationResult
17+
1218# =============== Gemini client bootstrap ===============
1319try :
1420 from google import genai
@@ -280,13 +286,6 @@ def _safe_idx(lst: List[Any], value: Any) -> Optional[int]:
280286 i = 0
281287 return max (0 , min (i , len (lst ) - 1 ))
282288
283- @staticmethod
284- def _get_json (text : Optional [str ]) -> Dict [str , Any ]:
285- try :
286- return json .loads ((text or "" ).strip () or "{}" )
287- except Exception :
288- return {}
289-
290289 # --------------------- main entrypoint ---------------------
291290
292291 def disambiguate_triple (
@@ -332,46 +331,56 @@ def disambiguate_triple(
332331 subj_list_text = self ._fmt_indexed (subject_candidates )
333332 obj_list_text = self ._fmt_indexed (object_candidates )
334333
335- prompt = f"""Pick the best RDF triple using ONLY these options.
334+ prompt = f"""
335+ Analyze the context and disambiguate the triple by selecting the correct Subject, Predicate, and Object.
336336
337- Allowed predicate URIs:
338- { pred_list_text }
337+ Context: { context }
339338
340- Subject candidates (choose by INDEX ):
341- { subj_list_text }
339+ Allowed Predicates (Select one URI ):
340+ { pred_list_text }
342341
343- Object candidates (choose by INDEX ):
344- { obj_list_text }
342+ Subject Candidates (Select by Index ):
343+ { subj_list_text }
345344
346- Context (helps decide, but does NOT add new options):
347- { context }
345+ Object Candidates (Select by Index):
346+ { obj_list_text }
347+ """
348348
349- Return ONLY strict JSON on one line (no prose):
350- {{"subject_index": 0, "predicate": "URI", "object_index": 0}}
351- Rules:
352- - "predicate" MUST be exactly one URI from Allowed predicate URIs.
353- - "subject_index" MUST be an integer index from Subject candidates.
354- - "object_index" MUST be an integer index from Object candidates.
355- - Do not invent or modify URIs. Do not swap roles.
356- """
349+ generation_config = types .GenerateContentConfig (
350+ response_mime_type = "application/json" ,
351+ response_schema = DisambiguationResult # <--- The Pydantic class output
352+ )
353+
357354
358355 # Call the model (Gemini client style)
359356 resp = self .client .models .generate_content (
360357 model = self .model_name ,
361358 contents = prompt ,
362- config = { "response_mime_type" : "application/json" },
359+ config = generation_config
363360 )
364361
365- data = self ._get_json (getattr (resp , "text" , None ))
362+ try :
363+ data = DisambiguationResult .model_validate_json (resp .text )
364+ pred_uri = data .predicate_uri
365+ s_idx = data .subject_index
366+ o_idx = data .object_index
367+
368+ # Validate predicate
369+ if pred_uri not in allowed :
370+ if self .verbose :
371+ print (f"Hallucination detected { pred_uri } " )
372+ pred_uri = None ## if none
366373
367- # Validate predicate
368- pred_uri = data .get ("predicate" , "" )
369- if pred_uri not in allowed :
370- pred_uri = allowed [0 ]
374+ except (ValidationError , ValueError ) as e :
375+ if self .verbose :
376+ print (f"Disambiguation validation failed: { e } " )
377+ pred_uri = None
378+ s_idx = 0
379+ o_idx = 0
371380
372381 # Clamp indices and map to URIs
373- si = self ._safe_idx (subject_candidates , data . get ( "subject_index" , 0 ) )
374- oi = self ._safe_idx (object_candidates , data . get ( "object_index" , 0 ) )
382+ si = self ._safe_idx (subject_candidates , s_idx )
383+ oi = self ._safe_idx (object_candidates , o_idx )
375384
376385 s_uri = subject_candidates [si ][0 ] if (subject_candidates and si is not None ) else ""
377386 o_uri = object_candidates [oi ][0 ] if (object_candidates and oi is not None ) else ""
@@ -392,7 +401,7 @@ def disambiguate_triple(
392401 # Build meta (compatible with your previous code)
393402 chosen_sim , rank0 = sim_map .get (pred_uri , (None , None ))
394403 meta = {
395- "label" : "candidate" ,
404+ "label" : "candidate" if pred_uri else "hallucination_rejected" ,
396405 "chosen_similarity" : float (chosen_sim ) if chosen_sim is not None else None ,
397406 "rank_in_topk" : (rank0 + 1 ) if rank0 is not None else None ,
398407 "topk" : total_k ,
0 commit comments