[mob] Store face embeddings from remote

This commit is contained in:
Neeraj Gupta
2024-07-20 00:04:44 +05:30
parent 6d06c02148
commit 39b8223979
5 changed files with 102 additions and 64 deletions

View File

@@ -101,12 +101,12 @@ class FaceRecognitionService {
.toSet();
_logger.info('starting remote fetch for ${fileIds.length} files');
final res =
await RemoteFileMLService.instance.getFaceEmbedding(fileIds);
await RemoteFileMLService.instance.getFileEmbeddings(fileIds);
_logger.info('fetched ${res.mlData.length} embeddings');
fetchedCount += res.mlData.length;
final List<Face> faces = [];
final remoteFileIdToVersion = <int, int>{};
for (FileMl fileMl in res.mlData.values) {
for (RemoteFileML fileMl in res.mlData.values) {
if (shouldDiscardRemoteEmbedding(fileMl)) continue;
if (fileMl.faceEmbedding.faces.isEmpty) {
faces.add(
@@ -117,8 +117,8 @@ class FaceRecognitionService {
} else {
for (final f in fileMl.faceEmbedding.faces) {
f.fileInfo = FileInfo(
imageHeight: fileMl.height,
imageWidth: fileMl.width,
imageHeight: fileMl.faceEmbedding.height,
imageWidth: fileMl.faceEmbedding.width,
);
faces.add(f);
}

View File

@@ -1,46 +1,54 @@
import "package:photos/face/model/face.dart";
class FileMl {
class RemoteFileML {
final int fileID;
final int? height;
final int? width;
final FaceEmbeddings faceEmbedding;
final Map<String, dynamic> remoteRawData;
final RemoteFaceEmbedding faceEmbedding;
final RemoteClipEmbedding? clipEmbedding;
FileMl(
RemoteFileML(
this.fileID,
this.faceEmbedding, {
this.height,
this.width,
this.remoteRawData, {
required this.faceEmbedding,
this.clipEmbedding,
});
// toJson
Map<String, dynamic> toJson() => {
'fileID': fileID,
'height': height,
'width': width,
'faceEmbedding': faceEmbedding.toJson(),
};
// fromJson
factory FileMl.fromJson(Map<String, dynamic> json) {
return FileMl(
json['fileID'] as int,
FaceEmbeddings.fromJson(json['faceEmbedding'] as Map<String, dynamic>),
height: json['height'] as int?,
width: json['width'] as int?,
Map<String, dynamic> toJson() {
throw UnimplementedError();
}
// fromRemote
factory RemoteFileML.fromRemote(int fileID, Map<String, dynamic> json) {
return RemoteFileML(
fileID,
json,
faceEmbedding: RemoteFaceEmbedding.fromJson(
json['face'] as Map<String, dynamic>,
),
clipEmbedding: json['clip'] == null
? null
: RemoteClipEmbedding.fromJson(
json['clip'] as Map<String, dynamic>,
),
);
}
}
class FaceEmbeddings {
class RemoteFaceEmbedding {
final List<Face> faces;
final int version;
// pkgname/version
// packageName/version
final String client;
final int height;
final int width;
FaceEmbeddings(
RemoteFaceEmbedding(
this.faces,
this.version, {
required this.client,
required this.height,
required this.width,
});
// toJson
@@ -48,33 +56,48 @@ class FaceEmbeddings {
'faces': faces.map((x) => x.toJson()).toList(),
'version': version,
'client': client,
'height': height,
'width': width,
};
// fromJson
factory FaceEmbeddings.fromJson(Map<String, dynamic> json) {
return FaceEmbeddings(
factory RemoteFaceEmbedding.fromJson(Map<String, dynamic> json) {
return RemoteFaceEmbedding(
List<Face>.from(
json['faces'].map((x) => Face.fromJson(x as Map<String, dynamic>)),
),
json['version'] as int,
client: json['client'] ?? 'unknown',
client: json['client'] as String,
height: json['height'] as int,
width: json['width'] as int,
);
}
}
class ClipEmbedding {
final int? version;
class RemoteClipEmbedding {
final int version;
final String client;
final List<double> embedding;
ClipEmbedding(this.embedding, {this.version});
RemoteClipEmbedding(
this.embedding, {
required this.version,
required this.client,
});
// toJson
Map<String, dynamic> toJson() => {
'version': version,
'embedding': embedding,
'version': version,
'client': client,
};
// fromJson
factory ClipEmbedding.fromJson(Map<String, dynamic> json) {
return ClipEmbedding(
factory RemoteClipEmbedding.fromJson(Map<String, dynamic> json) {
return RemoteClipEmbedding(
List<double>.from(json['embedding'] as List),
version: json['version'] as int?,
version: json['version'] as int,
client: json['client'] as String,
);
}
}

View File

@@ -1,7 +1,7 @@
import 'package:photos/services/machine_learning/file_ml/file_ml.dart';
class FilesMLDataResponse {
final Map<int, FileMl> mlData;
final Map<int, RemoteFileML> mlData;
// fileIDs that were indexed but they don't contain any meaningful embeddings
// and hence should be discarded for re-indexing
final Set<int> noEmbeddingFileIDs;

View File

@@ -1,8 +1,9 @@
import "dart:async";
import "dart:convert";
import "dart:io";
import "package:computer/computer.dart";
import "package:flutter/foundation.dart" show debugPrint;
import "package:flutter/foundation.dart" show Uint8List, debugPrint;
import "package:logging/logging.dart";
import "package:photos/core/network/network.dart";
import "package:photos/db/files_db.dart";
@@ -29,7 +30,8 @@ class RemoteFileMLService {
void init(SharedPreferences prefs) {}
Future<void> putFileEmbedding(EnteFile file, FileMl fileML) async {
Future<void> putFileEmbedding(EnteFile file, RemoteFileML fileML) async {
throw Exception("need to update implementation");
final encryptionKey = getFileKey(file);
final embeddingJSON = jsonEncode(fileML.toJson());
final encryptedEmbedding = await CryptoUtil.encryptChaCha(
@@ -56,7 +58,7 @@ class RemoteFileMLService {
}
}
Future<FilesMLDataResponse> getFaceEmbedding(
Future<FilesMLDataResponse> getFileEmbeddings(
Set<int> fileIds,
) async {
try {
@@ -64,7 +66,7 @@ class RemoteFileMLService {
"/embeddings/files",
data: {
"fileIDs": fileIds.toList(),
"model": 'file-ml-clip-face',
"model": 'ggml-clip',
},
);
final remoteEmb = res.data['embeddings'] as List;
@@ -93,10 +95,10 @@ class RemoteFileMLService {
}
}
Future<Map<int, FileMl>> decryptFileMLData(
Future<Map<int, RemoteFileML>> decryptFileMLData(
List<RemoteEmbedding> remoteEmbeddings,
) async {
final result = <int, FileMl>{};
final result = <int, RemoteFileML>{};
if (remoteEmbeddings.isEmpty) {
return result;
}
@@ -112,7 +114,7 @@ class RemoteFileMLService {
final input = EmbeddingsDecoderInput(embedding, fileKey);
inputs.add(input);
}
return _computer.compute<Map<String, dynamic>, Map<int, FileMl>>(
return _computer.compute<Map<String, dynamic>, Map<int, RemoteFileML>>(
_decryptFileMLComputer,
param: {
"inputs": inputs,
@@ -121,10 +123,16 @@ class RemoteFileMLService {
}
}
Future<Map<int, FileMl>> _decryptFileMLComputer(
Uint8List ungzipUint8List(Uint8List compressedData) {
final codec = GZipCodec();
final List<int> decompressedList = codec.decode(compressedData);
return Uint8List.fromList(decompressedList);
}
Future<Map<int, RemoteFileML>> _decryptFileMLComputer(
Map<String, dynamic> args,
) async {
final result = <int, FileMl>{};
final result = <int, RemoteFileML>{};
final inputs = args["inputs"] as List<EmbeddingsDecoderInput>;
for (final input in inputs) {
final decryptArgs = <String, dynamic>{};
@@ -134,15 +142,19 @@ Future<Map<int, FileMl>> _decryptFileMLComputer(
decryptArgs["header"] =
CryptoUtil.base642bin(input.embedding.decryptionHeader);
final embeddingData = chachaDecryptData(decryptArgs);
final decodedJson = jsonDecode(utf8.decode(embeddingData));
final FileMl decodedEmbedding =
FileMl.fromJson(decodedJson as Map<String, dynamic>);
// unzip the gzip data
final unzippedData = ungzipUint8List(embeddingData);
final decodedJson = jsonDecode(utf8.decode(unzippedData));
final RemoteFileML decodedEmbedding = RemoteFileML.fromRemote(
input.embedding.fileID,
decodedJson as Map<String, dynamic>,
);
result[input.embedding.fileID] = decodedEmbedding;
}
return result;
}
bool shouldDiscardRemoteEmbedding(FileMl fileMl) {
bool shouldDiscardRemoteEmbedding(RemoteFileML fileMl) {
if (fileMl.faceEmbedding.version < faceMlVersion) {
debugPrint("Discarding remote embedding for fileID ${fileMl.fileID} "
"because version is ${fileMl.faceEmbedding.version} and we need $faceMlVersion");
@@ -175,10 +187,6 @@ bool shouldDiscardRemoteEmbedding(FileMl fileMl) {
);
return true;
}
if (fileMl.width == null || fileMl.height == null) {
debugPrint("Discarding remote embedding for fileID ${fileMl.fileID} "
"because width is null");
return true;
}
return false;
}

View File

@@ -29,8 +29,8 @@ import 'package:photos/services/machine_learning/face_ml/face_embedding/face_emb
import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart';
import "package:photos/services/machine_learning/face_ml/face_recognition_service.dart";
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
import 'package:photos/services/machine_learning/file_ml/file_ml.dart';
import 'package:photos/services/machine_learning/file_ml/remote_fileml_service.dart';
import "package:photos/services/machine_learning/file_ml/file_ml.dart";
import "package:photos/services/machine_learning/file_ml/remote_fileml_service.dart";
import 'package:photos/services/machine_learning/ml_exceptions.dart';
import 'package:photos/services/machine_learning/ml_result.dart';
import "package:photos/services/machine_learning/semantic_search/clip/clip_image_encoder.dart";
@@ -470,15 +470,23 @@ class MLService {
if (!result.errorOccured) {
await RemoteFileMLService.instance.putFileEmbedding(
instruction.enteFile,
FileMl(
RemoteFileML(
instruction.enteFile.uploadedFileID!,
FaceEmbeddings(
{},
faceEmbedding: RemoteFaceEmbedding(
faces,
result.mlVersion,
client: client,
height: result.decodedImageSize.height,
width: result.decodedImageSize.width,
),
height: result.decodedImageSize.height,
width: result.decodedImageSize.width,
clipEmbedding: result.clipRan
? RemoteClipEmbedding(
result.clip!.embedding,
version: result.mlVersion,
client: client,
)
: null,
),
);
} else {
@@ -487,7 +495,6 @@ class MLService {
);
}
await FaceMLDataDB.instance.bulkInsertFaces(faces);
return actuallyRanML;
}
if (result.clipRan) {