diff --git a/mobile/lib/db/ml/base.dart b/mobile/lib/db/ml/base.dart index a99f70f813..c2685b1214 100644 --- a/mobile/lib/db/ml/base.dart +++ b/mobile/lib/db/ml/base.dart @@ -108,6 +108,9 @@ abstract class IMLDataDB { Future> getAllClipVectors(); Future> clipIndexedFileWithVersion(); + Future> getClipVectorsForFileIDs( + Iterable fileIDs, + ); Future getClipIndexedFileCount({int minimumMlVersion}); Future putClip(List embeddings); Future deleteClipEmbeddings(List fileIDs); diff --git a/mobile/lib/db/ml/db.dart b/mobile/lib/db/ml/db.dart index 0242d82269..862f748e04 100644 --- a/mobile/lib/db/ml/db.dart +++ b/mobile/lib/db/ml/db.dart @@ -418,6 +418,28 @@ class MLDataDB extends IMLDataDB { return maps.map((e) => mapRowToFace(e)).toList(); } + Future>> getFacesForFileIDs( + Iterable fileUploadIDs, + ) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + ''' + SELECT * FROM $facesTable + WHERE $fileIDColumn IN (${fileUploadIDs.map((id) => "'$id'").join(",")}) + ''', + ); + if (maps.isEmpty) { + return {}; + } + final result = >{}; + for (final map in maps) { + final face = mapRowToFace(map); + final fileID = map[fileIDColumn] as int; + result.putIfAbsent(fileID, () => []).add(face); + } + return result; + } + @override Future>> getClusterToFaceIDs( Set clusterIDs, @@ -503,6 +525,22 @@ class MLDataDB extends IMLDataDB { return result; } + Future> getFaceIdToPersonIdForFaces( + Iterable faceIDs, + ) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $faceIDColumn, $personIdColumn FROM $clusterPersonTable ' + 'INNER JOIN $faceClustersTable ON $clusterPersonTable.$clusterIDColumn = $faceClustersTable.$clusterIDColumn ' + 'WHERE $faceIDColumn IN (${faceIDs.map((id) => "'$id'").join(",")})', + ); + final Map result = {}; + for (final map in maps) { + result[map[faceIDColumn] as String] = map[personIdColumn] as String; + } + return result; + } + @override Future>> getClusterIdToFaceIdsForPerson( String personID, @@ -1164,6 +1202,22 @@ class MLDataDB extends IMLDataDB { return _convertToVectors(results); } + @override + Future> getClipVectorsForFileIDs( + Iterable fileIDs, + ) async { + final db = await MLDataDB.instance.asyncDB; + final results = await db.getAll( + 'SELECT * FROM $clipTable WHERE $fileIDColumn IN (${fileIDs.join(", ")})', + ); + final Map embeddings = {}; + for (final result in results) { + final embedding = _getVectorFromRow(result); + embeddings[embedding.fileID] = embedding; + } + return embeddings; + } + // Get indexed FileIDs @override Future> clipIndexedFileWithVersion() async { diff --git a/mobile/lib/services/search_service.dart b/mobile/lib/services/search_service.dart index e5544a9c46..93e2aabf1f 100644 --- a/mobile/lib/services/search_service.dart +++ b/mobile/lib/services/search_service.dart @@ -4,6 +4,7 @@ import "package:flutter/cupertino.dart"; import "package:flutter/material.dart"; import "package:intl/intl.dart"; import 'package:logging/logging.dart'; +import "package:ml_linalg/linalg.dart"; import "package:photos/core/constants.dart"; import 'package:photos/core/event_bus.dart'; import 'package:photos/data/holidays.dart'; @@ -40,6 +41,7 @@ import "package:photos/services/filter/db_filters.dart"; import "package:photos/services/location_service.dart"; import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; +import "package:photos/services/machine_learning/ml_computer.dart"; import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart'; import "package:photos/services/user_remote_flag_service.dart"; import "package:photos/states/location_screen_state.dart"; @@ -1196,9 +1198,12 @@ class SearchService { final currentTime = DateTime.now().toLocal(); final currentDayMonth = currentTime.month * 100 + currentTime.day; + final currentWeek = _getWeekNumber(currentTime); + final currentMonth = currentTime.month; final cutOffTime = currentTime.subtract(const Duration(days: 365)); final averageDailyPhotos = allFiles.length / 365; - final significanceThreshold = averageDailyPhotos * 0.25; + final significantDayThreshold = averageDailyPhotos * 0.25; + final significantWeekThreshold = averageDailyPhotos * 0.40; // Group files by day-month and year final dayMonthYearGroups = >>{}; @@ -1217,59 +1222,62 @@ class SearchService { .add(file); } - // Process each day-month + // Process each nearby day-month to find significant days for (final dayMonth in dayMonthYearGroups.keys) { final dayDiff = dayMonth - currentDayMonth; if (dayDiff < 0 || dayDiff > 2) continue; + // TODO: lau: this doesn't cover month changes properly final yearGroups = dayMonthYearGroups[dayMonth]!; - final significantYears = yearGroups.entries - .where((e) => e.value.length > significanceThreshold) + final significantDays = yearGroups.entries + .where((e) => e.value.length > significantDayThreshold) .map((e) => e.key) .toList(); - if (significantYears.length >= 3) { + if (significantDays.length >= 3) { // Combine all years for this day-month final date = DateTime(currentTime.year, dayMonth ~/ 100, dayMonth % 100); final allPhotos = yearGroups.values.expand((x) => x).toList(); + final photoSelection = await _bestSelection(allPhotos); searchResults.add( GenericSearchResult( ResultType.event, - 'Memories of ${DateFormat('MMMM d').format(date)}', - allPhotos, + "${DateFormat('MMMM d').format(date)} through the years", + photoSelection, hierarchicalSearchFilter: TopLevelGenericFilter( filterName: DateFormat('MMMM d').format(date), occurrence: kMostRelevantFilter, filterResultType: ResultType.event, - matchedUploadedIDs: filesToUploadedFileIDs(allPhotos), + matchedUploadedIDs: filesToUploadedFileIDs(photoSelection), filterIcon: Icons.event_outlined, ), ), ); } else { // Individual entries for significant years - for (final year in significantYears) { + for (final year in significantDays) { final date = DateTime(year, dayMonth ~/ 100, dayMonth % 100); final files = yearGroups[year]!; + final photoSelection = await _bestSelection(files); String name = DateFormat.yMMMd(Localizations.localeOf(context).languageCode) .format(date); if (date.day == currentTime.day && date.month == currentTime.month) { - name = 'This day, ${currentTime.year - date.year} years back'; + name = "This day, ${currentTime.year - date.year} years back"; } searchResults.add( GenericSearchResult( ResultType.event, name, - files, + photoSelection, hierarchicalSearchFilter: TopLevelGenericFilter( filterName: name, occurrence: kMostRelevantFilter, filterResultType: ResultType.event, - matchedUploadedIDs: filesToUploadedFileIDs(files), + matchedUploadedIDs: filesToUploadedFileIDs(photoSelection), filterIcon: Icons.event_outlined, ), ), @@ -1277,12 +1285,342 @@ class SearchService { } } - if (limit != null && searchResults.length >= limit) break; + if (limit != null && searchResults.length >= limit) return searchResults; } + // process to find significant weeks (only if there are no significant days) + if (searchResults.isEmpty) { + // Group files by week and year + final currentWeekYearGroups = >{}; + for (final file in allFiles) { + if (file.creationTime! > cutOffTime.microsecondsSinceEpoch) continue; + + final creationTime = + DateTime.fromMicrosecondsSinceEpoch(file.creationTime!); + final week = _getWeekNumber(creationTime); + if (week != currentWeek) continue; + final year = creationTime.year; + + currentWeekYearGroups.putIfAbsent(year, () => []).add(file); + } + + // Process the week and see if it's significant + if (currentWeekYearGroups.isNotEmpty) { + final significantWeeks = currentWeekYearGroups.entries + .where((e) => e.value.length > significantWeekThreshold) + .map((e) => e.key) + .toList(); + if (significantWeeks.length >= 3) { + // Combine all years for this week + final allPhotos = + currentWeekYearGroups.values.expand((x) => x).toList(); + final photoSelection = await _bestSelection(allPhotos); + + searchResults.add( + GenericSearchResult( + ResultType.event, + "This week through the years", + photoSelection, + hierarchicalSearchFilter: TopLevelGenericFilter( + filterName: "Week $currentWeek", + occurrence: kMostRelevantFilter, + filterResultType: ResultType.event, + matchedUploadedIDs: filesToUploadedFileIDs(photoSelection), + filterIcon: Icons.event_outlined, + ), + ), + ); + } else { + // Individual entries for significant years + for (final year in significantWeeks) { + final date = DateTime(year, 1, 1).add( + Duration(days: (currentWeek - 1) * 7), + ); + final files = currentWeekYearGroups[year]!; + final photoSelection = await _bestSelection(files); + final name = + "This week, ${currentTime.year - date.year} years back"; + + searchResults.add( + GenericSearchResult( + ResultType.event, + name, + photoSelection, + hierarchicalSearchFilter: TopLevelGenericFilter( + filterName: name, + occurrence: kMostRelevantFilter, + filterResultType: ResultType.event, + matchedUploadedIDs: filesToUploadedFileIDs(photoSelection), + filterIcon: Icons.event_outlined, + ), + ), + ); + } + } + } + } + + if (limit != null && searchResults.length >= limit) return searchResults; + + // process to find fillers (months) + const wantedMemories = 3; + final neededMemories = wantedMemories - searchResults.length; + if (neededMemories <= 0) return searchResults; + const monthSelectionSize = 20; + + // Group files by month and year + final currentMonthYearGroups = >{}; + for (final file in allFiles) { + if (file.creationTime! > cutOffTime.microsecondsSinceEpoch) continue; + + final creationTime = + DateTime.fromMicrosecondsSinceEpoch(file.creationTime!); + final month = creationTime.month; + if (month != currentMonth) continue; + final year = creationTime.year; + + currentMonthYearGroups.putIfAbsent(year, () => []).add(file); + } + + // Add the largest two months plus the month through the years + final sortedYearsForCurrentMonth = currentMonthYearGroups.keys.toList() + ..sort( + (a, b) => currentMonthYearGroups[b]!.length.compareTo( + currentMonthYearGroups[a]!.length, + ), + ); + if (neededMemories > 1) { + for (int i = neededMemories; i > 1; i--) { + if (sortedYearsForCurrentMonth.isEmpty) break; + final year = sortedYearsForCurrentMonth.removeAt(0); + final monthYearFiles = currentMonthYearGroups[year]!; + final photoSelection = await _bestSelection( + monthYearFiles, + prefferedSize: monthSelectionSize, + ); + final monthName = + DateFormat.MMMM(Localizations.localeOf(context).languageCode) + .format(DateTime(year, currentMonth)); + final name = monthName + ", ${currentTime.year - year} years back"; + searchResults.add( + GenericSearchResult( + ResultType.event, + name, + photoSelection, + hierarchicalSearchFilter: TopLevelGenericFilter( + filterName: name, + occurrence: kMostRelevantFilter, + filterResultType: ResultType.event, + matchedUploadedIDs: filesToUploadedFileIDs(photoSelection), + filterIcon: Icons.event_outlined, + ), + ), + ); + } + } + // Show the month through the remaining years + if (sortedYearsForCurrentMonth.isEmpty) return searchResults; + final allPhotos = sortedYearsForCurrentMonth + .expand((year) => currentMonthYearGroups[year]!) + .toList(); + final photoSelection = + await _bestSelection(allPhotos, prefferedSize: monthSelectionSize); + final monthName = + DateFormat.MMMM(Localizations.localeOf(context).languageCode) + .format(DateTime(currentTime.year, currentMonth)); + final name = monthName + " through the years"; + searchResults.add( + GenericSearchResult( + ResultType.event, + name, + photoSelection, + hierarchicalSearchFilter: TopLevelGenericFilter( + filterName: name, + occurrence: kMostRelevantFilter, + filterResultType: ResultType.event, + matchedUploadedIDs: filesToUploadedFileIDs(photoSelection), + filterIcon: Icons.event_outlined, + ), + ), + ); + return searchResults; } + int _getWeekNumber(DateTime date) { + // Get day of year (1-366) + final int dayOfYear = int.parse(DateFormat('D').format(date)); + // Integer division by 7 and add 1 to start from week 1 + return ((dayOfYear - 1) ~/ 7) + 1; + } + + /// Returns the best selection of files from the given list. + /// Makes sure that the selection is not more than [prefferedSize] or 10 files, + /// and that each year of the original list is represented. + Future> _bestSelection( + List files, { + int? prefferedSize, + }) async { + final fileCount = files.length; + int targetSize = prefferedSize ?? 10; + if (fileCount <= targetSize) return files; + final safeFiles = + files.where((file) => file.uploadedFileID != null).toList(); + final fileIDs = safeFiles.map((e) => e.uploadedFileID!).toSet(); + final fileIdToFace = await MLDataDB.instance.getFacesForFileIDs(fileIDs); + final faceIDs = + fileIdToFace.values.expand((x) => x.map((face) => face.faceID)).toSet(); + final faceIDsToPersonID = + await MLDataDB.instance.getFaceIdToPersonIdForFaces(faceIDs); + final fileIdToClip = + await MLDataDB.instance.getClipVectorsForFileIDs(fileIDs); + final allYears = safeFiles.map((e) { + final creationTime = DateTime.fromMicrosecondsSinceEpoch(e.creationTime!); + return creationTime.year; + }).toSet(); + + // Get clip scores for each file + const query = + 'Photo of a precious memory radiating warmth, vibrant energy, or quiet beauty — alive with color, light, or emotion'; + // TODO: lau: optimize this later so we don't keep computing embedding + final textEmbedding = await MLComputer.instance.runClipText(query); + final textVector = Vector.fromList(textEmbedding); + const clipThreshold = 0.75; + final fileToScore = {}; + for (final file in safeFiles) { + final clip = fileIdToClip[file.uploadedFileID!]; + if (clip == null) { + fileToScore[file.uploadedFileID!] = 0; + continue; + } + final score = clip.vector.dot(textVector); + fileToScore[file.uploadedFileID!] = score; + } + + // Get face scores for each file + final fileToFaceCount = {}; + for (final file in safeFiles) { + final fileID = file.uploadedFileID!; + fileToFaceCount[fileID] = 0; + final faces = fileIdToFace[fileID]; + if (faces == null || faces.isEmpty) { + continue; + } + for (final face in faces) { + if (faceIDsToPersonID.containsKey(face.faceID)) { + fileToFaceCount[fileID] = fileToFaceCount[fileID]! + 10; + } else { + fileToFaceCount[fileID] = fileToFaceCount[fileID]! + 1; + } + } + } + + final filteredFiles = []; + if (allYears.length <= 1) { + // TODO: lau: eventually this sorting might have to be replaced with some scoring system + // sort first on clip embeddings score (descending) + safeFiles.sort( + (a, b) => fileToScore[b.uploadedFileID!]! + .compareTo(fileToScore[a.uploadedFileID!]!), + ); + // then sort on faces (descending), heavily prioritizing named faces + safeFiles.sort( + (a, b) => fileToFaceCount[b.uploadedFileID!]! + .compareTo(fileToFaceCount[a.uploadedFileID!]!), + ); + + // then filter out similar images as much as possible + filteredFiles.add(safeFiles.first); + int skipped = 0; + filesLoop: + for (final file in safeFiles.sublist(1)) { + if (filteredFiles.length >= targetSize) break; + final clip = fileIdToClip[file.uploadedFileID!]; + if (clip != null && (fileCount - skipped) > targetSize) { + for (final filteredFile in filteredFiles) { + final fClip = fileIdToClip[filteredFile.uploadedFileID!]; + if (fClip == null) continue; + final similarity = clip.vector.dot(fClip.vector); + if (similarity > clipThreshold) { + skipped++; + continue filesLoop; + } + } + } + filteredFiles.add(file); + } + } else { + // Multiple years, each represented and roughly equally distributed + if (prefferedSize == null && (allYears.length * 2) > 10) { + targetSize = allYears.length * 3; + if (fileCount < targetSize) return safeFiles; + } + + // Group files by year and sort each year's list by CLIP then face count + final yearToFiles = >{}; + for (final file in safeFiles) { + final creationTime = + DateTime.fromMicrosecondsSinceEpoch(file.creationTime!); + final year = creationTime.year; + yearToFiles.putIfAbsent(year, () => []).add(file); + } + + for (final year in yearToFiles.keys) { + final yearFiles = yearToFiles[year]!; + // sort first on clip embeddings score (descending) + yearFiles.sort( + (a, b) => fileToScore[b.uploadedFileID!]! + .compareTo(fileToScore[a.uploadedFileID!]!), + ); + // then sort on faces (descending), heavily prioritizing named faces + yearFiles.sort( + (a, b) => fileToFaceCount[b.uploadedFileID!]! + .compareTo(fileToFaceCount[a.uploadedFileID!]!), + ); + } + + // Then join the years together one by one and filter similar images + final years = yearToFiles.keys.toList() + ..sort((a, b) => b.compareTo(a)); // Recent years first + int round = 0; + int skipped = 0; + whileLoop: + while (filteredFiles.length + skipped < fileCount) { + yearLoop: + for (final year in years) { + final yearFiles = yearToFiles[year]!; + if (yearFiles.isEmpty) continue; + final newFile = yearFiles.removeAt(0); + if (round != 0 && (fileCount - skipped) > targetSize) { + // check for filtering + final clip = fileIdToClip[newFile.uploadedFileID!]; + if (clip != null) { + for (final filteredFile in filteredFiles) { + final fClip = fileIdToClip[filteredFile.uploadedFileID!]; + if (fClip == null) continue; + final similarity = clip.vector.dot(fClip.vector); + if (similarity > clipThreshold) { + skipped++; + continue yearLoop; + } + } + } + } + filteredFiles.add(newFile); + if (filteredFiles.length >= targetSize || + filteredFiles.length + skipped >= fileCount) { + break whileLoop; + } + } + round++; + } + } + + // Order the final selection chronologically + filteredFiles.sort((a, b) => b.creationTime!.compareTo(a.creationTime!)); + return filteredFiles; + } + Future getRandomDateResults( BuildContext context, ) async { diff --git a/mobile/pubspec.yaml b/mobile/pubspec.yaml index 98a635b3ab..07b2938512 100644 --- a/mobile/pubspec.yaml +++ b/mobile/pubspec.yaml @@ -12,7 +12,7 @@ description: ente photos application # Read more about iOS versioning at # https://developer.apple.com/library/archive/documentation/General/Reference/InfoPlistKeyReference/Articles/CoreFoundationKeys.html -version: 0.9.90+990 +version: 0.9.91+991 publish_to: none environment: