You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
	
	
		
			259 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			TypeScript
		
	
			
		
		
	
	
			259 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			TypeScript
		
	
| import { DEFAULT_MODELS, ServiceProvider } from "../constant";
 | |
| import { LLMModel } from "../client/api";
 | |
| 
 | |
| const CustomSeq = {
 | |
|   val: -1000, //To ensure the custom model located at front, start from -1000, refer to constant.ts
 | |
|   cache: new Map<string, number>(),
 | |
|   next: (id: string) => {
 | |
|     if (CustomSeq.cache.has(id)) {
 | |
|       return CustomSeq.cache.get(id) as number;
 | |
|     } else {
 | |
|       let seq = CustomSeq.val++;
 | |
|       CustomSeq.cache.set(id, seq);
 | |
|       return seq;
 | |
|     }
 | |
|   },
 | |
| };
 | |
| 
 | |
| const customProvider = (providerName: string) => ({
 | |
|   id: providerName.toLowerCase(),
 | |
|   providerName: providerName,
 | |
|   providerType: "custom",
 | |
|   sorted: CustomSeq.next(providerName),
 | |
| });
 | |
| 
 | |
| /**
 | |
|  * Sorts an array of models based on specified rules.
 | |
|  *
 | |
|  * First, sorted by provider; if the same, sorted by model
 | |
|  */
 | |
| const sortModelTable = (models: ReturnType<typeof collectModels>) =>
 | |
|   models.sort((a, b) => {
 | |
|     if (a.provider && b.provider) {
 | |
|       let cmp = a.provider.sorted - b.provider.sorted;
 | |
|       return cmp === 0 ? a.sorted - b.sorted : cmp;
 | |
|     } else {
 | |
|       return a.sorted - b.sorted;
 | |
|     }
 | |
|   });
 | |
| 
 | |
| /**
 | |
|  * get model name and provider from a formatted string,
 | |
|  * e.g. `gpt-4@OpenAi` or `claude-3-5-sonnet@20240620@Google`
 | |
|  * @param modelWithProvider model name with provider separated by last `@` char,
 | |
|  * @returns [model, provider] tuple, if no `@` char found, provider is undefined
 | |
|  */
 | |
| export function getModelProvider(modelWithProvider: string): [string, string?] {
 | |
|   const [model, provider] = modelWithProvider.split(/@(?!.*@)/);
 | |
|   return [model, provider];
 | |
| }
 | |
| 
 | |
| export function collectModelTable(
 | |
|   models: readonly LLMModel[],
 | |
|   customModels: string,
 | |
| ) {
 | |
|   const modelTable: Record<
 | |
|     string,
 | |
|     {
 | |
|       available: boolean;
 | |
|       name: string;
 | |
|       displayName: string;
 | |
|       sorted: number;
 | |
|       provider?: LLMModel["provider"]; // Marked as optional
 | |
|       isDefault?: boolean;
 | |
|     }
 | |
|   > = {};
 | |
| 
 | |
|   // default models
 | |
|   models.forEach((m) => {
 | |
|     // using <modelName>@<providerId> as fullName
 | |
|     modelTable[`${m.name}@${m?.provider?.id}`] = {
 | |
|       ...m,
 | |
|       displayName: m.name, // 'provider' is copied over if it exists
 | |
|     };
 | |
|   });
 | |
| 
 | |
|   // server custom models
 | |
|   customModels
 | |
|     .split(",")
 | |
|     .filter((v) => !!v && v.length > 0)
 | |
|     .forEach((m) => {
 | |
|       const available = !m.startsWith("-");
 | |
|       const nameConfig =
 | |
|         m.startsWith("+") || m.startsWith("-") ? m.slice(1) : m;
 | |
|       let [name, displayName] = nameConfig.split("=");
 | |
| 
 | |
|       // enable or disable all models
 | |
|       if (name === "all") {
 | |
|         Object.values(modelTable).forEach(
 | |
|           (model) => (model.available = available),
 | |
|         );
 | |
|       } else {
 | |
|         // 1. find model by name, and set available value
 | |
|         const [customModelName, customProviderName] = getModelProvider(name);
 | |
|         let count = 0;
 | |
|         for (const fullName in modelTable) {
 | |
|           const [modelName, providerName] = getModelProvider(fullName);
 | |
|           if (
 | |
|             customModelName == modelName &&
 | |
|             (customProviderName === undefined ||
 | |
|               customProviderName === providerName)
 | |
|           ) {
 | |
|             count += 1;
 | |
|             modelTable[fullName]["available"] = available;
 | |
|             // swap name and displayName for bytedance
 | |
|             if (providerName === "bytedance") {
 | |
|               [name, displayName] = [displayName, modelName];
 | |
|               modelTable[fullName]["name"] = name;
 | |
|             }
 | |
|             if (displayName) {
 | |
|               modelTable[fullName]["displayName"] = displayName;
 | |
|             }
 | |
|           }
 | |
|         }
 | |
|         // 2. if model not exists, create new model with available value
 | |
|         if (count === 0) {
 | |
|           let [customModelName, customProviderName] = getModelProvider(name);
 | |
|           const provider = customProvider(
 | |
|             customProviderName || customModelName,
 | |
|           );
 | |
|           // swap name and displayName for bytedance
 | |
|           if (displayName && provider.providerName == "ByteDance") {
 | |
|             [customModelName, displayName] = [displayName, customModelName];
 | |
|           }
 | |
|           modelTable[`${customModelName}@${provider?.id}`] = {
 | |
|             name: customModelName,
 | |
|             displayName: displayName || customModelName,
 | |
|             available,
 | |
|             provider, // Use optional chaining
 | |
|             sorted: CustomSeq.next(`${customModelName}@${provider?.id}`),
 | |
|           };
 | |
|         }
 | |
|       }
 | |
|     });
 | |
| 
 | |
|   return modelTable;
 | |
| }
 | |
| 
 | |
| export function collectModelTableWithDefaultModel(
 | |
|   models: readonly LLMModel[],
 | |
|   customModels: string,
 | |
|   defaultModel: string,
 | |
| ) {
 | |
|   let modelTable = collectModelTable(models, customModels);
 | |
|   if (defaultModel && defaultModel !== "") {
 | |
|     if (defaultModel.includes("@")) {
 | |
|       if (defaultModel in modelTable) {
 | |
|         modelTable[defaultModel].isDefault = true;
 | |
|       }
 | |
|     } else {
 | |
|       for (const key of Object.keys(modelTable)) {
 | |
|         if (
 | |
|           modelTable[key].available &&
 | |
|           getModelProvider(key)[0] == defaultModel
 | |
|         ) {
 | |
|           modelTable[key].isDefault = true;
 | |
|           break;
 | |
|         }
 | |
|       }
 | |
|     }
 | |
|   }
 | |
|   return modelTable;
 | |
| }
 | |
| 
 | |
| /**
 | |
|  * Generate full model table.
 | |
|  */
 | |
| export function collectModels(
 | |
|   models: readonly LLMModel[],
 | |
|   customModels: string,
 | |
| ) {
 | |
|   const modelTable = collectModelTable(models, customModels);
 | |
|   let allModels = Object.values(modelTable);
 | |
| 
 | |
|   allModels = sortModelTable(allModels);
 | |
| 
 | |
|   return allModels;
 | |
| }
 | |
| 
 | |
| export function collectModelsWithDefaultModel(
 | |
|   models: readonly LLMModel[],
 | |
|   customModels: string,
 | |
|   defaultModel: string,
 | |
| ) {
 | |
|   const modelTable = collectModelTableWithDefaultModel(
 | |
|     models,
 | |
|     customModels,
 | |
|     defaultModel,
 | |
|   );
 | |
|   let allModels = Object.values(modelTable);
 | |
| 
 | |
|   allModels = sortModelTable(allModels);
 | |
| 
 | |
|   return allModels;
 | |
| }
 | |
| 
 | |
| export function isModelAvailableInServer(
 | |
|   customModels: string,
 | |
|   modelName: string,
 | |
|   providerName: string,
 | |
| ) {
 | |
|   const fullName = `${modelName}@${providerName}`;
 | |
|   const modelTable = collectModelTable(DEFAULT_MODELS, customModels);
 | |
|   return modelTable[fullName]?.available === false;
 | |
| }
 | |
| 
 | |
| /**
 | |
|  * Check if the model name is a GPT-4 related model
 | |
|  *
 | |
|  * @param modelName The name of the model to check
 | |
|  * @returns True if the model is a GPT-4 related model (excluding gpt-4o-mini)
 | |
|  */
 | |
| export function isGPT4Model(modelName: string): boolean {
 | |
|   return (
 | |
|     (modelName.startsWith("gpt-4") ||
 | |
|       modelName.startsWith("chatgpt-4o") ||
 | |
|       modelName.startsWith("o1")) &&
 | |
|     !modelName.startsWith("gpt-4o-mini")
 | |
|   );
 | |
| }
 | |
| 
 | |
| /**
 | |
|  * Checks if a model is not available on any of the specified providers in the server.
 | |
|  *
 | |
|  * @param {string} customModels - A string of custom models, comma-separated.
 | |
|  * @param {string} modelName - The name of the model to check.
 | |
|  * @param {string|string[]} providerNames - A string or array of provider names to check against.
 | |
|  *
 | |
|  * @returns {boolean} True if the model is not available on any of the specified providers, false otherwise.
 | |
|  */
 | |
| export function isModelNotavailableInServer(
 | |
|   customModels: string,
 | |
|   modelName: string,
 | |
|   providerNames: string | string[],
 | |
| ): boolean {
 | |
|   // Check DISABLE_GPT4 environment variable
 | |
|   if (
 | |
|     process.env.DISABLE_GPT4 === "1" &&
 | |
|     isGPT4Model(modelName.toLowerCase())
 | |
|   ) {
 | |
|     return true;
 | |
|   }
 | |
| 
 | |
|   const modelTable = collectModelTable(DEFAULT_MODELS, customModels);
 | |
| 
 | |
|   const providerNamesArray = Array.isArray(providerNames)
 | |
|     ? providerNames
 | |
|     : [providerNames];
 | |
|   for (const providerName of providerNamesArray) {
 | |
|     // if model provider is bytedance, use model config name to check if not avaliable
 | |
|     if (providerName === ServiceProvider.ByteDance) {
 | |
|       return !Object.values(modelTable).filter((v) => v.name === modelName)?.[0]
 | |
|         ?.available;
 | |
|     }
 | |
|     const fullName = `${modelName}@${providerName.toLowerCase()}`;
 | |
|     if (modelTable?.[fullName]?.available === true) return false;
 | |
|   }
 | |
|   return true;
 | |
| }
 |