diff --git a/desktop/src/main/services/ml-worker.ts b/desktop/src/main/services/ml-worker.ts index 35e342b029..f700d84477 100644 --- a/desktop/src/main/services/ml-worker.ts +++ b/desktop/src/main/services/ml-worker.ts @@ -184,14 +184,13 @@ const downloadModel = async (saveLocation: string, name: string) => { /** * Create an ONNX {@link InferenceSession} with some defaults. */ -const createInferenceSession = async (modelPath: string) => { - return await ort.InferenceSession.create(modelPath, { +const createInferenceSession = async (modelPath: string) => + ort.InferenceSession.create(modelPath, { // Restrict the number of threads to 1. intraOpNumThreads: 1, // Be more conservative with RAM usage. enableCpuMemArena: false, }); -}; const cachedCLIPImageSession = makeCachedInferenceSession( "mobileclip_s2_image_opset18_rgba_opt.onnx", @@ -233,9 +232,11 @@ const getTokenizer = () => (_tokenizer ??= new Tokenizer()); export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => { const sessionOrSkip = await Promise.race([ cachedCLIPTextSession(), - // Wait for a tick to get the session promise to resolved the first time - // this code runs on each app start (and the model has been downloaded). - wait(0).then(() => 1), + // Wait a bit to get the session promise to resolved the first time this + // code runs on each app start (in these cases the model will already be + // downloaded, so session creation should take only a 1 or 2 ticks: file + // system stat, and ort.InferenceSession.create). + wait(50).then(() => 1), ]); // Don't wait for the download to complete.