init
This commit is contained in:
171
base_mode.py
Normal file
171
base_mode.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import base64
|
||||
import threading
|
||||
import time
|
||||
|
||||
import dashscope
|
||||
from dashscope.audio.asr import RecognitionCallback, RecognitionResult, Recognition
|
||||
from dashscope.audio.qwen_tts_realtime import QwenTtsRealtime, AudioFormat, QwenTtsRealtimeCallback
|
||||
|
||||
import audio_player
|
||||
import audio_recorder
|
||||
from config import *
|
||||
|
||||
audio_player.init(channels=1, rate=24000)
|
||||
audio_recorder.init(channels=1, rate=16000)
|
||||
|
||||
dashscope.api_key = TTS_API_KEY
|
||||
|
||||
|
||||
class Mode:
|
||||
def __init__(self, voice="Cherry",
|
||||
asr_callback=lambda x: print(x, end="", flush=True),
|
||||
tts_callback=lambda x: print(x, end="", flush=True)):
|
||||
self.voice = voice
|
||||
self._tts_sess = None
|
||||
self._tts_done = threading.Event()
|
||||
self._asr_sess = None
|
||||
self.asr_res = ""
|
||||
self.last_asr_time = None
|
||||
|
||||
self.is_asr_recording = False
|
||||
self.is_tts_running = False
|
||||
|
||||
self.asr_callback = asr_callback
|
||||
self.tts_callback = tts_callback
|
||||
|
||||
def _ensure_tts_session(self):
|
||||
if self._tts_sess is not None:
|
||||
try:
|
||||
self._tts_sess.finish()
|
||||
self._tts_sess.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
class CB(QwenTtsRealtimeCallback):
|
||||
|
||||
def on_open(_):
|
||||
self._tts_done.clear()
|
||||
|
||||
def on_event(_, rsp):
|
||||
if rsp.get('type') == 'response.audio.delta':
|
||||
audio_player.feed(base64.b64decode(rsp['delta']))
|
||||
if rsp.get('type') == 'session.finished':
|
||||
self._tts_done.set()
|
||||
|
||||
self._tts_sess = QwenTtsRealtime(model='qwen3-tts-flash-realtime', callback=CB())
|
||||
self._tts_sess.connect()
|
||||
self._tts_sess.update_session(
|
||||
voice=self.voice,
|
||||
response_format=AudioFormat.PCM_24000HZ_MONO_16BIT,
|
||||
mode='server_commit'
|
||||
)
|
||||
|
||||
def ready_asr_session(self):
|
||||
self.asr_res = ""
|
||||
self.last_asr_time = None
|
||||
if self._asr_sess is not None:
|
||||
if self._asr_sess._running:
|
||||
self._asr_sess.stop()
|
||||
self._asr_sess = None
|
||||
class CB(RecognitionCallback):
|
||||
def on_open(_):
|
||||
pass
|
||||
|
||||
def on_event(_, result: RecognitionResult) -> None:
|
||||
self.last_asr_time = time.time()
|
||||
res = result.get_sentence()
|
||||
if res["sentence_end"]:
|
||||
self.asr_res += res["text"]
|
||||
self.asr_callback(res["text"])
|
||||
|
||||
def on_close(_) -> None:
|
||||
pass
|
||||
|
||||
def on_error(_, result: RecognitionResult) -> None:
|
||||
print(result)
|
||||
|
||||
self._asr_sess = Recognition(model='paraformer-realtime-v2',
|
||||
format='pcm',
|
||||
sample_rate=16000,
|
||||
callback=CB())
|
||||
self._asr_sess.start()
|
||||
|
||||
def start_asr_record(self):
|
||||
if self._asr_sess is None:
|
||||
raise RuntimeError("未准备asr会话,请调用ready_asr_session方法")
|
||||
self.is_asr_recording = True
|
||||
def th():
|
||||
while self.is_asr_recording:
|
||||
data = audio_recorder.get(3200)
|
||||
self._asr_sess.send_audio_frame(data)
|
||||
threading.Thread(target=th, daemon=True).start()
|
||||
|
||||
def stop_asr_record(self):
|
||||
self.is_asr_recording = False
|
||||
time.sleep(0.1)
|
||||
if self._asr_sess is not None:
|
||||
self._asr_sess.stop()
|
||||
|
||||
def tts(self, text):
|
||||
self._tts_sess.append_text(text)
|
||||
|
||||
def stream_pipeline(self, gen):
|
||||
self._ensure_tts_session()
|
||||
self.is_tts_running = True
|
||||
end_marks = {'.', '!', '?', '。', '!', '?', ';', ';', '\n'}
|
||||
res_cache = ""
|
||||
for chunk in gen:
|
||||
if not chunk:
|
||||
continue
|
||||
pos = next((i for i, c in enumerate(chunk) if c in end_marks), None)
|
||||
if pos is not None:
|
||||
res_cache += chunk[:pos + 1]
|
||||
try:
|
||||
self.tts(res_cache)
|
||||
time.sleep(0.3)
|
||||
except Exception as e:
|
||||
pass
|
||||
if self.tts_callback is not None:
|
||||
self.tts_callback(res_cache)
|
||||
res_cache = chunk[pos + 1:]
|
||||
else:
|
||||
res_cache += chunk
|
||||
if res_cache:
|
||||
self.tts(res_cache)
|
||||
if self.tts_callback is not None:
|
||||
self.tts_callback(res_cache)
|
||||
|
||||
if self.is_tts_running:
|
||||
# 结束 TTS 会话
|
||||
self._tts_sess.finish()
|
||||
# self._tts_done.wait(timeout=2) # 等待最后音频
|
||||
self._tts_sess = None
|
||||
# 等待全部播放完毕
|
||||
self._tts_done.wait()
|
||||
audio_player.wait()
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
self.is_tts_running = False
|
||||
|
||||
def tts_finish(self):
|
||||
if self._tts_sess is not None:
|
||||
self._tts_sess.finish()
|
||||
self.is_tts_running = False
|
||||
audio_player.clear()
|
||||
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
if hasattr(self, '_tts_sess'):
|
||||
try:
|
||||
self._tts_sess.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user