diff --git a/mobile/lib/models/search/search_types.dart b/mobile/lib/models/search/search_types.dart index a13fd57dcb..30b8a0bb0b 100644 --- a/mobile/lib/models/search/search_types.dart +++ b/mobile/lib/models/search/search_types.dart @@ -10,10 +10,10 @@ import "package:photos/events/people_changed_event.dart"; import "package:photos/generated/l10n.dart"; import "package:photos/models/collection/collection.dart"; import "package:photos/models/collection/collection_items.dart"; -import "package:photos/models/search/generic_search_result.dart"; import "package:photos/models/search/search_result.dart"; import "package:photos/models/typedefs.dart"; import "package:photos/services/collections_service.dart"; +import "package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart"; import "package:photos/services/search_service.dart"; import "package:photos/ui/viewer/gallery/collection_page.dart"; import "package:photos/ui/viewer/location/add_location_sheet.dart"; @@ -251,7 +251,7 @@ extension SectionTypeExtensions on SectionType { case SectionType.face: return SearchService.instance.getAllFace(limit); case SectionType.content: - return Future.value(List.empty()); + return SearchService.instance.getMagicSectionResutls(); case SectionType.moment: return SearchService.instance.getRandomMomentsSearchResults(context); @@ -293,6 +293,8 @@ extension SectionTypeExtensions on SectionType { switch (this) { case SectionType.location: return [Bus.instance.on()]; + case SectionType.content: + return [Bus.instance.on()]; default: return []; } diff --git a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart index 5f7f509355..d65c67aba3 100644 --- a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart +++ b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart @@ -131,11 +131,15 @@ class SemanticSearchService { _isSyncing = false; } + bool isMagicSearchEnabledAndReady() { + return LocalSettings.instance.hasEnabledMagicSearch() && + _frameworkInitialization.isCompleted; + } + // searchScreenQuery should only be used for the user initiate query on the search screen. // If there are multiple call tho this method, then for all the calls, the result will be the same as the last query. Future<(String, List)> searchScreenQuery(String query) async { - if (!LocalSettings.instance.hasEnabledMagicSearch() || - !_frameworkInitialization.isCompleted) { + if (!isMagicSearchEnabledAndReady()) { return (query, []); } // If there's an ongoing request, just update the last query and return its future. @@ -144,7 +148,7 @@ class SemanticSearchService { return _searchScreenRequest!; } else { // No ongoing request, start a new search. - _searchScreenRequest = _getMatchingFiles(query).then((result) { + _searchScreenRequest = getMatchingFiles(query).then((result) { // Search completed, reset the ongoing request. _searchScreenRequest = null; // If there was a new query during the last search, start a new search with the last query. @@ -236,7 +240,7 @@ class SemanticSearchService { _queue.clear(); } - Future> _getMatchingFiles( + Future> getMatchingFiles( String query, { double? scoreThreshold, }) async { @@ -247,11 +251,13 @@ class SemanticSearchService { final filesMap = await FilesDB.instance .getFilesFromIDs(queryResults.map((e) => e.id).toList()); - final results = []; final ignoredCollections = CollectionsService.instance.getHiddenCollectionIds(); + final deletedEntries = []; + final results = []; + for (final result in queryResults) { final file = filesMap[result.id]; if (file != null && !ignoredCollections.contains(file.collectionID)) { diff --git a/mobile/lib/services/search_service.dart b/mobile/lib/services/search_service.dart index d15eddb718..401cb574d4 100644 --- a/mobile/lib/services/search_service.dart +++ b/mobile/lib/services/search_service.dart @@ -40,6 +40,33 @@ import 'package:photos/utils/date_time_util.dart'; import "package:photos/utils/navigation_util.dart"; import 'package:tuple/tuple.dart'; +const magicPromptsData = [ + { + "prompt": "identity document", + "title": "Identity Document", + "minimumScore": 0.269, + "minimumSize": 0.0, + }, + { + "prompt": "sunset at the beach", + "title": "Sunset", + "minimumScore": 0.25, + "minimumSize": 0.0, + }, + { + "prompt": "roadtrip", + "title": "Roadtrip", + "minimumScore": 0.26, + "minimumSize": 0.0, + }, + { + "prompt": "pizza pasta burger", + "title": "Food", + "minimumScore": 0.27, + "minimumSize": 0.0, + } +]; + class SearchService { Future>? _cachedFilesFuture; Future>? _cachedHiddenFilesFuture; @@ -174,6 +201,29 @@ class SearchService { return searchResults; } + Future> getMagicSectionResutls() async { + if (!SemanticSearchService.instance.isMagicSearchEnabledAndReady()) { + return []; + } + final searchResuts = []; + for (Map magicPrompt in magicPromptsData) { + final files = await SemanticSearchService.instance.getMatchingFiles( + magicPrompt["prompt"], + scoreThreshold: magicPrompt["minimumScore"], + ); + if (files.isNotEmpty) { + searchResuts.add( + GenericSearchResult( + ResultType.magic, + magicPrompt["title"], + files, + ), + ); + } + } + return searchResuts; + } + Future> getRandomMomentsSearchResults( BuildContext context, ) async { diff --git a/mobile/lib/states/all_sections_examples_state.dart b/mobile/lib/states/all_sections_examples_state.dart index a40ecd9255..716de8db56 100644 --- a/mobile/lib/states/all_sections_examples_state.dart +++ b/mobile/lib/states/all_sections_examples_state.dart @@ -88,9 +88,6 @@ class _AllSectionsExamplesProviderState _logger.info("'_debounceTimer: reloading all sections in search tab"); final allSectionsExamples = >>[]; for (SectionType sectionType in SectionType.values) { - if (sectionType == SectionType.content) { - continue; - } allSectionsExamples.add( sectionType.getData(context, limit: kSearchSectionLimit), ); diff --git a/mobile/lib/ui/viewer/search_tab/magic_section.dart b/mobile/lib/ui/viewer/search_tab/magic_section.dart new file mode 100644 index 0000000000..3b49413e06 --- /dev/null +++ b/mobile/lib/ui/viewer/search_tab/magic_section.dart @@ -0,0 +1,285 @@ +import "dart:async"; +import "dart:math"; + +import "package:figma_squircle/figma_squircle.dart"; +import "package:flutter/material.dart"; +import "package:photos/core/constants.dart"; +import "package:photos/events/event.dart"; +import "package:photos/models/search/generic_search_result.dart"; +import "package:photos/models/search/recent_searches.dart"; +import "package:photos/models/search/search_types.dart"; +import "package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart"; +import "package:photos/theme/ente_theme.dart"; +import "package:photos/ui/viewer/file/no_thumbnail_widget.dart"; +import "package:photos/ui/viewer/file/thumbnail_widget.dart"; +import "package:photos/ui/viewer/search/result/search_result_page.dart"; +import "package:photos/ui/viewer/search/search_section_cta.dart"; +import "package:photos/ui/viewer/search_tab/section_header.dart"; +import "package:photos/utils/navigation_util.dart"; + +class MagicSection extends StatefulWidget { + final List magicSearchResults; + const MagicSection(this.magicSearchResults, {super.key}); + + @override + State createState() => _MagicSectionState(); +} + +class _MagicSectionState extends State { + late List _magicSearchResults; + final streamSubscriptions = []; + + @override + void initState() { + super.initState(); + _magicSearchResults = widget.magicSearchResults; + + //At times, ml framework is not initialized when the search results are + //requested (widget.momentsSearchResults is empty) and is initialized + //(which fires MLFrameworkInitializationUpdateEvent with + //InitializationState.initialized) before initState of this widget is + //called. We do listen to MLFrameworkInitializationUpdateEvent and reload + //this widget but the event with InitializationState.initialized would have + //already been fired in the above case. + if (_magicSearchResults.isEmpty) { + SectionType.content + .getData( + context, + limit: kSearchSectionLimit, + ) + .then((value) { + if (mounted) { + setState(() { + _magicSearchResults = value as List; + }); + } + }); + } + + final streamsToListenTo = SectionType.content.sectionUpdateEvents(); + for (Stream stream in streamsToListenTo) { + streamSubscriptions.add( + stream.listen((event) async { + final mlFrameWorkEvent = + event as MLFrameworkInitializationUpdateEvent; + if (mlFrameWorkEvent.state == InitializationState.initialized) { + _magicSearchResults = (await SectionType.content.getData( + context, + limit: kSearchSectionLimit, + )) as List; + setState(() {}); + } + }), + ); + } + } + + @override + void dispose() { + for (var subscriptions in streamSubscriptions) { + subscriptions.cancel(); + } + super.dispose(); + } + + @override + void didUpdateWidget(covariant MagicSection oldWidget) { + super.didUpdateWidget(oldWidget); + _magicSearchResults = widget.magicSearchResults; + } + + @override + Widget build(BuildContext context) { + if (_magicSearchResults.isEmpty) { + final textTheme = getEnteTextTheme(context); + return Padding( + padding: const EdgeInsets.only(left: 12, right: 8), + child: Row( + children: [ + Expanded( + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text( + SectionType.moment.sectionTitle(context), + style: textTheme.largeBold, + ), + const SizedBox(height: 24), + Padding( + padding: const EdgeInsets.only(left: 4), + child: Text( + SectionType.moment.getEmptyStateText(context), + style: textTheme.smallMuted, + ), + ), + ], + ), + ), + const SizedBox(width: 8), + const SearchSectionEmptyCTAIcon(SectionType.moment), + ], + ), + ); + } else { + return Padding( + padding: const EdgeInsets.symmetric(vertical: 8), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + SectionHeader( + SectionType.moment, + hasMore: (_magicSearchResults.length >= kSearchSectionLimit - 1), + ), + const SizedBox(height: 2), + SizedBox( + child: SingleChildScrollView( + clipBehavior: Clip.none, + padding: const EdgeInsets.symmetric(horizontal: 4.5), + physics: const BouncingScrollPhysics(), + scrollDirection: Axis.horizontal, + child: Row( + crossAxisAlignment: CrossAxisAlignment.start, + children: _magicSearchResults + .map( + (momentSearchResult) => + MomentRecommendation(momentSearchResult), + ) + .toList(), + ), + ), + ), + ], + ), + ); + } + } +} + +class MomentRecommendation extends StatelessWidget { + static const _width = 100.0; + static const _height = 145.0; + static const _borderWidth = 1.0; + static const _cornerRadius = 5.0; + static const _cornerSmoothing = 1.0; + final GenericSearchResult momentSearchResult; + const MomentRecommendation(this.momentSearchResult, {super.key}); + + @override + Widget build(BuildContext context) { + final heroTag = momentSearchResult.heroTag() + + (momentSearchResult.previewThumbnail()?.tag ?? ""); + final enteTextTheme = getEnteTextTheme(context); + return Padding( + padding: EdgeInsets.symmetric(horizontal: max(2.5 - _borderWidth, 0)), + child: GestureDetector( + onTap: () { + RecentSearches().add(momentSearchResult.name()); + if (momentSearchResult.onResultTap != null) { + momentSearchResult.onResultTap!(context); + } else { + routeToPage( + context, + SearchResultPage(momentSearchResult), + ); + } + }, + child: SizedBox( + width: _width + _borderWidth * 2, + height: _height + _borderWidth * 2, + child: Stack( + clipBehavior: Clip.none, + alignment: Alignment.center, + children: [ + ClipSmoothRect( + radius: SmoothBorderRadius( + cornerRadius: _cornerRadius + _borderWidth, + cornerSmoothing: _cornerSmoothing, + ), + child: Container( + color: Colors.white.withOpacity(0.16), + width: _width + _borderWidth * 2, + height: _height + _borderWidth * 2, + ), + ), + Container( + decoration: BoxDecoration( + boxShadow: [ + BoxShadow( + color: Colors.black.withOpacity(0.1), + blurRadius: 6.25, + offset: const Offset(-1.25, 2.5), + ), + ], + ), + child: ClipSmoothRect( + radius: SmoothBorderRadius( + cornerRadius: _cornerRadius, + cornerSmoothing: _cornerSmoothing, + ), + child: Stack( + alignment: Alignment.bottomCenter, + clipBehavior: Clip.none, + children: [ + SizedBox( + width: _width, + height: _height, + child: momentSearchResult.previewThumbnail() != null + ? Hero( + tag: heroTag, + child: ThumbnailWidget( + momentSearchResult.previewThumbnail()!, + shouldShowArchiveStatus: false, + shouldShowSyncStatus: false, + ), + ) + : const NoThumbnailWidget(), + ), + Container( + height: 145, + width: 100, + decoration: BoxDecoration( + gradient: LinearGradient( + begin: Alignment.topCenter, + end: Alignment.bottomCenter, + colors: [ + Colors.black.withOpacity(0), + Colors.black.withOpacity(0), + Colors.black.withOpacity(0.5), + ], + stops: const [ + 0, + 0.1, + 1, + ], + ), + ), + ), + ConstrainedBox( + constraints: const BoxConstraints( + maxWidth: 76, + ), + child: Padding( + padding: const EdgeInsets.only( + bottom: 8, + ), + child: Text( + momentSearchResult.name(), + style: enteTextTheme.small.copyWith( + color: Colors.white, + ), + maxLines: 3, + overflow: TextOverflow.fade, + ), + ), + ), + ], + ), + ), + ), + ], + ), + ), + ), + ); + } +} diff --git a/mobile/lib/ui/viewer/search_tab/search_tab.dart b/mobile/lib/ui/viewer/search_tab/search_tab.dart index 46dcfda036..9c3d1d2ae2 100644 --- a/mobile/lib/ui/viewer/search_tab/search_tab.dart +++ b/mobile/lib/ui/viewer/search_tab/search_tab.dart @@ -18,6 +18,7 @@ import "package:photos/ui/viewer/search_tab/contacts_section.dart"; import "package:photos/ui/viewer/search_tab/descriptions_section.dart"; import "package:photos/ui/viewer/search_tab/file_type_section.dart"; import "package:photos/ui/viewer/search_tab/locations_section.dart"; +import "package:photos/ui/viewer/search_tab/magic_section.dart"; import "package:photos/ui/viewer/search_tab/moments_section.dart"; import "package:photos/ui/viewer/search_tab/people_section.dart"; import "package:photos/utils/local_settings.dart"; @@ -82,7 +83,6 @@ class _AllSearchSectionsState extends State { @override Widget build(BuildContext context) { final searchTypes = SectionType.values.toList(growable: true); - searchTypes.remove(SectionType.content); return Padding( padding: const EdgeInsets.only(top: 8), @@ -153,6 +153,11 @@ class _AllSearchSectionsState extends State { snapshot.data!.elementAt(index) as List, ); + case SectionType.content: + return MagicSection( + snapshot.data!.elementAt(index) + as List, + ); default: const SizedBox.shrink(); }