You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
	
	
		
			322 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			TypeScript
		
	
		
		
			
		
	
	
			322 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			TypeScript
		
	
| 
											9 months ago
										 | import { | ||
|  |   ApiPath, | ||
|  |   Google, | ||
|  |   REQUEST_TIMEOUT_MS, | ||
|  |   REQUEST_TIMEOUT_MS_FOR_THINKING, | ||
|  | } from "@/app/constant"; | ||
|  | import { | ||
|  |   ChatOptions, | ||
|  |   getHeaders, | ||
|  |   LLMApi, | ||
|  |   LLMModel, | ||
|  |   LLMUsage, | ||
|  |   SpeechOptions, | ||
|  | } from "../api"; | ||
|  | import { | ||
|  |   useAccessStore, | ||
|  |   useAppConfig, | ||
|  |   useChatStore, | ||
|  |   usePluginStore, | ||
|  |   ChatMessageTool, | ||
|  | } from "@/app/store"; | ||
|  | import { stream } from "@/app/utils/chat"; | ||
|  | import { getClientConfig } from "@/app/config/client"; | ||
|  | import { GEMINI_BASE_URL } from "@/app/constant"; | ||
|  | 
 | ||
|  | import { | ||
|  |   getMessageTextContent, | ||
|  |   getMessageImages, | ||
|  |   isVisionModel, | ||
|  | } from "@/app/utils"; | ||
|  | import { preProcessImageContent } from "@/app/utils/chat"; | ||
|  | import { nanoid } from "nanoid"; | ||
|  | import { RequestPayload } from "./openai"; | ||
|  | import { fetch } from "@/app/utils/stream"; | ||
|  | 
 | ||
|  | export class GeminiProApi implements LLMApi { | ||
|  |   path(path: string, shouldStream = false): string { | ||
|  |     const accessStore = useAccessStore.getState(); | ||
|  | 
 | ||
|  |     let baseUrl = ""; | ||
|  |     if (accessStore.useCustomConfig) { | ||
|  |       baseUrl = accessStore.googleUrl; | ||
|  |     } | ||
|  | 
 | ||
|  |     const isApp = !!getClientConfig()?.isApp; | ||
|  |     if (baseUrl.length === 0) { | ||
|  |       baseUrl = isApp ? GEMINI_BASE_URL : ApiPath.Google; | ||
|  |     } | ||
|  |     if (baseUrl.endsWith("/")) { | ||
|  |       baseUrl = baseUrl.slice(0, baseUrl.length - 1); | ||
|  |     } | ||
|  |     if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.Google)) { | ||
|  |       baseUrl = "https://" + baseUrl; | ||
|  |     } | ||
|  | 
 | ||
|  |     console.log("[Proxy Endpoint] ", baseUrl, path); | ||
|  | 
 | ||
|  |     let chatPath = [baseUrl, path].join("/"); | ||
|  |     if (shouldStream) { | ||
|  |       chatPath += chatPath.includes("?") ? "&alt=sse" : "?alt=sse"; | ||
|  |     } | ||
|  | 
 | ||
|  |     return chatPath; | ||
|  |   } | ||
|  |   extractMessage(res: any) { | ||
|  |     console.log("[Response] gemini-pro response: ", res); | ||
|  | 
 | ||
|  |     const getTextFromParts = (parts: any[]) => { | ||
|  |       if (!Array.isArray(parts)) return ""; | ||
|  | 
 | ||
|  |       return parts | ||
|  |         .map((part) => part?.text || "") | ||
|  |         .filter((text) => text.trim() !== "") | ||
|  |         .join("\n\n"); | ||
|  |     }; | ||
|  | 
 | ||
|  |     let content = ""; | ||
|  |     if (Array.isArray(res)) { | ||
|  |       res.map((item) => { | ||
|  |         content += getTextFromParts(item?.candidates?.at(0)?.content?.parts); | ||
|  |       }); | ||
|  |     } | ||
|  | 
 | ||
|  |     return ( | ||
|  |       getTextFromParts(res?.candidates?.at(0)?.content?.parts) || | ||
|  |       content || //getTextFromParts(res?.at(0)?.candidates?.at(0)?.content?.parts) ||
 | ||
|  |       res?.error?.message || | ||
|  |       "" | ||
|  |     ); | ||
|  |   } | ||
|  |   speech(options: SpeechOptions): Promise<ArrayBuffer> { | ||
|  |     throw new Error("Method not implemented."); | ||
|  |   } | ||
|  | 
 | ||
|  |   async chat(options: ChatOptions): Promise<void> { | ||
|  |     const apiClient = this; | ||
|  |     let multimodal = false; | ||
|  | 
 | ||
|  |     // try get base64image from local cache image_url
 | ||
|  |     const _messages: ChatOptions["messages"] = []; | ||
|  |     for (const v of options.messages) { | ||
|  |       const content = await preProcessImageContent(v.content); | ||
|  |       _messages.push({ role: v.role, content }); | ||
|  |     } | ||
|  |     const messages = _messages.map((v) => { | ||
|  |       let parts: any[] = [{ text: getMessageTextContent(v) }]; | ||
|  |       if (isVisionModel(options.config.model)) { | ||
|  |         const images = getMessageImages(v); | ||
|  |         if (images.length > 0) { | ||
|  |           multimodal = true; | ||
|  |           parts = parts.concat( | ||
|  |             images.map((image) => { | ||
|  |               const imageType = image.split(";")[0].split(":")[1]; | ||
|  |               const imageData = image.split(",")[1]; | ||
|  |               return { | ||
|  |                 inline_data: { | ||
|  |                   mime_type: imageType, | ||
|  |                   data: imageData, | ||
|  |                 }, | ||
|  |               }; | ||
|  |             }), | ||
|  |           ); | ||
|  |         } | ||
|  |       } | ||
|  |       return { | ||
|  |         role: v.role.replace("assistant", "model").replace("system", "user"), | ||
|  |         parts: parts, | ||
|  |       }; | ||
|  |     }); | ||
|  | 
 | ||
|  |     // google requires that role in neighboring messages must not be the same
 | ||
|  |     for (let i = 0; i < messages.length - 1; ) { | ||
|  |       // Check if current and next item both have the role "model"
 | ||
|  |       if (messages[i].role === messages[i + 1].role) { | ||
|  |         // Concatenate the 'parts' of the current and next item
 | ||
|  |         messages[i].parts = messages[i].parts.concat(messages[i + 1].parts); | ||
|  |         // Remove the next item
 | ||
|  |         messages.splice(i + 1, 1); | ||
|  |       } else { | ||
|  |         // Move to the next item
 | ||
|  |         i++; | ||
|  |       } | ||
|  |     } | ||
|  |     // if (visionModel && messages.length > 1) {
 | ||
|  |     //   options.onError?.(new Error("Multiturn chat is not enabled for models/gemini-pro-vision"));
 | ||
|  |     // }
 | ||
|  | 
 | ||
|  |     const accessStore = useAccessStore.getState(); | ||
|  | 
 | ||
|  |     const modelConfig = { | ||
|  |       ...useAppConfig.getState().modelConfig, | ||
|  |       ...useChatStore.getState().currentSession().mask.modelConfig, | ||
|  |       ...{ | ||
|  |         model: options.config.model, | ||
|  |       }, | ||
|  |     }; | ||
|  |     const requestPayload = { | ||
|  |       contents: messages, | ||
|  |       generationConfig: { | ||
|  |         // stopSequences: [
 | ||
|  |         //   "Title"
 | ||
|  |         // ],
 | ||
|  |         temperature: modelConfig.temperature, | ||
|  |         maxOutputTokens: modelConfig.max_tokens, | ||
|  |         topP: modelConfig.top_p, | ||
|  |         // "topK": modelConfig.top_k,
 | ||
|  |       }, | ||
|  |       safetySettings: [ | ||
|  |         { | ||
|  |           category: "HARM_CATEGORY_HARASSMENT", | ||
|  |           threshold: accessStore.googleSafetySettings, | ||
|  |         }, | ||
|  |         { | ||
|  |           category: "HARM_CATEGORY_HATE_SPEECH", | ||
|  |           threshold: accessStore.googleSafetySettings, | ||
|  |         }, | ||
|  |         { | ||
|  |           category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", | ||
|  |           threshold: accessStore.googleSafetySettings, | ||
|  |         }, | ||
|  |         { | ||
|  |           category: "HARM_CATEGORY_DANGEROUS_CONTENT", | ||
|  |           threshold: accessStore.googleSafetySettings, | ||
|  |         }, | ||
|  |       ], | ||
|  |     }; | ||
|  | 
 | ||
|  |     let shouldStream = !!options.config.stream; | ||
|  |     const controller = new AbortController(); | ||
|  |     options.onController?.(controller); | ||
|  |     try { | ||
|  |       // https://github.com/google-gemini/cookbook/blob/main/quickstarts/rest/Streaming_REST.ipynb
 | ||
|  |       const chatPath = this.path( | ||
|  |         Google.ChatPath(modelConfig.model), | ||
|  |         shouldStream, | ||
|  |       ); | ||
|  | 
 | ||
|  |       const chatPayload = { | ||
|  |         method: "POST", | ||
|  |         body: JSON.stringify(requestPayload), | ||
|  |         signal: controller.signal, | ||
|  |         headers: getHeaders(), | ||
|  |       }; | ||
|  | 
 | ||
|  |       const isThinking = options.config.model.includes("-thinking"); | ||
|  |       // make a fetch request
 | ||
|  |       const requestTimeoutId = setTimeout( | ||
|  |         () => controller.abort(), | ||
|  |         isThinking ? REQUEST_TIMEOUT_MS_FOR_THINKING : REQUEST_TIMEOUT_MS, | ||
|  |       ); | ||
|  | 
 | ||
|  |       if (shouldStream) { | ||
|  |         const [tools, funcs] = usePluginStore | ||
|  |           .getState() | ||
|  |           .getAsTools( | ||
|  |             useChatStore.getState().currentSession().mask?.plugin || [], | ||
|  |           ); | ||
|  |         return stream( | ||
|  |           chatPath, | ||
|  |           requestPayload, | ||
|  |           getHeaders(), | ||
|  |           // @ts-ignore
 | ||
|  |           tools.length > 0 | ||
|  |             ? // @ts-ignore
 | ||
|  |               [{ functionDeclarations: tools.map((tool) => tool.function) }] | ||
|  |             : [], | ||
|  |           funcs, | ||
|  |           controller, | ||
|  |           // parseSSE
 | ||
|  |           (text: string, runTools: ChatMessageTool[]) => { | ||
|  |             // console.log("parseSSE", text, runTools);
 | ||
|  |             const chunkJson = JSON.parse(text); | ||
|  | 
 | ||
|  |             const functionCall = chunkJson?.candidates | ||
|  |               ?.at(0) | ||
|  |               ?.content.parts.at(0)?.functionCall; | ||
|  |             if (functionCall) { | ||
|  |               const { name, args } = functionCall; | ||
|  |               runTools.push({ | ||
|  |                 id: nanoid(), | ||
|  |                 type: "function", | ||
|  |                 function: { | ||
|  |                   name, | ||
|  |                   arguments: JSON.stringify(args), // utils.chat call function, using JSON.parse
 | ||
|  |                 }, | ||
|  |               }); | ||
|  |             } | ||
|  |             return chunkJson?.candidates | ||
|  |               ?.at(0) | ||
|  |               ?.content.parts?.map((part: { text: string }) => part.text) | ||
|  |               .join("\n\n"); | ||
|  |           }, | ||
|  |           // processToolMessage, include tool_calls message and tool call results
 | ||
|  |           ( | ||
|  |             requestPayload: RequestPayload, | ||
|  |             toolCallMessage: any, | ||
|  |             toolCallResult: any[], | ||
|  |           ) => { | ||
|  |             // @ts-ignore
 | ||
|  |             requestPayload?.contents?.splice( | ||
|  |               // @ts-ignore
 | ||
|  |               requestPayload?.contents?.length, | ||
|  |               0, | ||
|  |               { | ||
|  |                 role: "model", | ||
|  |                 parts: toolCallMessage.tool_calls.map( | ||
|  |                   (tool: ChatMessageTool) => ({ | ||
|  |                     functionCall: { | ||
|  |                       name: tool?.function?.name, | ||
|  |                       args: JSON.parse(tool?.function?.arguments as string), | ||
|  |                     }, | ||
|  |                   }), | ||
|  |                 ), | ||
|  |               }, | ||
|  |               // @ts-ignore
 | ||
|  |               ...toolCallResult.map((result) => ({ | ||
|  |                 role: "function", | ||
|  |                 parts: [ | ||
|  |                   { | ||
|  |                     functionResponse: { | ||
|  |                       name: result.name, | ||
|  |                       response: { | ||
|  |                         name: result.name, | ||
|  |                         content: result.content, // TODO just text content...
 | ||
|  |                       }, | ||
|  |                     }, | ||
|  |                   }, | ||
|  |                 ], | ||
|  |               })), | ||
|  |             ); | ||
|  |           }, | ||
|  |           options, | ||
|  |         ); | ||
|  |       } else { | ||
|  |         const res = await fetch(chatPath, chatPayload); | ||
|  |         clearTimeout(requestTimeoutId); | ||
|  |         const resJson = await res.json(); | ||
|  |         if (resJson?.promptFeedback?.blockReason) { | ||
|  |           // being blocked
 | ||
|  |           options.onError?.( | ||
|  |             new Error( | ||
|  |               "Message is being blocked for reason: " + | ||
|  |                 resJson.promptFeedback.blockReason, | ||
|  |             ), | ||
|  |           ); | ||
|  |         } | ||
|  |         const message = apiClient.extractMessage(resJson); | ||
|  |         options.onFinish(message, res); | ||
|  |       } | ||
|  |     } catch (e) { | ||
|  |       console.log("[Request] failed to make a chat request", e); | ||
|  |       options.onError?.(e as Error); | ||
|  |     } | ||
|  |   } | ||
|  |   usage(): Promise<LLMUsage> { | ||
|  |     throw new Error("Method not implemented."); | ||
|  |   } | ||
|  |   async models(): Promise<LLMModel[]> { | ||
|  |     return []; | ||
|  |   } | ||
|  | } |