22# Time :2025/3/29 11:16
33# Author :Hui Huang
44import asyncio
5+ import json
56import math
67import os .path
78import re
4243 "very_high" : 4 ,
4344}
4445
45- GENDER_MAP = {
46+ GENDER_MAP : dict [ Literal [ "male" , "female" ], int ] = {
4647 "female" : 0 ,
4748 "male" : 1 ,
4849}
4950
51+ ID2GENDER = {v : k for k , v in GENDER_MAP .items ()}
52+
5053
5154@dataclass
5255class SparkAcousticTokens :
5356 prompt : str
57+ gender : Literal ["female" , "male" ]
5458 global_tokens : Optional [torch .Tensor ] = None
5559
5660 def __post_init__ (self ):
@@ -73,15 +77,21 @@ def _parse_prompt(self):
7377 )
7478 self .global_tokens = global_token_ids
7579
80+ def to_dict (self ) -> dict [str , str ]:
81+ return {
82+ "prompt" : self .prompt ,
83+ "gender" : self .gender
84+ }
85+
7686 def save (self , filepath : str ):
7787 with open (filepath , 'w' , encoding = 'utf8' ) as w :
78- w .write (self .prompt )
88+ w .write (json . dumps ( self .to_dict (), ensure_ascii = False , indent = 2 ) )
7989
8090 @classmethod
8191 def load (cls , filepath : str ):
8292 with open (filepath , 'r' , encoding = 'utf8' ) as r :
83- prompt = r . read ( )
84- return cls (prompt = prompt )
93+ data = json . load ( r )
94+ return cls (** data )
8595
8696
8797def process_prompt (
@@ -619,6 +629,18 @@ async def _control_generate(
619629 acoustic_tokens : Optional [SparkAcousticTokens | str ] = None ,
620630 return_acoustic_tokens : bool = False ,
621631 ** kwargs ):
632+ gender : Literal ["female" , "male" ] = gender if gender in ["female" , "male" ] else "female"
633+
634+ if acoustic_tokens is not None and isinstance (acoustic_tokens , str ):
635+ acoustic_tokens = SparkAcousticTokens .load (acoustic_tokens )
636+
637+ if acoustic_tokens is not None :
638+ if acoustic_tokens .gender != gender :
639+ logger .warning (
640+ f"The provided `acoustic_tokens` belong to the `{ acoustic_tokens .gender } `, but the specified gender is { gender } . "
641+ f"The `acoustic_tokens` will therefore not be used." )
642+ acoustic_tokens = None
643+
622644 segments = self .preprocess_text (
623645 text ,
624646 window_size = window_size ,
@@ -654,14 +676,11 @@ async def generate_audio(
654676 "completion" : generated ['completion' ]
655677 }
656678
657- if acoustic_tokens is not None and isinstance (acoustic_tokens , str ):
658- acoustic_tokens = SparkAcousticTokens (acoustic_tokens )
659-
660679 audios = []
661680 if acoustic_tokens is None :
662681 # 如果没有传入音色,使用第一段生成音色token,将其与后面片段一起拼接,使用相同音色token引导输出semantic tokens。
663682 first_output = await generate_audio (segments [0 ], acoustic_token = None )
664- acoustic_tokens = SparkAcousticTokens (first_output ['completion' ])
683+ acoustic_tokens = SparkAcousticTokens (first_output ['completion' ], gender = gender )
665684 audios .append (first_output ['audio' ])
666685 segments = segments [1 :]
667686
@@ -706,7 +725,7 @@ async def speak_async(
706725 logger .error (err_msg )
707726 raise ValueError (err_msg )
708727 self .set_seed (seed = self .seed )
709- acoustic_tokens = None
728+ out_acoustic_tokens = None
710729 if name in ["female" , "male" ]:
711730 output = await self ._control_generate (
712731 text = text ,
@@ -727,7 +746,7 @@ async def speak_async(
727746 )
728747 if return_acoustic_tokens and isinstance (output , tuple ):
729748 audio = output [0 ]
730- acoustic_tokens = output [1 ]
749+ out_acoustic_tokens = output [1 ]
731750 else :
732751 audio = output
733752 else :
@@ -756,8 +775,8 @@ async def speak_async(
756775
757776 torch .cuda .empty_cache ()
758777
759- if acoustic_tokens is not None :
760- return audio , acoustic_tokens
778+ if out_acoustic_tokens is not None :
779+ return audio , out_acoustic_tokens
761780 return audio
762781
763782 async def _control_stream_generate (
@@ -782,6 +801,8 @@ async def _control_stream_generate(
782801 return_acoustic_tokens : bool = False ,
783802 ** kwargs
784803 ):
804+ gender : Literal ["female" , "male" ] = gender if gender in ["female" , "male" ] else "female"
805+
785806 if audio_chunk_duration < 0.5 :
786807 err_msg = "audio_chunk_duration at least 0.5 seconds"
787808 logger .error (err_msg )
@@ -792,7 +813,14 @@ async def _control_stream_generate(
792813 raise ValueError (err_msg )
793814
794815 if acoustic_tokens is not None and isinstance (acoustic_tokens , str ):
795- acoustic_tokens = SparkAcousticTokens (acoustic_tokens )
816+ acoustic_tokens = SparkAcousticTokens .load (acoustic_tokens )
817+
818+ if acoustic_tokens is not None :
819+ if acoustic_tokens .gender != gender :
820+ logger .warning (
821+ f"The provided `acoustic_tokens` belong to the `{ acoustic_tokens .gender } `, but the specified gender is { gender } . "
822+ f"The `acoustic_tokens` will therefore not be used." )
823+ acoustic_tokens = None
796824
797825 audio_tokenizer_frame_rate = 50
798826 max_chunk_size = math .ceil (max_audio_chunk_duration * audio_tokenizer_frame_rate )
@@ -840,7 +868,7 @@ async def _control_stream_generate(
840868 r"(<\|start_acoustic_token\|>.*?<\|end_global_token\|>)" ,
841869 completion )
842870 if len (acoustics ) > 0 :
843- acoustic_tokens = SparkAcousticTokens (acoustics [0 ])
871+ acoustic_tokens = SparkAcousticTokens (acoustics [0 ], gender = gender )
844872 completion = ""
845873 else :
846874 continue
0 commit comments