Make ml related classes generic
This commit is contained in:
@@ -112,7 +112,7 @@ abstract class IMLDataDB<T> {
|
||||
Future<List<EmbeddingVector>> getAllClipVectors();
|
||||
Future<Map<int, int>> clipIndexedFileWithVersion();
|
||||
Future<int> getClipIndexedFileCount({int minimumMlVersion});
|
||||
Future<void> putClip(List<ClipEmbedding> embeddings);
|
||||
Future<void> putClip(List<ClipEmbedding<T>> embeddings);
|
||||
Future<void> deleteClipEmbeddings(List<T> fileIDs);
|
||||
Future<void> deleteClipIndexes();
|
||||
}
|
||||
|
||||
@@ -2,8 +2,8 @@ import "dart:convert";
|
||||
|
||||
import "package:photos/models/ml/ml_versions.dart";
|
||||
|
||||
class ClipEmbedding {
|
||||
final int fileID;
|
||||
class ClipEmbedding<T> {
|
||||
final T fileID;
|
||||
final List<double> embedding;
|
||||
int version;
|
||||
|
||||
@@ -15,7 +15,7 @@ class ClipEmbedding {
|
||||
required this.version,
|
||||
});
|
||||
|
||||
factory ClipEmbedding.empty(int fileID) {
|
||||
factory ClipEmbedding.empty(T fileID) {
|
||||
return ClipEmbedding(
|
||||
fileID: fileID,
|
||||
embedding: <double>[],
|
||||
|
||||
@@ -16,7 +16,7 @@ class FileInfo {
|
||||
});
|
||||
}
|
||||
|
||||
class Face {
|
||||
class Face<T> {
|
||||
final String faceID;
|
||||
final List<double> embedding;
|
||||
Detection detection;
|
||||
@@ -26,7 +26,7 @@ class Face {
|
||||
///#region Local DB fields
|
||||
// This is not stored on the server, using it for local DB row
|
||||
FileInfo? fileInfo;
|
||||
final int fileID;
|
||||
final T fileID;
|
||||
|
||||
///#endregion
|
||||
|
||||
@@ -48,7 +48,7 @@ class Face {
|
||||
|
||||
factory Face.fromFaceResult(
|
||||
FaceResult faceResult,
|
||||
int fileID,
|
||||
T fileID,
|
||||
Dimensions decodedDimensions,
|
||||
) {
|
||||
final detection = Detection(
|
||||
@@ -81,7 +81,7 @@ class Face {
|
||||
);
|
||||
}
|
||||
|
||||
factory Face.empty(int fileID, {bool error = false}) {
|
||||
static Face<T> empty<T>(T fileID, {bool error = false}) {
|
||||
return Face(
|
||||
"${fileID}_0_0_0_0",
|
||||
fileID,
|
||||
@@ -92,9 +92,9 @@ class Face {
|
||||
);
|
||||
}
|
||||
|
||||
factory Face.fromJson(Map<String, dynamic> json) {
|
||||
static Face<T> fromJson<T>(Map<String, dynamic> json) {
|
||||
final String faceID = json['faceID'] as String;
|
||||
final int fileID = getFileIdFromFaceId<int>(faceID);
|
||||
final T fileID = getFileIdFromFaceId(faceID);
|
||||
return Face(
|
||||
faceID,
|
||||
fileID,
|
||||
|
||||
@@ -162,7 +162,8 @@ class FaceDetectionRelative extends Detection {
|
||||
);
|
||||
}
|
||||
|
||||
String toFaceID({required int fileID}) {
|
||||
// todo: neeraj review later #localgallery
|
||||
String toFaceID({required dynamic fileID}) {
|
||||
// Assert that the values are within the expected range
|
||||
assert(
|
||||
(xMinBox >= 0 && xMinBox <= 1) &&
|
||||
|
||||
@@ -71,8 +71,8 @@ class FaceRecognitionService {
|
||||
}
|
||||
}
|
||||
|
||||
static Future<List<FaceResult>> runFacesPipeline(
|
||||
int enteFileID,
|
||||
static Future<List<FaceResult>> runFacesPipeline<T>(
|
||||
T enteFileID,
|
||||
Image image,
|
||||
Uint8List rawRgbaBytes,
|
||||
int faceDetectionAddress,
|
||||
@@ -131,8 +131,8 @@ class FaceRecognitionService {
|
||||
}
|
||||
|
||||
/// Runs face recognition on the given image data.
|
||||
static Future<List<FaceDetectionRelative>> _detectFacesSync(
|
||||
int fileID,
|
||||
static Future<List<FaceDetectionRelative>> _detectFacesSync<T>(
|
||||
T fileID,
|
||||
Image image,
|
||||
Uint8List rawRgbaBytes,
|
||||
int interpreterAddress,
|
||||
|
||||
@@ -55,11 +55,11 @@ class MLIndexingIsolate extends SuperIsolate {
|
||||
factory MLIndexingIsolate() => instance;
|
||||
|
||||
/// Analyzes the given image data by running the full pipeline for faces, using [analyzeImageStatic] in the isolate.
|
||||
Future<MLResult?> analyzeImage(
|
||||
Future<MLResult<T>?> analyzeImage<T>(
|
||||
FileMLInstruction instruction,
|
||||
String filePath,
|
||||
) async {
|
||||
late MLResult result;
|
||||
late MLResult<T> result;
|
||||
|
||||
try {
|
||||
final resultJsonString =
|
||||
@@ -78,7 +78,7 @@ class MLIndexingIsolate extends SuperIsolate {
|
||||
}
|
||||
return null;
|
||||
}
|
||||
result = MLResult.fromJsonString(resultJsonString);
|
||||
result = MLResult.fromJsonString<T>(resultJsonString);
|
||||
} catch (e, s) {
|
||||
_logger.severe(
|
||||
"Could not analyze image with ID ${instruction.file.uploadedFileID} \n",
|
||||
|
||||
@@ -7,8 +7,8 @@ import 'package:photos/services/machine_learning/face_ml/face_alignment/alignmen
|
||||
import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart';
|
||||
import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart';
|
||||
|
||||
class MLResult {
|
||||
int fileId;
|
||||
class MLResult<T> {
|
||||
T fileId;
|
||||
|
||||
List<FaceResult>? faces = <FaceResult>[];
|
||||
ClipResult? clip;
|
||||
@@ -20,14 +20,14 @@ class MLResult {
|
||||
bool get clipRan => clip != null;
|
||||
|
||||
MLResult({
|
||||
this.fileId = -1,
|
||||
required this.fileId,
|
||||
this.faces,
|
||||
this.clip,
|
||||
this.decodedImageSize = const Dimensions(width: -1, height: -1),
|
||||
});
|
||||
|
||||
MLResult.fromEnteFileID(
|
||||
fileID, {
|
||||
T fileID, {
|
||||
this.decodedImageSize = const Dimensions(width: -1, height: -1),
|
||||
}) : fileId = fileID;
|
||||
|
||||
@@ -43,7 +43,9 @@ class MLResult {
|
||||
|
||||
String toJsonString() => jsonEncode(_toJson());
|
||||
|
||||
static MLResult _fromJson(Map<String, dynamic> json) {
|
||||
static MLResult<T> _fromJson<T>(
|
||||
Map<String, dynamic> json,
|
||||
) {
|
||||
return MLResult(
|
||||
fileId: json['fileId'],
|
||||
faces: json['faces'] != null
|
||||
@@ -70,13 +72,13 @@ class MLResult {
|
||||
);
|
||||
}
|
||||
|
||||
static MLResult fromJsonString(String jsonString) {
|
||||
return _fromJson(jsonDecode(jsonString));
|
||||
static MLResult<T> fromJsonString<T>(String jsonString) {
|
||||
return _fromJson<T>(jsonDecode(jsonString));
|
||||
}
|
||||
}
|
||||
|
||||
class ClipResult {
|
||||
final int fileID;
|
||||
class ClipResult<T> {
|
||||
final T fileID;
|
||||
final Embedding embedding;
|
||||
|
||||
ClipResult({
|
||||
@@ -97,12 +99,12 @@ class ClipResult {
|
||||
}
|
||||
}
|
||||
|
||||
class FaceResult {
|
||||
class FaceResult<T> {
|
||||
late FaceDetectionRelative detection;
|
||||
late double blurValue;
|
||||
late AlignmentResult alignment;
|
||||
late Embedding embedding;
|
||||
late int fileId;
|
||||
late T fileId;
|
||||
late String faceId;
|
||||
|
||||
bool get isBlurry => blurValue < kLaplacianHardThreshold;
|
||||
@@ -118,7 +120,7 @@ class FaceResult {
|
||||
|
||||
FaceResult.fromFaceDetection(
|
||||
FaceDetectionRelative faceDetection,
|
||||
int fileID,
|
||||
T fileID,
|
||||
) {
|
||||
fileId = fileID;
|
||||
faceId = faceDetection.toFaceID(fileID: fileID);
|
||||
@@ -146,7 +148,7 @@ class FaceResult {
|
||||
}
|
||||
}
|
||||
|
||||
T getFileIdFromFaceId<T extends Object>(String faceId) {
|
||||
T getFileIdFromFaceId<T>(String faceId) {
|
||||
final String faceIdSplit = faceId.substring(0, faceId.indexOf('_'));
|
||||
if (T == int) {
|
||||
return int.parse(faceIdSplit) as T;
|
||||
|
||||
@@ -10,6 +10,7 @@ import "package:photos/db/files_db.dart";
|
||||
import "package:photos/db/ml/db.dart";
|
||||
import "package:photos/events/machine_learning_control_event.dart";
|
||||
import "package:photos/events/people_changed_event.dart";
|
||||
import "package:photos/models/ml/clip.dart";
|
||||
import "package:photos/models/ml/face/face.dart";
|
||||
import "package:photos/models/ml/ml_versions.dart";
|
||||
import "package:photos/service_locator.dart";
|
||||
@@ -20,7 +21,6 @@ import "package:photos/services/machine_learning/face_ml/face_clustering/face_db
|
||||
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
|
||||
import "package:photos/services/machine_learning/ml_indexing_isolate.dart";
|
||||
import 'package:photos/services/machine_learning/ml_result.dart';
|
||||
import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart";
|
||||
import "package:photos/utils/ml_util.dart";
|
||||
import "package:photos/utils/network_util.dart";
|
||||
import "package:photos/utils/ram_check_util.dart";
|
||||
@@ -428,13 +428,14 @@ class MLService {
|
||||
}
|
||||
}
|
||||
|
||||
Future<bool> processImage(FileMLInstruction instruction) async {
|
||||
Future<bool> processImage<T>(FileMLInstruction instruction) async {
|
||||
bool actuallyRanML = false;
|
||||
|
||||
try {
|
||||
final String filePath = await getImagePathForML(instruction.file);
|
||||
|
||||
final MLResult? result = await MLIndexingIsolate.instance.analyzeImage(
|
||||
final MLResult<T>? result =
|
||||
await MLIndexingIsolate.instance.analyzeImage<T>(
|
||||
instruction,
|
||||
filePath,
|
||||
);
|
||||
@@ -502,9 +503,13 @@ class MLService {
|
||||
// Storing results locally
|
||||
if (result.facesRan) await mlDataDB.bulkInsertFaces(faces);
|
||||
if (result.clipRan) {
|
||||
await SemanticSearchService.instance.storeClipImageResult(
|
||||
result.clip!,
|
||||
);
|
||||
await mlDataDB.putClip([
|
||||
ClipEmbedding<int>(
|
||||
fileID: result.clip!.fileID,
|
||||
embedding: result.clip!.embedding,
|
||||
version: clipMlVersion,
|
||||
),
|
||||
]);
|
||||
}
|
||||
_logger.info("ML results for fileID ${result.fileId} stored locally");
|
||||
return actuallyRanML;
|
||||
@@ -526,9 +531,9 @@ class MLService {
|
||||
await mlDataDB.bulkInsertFaces(
|
||||
[Face.empty(instruction.file.uploadedFileID!, error: true)],
|
||||
);
|
||||
await SemanticSearchService.instance.storeEmptyClipImageResult(
|
||||
instruction.file,
|
||||
);
|
||||
|
||||
await mlDataDB
|
||||
.putClip([ClipEmbedding.empty(instruction.file.uploadedFileID!)]);
|
||||
return true;
|
||||
}
|
||||
_logger.severe(
|
||||
|
||||
@@ -12,7 +12,6 @@ import "package:photos/db/ml/db.dart";
|
||||
import 'package:photos/events/embedding_updated_event.dart';
|
||||
import "package:photos/models/file/file.dart";
|
||||
import "package:photos/models/ml/clip.dart";
|
||||
import "package:photos/models/ml/ml_versions.dart";
|
||||
import "package:photos/service_locator.dart";
|
||||
import "package:photos/services/collections_service.dart";
|
||||
import "package:photos/services/machine_learning/ml_computer.dart";
|
||||
@@ -235,20 +234,6 @@ class SemanticSearchService {
|
||||
_logger.info("Clip text model loaded");
|
||||
}
|
||||
|
||||
Future<void> storeClipImageResult(ClipResult clipResult) async {
|
||||
final embedding = ClipEmbedding(
|
||||
fileID: clipResult.fileID,
|
||||
embedding: clipResult.embedding,
|
||||
version: clipMlVersion,
|
||||
);
|
||||
await mlDataDB.putClip([embedding]);
|
||||
}
|
||||
|
||||
Future<void> storeEmptyClipImageResult(EnteFile entefile) async {
|
||||
final embedding = ClipEmbedding.empty(entefile.uploadedFileID!);
|
||||
await mlDataDB.putClip([embedding]);
|
||||
}
|
||||
|
||||
Future<List<double>> _getTextEmbedding(String query) async {
|
||||
_logger.info("Searching for ${kDebugMode ? query : ''}");
|
||||
final cachedResult = _queryEmbeddingCache.get(query);
|
||||
@@ -292,8 +277,8 @@ class SemanticSearchService {
|
||||
});
|
||||
}
|
||||
|
||||
static Future<ClipResult> runClipImage(
|
||||
int enteFileID,
|
||||
static Future<ClipResult<T>> runClipImage<T>(
|
||||
T enteFileID,
|
||||
Image image,
|
||||
Uint8List rawRgbaBytes,
|
||||
int clipImageAddress,
|
||||
|
||||
@@ -221,7 +221,7 @@ Stream<List<FileMLInstruction>> fetchEmbeddingsAndInstructions(
|
||||
fileMl.getClipEmbeddingIfCompatible(clipMlVersion);
|
||||
if (remoteClipEmbedding != null) {
|
||||
clipEmbeddings.add(
|
||||
ClipEmbedding(
|
||||
ClipEmbedding<int>(
|
||||
fileID: fileMl.fileID,
|
||||
embedding: remoteClipEmbedding.embedding,
|
||||
version: remoteClipEmbedding.version,
|
||||
@@ -395,9 +395,9 @@ void normalizeEmbedding(List<double> embedding) {
|
||||
}
|
||||
}
|
||||
|
||||
Future<MLResult> analyzeImageStatic(Map args) async {
|
||||
Future<MLResult<T>> analyzeImageStatic<T>(Map args) async {
|
||||
try {
|
||||
final int enteFileID = args["enteFileID"] as int;
|
||||
final T enteFileID = args["enteFileID"] as T;
|
||||
final String imagePath = args["filePath"] as String;
|
||||
final bool runFaces = args["runFaces"] as bool;
|
||||
final bool runClip = args["runClip"] as bool;
|
||||
@@ -422,7 +422,7 @@ Future<MLResult> analyzeImageStatic(Map args) async {
|
||||
String faceMsString = "", clipMsString = "";
|
||||
final pipelines = await Future.wait([
|
||||
runFaces
|
||||
? FaceRecognitionService.runFacesPipeline(
|
||||
? FaceRecognitionService.runFacesPipeline<T>(
|
||||
enteFileID,
|
||||
image,
|
||||
rawRgbaBytes,
|
||||
@@ -435,7 +435,7 @@ Future<MLResult> analyzeImageStatic(Map args) async {
|
||||
})
|
||||
: Future.value(null),
|
||||
runClip
|
||||
? SemanticSearchService.runClipImage(
|
||||
? SemanticSearchService.runClipImage<T>(
|
||||
enteFileID,
|
||||
image,
|
||||
rawRgbaBytes,
|
||||
|
||||
Reference in New Issue
Block a user