365 lines
12 KiB
Python
365 lines
12 KiB
Python
import threading
|
||
import time
|
||
from http import HTTPStatus
|
||
|
||
import dashscope
|
||
import openai
|
||
|
||
import base_mode as bm
|
||
from config import *
|
||
|
||
class DialogMode(bm.Mode):
|
||
def __init__(self, model="qwen-plus", system_prompt="You are a helpful assistant", threshold_no_speak=2,
|
||
asr_callback=lambda x: print(x, end="", flush=True), tts_callback=lambda x: print(x, end="", flush=True)):
|
||
"""
|
||
自由对话模式类的构造函数
|
||
:param model: 模型名称
|
||
:param system_prompt: 系统提示词
|
||
:param threshold_no_speak: 多长时间未说话认定为结束,单位:秒
|
||
"""
|
||
super().__init__(
|
||
asr_callback=asr_callback,
|
||
tts_callback=tts_callback
|
||
)
|
||
self.model = model
|
||
self.client = openai.OpenAI(
|
||
api_key=COM_API_KEY,
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||
)
|
||
self.context = [
|
||
{"role": "system", "content": system_prompt}
|
||
]
|
||
self.system_prompt = system_prompt
|
||
self.main_loop_thread = threading.Thread(target=self.main_loop)
|
||
self.running = False
|
||
|
||
self.threshold_no_speak = threshold_no_speak
|
||
|
||
def ask_ai(self, prompt):
|
||
self.context.append(
|
||
{"role": "user", "content": prompt}
|
||
)
|
||
response = self.client.chat.completions.create(
|
||
model=self.model,
|
||
messages=self.context,
|
||
max_tokens=2048,
|
||
stream=True
|
||
)
|
||
all_text = ""
|
||
for chunk in response:
|
||
content = chunk.choices[0].delta.content or ""
|
||
all_text += content
|
||
yield content
|
||
self.context.append(
|
||
{"role": "assistant", "content": all_text}
|
||
)
|
||
yield ""
|
||
|
||
def main_loop(self):
|
||
while self.running:
|
||
# 1. 启动ASR服务
|
||
self.ready_asr_session()
|
||
# 2. 开始录音
|
||
self.start_asr_record()
|
||
# 3. 等待开始说话
|
||
while self.last_asr_time is None: pass
|
||
# 4. 等待说话时长过长
|
||
while (time.time() - self.last_asr_time) <= self.threshold_no_speak: pass
|
||
# 5. 关闭ASR
|
||
self.stop_asr_record()
|
||
# 6. 扔给AI
|
||
self.stream_pipeline(self.ask_ai(self.asr_res))
|
||
|
||
|
||
def run(self):
|
||
super().run()
|
||
self.running = True
|
||
print("请开始说话...")
|
||
self.main_loop_thread.start()
|
||
|
||
def stop(self):
|
||
self.running = False
|
||
self.tts_finish()
|
||
|
||
|
||
class MakePointMode(bm.Mode):
|
||
def __init__(self, side="正方", topic="", opinion="", asr_callback=lambda x: print(x, end="", flush=True), tts_callback=lambda x: print(x, end="", flush=True)):
|
||
super().__init__(
|
||
asr_callback=asr_callback,
|
||
tts_callback=tts_callback
|
||
)
|
||
self.side = side
|
||
self.topic = topic
|
||
self.opinion = opinion
|
||
self.session_id = None
|
||
self.last_prompt = "" # 当side="正方"时,为反方立论内容,否则为空字符串
|
||
|
||
def ask_ai(self, prompt):
|
||
biz_params = {
|
||
"user_prompt_params" : {
|
||
"side": self.side,
|
||
"topic": self.topic,
|
||
"opinion": self.opinion
|
||
}
|
||
}
|
||
if self.session_id is None:
|
||
res = dashscope.Application.call(
|
||
api_key=COM_API_KEY,
|
||
app_id=APP_ID,
|
||
prompt=prompt,
|
||
stream=True,
|
||
incremental_output=True,
|
||
biz_params=biz_params
|
||
)
|
||
else:
|
||
res = dashscope.Application.call(
|
||
api_key=COM_API_KEY,
|
||
app_id=APP_ID,
|
||
prompt=prompt,
|
||
stream=True,
|
||
incremental_output=True,
|
||
session_id=self.session_id,
|
||
biz_params=biz_params
|
||
)
|
||
for chunk in res:
|
||
if chunk.status_code == HTTPStatus.OK:
|
||
self.session_id = chunk.output.session_id
|
||
yield chunk.output.text
|
||
|
||
def ask(self, other_op=None):
|
||
if self.side == "正方":
|
||
return self.ask_ai("主席:请正方开始立论")
|
||
else:
|
||
return self.ask_ai("主席:请正方开始立论\n正方:" + other_op)
|
||
|
||
def start_record(self):
|
||
self.ready_asr_session()
|
||
self.start_asr_record()
|
||
|
||
def stop_record(self):
|
||
"""
|
||
该函数执行完毕后才能调用start_talk和start_identify_op函数!
|
||
:return:
|
||
"""
|
||
self.stop_asr_record()
|
||
|
||
def start_talk(self):
|
||
"""
|
||
仅对反方:开始立论
|
||
:return:
|
||
"""
|
||
if self.side == "反方":
|
||
self.stream_pipeline(self.ask(self.asr_res))
|
||
|
||
def start_identify_op(self):
|
||
"""
|
||
仅对正方:设定last_prompt,以便返回给下面的自由辩论模式
|
||
:return:
|
||
"""
|
||
if self.side == "正方":
|
||
self.last_prompt = self.asr_res
|
||
|
||
def ready_next(self):
|
||
if self.side == "正方":
|
||
self.start_identify_op()
|
||
return self.last_prompt
|
||
else:
|
||
self.start_talk()
|
||
|
||
|
||
def run(self):
|
||
"""
|
||
流程:
|
||
1. 正方:先说话,再录音
|
||
调用本函数时,直接开始说话,说话完成后,当对方开始立论时,需显式调用start_record函数;结束立论时,需显式调用stop_record函数;
|
||
当进入自由辩论环节时,需显式调用ready_next。如果为正方,需接收返回值,作为自由辩论环节提示词前面的反方立论内容。
|
||
2. 反方:先录音,再说话
|
||
调用本函数时,开始录音(因此需要在对方将要说话时切入此模式),当对方说完话后,应显式调用stop_record函数;本方开始立论时,应显式调用start_talk函数
|
||
:return:
|
||
"""
|
||
super().run()
|
||
if self.side == "正方":
|
||
# 先说话再录音
|
||
self.stream_pipeline(self.ask())
|
||
else:
|
||
# 先录音再说话。外界控制
|
||
self.start_record()
|
||
|
||
|
||
class FreeDebateMode(bm.Mode):
|
||
def __init__(self, session_id, side="正方", topic="", opinion="", last_prompt="", threshold_no_speak=2,
|
||
asr_callback=lambda x: print(x, end="", flush=True), tts_callback=lambda x: print(x, end="", flush=True)):
|
||
super().__init__(
|
||
asr_callback=asr_callback,
|
||
tts_callback=tts_callback
|
||
)
|
||
self.threshold_no_speak = threshold_no_speak
|
||
self.side = side
|
||
self.topic = topic
|
||
self.opinion = opinion
|
||
self.last_prompt = last_prompt
|
||
self.session_id = session_id
|
||
self.is_first = True
|
||
self.last_prompt_next_mode = None
|
||
self.main_loop_thread = threading.Thread(target=self.main_loop)
|
||
self.running = False
|
||
|
||
def ask_ai(self, prompt):
|
||
biz_params = {
|
||
"user_prompt_params" : {
|
||
"side": self.side,
|
||
"topic": self.topic,
|
||
"opinion": self.opinion
|
||
}
|
||
}
|
||
res = dashscope.Application.call(
|
||
api_key=COM_API_KEY,
|
||
app_id=APP_ID,
|
||
prompt=prompt,
|
||
stream=True,
|
||
incremental_output=True,
|
||
session_id=self.session_id,
|
||
biz_params=biz_params
|
||
)
|
||
for chunk in res:
|
||
if chunk.status_code == HTTPStatus.OK:
|
||
self.session_id = chunk.output.session_id
|
||
yield chunk.output.text
|
||
|
||
def init_ask(self):
|
||
"""
|
||
仅正方:用于将反方的立论传入,并让正方说话
|
||
:return:
|
||
"""
|
||
if self.side == "正方":
|
||
return self.ask_ai("反方:"+self.last_prompt+"\n主席:下面进入自由辩论环节,请正方开始发言")
|
||
|
||
def ask(self, context=None):
|
||
if self.side == "正方":
|
||
return self.ask_ai("反方:"+context)
|
||
else:
|
||
if self.is_first:
|
||
self.is_first = False
|
||
return self.ask_ai("主席:下面进入自由辩论环节,请正方开始发言\n正方:" + context)
|
||
return self.ask_ai("正方:"+context)
|
||
|
||
def main_loop(self):
|
||
# 初始ask
|
||
if self.side == "正方":
|
||
self.stream_pipeline(self.init_ask())
|
||
# 主对话循环
|
||
while self.running:
|
||
# 1. 启动ASR服务
|
||
self.ready_asr_session()
|
||
# 2. 开始录音
|
||
self.start_asr_record()
|
||
# 3. 等待开始说话
|
||
while self.last_asr_time is None: pass
|
||
# 4. 等待说话时长过长
|
||
while (time.time() - self.last_asr_time) <= self.threshold_no_speak: pass
|
||
# 5. 关闭ASR
|
||
self.stop_asr_record()
|
||
# 6. 扔给AI
|
||
self.stream_pipeline(self.ask(self.asr_res))
|
||
|
||
def run(self):
|
||
"""
|
||
流程:
|
||
类似对话模式
|
||
当进入结论模式时,外界直接通过析构对象然后启动结论模式即可。但建议取一下last_prompt_next_mode字段
|
||
:return:
|
||
"""
|
||
super().run()
|
||
self.running = True
|
||
self.main_loop_thread.start()
|
||
|
||
def stop(self):
|
||
self.running = False
|
||
self.tts_finish()
|
||
|
||
|
||
class EndDebateMode(bm.Mode):
|
||
def __init__(self, session_id, side="正方", topic="", opinion="", last_prompt="",
|
||
asr_callback=lambda x: print(x, end="", flush=True), tts_callback=lambda x: print(x, end="", flush=True)):
|
||
super().__init__(
|
||
asr_callback=asr_callback,
|
||
tts_callback=tts_callback
|
||
)
|
||
self.side = side
|
||
self.topic = topic
|
||
self.opinion = opinion
|
||
self.last_prompt = last_prompt
|
||
self.session_id = session_id
|
||
|
||
def ask_ai(self, prompt):
|
||
biz_params = {
|
||
"user_prompt_params": {
|
||
"side": self.side,
|
||
"topic": self.topic,
|
||
"opinion": self.opinion
|
||
}
|
||
}
|
||
res = dashscope.Application.call(
|
||
api_key=COM_API_KEY,
|
||
app_id=APP_ID,
|
||
prompt=prompt,
|
||
stream=True,
|
||
incremental_output=True,
|
||
session_id=self.session_id,
|
||
biz_params=biz_params
|
||
)
|
||
for chunk in res:
|
||
if chunk.status_code == HTTPStatus.OK:
|
||
self.session_id = chunk.output.session_id
|
||
yield chunk.output.text
|
||
|
||
def ask(self, context=None):
|
||
if self.side == "反方":
|
||
if self.last_prompt != "":
|
||
return self.ask_ai("反方:"+self.last_prompt+"\n主席:下面进入结辩环节,请反方开始发言")
|
||
else:
|
||
return self.ask_ai("主席:下面进入结辩环节,请反方开始发言")
|
||
else:
|
||
return self.ask_ai("主席:下面进入结辩环节,请反方开始发言\n反方:"+context)
|
||
|
||
def start_record(self):
|
||
self.ready_asr_session()
|
||
self.start_asr_record()
|
||
|
||
def stop_record(self):
|
||
"""
|
||
该函数执行完毕后才能调用start_talk和start_identify_op函数!
|
||
:return:
|
||
"""
|
||
self.stop_asr_record()
|
||
|
||
def start_talk(self):
|
||
"""
|
||
开始结论
|
||
:return:
|
||
"""
|
||
if self.side == "反方":
|
||
self.stream_pipeline(self.ask())
|
||
else:
|
||
self.stream_pipeline(self.ask(self.asr_res))
|
||
|
||
def run(self):
|
||
"""
|
||
流程:
|
||
1. 正方:当对方说话时,需外界调用start_record方法,说话结束后调用stop_record方法,再调用start_talk方法
|
||
2. 反方:需外界调用start_talk方法,然后直接开始结辩
|
||
走到这里,圆满结束!!!!
|
||
:return:
|
||
"""
|
||
super().run()
|
||
pass
|
||
|
||
|
||
if __name__ == '__main__':
|
||
mode = DialogMode(
|
||
asr_callback=lambda x: print(x, end="", flush=True),
|
||
tts_callback=lambda x: print(x, end="", flush=True)
|
||
)
|
||
mode.run()
|
||
time.sleep(200)
|