Compare commits
1 Commits
fixReqVal
...
discover_e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dd8b94d1a7 |
2720
mobile/apps/photos/assets/ml/discover_embeddings.json
Normal file
2720
mobile/apps/photos/assets/ml/discover_embeddings.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,3 @@
|
|||||||
// https://discover.ente.io/v1.json
|
|
||||||
class Prompt {
|
class Prompt {
|
||||||
final String id;
|
final String id;
|
||||||
final int position;
|
final int position;
|
||||||
@@ -9,6 +8,17 @@ class Prompt {
|
|||||||
final bool showVideo;
|
final bool showVideo;
|
||||||
final bool recentFirst;
|
final bool recentFirst;
|
||||||
|
|
||||||
|
Prompt({
|
||||||
|
String? id,
|
||||||
|
this.position = 0,
|
||||||
|
required this.query,
|
||||||
|
required this.minScore,
|
||||||
|
required this.minSize,
|
||||||
|
required this.title,
|
||||||
|
this.showVideo = true,
|
||||||
|
this.recentFirst = false,
|
||||||
|
}) : id = id ?? title;
|
||||||
|
|
||||||
// fromJson
|
// fromJson
|
||||||
Prompt.fromJson(Map<String, dynamic> json)
|
Prompt.fromJson(Map<String, dynamic> json)
|
||||||
: id = json['id'] ?? json['title'],
|
: id = json['id'] ?? json['title'],
|
||||||
|
|||||||
@@ -193,22 +193,13 @@ class SemanticSearchService {
|
|||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
Future<Map<String, List<int>>> getMatchingFileIDs(
|
Future<Map<String, List<int>>> getMatchingFileIDsWithEmbeddings(
|
||||||
|
Map<String, List<double>> queryToEmbedding,
|
||||||
Map<String, double> queryToScore,
|
Map<String, double> queryToScore,
|
||||||
) async {
|
) async {
|
||||||
final textEmbeddings = <String, List<double>>{};
|
|
||||||
final minimumSimilarityMap = <String, double>{};
|
|
||||||
for (final entry in queryToScore.entries) {
|
|
||||||
final query = entry.key;
|
|
||||||
final score = entry.value;
|
|
||||||
final textEmbedding = await _getTextEmbedding(query);
|
|
||||||
textEmbeddings[query] = textEmbedding;
|
|
||||||
minimumSimilarityMap[query] = score;
|
|
||||||
}
|
|
||||||
|
|
||||||
final queryResults = await _getSimilarities(
|
final queryResults = await _getSimilarities(
|
||||||
textEmbeddings,
|
queryToEmbedding,
|
||||||
minimumSimilarityMap: minimumSimilarityMap,
|
minimumSimilarityMap: queryToScore,
|
||||||
);
|
);
|
||||||
final result = <String, List<int>>{};
|
final result = <String, List<int>>{};
|
||||||
for (final entry in queryResults.entries) {
|
for (final entry in queryResults.entries) {
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import "dart:async";
|
|||||||
import "dart:convert";
|
import "dart:convert";
|
||||||
import "dart:io";
|
import "dart:io";
|
||||||
|
|
||||||
import "package:computer/computer.dart";
|
|
||||||
import "package:flutter/foundation.dart";
|
import "package:flutter/foundation.dart";
|
||||||
import "package:flutter/widgets.dart";
|
import "package:flutter/widgets.dart";
|
||||||
import "package:logging/logging.dart";
|
import "package:logging/logging.dart";
|
||||||
@@ -21,12 +20,12 @@ import "package:photos/models/search/hierarchical/magic_filter.dart";
|
|||||||
import "package:photos/models/search/search_types.dart";
|
import "package:photos/models/search/search_types.dart";
|
||||||
import "package:photos/service_locator.dart";
|
import "package:photos/service_locator.dart";
|
||||||
import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart";
|
import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart";
|
||||||
import "package:photos/services/remote_assets_service.dart";
|
|
||||||
import "package:photos/services/search_service.dart";
|
import "package:photos/services/search_service.dart";
|
||||||
import "package:photos/ui/viewer/search/result/magic_result_screen.dart";
|
import "package:photos/ui/viewer/search/result/magic_result_screen.dart";
|
||||||
import "package:photos/utils/cache_util.dart";
|
import "package:photos/utils/cache_util.dart";
|
||||||
import "package:photos/utils/file_util.dart";
|
import "package:photos/utils/file_util.dart";
|
||||||
import "package:photos/utils/navigation_util.dart";
|
import "package:photos/utils/navigation_util.dart";
|
||||||
|
import "package:photos/utils/text_embeddings_util.dart";
|
||||||
import "package:shared_preferences/shared_preferences.dart";
|
import "package:shared_preferences/shared_preferences.dart";
|
||||||
|
|
||||||
class MagicCache {
|
class MagicCache {
|
||||||
@@ -119,8 +118,13 @@ GenericSearchResult? toGenericSearchResult(
|
|||||||
}
|
}
|
||||||
if (!prompt.recentFirst) {
|
if (!prompt.recentFirst) {
|
||||||
enteFilesInMagicCache.sort((a, b) {
|
enteFilesInMagicCache.sort((a, b) {
|
||||||
return fileIdToPositionMap[a.uploadedFileID!]!
|
final aID = a.uploadedFileID;
|
||||||
.compareTo(fileIdToPositionMap[b.uploadedFileID!]!);
|
final bID = b.uploadedFileID;
|
||||||
|
if (aID == null || bID == null) return 0;
|
||||||
|
final aPos = fileIdToPositionMap[aID];
|
||||||
|
final bPos = fileIdToPositionMap[bID];
|
||||||
|
if (aPos == null || bPos == null) return 0;
|
||||||
|
return aPos.compareTo(bPos);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
final String title = getLocalizedTitle(context, prompt.title);
|
final String title = getLocalizedTitle(context, prompt.title);
|
||||||
@@ -168,7 +172,6 @@ GenericSearchResult? toGenericSearchResult(
|
|||||||
|
|
||||||
class MagicCacheService {
|
class MagicCacheService {
|
||||||
static const _lastMagicCacheUpdateTime = "last_magic_cache_update_time";
|
static const _lastMagicCacheUpdateTime = "last_magic_cache_update_time";
|
||||||
static const _kMagicPromptsDataUrl = "https://discover.ente.io/v2.json";
|
|
||||||
|
|
||||||
/// Delay is for cache update to be done not during app init, during which a
|
/// Delay is for cache update to be done not during app init, during which a
|
||||||
/// lot of other things are happening.
|
/// lot of other things are happening.
|
||||||
@@ -178,7 +181,6 @@ class MagicCacheService {
|
|||||||
late final Logger _logger = Logger((MagicCacheService).toString());
|
late final Logger _logger = Logger((MagicCacheService).toString());
|
||||||
|
|
||||||
Future<List<MagicCache>>? _magicCacheFuture;
|
Future<List<MagicCache>>? _magicCacheFuture;
|
||||||
Future<List<Prompt>>? _promptFuture;
|
|
||||||
final Set<String> _pendingUpdateReason = {};
|
final Set<String> _pendingUpdateReason = {};
|
||||||
bool _isUpdateInProgress = false;
|
bool _isUpdateInProgress = false;
|
||||||
|
|
||||||
@@ -213,11 +215,7 @@ class MagicCacheService {
|
|||||||
if (!enableDiscover) {
|
if (!enableDiscover) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
final updatedJSONFile = await RemoteAssetsService.instance
|
if (lastMagicCacheUpdateTime <
|
||||||
.getAssetIfUpdated(_kMagicPromptsDataUrl);
|
|
||||||
if (updatedJSONFile != null) {
|
|
||||||
queueUpdate("Prompts data updated");
|
|
||||||
} else if (lastMagicCacheUpdateTime <
|
|
||||||
DateTime.now()
|
DateTime.now()
|
||||||
.subtract(const Duration(hours: 12))
|
.subtract(const Duration(hours: 12))
|
||||||
.millisecondsSinceEpoch) {
|
.millisecondsSinceEpoch) {
|
||||||
@@ -247,10 +245,7 @@ class MagicCacheService {
|
|||||||
_isUpdateInProgress = true;
|
_isUpdateInProgress = true;
|
||||||
final EnteWatch? w = kDebugMode ? EnteWatch("magicCacheWatch") : null;
|
final EnteWatch? w = kDebugMode ? EnteWatch("magicCacheWatch") : null;
|
||||||
w?.start();
|
w?.start();
|
||||||
final magicPromptsData = await getPrompts();
|
final List<MagicCache> magicCaches = await _nonEmptyMagicResults();
|
||||||
w?.log("loadedPrompts");
|
|
||||||
final List<MagicCache> magicCaches =
|
|
||||||
await _nonEmptyMagicResults(magicPromptsData);
|
|
||||||
w?.log("resultComputed");
|
w?.log("resultComputed");
|
||||||
_magicCacheFuture = Future.value(magicCaches);
|
_magicCacheFuture = Future.value(magicCaches);
|
||||||
await writeToJsonFile<List<MagicCache>>(
|
await writeToJsonFile<List<MagicCache>>(
|
||||||
@@ -262,7 +257,6 @@ class MagicCacheService {
|
|||||||
await _resetLastMagicCacheUpdateTime();
|
await _resetLastMagicCacheUpdateTime();
|
||||||
w?.logAndReset('done');
|
w?.logAndReset('done');
|
||||||
_pendingUpdateReason.clear();
|
_pendingUpdateReason.clear();
|
||||||
Bus.instance.fire(MagicCacheUpdatedEvent());
|
|
||||||
} catch (e, s) {
|
} catch (e, s) {
|
||||||
_logger.info("Error updating magic cache", e, s);
|
_logger.info("Error updating magic cache", e, s);
|
||||||
} finally {
|
} finally {
|
||||||
@@ -271,14 +265,6 @@ class MagicCacheService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Future<List<Prompt>> getPrompts() async {
|
|
||||||
if (_promptFuture != null) {
|
|
||||||
return _promptFuture!;
|
|
||||||
}
|
|
||||||
_promptFuture = _readPromptFromDiskOrNetwork();
|
|
||||||
return _promptFuture!;
|
|
||||||
}
|
|
||||||
|
|
||||||
Future<List<MagicCache>> getMagicCache() async {
|
Future<List<MagicCache>> getMagicCache() async {
|
||||||
if (_magicCacheFuture != null) {
|
if (_magicCacheFuture != null) {
|
||||||
return _magicCacheFuture!;
|
return _magicCacheFuture!;
|
||||||
@@ -287,17 +273,6 @@ class MagicCacheService {
|
|||||||
return _magicCacheFuture!;
|
return _magicCacheFuture!;
|
||||||
}
|
}
|
||||||
|
|
||||||
Future<List<Prompt>> _readPromptFromDiskOrNetwork() async {
|
|
||||||
final String path =
|
|
||||||
await RemoteAssetsService.instance.getAssetPath(_kMagicPromptsDataUrl);
|
|
||||||
return Computer.shared().compute(
|
|
||||||
_loadMagicPrompts,
|
|
||||||
param: <String, dynamic>{
|
|
||||||
"path": path,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Future<List<MagicCache>> _readResultFromDisk() async {
|
Future<List<MagicCache>> _readResultFromDisk() async {
|
||||||
_logger.info("Reading magic cache result from disk");
|
_logger.info("Reading magic cache result from disk");
|
||||||
final cache = await decodeJsonFile<List<MagicCache>>(
|
final cache = await decodeJsonFile<List<MagicCache>>(
|
||||||
@@ -312,6 +287,7 @@ class MagicCacheService {
|
|||||||
if (file.existsSync()) {
|
if (file.existsSync()) {
|
||||||
await file.delete();
|
await file.delete();
|
||||||
}
|
}
|
||||||
|
_magicCacheFuture = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
Future<List<GenericSearchResult>> getMagicGenericSearchResult(
|
Future<List<GenericSearchResult>> getMagicGenericSearchResult(
|
||||||
@@ -322,7 +298,30 @@ class MagicCacheService {
|
|||||||
kDebugMode ? EnteWatch("magicGenericSearchResult") : null;
|
kDebugMode ? EnteWatch("magicGenericSearchResult") : null;
|
||||||
w?.start();
|
w?.start();
|
||||||
final magicCaches = await getMagicCache();
|
final magicCaches = await getMagicCache();
|
||||||
final List<Prompt> prompts = await getPrompts();
|
|
||||||
|
// Load discover embeddings to get prompts
|
||||||
|
final discoverEmbeddings = await loadDiscoverEmbeddings();
|
||||||
|
if (discoverEmbeddings == null) {
|
||||||
|
_logger.severe("No discover embeddings available in assets");
|
||||||
|
throw Exception("No discover embeddings available in assets");
|
||||||
|
}
|
||||||
|
|
||||||
|
final prompts = <Prompt>[];
|
||||||
|
for (final entry in discoverEmbeddings.queryToPromptData.entries) {
|
||||||
|
final query = entry.key;
|
||||||
|
final promptData = entry.value;
|
||||||
|
prompts.add(
|
||||||
|
Prompt(
|
||||||
|
query: query,
|
||||||
|
title: promptData.title,
|
||||||
|
minScore: promptData.minScore,
|
||||||
|
minSize: promptData.minSize,
|
||||||
|
showVideo: promptData.showVideo ?? true,
|
||||||
|
recentFirst: promptData.recentFirst ?? false,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (magicCaches.isEmpty) {
|
if (magicCaches.isEmpty) {
|
||||||
w?.log("No magic cache found");
|
w?.log("No magic cache found");
|
||||||
return [];
|
return [];
|
||||||
@@ -375,35 +374,53 @@ class MagicCacheService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static Future<List<Prompt>> _loadMagicPrompts(
|
Future<List<MagicCache>> _nonEmptyMagicResults() async {
|
||||||
Map<String, dynamic> args,
|
|
||||||
) async {
|
|
||||||
final String path = args["path"] as String;
|
|
||||||
final File file = File(path);
|
|
||||||
final String contents = await file.readAsString();
|
|
||||||
final Map<String, dynamic> promptsJson = jsonDecode(contents);
|
|
||||||
final List<dynamic> promptData = promptsJson['prompts'];
|
|
||||||
return promptData
|
|
||||||
.map<Prompt>((jsonItem) => Prompt.fromJson(jsonItem))
|
|
||||||
.toList();
|
|
||||||
}
|
|
||||||
|
|
||||||
///Returns non-empty magic results from magicPromptsData
|
|
||||||
///Length is number of prompts, can be less if there are not enough non-empty
|
|
||||||
///results
|
|
||||||
Future<List<MagicCache>> _nonEmptyMagicResults(
|
|
||||||
List<Prompt> magicPromptsData,
|
|
||||||
) async {
|
|
||||||
final TimeLogger t = TimeLogger();
|
final TimeLogger t = TimeLogger();
|
||||||
final results = <MagicCache>[];
|
final results = <MagicCache>[];
|
||||||
final List<int> matchCount = [];
|
final List<int> matchCount = [];
|
||||||
final Map<String, double> queryToScore = {};
|
|
||||||
for (Prompt prompt in magicPromptsData) {
|
// Uncomment this code to generate embeddings and save them to a JSON file
|
||||||
queryToScore[prompt.query] = prompt.minScore;
|
// await generateAndSaveDiscoverEmbeddings();
|
||||||
|
// (from package:photos/utils/text_embeddings_util.dart)
|
||||||
|
|
||||||
|
// Load pre-computed discover embeddings from assets
|
||||||
|
final discoverEmbeddings = await loadDiscoverEmbeddings();
|
||||||
|
if (discoverEmbeddings == null) {
|
||||||
|
_logger.severe('Failed to load discover embeddings');
|
||||||
|
throw Exception('Failed to load discover embeddings');
|
||||||
}
|
}
|
||||||
final clipResults =
|
|
||||||
await SemanticSearchService.instance.getMatchingFileIDs(queryToScore);
|
// Build the embeddings and score maps
|
||||||
for (Prompt prompt in magicPromptsData) {
|
final Map<String, List<double>> queryToEmbedding = {};
|
||||||
|
final Map<String, double> queryToScore = {};
|
||||||
|
final List<Prompt> prompts = [];
|
||||||
|
|
||||||
|
for (final entry in discoverEmbeddings.queryToPromptData.entries) {
|
||||||
|
final query = entry.key;
|
||||||
|
final promptData = entry.value;
|
||||||
|
final vector = discoverEmbeddings.queryToVector[query];
|
||||||
|
|
||||||
|
if (vector != null) {
|
||||||
|
queryToEmbedding[query] = vector.toList();
|
||||||
|
queryToScore[query] = promptData.minScore;
|
||||||
|
prompts.add(
|
||||||
|
Prompt(
|
||||||
|
query: query,
|
||||||
|
title: promptData.title,
|
||||||
|
minScore: promptData.minScore,
|
||||||
|
minSize: promptData.minSize,
|
||||||
|
showVideo: promptData.showVideo ?? true,
|
||||||
|
recentFirst: promptData.recentFirst ?? false,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_logger.info('Using pre-computed discover embeddings from assets');
|
||||||
|
final clipResults = await SemanticSearchService.instance
|
||||||
|
.getMatchingFileIDsWithEmbeddings(queryToEmbedding, queryToScore);
|
||||||
|
|
||||||
|
for (final prompt in prompts) {
|
||||||
final List<int> fileUploadedIDs = clipResults[prompt.query] ?? [];
|
final List<int> fileUploadedIDs = clipResults[prompt.query] ?? [];
|
||||||
if (fileUploadedIDs.isNotEmpty) {
|
if (fileUploadedIDs.isNotEmpty) {
|
||||||
results.add(
|
results.add(
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ class SmartMemoriesService {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Load pre-computed text embeddings from assets
|
// Load pre-computed text embeddings from assets
|
||||||
final textEmbeddings = await loadTextEmbeddingsFromAssets();
|
final textEmbeddings = await loadMemoriesEmbeddings();
|
||||||
if (textEmbeddings == null) {
|
if (textEmbeddings == null) {
|
||||||
_logger.severe('Failed to load pre-computed text embeddings');
|
_logger.severe('Failed to load pre-computed text embeddings');
|
||||||
throw Exception(
|
throw Exception(
|
||||||
|
|||||||
@@ -9,13 +9,14 @@ import "package:path_provider/path_provider.dart"
|
|||||||
show getExternalStorageDirectory;
|
show getExternalStorageDirectory;
|
||||||
import 'package:photos/models/memories/clip_memory.dart';
|
import 'package:photos/models/memories/clip_memory.dart';
|
||||||
import 'package:photos/models/memories/people_memory.dart';
|
import 'package:photos/models/memories/people_memory.dart';
|
||||||
|
import 'package:photos/models/ml/discover/prompt.dart';
|
||||||
import "package:photos/services/machine_learning/ml_computer.dart"
|
import "package:photos/services/machine_learning/ml_computer.dart"
|
||||||
show MLComputer;
|
show MLComputer;
|
||||||
|
|
||||||
final _logger = Logger('TextEmbeddingsUtil');
|
final _logger = Logger('TextEmbeddingsUtil');
|
||||||
|
|
||||||
/// Loads pre-computed text embeddings from assets
|
/// Loads pre-computed memories text embeddings from assets
|
||||||
Future<TextEmbeddings?> loadTextEmbeddingsFromAssets() async {
|
Future<MemoriesEmbeddings?> loadMemoriesEmbeddings() async {
|
||||||
try {
|
try {
|
||||||
_logger.info('Loading text embeddings from assets');
|
_logger.info('Loading text embeddings from assets');
|
||||||
final jsonString =
|
final jsonString =
|
||||||
@@ -81,7 +82,7 @@ Future<TextEmbeddings?> loadTextEmbeddingsFromAssets() async {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_logger.info('Text embeddings loaded successfully from JSON assets');
|
_logger.info('Text embeddings loaded successfully from JSON assets');
|
||||||
return TextEmbeddings(
|
return MemoriesEmbeddings(
|
||||||
clipPositiveVector: clipPositiveVector,
|
clipPositiveVector: clipPositiveVector,
|
||||||
peopleActivityVectors: peopleActivityVectors,
|
peopleActivityVectors: peopleActivityVectors,
|
||||||
clipMemoryTypeVectors: clipMemoryTypeVectors,
|
clipMemoryTypeVectors: clipMemoryTypeVectors,
|
||||||
@@ -92,18 +93,94 @@ Future<TextEmbeddings?> loadTextEmbeddingsFromAssets() async {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class TextEmbeddings {
|
/// Container for all memories text embeddings
|
||||||
|
class MemoriesEmbeddings {
|
||||||
final Vector clipPositiveVector;
|
final Vector clipPositiveVector;
|
||||||
final Map<PeopleActivity, Vector> peopleActivityVectors;
|
final Map<PeopleActivity, Vector> peopleActivityVectors;
|
||||||
final Map<ClipMemoryType, Vector> clipMemoryTypeVectors;
|
final Map<ClipMemoryType, Vector> clipMemoryTypeVectors;
|
||||||
|
|
||||||
const TextEmbeddings({
|
const MemoriesEmbeddings({
|
||||||
required this.clipPositiveVector,
|
required this.clipPositiveVector,
|
||||||
required this.peopleActivityVectors,
|
required this.peopleActivityVectors,
|
||||||
required this.clipMemoryTypeVectors,
|
required this.clipMemoryTypeVectors,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Loads pre-computed discover text embeddings from assets
|
||||||
|
Future<DiscoverEmbeddings?> loadDiscoverEmbeddings() async {
|
||||||
|
try {
|
||||||
|
_logger.info('Loading discover embeddings from assets');
|
||||||
|
final jsonString =
|
||||||
|
await rootBundle.loadString('assets/ml/discover_embeddings.json');
|
||||||
|
final data = json.decode(jsonString) as Map<String, dynamic>;
|
||||||
|
|
||||||
|
final promptsList = data['prompts'] as List;
|
||||||
|
final Map<String, Vector> queryToVector = {};
|
||||||
|
final Map<String, PromptData> queryToPromptData = {};
|
||||||
|
|
||||||
|
// Parse all discover embeddings
|
||||||
|
for (final promptJson in promptsList) {
|
||||||
|
final query = promptJson['query'] as String;
|
||||||
|
final vector = (promptJson['vector'] as List).cast<double>();
|
||||||
|
if (vector.isNotEmpty) {
|
||||||
|
queryToVector[query] = Vector.fromList(vector);
|
||||||
|
queryToPromptData[query] = PromptData(
|
||||||
|
title: promptJson['title'] as String,
|
||||||
|
minScore: (promptJson['minScore'] as num).toDouble(),
|
||||||
|
minSize: (promptJson['minSize'] as num).toDouble(),
|
||||||
|
showVideo: promptJson['showVideo'] as bool?,
|
||||||
|
recentFirst: promptJson['recentFirst'] as bool?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we have all embeddings
|
||||||
|
if (queryToVector.isEmpty) {
|
||||||
|
_logger.warning('No discover embeddings found');
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
_logger.info(
|
||||||
|
'Discover embeddings loaded successfully with ${queryToVector.length} queries',
|
||||||
|
);
|
||||||
|
return DiscoverEmbeddings(
|
||||||
|
queryToVector: queryToVector,
|
||||||
|
queryToPromptData: queryToPromptData,
|
||||||
|
);
|
||||||
|
} catch (e, stackTrace) {
|
||||||
|
_logger.severe('Failed to load discover embeddings', e, stackTrace);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Container for discover/magic embeddings
|
||||||
|
class DiscoverEmbeddings {
|
||||||
|
final Map<String, Vector> queryToVector;
|
||||||
|
final Map<String, PromptData> queryToPromptData;
|
||||||
|
|
||||||
|
const DiscoverEmbeddings({
|
||||||
|
required this.queryToVector,
|
||||||
|
required this.queryToPromptData,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Container for prompt metadata
|
||||||
|
class PromptData {
|
||||||
|
final String title;
|
||||||
|
final double minScore;
|
||||||
|
final double minSize;
|
||||||
|
final bool? showVideo;
|
||||||
|
final bool? recentFirst;
|
||||||
|
|
||||||
|
const PromptData({
|
||||||
|
required this.title,
|
||||||
|
required this.minScore,
|
||||||
|
required this.minSize,
|
||||||
|
this.showVideo,
|
||||||
|
this.recentFirst,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
/// Helper function to generate text embeddings and save them to a JSON file
|
/// Helper function to generate text embeddings and save them to a JSON file
|
||||||
/// Run this once to generate the embeddings, then copy the output
|
/// Run this once to generate the embeddings, then copy the output
|
||||||
/// to assets/ml/text_embeddings.json
|
/// to assets/ml/text_embeddings.json
|
||||||
@@ -170,3 +247,161 @@ Future<void> generateAndSaveTextEmbeddings() async {
|
|||||||
'_generateAndSaveTextEmbeddings: Text embeddings generation complete! Copy the JSON output above to assets/ml/text_embeddings.json',
|
'_generateAndSaveTextEmbeddings: Text embeddings generation complete! Copy the JSON output above to assets/ml/text_embeddings.json',
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Helper function to generate discover text embeddings and save them to a JSON file
|
||||||
|
/// Run this once to generate the embeddings, then copy the output
|
||||||
|
/// to assets/ml/discover_embeddings.json
|
||||||
|
Future<void> generateAndSaveDiscoverEmbeddings() async {
|
||||||
|
// Hardcoded prompts from v2.json
|
||||||
|
final magicPromptsData = [
|
||||||
|
Prompt(
|
||||||
|
query: "identity documents and cards",
|
||||||
|
title: "Identity",
|
||||||
|
minScore: 0.20,
|
||||||
|
minSize: 0.0,
|
||||||
|
showVideo: false,
|
||||||
|
),
|
||||||
|
Prompt(
|
||||||
|
query: "screenshots",
|
||||||
|
title: "Screenshots",
|
||||||
|
minScore: 0.18,
|
||||||
|
minSize: 0.0,
|
||||||
|
recentFirst: true,
|
||||||
|
showVideo: false,
|
||||||
|
),
|
||||||
|
Prompt(
|
||||||
|
query: "receipts, bills and invoices",
|
||||||
|
title: "Receipts",
|
||||||
|
minScore: 0.20,
|
||||||
|
minSize: 0.0,
|
||||||
|
recentFirst: true,
|
||||||
|
),
|
||||||
|
Prompt(
|
||||||
|
query: "screenshots of digital notes, photos of notebook and whiteboards",
|
||||||
|
title: "Notes",
|
||||||
|
minScore: 0.20,
|
||||||
|
minSize: 0.0,
|
||||||
|
recentFirst: true,
|
||||||
|
),
|
||||||
|
Prompt(
|
||||||
|
query: "memes",
|
||||||
|
title: "Memes",
|
||||||
|
minScore: 0.18,
|
||||||
|
minSize: 0.0,
|
||||||
|
recentFirst: true,
|
||||||
|
),
|
||||||
|
Prompt(
|
||||||
|
query: "business card",
|
||||||
|
title: "Visiting Cards",
|
||||||
|
minScore: 0.20,
|
||||||
|
minSize: 0.0,
|
||||||
|
recentFirst: true,
|
||||||
|
showVideo: false,
|
||||||
|
),
|
||||||
|
Prompt(
|
||||||
|
query: "photo of toddler",
|
||||||
|
title: "Babies",
|
||||||
|
minScore: 0.20,
|
||||||
|
minSize: 0.0,
|
||||||
|
recentFirst: true,
|
||||||
|
),
|
||||||
|
Prompt(
|
||||||
|
query: "photos of pets",
|
||||||
|
title: "Pets",
|
||||||
|
minScore: 0.20,
|
||||||
|
minSize: 0.0,
|
||||||
|
recentFirst: true,
|
||||||
|
),
|
||||||
|
Prompt(
|
||||||
|
query: "selfie",
|
||||||
|
title: "Selfies",
|
||||||
|
minScore: 0.18,
|
||||||
|
minSize: 0.0,
|
||||||
|
recentFirst: true,
|
||||||
|
),
|
||||||
|
Prompt(
|
||||||
|
query: "high resolution hd wallpaper",
|
||||||
|
title: "Wallpapers",
|
||||||
|
minScore: 0.18,
|
||||||
|
minSize: 0.0,
|
||||||
|
showVideo: false,
|
||||||
|
),
|
||||||
|
Prompt(
|
||||||
|
query: "photo of food",
|
||||||
|
title: "Food",
|
||||||
|
minScore: 0.20,
|
||||||
|
minSize: 0.0,
|
||||||
|
recentFirst: true,
|
||||||
|
),
|
||||||
|
Prompt(
|
||||||
|
query: "celebration with friends",
|
||||||
|
title: "Celebrations",
|
||||||
|
minScore: 0.20,
|
||||||
|
minSize: 0.0,
|
||||||
|
recentFirst: true,
|
||||||
|
),
|
||||||
|
Prompt(
|
||||||
|
query: "photo of setting sun",
|
||||||
|
title: "Sunset",
|
||||||
|
minScore: 0.20,
|
||||||
|
minSize: 0.0,
|
||||||
|
),
|
||||||
|
Prompt(
|
||||||
|
query: "photo of mountains or hills",
|
||||||
|
title: "Hills",
|
||||||
|
minScore: 0.20,
|
||||||
|
minSize: 0.0,
|
||||||
|
),
|
||||||
|
Prompt(
|
||||||
|
query: "outdoor photo of trees, leaves, flowers",
|
||||||
|
title: "Greenery",
|
||||||
|
minScore: 0.20,
|
||||||
|
minSize: 0.0,
|
||||||
|
),
|
||||||
|
];
|
||||||
|
|
||||||
|
final Map<String, dynamic> embeddingsData = {
|
||||||
|
'version': '1.0.0',
|
||||||
|
'prompts': [],
|
||||||
|
};
|
||||||
|
|
||||||
|
final prompts = <Map<String, dynamic>>[];
|
||||||
|
|
||||||
|
for (final prompt in magicPromptsData) {
|
||||||
|
_logger.info('Generating embedding for discover prompt: ${prompt.title}');
|
||||||
|
final vector = await MLComputer.instance.runClipText(prompt.query);
|
||||||
|
|
||||||
|
final promptData = {
|
||||||
|
'query': prompt.query,
|
||||||
|
'title': prompt.title,
|
||||||
|
'minScore': prompt.minScore,
|
||||||
|
'minSize': prompt.minSize,
|
||||||
|
'vector': vector,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add optional fields
|
||||||
|
promptData['showVideo'] = prompt.showVideo;
|
||||||
|
promptData['recentFirst'] = prompt.recentFirst;
|
||||||
|
|
||||||
|
prompts.add(promptData);
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddingsData['prompts'] = prompts;
|
||||||
|
|
||||||
|
// Convert to JSON and log it
|
||||||
|
final jsonString = const JsonEncoder.withIndent(' ').convert(embeddingsData);
|
||||||
|
dev.log(
|
||||||
|
'_generateAndSaveDiscoverEmbeddings: Generated discover embeddings JSON',
|
||||||
|
);
|
||||||
|
|
||||||
|
final tempDir = await getExternalStorageDirectory();
|
||||||
|
final file = File('${tempDir!.path}/discover_embeddings.json');
|
||||||
|
await file.writeAsString(jsonString);
|
||||||
|
dev.log(
|
||||||
|
'_generateAndSaveDiscoverEmbeddings: Saved discover embeddings to ${file.path}',
|
||||||
|
);
|
||||||
|
|
||||||
|
dev.log(
|
||||||
|
'_generateAndSaveDiscoverEmbeddings: Discover embeddings generation complete! Copy the JSON output above to assets/ml/discover_embeddings.json',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user