useDialogueEngine.ts 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. import { ref, reactive, computed, onUnmounted } from 'vue'
  2. import type { PreviewChatMessage, DialogueAPI, SessionConfig, DialogueReport } from '@/types/englishSpeaking'
  3. import { MockDialogueAPI, RealDialogueAPI } from '../services/llmService'
  4. export function useDialogueEngine(mode: 'preview' | 'real' = 'preview') {
  5. const messages = ref<PreviewChatMessage[]>([])
  6. const sessionId = ref<string | null>(null)
  7. const expiresAt = ref<string | null>(null)
  8. const currentRound = ref(1)
  9. const isComplete = ref(false)
  10. const countdownSeconds = ref<number | null>(null)
  11. let api: DialogueAPI = mode === 'real' ? new RealDialogueAPI() : new MockDialogueAPI()
  12. let currentAbortController: AbortController | null = null
  13. let countdownTimer: ReturnType<typeof setInterval> | null = null
  14. let ttsUtterance: SpeechSynthesisUtterance | null = null
  15. const isProcessing = computed(() => messages.value.some(m => m.status === 'loading'))
  16. const canRecord = computed(() => !isProcessing.value && !isComplete.value)
  17. // ==================== Session ====================
  18. async function initSession(config: SessionConfig) {
  19. try {
  20. const info = await api.createSession(config)
  21. sessionId.value = info.sessionId
  22. expiresAt.value = info.expiresAt || null
  23. messages.value.push({
  24. id: crypto.randomUUID(),
  25. role: 'ai',
  26. content: info.aiMessage,
  27. timestamp: new Date(),
  28. status: 'done',
  29. })
  30. if (info.expiresAt) startCountdown(info.expiresAt)
  31. speakTTS(info.aiMessage)
  32. } catch (err: any) {
  33. console.error('Failed to init session:', err)
  34. }
  35. }
  36. // ==================== Send Message ====================
  37. async function sendStudentMessage(audioBlob: Blob) {
  38. if (!sessionId.value || isProcessing.value) return
  39. // Add student message (loading)
  40. const studentMsg = reactive<PreviewChatMessage>({
  41. id: crypto.randomUUID(),
  42. role: 'student',
  43. content: '',
  44. timestamp: new Date(),
  45. status: 'loading',
  46. audioBlob,
  47. })
  48. messages.value.push(studentMsg)
  49. // Add AI message placeholder
  50. const aiMsg = reactive<PreviewChatMessage>({
  51. id: crypto.randomUUID(),
  52. role: 'ai',
  53. content: '',
  54. timestamp: new Date(),
  55. status: 'loading',
  56. })
  57. currentAbortController = new AbortController()
  58. try {
  59. const generator = api.speak(sessionId.value, audioBlob, currentAbortController.signal)
  60. for await (const event of generator) {
  61. if (event.type === 'transcript') {
  62. studentMsg.content = event.text
  63. studentMsg.status = 'done'
  64. // Now push AI message placeholder
  65. messages.value.push(aiMsg)
  66. } else if (event.type === 'token') {
  67. aiMsg.content += event.text
  68. } else if (event.type === 'done') {
  69. aiMsg.status = 'done'
  70. isComplete.value = event.isComplete
  71. if (!event.isComplete) {
  72. currentRound.value++
  73. }
  74. speakTTS(aiMsg.content)
  75. }
  76. }
  77. // If student message never got transcript, mark done with fallback
  78. if (studentMsg.status === 'loading') {
  79. studentMsg.status = 'done'
  80. }
  81. if (aiMsg.status === 'loading') {
  82. aiMsg.status = 'done'
  83. }
  84. } catch (err: any) {
  85. if (err.name === 'AbortError') return
  86. // Determine which message to mark as error
  87. if (studentMsg.status === 'loading') {
  88. studentMsg.status = 'error'
  89. studentMsg.error = err.message || 'Request failed'
  90. } else if (aiMsg.status === 'loading') {
  91. aiMsg.status = 'error'
  92. aiMsg.error = err.message || 'Request failed'
  93. }
  94. } finally {
  95. currentAbortController = null
  96. }
  97. }
  98. // ==================== Retry / Regenerate ====================
  99. async function retryMessage(messageId: string) {
  100. const msg = messages.value.find(m => m.id === messageId)
  101. if (!msg || msg.status !== 'error') return
  102. if (msg.role === 'student' && msg.audioBlob) {
  103. // Remove the failed student message and any subsequent AI message
  104. const idx = messages.value.indexOf(msg)
  105. messages.value.splice(idx)
  106. await sendStudentMessage(msg.audioBlob)
  107. }
  108. }
  109. async function regenerateAiMessage(messageId: string) {
  110. const msg = messages.value.find(m => m.id === messageId)
  111. if (!msg || msg.role !== 'ai' || msg.status !== 'error') return
  112. // Find the student message before this AI message
  113. const idx = messages.value.indexOf(msg)
  114. const prevStudent = messages.value.slice(0, idx).reverse().find(m => m.role === 'student')
  115. if (!prevStudent?.audioBlob || !sessionId.value) return
  116. // Remove the failed AI message
  117. messages.value.splice(idx, 1)
  118. // Re-add AI placeholder and stream
  119. const aiMsg = reactive<PreviewChatMessage>({
  120. id: crypto.randomUUID(),
  121. role: 'ai',
  122. content: '',
  123. timestamp: new Date(),
  124. status: 'loading',
  125. })
  126. messages.value.push(aiMsg)
  127. currentAbortController = new AbortController()
  128. try {
  129. const generator = api.speak(sessionId.value, prevStudent.audioBlob, currentAbortController.signal)
  130. for await (const event of generator) {
  131. if (event.type === 'transcript') {
  132. // Skip transcript on regenerate, student message already exists
  133. } else if (event.type === 'token') {
  134. aiMsg.content += event.text
  135. } else if (event.type === 'done') {
  136. aiMsg.status = 'done'
  137. isComplete.value = event.isComplete
  138. if (!event.isComplete) currentRound.value++
  139. speakTTS(aiMsg.content)
  140. }
  141. }
  142. if (aiMsg.status === 'loading') aiMsg.status = 'done'
  143. } catch (err: any) {
  144. if (err.name === 'AbortError') return
  145. aiMsg.status = 'error'
  146. aiMsg.error = err.message || 'Request failed'
  147. } finally {
  148. currentAbortController = null
  149. }
  150. }
  151. // ==================== Report ====================
  152. function getReport(): Promise<DialogueReport> {
  153. if (!sessionId.value) return Promise.reject(new Error('No session'))
  154. return new Promise((resolve, reject) => {
  155. let attempts = 0
  156. const maxAttempts = 15 // 30s / 2s
  157. const poll = async () => {
  158. attempts++
  159. try {
  160. const report = await api.getReport(sessionId.value!)
  161. resolve(report)
  162. } catch {
  163. if (attempts >= maxAttempts) {
  164. reject(new Error('Report timeout'))
  165. } else {
  166. setTimeout(poll, 2000)
  167. }
  168. }
  169. }
  170. poll()
  171. })
  172. }
  173. // ==================== TTS ====================
  174. function speakTTS(text: string) {
  175. if (!text || typeof speechSynthesis === 'undefined') return
  176. cancelTTS()
  177. ttsUtterance = new SpeechSynthesisUtterance(text)
  178. ttsUtterance.lang = 'en-US'
  179. ttsUtterance.rate = 0.9
  180. speechSynthesis.speak(ttsUtterance)
  181. }
  182. function cancelTTS() {
  183. if (typeof speechSynthesis !== 'undefined') {
  184. speechSynthesis.cancel()
  185. }
  186. ttsUtterance = null
  187. }
  188. // ==================== Countdown ====================
  189. function startCountdown(expiresAtStr: string) {
  190. stopCountdown()
  191. const update = () => {
  192. const remaining = Math.max(0, Math.floor((new Date(expiresAtStr).getTime() - Date.now()) / 1000))
  193. countdownSeconds.value = remaining
  194. if (remaining <= 0) {
  195. stopCountdown()
  196. isComplete.value = true
  197. }
  198. }
  199. update()
  200. countdownTimer = setInterval(update, 1000)
  201. }
  202. function stopCountdown() {
  203. if (countdownTimer) { clearInterval(countdownTimer); countdownTimer = null }
  204. countdownSeconds.value = null
  205. }
  206. // ==================== Abort ====================
  207. function abort() {
  208. currentAbortController?.abort()
  209. currentAbortController = null
  210. }
  211. // ==================== Streaming Speak (WebSocket) ====================
  212. /** 流式开始:立即 push 学生占位消息,返回 controller 供外部推 chunk / 结束 */
  213. function beginStudentStream(opts: {
  214. sampleRate: number
  215. bits?: number
  216. channels?: number
  217. }) {
  218. if (!sessionId.value || isProcessing.value) {
  219. return null
  220. }
  221. // 立即占位:录音按完成的那一刻 UI 就已经显示学生泡泡 + AI placeholder
  222. const studentMsg = reactive<PreviewChatMessage>({
  223. id: crypto.randomUUID(),
  224. role: 'student',
  225. content: '',
  226. timestamp: new Date(),
  227. status: 'loading',
  228. })
  229. messages.value.push(studentMsg)
  230. const aiMsg = reactive<PreviewChatMessage>({
  231. id: crypto.randomUUID(),
  232. role: 'ai',
  233. content: '',
  234. timestamp: new Date(),
  235. status: 'loading',
  236. })
  237. messages.value.push(aiMsg)
  238. const wsUrl = buildWsUrl('/speak-stream')
  239. const ws = new WebSocket(wsUrl)
  240. ws.binaryType = 'arraybuffer'
  241. let aborted = false
  242. let chunkQueue: ArrayBuffer[] = []
  243. let open = false
  244. const finalizeError = (msg: string) => {
  245. if (studentMsg.status === 'loading') {
  246. studentMsg.status = 'error'
  247. studentMsg.error = msg
  248. }
  249. else if (aiMsg.status === 'loading') {
  250. aiMsg.status = 'error'
  251. aiMsg.error = msg
  252. }
  253. }
  254. ws.onopen = () => {
  255. open = true
  256. ws.send(JSON.stringify({
  257. type: 'start',
  258. sessionId: sessionId.value,
  259. sampleRate: opts.sampleRate,
  260. bits: opts.bits ?? 16,
  261. channels: opts.channels ?? 1,
  262. }))
  263. // flush 队列里攒的 chunk
  264. for (const c of chunkQueue) ws.send(c)
  265. chunkQueue = []
  266. }
  267. ws.onmessage = (e: MessageEvent) => {
  268. try {
  269. const data = JSON.parse(e.data)
  270. if (data.type === 'transcript') {
  271. studentMsg.content = data.text
  272. studentMsg.status = 'done'
  273. }
  274. else if (data.type === 'token') {
  275. aiMsg.content += data.content
  276. }
  277. else if (data.type === 'done') {
  278. aiMsg.status = 'done'
  279. isComplete.value = !!data.isComplete
  280. if (!data.isComplete) currentRound.value++
  281. speakTTS(aiMsg.content)
  282. ws.close()
  283. }
  284. else if (data.type === 'error') {
  285. finalizeError(friendlyErrorMessage(data.message))
  286. ws.close()
  287. }
  288. } catch { /* ignore */ }
  289. }
  290. ws.onerror = () => {
  291. if (!aborted) finalizeError(friendlyErrorMessage('WebSocket error'))
  292. }
  293. ws.onclose = () => {
  294. if (studentMsg.status === 'loading') finalizeError(friendlyErrorMessage('Connection closed'))
  295. else if (aiMsg.status === 'loading') finalizeError(friendlyErrorMessage('Connection closed'))
  296. }
  297. const pushChunk = (chunk: ArrayBuffer) => {
  298. if (aborted) return
  299. if (open && ws.readyState === WebSocket.OPEN) ws.send(chunk)
  300. else chunkQueue.push(chunk)
  301. }
  302. const finish = () => {
  303. if (aborted) return
  304. if (open && ws.readyState === WebSocket.OPEN) {
  305. ws.send(JSON.stringify({ type: 'stop' }))
  306. } else {
  307. // 还没 open 就被要求结束 → 废话不说 直接 close
  308. ws.close()
  309. }
  310. }
  311. const abortStream = () => {
  312. aborted = true
  313. try { ws.close() } catch { /* ignore */ }
  314. finalizeError('Aborted')
  315. }
  316. // 便于重试:保存关联的学生消息 id
  317. return { studentMsgId: studentMsg.id, aiMsgId: aiMsg.id, pushChunk, finish, abortStream }
  318. }
  319. /**
  320. * 流式失败时的 HTTP fallback:用完整 audioBlob 走旧 /speak 路径。
  321. * 会把 beginStudentStream 已 push 的占位消息回收(避免重复)。
  322. */
  323. async function streamFallback(audioBlob: Blob, studentMsgId: string, aiMsgId: string) {
  324. // 移除占位消息
  325. messages.value = messages.value.filter(m => m.id !== studentMsgId && m.id !== aiMsgId)
  326. // 走旧流程
  327. await sendStudentMessage(audioBlob)
  328. }
  329. // ==================== Cleanup ====================
  330. onUnmounted(() => {
  331. abort()
  332. cancelTTS()
  333. stopCountdown()
  334. })
  335. return {
  336. messages,
  337. sessionId,
  338. currentRound,
  339. isComplete,
  340. isProcessing,
  341. canRecord,
  342. countdownSeconds,
  343. initSession,
  344. sendStudentMessage,
  345. beginStudentStream,
  346. streamFallback,
  347. retryMessage,
  348. regenerateAiMessage,
  349. getReport,
  350. abort,
  351. cancelTTS,
  352. }
  353. }
  354. // ==================== Helpers ====================
  355. /** 把 /api/speaking/dialogue/... 的 HTTP base URL 转成 ws/wss URL */
  356. function buildWsUrl(path: string): string {
  357. const API_BASE = 'http://localhost:8000/api/speaking/dialogue'
  358. const wsBase = API_BASE.replace(/^http/, 'ws')
  359. return wsBase + path
  360. }
  361. /** 把后端英文错误转成面向用户的中文友好文案 */
  362. function friendlyErrorMessage(raw: string | undefined): string {
  363. const map: Record<string, string> = {
  364. 'No speech detected': '没听清,请再说一次',
  365. 'Session not found': '会话已失效,请刷新页面重新开始',
  366. 'Session is not active': '会话已结束',
  367. 'sessionId required': '会话参数缺失',
  368. 'First message must be start': '协议异常,请刷新重试',
  369. 'Invalid start payload': '协议异常,请刷新重试',
  370. 'Internal error': '服务暂时不可用,请稍后重试',
  371. 'WebSocket error': '网络连接异常',
  372. 'Connection closed': '连接中断,请重试',
  373. 'Aborted': '已取消',
  374. }
  375. if (!raw) return '请求失败,请重试'
  376. return map[raw] || raw
  377. }