You cannot select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
	
	
		
			164 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			TypeScript
		
	
		
		
			
		
	
	
			164 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			TypeScript
		
	
| 
											9 months ago
										 | import { | ||
|  |   Stability, | ||
|  |   StoreKey, | ||
|  |   ACCESS_CODE_PREFIX, | ||
|  |   ApiPath, | ||
|  | } from "@/app/constant"; | ||
|  | import { getBearerToken } from "@/app/client/api"; | ||
|  | import { createPersistStore } from "@/app/utils/store"; | ||
|  | import { nanoid } from "nanoid"; | ||
|  | import { uploadImage, base64Image2Blob } from "@/app/utils/chat"; | ||
|  | import { models, getModelParamBasicData } from "@/app/components/sd/sd-panel"; | ||
|  | import { useAccessStore } from "./access"; | ||
|  | 
 | ||
|  | const defaultModel = { | ||
|  |   name: models[0].name, | ||
|  |   value: models[0].value, | ||
|  | }; | ||
|  | 
 | ||
|  | const defaultParams = getModelParamBasicData(models[0].params({}), {}); | ||
|  | 
 | ||
|  | const DEFAULT_SD_STATE = { | ||
|  |   currentId: 0, | ||
|  |   draw: [], | ||
|  |   currentModel: defaultModel, | ||
|  |   currentParams: defaultParams, | ||
|  | }; | ||
|  | 
 | ||
|  | export const useSdStore = createPersistStore< | ||
|  |   { | ||
|  |     currentId: number; | ||
|  |     draw: any[]; | ||
|  |     currentModel: typeof defaultModel; | ||
|  |     currentParams: any; | ||
|  |   }, | ||
|  |   { | ||
|  |     getNextId: () => number; | ||
|  |     sendTask: (data: any, okCall?: Function) => void; | ||
|  |     updateDraw: (draw: any) => void; | ||
|  |     setCurrentModel: (model: any) => void; | ||
|  |     setCurrentParams: (data: any) => void; | ||
|  |   } | ||
|  | >( | ||
|  |   DEFAULT_SD_STATE, | ||
|  |   (set, _get) => { | ||
|  |     function get() { | ||
|  |       return { | ||
|  |         ..._get(), | ||
|  |         ...methods, | ||
|  |       }; | ||
|  |     } | ||
|  | 
 | ||
|  |     const methods = { | ||
|  |       getNextId() { | ||
|  |         const id = ++_get().currentId; | ||
|  |         set({ currentId: id }); | ||
|  |         return id; | ||
|  |       }, | ||
|  |       sendTask(data: any, okCall?: Function) { | ||
|  |         data = { ...data, id: nanoid(), status: "running" }; | ||
|  |         set({ draw: [data, ..._get().draw] }); | ||
|  |         this.getNextId(); | ||
|  |         this.stabilityRequestCall(data); | ||
|  |         okCall?.(); | ||
|  |       }, | ||
|  |       stabilityRequestCall(data: any) { | ||
|  |         const accessStore = useAccessStore.getState(); | ||
|  |         let prefix: string = ApiPath.Stability as string; | ||
|  |         let bearerToken = ""; | ||
|  |         if (accessStore.useCustomConfig) { | ||
|  |           prefix = accessStore.stabilityUrl || (ApiPath.Stability as string); | ||
|  |           bearerToken = getBearerToken(accessStore.stabilityApiKey); | ||
|  |         } | ||
|  |         if (!bearerToken && accessStore.enabledAccessControl()) { | ||
|  |           bearerToken = getBearerToken( | ||
|  |             ACCESS_CODE_PREFIX + accessStore.accessCode, | ||
|  |           ); | ||
|  |         } | ||
|  |         const headers = { | ||
|  |           Accept: "application/json", | ||
|  |           Authorization: bearerToken, | ||
|  |         }; | ||
|  |         const path = `${prefix}/${Stability.GeneratePath}/${data.model}`; | ||
|  |         const formData = new FormData(); | ||
|  |         for (let paramsKey in data.params) { | ||
|  |           formData.append(paramsKey, data.params[paramsKey]); | ||
|  |         } | ||
|  |         fetch(path, { | ||
|  |           method: "POST", | ||
|  |           headers, | ||
|  |           body: formData, | ||
|  |         }) | ||
|  |           .then((response) => response.json()) | ||
|  |           .then((resData) => { | ||
|  |             if (resData.errors && resData.errors.length > 0) { | ||
|  |               this.updateDraw({ | ||
|  |                 ...data, | ||
|  |                 status: "error", | ||
|  |                 error: resData.errors[0], | ||
|  |               }); | ||
|  |               this.getNextId(); | ||
|  |               return; | ||
|  |             } | ||
|  |             const self = this; | ||
|  |             if (resData.finish_reason === "SUCCESS") { | ||
|  |               uploadImage(base64Image2Blob(resData.image, "image/png")) | ||
|  |                 .then((img_data) => { | ||
|  |                   console.debug("uploadImage success", img_data, self); | ||
|  |                   self.updateDraw({ | ||
|  |                     ...data, | ||
|  |                     status: "success", | ||
|  |                     img_data, | ||
|  |                   }); | ||
|  |                 }) | ||
|  |                 .catch((e) => { | ||
|  |                   console.error("uploadImage error", e); | ||
|  |                   self.updateDraw({ | ||
|  |                     ...data, | ||
|  |                     status: "error", | ||
|  |                     error: JSON.stringify(e), | ||
|  |                   }); | ||
|  |                 }); | ||
|  |             } else { | ||
|  |               self.updateDraw({ | ||
|  |                 ...data, | ||
|  |                 status: "error", | ||
|  |                 error: JSON.stringify(resData), | ||
|  |               }); | ||
|  |             } | ||
|  |             this.getNextId(); | ||
|  |           }) | ||
|  |           .catch((error) => { | ||
|  |             this.updateDraw({ ...data, status: "error", error: error.message }); | ||
|  |             console.error("Error:", error); | ||
|  |             this.getNextId(); | ||
|  |           }); | ||
|  |       }, | ||
|  |       updateDraw(_draw: any) { | ||
|  |         const draw = _get().draw || []; | ||
|  |         draw.some((item, index) => { | ||
|  |           if (item.id === _draw.id) { | ||
|  |             draw[index] = _draw; | ||
|  |             set(() => ({ draw })); | ||
|  |             return true; | ||
|  |           } | ||
|  |         }); | ||
|  |       }, | ||
|  |       setCurrentModel(model: any) { | ||
|  |         set({ currentModel: model }); | ||
|  |       }, | ||
|  |       setCurrentParams(data: any) { | ||
|  |         set({ | ||
|  |           currentParams: data, | ||
|  |         }); | ||
|  |       }, | ||
|  |     }; | ||
|  | 
 | ||
|  |     return methods; | ||
|  |   }, | ||
|  |   { | ||
|  |     name: StoreKey.SdList, | ||
|  |     version: 1.0, | ||
|  |   }, | ||
|  | ); |