init
This commit is contained in:
364
modes.py
Normal file
364
modes.py
Normal file
@@ -0,0 +1,364 @@
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
|
||||
import dashscope
|
||||
import openai
|
||||
|
||||
import base_mode as bm
|
||||
from config import *
|
||||
|
||||
class DialogMode(bm.Mode):
|
||||
def __init__(self, model="qwen-plus", system_prompt="You are a helpful assistant", threshold_no_speak=2,
|
||||
asr_callback=lambda x: print(x, end="", flush=True), tts_callback=lambda x: print(x, end="", flush=True)):
|
||||
"""
|
||||
自由对话模式类的构造函数
|
||||
:param model: 模型名称
|
||||
:param system_prompt: 系统提示词
|
||||
:param threshold_no_speak: 多长时间未说话认定为结束,单位:秒
|
||||
"""
|
||||
super().__init__(
|
||||
asr_callback=asr_callback,
|
||||
tts_callback=tts_callback
|
||||
)
|
||||
self.model = model
|
||||
self.client = openai.OpenAI(
|
||||
api_key=COM_API_KEY,
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
)
|
||||
self.context = [
|
||||
{"role": "system", "content": system_prompt}
|
||||
]
|
||||
self.system_prompt = system_prompt
|
||||
self.main_loop_thread = threading.Thread(target=self.main_loop)
|
||||
self.running = False
|
||||
|
||||
self.threshold_no_speak = threshold_no_speak
|
||||
|
||||
def ask_ai(self, prompt):
|
||||
self.context.append(
|
||||
{"role": "user", "content": prompt}
|
||||
)
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=self.context,
|
||||
max_tokens=2048,
|
||||
stream=True
|
||||
)
|
||||
all_text = ""
|
||||
for chunk in response:
|
||||
content = chunk.choices[0].delta.content or ""
|
||||
all_text += content
|
||||
yield content
|
||||
self.context.append(
|
||||
{"role": "assistant", "content": all_text}
|
||||
)
|
||||
yield ""
|
||||
|
||||
def main_loop(self):
|
||||
while self.running:
|
||||
# 1. 启动ASR服务
|
||||
self.ready_asr_session()
|
||||
# 2. 开始录音
|
||||
self.start_asr_record()
|
||||
# 3. 等待开始说话
|
||||
while self.last_asr_time is None: pass
|
||||
# 4. 等待说话时长过长
|
||||
while (time.time() - self.last_asr_time) <= self.threshold_no_speak: pass
|
||||
# 5. 关闭ASR
|
||||
self.stop_asr_record()
|
||||
# 6. 扔给AI
|
||||
self.stream_pipeline(self.ask_ai(self.asr_res))
|
||||
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
self.running = True
|
||||
print("请开始说话...")
|
||||
self.main_loop_thread.start()
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
self.tts_finish()
|
||||
|
||||
|
||||
class MakePointMode(bm.Mode):
|
||||
def __init__(self, side="正方", topic="", opinion="", asr_callback=lambda x: print(x, end="", flush=True), tts_callback=lambda x: print(x, end="", flush=True)):
|
||||
super().__init__(
|
||||
asr_callback=asr_callback,
|
||||
tts_callback=tts_callback
|
||||
)
|
||||
self.side = side
|
||||
self.topic = topic
|
||||
self.opinion = opinion
|
||||
self.session_id = None
|
||||
self.last_prompt = "" # 当side="正方"时,为反方立论内容,否则为空字符串
|
||||
|
||||
def ask_ai(self, prompt):
|
||||
biz_params = {
|
||||
"user_prompt_params" : {
|
||||
"side": self.side,
|
||||
"topic": self.topic,
|
||||
"opinion": self.opinion
|
||||
}
|
||||
}
|
||||
if self.session_id is None:
|
||||
res = dashscope.Application.call(
|
||||
api_key=COM_API_KEY,
|
||||
app_id=APP_ID,
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
incremental_output=True,
|
||||
biz_params=biz_params
|
||||
)
|
||||
else:
|
||||
res = dashscope.Application.call(
|
||||
api_key=COM_API_KEY,
|
||||
app_id=APP_ID,
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
incremental_output=True,
|
||||
session_id=self.session_id,
|
||||
biz_params=biz_params
|
||||
)
|
||||
for chunk in res:
|
||||
if chunk.status_code == HTTPStatus.OK:
|
||||
self.session_id = chunk.output.session_id
|
||||
yield chunk.output.text
|
||||
|
||||
def ask(self, other_op=None):
|
||||
if self.side == "正方":
|
||||
return self.ask_ai("主席:请正方开始立论")
|
||||
else:
|
||||
return self.ask_ai("主席:请正方开始立论\n正方:" + other_op)
|
||||
|
||||
def start_record(self):
|
||||
self.ready_asr_session()
|
||||
self.start_asr_record()
|
||||
|
||||
def stop_record(self):
|
||||
"""
|
||||
该函数执行完毕后才能调用start_talk和start_identify_op函数!
|
||||
:return:
|
||||
"""
|
||||
self.stop_asr_record()
|
||||
|
||||
def start_talk(self):
|
||||
"""
|
||||
仅对反方:开始立论
|
||||
:return:
|
||||
"""
|
||||
if self.side == "反方":
|
||||
self.stream_pipeline(self.ask(self.asr_res))
|
||||
|
||||
def start_identify_op(self):
|
||||
"""
|
||||
仅对正方:设定last_prompt,以便返回给下面的自由辩论模式
|
||||
:return:
|
||||
"""
|
||||
if self.side == "正方":
|
||||
self.last_prompt = self.asr_res
|
||||
|
||||
def ready_next(self):
|
||||
if self.side == "正方":
|
||||
self.start_identify_op()
|
||||
return self.last_prompt
|
||||
else:
|
||||
self.start_talk()
|
||||
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
流程:
|
||||
1. 正方:先说话,再录音
|
||||
调用本函数时,直接开始说话,说话完成后,当对方开始立论时,需显式调用start_record函数;结束立论时,需显式调用stop_record函数;
|
||||
当进入自由辩论环节时,需显式调用ready_next。如果为正方,需接收返回值,作为自由辩论环节提示词前面的反方立论内容。
|
||||
2. 反方:先录音,再说话
|
||||
调用本函数时,开始录音(因此需要在对方将要说话时切入此模式),当对方说完话后,应显式调用stop_record函数;本方开始立论时,应显式调用start_talk函数
|
||||
:return:
|
||||
"""
|
||||
super().run()
|
||||
if self.side == "正方":
|
||||
# 先说话再录音
|
||||
self.stream_pipeline(self.ask())
|
||||
else:
|
||||
# 先录音再说话。外界控制
|
||||
self.start_record()
|
||||
|
||||
|
||||
class FreeDebateMode(bm.Mode):
|
||||
def __init__(self, session_id, side="正方", topic="", opinion="", last_prompt="", threshold_no_speak=2,
|
||||
asr_callback=lambda x: print(x, end="", flush=True), tts_callback=lambda x: print(x, end="", flush=True)):
|
||||
super().__init__(
|
||||
asr_callback=asr_callback,
|
||||
tts_callback=tts_callback
|
||||
)
|
||||
self.threshold_no_speak = threshold_no_speak
|
||||
self.side = side
|
||||
self.topic = topic
|
||||
self.opinion = opinion
|
||||
self.last_prompt = last_prompt
|
||||
self.session_id = session_id
|
||||
self.is_first = True
|
||||
self.last_prompt_next_mode = None
|
||||
self.main_loop_thread = threading.Thread(target=self.main_loop)
|
||||
self.running = False
|
||||
|
||||
def ask_ai(self, prompt):
|
||||
biz_params = {
|
||||
"user_prompt_params" : {
|
||||
"side": self.side,
|
||||
"topic": self.topic,
|
||||
"opinion": self.opinion
|
||||
}
|
||||
}
|
||||
res = dashscope.Application.call(
|
||||
api_key=COM_API_KEY,
|
||||
app_id=APP_ID,
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
incremental_output=True,
|
||||
session_id=self.session_id,
|
||||
biz_params=biz_params
|
||||
)
|
||||
for chunk in res:
|
||||
if chunk.status_code == HTTPStatus.OK:
|
||||
self.session_id = chunk.output.session_id
|
||||
yield chunk.output.text
|
||||
|
||||
def init_ask(self):
|
||||
"""
|
||||
仅正方:用于将反方的立论传入,并让正方说话
|
||||
:return:
|
||||
"""
|
||||
if self.side == "正方":
|
||||
return self.ask_ai("反方:"+self.last_prompt+"\n主席:下面进入自由辩论环节,请正方开始发言")
|
||||
|
||||
def ask(self, context=None):
|
||||
if self.side == "正方":
|
||||
return self.ask_ai("反方:"+context)
|
||||
else:
|
||||
if self.is_first:
|
||||
self.is_first = False
|
||||
return self.ask_ai("主席:下面进入自由辩论环节,请正方开始发言\n正方:" + context)
|
||||
return self.ask_ai("正方:"+context)
|
||||
|
||||
def main_loop(self):
|
||||
# 初始ask
|
||||
if self.side == "正方":
|
||||
self.stream_pipeline(self.init_ask())
|
||||
# 主对话循环
|
||||
while self.running:
|
||||
# 1. 启动ASR服务
|
||||
self.ready_asr_session()
|
||||
# 2. 开始录音
|
||||
self.start_asr_record()
|
||||
# 3. 等待开始说话
|
||||
while self.last_asr_time is None: pass
|
||||
# 4. 等待说话时长过长
|
||||
while (time.time() - self.last_asr_time) <= self.threshold_no_speak: pass
|
||||
# 5. 关闭ASR
|
||||
self.stop_asr_record()
|
||||
# 6. 扔给AI
|
||||
self.stream_pipeline(self.ask(self.asr_res))
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
流程:
|
||||
类似对话模式
|
||||
当进入结论模式时,外界直接通过析构对象然后启动结论模式即可。但建议取一下last_prompt_next_mode字段
|
||||
:return:
|
||||
"""
|
||||
super().run()
|
||||
self.running = True
|
||||
self.main_loop_thread.start()
|
||||
|
||||
def stop(self):
|
||||
self.running = False
|
||||
self.tts_finish()
|
||||
|
||||
|
||||
class EndDebateMode(bm.Mode):
|
||||
def __init__(self, session_id, side="正方", topic="", opinion="", last_prompt="",
|
||||
asr_callback=lambda x: print(x, end="", flush=True), tts_callback=lambda x: print(x, end="", flush=True)):
|
||||
super().__init__(
|
||||
asr_callback=asr_callback,
|
||||
tts_callback=tts_callback
|
||||
)
|
||||
self.side = side
|
||||
self.topic = topic
|
||||
self.opinion = opinion
|
||||
self.last_prompt = last_prompt
|
||||
self.session_id = session_id
|
||||
|
||||
def ask_ai(self, prompt):
|
||||
biz_params = {
|
||||
"user_prompt_params": {
|
||||
"side": self.side,
|
||||
"topic": self.topic,
|
||||
"opinion": self.opinion
|
||||
}
|
||||
}
|
||||
res = dashscope.Application.call(
|
||||
api_key=COM_API_KEY,
|
||||
app_id=APP_ID,
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
incremental_output=True,
|
||||
session_id=self.session_id,
|
||||
biz_params=biz_params
|
||||
)
|
||||
for chunk in res:
|
||||
if chunk.status_code == HTTPStatus.OK:
|
||||
self.session_id = chunk.output.session_id
|
||||
yield chunk.output.text
|
||||
|
||||
def ask(self, context=None):
|
||||
if self.side == "反方":
|
||||
if self.last_prompt != "":
|
||||
return self.ask_ai("反方:"+self.last_prompt+"\n主席:下面进入结辩环节,请反方开始发言")
|
||||
else:
|
||||
return self.ask_ai("主席:下面进入结辩环节,请反方开始发言")
|
||||
else:
|
||||
return self.ask_ai("主席:下面进入结辩环节,请反方开始发言\n反方:"+context)
|
||||
|
||||
def start_record(self):
|
||||
self.ready_asr_session()
|
||||
self.start_asr_record()
|
||||
|
||||
def stop_record(self):
|
||||
"""
|
||||
该函数执行完毕后才能调用start_talk和start_identify_op函数!
|
||||
:return:
|
||||
"""
|
||||
self.stop_asr_record()
|
||||
|
||||
def start_talk(self):
|
||||
"""
|
||||
开始结论
|
||||
:return:
|
||||
"""
|
||||
if self.side == "反方":
|
||||
self.stream_pipeline(self.ask())
|
||||
else:
|
||||
self.stream_pipeline(self.ask(self.asr_res))
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
流程:
|
||||
1. 正方:当对方说话时,需外界调用start_record方法,说话结束后调用stop_record方法,再调用start_talk方法
|
||||
2. 反方:需外界调用start_talk方法,然后直接开始结辩
|
||||
走到这里,圆满结束!!!!
|
||||
:return:
|
||||
"""
|
||||
super().run()
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mode = DialogMode(
|
||||
asr_callback=lambda x: print(x, end="", flush=True),
|
||||
tts_callback=lambda x: print(x, end="", flush=True)
|
||||
)
|
||||
mode.run()
|
||||
time.sleep(200)
|
||||
Reference in New Issue
Block a user