You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
	
	
		
			259 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			TypeScript
		
	
		
		
			
		
	
	
			259 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			TypeScript
		
	
| 
											9 months ago
										 | import { DEFAULT_MODELS, ServiceProvider } from "../constant"; | ||
|  | import { LLMModel } from "../client/api"; | ||
|  | 
 | ||
|  | const CustomSeq = { | ||
|  |   val: -1000, //To ensure the custom model located at front, start from -1000, refer to constant.ts
 | ||
|  |   cache: new Map<string, number>(), | ||
|  |   next: (id: string) => { | ||
|  |     if (CustomSeq.cache.has(id)) { | ||
|  |       return CustomSeq.cache.get(id) as number; | ||
|  |     } else { | ||
|  |       let seq = CustomSeq.val++; | ||
|  |       CustomSeq.cache.set(id, seq); | ||
|  |       return seq; | ||
|  |     } | ||
|  |   }, | ||
|  | }; | ||
|  | 
 | ||
|  | const customProvider = (providerName: string) => ({ | ||
|  |   id: providerName.toLowerCase(), | ||
|  |   providerName: providerName, | ||
|  |   providerType: "custom", | ||
|  |   sorted: CustomSeq.next(providerName), | ||
|  | }); | ||
|  | 
 | ||
|  | /** | ||
|  |  * Sorts an array of models based on specified rules. | ||
|  |  * | ||
|  |  * First, sorted by provider; if the same, sorted by model | ||
|  |  */ | ||
|  | const sortModelTable = (models: ReturnType<typeof collectModels>) => | ||
|  |   models.sort((a, b) => { | ||
|  |     if (a.provider && b.provider) { | ||
|  |       let cmp = a.provider.sorted - b.provider.sorted; | ||
|  |       return cmp === 0 ? a.sorted - b.sorted : cmp; | ||
|  |     } else { | ||
|  |       return a.sorted - b.sorted; | ||
|  |     } | ||
|  |   }); | ||
|  | 
 | ||
|  | /** | ||
|  |  * get model name and provider from a formatted string, | ||
|  |  * e.g. `gpt-4@OpenAi` or `claude-3-5-sonnet@20240620@Google` | ||
|  |  * @param modelWithProvider model name with provider separated by last `@` char, | ||
|  |  * @returns [model, provider] tuple, if no `@` char found, provider is undefined | ||
|  |  */ | ||
|  | export function getModelProvider(modelWithProvider: string): [string, string?] { | ||
|  |   const [model, provider] = modelWithProvider.split(/@(?!.*@)/); | ||
|  |   return [model, provider]; | ||
|  | } | ||
|  | 
 | ||
|  | export function collectModelTable( | ||
|  |   models: readonly LLMModel[], | ||
|  |   customModels: string, | ||
|  | ) { | ||
|  |   const modelTable: Record< | ||
|  |     string, | ||
|  |     { | ||
|  |       available: boolean; | ||
|  |       name: string; | ||
|  |       displayName: string; | ||
|  |       sorted: number; | ||
|  |       provider?: LLMModel["provider"]; // Marked as optional
 | ||
|  |       isDefault?: boolean; | ||
|  |     } | ||
|  |   > = {}; | ||
|  | 
 | ||
|  |   // default models
 | ||
|  |   models.forEach((m) => { | ||
|  |     // using <modelName>@<providerId> as fullName
 | ||
|  |     modelTable[`${m.name}@${m?.provider?.id}`] = { | ||
|  |       ...m, | ||
|  |       displayName: m.name, // 'provider' is copied over if it exists
 | ||
|  |     }; | ||
|  |   }); | ||
|  | 
 | ||
|  |   // server custom models
 | ||
|  |   customModels | ||
|  |     .split(",") | ||
|  |     .filter((v) => !!v && v.length > 0) | ||
|  |     .forEach((m) => { | ||
|  |       const available = !m.startsWith("-"); | ||
|  |       const nameConfig = | ||
|  |         m.startsWith("+") || m.startsWith("-") ? m.slice(1) : m; | ||
|  |       let [name, displayName] = nameConfig.split("="); | ||
|  | 
 | ||
|  |       // enable or disable all models
 | ||
|  |       if (name === "all") { | ||
|  |         Object.values(modelTable).forEach( | ||
|  |           (model) => (model.available = available), | ||
|  |         ); | ||
|  |       } else { | ||
|  |         // 1. find model by name, and set available value
 | ||
|  |         const [customModelName, customProviderName] = getModelProvider(name); | ||
|  |         let count = 0; | ||
|  |         for (const fullName in modelTable) { | ||
|  |           const [modelName, providerName] = getModelProvider(fullName); | ||
|  |           if ( | ||
|  |             customModelName == modelName && | ||
|  |             (customProviderName === undefined || | ||
|  |               customProviderName === providerName) | ||
|  |           ) { | ||
|  |             count += 1; | ||
|  |             modelTable[fullName]["available"] = available; | ||
|  |             // swap name and displayName for bytedance
 | ||
|  |             if (providerName === "bytedance") { | ||
|  |               [name, displayName] = [displayName, modelName]; | ||
|  |               modelTable[fullName]["name"] = name; | ||
|  |             } | ||
|  |             if (displayName) { | ||
|  |               modelTable[fullName]["displayName"] = displayName; | ||
|  |             } | ||
|  |           } | ||
|  |         } | ||
|  |         // 2. if model not exists, create new model with available value
 | ||
|  |         if (count === 0) { | ||
|  |           let [customModelName, customProviderName] = getModelProvider(name); | ||
|  |           const provider = customProvider( | ||
|  |             customProviderName || customModelName, | ||
|  |           ); | ||
|  |           // swap name and displayName for bytedance
 | ||
|  |           if (displayName && provider.providerName == "ByteDance") { | ||
|  |             [customModelName, displayName] = [displayName, customModelName]; | ||
|  |           } | ||
|  |           modelTable[`${customModelName}@${provider?.id}`] = { | ||
|  |             name: customModelName, | ||
|  |             displayName: displayName || customModelName, | ||
|  |             available, | ||
|  |             provider, // Use optional chaining
 | ||
|  |             sorted: CustomSeq.next(`${customModelName}@${provider?.id}`), | ||
|  |           }; | ||
|  |         } | ||
|  |       } | ||
|  |     }); | ||
|  | 
 | ||
|  |   return modelTable; | ||
|  | } | ||
|  | 
 | ||
|  | export function collectModelTableWithDefaultModel( | ||
|  |   models: readonly LLMModel[], | ||
|  |   customModels: string, | ||
|  |   defaultModel: string, | ||
|  | ) { | ||
|  |   let modelTable = collectModelTable(models, customModels); | ||
|  |   if (defaultModel && defaultModel !== "") { | ||
|  |     if (defaultModel.includes("@")) { | ||
|  |       if (defaultModel in modelTable) { | ||
|  |         modelTable[defaultModel].isDefault = true; | ||
|  |       } | ||
|  |     } else { | ||
|  |       for (const key of Object.keys(modelTable)) { | ||
|  |         if ( | ||
|  |           modelTable[key].available && | ||
|  |           getModelProvider(key)[0] == defaultModel | ||
|  |         ) { | ||
|  |           modelTable[key].isDefault = true; | ||
|  |           break; | ||
|  |         } | ||
|  |       } | ||
|  |     } | ||
|  |   } | ||
|  |   return modelTable; | ||
|  | } | ||
|  | 
 | ||
|  | /** | ||
|  |  * Generate full model table. | ||
|  |  */ | ||
|  | export function collectModels( | ||
|  |   models: readonly LLMModel[], | ||
|  |   customModels: string, | ||
|  | ) { | ||
|  |   const modelTable = collectModelTable(models, customModels); | ||
|  |   let allModels = Object.values(modelTable); | ||
|  | 
 | ||
|  |   allModels = sortModelTable(allModels); | ||
|  | 
 | ||
|  |   return allModels; | ||
|  | } | ||
|  | 
 | ||
|  | export function collectModelsWithDefaultModel( | ||
|  |   models: readonly LLMModel[], | ||
|  |   customModels: string, | ||
|  |   defaultModel: string, | ||
|  | ) { | ||
|  |   const modelTable = collectModelTableWithDefaultModel( | ||
|  |     models, | ||
|  |     customModels, | ||
|  |     defaultModel, | ||
|  |   ); | ||
|  |   let allModels = Object.values(modelTable); | ||
|  | 
 | ||
|  |   allModels = sortModelTable(allModels); | ||
|  | 
 | ||
|  |   return allModels; | ||
|  | } | ||
|  | 
 | ||
|  | export function isModelAvailableInServer( | ||
|  |   customModels: string, | ||
|  |   modelName: string, | ||
|  |   providerName: string, | ||
|  | ) { | ||
|  |   const fullName = `${modelName}@${providerName}`; | ||
|  |   const modelTable = collectModelTable(DEFAULT_MODELS, customModels); | ||
|  |   return modelTable[fullName]?.available === false; | ||
|  | } | ||
|  | 
 | ||
|  | /** | ||
|  |  * Check if the model name is a GPT-4 related model | ||
|  |  * | ||
|  |  * @param modelName The name of the model to check | ||
|  |  * @returns True if the model is a GPT-4 related model (excluding gpt-4o-mini) | ||
|  |  */ | ||
|  | export function isGPT4Model(modelName: string): boolean { | ||
|  |   return ( | ||
|  |     (modelName.startsWith("gpt-4") || | ||
|  |       modelName.startsWith("chatgpt-4o") || | ||
|  |       modelName.startsWith("o1")) && | ||
|  |     !modelName.startsWith("gpt-4o-mini") | ||
|  |   ); | ||
|  | } | ||
|  | 
 | ||
|  | /** | ||
|  |  * Checks if a model is not available on any of the specified providers in the server. | ||
|  |  * | ||
|  |  * @param {string} customModels - A string of custom models, comma-separated. | ||
|  |  * @param {string} modelName - The name of the model to check. | ||
|  |  * @param {string|string[]} providerNames - A string or array of provider names to check against. | ||
|  |  * | ||
|  |  * @returns {boolean} True if the model is not available on any of the specified providers, false otherwise. | ||
|  |  */ | ||
|  | export function isModelNotavailableInServer( | ||
|  |   customModels: string, | ||
|  |   modelName: string, | ||
|  |   providerNames: string | string[], | ||
|  | ): boolean { | ||
|  |   // Check DISABLE_GPT4 environment variable
 | ||
|  |   if ( | ||
|  |     process.env.DISABLE_GPT4 === "1" && | ||
|  |     isGPT4Model(modelName.toLowerCase()) | ||
|  |   ) { | ||
|  |     return true; | ||
|  |   } | ||
|  | 
 | ||
|  |   const modelTable = collectModelTable(DEFAULT_MODELS, customModels); | ||
|  | 
 | ||
|  |   const providerNamesArray = Array.isArray(providerNames) | ||
|  |     ? providerNames | ||
|  |     : [providerNames]; | ||
|  |   for (const providerName of providerNamesArray) { | ||
|  |     // if model provider is bytedance, use model config name to check if not avaliable
 | ||
|  |     if (providerName === ServiceProvider.ByteDance) { | ||
|  |       return !Object.values(modelTable).filter((v) => v.name === modelName)?.[0] | ||
|  |         ?.available; | ||
|  |     } | ||
|  |     const fullName = `${modelName}@${providerName.toLowerCase()}`; | ||
|  |     if (modelTable?.[fullName]?.available === true) return false; | ||
|  |   } | ||
|  |   return true; | ||
|  | } |