Playground.js 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. import React, { useCallback, useContext, useEffect, useState } from 'react';
  2. import { useNavigate, useSearchParams } from 'react-router-dom';
  3. import { UserContext } from '../context/User';
  4. import { API, getUserIdFromLocalStorage, showError } from '../helpers';
  5. import { Card, Chat, Input, Layout, Select, Slider, TextArea, Typography } from '@douyinfe/semi-ui';
  6. import { SSE } from 'sse';
  7. const defaultMessage = [
  8. {
  9. role: 'user',
  10. id: '2',
  11. createAt: 1715676751919,
  12. content: "你好",
  13. },
  14. {
  15. role: 'assistant',
  16. id: '3',
  17. createAt: 1715676751919,
  18. content: "你好,请问有什么可以帮助您的吗?",
  19. }
  20. ];
  21. const roleInfo = {
  22. user: {
  23. name: 'User',
  24. avatar: 'https://lf3-static.bytednsdoc.com/obj/eden-cn/ptlz_zlp/ljhwZthlaukjlkulzlp/docs-icon.png'
  25. },
  26. assistant: {
  27. name: 'Assistant',
  28. avatar: 'logo.png'
  29. },
  30. system: {
  31. name: 'System',
  32. avatar: 'https://lf3-static.bytednsdoc.com/obj/eden-cn/ptlz_zlp/ljhwZthlaukjlkulzlp/other/logo.png'
  33. }
  34. }
  35. let id = 4;
  36. function getId() {
  37. return `${id++}`
  38. }
  39. const Playground = () => {
  40. const [inputs, setInputs] = useState({
  41. model: 'gpt-4o-mini',
  42. group: '',
  43. max_tokens: 0,
  44. temperature: 0,
  45. });
  46. const [searchParams, setSearchParams] = useSearchParams();
  47. const [userState, userDispatch] = useContext(UserContext);
  48. const [status, setStatus] = useState({});
  49. const [systemPrompt, setSystemPrompt] = useState('You are a helpful assistant. You can help me by answering my questions. You can also ask me questions.');
  50. const [message, setMessage] = useState(defaultMessage);
  51. const [models, setModels] = useState([]);
  52. const [groups, setGroups] = useState([]);
  53. const handleInputChange = (name, value) => {
  54. setInputs((inputs) => ({ ...inputs, [name]: value }));
  55. };
  56. useEffect(() => {
  57. if (searchParams.get('expired')) {
  58. showError('未登录或登录已过期,请重新登录!');
  59. }
  60. let status = localStorage.getItem('status');
  61. if (status) {
  62. status = JSON.parse(status);
  63. setStatus(status);
  64. }
  65. loadModels();
  66. loadGroups();
  67. }, []);
  68. const loadModels = async () => {
  69. let res = await API.get(`/api/user/models`);
  70. const { success, message, data } = res.data;
  71. if (success) {
  72. let localModelOptions = data.map((model) => ({
  73. label: model,
  74. value: model,
  75. }));
  76. setModels(localModelOptions);
  77. } else {
  78. showError(message);
  79. }
  80. };
  81. const loadGroups = async () => {
  82. let res = await API.get(`/api/user/self/groups`);
  83. const { success, message, data } = res.data;
  84. if (success) {
  85. // return data is a map, key is group name, value is group description
  86. // label is group description, value is group name
  87. let localGroupOptions = Object.keys(data).map((group) => ({
  88. label: data[group],
  89. value: group,
  90. }));
  91. // handleInputChange('group', localGroupOptions[0].value);
  92. if (localGroupOptions.length > 0) {
  93. // set default group at first
  94. localGroupOptions.unshift({
  95. label: '用户分组',
  96. value: '',
  97. });
  98. } else {
  99. localGroupOptions = [{
  100. label: '用户分组',
  101. value: '',
  102. }];
  103. setGroups(localGroupOptions);
  104. }
  105. setGroups(localGroupOptions);
  106. handleInputChange('group', localGroupOptions[0].value);
  107. } else {
  108. showError(message);
  109. }
  110. };
  111. const commonOuterStyle = {
  112. border: '1px solid var(--semi-color-border)',
  113. borderRadius: '16px',
  114. margin: '0px 8px',
  115. }
  116. const getSystemMessage = () => {
  117. if (systemPrompt !== '') {
  118. return {
  119. role: 'system',
  120. id: '1',
  121. createAt: 1715676751919,
  122. content: systemPrompt,
  123. }
  124. }
  125. }
  126. let handleSSE = (payload) => {
  127. let source = new SSE('/pg/chat/completions', {
  128. headers: {
  129. "Content-Type": "application/json",
  130. "New-Api-User": getUserIdFromLocalStorage(),
  131. },
  132. method: "POST",
  133. payload: JSON.stringify(payload),
  134. });
  135. source.addEventListener("message", (e) => {
  136. if (e.data !== "[DONE]") {
  137. let payload = JSON.parse(e.data);
  138. // console.log("Payload: ", payload);
  139. if (payload.choices.length === 0) {
  140. source.close();
  141. completeMessage();
  142. } else {
  143. let text = payload.choices[0].delta.content;
  144. if (text) {
  145. generateMockResponse(text);
  146. }
  147. }
  148. } else {
  149. completeMessage();
  150. }
  151. });
  152. source.addEventListener("error", (e) => {
  153. generateMockResponse(e.data)
  154. completeMessage('error')
  155. });
  156. source.addEventListener("readystatechange", (e) => {
  157. if (e.readyState >= 2) {
  158. if (source.status === undefined) {
  159. source.close();
  160. completeMessage();
  161. }
  162. }
  163. });
  164. source.stream();
  165. }
  166. const onMessageSend = useCallback((content, attachment) => {
  167. console.log("attachment: ", attachment);
  168. setMessage((prevMessage) => {
  169. const newMessage = [
  170. ...prevMessage,
  171. {
  172. role: 'user',
  173. content: content,
  174. createAt: Date.now(),
  175. id: getId()
  176. }
  177. ];
  178. // 将 getPayload 移到这里
  179. const getPayload = () => {
  180. let systemMessage = getSystemMessage();
  181. let messages = newMessage.map((item) => {
  182. return {
  183. role: item.role,
  184. content: item.content,
  185. }
  186. });
  187. if (systemMessage) {
  188. messages.unshift(systemMessage);
  189. }
  190. return {
  191. messages: messages,
  192. stream: true,
  193. model: inputs.model,
  194. group: inputs.group,
  195. max_tokens: parseInt(inputs.max_tokens),
  196. temperature: inputs.temperature,
  197. };
  198. };
  199. // 使用更新后的消息状态调用 handleSSE
  200. handleSSE(getPayload());
  201. newMessage.push({
  202. role: 'assistant',
  203. content: '',
  204. createAt: Date.now(),
  205. id: getId(),
  206. status: 'loading'
  207. });
  208. return newMessage;
  209. });
  210. }, [getSystemMessage]);
  211. const completeMessage = useCallback((status = 'complete') => {
  212. // console.log("Complete Message: ", status)
  213. setMessage((prevMessage) => {
  214. const lastMessage = prevMessage[prevMessage.length - 1];
  215. // only change the status if the last message is not complete and not error
  216. if (lastMessage.status === 'complete' || lastMessage.status === 'error') {
  217. return prevMessage;
  218. }
  219. return [
  220. ...prevMessage.slice(0, -1),
  221. { ...lastMessage, status: status }
  222. ];
  223. });
  224. }, [])
  225. const generateMockResponse = useCallback((content) => {
  226. // console.log("Generate Mock Response: ", content);
  227. setMessage((message) => {
  228. const lastMessage = message[message.length - 1];
  229. let newMessage = {...lastMessage};
  230. if (lastMessage.status === 'loading' || lastMessage.status === 'incomplete') {
  231. newMessage = {
  232. ...newMessage,
  233. content: (lastMessage.content || '') + content,
  234. status: 'incomplete'
  235. }
  236. }
  237. return [ ...message.slice(0, -1), newMessage ]
  238. })
  239. }, []);
  240. return (
  241. <Layout style={{height: '100%'}}>
  242. <Layout.Sider>
  243. <Card style={commonOuterStyle}>
  244. <div style={{ marginTop: 10 }}>
  245. <Typography.Text strong>分组:</Typography.Text>
  246. </div>
  247. <Select
  248. placeholder={'请选择分组'}
  249. name='group'
  250. required
  251. selection
  252. onChange={(value) => {
  253. handleInputChange('group', value);
  254. }}
  255. value={inputs.group}
  256. autoComplete='new-password'
  257. optionList={groups}
  258. />
  259. <div style={{ marginTop: 10 }}>
  260. <Typography.Text strong>模型:</Typography.Text>
  261. </div>
  262. <Select
  263. placeholder={'请选择模型'}
  264. name='model'
  265. required
  266. selection
  267. filter
  268. onChange={(value) => {
  269. handleInputChange('model', value);
  270. }}
  271. value={inputs.model}
  272. autoComplete='new-password'
  273. optionList={models}
  274. />
  275. <div style={{ marginTop: 10 }}>
  276. <Typography.Text strong>Temperature:</Typography.Text>
  277. </div>
  278. <Slider
  279. step={0.1}
  280. min={0.1}
  281. max={1}
  282. value={inputs.temperature}
  283. onChange={(value) => {
  284. handleInputChange('temperature', value);
  285. }}
  286. />
  287. <div style={{ marginTop: 10 }}>
  288. <Typography.Text strong>MaxTokens:</Typography.Text>
  289. </div>
  290. <Input
  291. placeholder='MaxTokens'
  292. name='max_tokens'
  293. required
  294. autoComplete='new-password'
  295. defaultValue={0}
  296. value={inputs.max_tokens}
  297. onChange={(value) => {
  298. handleInputChange('max_tokens', value);
  299. }}
  300. />
  301. <div style={{ marginTop: 10 }}>
  302. <Typography.Text strong>System:</Typography.Text>
  303. </div>
  304. <TextArea
  305. placeholder='System Prompt'
  306. name='system'
  307. required
  308. autoComplete='new-password'
  309. autosize
  310. defaultValue={systemPrompt}
  311. // value={systemPrompt}
  312. onChange={(value) => {
  313. setSystemPrompt(value);
  314. }}
  315. />
  316. </Card>
  317. </Layout.Sider>
  318. <Layout.Content>
  319. <div style={{height: '100%'}}>
  320. <Chat
  321. chatBoxRenderConfig={{
  322. renderChatBoxAction: () => {
  323. return <div></div>
  324. }
  325. }}
  326. roleConfig={roleInfo}
  327. style={commonOuterStyle}
  328. chats={message}
  329. onMessageSend={onMessageSend}
  330. showClearContext
  331. onClear={() => {
  332. setMessage([]);
  333. }}
  334. />
  335. </div>
  336. </Layout.Content>
  337. </Layout>
  338. );
  339. };
  340. export default Playground;