Compare commits

...

1 Commits

Author SHA1 Message Date
laurenspriem
dd8b94d1a7 Move embeddings for discover section in assets 2025-08-23 17:00:19 +05:30
6 changed files with 3054 additions and 81 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,3 @@
// https://discover.ente.io/v1.json
class Prompt {
final String id;
final int position;
@@ -9,6 +8,17 @@ class Prompt {
final bool showVideo;
final bool recentFirst;
Prompt({
String? id,
this.position = 0,
required this.query,
required this.minScore,
required this.minSize,
required this.title,
this.showVideo = true,
this.recentFirst = false,
}) : id = id ?? title;
// fromJson
Prompt.fromJson(Map<String, dynamic> json)
: id = json['id'] ?? json['title'],

View File

@@ -193,22 +193,13 @@ class SemanticSearchService {
return results;
}
Future<Map<String, List<int>>> getMatchingFileIDs(
Future<Map<String, List<int>>> getMatchingFileIDsWithEmbeddings(
Map<String, List<double>> queryToEmbedding,
Map<String, double> queryToScore,
) async {
final textEmbeddings = <String, List<double>>{};
final minimumSimilarityMap = <String, double>{};
for (final entry in queryToScore.entries) {
final query = entry.key;
final score = entry.value;
final textEmbedding = await _getTextEmbedding(query);
textEmbeddings[query] = textEmbedding;
minimumSimilarityMap[query] = score;
}
final queryResults = await _getSimilarities(
textEmbeddings,
minimumSimilarityMap: minimumSimilarityMap,
queryToEmbedding,
minimumSimilarityMap: queryToScore,
);
final result = <String, List<int>>{};
for (final entry in queryResults.entries) {

View File

@@ -2,7 +2,6 @@ import "dart:async";
import "dart:convert";
import "dart:io";
import "package:computer/computer.dart";
import "package:flutter/foundation.dart";
import "package:flutter/widgets.dart";
import "package:logging/logging.dart";
@@ -21,12 +20,12 @@ import "package:photos/models/search/hierarchical/magic_filter.dart";
import "package:photos/models/search/search_types.dart";
import "package:photos/service_locator.dart";
import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart";
import "package:photos/services/remote_assets_service.dart";
import "package:photos/services/search_service.dart";
import "package:photos/ui/viewer/search/result/magic_result_screen.dart";
import "package:photos/utils/cache_util.dart";
import "package:photos/utils/file_util.dart";
import "package:photos/utils/navigation_util.dart";
import "package:photos/utils/text_embeddings_util.dart";
import "package:shared_preferences/shared_preferences.dart";
class MagicCache {
@@ -119,8 +118,13 @@ GenericSearchResult? toGenericSearchResult(
}
if (!prompt.recentFirst) {
enteFilesInMagicCache.sort((a, b) {
return fileIdToPositionMap[a.uploadedFileID!]!
.compareTo(fileIdToPositionMap[b.uploadedFileID!]!);
final aID = a.uploadedFileID;
final bID = b.uploadedFileID;
if (aID == null || bID == null) return 0;
final aPos = fileIdToPositionMap[aID];
final bPos = fileIdToPositionMap[bID];
if (aPos == null || bPos == null) return 0;
return aPos.compareTo(bPos);
});
}
final String title = getLocalizedTitle(context, prompt.title);
@@ -168,7 +172,6 @@ GenericSearchResult? toGenericSearchResult(
class MagicCacheService {
static const _lastMagicCacheUpdateTime = "last_magic_cache_update_time";
static const _kMagicPromptsDataUrl = "https://discover.ente.io/v2.json";
/// Delay is for cache update to be done not during app init, during which a
/// lot of other things are happening.
@@ -178,7 +181,6 @@ class MagicCacheService {
late final Logger _logger = Logger((MagicCacheService).toString());
Future<List<MagicCache>>? _magicCacheFuture;
Future<List<Prompt>>? _promptFuture;
final Set<String> _pendingUpdateReason = {};
bool _isUpdateInProgress = false;
@@ -213,11 +215,7 @@ class MagicCacheService {
if (!enableDiscover) {
return;
}
final updatedJSONFile = await RemoteAssetsService.instance
.getAssetIfUpdated(_kMagicPromptsDataUrl);
if (updatedJSONFile != null) {
queueUpdate("Prompts data updated");
} else if (lastMagicCacheUpdateTime <
if (lastMagicCacheUpdateTime <
DateTime.now()
.subtract(const Duration(hours: 12))
.millisecondsSinceEpoch) {
@@ -247,10 +245,7 @@ class MagicCacheService {
_isUpdateInProgress = true;
final EnteWatch? w = kDebugMode ? EnteWatch("magicCacheWatch") : null;
w?.start();
final magicPromptsData = await getPrompts();
w?.log("loadedPrompts");
final List<MagicCache> magicCaches =
await _nonEmptyMagicResults(magicPromptsData);
final List<MagicCache> magicCaches = await _nonEmptyMagicResults();
w?.log("resultComputed");
_magicCacheFuture = Future.value(magicCaches);
await writeToJsonFile<List<MagicCache>>(
@@ -262,7 +257,6 @@ class MagicCacheService {
await _resetLastMagicCacheUpdateTime();
w?.logAndReset('done');
_pendingUpdateReason.clear();
Bus.instance.fire(MagicCacheUpdatedEvent());
} catch (e, s) {
_logger.info("Error updating magic cache", e, s);
} finally {
@@ -271,14 +265,6 @@ class MagicCacheService {
}
}
Future<List<Prompt>> getPrompts() async {
if (_promptFuture != null) {
return _promptFuture!;
}
_promptFuture = _readPromptFromDiskOrNetwork();
return _promptFuture!;
}
Future<List<MagicCache>> getMagicCache() async {
if (_magicCacheFuture != null) {
return _magicCacheFuture!;
@@ -287,17 +273,6 @@ class MagicCacheService {
return _magicCacheFuture!;
}
Future<List<Prompt>> _readPromptFromDiskOrNetwork() async {
final String path =
await RemoteAssetsService.instance.getAssetPath(_kMagicPromptsDataUrl);
return Computer.shared().compute(
_loadMagicPrompts,
param: <String, dynamic>{
"path": path,
},
);
}
Future<List<MagicCache>> _readResultFromDisk() async {
_logger.info("Reading magic cache result from disk");
final cache = await decodeJsonFile<List<MagicCache>>(
@@ -312,6 +287,7 @@ class MagicCacheService {
if (file.existsSync()) {
await file.delete();
}
_magicCacheFuture = null;
}
Future<List<GenericSearchResult>> getMagicGenericSearchResult(
@@ -322,7 +298,30 @@ class MagicCacheService {
kDebugMode ? EnteWatch("magicGenericSearchResult") : null;
w?.start();
final magicCaches = await getMagicCache();
final List<Prompt> prompts = await getPrompts();
// Load discover embeddings to get prompts
final discoverEmbeddings = await loadDiscoverEmbeddings();
if (discoverEmbeddings == null) {
_logger.severe("No discover embeddings available in assets");
throw Exception("No discover embeddings available in assets");
}
final prompts = <Prompt>[];
for (final entry in discoverEmbeddings.queryToPromptData.entries) {
final query = entry.key;
final promptData = entry.value;
prompts.add(
Prompt(
query: query,
title: promptData.title,
minScore: promptData.minScore,
minSize: promptData.minSize,
showVideo: promptData.showVideo ?? true,
recentFirst: promptData.recentFirst ?? false,
),
);
}
if (magicCaches.isEmpty) {
w?.log("No magic cache found");
return [];
@@ -375,35 +374,53 @@ class MagicCacheService {
}
}
static Future<List<Prompt>> _loadMagicPrompts(
Map<String, dynamic> args,
) async {
final String path = args["path"] as String;
final File file = File(path);
final String contents = await file.readAsString();
final Map<String, dynamic> promptsJson = jsonDecode(contents);
final List<dynamic> promptData = promptsJson['prompts'];
return promptData
.map<Prompt>((jsonItem) => Prompt.fromJson(jsonItem))
.toList();
}
///Returns non-empty magic results from magicPromptsData
///Length is number of prompts, can be less if there are not enough non-empty
///results
Future<List<MagicCache>> _nonEmptyMagicResults(
List<Prompt> magicPromptsData,
) async {
Future<List<MagicCache>> _nonEmptyMagicResults() async {
final TimeLogger t = TimeLogger();
final results = <MagicCache>[];
final List<int> matchCount = [];
final Map<String, double> queryToScore = {};
for (Prompt prompt in magicPromptsData) {
queryToScore[prompt.query] = prompt.minScore;
// Uncomment this code to generate embeddings and save them to a JSON file
// await generateAndSaveDiscoverEmbeddings();
// (from package:photos/utils/text_embeddings_util.dart)
// Load pre-computed discover embeddings from assets
final discoverEmbeddings = await loadDiscoverEmbeddings();
if (discoverEmbeddings == null) {
_logger.severe('Failed to load discover embeddings');
throw Exception('Failed to load discover embeddings');
}
final clipResults =
await SemanticSearchService.instance.getMatchingFileIDs(queryToScore);
for (Prompt prompt in magicPromptsData) {
// Build the embeddings and score maps
final Map<String, List<double>> queryToEmbedding = {};
final Map<String, double> queryToScore = {};
final List<Prompt> prompts = [];
for (final entry in discoverEmbeddings.queryToPromptData.entries) {
final query = entry.key;
final promptData = entry.value;
final vector = discoverEmbeddings.queryToVector[query];
if (vector != null) {
queryToEmbedding[query] = vector.toList();
queryToScore[query] = promptData.minScore;
prompts.add(
Prompt(
query: query,
title: promptData.title,
minScore: promptData.minScore,
minSize: promptData.minSize,
showVideo: promptData.showVideo ?? true,
recentFirst: promptData.recentFirst ?? false,
),
);
}
}
_logger.info('Using pre-computed discover embeddings from assets');
final clipResults = await SemanticSearchService.instance
.getMatchingFileIDsWithEmbeddings(queryToEmbedding, queryToScore);
for (final prompt in prompts) {
final List<int> fileUploadedIDs = clipResults[prompt.query] ?? [];
if (fileUploadedIDs.isNotEmpty) {
results.add(

View File

@@ -104,7 +104,7 @@ class SmartMemoriesService {
);
// Load pre-computed text embeddings from assets
final textEmbeddings = await loadTextEmbeddingsFromAssets();
final textEmbeddings = await loadMemoriesEmbeddings();
if (textEmbeddings == null) {
_logger.severe('Failed to load pre-computed text embeddings');
throw Exception(

View File

@@ -9,13 +9,14 @@ import "package:path_provider/path_provider.dart"
show getExternalStorageDirectory;
import 'package:photos/models/memories/clip_memory.dart';
import 'package:photos/models/memories/people_memory.dart';
import 'package:photos/models/ml/discover/prompt.dart';
import "package:photos/services/machine_learning/ml_computer.dart"
show MLComputer;
final _logger = Logger('TextEmbeddingsUtil');
/// Loads pre-computed text embeddings from assets
Future<TextEmbeddings?> loadTextEmbeddingsFromAssets() async {
/// Loads pre-computed memories text embeddings from assets
Future<MemoriesEmbeddings?> loadMemoriesEmbeddings() async {
try {
_logger.info('Loading text embeddings from assets');
final jsonString =
@@ -81,7 +82,7 @@ Future<TextEmbeddings?> loadTextEmbeddingsFromAssets() async {
}
_logger.info('Text embeddings loaded successfully from JSON assets');
return TextEmbeddings(
return MemoriesEmbeddings(
clipPositiveVector: clipPositiveVector,
peopleActivityVectors: peopleActivityVectors,
clipMemoryTypeVectors: clipMemoryTypeVectors,
@@ -92,18 +93,94 @@ Future<TextEmbeddings?> loadTextEmbeddingsFromAssets() async {
}
}
class TextEmbeddings {
/// Container for all memories text embeddings
class MemoriesEmbeddings {
final Vector clipPositiveVector;
final Map<PeopleActivity, Vector> peopleActivityVectors;
final Map<ClipMemoryType, Vector> clipMemoryTypeVectors;
const TextEmbeddings({
const MemoriesEmbeddings({
required this.clipPositiveVector,
required this.peopleActivityVectors,
required this.clipMemoryTypeVectors,
});
}
/// Loads pre-computed discover text embeddings from assets
Future<DiscoverEmbeddings?> loadDiscoverEmbeddings() async {
try {
_logger.info('Loading discover embeddings from assets');
final jsonString =
await rootBundle.loadString('assets/ml/discover_embeddings.json');
final data = json.decode(jsonString) as Map<String, dynamic>;
final promptsList = data['prompts'] as List;
final Map<String, Vector> queryToVector = {};
final Map<String, PromptData> queryToPromptData = {};
// Parse all discover embeddings
for (final promptJson in promptsList) {
final query = promptJson['query'] as String;
final vector = (promptJson['vector'] as List).cast<double>();
if (vector.isNotEmpty) {
queryToVector[query] = Vector.fromList(vector);
queryToPromptData[query] = PromptData(
title: promptJson['title'] as String,
minScore: (promptJson['minScore'] as num).toDouble(),
minSize: (promptJson['minSize'] as num).toDouble(),
showVideo: promptJson['showVideo'] as bool?,
recentFirst: promptJson['recentFirst'] as bool?,
);
}
}
// Check if we have all embeddings
if (queryToVector.isEmpty) {
_logger.warning('No discover embeddings found');
return null;
}
_logger.info(
'Discover embeddings loaded successfully with ${queryToVector.length} queries',
);
return DiscoverEmbeddings(
queryToVector: queryToVector,
queryToPromptData: queryToPromptData,
);
} catch (e, stackTrace) {
_logger.severe('Failed to load discover embeddings', e, stackTrace);
return null;
}
}
/// Container for discover/magic embeddings
class DiscoverEmbeddings {
final Map<String, Vector> queryToVector;
final Map<String, PromptData> queryToPromptData;
const DiscoverEmbeddings({
required this.queryToVector,
required this.queryToPromptData,
});
}
/// Container for prompt metadata
class PromptData {
final String title;
final double minScore;
final double minSize;
final bool? showVideo;
final bool? recentFirst;
const PromptData({
required this.title,
required this.minScore,
required this.minSize,
this.showVideo,
this.recentFirst,
});
}
/// Helper function to generate text embeddings and save them to a JSON file
/// Run this once to generate the embeddings, then copy the output
/// to assets/ml/text_embeddings.json
@@ -170,3 +247,161 @@ Future<void> generateAndSaveTextEmbeddings() async {
'_generateAndSaveTextEmbeddings: Text embeddings generation complete! Copy the JSON output above to assets/ml/text_embeddings.json',
);
}
/// Helper function to generate discover text embeddings and save them to a JSON file
/// Run this once to generate the embeddings, then copy the output
/// to assets/ml/discover_embeddings.json
Future<void> generateAndSaveDiscoverEmbeddings() async {
// Hardcoded prompts from v2.json
final magicPromptsData = [
Prompt(
query: "identity documents and cards",
title: "Identity",
minScore: 0.20,
minSize: 0.0,
showVideo: false,
),
Prompt(
query: "screenshots",
title: "Screenshots",
minScore: 0.18,
minSize: 0.0,
recentFirst: true,
showVideo: false,
),
Prompt(
query: "receipts, bills and invoices",
title: "Receipts",
minScore: 0.20,
minSize: 0.0,
recentFirst: true,
),
Prompt(
query: "screenshots of digital notes, photos of notebook and whiteboards",
title: "Notes",
minScore: 0.20,
minSize: 0.0,
recentFirst: true,
),
Prompt(
query: "memes",
title: "Memes",
minScore: 0.18,
minSize: 0.0,
recentFirst: true,
),
Prompt(
query: "business card",
title: "Visiting Cards",
minScore: 0.20,
minSize: 0.0,
recentFirst: true,
showVideo: false,
),
Prompt(
query: "photo of toddler",
title: "Babies",
minScore: 0.20,
minSize: 0.0,
recentFirst: true,
),
Prompt(
query: "photos of pets",
title: "Pets",
minScore: 0.20,
minSize: 0.0,
recentFirst: true,
),
Prompt(
query: "selfie",
title: "Selfies",
minScore: 0.18,
minSize: 0.0,
recentFirst: true,
),
Prompt(
query: "high resolution hd wallpaper",
title: "Wallpapers",
minScore: 0.18,
minSize: 0.0,
showVideo: false,
),
Prompt(
query: "photo of food",
title: "Food",
minScore: 0.20,
minSize: 0.0,
recentFirst: true,
),
Prompt(
query: "celebration with friends",
title: "Celebrations",
minScore: 0.20,
minSize: 0.0,
recentFirst: true,
),
Prompt(
query: "photo of setting sun",
title: "Sunset",
minScore: 0.20,
minSize: 0.0,
),
Prompt(
query: "photo of mountains or hills",
title: "Hills",
minScore: 0.20,
minSize: 0.0,
),
Prompt(
query: "outdoor photo of trees, leaves, flowers",
title: "Greenery",
minScore: 0.20,
minSize: 0.0,
),
];
final Map<String, dynamic> embeddingsData = {
'version': '1.0.0',
'prompts': [],
};
final prompts = <Map<String, dynamic>>[];
for (final prompt in magicPromptsData) {
_logger.info('Generating embedding for discover prompt: ${prompt.title}');
final vector = await MLComputer.instance.runClipText(prompt.query);
final promptData = {
'query': prompt.query,
'title': prompt.title,
'minScore': prompt.minScore,
'minSize': prompt.minSize,
'vector': vector,
};
// Add optional fields
promptData['showVideo'] = prompt.showVideo;
promptData['recentFirst'] = prompt.recentFirst;
prompts.add(promptData);
}
embeddingsData['prompts'] = prompts;
// Convert to JSON and log it
final jsonString = const JsonEncoder.withIndent(' ').convert(embeddingsData);
dev.log(
'_generateAndSaveDiscoverEmbeddings: Generated discover embeddings JSON',
);
final tempDir = await getExternalStorageDirectory();
final file = File('${tempDir!.path}/discover_embeddings.json');
await file.writeAsString(jsonString);
dev.log(
'_generateAndSaveDiscoverEmbeddings: Saved discover embeddings to ${file.path}',
);
dev.log(
'_generateAndSaveDiscoverEmbeddings: Discover embeddings generation complete! Copy the JSON output above to assets/ml/discover_embeddings.json',
);
}