Files
RKOneAIDebate/base_mode.py
2026-02-24 17:41:04 +08:00

172 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import base64
import threading
import time
import dashscope
from dashscope.audio.asr import RecognitionCallback, RecognitionResult, Recognition
from dashscope.audio.qwen_tts_realtime import QwenTtsRealtime, AudioFormat, QwenTtsRealtimeCallback
import audio_player
import audio_recorder
from config import *
audio_player.init(channels=1, rate=24000)
audio_recorder.init(channels=1, rate=16000)
dashscope.api_key = TTS_API_KEY
class Mode:
def __init__(self, voice="Cherry",
asr_callback=lambda x: print(x, end="", flush=True),
tts_callback=lambda x: print(x, end="", flush=True)):
self.voice = voice
self._tts_sess = None
self._tts_done = threading.Event()
self._asr_sess = None
self.asr_res = ""
self.last_asr_time = None
self.is_asr_recording = False
self.is_tts_running = False
self.asr_callback = asr_callback
self.tts_callback = tts_callback
def _ensure_tts_session(self):
if self._tts_sess is not None:
try:
self._tts_sess.finish()
self._tts_sess.close()
except:
pass
class CB(QwenTtsRealtimeCallback):
def on_open(_):
self._tts_done.clear()
def on_event(_, rsp):
if rsp.get('type') == 'response.audio.delta':
audio_player.feed(base64.b64decode(rsp['delta']))
if rsp.get('type') == 'session.finished':
self._tts_done.set()
self._tts_sess = QwenTtsRealtime(model='qwen3-tts-flash-realtime', callback=CB())
self._tts_sess.connect()
self._tts_sess.update_session(
voice=self.voice,
response_format=AudioFormat.PCM_24000HZ_MONO_16BIT,
mode='server_commit'
)
def ready_asr_session(self):
self.asr_res = ""
self.last_asr_time = None
if self._asr_sess is not None:
if self._asr_sess._running:
self._asr_sess.stop()
self._asr_sess = None
class CB(RecognitionCallback):
def on_open(_):
pass
def on_event(_, result: RecognitionResult) -> None:
self.last_asr_time = time.time()
res = result.get_sentence()
if res["sentence_end"]:
self.asr_res += res["text"]
self.asr_callback(res["text"])
def on_close(_) -> None:
pass
def on_error(_, result: RecognitionResult) -> None:
print(result)
self._asr_sess = Recognition(model='paraformer-realtime-v2',
format='pcm',
sample_rate=16000,
callback=CB())
self._asr_sess.start()
def start_asr_record(self):
if self._asr_sess is None:
raise RuntimeError("未准备asr会话请调用ready_asr_session方法")
self.is_asr_recording = True
def th():
while self.is_asr_recording:
data = audio_recorder.get(3200)
self._asr_sess.send_audio_frame(data)
threading.Thread(target=th, daemon=True).start()
def stop_asr_record(self):
self.is_asr_recording = False
time.sleep(0.1)
if self._asr_sess is not None:
self._asr_sess.stop()
def tts(self, text):
self._tts_sess.append_text(text)
def stream_pipeline(self, gen):
self._ensure_tts_session()
self.is_tts_running = True
end_marks = {'.', '!', '?', '', '', '', ';', '', '\n'}
res_cache = ""
for chunk in gen:
if not chunk:
continue
pos = next((i for i, c in enumerate(chunk) if c in end_marks), None)
if pos is not None:
res_cache += chunk[:pos + 1]
try:
self.tts(res_cache)
time.sleep(0.3)
except Exception as e:
pass
if self.tts_callback is not None:
self.tts_callback(res_cache)
res_cache = chunk[pos + 1:]
else:
res_cache += chunk
if res_cache:
self.tts(res_cache)
if self.tts_callback is not None:
self.tts_callback(res_cache)
if self.is_tts_running:
# 结束 TTS 会话
self._tts_sess.finish()
# self._tts_done.wait(timeout=2) # 等待最后音频
self._tts_sess = None
# 等待全部播放完毕
self._tts_done.wait()
audio_player.wait()
time.sleep(0.1)
self.is_tts_running = False
def tts_finish(self):
if self._tts_sess is not None:
self._tts_sess.finish()
self.is_tts_running = False
audio_player.clear()
def run(self):
pass
def close(self):
if hasattr(self, '_tts_sess'):
try:
self._tts_sess.close()
except:
pass
def __del__(self):
self.close()
def stop(self):
pass