diff --git a/mobile/lib/main.dart b/mobile/lib/main.dart index f180f2bfff..cba068c51b 100644 --- a/mobile/lib/main.dart +++ b/mobile/lib/main.dart @@ -37,6 +37,7 @@ import "package:photos/services/machine_learning/face_ml/person/person_service.d import 'package:photos/services/machine_learning/file_ml/remote_fileml_service.dart'; import "package:photos/services/machine_learning/machine_learning_controller.dart"; import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart'; +import "package:photos/services/magic_cache_service.dart"; import 'package:photos/services/memories_service.dart'; import 'package:photos/services/push_service.dart'; import 'package:photos/services/remote_sync_service.dart'; @@ -303,6 +304,8 @@ Future _init(bool isBackground, {String via = ''}) async { preferences, ); + MagicCacheService.instance.init(preferences); + initComplete = true; _logger.info("Initialization done"); } catch (e, s) { diff --git a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart index d65c67aba3..1de02434e3 100644 --- a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart +++ b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart @@ -277,6 +277,49 @@ class SemanticSearchService { return results; } + Future> getMatchingFileIDs(String query, double minScore) async { + final textEmbedding = await _getTextEmbedding(query); + + final queryResults = + await _getScores(textEmbedding, scoreThreshold: minScore); + + final queryResultIds = []; + for (QueryResult result in queryResults) { + queryResultIds.add(result.id); + } + + final filesMap = await FilesDB.instance.getFilesFromIDs( + queryResultIds, + ); + final results = []; + + final ignoredCollections = + CollectionsService.instance.getHiddenCollectionIds(); + final deletedEntries = []; + for (final result in queryResults) { + final file = filesMap[result.id]; + if (file != null && !ignoredCollections.contains(file.collectionID)) { + results.add(file); + } + if (file == null) { + deletedEntries.add(result.id); + } + } + + _logger.info(results.length.toString() + " results"); + + if (deletedEntries.isNotEmpty) { + unawaited(EmbeddingsDB.instance.deleteEmbeddings(deletedEntries)); + } + + final matchingFileIDs = []; + for (EnteFile file in results) { + matchingFileIDs.add(file.uploadedFileID!); + } + + return matchingFileIDs; + } + void _addToQueue(EnteFile file) { if (!LocalSettings.instance.hasEnabledMagicSearch()) { return; diff --git a/mobile/lib/services/magic_cache_service.dart b/mobile/lib/services/magic_cache_service.dart new file mode 100644 index 0000000000..e02f07b77a --- /dev/null +++ b/mobile/lib/services/magic_cache_service.dart @@ -0,0 +1,142 @@ +import "dart:convert"; +import 'dart:math'; + +import "package:logging/logging.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/models/search/generic_search_result.dart"; +import "package:photos/models/search/search_types.dart"; +import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart"; +import "package:photos/services/remote_assets_service.dart"; +import "package:photos/services/search_service.dart"; +import "package:shared_preferences/shared_preferences.dart"; + +class MagicCache { + final String title; + final List fileUploadedIDs; + MagicCache(this.title, this.fileUploadedIDs); + + factory MagicCache.fromJson(Map json) { + return MagicCache( + json['title'], + List.from(json['fileUploadedIDs']), + ); + } + + Map toJson() { + return { + 'title': title, + 'fileUploadedIDs': fileUploadedIDs, + }; + } + + static String encodeListToJson(List magicCaches) { + final jsonList = magicCaches.map((cache) => cache.toJson()).toList(); + return jsonEncode(jsonList); + } + + static List decodeJsonToList(String jsonString) { + final jsonList = jsonDecode(jsonString) as List; + return jsonList.map((json) => MagicCache.fromJson(json)).toList(); + } +} + +extension MagicCacheServiceExtension on MagicCache { + Future toGenericSearchResult() async { + final allEnteFiles = await SearchService.instance.getAllFiles(); + final enteFilesInMagicCache = []; + for (EnteFile file in allEnteFiles) { + if (file.uploadedFileID != null && + fileUploadedIDs.contains(file.uploadedFileID as int)) { + enteFilesInMagicCache.add(file); + } + } + return GenericSearchResult( + ResultType.magic, + title, + enteFilesInMagicCache, + ); + } +} + +class MagicCacheService { + static const _key = "magic_cache"; + static const _kMagicPromptsDataUrl = "https://discover.ente.io/v1.json"; + + late SharedPreferences prefs; + final Logger _logger = Logger((MagicCacheService).toString()); + MagicCacheService._privateConstructor(); + + static final MagicCacheService instance = + MagicCacheService._privateConstructor(); + + void init(SharedPreferences preferences) { + prefs = preferences; + _updateCacheIfNeeded(); + } + + Future _updateCacheIfNeeded() async { + final jsonFile = await RemoteAssetsService.instance + .getAssetIfUpdated(_kMagicPromptsDataUrl); + if (jsonFile == null) {} + } + + Future>> getMatchingFileIDsForPromptData( + Map promptData, + ) async { + final result = await SemanticSearchService.instance.getMatchingFileIDs( + promptData["prompt"] as String, + promptData["minimumScore"] as double, + ); + + return {promptData["title"] as String: result}; + } + + Future updateMagicCache() async { + final magicCaches = []; + await prefs.setString( + _key, + MagicCache.encodeListToJson(magicCaches), + ); + } + + Future?> getMagicCache() async { + final jsonString = prefs.getString(_key); + if (jsonString == null) { + _logger.info("No $_key in shared preferences"); + return null; + } + return MagicCache.decodeJsonToList(jsonString); + } + + Future clearMagicCache() async { + await prefs.remove(_key); + } + + Future> getMagicGenericSearchResult() async { + final magicCaches = await getMagicCache(); + if (magicCaches == null) { + _logger.info("No magic cache found"); + return []; + } + final List genericSearchResults = []; + for (MagicCache magicCache in magicCaches) { + final genericSearchResult = await magicCache.toGenericSearchResult(); + genericSearchResults.add(genericSearchResult); + } + return genericSearchResults; + } + + ///Generates from 0 to max non-repeating random numbers + List _generateUniqueRandomNumbers(int max, int count) { + final numbers = []; + for (int i = 1; i <= count;) { + final randomNumber = Random().nextInt(max + 1); + if (numbers.contains(randomNumber)) { + continue; + } + numbers.add(randomNumber); + i++; + } + return numbers; + } +}