172 lines
5.3 KiB
Python
172 lines
5.3 KiB
Python
import base64
|
||
import threading
|
||
import time
|
||
|
||
import dashscope
|
||
from dashscope.audio.asr import RecognitionCallback, RecognitionResult, Recognition
|
||
from dashscope.audio.qwen_tts_realtime import QwenTtsRealtime, AudioFormat, QwenTtsRealtimeCallback
|
||
|
||
import audio_player
|
||
import audio_recorder
|
||
from config import *
|
||
|
||
audio_player.init(channels=1, rate=24000)
|
||
audio_recorder.init(channels=1, rate=16000)
|
||
|
||
dashscope.api_key = TTS_API_KEY
|
||
|
||
|
||
class Mode:
|
||
def __init__(self, voice="Cherry",
|
||
asr_callback=lambda x: print(x, end="", flush=True),
|
||
tts_callback=lambda x: print(x, end="", flush=True)):
|
||
self.voice = voice
|
||
self._tts_sess = None
|
||
self._tts_done = threading.Event()
|
||
self._asr_sess = None
|
||
self.asr_res = ""
|
||
self.last_asr_time = None
|
||
|
||
self.is_asr_recording = False
|
||
self.is_tts_running = False
|
||
|
||
self.asr_callback = asr_callback
|
||
self.tts_callback = tts_callback
|
||
|
||
def _ensure_tts_session(self):
|
||
if self._tts_sess is not None:
|
||
try:
|
||
self._tts_sess.finish()
|
||
self._tts_sess.close()
|
||
except:
|
||
pass
|
||
|
||
class CB(QwenTtsRealtimeCallback):
|
||
|
||
def on_open(_):
|
||
self._tts_done.clear()
|
||
|
||
def on_event(_, rsp):
|
||
if rsp.get('type') == 'response.audio.delta':
|
||
audio_player.feed(base64.b64decode(rsp['delta']))
|
||
if rsp.get('type') == 'session.finished':
|
||
self._tts_done.set()
|
||
|
||
self._tts_sess = QwenTtsRealtime(model='qwen3-tts-flash-realtime', callback=CB())
|
||
self._tts_sess.connect()
|
||
self._tts_sess.update_session(
|
||
voice=self.voice,
|
||
response_format=AudioFormat.PCM_24000HZ_MONO_16BIT,
|
||
mode='server_commit'
|
||
)
|
||
|
||
def ready_asr_session(self):
|
||
self.asr_res = ""
|
||
self.last_asr_time = None
|
||
if self._asr_sess is not None:
|
||
if self._asr_sess._running:
|
||
self._asr_sess.stop()
|
||
self._asr_sess = None
|
||
class CB(RecognitionCallback):
|
||
def on_open(_):
|
||
pass
|
||
|
||
def on_event(_, result: RecognitionResult) -> None:
|
||
self.last_asr_time = time.time()
|
||
res = result.get_sentence()
|
||
if res["sentence_end"]:
|
||
self.asr_res += res["text"]
|
||
self.asr_callback(res["text"])
|
||
|
||
def on_close(_) -> None:
|
||
pass
|
||
|
||
def on_error(_, result: RecognitionResult) -> None:
|
||
print(result)
|
||
|
||
self._asr_sess = Recognition(model='paraformer-realtime-v2',
|
||
format='pcm',
|
||
sample_rate=16000,
|
||
callback=CB())
|
||
self._asr_sess.start()
|
||
|
||
def start_asr_record(self):
|
||
if self._asr_sess is None:
|
||
raise RuntimeError("未准备asr会话,请调用ready_asr_session方法")
|
||
self.is_asr_recording = True
|
||
def th():
|
||
while self.is_asr_recording:
|
||
data = audio_recorder.get(3200)
|
||
self._asr_sess.send_audio_frame(data)
|
||
threading.Thread(target=th, daemon=True).start()
|
||
|
||
def stop_asr_record(self):
|
||
self.is_asr_recording = False
|
||
time.sleep(0.1)
|
||
if self._asr_sess is not None:
|
||
self._asr_sess.stop()
|
||
|
||
def tts(self, text):
|
||
self._tts_sess.append_text(text)
|
||
|
||
def stream_pipeline(self, gen):
|
||
self._ensure_tts_session()
|
||
self.is_tts_running = True
|
||
end_marks = {'.', '!', '?', '。', '!', '?', ';', ';', '\n'}
|
||
res_cache = ""
|
||
for chunk in gen:
|
||
if not chunk:
|
||
continue
|
||
pos = next((i for i, c in enumerate(chunk) if c in end_marks), None)
|
||
if pos is not None:
|
||
res_cache += chunk[:pos + 1]
|
||
try:
|
||
self.tts(res_cache)
|
||
time.sleep(0.3)
|
||
except Exception as e:
|
||
pass
|
||
if self.tts_callback is not None:
|
||
self.tts_callback(res_cache)
|
||
res_cache = chunk[pos + 1:]
|
||
else:
|
||
res_cache += chunk
|
||
if res_cache:
|
||
self.tts(res_cache)
|
||
if self.tts_callback is not None:
|
||
self.tts_callback(res_cache)
|
||
|
||
if self.is_tts_running:
|
||
# 结束 TTS 会话
|
||
self._tts_sess.finish()
|
||
# self._tts_done.wait(timeout=2) # 等待最后音频
|
||
self._tts_sess = None
|
||
# 等待全部播放完毕
|
||
self._tts_done.wait()
|
||
audio_player.wait()
|
||
|
||
time.sleep(0.1)
|
||
|
||
self.is_tts_running = False
|
||
|
||
def tts_finish(self):
|
||
if self._tts_sess is not None:
|
||
self._tts_sess.finish()
|
||
self.is_tts_running = False
|
||
audio_player.clear()
|
||
|
||
def run(self):
|
||
pass
|
||
|
||
def close(self):
|
||
if hasattr(self, '_tts_sess'):
|
||
try:
|
||
self._tts_sess.close()
|
||
except:
|
||
pass
|
||
|
||
def __del__(self):
|
||
self.close()
|
||
|
||
def stop(self):
|
||
pass
|