diff --git a/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart b/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart index 17546d58cc..e75546b803 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_recognition_service.dart @@ -101,12 +101,12 @@ class FaceRecognitionService { .toSet(); _logger.info('starting remote fetch for ${fileIds.length} files'); final res = - await RemoteFileMLService.instance.getFaceEmbedding(fileIds); + await RemoteFileMLService.instance.getFileEmbeddings(fileIds); _logger.info('fetched ${res.mlData.length} embeddings'); fetchedCount += res.mlData.length; final List faces = []; final remoteFileIdToVersion = {}; - for (FileMl fileMl in res.mlData.values) { + for (RemoteFileML fileMl in res.mlData.values) { if (shouldDiscardRemoteEmbedding(fileMl)) continue; if (fileMl.faceEmbedding.faces.isEmpty) { faces.add( @@ -117,8 +117,8 @@ class FaceRecognitionService { } else { for (final f in fileMl.faceEmbedding.faces) { f.fileInfo = FileInfo( - imageHeight: fileMl.height, - imageWidth: fileMl.width, + imageHeight: fileMl.faceEmbedding.height, + imageWidth: fileMl.faceEmbedding.width, ); faces.add(f); } diff --git a/mobile/lib/services/machine_learning/file_ml/file_ml.dart b/mobile/lib/services/machine_learning/file_ml/file_ml.dart index 4ac1370834..51b2ceec6a 100644 --- a/mobile/lib/services/machine_learning/file_ml/file_ml.dart +++ b/mobile/lib/services/machine_learning/file_ml/file_ml.dart @@ -1,46 +1,54 @@ import "package:photos/face/model/face.dart"; -class FileMl { +class RemoteFileML { final int fileID; - final int? height; - final int? width; - final FaceEmbeddings faceEmbedding; + final Map remoteRawData; + final RemoteFaceEmbedding faceEmbedding; + final RemoteClipEmbedding? clipEmbedding; - FileMl( + RemoteFileML( this.fileID, - this.faceEmbedding, { - this.height, - this.width, + this.remoteRawData, { + required this.faceEmbedding, + this.clipEmbedding, }); // toJson - Map toJson() => { - 'fileID': fileID, - 'height': height, - 'width': width, - 'faceEmbedding': faceEmbedding.toJson(), - }; - // fromJson - factory FileMl.fromJson(Map json) { - return FileMl( - json['fileID'] as int, - FaceEmbeddings.fromJson(json['faceEmbedding'] as Map), - height: json['height'] as int?, - width: json['width'] as int?, + Map toJson() { + throw UnimplementedError(); + } + + // fromRemote + factory RemoteFileML.fromRemote(int fileID, Map json) { + return RemoteFileML( + fileID, + json, + faceEmbedding: RemoteFaceEmbedding.fromJson( + json['face'] as Map, + ), + clipEmbedding: json['clip'] == null + ? null + : RemoteClipEmbedding.fromJson( + json['clip'] as Map, + ), ); } } -class FaceEmbeddings { +class RemoteFaceEmbedding { final List faces; final int version; - // pkgname/version + // packageName/version final String client; + final int height; + final int width; - FaceEmbeddings( + RemoteFaceEmbedding( this.faces, this.version, { required this.client, + required this.height, + required this.width, }); // toJson @@ -48,33 +56,48 @@ class FaceEmbeddings { 'faces': faces.map((x) => x.toJson()).toList(), 'version': version, 'client': client, + 'height': height, + 'width': width, }; + // fromJson - factory FaceEmbeddings.fromJson(Map json) { - return FaceEmbeddings( + factory RemoteFaceEmbedding.fromJson(Map json) { + return RemoteFaceEmbedding( List.from( json['faces'].map((x) => Face.fromJson(x as Map)), ), json['version'] as int, - client: json['client'] ?? 'unknown', + client: json['client'] as String, + height: json['height'] as int, + width: json['width'] as int, ); } } -class ClipEmbedding { - final int? version; +class RemoteClipEmbedding { + final int version; + final String client; final List embedding; - ClipEmbedding(this.embedding, {this.version}); + + RemoteClipEmbedding( + this.embedding, { + required this.version, + required this.client, + }); + // toJson Map toJson() => { - 'version': version, 'embedding': embedding, + 'version': version, + 'client': client, }; + // fromJson - factory ClipEmbedding.fromJson(Map json) { - return ClipEmbedding( + factory RemoteClipEmbedding.fromJson(Map json) { + return RemoteClipEmbedding( List.from(json['embedding'] as List), - version: json['version'] as int?, + version: json['version'] as int, + client: json['client'] as String, ); } } diff --git a/mobile/lib/services/machine_learning/file_ml/files_ml_data_response.dart b/mobile/lib/services/machine_learning/file_ml/files_ml_data_response.dart index 475f52d0a3..67b69dcf5c 100644 --- a/mobile/lib/services/machine_learning/file_ml/files_ml_data_response.dart +++ b/mobile/lib/services/machine_learning/file_ml/files_ml_data_response.dart @@ -1,7 +1,7 @@ import 'package:photos/services/machine_learning/file_ml/file_ml.dart'; class FilesMLDataResponse { - final Map mlData; + final Map mlData; // fileIDs that were indexed but they don't contain any meaningful embeddings // and hence should be discarded for re-indexing final Set noEmbeddingFileIDs; diff --git a/mobile/lib/services/machine_learning/file_ml/remote_fileml_service.dart b/mobile/lib/services/machine_learning/file_ml/remote_fileml_service.dart index edfd6293e1..51b52fbaf0 100644 --- a/mobile/lib/services/machine_learning/file_ml/remote_fileml_service.dart +++ b/mobile/lib/services/machine_learning/file_ml/remote_fileml_service.dart @@ -1,8 +1,9 @@ import "dart:async"; import "dart:convert"; +import "dart:io"; import "package:computer/computer.dart"; -import "package:flutter/foundation.dart" show debugPrint; +import "package:flutter/foundation.dart" show Uint8List, debugPrint; import "package:logging/logging.dart"; import "package:photos/core/network/network.dart"; import "package:photos/db/files_db.dart"; @@ -29,7 +30,8 @@ class RemoteFileMLService { void init(SharedPreferences prefs) {} - Future putFileEmbedding(EnteFile file, FileMl fileML) async { + Future putFileEmbedding(EnteFile file, RemoteFileML fileML) async { + throw Exception("need to update implementation"); final encryptionKey = getFileKey(file); final embeddingJSON = jsonEncode(fileML.toJson()); final encryptedEmbedding = await CryptoUtil.encryptChaCha( @@ -56,7 +58,7 @@ class RemoteFileMLService { } } - Future getFaceEmbedding( + Future getFileEmbeddings( Set fileIds, ) async { try { @@ -64,7 +66,7 @@ class RemoteFileMLService { "/embeddings/files", data: { "fileIDs": fileIds.toList(), - "model": 'file-ml-clip-face', + "model": 'ggml-clip', }, ); final remoteEmb = res.data['embeddings'] as List; @@ -93,10 +95,10 @@ class RemoteFileMLService { } } - Future> decryptFileMLData( + Future> decryptFileMLData( List remoteEmbeddings, ) async { - final result = {}; + final result = {}; if (remoteEmbeddings.isEmpty) { return result; } @@ -112,7 +114,7 @@ class RemoteFileMLService { final input = EmbeddingsDecoderInput(embedding, fileKey); inputs.add(input); } - return _computer.compute, Map>( + return _computer.compute, Map>( _decryptFileMLComputer, param: { "inputs": inputs, @@ -121,10 +123,16 @@ class RemoteFileMLService { } } -Future> _decryptFileMLComputer( +Uint8List ungzipUint8List(Uint8List compressedData) { + final codec = GZipCodec(); + final List decompressedList = codec.decode(compressedData); + return Uint8List.fromList(decompressedList); +} + +Future> _decryptFileMLComputer( Map args, ) async { - final result = {}; + final result = {}; final inputs = args["inputs"] as List; for (final input in inputs) { final decryptArgs = {}; @@ -134,15 +142,19 @@ Future> _decryptFileMLComputer( decryptArgs["header"] = CryptoUtil.base642bin(input.embedding.decryptionHeader); final embeddingData = chachaDecryptData(decryptArgs); - final decodedJson = jsonDecode(utf8.decode(embeddingData)); - final FileMl decodedEmbedding = - FileMl.fromJson(decodedJson as Map); + // unzip the gzip data + final unzippedData = ungzipUint8List(embeddingData); + final decodedJson = jsonDecode(utf8.decode(unzippedData)); + final RemoteFileML decodedEmbedding = RemoteFileML.fromRemote( + input.embedding.fileID, + decodedJson as Map, + ); result[input.embedding.fileID] = decodedEmbedding; } return result; } -bool shouldDiscardRemoteEmbedding(FileMl fileMl) { +bool shouldDiscardRemoteEmbedding(RemoteFileML fileMl) { if (fileMl.faceEmbedding.version < faceMlVersion) { debugPrint("Discarding remote embedding for fileID ${fileMl.fileID} " "because version is ${fileMl.faceEmbedding.version} and we need $faceMlVersion"); @@ -175,10 +187,6 @@ bool shouldDiscardRemoteEmbedding(FileMl fileMl) { ); return true; } - if (fileMl.width == null || fileMl.height == null) { - debugPrint("Discarding remote embedding for fileID ${fileMl.fileID} " - "because width is null"); - return true; - } + return false; } diff --git a/mobile/lib/services/machine_learning/ml_service.dart b/mobile/lib/services/machine_learning/ml_service.dart index 875d54a650..52aad2a824 100644 --- a/mobile/lib/services/machine_learning/ml_service.dart +++ b/mobile/lib/services/machine_learning/ml_service.dart @@ -29,8 +29,8 @@ import 'package:photos/services/machine_learning/face_ml/face_embedding/face_emb import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; import "package:photos/services/machine_learning/face_ml/face_recognition_service.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; -import 'package:photos/services/machine_learning/file_ml/file_ml.dart'; -import 'package:photos/services/machine_learning/file_ml/remote_fileml_service.dart'; +import "package:photos/services/machine_learning/file_ml/file_ml.dart"; +import "package:photos/services/machine_learning/file_ml/remote_fileml_service.dart"; import 'package:photos/services/machine_learning/ml_exceptions.dart'; import 'package:photos/services/machine_learning/ml_result.dart'; import "package:photos/services/machine_learning/semantic_search/clip/clip_image_encoder.dart"; @@ -470,15 +470,23 @@ class MLService { if (!result.errorOccured) { await RemoteFileMLService.instance.putFileEmbedding( instruction.enteFile, - FileMl( + RemoteFileML( instruction.enteFile.uploadedFileID!, - FaceEmbeddings( + {}, + faceEmbedding: RemoteFaceEmbedding( faces, result.mlVersion, client: client, + height: result.decodedImageSize.height, + width: result.decodedImageSize.width, ), - height: result.decodedImageSize.height, - width: result.decodedImageSize.width, + clipEmbedding: result.clipRan + ? RemoteClipEmbedding( + result.clip!.embedding, + version: result.mlVersion, + client: client, + ) + : null, ), ); } else { @@ -487,7 +495,6 @@ class MLService { ); } await FaceMLDataDB.instance.bulkInsertFaces(faces); - return actuallyRanML; } if (result.clipRan) {