diff --git a/mobile/lib/db/ml/clip_db.dart b/mobile/lib/db/ml/clip_db.dart index 59f6dc936b..953a7c5397 100644 --- a/mobile/lib/db/ml/clip_db.dart +++ b/mobile/lib/db/ml/clip_db.dart @@ -17,6 +17,21 @@ extension ClipDB on MLDataDB { return _convertToVectors(results); } + Future> getClipVectorsForFileIDs( + Iterable fileIDs, + ) async { + final db = await MLDataDB.instance.asyncDB; + final results = await db.getAll( + 'SELECT * FROM $clipTable WHERE $fileIDColumn IN (${fileIDs.join(", ")})', + ); + final Map embeddings = {}; + for (final result in results) { + final embedding = _getVectorFromRow(result); + embeddings[embedding.fileID] = embedding; + } + return embeddings; + } + // Get indexed FileIDs Future> clipIndexedFileWithVersion() async { final db = await MLDataDB.instance.asyncDB; diff --git a/mobile/lib/db/ml/db.dart b/mobile/lib/db/ml/db.dart index 90f31d83ae..81eb486c26 100644 --- a/mobile/lib/db/ml/db.dart +++ b/mobile/lib/db/ml/db.dart @@ -400,6 +400,28 @@ class MLDataDB { return maps.map((e) => mapRowToFace(e)).toList(); } + Future>> getFacesForFileIDs( + Iterable fileUploadIDs, + ) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + ''' + SELECT * FROM $facesTable + WHERE $fileIDColumn IN (${fileUploadIDs.map((id) => "'$id'").join(",")}) + ''', + ); + if (maps.isEmpty) { + return {}; + } + final result = >{}; + for (final map in maps) { + final face = mapRowToFace(map); + final fileID = map[fileIDColumn] as int; + result.putIfAbsent(fileID, () => []).add(face); + } + return result; + } + Future>> getClusterToFaceIDs( Set clusterIDs, ) async {