From 3f262c5ba2a7a9b3bc0f9697d4380b4dcd3154b3 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Tue, 15 Apr 2025 15:40:27 +0530 Subject: [PATCH] Add db for offline ml data --- mobile/lib/db/ml/base.dart | 2 +- mobile/lib/db/ml/db.dart | 14 +- mobile/lib/db/ml/db_model_mappers.dart | 14 +- mobile/lib/db/ml/offlinedb.dart | 1252 ++++++++++++++++++++++++ mobile/lib/db/ml/schema.dart | 18 +- 5 files changed, 1279 insertions(+), 21 deletions(-) create mode 100644 mobile/lib/db/ml/offlinedb.dart diff --git a/mobile/lib/db/ml/base.dart b/mobile/lib/db/ml/base.dart index 35e745fd1f..553000719b 100644 --- a/mobile/lib/db/ml/base.dart +++ b/mobile/lib/db/ml/base.dart @@ -9,7 +9,7 @@ import "package:photos/services/machine_learning/face_ml/face_clustering/face_db abstract class IMLDataDB { Future bulkInsertFaces(List faces); Future updateFaceIdToClusterId(Map faceIDToClusterID); - Future> faceIndexedFileIds({int minimumMlVersion}); + Future> faceIndexedFileIds({int minimumMlVersion}); Future getFaceIndexedFileCount({int minimumMlVersion}); Future> clusterIdToFaceCount(); Future> getPersonIgnoredClusters(String personID); diff --git a/mobile/lib/db/ml/db.dart b/mobile/lib/db/ml/db.dart index fde278a9f2..b36c862649 100644 --- a/mobile/lib/db/ml/db.dart +++ b/mobile/lib/db/ml/db.dart @@ -49,13 +49,13 @@ class MLDataDB with SqlDbBase implements IMLDataDB { static final MLDataDB instance = MLDataDB._privateConstructor(); static final _migrationScripts = [ - createFacesTable, + getCreateFacesTable(false), createFaceClustersTable, createClusterPersonTable, createClusterSummaryTable, createNotPersonFeedbackTable, fcClusterIDIndex, - createClipEmbeddingsTable, + getCreateClipEmbeddingsTable(false), createFileDataTable, ]; @@ -337,10 +337,10 @@ class MLDataDB with SqlDbBase implements IMLDataDB { (element) => (element[fileIDColumn] as int) == avatarFileId, ); if (row != null) { - return mapRowToFace(row); + return mapRowToFace(row); } } - return mapRowToFace(faceMaps.first); + return mapRowToFace(faceMaps.first); } } if (clusterID != null) { @@ -384,7 +384,7 @@ class MLDataDB with SqlDbBase implements IMLDataDB { if (maps.isEmpty) { return null; } - return maps.map((e) => mapRowToFace(e)).toList(); + return maps.map((e) => mapRowToFace(e)).toList(); } @override @@ -401,7 +401,7 @@ class MLDataDB with SqlDbBase implements IMLDataDB { } final result = >{}; for (final map in maps) { - final face = mapRowToFaceWithoutEmbedding(map); + final face = mapRowToFaceWithoutEmbedding(map); final fileID = map[fileIDColumn] as int; result.putIfAbsent(fileID, () => []).add(face); } @@ -1060,7 +1060,7 @@ class MLDataDB with SqlDbBase implements IMLDataDB { final db = await instance.asyncDB; if (faces) { await db.execute(deleteFacesTable); - await db.execute(createFacesTable); + await db.execute(getCreateFacesTable(false)); await db.execute(deleteFaceClustersTable); await db.execute(createFaceClustersTable); await db.execute(fcClusterIDIndex); diff --git a/mobile/lib/db/ml/db_model_mappers.dart b/mobile/lib/db/ml/db_model_mappers.dart index 4940b13dfc..84f7dba576 100644 --- a/mobile/lib/db/ml/db_model_mappers.dart +++ b/mobile/lib/db/ml/db_model_mappers.dart @@ -7,7 +7,7 @@ import "package:photos/models/ml/face/face.dart"; import "package:photos/models/ml/face/face_with_embedding.dart"; import "package:photos/models/ml/ml_versions.dart"; -Map mapRemoteToFaceDB(Face face) { +Map mapRemoteToFaceDB(Face face) { return { faceIDColumn: face.faceID, fileIDColumn: face.fileID, @@ -24,10 +24,10 @@ Map mapRemoteToFaceDB(Face face) { }; } -Face mapRowToFace(Map row) { +Face mapRowToFace(Map row) { return Face( row[faceIDColumn] as String, - row[fileIDColumn] as int, + row[fileIDColumn] as T, EVector.fromBuffer(row[embeddingColumn] as List).values, row[faceScore] as double, Detection.fromJson(json.decode(row[faceDetectionColumn] as String)), @@ -39,10 +39,12 @@ Face mapRowToFace(Map row) { ); } -FaceWithoutEmbedding mapRowToFaceWithoutEmbedding(Map row) { - return FaceWithoutEmbedding( +FaceWithoutEmbedding mapRowToFaceWithoutEmbedding( + Map row, +) { + return FaceWithoutEmbedding( row[faceIDColumn] as String, - row[fileIDColumn] as int, + row[fileIDColumn] as T, row[faceScore] as double, Detection.fromJson(json.decode(row[faceDetectionColumn] as String)), row[faceBlur] as double, diff --git a/mobile/lib/db/ml/offlinedb.dart b/mobile/lib/db/ml/offlinedb.dart new file mode 100644 index 0000000000..d2db839e96 --- /dev/null +++ b/mobile/lib/db/ml/offlinedb.dart @@ -0,0 +1,1252 @@ +import 'dart:async'; +import "dart:math"; + +import "package:collection/collection.dart"; +import "package:flutter/foundation.dart"; +import 'package:logging/logging.dart'; +import 'package:path/path.dart' show join; +import 'package:path_provider/path_provider.dart'; +import "package:photos/core/event_bus.dart"; +import "package:photos/db/common/base.dart"; +import "package:photos/db/ml/base.dart"; +import "package:photos/db/ml/db_model_mappers.dart"; +import 'package:photos/db/ml/schema.dart'; +import "package:photos/events/embedding_updated_event.dart"; +import "package:photos/extensions/stop_watch.dart"; +import "package:photos/models/ml/clip.dart"; +import "package:photos/models/ml/face/face.dart"; +import "package:photos/models/ml/face/face_with_embedding.dart"; +import "package:photos/models/ml/ml_versions.dart"; +import "package:photos/models/ml/vector.dart"; +import "package:photos/services/machine_learning/face_ml/face_clustering/face_db_info_for_clustering.dart"; +import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; +import "package:photos/services/machine_learning/ml_result.dart"; +import "package:photos/utils/ml_util.dart"; +import 'package:sqlite_async/sqlite_async.dart'; + +/// Stores all data for the ML related features. The database can be accessed by `MLDataDB.instance.database`. +/// +/// This includes: +/// [facesTable] - Stores all the detected faces and its embeddings in the images. +/// [faceClustersTable] - Stores all the mappings from the faces (faceID) to the clusters (clusterID). +/// [clusterPersonTable] - Stores all the clusters that are mapped to a certain person. +/// [clusterSummaryTable] - Stores a summary of each cluster, containg the mean embedding and the number of faces in the cluster. +/// [notPersonFeedback] - Stores the clusters that are confirmed not to belong to a certain person by the user +/// +/// [clipTable] - Stores the embeddings of the CLIP model +/// [fileDataTable] - Stores data about the files that are already processed by the ML models +class OfflineMLDataDB with SqlDbBase implements IMLDataDB { + static final Logger _logger = Logger("MLDataDB"); + + static const _databaseName = "ente.ml.db"; + + static Logger get logger => _logger; + + // static const _databaseVersion = 1; + + OfflineMLDataDB._privateConstructor(); + + static final OfflineMLDataDB instance = OfflineMLDataDB._privateConstructor(); + + static final _migrationScripts = [ + getCreateFacesTable(true), + createFaceClustersTable, + createClusterPersonTable, + createClusterSummaryTable, + createNotPersonFeedbackTable, + fcClusterIDIndex, + getCreateClipEmbeddingsTable(true), + createFileDataTable, + ]; + + // only have a single app-wide reference to the database + static Future? _sqliteAsyncDBFuture; + + Future get asyncDB async { + _sqliteAsyncDBFuture ??= _initSqliteAsyncDatabase(); + return _sqliteAsyncDBFuture!; + } + + Future _initSqliteAsyncDatabase() async { + final documentsDirectory = await getApplicationDocumentsDirectory(); + final String databaseDirectory = + join(documentsDirectory.path, _databaseName); + _logger.info("Opening sqlite_async access: DB path " + databaseDirectory); + final asyncDBConnection = + SqliteDatabase(path: databaseDirectory, maxReaders: 2); + final stopwatch = Stopwatch()..start(); + _logger.info("MLDataDB: Starting migration"); + await migrate(asyncDBConnection, _migrationScripts); + _logger.info( + "MLDataDB Migration took ${stopwatch.elapsedMilliseconds} ms", + ); + stopwatch.stop(); + + return asyncDBConnection; + } + + // bulkInsertFaces inserts the faces in the database in batches of 1000. + // This is done to avoid the error "too many SQL variables" when inserting + // a large number of faces. + @override + Future bulkInsertFaces(List faces) async { + final db = await instance.asyncDB; + const batchSize = 500; + final numBatches = (faces.length / batchSize).ceil(); + for (int i = 0; i < numBatches; i++) { + final start = i * batchSize; + final end = min((i + 1) * batchSize, faces.length); + final batch = faces.sublist(start, end); + + const String sql = ''' + INSERT INTO $facesTable ( + $fileIDColumn, $faceIDColumn, $faceDetectionColumn, $embeddingColumn, $faceScore, $faceBlur, $isSideways, $imageHeight, $imageWidth, $mlVersionColumn + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT($fileIDColumn, $faceIDColumn) DO UPDATE SET $faceIDColumn = excluded.$faceIDColumn, $faceDetectionColumn = excluded.$faceDetectionColumn, $embeddingColumn = excluded.$embeddingColumn, $faceScore = excluded.$faceScore, $faceBlur = excluded.$faceBlur, $isSideways = excluded.$isSideways, $imageHeight = excluded.$imageHeight, $imageWidth = excluded.$imageWidth, $mlVersionColumn = excluded.$mlVersionColumn + '''; + final parameterSets = batch.map((face) { + final map = mapRemoteToFaceDB(face); + return [ + map[fileIDColumn], + map[faceIDColumn], + map[faceDetectionColumn], + map[embeddingColumn], + map[faceScore], + map[faceBlur], + map[isSideways], + map[imageHeight], + map[imageWidth], + map[mlVersionColumn], + ]; + }).toList(); + + await db.executeBatch(sql, parameterSets); + } + } + + @override + Future updateFaceIdToClusterId( + Map faceIDToClusterID, + ) async { + final db = await instance.asyncDB; + const batchSize = 500; + final numBatches = (faceIDToClusterID.length / batchSize).ceil(); + for (int i = 0; i < numBatches; i++) { + final start = i * batchSize; + final end = min((i + 1) * batchSize, faceIDToClusterID.length); + final batch = faceIDToClusterID.entries.toList().sublist(start, end); + + const String sql = ''' + INSERT INTO $faceClustersTable ($faceIDColumn, $clusterIDColumn) + VALUES (?, ?) + ON CONFLICT($faceIDColumn) DO UPDATE SET $clusterIDColumn = excluded.$clusterIDColumn + '''; + final parameterSets = batch.map((e) => [e.key, e.value]).toList(); + + await db.executeBatch(sql, parameterSets); + } + } + + /// Returns a map of fileID to the indexed ML version + @override + Future> faceIndexedFileIds({ + int minimumMlVersion = faceMlVersion, + }) async { + final db = await instance.asyncDB; + final String query = ''' + SELECT $fileIDColumn, $mlVersionColumn + FROM $facesTable + WHERE $mlVersionColumn >= $minimumMlVersion + '''; + final List> maps = await db.getAll(query); + final Map result = {}; + for (final map in maps) { + result[map[fileIDColumn] as String] = map[mlVersionColumn] as int; + } + return result; + } + + @override + Future getFaceIndexedFileCount({ + int minimumMlVersion = faceMlVersion, + }) async { + final db = await instance.asyncDB; + final String query = + 'SELECT COUNT(DISTINCT $fileIDColumn) as count FROM $facesTable WHERE $mlVersionColumn >= $minimumMlVersion'; + final List> maps = await db.getAll(query); + return maps.first['count'] as int; + } + + @override + Future> clusterIdToFaceCount() async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $clusterIDColumn, COUNT(*) as count FROM $faceClustersTable where $clusterIDColumn IS NOT NULL GROUP BY $clusterIDColumn ', + ); + final Map result = {}; + for (final map in maps) { + result[map[clusterIDColumn] as String] = map['count'] as int; + } + return result; + } + + @override + Future> getPersonIgnoredClusters(String personID) async { + final db = await instance.asyncDB; + // find out clusterIds that are assigned to other persons using the clusters table + final List> otherPersonMaps = await db.getAll( + 'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn != ? AND $personIdColumn IS NOT NULL', + [personID], + ); + final Set ignoredClusterIDs = + otherPersonMaps.map((e) => e[clusterIDColumn] as String).toSet(); + final List> rejectMaps = await db.getAll( + 'SELECT $clusterIDColumn FROM $notPersonFeedback WHERE $personIdColumn = ?', + [personID], + ); + final Set rejectClusterIDs = + rejectMaps.map((e) => e[clusterIDColumn] as String).toSet(); + return ignoredClusterIDs.union(rejectClusterIDs); + } + + @override + Future> getPersonClusterIDs(String personID) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn = ?', + [personID], + ); + return maps.map((e) => e[clusterIDColumn] as String).toSet(); + } + + @override + Future> getPersonsClusterIDs(List personID) async { + final db = await instance.asyncDB; + final inParam = personID.map((e) => "'$e'").join(','); + final List> maps = await db.getAll( + 'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn IN ($inParam)', + ); + return maps.map((e) => e[clusterIDColumn] as String).toSet(); + } + + @override + Future clearTable() async { + final db = await instance.asyncDB; + + await db.execute(deleteFacesTable); + await db.execute(deleteFaceClustersTable); + await db.execute(deleteClusterPersonTable); + await db.execute(deleteClusterSummaryTable); + await db.execute(deleteNotPersonFeedbackTable); + await db.execute(deleteClipEmbeddingsTable); + await db.execute(deleteFileDataTable); + } + + @override + Future> getFaceEmbeddingsForCluster( + String clusterID, { + int? limit, + }) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $embeddingColumn FROM $facesTable WHERE $faceIDColumn in (SELECT $faceIDColumn from $faceClustersTable where $clusterIDColumn = ?) ${limit != null ? 'LIMIT $limit' : ''}', + [clusterID], + ); + return maps.map((e) => e[embeddingColumn] as Uint8List); + } + + @override + Future>> getFaceEmbeddingsForClusters( + Iterable clusterIDs, { + int? limit, + }) async { + final db = await instance.asyncDB; + final Map> result = {}; + + final selectQuery = ''' + SELECT fc.$clusterIDColumn, fe.$embeddingColumn + FROM $faceClustersTable fc + INNER JOIN $facesTable fe ON fc.$faceIDColumn = fe.$faceIDColumn + WHERE fc.$clusterIDColumn IN (${List.filled(clusterIDs.length, '?').join(',')}) + ${limit != null ? 'LIMIT ?' : ''} +'''; + + final List selectQueryParams = [...clusterIDs]; + if (limit != null) { + selectQueryParams.add(limit); + } + + final List> maps = + await db.getAll(selectQuery, selectQueryParams); + + for (final map in maps) { + final clusterID = map[clusterIDColumn] as String; + final faceEmbedding = map[embeddingColumn] as Uint8List; + result.putIfAbsent(clusterID, () => []).add(faceEmbedding); + } + + return result; + } + + @override + Future getCoverFaceForPerson({ + required String recentFileID, + String? personID, + String? avatarFaceId, + String? clusterID, + }) async { + // read person from db + final db = await instance.asyncDB; + if (personID != null) { + final List fileId = [recentFileID]; + String? avatarFileId; + if (avatarFaceId != null) { + avatarFileId = getFileIdFromFaceId(avatarFaceId); + if (avatarFileId != null) { + fileId.add(avatarFileId); + } + } + const String queryClusterID = ''' + SELECT $clusterIDColumn + FROM $clusterPersonTable + WHERE $personIdColumn = ? + '''; + final clusterRows = await db.getAll( + queryClusterID, + [personID], + ); + final clusterIDs = + clusterRows.map((e) => e[clusterIDColumn] as String).toList(); + + final List> faceMaps = await db.getAll( + ''' + SELECT * FROM $facesTable + WHERE $faceIDColumn IN ( + SELECT $faceIDColumn + FROM $faceClustersTable + WHERE $clusterIDColumn IN (${List.filled(clusterIDs.length, '?').join(',')}) + ) + AND $fileIDColumn IN (${List.filled(fileId.length, '?').join(',')}) + ORDER BY $faceScore DESC + ''', + [...clusterIDs, ...fileId], + ); + if (faceMaps.isNotEmpty) { + if (avatarFileId != null) { + final row = faceMaps.firstWhereOrNull( + (element) => (element[fileIDColumn] as String) == avatarFileId, + ); + if (row != null) { + return mapRowToFace(row); + } + } + return mapRowToFace(faceMaps.first); + } + } + if (clusterID != null) { + const String queryFaceID = ''' + SELECT $faceIDColumn + FROM $faceClustersTable + WHERE $clusterIDColumn = ? + '''; + final List> faceMaps = await db.getAll( + queryFaceID, + [clusterID], + ); + final List? faces = await getFacesForGivenFileID(recentFileID); + if (faces != null) { + for (final face in faces) { + if (faceMaps.any( + (element) => (element[faceIDColumn] as String) == face.faceID, + )) { + return face; + } + } + } + } + if (personID == null && clusterID == null) { + throw Exception("personID and clusterID cannot be null"); + } + return null; + } + + @override + Future?> getFacesForGivenFileID(String fileUploadID) async { + final db = await instance.asyncDB; + const String query = ''' + SELECT * FROM $facesTable + WHERE $fileIDColumn = ? + '''; + final List> maps = await db.getAll( + query, + [fileUploadID], + ); + if (maps.isEmpty) { + return null; + } + return maps.map((e) => mapRowToFace(e)).toList(); + } + + @override + Future>> + getFileIDsToFacesWithoutEmbedding() async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + ''' + SELECT $faceIDColumn, $fileIDColumn, $faceScore, $faceDetectionColumn, $faceBlur FROM $facesTable + ''', + ); + if (maps.isEmpty) { + return {}; + } + final result = >{}; + for (final map in maps) { + final face = mapRowToFaceWithoutEmbedding(map); + final fileID = map[fileIDColumn] as int; + result.putIfAbsent(fileID, () => []).add(face); + } + return result; + } + + @override + Future>> getClusterToFaceIDs( + Set clusterIDs, + ) async { + final db = await instance.asyncDB; + final Map> result = {}; + + final List> maps = await db.getAll( + ''' + SELECT $clusterIDColumn, $faceIDColumn + FROM $faceClustersTable + WHERE $clusterIDColumn IN (${List.filled(clusterIDs.length, '?').join(',')}) + ''', + [...clusterIDs], + ); + + for (final map in maps) { + final clusterID = map[clusterIDColumn] as String; + final faceID = map[faceIDColumn] as String; + result.putIfAbsent(clusterID, () => []).add(faceID); + } + return result; + } + + @override + Future getClusterIDForFaceID(String faceID) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $clusterIDColumn FROM $faceClustersTable WHERE $faceIDColumn = ?', + [faceID], + ); + if (maps.isEmpty) { + return null; + } + return maps.first[clusterIDColumn] as String; + } + + @override + Future>> getAllClusterIdToFaceIDs() async { + final db = await instance.asyncDB; + final Map> result = {}; + final List> maps = await db.getAll( + 'SELECT $clusterIDColumn, $faceIDColumn FROM $faceClustersTable', + ); + for (final map in maps) { + final clusterID = map[clusterIDColumn] as String; + final faceID = map[faceIDColumn] as String; + result.putIfAbsent(clusterID, () => []).add(faceID); + } + return result; + } + + @override + Future> getFaceIDsForCluster(String clusterID) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $faceIDColumn FROM $faceClustersTable ' + 'WHERE $faceClustersTable.$clusterIDColumn = ?', + [clusterID], + ); + return maps.map((e) => e[faceIDColumn] as String).toSet(); + } + + // Get Map of personID to Map of clusterID to faceIDs + @override + Future>>> + getPersonToClusterIdToFaceIds() async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $personIdColumn, $faceClustersTable.$clusterIDColumn, $faceIDColumn FROM $clusterPersonTable ' + 'INNER JOIN $faceClustersTable ON $clusterPersonTable.$clusterIDColumn = $faceClustersTable.$clusterIDColumn', + ); + final Map>> result = {}; + for (final map in maps) { + final personID = map[personIdColumn] as String; + final clusterID = map[clusterIDColumn] as String; + final faceID = map[faceIDColumn] as String; + result + .putIfAbsent(personID, () => {}) + .putIfAbsent(clusterID, () => {}) + .add(faceID); + } + return result; + } + + Future> getFaceIdToPersonIdForFaces( + Iterable faceIDs, + ) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $faceIDColumn, $personIdColumn FROM $clusterPersonTable ' + 'INNER JOIN $faceClustersTable ON $clusterPersonTable.$clusterIDColumn = $faceClustersTable.$clusterIDColumn ' + 'WHERE $faceIDColumn IN (${faceIDs.map((id) => "'$id'").join(",")})', + ); + final Map result = {}; + for (final map in maps) { + result[map[faceIDColumn] as String] = map[personIdColumn] as String; + } + return result; + } + + @override + Future>> getClusterIdToFaceIdsForPerson( + String personID, + ) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $faceClustersTable.$clusterIDColumn, $faceIDColumn FROM $clusterPersonTable ' + 'INNER JOIN $faceClustersTable ON $clusterPersonTable.$clusterIDColumn = $faceClustersTable.$clusterIDColumn ' + 'WHERE $personIdColumn = ?', + [personID], + ); + final Map> result = {}; + for (final map in maps) { + final clusterID = map[clusterIDColumn] as String; + final faceID = map[faceIDColumn] as String; + result.putIfAbsent(clusterID, () => {}).add(faceID); + } + return result; + } + + @override + Future> getFaceIDsForPerson(String personID) async { + final db = await instance.asyncDB; + final faceIdsResult = await db.getAll( + 'SELECT $faceIDColumn FROM $faceClustersTable LEFT JOIN $clusterPersonTable ' + 'ON $faceClustersTable.$clusterIDColumn = $clusterPersonTable.$clusterIDColumn ' + 'WHERE $clusterPersonTable.$personIdColumn = ?', + [personID], + ); + return faceIdsResult.map((e) => e[faceIDColumn] as String).toSet(); + } + + @override + Future> getBlurValuesForCluster(String clusterID) async { + final db = await instance.asyncDB; + const String query = ''' + SELECT $facesTable.$faceBlur + FROM $facesTable + JOIN $faceClustersTable ON $facesTable.$faceIDColumn = $faceClustersTable.$faceIDColumn + WHERE $faceClustersTable.$clusterIDColumn = ? + '''; + // const String query2 = ''' + // SELECT $faceBlur + // FROM $facesTable + // WHERE $faceIDColumn IN (SELECT $faceIDColumn FROM $faceClustersTable WHERE $clusterIDColumn = ?) + // '''; + final List> maps = await db.getAll( + query, + [clusterID], + ); + return maps.map((e) => e[faceBlur] as double).toSet(); + } + + @override + Future> getFaceIdsToClusterIds( + Iterable faceIds, + ) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $faceIDColumn, $clusterIDColumn FROM $faceClustersTable where $faceIDColumn IN (${faceIds.map((id) => "'$id'").join(",")})', + ); + final Map result = {}; + for (final map in maps) { + result[map[faceIDColumn] as String] = map[clusterIDColumn] as String?; + } + return result; + } + + @override + Future>> getFileIdToClusterIds() async { + final Map> result = {}; + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $clusterIDColumn, $faceIDColumn FROM $faceClustersTable', + ); + + for (final map in maps) { + final clusterID = map[clusterIDColumn] as String; + final faceID = map[faceIDColumn] as String; + final fileID = getFileIdFromFaceId(faceID); + result[fileID] = (result[fileID] ?? {})..add(clusterID); + } + return result; + } + + @override + Future forceUpdateClusterIds( + Map faceIDToClusterID, + ) async { + final db = await instance.asyncDB; + + const String sql = ''' + INSERT INTO $faceClustersTable ($faceIDColumn, $clusterIDColumn) + VALUES (?, ?) + ON CONFLICT($faceIDColumn) DO UPDATE SET $clusterIDColumn = excluded.$clusterIDColumn + '''; + final parameterSets = + faceIDToClusterID.entries.map((e) => [e.key, e.value]).toList(); + await db.executeBatch(sql, parameterSets); + } + + @override + Future removeFaceIdToClusterId( + Map faceIDToClusterID, + ) async { + final db = await instance.asyncDB; + const String sql = ''' + DELETE FROM $faceClustersTable + WHERE $faceIDColumn = ? AND $clusterIDColumn = ? + '''; + final parameterSets = + faceIDToClusterID.entries.map((e) => [e.key, e.value]).toList(); + await db.executeBatch(sql, parameterSets); + } + + @override + Future removePerson(String personID) async { + final db = await instance.asyncDB; + + await db.writeTransaction((tx) async { + try { + await tx.execute( + 'DELETE FROM $clusterPersonTable WHERE $personIdColumn = ?', + [personID], + ); + } catch (e) { + _logger.severe('Error in the first write of removePerson', e); + rethrow; + } + try { + await tx.execute( + 'DELETE FROM $notPersonFeedback WHERE $personIdColumn = ?', + [personID], + ); + } catch (e) { + _logger.severe('Error in the second write of removePerson', e); + rethrow; + } + }); + } + + @override + Future> getFaceInfoForClustering({ + int maxFaces = 20000, + int offset = 0, + int batchSize = 10000, + }) async { + try { + final EnteWatch w = EnteWatch("getFaceEmbeddingMap")..start(); + w.logAndReset( + 'reading as float offset: $offset, maxFaces: $maxFaces, batchSize: $batchSize', + ); + final db = await instance.asyncDB; + + final List result = []; + while (true) { + // Query a batch of rows + final List> maps = await db.getAll( + 'SELECT $faceIDColumn, $embeddingColumn, $faceScore, $faceBlur, $isSideways FROM $facesTable' + ' WHERE $faceScore > $kMinimumQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold' + ' ORDER BY $faceIDColumn' + ' DESC LIMIT $batchSize OFFSET $offset', + ); + // Break the loop if no more rows + if (maps.isEmpty) { + break; + } + final List faceIds = []; + for (final map in maps) { + faceIds.add(map[faceIDColumn] as String); + } + final faceIdToClusterId = await getFaceIdsToClusterIds(faceIds); + for (final map in maps) { + final faceID = map[faceIDColumn] as String; + final faceInfo = FaceDbInfoForClustering( + faceID: faceID, + clusterId: faceIdToClusterId[faceID], + embeddingBytes: map[embeddingColumn] as Uint8List, + faceScore: map[faceScore] as double, + blurValue: map[faceBlur] as double, + isSideways: (map[isSideways] as int) == 1, + ); + result.add(faceInfo); + } + if (result.length >= maxFaces) { + break; + } + offset += batchSize; + } + w.stopWithLog('done reading face embeddings ${result.length}'); + return result; + } catch (e) { + _logger.severe('err in getFaceInfoForClustering', e); + rethrow; + } + } + + @override + Future> getFaceEmbeddingMapForFaces( + Iterable faceIDs, + ) async { + _logger.info('reading face embeddings for ${faceIDs.length} faces'); + final db = await instance.asyncDB; + + // Define the batch size + const batchSize = 10000; + int offset = 0; + + final Map result = {}; + while (true) { + // Query a batch of rows + final String query = ''' + SELECT $faceIDColumn, $embeddingColumn + FROM $facesTable + WHERE $faceIDColumn IN (${faceIDs.map((id) => "'$id'").join(",")}) + ORDER BY $faceIDColumn DESC + LIMIT $batchSize OFFSET $offset + '''; + final List> maps = await db.getAll(query); + // Break the loop if no more rows + if (maps.isEmpty) { + break; + } + for (final map in maps) { + final faceID = map[faceIDColumn] as String; + result[faceID] = map[embeddingColumn] as Uint8List; + } + if (result.length > 10000) { + break; + } + offset += batchSize; + } + _logger.info('done reading face embeddings for ${faceIDs.length} faces'); + return result; + } + + @override + Future getTotalFaceCount() async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $kMinimumQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold', + ); + return maps.first['count'] as int; + } + + @override + Future getErroredFaceCount() async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore < 0', + ); + return maps.first['count'] as int; + } + + @override + Future> getErroredFileIDs() async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT DISTINCT $fileIDColumn FROM $facesTable WHERE $faceScore < 0', + ); + return maps.map((e) => e[fileIDColumn] as String).toSet(); + } + + @override + Future deleteFaceIndexForFiles(List fileIDs) async { + final db = await instance.asyncDB; + final String sql = ''' + DELETE FROM $facesTable WHERE $fileIDColumn IN (${fileIDs.join(", ")}) + '''; + await db.execute(sql); + } + + @override + Future getClusteredOrFacelessFileCount() async { + final db = await instance.asyncDB; + final List> clustered = await db.getAll( + 'SELECT $faceIDColumn FROM $faceClustersTable', + ); + final Set clusteredFileIDs = {}; + for (final map in clustered) { + final int fileID = getFileIdFromFaceId(map[faceIDColumn] as String); + clusteredFileIDs.add(fileID); + } + + final List> badFacesFiles = await db.getAll( + 'SELECT DISTINCT $fileIDColumn FROM $facesTable WHERE $faceScore <= $kMinimumQualityFaceScore OR $faceBlur <= $kLaplacianHardThreshold', + ); + final Set badFileIDs = {}; + for (final map in badFacesFiles) { + badFileIDs.add(map[fileIDColumn] as int); + } + + final List> goodFacesFiles = await db.getAll( + 'SELECT DISTINCT $fileIDColumn FROM $facesTable WHERE $faceScore > $kMinimumQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold', + ); + final Set goodFileIDs = {}; + for (final map in goodFacesFiles) { + goodFileIDs.add(map[fileIDColumn] as int); + } + final trulyFacelessFiles = badFileIDs.difference(goodFileIDs); + return clusteredFileIDs.length + trulyFacelessFiles.length; + } + + @override + Future getClusteredToIndexableFilesRatio() async { + final int indexableFiles = (await getIndexableFileIDs()).length; + final int clusteredFiles = await getClusteredOrFacelessFileCount(); + + return clusteredFiles / indexableFiles; + } + + @override + Future getUnclusteredFaceCount() async { + final db = await instance.asyncDB; + const String query = ''' + SELECT f.$faceIDColumn + FROM $facesTable f + LEFT JOIN $faceClustersTable fc ON f.$faceIDColumn = fc.$faceIDColumn + WHERE f.$faceScore > $kMinimumQualityFaceScore + AND f.$faceBlur > $kLaplacianHardThreshold + AND fc.$faceIDColumn IS NULL + '''; + final List> maps = await db.getAll(query); + return maps.length; + } + + /// WARNING: Only use this method if the person has just been created. + /// Otherwise, use [ClusterFeedbackService.instance.addClusterToExistingPerson] instead. + @override + Future assignClusterToPerson({ + required String personID, + required String clusterID, + }) async { + final db = await instance.asyncDB; + + const String sql = ''' + INSERT INTO $clusterPersonTable ($personIdColumn, $clusterIDColumn) VALUES (?, ?) ON CONFLICT($personIdColumn, $clusterIDColumn) DO NOTHING + '''; + await db.execute(sql, [personID, clusterID]); + } + + @override + Future bulkAssignClusterToPersonID( + Map clusterToPersonID, + ) async { + final db = await instance.asyncDB; + + const String sql = ''' + INSERT INTO $clusterPersonTable ($personIdColumn, $clusterIDColumn) VALUES (?, ?) ON CONFLICT($personIdColumn, $clusterIDColumn) DO NOTHING + '''; + final parameterSets = + clusterToPersonID.entries.map((e) => [e.value, e.key]).toList(); + await db.executeBatch(sql, parameterSets); + } + + @override + Future captureNotPersonFeedback({ + required String personID, + required String clusterID, + }) async { + final db = await instance.asyncDB; + + const String sql = ''' + INSERT INTO $notPersonFeedback ($personIdColumn, $clusterIDColumn) VALUES (?, ?) ON CONFLICT DO NOTHING + '''; + await db.execute(sql, [personID, clusterID]); + } + + @override + Future bulkCaptureNotPersonFeedback( + Map clusterToPersonID, + ) async { + final db = await instance.asyncDB; + + const String sql = ''' + INSERT INTO $notPersonFeedback ($personIdColumn, $clusterIDColumn) VALUES (?, ?) ON CONFLICT DO NOTHING + '''; + final parameterSets = + clusterToPersonID.entries.map((e) => [e.value, e.key]).toList(); + + await db.executeBatch(sql, parameterSets); + } + + @override + Future removeNotPersonFeedback({ + required String personID, + required String clusterID, + }) async { + final db = await instance.asyncDB; + + const String sql = ''' + DELETE FROM $notPersonFeedback WHERE $personIdColumn = ? AND $clusterIDColumn = ? + '''; + await db.execute(sql, [personID, clusterID]); + } + + @override + Future removeClusterToPerson({ + required String personID, + required String clusterID, + }) async { + final db = await instance.asyncDB; + + const String sql = ''' + DELETE FROM $clusterPersonTable WHERE $personIdColumn = ? AND $clusterIDColumn = ? + '''; + await db.execute(sql, [personID, clusterID]); + } + + // for a given personID, return a map of clusterID to fileIDs using join query + @override + Future>> getFileIdToClusterIDSet(String personID) { + final db = instance.asyncDB; + return db.then((db) async { + final List> maps = await db.getAll( + 'SELECT $faceClustersTable.$clusterIDColumn, $faceIDColumn FROM $faceClustersTable ' + 'INNER JOIN $clusterPersonTable ' + 'ON $faceClustersTable.$clusterIDColumn = $clusterPersonTable.$clusterIDColumn ' + 'WHERE $clusterPersonTable.$personIdColumn = ?', + [personID], + ); + final Map> result = {}; + for (final map in maps) { + final clusterID = map[clusterIDColumn] as String; + final String faceID = map[faceIDColumn] as String; + final fileID = getFileIdFromFaceId(faceID); + result[fileID] = (result[fileID] ?? {})..add(clusterID); + } + return result; + }); + } + + @override + Future>> getFileIdToClusterIDSetForCluster( + Set clusterIDs, + ) { + final db = instance.asyncDB; + return db.then((db) async { + final List> maps = await db.getAll( + ''' + SELECT $clusterIDColumn, $faceIDColumn + FROM $faceClustersTable + WHERE $clusterIDColumn IN (${List.filled(clusterIDs.length, '?').join(',')}) + ''', + [...clusterIDs], + ); + final Map> result = {}; + for (final map in maps) { + final clusterID = map[clusterIDColumn] as String; + final faceID = map[faceIDColumn] as String; + final fileID = getFileIdFromFaceId(faceID); + result[fileID] = (result[fileID] ?? {})..add(clusterID); + } + return result; + }); + } + + @override + Future clusterSummaryUpdate( + Map summary, + ) async { + final db = await instance.asyncDB; + + const String sql = ''' + INSERT INTO $clusterSummaryTable ($clusterIDColumn, $avgColumn, $countColumn) VALUES (?, ?, ?) ON CONFLICT($clusterIDColumn) DO UPDATE SET $avgColumn = excluded.$avgColumn, $countColumn = excluded.$countColumn + '''; + final List> parameterSets = []; + int batchCounter = 0; + for (final entry in summary.entries) { + if (batchCounter == 400) { + await db.executeBatch(sql, parameterSets); + batchCounter = 0; + parameterSets.clear(); + } + final String clusterID = entry.key; + final int count = entry.value.$2; + final Uint8List avg = entry.value.$1; + parameterSets.add([clusterID, avg, count]); + batchCounter++; + } + await db.executeBatch(sql, parameterSets); + } + + @override + Future deleteClusterSummary(String clusterID) async { + final db = await instance.asyncDB; + const String sqlDelete = + 'DELETE FROM $clusterSummaryTable WHERE $clusterIDColumn = ?'; + await db.execute(sqlDelete, [clusterID]); + } + + /// Returns a map of clusterID to (avg embedding, count) + @override + Future> getAllClusterSummary([ + int? minClusterSize, + ]) async { + final db = await instance.asyncDB; + final Map result = {}; + final rows = await db.getAll( + 'SELECT * FROM $clusterSummaryTable${minClusterSize != null ? ' WHERE $countColumn >= $minClusterSize' : ''}', + ); + for (final r in rows) { + final id = r[clusterIDColumn] as String; + final avg = r[avgColumn] as Uint8List; + final count = r[countColumn] as int; + result[id] = (avg, count); + } + return result; + } + + @override + Future> getClusterToClusterSummary( + Iterable clusterIDs, + ) async { + final db = await instance.asyncDB; + final Map result = {}; + + final rows = await db.getAll( + 'SELECT * FROM $clusterSummaryTable WHERE $clusterIDColumn IN (${List.filled(clusterIDs.length, '?').join(',')})', + [...clusterIDs], + ); + + for (final r in rows) { + final id = r[clusterIDColumn] as String; + final avg = r[avgColumn] as Uint8List; + final count = r[countColumn] as int; + result[id] = (avg, count); + } + return result; + } + + @override + Future> getClusterIDToPersonID() async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $personIdColumn, $clusterIDColumn FROM $clusterPersonTable', + ); + final Map result = {}; + for (final map in maps) { + result[map[clusterIDColumn] as String] = map[personIdColumn] as String; + } + return result; + } + + /// WARNING: This will delete ALL data in the database! Only use this for debug/testing purposes! + @override + Future dropClustersAndPersonTable({bool faces = false}) async { + try { + final db = await instance.asyncDB; + if (faces) { + await db.execute(deleteFacesTable); + await db.execute(getCreateFacesTable(true)); + await db.execute(deleteFaceClustersTable); + await db.execute(createFaceClustersTable); + await db.execute(fcClusterIDIndex); + } + + await db.execute(deleteClusterPersonTable); + await db.execute(deleteNotPersonFeedbackTable); + await db.execute(deleteClusterSummaryTable); + await db.execute(deleteFaceClustersTable); + + await db.execute(createClusterPersonTable); + await db.execute(createNotPersonFeedbackTable); + await db.execute(createClusterSummaryTable); + await db.execute(createFaceClustersTable); + await db.execute(fcClusterIDIndex); + } catch (e, s) { + _logger.severe('Error dropping clusters and person table', e, s); + } + } + + /// WARNING: This will delete ALL data in the tables! Only use this for debug/testing purposes! + @override + Future dropFacesFeedbackTables() async { + try { + final db = await instance.asyncDB; + + // Drop the tables + await db.execute(deleteClusterPersonTable); + await db.execute(deleteNotPersonFeedbackTable); + + // Recreate the tables + await db.execute(createClusterPersonTable); + await db.execute(createNotPersonFeedbackTable); + } catch (e) { + _logger.severe('Error dropping feedback tables', e); + } + } + + @override + Future> getFileIDsOfPersonID(String personID) async { + final db = await instance.asyncDB; + final result = await db.getAll( + ''' + SELECT DISTINCT $facesTable.$fileIDColumn + FROM $clusterPersonTable + JOIN $faceClustersTable ON $clusterPersonTable.$clusterIDColumn = $faceClustersTable.$clusterIDColumn + JOIN $facesTable ON $faceClustersTable.$faceIDColumn = $facesTable.$faceIDColumn + WHERE $clusterPersonTable.$personIdColumn = ? + ''', + [personID], + ); + + return [for (final row in result) row[fileIDColumn]]; + } + + @override + Future> getFileIDsOfClusterID(String clusterID) async { + final db = await instance.asyncDB; + final result = await db.getAll( + ''' + SELECT DISTINCT $facesTable.$fileIDColumn + FROM $faceClustersTable + JOIN $facesTable ON $faceClustersTable.$faceIDColumn = $facesTable.$faceIDColumn + WHERE $faceClustersTable.$clusterIDColumn = ? + ''', + [clusterID], + ); + + return [for (final row in result) row[fileIDColumn]]; + } + + @override + Future> getAllFileIDsOfFaceIDsNotInAnyCluster() async { + final db = await instance.asyncDB; + final result = await db.getAll( + ''' + SELECT DISTINCT file_id + FROM faces + LEFT JOIN face_clusters ON faces.face_id = face_clusters.face_id + WHERE face_clusters.face_id IS NULL; + ''', + ); + return {for (final row in result) row[fileIDColumn]}; + } + + @override + Future> getAllFilesAssociatedWithAllClusters({ + List? exceptClusters, + }) async { + final notInParam = exceptClusters?.map((e) => "'$e'").join(',') ?? ''; + final db = await instance.asyncDB; + final result = await db.getAll(''' + SELECT DISTINCT $facesTable.$fileIDColumn + FROM $facesTable + JOIN $faceClustersTable on $faceClustersTable.$faceIDColumn = $facesTable.$faceIDColumn + WHERE $faceClustersTable.$clusterIDColumn NOT IN ($notInParam); + '''); + + return {for (final row in result) row[fileIDColumn]}; + } + + @override + Future> getAllClipVectors() async { + Logger("ClipDB").info("reading all embeddings from DB"); + final db = await instance.asyncDB; + final results = await db + .getAll('SELECT $fileIDColumn, $embeddingColumn FROM $clipTable'); + + // Convert rows to vectors + final List embeddings = []; + for (final result in results) { + // Convert to EmbeddingVector + final embedding = EmbeddingVector( + fileID: result[fileIDColumn], + embedding: Float32List.view(result[embeddingColumn].buffer), + ); + if (embedding.isEmpty) continue; + embeddings.add(embedding); + } + return embeddings; + } + + // Get indexed FileIDs + @override + Future> clipIndexedFileWithVersion() async { + final db = await instance.asyncDB; + final maps = await db + .getAll('SELECT $fileIDColumn , $mlVersionColumn FROM $clipTable'); + final Map result = {}; + for (final map in maps) { + result[map[fileIDColumn] as int] = map[mlVersionColumn] as int; + } + return result; + } + + @override + Future getClipIndexedFileCount({ + int minimumMlVersion = clipMlVersion, + }) async { + final db = await instance.asyncDB; + final String query = + 'SELECT COUNT(DISTINCT $fileIDColumn) as count FROM $clipTable WHERE $mlVersionColumn >= $minimumMlVersion'; + final List> maps = await db.getAll(query); + return maps.first['count'] as int; + } + + @override + Future putClip(List embeddings) async { + if (embeddings.isEmpty) return; + final db = await instance.asyncDB; + if (embeddings.length == 1) { + await db.execute( + 'INSERT OR REPLACE INTO $clipTable ($fileIDColumn, $embeddingColumn, $mlVersionColumn) VALUES (?, ?, ?)', + _getRowFromEmbedding(embeddings.first), + ); + } else { + final inputs = embeddings.map((e) => _getRowFromEmbedding(e)).toList(); + await db.executeBatch( + 'INSERT OR REPLACE INTO $clipTable ($fileIDColumn, $embeddingColumn, $mlVersionColumn) values(?, ?, ?)', + inputs, + ); + } + Bus.instance.fire(EmbeddingUpdatedEvent()); + } + + @override + Future deleteClipEmbeddings(List fileIDs) async { + final db = await instance.asyncDB; + await db.execute( + 'DELETE FROM $clipTable WHERE $fileIDColumn IN (${fileIDs.join(", ")})', + ); + Bus.instance.fire(EmbeddingUpdatedEvent()); + } + + @override + Future deleteClipIndexes() async { + final db = await instance.asyncDB; + await db.execute('DELETE FROM $clipTable'); + Bus.instance.fire(EmbeddingUpdatedEvent()); + } + + List _getRowFromEmbedding(ClipEmbedding embedding) { + return [ + embedding.fileID, + Float32List.fromList(embedding.embedding).buffer.asUint8List(), + embedding.version, + ]; + } +} diff --git a/mobile/lib/db/ml/schema.dart b/mobile/lib/db/ml/schema.dart index 5968215cf5..3973e0009d 100644 --- a/mobile/lib/db/ml/schema.dart +++ b/mobile/lib/db/ml/schema.dart @@ -16,8 +16,9 @@ const mlVersionColumn = 'ml_version'; const personIdColumn = 'person_id'; const clusterIDColumn = 'cluster_id'; -const createFacesTable = '''CREATE TABLE IF NOT EXISTS $facesTable ( - $fileIDColumn INTEGER NOT NULL, +String getCreateFacesTable(bool isOfflineDB) { + return '''CREATE TABLE IF NOT EXISTS $facesTable ( + $fileIDColumn ${isOfflineDB ? 'TEXT' : 'INTEGER'} NOT NULL, $faceIDColumn TEXT NOT NULL UNIQUE, $faceDetectionColumn TEXT NOT NULL, $embeddingColumn BLOB NOT NULL, @@ -30,6 +31,7 @@ const createFacesTable = '''CREATE TABLE IF NOT EXISTS $facesTable ( PRIMARY KEY($fileIDColumn, $faceIDColumn) ); '''; +} const deleteFacesTable = 'DELETE FROM $facesTable'; // End of Faces Table Fields & Schema Queries @@ -97,18 +99,20 @@ const deleteNotPersonFeedbackTable = 'DELETE FROM $notPersonFeedback'; // ## CLIP EMBEDDINGS TABLE const clipTable = 'clip'; -const createClipEmbeddingsTable = ''' -CREATE TABLE IF NOT EXISTS $clipTable ( - $fileIDColumn INTEGER NOT NULL, +String getCreateClipEmbeddingsTable(bool isOfflineDB) { + return '''CREATE TABLE IF NOT EXISTS $clipTable ( + $fileIDColumn ${isOfflineDB ? 'TEXT' : 'INTEGER'} NOT NULL, $embeddingColumn BLOB NOT NULL, $mlVersionColumn INTEGER NOT NULL, - PRIMARY KEY ($fileIDColumn) + PRIMARY KEY($fileIDColumn) ); -'''; + '''; +} const deleteClipEmbeddingsTable = 'DELETE FROM $clipTable'; const fileDataTable = 'filedata'; + const createFileDataTable = ''' CREATE TABLE IF NOT EXISTS $fileDataTable ( $fileIDColumn INTEGER NOT NULL,