Make ml related classes generic

This commit is contained in:
Neeraj Gupta
2025-04-14 16:23:06 +05:30
parent eaf74e4059
commit e9a24efecb
10 changed files with 55 additions and 62 deletions

View File

@@ -112,7 +112,7 @@ abstract class IMLDataDB<T> {
Future<List<EmbeddingVector>> getAllClipVectors();
Future<Map<int, int>> clipIndexedFileWithVersion();
Future<int> getClipIndexedFileCount({int minimumMlVersion});
Future<void> putClip(List<ClipEmbedding> embeddings);
Future<void> putClip(List<ClipEmbedding<T>> embeddings);
Future<void> deleteClipEmbeddings(List<T> fileIDs);
Future<void> deleteClipIndexes();
}

View File

@@ -2,8 +2,8 @@ import "dart:convert";
import "package:photos/models/ml/ml_versions.dart";
class ClipEmbedding {
final int fileID;
class ClipEmbedding<T> {
final T fileID;
final List<double> embedding;
int version;
@@ -15,7 +15,7 @@ class ClipEmbedding {
required this.version,
});
factory ClipEmbedding.empty(int fileID) {
factory ClipEmbedding.empty(T fileID) {
return ClipEmbedding(
fileID: fileID,
embedding: <double>[],

View File

@@ -16,7 +16,7 @@ class FileInfo {
});
}
class Face {
class Face<T> {
final String faceID;
final List<double> embedding;
Detection detection;
@@ -26,7 +26,7 @@ class Face {
///#region Local DB fields
// This is not stored on the server, using it for local DB row
FileInfo? fileInfo;
final int fileID;
final T fileID;
///#endregion
@@ -48,7 +48,7 @@ class Face {
factory Face.fromFaceResult(
FaceResult faceResult,
int fileID,
T fileID,
Dimensions decodedDimensions,
) {
final detection = Detection(
@@ -81,7 +81,7 @@ class Face {
);
}
factory Face.empty(int fileID, {bool error = false}) {
static Face<T> empty<T>(T fileID, {bool error = false}) {
return Face(
"${fileID}_0_0_0_0",
fileID,
@@ -92,9 +92,9 @@ class Face {
);
}
factory Face.fromJson(Map<String, dynamic> json) {
static Face<T> fromJson<T>(Map<String, dynamic> json) {
final String faceID = json['faceID'] as String;
final int fileID = getFileIdFromFaceId<int>(faceID);
final T fileID = getFileIdFromFaceId(faceID);
return Face(
faceID,
fileID,

View File

@@ -162,7 +162,8 @@ class FaceDetectionRelative extends Detection {
);
}
String toFaceID({required int fileID}) {
// todo: neeraj review later #localgallery
String toFaceID({required dynamic fileID}) {
// Assert that the values are within the expected range
assert(
(xMinBox >= 0 && xMinBox <= 1) &&

View File

@@ -71,8 +71,8 @@ class FaceRecognitionService {
}
}
static Future<List<FaceResult>> runFacesPipeline(
int enteFileID,
static Future<List<FaceResult>> runFacesPipeline<T>(
T enteFileID,
Image image,
Uint8List rawRgbaBytes,
int faceDetectionAddress,
@@ -131,8 +131,8 @@ class FaceRecognitionService {
}
/// Runs face recognition on the given image data.
static Future<List<FaceDetectionRelative>> _detectFacesSync(
int fileID,
static Future<List<FaceDetectionRelative>> _detectFacesSync<T>(
T fileID,
Image image,
Uint8List rawRgbaBytes,
int interpreterAddress,

View File

@@ -55,11 +55,11 @@ class MLIndexingIsolate extends SuperIsolate {
factory MLIndexingIsolate() => instance;
/// Analyzes the given image data by running the full pipeline for faces, using [analyzeImageStatic] in the isolate.
Future<MLResult?> analyzeImage(
Future<MLResult<T>?> analyzeImage<T>(
FileMLInstruction instruction,
String filePath,
) async {
late MLResult result;
late MLResult<T> result;
try {
final resultJsonString =
@@ -78,7 +78,7 @@ class MLIndexingIsolate extends SuperIsolate {
}
return null;
}
result = MLResult.fromJsonString(resultJsonString);
result = MLResult.fromJsonString<T>(resultJsonString);
} catch (e, s) {
_logger.severe(
"Could not analyze image with ID ${instruction.file.uploadedFileID} \n",

View File

@@ -7,8 +7,8 @@ import 'package:photos/services/machine_learning/face_ml/face_alignment/alignmen
import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart';
import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart';
class MLResult {
int fileId;
class MLResult<T> {
T fileId;
List<FaceResult>? faces = <FaceResult>[];
ClipResult? clip;
@@ -20,14 +20,14 @@ class MLResult {
bool get clipRan => clip != null;
MLResult({
this.fileId = -1,
required this.fileId,
this.faces,
this.clip,
this.decodedImageSize = const Dimensions(width: -1, height: -1),
});
MLResult.fromEnteFileID(
fileID, {
T fileID, {
this.decodedImageSize = const Dimensions(width: -1, height: -1),
}) : fileId = fileID;
@@ -43,7 +43,9 @@ class MLResult {
String toJsonString() => jsonEncode(_toJson());
static MLResult _fromJson(Map<String, dynamic> json) {
static MLResult<T> _fromJson<T>(
Map<String, dynamic> json,
) {
return MLResult(
fileId: json['fileId'],
faces: json['faces'] != null
@@ -70,13 +72,13 @@ class MLResult {
);
}
static MLResult fromJsonString(String jsonString) {
return _fromJson(jsonDecode(jsonString));
static MLResult<T> fromJsonString<T>(String jsonString) {
return _fromJson<T>(jsonDecode(jsonString));
}
}
class ClipResult {
final int fileID;
class ClipResult<T> {
final T fileID;
final Embedding embedding;
ClipResult({
@@ -97,12 +99,12 @@ class ClipResult {
}
}
class FaceResult {
class FaceResult<T> {
late FaceDetectionRelative detection;
late double blurValue;
late AlignmentResult alignment;
late Embedding embedding;
late int fileId;
late T fileId;
late String faceId;
bool get isBlurry => blurValue < kLaplacianHardThreshold;
@@ -118,7 +120,7 @@ class FaceResult {
FaceResult.fromFaceDetection(
FaceDetectionRelative faceDetection,
int fileID,
T fileID,
) {
fileId = fileID;
faceId = faceDetection.toFaceID(fileID: fileID);
@@ -146,7 +148,7 @@ class FaceResult {
}
}
T getFileIdFromFaceId<T extends Object>(String faceId) {
T getFileIdFromFaceId<T>(String faceId) {
final String faceIdSplit = faceId.substring(0, faceId.indexOf('_'));
if (T == int) {
return int.parse(faceIdSplit) as T;

View File

@@ -10,6 +10,7 @@ import "package:photos/db/files_db.dart";
import "package:photos/db/ml/db.dart";
import "package:photos/events/machine_learning_control_event.dart";
import "package:photos/events/people_changed_event.dart";
import "package:photos/models/ml/clip.dart";
import "package:photos/models/ml/face/face.dart";
import "package:photos/models/ml/ml_versions.dart";
import "package:photos/service_locator.dart";
@@ -20,7 +21,6 @@ import "package:photos/services/machine_learning/face_ml/face_clustering/face_db
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
import "package:photos/services/machine_learning/ml_indexing_isolate.dart";
import 'package:photos/services/machine_learning/ml_result.dart';
import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart";
import "package:photos/utils/ml_util.dart";
import "package:photos/utils/network_util.dart";
import "package:photos/utils/ram_check_util.dart";
@@ -428,13 +428,14 @@ class MLService {
}
}
Future<bool> processImage(FileMLInstruction instruction) async {
Future<bool> processImage<T>(FileMLInstruction instruction) async {
bool actuallyRanML = false;
try {
final String filePath = await getImagePathForML(instruction.file);
final MLResult? result = await MLIndexingIsolate.instance.analyzeImage(
final MLResult<T>? result =
await MLIndexingIsolate.instance.analyzeImage<T>(
instruction,
filePath,
);
@@ -502,9 +503,13 @@ class MLService {
// Storing results locally
if (result.facesRan) await mlDataDB.bulkInsertFaces(faces);
if (result.clipRan) {
await SemanticSearchService.instance.storeClipImageResult(
result.clip!,
);
await mlDataDB.putClip([
ClipEmbedding<int>(
fileID: result.clip!.fileID,
embedding: result.clip!.embedding,
version: clipMlVersion,
),
]);
}
_logger.info("ML results for fileID ${result.fileId} stored locally");
return actuallyRanML;
@@ -526,9 +531,9 @@ class MLService {
await mlDataDB.bulkInsertFaces(
[Face.empty(instruction.file.uploadedFileID!, error: true)],
);
await SemanticSearchService.instance.storeEmptyClipImageResult(
instruction.file,
);
await mlDataDB
.putClip([ClipEmbedding.empty(instruction.file.uploadedFileID!)]);
return true;
}
_logger.severe(

View File

@@ -12,7 +12,6 @@ import "package:photos/db/ml/db.dart";
import 'package:photos/events/embedding_updated_event.dart';
import "package:photos/models/file/file.dart";
import "package:photos/models/ml/clip.dart";
import "package:photos/models/ml/ml_versions.dart";
import "package:photos/service_locator.dart";
import "package:photos/services/collections_service.dart";
import "package:photos/services/machine_learning/ml_computer.dart";
@@ -235,20 +234,6 @@ class SemanticSearchService {
_logger.info("Clip text model loaded");
}
Future<void> storeClipImageResult(ClipResult clipResult) async {
final embedding = ClipEmbedding(
fileID: clipResult.fileID,
embedding: clipResult.embedding,
version: clipMlVersion,
);
await mlDataDB.putClip([embedding]);
}
Future<void> storeEmptyClipImageResult(EnteFile entefile) async {
final embedding = ClipEmbedding.empty(entefile.uploadedFileID!);
await mlDataDB.putClip([embedding]);
}
Future<List<double>> _getTextEmbedding(String query) async {
_logger.info("Searching for ${kDebugMode ? query : ''}");
final cachedResult = _queryEmbeddingCache.get(query);
@@ -292,8 +277,8 @@ class SemanticSearchService {
});
}
static Future<ClipResult> runClipImage(
int enteFileID,
static Future<ClipResult<T>> runClipImage<T>(
T enteFileID,
Image image,
Uint8List rawRgbaBytes,
int clipImageAddress,

View File

@@ -221,7 +221,7 @@ Stream<List<FileMLInstruction>> fetchEmbeddingsAndInstructions(
fileMl.getClipEmbeddingIfCompatible(clipMlVersion);
if (remoteClipEmbedding != null) {
clipEmbeddings.add(
ClipEmbedding(
ClipEmbedding<int>(
fileID: fileMl.fileID,
embedding: remoteClipEmbedding.embedding,
version: remoteClipEmbedding.version,
@@ -395,9 +395,9 @@ void normalizeEmbedding(List<double> embedding) {
}
}
Future<MLResult> analyzeImageStatic(Map args) async {
Future<MLResult<T>> analyzeImageStatic<T>(Map args) async {
try {
final int enteFileID = args["enteFileID"] as int;
final T enteFileID = args["enteFileID"] as T;
final String imagePath = args["filePath"] as String;
final bool runFaces = args["runFaces"] as bool;
final bool runClip = args["runClip"] as bool;
@@ -422,7 +422,7 @@ Future<MLResult> analyzeImageStatic(Map args) async {
String faceMsString = "", clipMsString = "";
final pipelines = await Future.wait([
runFaces
? FaceRecognitionService.runFacesPipeline(
? FaceRecognitionService.runFacesPipeline<T>(
enteFileID,
image,
rawRgbaBytes,
@@ -435,7 +435,7 @@ Future<MLResult> analyzeImageStatic(Map args) async {
})
: Future.value(null),
runClip
? SemanticSearchService.runClipImage(
? SemanticSearchService.runClipImage<T>(
enteFileID,
image,
rawRgbaBytes,