[mob][photos] First iteration of file selection

This commit is contained in:
laurenspriem
2025-01-28 10:43:58 +05:30
parent a575e5f3c8
commit e50ab7ca84

View File

@@ -4,12 +4,14 @@ import "package:flutter/cupertino.dart";
import "package:flutter/material.dart";
import "package:intl/intl.dart";
import 'package:logging/logging.dart';
import "package:ml_linalg/linalg.dart";
import "package:photos/core/constants.dart";
import 'package:photos/core/event_bus.dart';
import 'package:photos/data/holidays.dart';
import 'package:photos/data/months.dart';
import 'package:photos/data/years.dart';
import 'package:photos/db/files_db.dart';
import "package:photos/db/ml/clip_db.dart";
import "package:photos/db/ml/db.dart";
import 'package:photos/events/local_photos_updated_event.dart';
import "package:photos/extensions/user_extension.dart";
@@ -40,6 +42,7 @@ import "package:photos/services/filter/db_filters.dart";
import "package:photos/services/location_service.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
import "package:photos/services/machine_learning/ml_computer.dart";
import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart';
import "package:photos/services/user_remote_flag_service.dart";
import "package:photos/states/location_screen_state.dart";
@@ -1233,17 +1236,18 @@ class SearchService {
final date =
DateTime(currentTime.year, dayMonth ~/ 100, dayMonth % 100);
final allPhotos = yearGroups.values.expand((x) => x).toList();
final photoSelection = await _bestSelection(allPhotos);
searchResults.add(
GenericSearchResult(
ResultType.event,
'Memories of ${DateFormat('MMMM d').format(date)}',
allPhotos,
photoSelection,
hierarchicalSearchFilter: TopLevelGenericFilter(
filterName: DateFormat('MMMM d').format(date),
occurrence: kMostRelevantFilter,
filterResultType: ResultType.event,
matchedUploadedIDs: filesToUploadedFileIDs(allPhotos),
matchedUploadedIDs: filesToUploadedFileIDs(photoSelection),
filterIcon: Icons.event_outlined,
),
),
@@ -1253,6 +1257,7 @@ class SearchService {
for (final year in significantYears) {
final date = DateTime(year, dayMonth ~/ 100, dayMonth % 100);
final files = yearGroups[year]!;
final photoSelection = await _bestSelection(files);
String name =
DateFormat.yMMMd(Localizations.localeOf(context).languageCode)
.format(date);
@@ -1264,12 +1269,12 @@ class SearchService {
GenericSearchResult(
ResultType.event,
name,
files,
photoSelection,
hierarchicalSearchFilter: TopLevelGenericFilter(
filterName: name,
occurrence: kMostRelevantFilter,
filterResultType: ResultType.event,
matchedUploadedIDs: filesToUploadedFileIDs(files),
matchedUploadedIDs: filesToUploadedFileIDs(photoSelection),
filterIcon: Icons.event_outlined,
),
),
@@ -1283,6 +1288,83 @@ class SearchService {
return searchResults;
}
/// Returns the best selection of files from the given list.
/// Makes sure that the selection is not more than 10 files,
/// and that each year of the original list is represented.
Future<List<EnteFile>> _bestSelection(List<EnteFile> files) async {
if (files.length <= 10) return files;
final fileIDs = files.map((e) => e.uploadedFileID!).toSet();
final fileIdToFace = await MLDataDB.instance.getFacesForFileIDs(fileIDs);
final fileIdToClip =
await MLDataDB.instance.getClipVectorsForFileIDs(fileIDs);
final allYears = files.map((e) {
final creationTime = DateTime.fromMicrosecondsSinceEpoch(e.creationTime!);
return creationTime.year;
}).toSet();
if (allYears.length <= 1) {
// sort first on clip embeddings TODO: lau: come up with a better prompt
const query = 'Beautiful photo of a nice memory';
// TODO: lau: optimize this later so we don't keep computing embedding
final textEmbedding = await MLComputer.instance.runClipText(query);
final textVector = Vector.fromList(textEmbedding);
final fileToScore = <int, double>{};
for (final file in files) {
final clip = fileIdToClip[file.uploadedFileID!];
if (clip == null) {
fileToScore[file.uploadedFileID!] = 0;
continue;
}
final score = clip.vector.dot(textVector);
fileToScore[file.uploadedFileID!] = score;
}
files.sort(
(a, b) => fileToScore[b.uploadedFileID!]!
.compareTo(fileToScore[a.uploadedFileID!]!),
);
// then sort on faces
final fileToFaceCount = <int, int>{};
for (final file in files) {
final faces = fileIdToFace[file.uploadedFileID!];
if (faces == null) {
fileToFaceCount[file.uploadedFileID!] = 0;
continue;
}
fileToFaceCount[file.uploadedFileID!] = faces.length;
}
files.sort(
(a, b) => fileToFaceCount[b.uploadedFileID!]!
.compareTo(fileToFaceCount[a.uploadedFileID!]!),
);
// then filter out similar images as much as possible
final filteredFiles = <EnteFile>[files.first];
final fileCount = files.length;
int skipped = 0;
filesLoop:
for (final file in files.sublist(1)) {
if (filteredFiles.length >= 10) break;
final clip = fileIdToClip[file.uploadedFileID!];
if (clip != null && (fileCount - skipped) > 10) {
for (final filteredFile in filteredFiles) {
final fClip = fileIdToClip[filteredFile.uploadedFileID!];
if (fClip == null) continue;
final similarity = clip.vector.dot(fClip.vector);
if (similarity > 0.15) {
skipped++;
continue filesLoop;
}
}
}
filteredFiles.add(file);
}
return filteredFiles;
} else {
// TODO: lau: add logic for multiple years. Main extra thing is getting distribution over the years right
return files;
}
}
Future<GenericSearchResult?> getRandomDateResults(
BuildContext context,
) async {