[mob][photos] Discovery (behind feature flag) (#2115)

Show Magic search examples on search tab for internal users.
This commit is contained in:
Ashil
2024-06-13 10:20:37 +05:30
committed by GitHub
9 changed files with 381 additions and 32 deletions

View File

@@ -10,10 +10,10 @@ import "package:photos/events/people_changed_event.dart";
import "package:photos/generated/l10n.dart";
import "package:photos/models/collection/collection.dart";
import "package:photos/models/collection/collection_items.dart";
import "package:photos/models/search/generic_search_result.dart";
import "package:photos/models/search/search_result.dart";
import "package:photos/models/typedefs.dart";
import "package:photos/services/collections_service.dart";
import "package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart";
import "package:photos/services/search_service.dart";
import "package:photos/ui/viewer/gallery/collection_page.dart";
import "package:photos/ui/viewer/location/add_location_sheet.dart";
@@ -41,8 +41,7 @@ enum ResultType {
enum SectionType {
face,
location,
// Grouping based on ML or manual tagging
content,
magic,
// includes year, month , day, event ResultType
moment,
album,
@@ -58,8 +57,8 @@ extension SectionTypeExtensions on SectionType {
switch (this) {
case SectionType.face:
return S.of(context).people;
case SectionType.content:
return S.of(context).contents;
case SectionType.magic:
return "Magic";
case SectionType.moment:
return S.of(context).moments;
case SectionType.location:
@@ -79,8 +78,8 @@ extension SectionTypeExtensions on SectionType {
switch (this) {
case SectionType.face:
return S.of(context).searchFaceEmptySection;
case SectionType.content:
return "Contents";
case SectionType.magic:
return "Magic";
case SectionType.moment:
return S.of(context).searchDatesEmptySection;
case SectionType.location:
@@ -102,7 +101,7 @@ extension SectionTypeExtensions on SectionType {
switch (this) {
case SectionType.face:
return false;
case SectionType.content:
case SectionType.magic:
return false;
case SectionType.moment:
return false;
@@ -125,7 +124,7 @@ extension SectionTypeExtensions on SectionType {
switch (this) {
case SectionType.face:
return false;
case SectionType.content:
case SectionType.magic:
return false;
case SectionType.moment:
return false;
@@ -147,9 +146,9 @@ extension SectionTypeExtensions on SectionType {
case SectionType.face:
// todo: later
return "Setup";
case SectionType.content:
case SectionType.magic:
// todo: later
return "Add tags";
return "temp";
case SectionType.moment:
return S.of(context).addNew;
case SectionType.location:
@@ -169,7 +168,7 @@ extension SectionTypeExtensions on SectionType {
switch (this) {
case SectionType.face:
return Icons.adaptive.arrow_forward_outlined;
case SectionType.content:
case SectionType.magic:
return null;
case SectionType.moment:
return null;
@@ -250,8 +249,8 @@ extension SectionTypeExtensions on SectionType {
switch (this) {
case SectionType.face:
return SearchService.instance.getAllFace(limit);
case SectionType.content:
return Future.value(List<GenericSearchResult>.empty());
case SectionType.magic:
return SearchService.instance.getMagicSectionResutls();
case SectionType.moment:
return SearchService.instance.getRandomMomentsSearchResults(context);
@@ -293,6 +292,8 @@ extension SectionTypeExtensions on SectionType {
switch (this) {
case SectionType.location:
return [Bus.instance.on<LocationTagUpdatedEvent>()];
case SectionType.magic:
return [Bus.instance.on<MLFrameworkInitializationUpdateEvent>()];
default:
return [];
}

View File

@@ -131,11 +131,15 @@ class SemanticSearchService {
_isSyncing = false;
}
bool isMagicSearchEnabledAndReady() {
return LocalSettings.instance.hasEnabledMagicSearch() &&
_frameworkInitialization.isCompleted;
}
// searchScreenQuery should only be used for the user initiate query on the search screen.
// If there are multiple call tho this method, then for all the calls, the result will be the same as the last query.
Future<(String, List<EnteFile>)> searchScreenQuery(String query) async {
if (!LocalSettings.instance.hasEnabledMagicSearch() ||
!_frameworkInitialization.isCompleted) {
if (!isMagicSearchEnabledAndReady()) {
return (query, <EnteFile>[]);
}
// If there's an ongoing request, just update the last query and return its future.
@@ -144,7 +148,7 @@ class SemanticSearchService {
return _searchScreenRequest!;
} else {
// No ongoing request, start a new search.
_searchScreenRequest = _getMatchingFiles(query).then((result) {
_searchScreenRequest = getMatchingFiles(query).then((result) {
// Search completed, reset the ongoing request.
_searchScreenRequest = null;
// If there was a new query during the last search, start a new search with the last query.
@@ -236,18 +240,24 @@ class SemanticSearchService {
_queue.clear();
}
Future<List<EnteFile>> _getMatchingFiles(String query) async {
Future<List<EnteFile>> getMatchingFiles(
String query, {
double? scoreThreshold,
}) async {
final textEmbedding = await _getTextEmbedding(query);
final queryResults = await _getScores(textEmbedding);
final queryResults =
await _getScores(textEmbedding, scoreThreshold: scoreThreshold);
final filesMap = await FilesDB.instance
.getFilesFromIDs(queryResults.map((e) => e.id).toList());
final results = <EnteFile>[];
final ignoredCollections =
CollectionsService.instance.getHiddenCollectionIds();
final deletedEntries = <int>[];
final results = <EnteFile>[];
for (final result in queryResults) {
final file = filesMap[result.id];
if (file != null && !ignoredCollections.contains(file.collectionID)) {
@@ -355,13 +365,17 @@ class SemanticSearchService {
}
}
Future<List<QueryResult>> _getScores(List<double> textEmbedding) async {
Future<List<QueryResult>> _getScores(
List<double> textEmbedding, {
double? scoreThreshold,
}) async {
final startTime = DateTime.now();
final List<QueryResult> queryResults = await _computer.compute(
computeBulkScore,
param: {
"imageEmbeddings": _cachedEmbeddings,
"textEmbedding": textEmbedding,
"scoreThreshold": scoreThreshold,
},
taskName: "computeBulkScore",
);
@@ -402,12 +416,14 @@ List<QueryResult> computeBulkScore(Map args) {
final queryResults = <QueryResult>[];
final imageEmbeddings = args["imageEmbeddings"] as List<Embedding>;
final textEmbedding = args["textEmbedding"] as List<double>;
final scoreThreshold =
args["scoreThreshold"] ?? SemanticSearchService.kScoreThreshold;
for (final imageEmbedding in imageEmbeddings) {
final score = computeScore(
imageEmbedding.embedding,
textEmbedding,
);
if (score >= SemanticSearchService.kScoreThreshold) {
if (score >= scoreThreshold) {
queryResults.add(QueryResult(imageEmbedding.fileID, score));
}
}
@@ -422,7 +438,8 @@ double computeScore(List<double> imageEmbedding, List<double> textEmbedding) {
"The two embeddings should have the same length",
);
double score = 0;
for (int index = 0; index < imageEmbedding.length; index++) {
final length = imageEmbedding.length;
for (int index = 0; index < length; index++) {
score += imageEmbedding[index] * textEmbedding[index];
}
return score;

View File

@@ -18,10 +18,10 @@ class RemoteAssetsService {
static final RemoteAssetsService instance =
RemoteAssetsService._privateConstructor();
Future<File> getAsset(String remotePath) async {
Future<File> getAsset(String remotePath, {bool refetch = false}) async {
final path = await _getLocalPath(remotePath);
final file = File(path);
if (await file.exists()) {
if (await file.exists() && !refetch) {
_logger.info("Returning cached file for $remotePath");
return file;
} else {

View File

@@ -1,3 +1,4 @@
import "dart:convert";
import "dart:math";
import "package:flutter/cupertino.dart";
@@ -26,11 +27,13 @@ import 'package:photos/models/search/album_search_result.dart';
import 'package:photos/models/search/generic_search_result.dart';
import "package:photos/models/search/search_constants.dart";
import "package:photos/models/search/search_types.dart";
import "package:photos/service_locator.dart";
import 'package:photos/services/collections_service.dart';
import "package:photos/services/location_service.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart';
import "package:photos/services/remote_assets_service.dart";
import "package:photos/states/location_screen_state.dart";
import "package:photos/ui/viewer/location/add_location_sheet.dart";
import "package:photos/ui/viewer/location/location_screen.dart";
@@ -46,6 +49,9 @@ class SearchService {
final _logger = Logger((SearchService).toString());
final _collectionService = CollectionsService.instance;
static const _maximumResultsLimit = 20;
static const _kMagicPromptsDataUrl = "https://discover.ente.io/v1.json";
var magicPromptsData = [];
SearchService._privateConstructor();
@@ -57,6 +63,17 @@ class SearchService {
_cachedFilesFuture = null;
_cachedHiddenFilesFuture = null;
});
if (flagService.internalUser) {
_loadMagicPrompts();
}
}
Future<dynamic> _loadMagicPrompts() async {
final file = await RemoteAssetsService.instance
.getAsset(_kMagicPromptsDataUrl, refetch: true);
final json = jsonDecode(await file.readAsString());
magicPromptsData = json["prompts"];
}
Set<int> ignoreCollections() {
@@ -174,6 +191,29 @@ class SearchService {
return searchResults;
}
Future<List<GenericSearchResult>> getMagicSectionResutls() async {
if (!SemanticSearchService.instance.isMagicSearchEnabledAndReady()) {
return <GenericSearchResult>[];
}
final searchResuts = <GenericSearchResult>[];
for (Map<String, dynamic> magicPrompt in magicPromptsData) {
final files = await SemanticSearchService.instance.getMatchingFiles(
magicPrompt["prompt"],
scoreThreshold: magicPrompt["minimumScore"],
);
if (files.isNotEmpty) {
searchResuts.add(
GenericSearchResult(
ResultType.magic,
magicPrompt["title"],
files,
),
);
}
}
return searchResuts;
}
Future<List<GenericSearchResult>> getRandomMomentsSearchResults(
BuildContext context,
) async {

View File

@@ -88,9 +88,6 @@ class _AllSectionsExamplesProviderState
_logger.info("'_debounceTimer: reloading all sections in search tab");
final allSectionsExamples = <Future<List<SearchResult>>>[];
for (SectionType sectionType in SectionType.values) {
if (sectionType == SectionType.content) {
continue;
}
allSectionsExamples.add(
sectionType.getData(context, limit: kSearchSectionLimit),
);

View File

@@ -21,7 +21,7 @@ class _NoResultWidgetState extends State<NoResultWidget> {
super.initState();
searchTypes = SectionType.values.toList(growable: true);
// remove face and content sectionType
searchTypes.remove(SectionType.content);
searchTypes.remove(SectionType.magic);
}
@override

View File

@@ -13,14 +13,12 @@ class SearchResultWidget extends StatelessWidget {
final SearchResult searchResult;
final Future<int>? resultCount;
final Function? onResultTap;
final Map<String, dynamic>? params;
const SearchResultWidget(
this.searchResult, {
Key? key,
this.resultCount,
this.onResultTap,
this.params,
}) : super(key: key);
@override

View File

@@ -0,0 +1,291 @@
import "dart:async";
import "dart:math";
import "package:figma_squircle/figma_squircle.dart";
import "package:flutter/material.dart";
import "package:photos/core/constants.dart";
import "package:photos/events/event.dart";
import "package:photos/models/search/generic_search_result.dart";
import "package:photos/models/search/recent_searches.dart";
import "package:photos/models/search/search_types.dart";
import "package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart";
import "package:photos/theme/ente_theme.dart";
import "package:photos/ui/viewer/file/no_thumbnail_widget.dart";
import "package:photos/ui/viewer/file/thumbnail_widget.dart";
import "package:photos/ui/viewer/search/result/search_result_page.dart";
import "package:photos/ui/viewer/search_tab/section_header.dart";
import "package:photos/utils/navigation_util.dart";
class MagicSection extends StatefulWidget {
final List<GenericSearchResult> magicSearchResults;
const MagicSection(this.magicSearchResults, {super.key});
@override
State<MagicSection> createState() => _MagicSectionState();
}
class _MagicSectionState extends State<MagicSection> {
late List<GenericSearchResult> _magicSearchResults;
final streamSubscriptions = <StreamSubscription>[];
@override
void initState() {
super.initState();
_magicSearchResults = widget.magicSearchResults;
//At times, ml framework is not initialized when the search results are
//requested (widget.momentsSearchResults is empty) and is initialized
//(which fires MLFrameworkInitializationUpdateEvent with
//InitializationState.initialized) before initState of this widget is
//called. We do listen to MLFrameworkInitializationUpdateEvent and reload
//this widget but the event with InitializationState.initialized would have
//already been fired in the above case.
if (_magicSearchResults.isEmpty) {
SectionType.magic
.getData(
context,
limit: kSearchSectionLimit,
)
.then((value) {
if (mounted) {
setState(() {
_magicSearchResults = value as List<GenericSearchResult>;
});
}
});
}
final streamsToListenTo = SectionType.magic.sectionUpdateEvents();
for (Stream<Event> stream in streamsToListenTo) {
streamSubscriptions.add(
stream.listen((event) async {
final mlFrameWorkEvent =
event as MLFrameworkInitializationUpdateEvent;
if (mlFrameWorkEvent.state == InitializationState.initialized) {
_magicSearchResults = (await SectionType.magic.getData(
context,
limit: kSearchSectionLimit,
)) as List<GenericSearchResult>;
setState(() {});
}
}),
);
}
}
@override
void dispose() {
for (var subscriptions in streamSubscriptions) {
subscriptions.cancel();
}
super.dispose();
}
@override
void didUpdateWidget(covariant MagicSection oldWidget) {
super.didUpdateWidget(oldWidget);
//widget.magicSearch is empty when doing a hot reload
if (widget.magicSearchResults.isNotEmpty) {
_magicSearchResults = widget.magicSearchResults;
}
}
@override
Widget build(BuildContext context) {
if (_magicSearchResults.isEmpty) {
// final textTheme = getEnteTextTheme(context);
// return Padding(
// padding: const EdgeInsets.only(left: 12, right: 8),
// child: Row(
// children: [
// Expanded(
// child: Column(
// crossAxisAlignment: CrossAxisAlignment.start,
// children: [
// Text(
// SectionType.magic.sectionTitle(context),
// style: textTheme.largeBold,
// ),
// const SizedBox(height: 24),
// Padding(
// padding: const EdgeInsets.only(left: 4),
// child: Text(
// SectionType.magic.getEmptyStateText(context),
// style: textTheme.smallMuted,
// ),
// ),
// ],
// ),
// ),
// const SizedBox(width: 8),
// const SearchSectionEmptyCTAIcon(SectionType.magic),
// ],
// ),
// );
return const SizedBox.shrink();
} else {
return Padding(
padding: const EdgeInsets.symmetric(vertical: 8),
child: Column(
crossAxisAlignment: CrossAxisAlignment.start,
children: [
SectionHeader(
SectionType.magic,
hasMore: (_magicSearchResults.length >= kSearchSectionLimit - 1),
),
const SizedBox(height: 2),
SizedBox(
child: SingleChildScrollView(
clipBehavior: Clip.none,
padding: const EdgeInsets.symmetric(horizontal: 4.5),
physics: const BouncingScrollPhysics(),
scrollDirection: Axis.horizontal,
child: Row(
crossAxisAlignment: CrossAxisAlignment.start,
children: _magicSearchResults
.map(
(magicSearchResult) =>
MagicRecommendation(magicSearchResult),
)
.toList(),
),
),
),
],
),
);
}
}
}
class MagicRecommendation extends StatelessWidget {
static const _width = 100.0;
static const _height = 110.0;
static const _borderWidth = 1.0;
static const _cornerRadius = 12.0;
static const _cornerSmoothing = 1.0;
final GenericSearchResult magicSearchResult;
const MagicRecommendation(this.magicSearchResult, {super.key});
@override
Widget build(BuildContext context) {
final heroTag = magicSearchResult.heroTag() +
(magicSearchResult.previewThumbnail()?.tag ?? "");
final enteTextTheme = getEnteTextTheme(context);
return Padding(
padding: EdgeInsets.symmetric(horizontal: max(2.5 - _borderWidth, 0)),
child: GestureDetector(
onTap: () {
RecentSearches().add(magicSearchResult.name());
if (magicSearchResult.onResultTap != null) {
magicSearchResult.onResultTap!(context);
} else {
routeToPage(
context,
SearchResultPage(
magicSearchResult,
enableGrouping: false,
),
);
}
},
child: SizedBox(
width: _width + _borderWidth * 2,
height: _height + _borderWidth * 2,
child: Stack(
clipBehavior: Clip.none,
alignment: Alignment.center,
children: [
ClipSmoothRect(
radius: SmoothBorderRadius(
cornerRadius: _cornerRadius + _borderWidth,
cornerSmoothing: _cornerSmoothing,
),
child: Container(
color: getEnteColorScheme(context).strokeFaint,
width: _width + _borderWidth * 2,
height: _height + _borderWidth * 2,
),
),
Container(
decoration: BoxDecoration(
boxShadow: [
BoxShadow(
color: Colors.black.withOpacity(0.1),
blurRadius: 6.25,
offset: const Offset(-1.25, 2.5),
),
],
),
child: ClipSmoothRect(
radius: SmoothBorderRadius(
cornerRadius: _cornerRadius,
cornerSmoothing: _cornerSmoothing,
),
child: Stack(
alignment: Alignment.bottomCenter,
clipBehavior: Clip.none,
children: [
SizedBox(
width: _width,
height: _height,
child: magicSearchResult.previewThumbnail() != null
? Hero(
tag: heroTag,
child: ThumbnailWidget(
magicSearchResult.previewThumbnail()!,
shouldShowArchiveStatus: false,
shouldShowSyncStatus: false,
),
)
: const NoThumbnailWidget(),
),
Container(
height: _height,
width: _width,
decoration: BoxDecoration(
gradient: LinearGradient(
begin: Alignment.topCenter,
end: Alignment.bottomCenter,
colors: [
Colors.black.withOpacity(0),
Colors.black.withOpacity(0),
Colors.black.withOpacity(0.5),
],
stops: const [
0,
0.1,
1,
],
),
),
),
ConstrainedBox(
constraints: const BoxConstraints(
maxWidth: 76,
),
child: Padding(
padding: const EdgeInsets.only(
bottom: 8,
),
child: Text(
magicSearchResult.name(),
style: enteTextTheme.small.copyWith(
color: Colors.white,
),
maxLines: 3,
overflow: TextOverflow.fade,
),
),
),
],
),
),
),
],
),
),
),
);
}
}

View File

@@ -18,6 +18,7 @@ import "package:photos/ui/viewer/search_tab/contacts_section.dart";
import "package:photos/ui/viewer/search_tab/descriptions_section.dart";
import "package:photos/ui/viewer/search_tab/file_type_section.dart";
import "package:photos/ui/viewer/search_tab/locations_section.dart";
import "package:photos/ui/viewer/search_tab/magic_section.dart";
import "package:photos/ui/viewer/search_tab/moments_section.dart";
import "package:photos/ui/viewer/search_tab/people_section.dart";
import "package:photos/utils/local_settings.dart";
@@ -82,7 +83,6 @@ class _AllSearchSectionsState extends State<AllSearchSections> {
@override
Widget build(BuildContext context) {
final searchTypes = SectionType.values.toList(growable: true);
searchTypes.remove(SectionType.content);
return Padding(
padding: const EdgeInsets.only(top: 8),
@@ -153,6 +153,11 @@ class _AllSearchSectionsState extends State<AllSearchSections> {
snapshot.data!.elementAt(index)
as List<GenericSearchResult>,
);
case SectionType.magic:
return MagicSection(
snapshot.data!.elementAt(index)
as List<GenericSearchResult>,
);
default:
const SizedBox.shrink();
}