Playground.js 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. import React, { useCallback, useContext, useEffect, useState } from 'react';
  2. import { useNavigate, useSearchParams } from 'react-router-dom';
  3. import { UserContext } from '../context/User';
  4. import { API, getUserIdFromLocalStorage, showError } from '../helpers';
  5. import { Card, Chat, Input, Layout, Select, Slider, TextArea, Typography } from '@douyinfe/semi-ui';
  6. import { SSE } from 'sse';
  7. const defaultMessage = [
  8. {
  9. role: 'user',
  10. id: '2',
  11. createAt: 1715676751919,
  12. content: "你好",
  13. },
  14. {
  15. role: 'assistant',
  16. id: '3',
  17. createAt: 1715676751919,
  18. content: "你好,请问有什么可以帮助您的吗?",
  19. }
  20. ];
  21. let id = 4;
  22. function getId() {
  23. return `${id++}`
  24. }
  25. const Playground = () => {
  26. const [inputs, setInputs] = useState({
  27. model: 'gpt-4o-mini',
  28. group: '',
  29. max_tokens: 0,
  30. temperature: 0,
  31. });
  32. const [searchParams, setSearchParams] = useSearchParams();
  33. const [userState, userDispatch] = useContext(UserContext);
  34. const [status, setStatus] = useState({});
  35. const [systemPrompt, setSystemPrompt] = useState('You are a helpful assistant. You can help me by answering my questions. You can also ask me questions.');
  36. const [message, setMessage] = useState(defaultMessage);
  37. const [models, setModels] = useState([]);
  38. const [groups, setGroups] = useState([]);
  39. const handleInputChange = (name, value) => {
  40. setInputs((inputs) => ({ ...inputs, [name]: value }));
  41. };
  42. useEffect(() => {
  43. if (searchParams.get('expired')) {
  44. showError('未登录或登录已过期,请重新登录!');
  45. }
  46. let status = localStorage.getItem('status');
  47. if (status) {
  48. status = JSON.parse(status);
  49. setStatus(status);
  50. }
  51. loadModels();
  52. loadGroups();
  53. }, []);
  54. const loadModels = async () => {
  55. let res = await API.get(`/api/user/models`);
  56. const { success, message, data } = res.data;
  57. if (success) {
  58. let localModelOptions = data.map((model) => ({
  59. label: model,
  60. value: model,
  61. }));
  62. setModels(localModelOptions);
  63. } else {
  64. showError(message);
  65. }
  66. };
  67. const loadGroups = async () => {
  68. let res = await API.get(`/api/user/groups`);
  69. const { success, message, data } = res.data;
  70. if (success) {
  71. // return data is a map, key is group name, value is group description
  72. // label is group description, value is group name
  73. let localGroupOptions = Object.keys(data).map((group) => ({
  74. label: data[group],
  75. value: group,
  76. }));
  77. // handleInputChange('group', localGroupOptions[0].value);
  78. if (localGroupOptions.length > 0) {
  79. } else {
  80. localGroupOptions = [{
  81. label: '用户分组',
  82. value: '',
  83. }];
  84. setGroups(localGroupOptions);
  85. }
  86. setGroups(localGroupOptions);
  87. handleInputChange('group', localGroupOptions[0].value);
  88. } else {
  89. showError(message);
  90. }
  91. };
  92. const commonOuterStyle = {
  93. border: '1px solid var(--semi-color-border)',
  94. borderRadius: '16px',
  95. margin: '0px 8px',
  96. }
  97. const getSystemMessage = () => {
  98. if (systemPrompt !== '') {
  99. return {
  100. role: 'system',
  101. id: '1',
  102. createAt: 1715676751919,
  103. content: systemPrompt,
  104. }
  105. }
  106. }
  107. let handleSSE = (payload) => {
  108. let source = new SSE('/pg/chat/completions', {
  109. headers: {
  110. "Content-Type": "application/json",
  111. "New-Api-User": getUserIdFromLocalStorage(),
  112. },
  113. method: "POST",
  114. payload: JSON.stringify(payload),
  115. });
  116. source.addEventListener("message", (e) => {
  117. if (e.data !== "[DONE]") {
  118. let payload = JSON.parse(e.data);
  119. // console.log("Payload: ", payload);
  120. if (payload.choices.length === 0) {
  121. source.close();
  122. completeMessage();
  123. } else {
  124. let text = payload.choices[0].delta.content;
  125. if (text) {
  126. generateMockResponse(text);
  127. }
  128. }
  129. } else {
  130. completeMessage();
  131. }
  132. });
  133. source.addEventListener("error", (e) => {
  134. generateMockResponse(e.data)
  135. completeMessage('error')
  136. });
  137. source.addEventListener("readystatechange", (e) => {
  138. if (e.readyState >= 2) {
  139. if (source.status === undefined) {
  140. source.close();
  141. completeMessage();
  142. }
  143. }
  144. });
  145. source.stream();
  146. }
  147. const onMessageSend = useCallback((content, attachment) => {
  148. console.log("attachment: ", attachment);
  149. setMessage((prevMessage) => {
  150. const newMessage = [
  151. ...prevMessage,
  152. {
  153. role: 'user',
  154. content: content,
  155. createAt: Date.now(),
  156. id: getId()
  157. }
  158. ];
  159. // 将 getPayload 移到这里
  160. const getPayload = () => {
  161. let systemMessage = getSystemMessage();
  162. let messages = newMessage.map((item) => {
  163. return {
  164. role: item.role,
  165. content: item.content,
  166. }
  167. });
  168. if (systemMessage) {
  169. messages.unshift(systemMessage);
  170. }
  171. return {
  172. messages: messages,
  173. stream: true,
  174. model: inputs.model,
  175. group: inputs.group,
  176. max_tokens: parseInt(inputs.max_tokens),
  177. temperature: inputs.temperature,
  178. };
  179. };
  180. // 使用更新后的消息状态调用 handleSSE
  181. handleSSE(getPayload());
  182. newMessage.push({
  183. role: 'assistant',
  184. content: '',
  185. createAt: Date.now(),
  186. id: getId(),
  187. status: 'loading'
  188. });
  189. return newMessage;
  190. });
  191. }, [getSystemMessage]);
  192. const completeMessage = useCallback((status = 'complete') => {
  193. // console.log("Complete Message: ", status)
  194. setMessage((prevMessage) => {
  195. const lastMessage = prevMessage[prevMessage.length - 1];
  196. // only change the status if the last message is not complete and not error
  197. if (lastMessage.status === 'complete' || lastMessage.status === 'error') {
  198. return prevMessage;
  199. }
  200. return [
  201. ...prevMessage.slice(0, -1),
  202. { ...lastMessage, status: status }
  203. ];
  204. });
  205. }, [])
  206. const generateMockResponse = useCallback((content) => {
  207. // console.log("Generate Mock Response: ", content);
  208. setMessage((message) => {
  209. const lastMessage = message[message.length - 1];
  210. let newMessage = {...lastMessage};
  211. if (lastMessage.status === 'loading' || lastMessage.status === 'incomplete') {
  212. newMessage = {
  213. ...newMessage,
  214. content: (lastMessage.content || '') + content,
  215. status: 'incomplete'
  216. }
  217. }
  218. return [ ...message.slice(0, -1), newMessage ]
  219. })
  220. }, []);
  221. return (
  222. <Layout style={{height: '100%'}}>
  223. <Layout.Sider>
  224. <Card style={commonOuterStyle}>
  225. <div style={{ marginTop: 10 }}>
  226. <Typography.Text strong>分组:</Typography.Text>
  227. </div>
  228. <Select
  229. placeholder={'请选择分组'}
  230. name='group'
  231. required
  232. selection
  233. onChange={(value) => {
  234. handleInputChange('group', value);
  235. }}
  236. value={inputs.group}
  237. autoComplete='new-password'
  238. optionList={groups}
  239. />
  240. <div style={{ marginTop: 10 }}>
  241. <Typography.Text strong>模型:</Typography.Text>
  242. </div>
  243. <Select
  244. placeholder={'请选择模型'}
  245. name='model'
  246. required
  247. selection
  248. filter
  249. onChange={(value) => {
  250. handleInputChange('model', value);
  251. }}
  252. value={inputs.model}
  253. autoComplete='new-password'
  254. optionList={models}
  255. />
  256. <div style={{ marginTop: 10 }}>
  257. <Typography.Text strong>Temperature:</Typography.Text>
  258. </div>
  259. <Slider
  260. step={0.1}
  261. min={0.1}
  262. max={1}
  263. value={inputs.temperature}
  264. onChange={(value) => {
  265. handleInputChange('temperature', value);
  266. }}
  267. />
  268. <div style={{ marginTop: 10 }}>
  269. <Typography.Text strong>MaxTokens:</Typography.Text>
  270. </div>
  271. <Input
  272. placeholder='MaxTokens'
  273. name='max_tokens'
  274. required
  275. autoComplete='new-password'
  276. defaultValue={0}
  277. value={inputs.max_tokens}
  278. onChange={(value) => {
  279. handleInputChange('max_tokens', value);
  280. }}
  281. />
  282. <div style={{ marginTop: 10 }}>
  283. <Typography.Text strong>System:</Typography.Text>
  284. </div>
  285. <TextArea
  286. placeholder='System Prompt'
  287. name='system'
  288. required
  289. autoComplete='new-password'
  290. autosize
  291. defaultValue={systemPrompt}
  292. // value={systemPrompt}
  293. onChange={(value) => {
  294. setSystemPrompt(value);
  295. }}
  296. />
  297. </Card>
  298. </Layout.Sider>
  299. <Layout.Content>
  300. <div style={{height: '100%'}}>
  301. <Chat
  302. chatBoxRenderConfig={{
  303. renderChatBoxAction: () => {
  304. return <div></div>
  305. }
  306. }}
  307. style={commonOuterStyle}
  308. chats={message}
  309. onMessageSend={onMessageSend}
  310. showClearContext
  311. onClear={() => {
  312. setMessage([]);
  313. }}
  314. />
  315. </div>
  316. </Layout.Content>
  317. </Layout>
  318. );
  319. };
  320. export default Playground;