From e50ab7ca8415277a01ca72a313e8666e64a5da19 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Tue, 28 Jan 2025 10:43:58 +0530 Subject: [PATCH] [mob][photos] First iteration of file selection --- mobile/lib/services/search_service.dart | 90 +++++++++++++++++++++++-- 1 file changed, 86 insertions(+), 4 deletions(-) diff --git a/mobile/lib/services/search_service.dart b/mobile/lib/services/search_service.dart index ad49703911..d0d7df0642 100644 --- a/mobile/lib/services/search_service.dart +++ b/mobile/lib/services/search_service.dart @@ -4,12 +4,14 @@ import "package:flutter/cupertino.dart"; import "package:flutter/material.dart"; import "package:intl/intl.dart"; import 'package:logging/logging.dart'; +import "package:ml_linalg/linalg.dart"; import "package:photos/core/constants.dart"; import 'package:photos/core/event_bus.dart'; import 'package:photos/data/holidays.dart'; import 'package:photos/data/months.dart'; import 'package:photos/data/years.dart'; import 'package:photos/db/files_db.dart'; +import "package:photos/db/ml/clip_db.dart"; import "package:photos/db/ml/db.dart"; import 'package:photos/events/local_photos_updated_event.dart'; import "package:photos/extensions/user_extension.dart"; @@ -40,6 +42,7 @@ import "package:photos/services/filter/db_filters.dart"; import "package:photos/services/location_service.dart"; import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; +import "package:photos/services/machine_learning/ml_computer.dart"; import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart'; import "package:photos/services/user_remote_flag_service.dart"; import "package:photos/states/location_screen_state.dart"; @@ -1233,17 +1236,18 @@ class SearchService { final date = DateTime(currentTime.year, dayMonth ~/ 100, dayMonth % 100); final allPhotos = yearGroups.values.expand((x) => x).toList(); + final photoSelection = await _bestSelection(allPhotos); searchResults.add( GenericSearchResult( ResultType.event, 'Memories of ${DateFormat('MMMM d').format(date)}', - allPhotos, + photoSelection, hierarchicalSearchFilter: TopLevelGenericFilter( filterName: DateFormat('MMMM d').format(date), occurrence: kMostRelevantFilter, filterResultType: ResultType.event, - matchedUploadedIDs: filesToUploadedFileIDs(allPhotos), + matchedUploadedIDs: filesToUploadedFileIDs(photoSelection), filterIcon: Icons.event_outlined, ), ), @@ -1253,6 +1257,7 @@ class SearchService { for (final year in significantYears) { final date = DateTime(year, dayMonth ~/ 100, dayMonth % 100); final files = yearGroups[year]!; + final photoSelection = await _bestSelection(files); String name = DateFormat.yMMMd(Localizations.localeOf(context).languageCode) .format(date); @@ -1264,12 +1269,12 @@ class SearchService { GenericSearchResult( ResultType.event, name, - files, + photoSelection, hierarchicalSearchFilter: TopLevelGenericFilter( filterName: name, occurrence: kMostRelevantFilter, filterResultType: ResultType.event, - matchedUploadedIDs: filesToUploadedFileIDs(files), + matchedUploadedIDs: filesToUploadedFileIDs(photoSelection), filterIcon: Icons.event_outlined, ), ), @@ -1283,6 +1288,83 @@ class SearchService { return searchResults; } + /// Returns the best selection of files from the given list. + /// Makes sure that the selection is not more than 10 files, + /// and that each year of the original list is represented. + Future> _bestSelection(List files) async { + if (files.length <= 10) return files; + final fileIDs = files.map((e) => e.uploadedFileID!).toSet(); + final fileIdToFace = await MLDataDB.instance.getFacesForFileIDs(fileIDs); + final fileIdToClip = + await MLDataDB.instance.getClipVectorsForFileIDs(fileIDs); + final allYears = files.map((e) { + final creationTime = DateTime.fromMicrosecondsSinceEpoch(e.creationTime!); + return creationTime.year; + }).toSet(); + if (allYears.length <= 1) { + // sort first on clip embeddings TODO: lau: come up with a better prompt + const query = 'Beautiful photo of a nice memory'; + // TODO: lau: optimize this later so we don't keep computing embedding + final textEmbedding = await MLComputer.instance.runClipText(query); + final textVector = Vector.fromList(textEmbedding); + final fileToScore = {}; + for (final file in files) { + final clip = fileIdToClip[file.uploadedFileID!]; + if (clip == null) { + fileToScore[file.uploadedFileID!] = 0; + continue; + } + final score = clip.vector.dot(textVector); + fileToScore[file.uploadedFileID!] = score; + } + files.sort( + (a, b) => fileToScore[b.uploadedFileID!]! + .compareTo(fileToScore[a.uploadedFileID!]!), + ); + + // then sort on faces + final fileToFaceCount = {}; + for (final file in files) { + final faces = fileIdToFace[file.uploadedFileID!]; + if (faces == null) { + fileToFaceCount[file.uploadedFileID!] = 0; + continue; + } + fileToFaceCount[file.uploadedFileID!] = faces.length; + } + files.sort( + (a, b) => fileToFaceCount[b.uploadedFileID!]! + .compareTo(fileToFaceCount[a.uploadedFileID!]!), + ); + + // then filter out similar images as much as possible + final filteredFiles = [files.first]; + final fileCount = files.length; + int skipped = 0; + filesLoop: + for (final file in files.sublist(1)) { + if (filteredFiles.length >= 10) break; + final clip = fileIdToClip[file.uploadedFileID!]; + if (clip != null && (fileCount - skipped) > 10) { + for (final filteredFile in filteredFiles) { + final fClip = fileIdToClip[filteredFile.uploadedFileID!]; + if (fClip == null) continue; + final similarity = clip.vector.dot(fClip.vector); + if (similarity > 0.15) { + skipped++; + continue filesLoop; + } + } + } + filteredFiles.add(file); + } + return filteredFiles; + } else { + // TODO: lau: add logic for multiple years. Main extra thing is getting distribution over the years right + return files; + } + } + Future getRandomDateResults( BuildContext context, ) async {