This commit is contained in:
wtz
2026-02-24 17:41:04 +08:00
commit fc32cab12b
9 changed files with 1239 additions and 0 deletions

171
base_mode.py Normal file
View File

@@ -0,0 +1,171 @@
import base64
import threading
import time
import dashscope
from dashscope.audio.asr import RecognitionCallback, RecognitionResult, Recognition
from dashscope.audio.qwen_tts_realtime import QwenTtsRealtime, AudioFormat, QwenTtsRealtimeCallback
import audio_player
import audio_recorder
from config import *
audio_player.init(channels=1, rate=24000)
audio_recorder.init(channels=1, rate=16000)
dashscope.api_key = TTS_API_KEY
class Mode:
def __init__(self, voice="Cherry",
asr_callback=lambda x: print(x, end="", flush=True),
tts_callback=lambda x: print(x, end="", flush=True)):
self.voice = voice
self._tts_sess = None
self._tts_done = threading.Event()
self._asr_sess = None
self.asr_res = ""
self.last_asr_time = None
self.is_asr_recording = False
self.is_tts_running = False
self.asr_callback = asr_callback
self.tts_callback = tts_callback
def _ensure_tts_session(self):
if self._tts_sess is not None:
try:
self._tts_sess.finish()
self._tts_sess.close()
except:
pass
class CB(QwenTtsRealtimeCallback):
def on_open(_):
self._tts_done.clear()
def on_event(_, rsp):
if rsp.get('type') == 'response.audio.delta':
audio_player.feed(base64.b64decode(rsp['delta']))
if rsp.get('type') == 'session.finished':
self._tts_done.set()
self._tts_sess = QwenTtsRealtime(model='qwen3-tts-flash-realtime', callback=CB())
self._tts_sess.connect()
self._tts_sess.update_session(
voice=self.voice,
response_format=AudioFormat.PCM_24000HZ_MONO_16BIT,
mode='server_commit'
)
def ready_asr_session(self):
self.asr_res = ""
self.last_asr_time = None
if self._asr_sess is not None:
if self._asr_sess._running:
self._asr_sess.stop()
self._asr_sess = None
class CB(RecognitionCallback):
def on_open(_):
pass
def on_event(_, result: RecognitionResult) -> None:
self.last_asr_time = time.time()
res = result.get_sentence()
if res["sentence_end"]:
self.asr_res += res["text"]
self.asr_callback(res["text"])
def on_close(_) -> None:
pass
def on_error(_, result: RecognitionResult) -> None:
print(result)
self._asr_sess = Recognition(model='paraformer-realtime-v2',
format='pcm',
sample_rate=16000,
callback=CB())
self._asr_sess.start()
def start_asr_record(self):
if self._asr_sess is None:
raise RuntimeError("未准备asr会话请调用ready_asr_session方法")
self.is_asr_recording = True
def th():
while self.is_asr_recording:
data = audio_recorder.get(3200)
self._asr_sess.send_audio_frame(data)
threading.Thread(target=th, daemon=True).start()
def stop_asr_record(self):
self.is_asr_recording = False
time.sleep(0.1)
if self._asr_sess is not None:
self._asr_sess.stop()
def tts(self, text):
self._tts_sess.append_text(text)
def stream_pipeline(self, gen):
self._ensure_tts_session()
self.is_tts_running = True
end_marks = {'.', '!', '?', '', '', '', ';', '', '\n'}
res_cache = ""
for chunk in gen:
if not chunk:
continue
pos = next((i for i, c in enumerate(chunk) if c in end_marks), None)
if pos is not None:
res_cache += chunk[:pos + 1]
try:
self.tts(res_cache)
time.sleep(0.3)
except Exception as e:
pass
if self.tts_callback is not None:
self.tts_callback(res_cache)
res_cache = chunk[pos + 1:]
else:
res_cache += chunk
if res_cache:
self.tts(res_cache)
if self.tts_callback is not None:
self.tts_callback(res_cache)
if self.is_tts_running:
# 结束 TTS 会话
self._tts_sess.finish()
# self._tts_done.wait(timeout=2) # 等待最后音频
self._tts_sess = None
# 等待全部播放完毕
self._tts_done.wait()
audio_player.wait()
time.sleep(0.1)
self.is_tts_running = False
def tts_finish(self):
if self._tts_sess is not None:
self._tts_sess.finish()
self.is_tts_running = False
audio_player.clear()
def run(self):
pass
def close(self):
if hasattr(self, '_tts_sess'):
try:
self._tts_sess.close()
except:
pass
def __del__(self):
self.close()
def stop(self):
pass