[mob][photos] Remove any model differentiation between onnx and ggml
This commit is contained in:
@@ -16,7 +16,6 @@ class EmbeddingsDB {
|
||||
static const databaseName = "ente.embeddings.db";
|
||||
static const tableName = "embeddings";
|
||||
static const columnFileID = "file_id";
|
||||
static const columnModel = "model";
|
||||
static const columnEmbedding = "embedding";
|
||||
static const columnUpdationTime = "updation_time";
|
||||
|
||||
@@ -42,7 +41,7 @@ class EmbeddingsDB {
|
||||
1,
|
||||
(tx) async {
|
||||
await tx.execute(
|
||||
'CREATE TABLE $tableName ($columnFileID INTEGER NOT NULL, $columnModel INTEGER NOT NULL, $columnEmbedding BLOB NOT NULL, $columnUpdationTime INTEGER, UNIQUE ($columnFileID, $columnModel))',
|
||||
'CREATE TABLE $tableName ($columnFileID INTEGER NOT NULL, $columnEmbedding BLOB NOT NULL, $columnUpdationTime INTEGER, UNIQUE ($columnFileID))',
|
||||
);
|
||||
},
|
||||
),
|
||||
@@ -57,19 +56,16 @@ class EmbeddingsDB {
|
||||
await db.execute('DELETE FROM $tableName');
|
||||
}
|
||||
|
||||
Future<List<Embedding>> getAll(Model model) async {
|
||||
Future<List<Embedding>> getAll() async {
|
||||
final db = await _database;
|
||||
final results = await db.getAll('SELECT * FROM $tableName');
|
||||
return _convertToEmbeddings(results);
|
||||
}
|
||||
|
||||
// Get FileIDs for a specific model
|
||||
Future<Set<int>> getFileIDs(Model model) async {
|
||||
Future<Set<int>> getFileIDs() async {
|
||||
final db = await _database;
|
||||
final results = await db.getAll(
|
||||
'SELECT $columnFileID FROM $tableName WHERE $columnModel = ?',
|
||||
[modelToInt(model)!],
|
||||
);
|
||||
final results = await db.getAll('SELECT $columnFileID FROM $tableName');
|
||||
if (results.isEmpty) {
|
||||
return <int>{};
|
||||
}
|
||||
@@ -79,7 +75,7 @@ class EmbeddingsDB {
|
||||
Future<void> put(Embedding embedding) async {
|
||||
final db = await _database;
|
||||
await db.execute(
|
||||
'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnModel, $columnEmbedding, $columnUpdationTime) VALUES (?, ?, ?, ?)',
|
||||
'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnUpdationTime) VALUES (?, ?, ?, ?)',
|
||||
_getRowFromEmbedding(embedding),
|
||||
);
|
||||
Bus.instance.fire(EmbeddingUpdatedEvent());
|
||||
@@ -89,7 +85,7 @@ class EmbeddingsDB {
|
||||
final db = await _database;
|
||||
final inputs = embeddings.map((e) => _getRowFromEmbedding(e)).toList();
|
||||
await db.executeBatch(
|
||||
'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnModel, $columnEmbedding, $columnUpdationTime) values(?, ?, ?, ?)',
|
||||
'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnUpdationTime) values(?, ?, ?, ?)',
|
||||
inputs,
|
||||
);
|
||||
Bus.instance.fire(EmbeddingUpdatedEvent());
|
||||
@@ -111,12 +107,9 @@ class EmbeddingsDB {
|
||||
Bus.instance.fire(EmbeddingUpdatedEvent());
|
||||
}
|
||||
|
||||
Future<void> deleteAllForModel(Model model) async {
|
||||
Future<void> deleteAll() async {
|
||||
final db = await _database;
|
||||
await db.execute(
|
||||
'DELETE FROM $tableName WHERE $columnModel = ?',
|
||||
[modelToInt(model)!],
|
||||
);
|
||||
await db.execute('DELETE FROM $tableName');
|
||||
Bus.instance.fire(EmbeddingUpdatedEvent());
|
||||
}
|
||||
|
||||
@@ -132,16 +125,14 @@ class EmbeddingsDB {
|
||||
|
||||
Embedding _getEmbeddingFromRow(Map<String, dynamic> row) {
|
||||
final fileID = row[columnFileID];
|
||||
final model = intToModel(row[columnModel])!;
|
||||
final bytes = row[columnEmbedding] as Uint8List;
|
||||
final list = Float32List.view(bytes.buffer);
|
||||
return Embedding(fileID: fileID, model: model, embedding: list);
|
||||
return Embedding(fileID: fileID, embedding: list);
|
||||
}
|
||||
|
||||
List<Object?> _getRowFromEmbedding(Embedding embedding) {
|
||||
return [
|
||||
embedding.fileID,
|
||||
modelToInt(embedding.model)!,
|
||||
Float32List.fromList(embedding.embedding).buffer.asUint8List(),
|
||||
embedding.updationTime,
|
||||
];
|
||||
@@ -157,26 +148,4 @@ class EmbeddingsDB {
|
||||
await deprecatedIsar.delete();
|
||||
}
|
||||
}
|
||||
|
||||
int? modelToInt(Model model) {
|
||||
switch (model) {
|
||||
case Model.onnxClip:
|
||||
return 1;
|
||||
case Model.ggmlClip:
|
||||
return 2;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
Model? intToModel(int model) {
|
||||
switch (model) {
|
||||
case 1:
|
||||
return Model.onnxClip;
|
||||
case 2:
|
||||
return Model.ggmlClip;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ import "dart:convert";
|
||||
|
||||
class Embedding {
|
||||
final int fileID;
|
||||
final Model model;
|
||||
final List<double> embedding;
|
||||
int? updationTime;
|
||||
|
||||
@@ -10,15 +9,13 @@ class Embedding {
|
||||
|
||||
Embedding({
|
||||
required this.fileID,
|
||||
required this.model,
|
||||
required this.embedding,
|
||||
this.updationTime,
|
||||
});
|
||||
|
||||
factory Embedding.empty(int fileID, Model model) {
|
||||
factory Embedding.empty(int fileID) {
|
||||
return Embedding(
|
||||
fileID: fileID,
|
||||
model: model,
|
||||
embedding: <double>[],
|
||||
);
|
||||
}
|
||||
@@ -31,34 +28,3 @@ class Embedding {
|
||||
return jsonEncode(embedding);
|
||||
}
|
||||
}
|
||||
|
||||
enum Model {
|
||||
onnxClip,
|
||||
ggmlClip,
|
||||
}
|
||||
|
||||
extension ModelExtension on Model {
|
||||
String get name => serialize(this);
|
||||
}
|
||||
|
||||
String serialize(Model model) {
|
||||
switch (model) {
|
||||
case Model.onnxClip:
|
||||
return 'onnx-clip';
|
||||
case Model.ggmlClip:
|
||||
return 'ggml-clip';
|
||||
default:
|
||||
throw Exception('$model is not a valid Model');
|
||||
}
|
||||
}
|
||||
|
||||
Model deserialize(String model) {
|
||||
switch (model) {
|
||||
case 'onnx-clip':
|
||||
return Model.onnxClip;
|
||||
case 'ggml-clip':
|
||||
return Model.ggmlClip;
|
||||
default:
|
||||
throw Exception('$model is not a valid Model');
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,17 +33,16 @@ class EmbeddingStore {
|
||||
_preferences = await SharedPreferences.getInstance();
|
||||
}
|
||||
|
||||
Future<bool> pullEmbeddings(Model model) async {
|
||||
return true; // TODO: remove this
|
||||
Future<bool> pullEmbeddings() async {
|
||||
if (_remoteSyncStatus != null) {
|
||||
return _remoteSyncStatus!.future;
|
||||
}
|
||||
_remoteSyncStatus = Completer();
|
||||
try {
|
||||
var remoteEmbeddings = await _getRemoteEmbeddings(model);
|
||||
var remoteEmbeddings = await _getRemoteEmbeddings();
|
||||
await _storeRemoteEmbeddings(remoteEmbeddings.embeddings);
|
||||
while (remoteEmbeddings.hasMore) {
|
||||
remoteEmbeddings = await _getRemoteEmbeddings(model);
|
||||
remoteEmbeddings = await _getRemoteEmbeddings();
|
||||
await _storeRemoteEmbeddings(remoteEmbeddings.embeddings);
|
||||
}
|
||||
_remoteSyncStatus!.complete(true);
|
||||
@@ -85,8 +84,8 @@ class EmbeddingStore {
|
||||
unawaited(_pushEmbedding(file, embedding));
|
||||
}
|
||||
|
||||
Future<void> clearEmbeddings(Model model) async {
|
||||
await EmbeddingsDB.instance.deleteAllForModel(model);
|
||||
Future<void> clearEmbeddings() async {
|
||||
await EmbeddingsDB.instance.deleteAll();
|
||||
await _preferences.remove(kEmbeddingsSyncTimeKey);
|
||||
}
|
||||
|
||||
@@ -106,7 +105,6 @@ class EmbeddingStore {
|
||||
"/embeddings",
|
||||
data: {
|
||||
"fileID": embedding.fileID,
|
||||
"model": embedding.model.name,
|
||||
"encryptedEmbedding": encryptedData,
|
||||
"decryptionHeader": header,
|
||||
},
|
||||
@@ -119,10 +117,7 @@ class EmbeddingStore {
|
||||
}
|
||||
}
|
||||
|
||||
Future<RemoteEmbeddings> _getRemoteEmbeddings(
|
||||
Model model, {
|
||||
int limit = 200,
|
||||
}) async {
|
||||
Future<RemoteEmbeddings> _getRemoteEmbeddings({int limit = 200}) async {
|
||||
final remoteEmbeddings = <RemoteEmbedding>[];
|
||||
try {
|
||||
final sinceTime = _preferences.getInt(kEmbeddingsSyncTimeKey) ?? 0;
|
||||
@@ -130,7 +125,6 @@ class EmbeddingStore {
|
||||
final response = await _dio.get(
|
||||
"/embeddings/diff",
|
||||
queryParameters: {
|
||||
"model": model.name,
|
||||
"sinceTime": sinceTime,
|
||||
"limit": limit,
|
||||
},
|
||||
@@ -212,7 +206,6 @@ Future<List<Embedding>> _decodeEmbeddings(Map<String, dynamic> args) async {
|
||||
embeddings.add(
|
||||
Embedding(
|
||||
fileID: input.embedding.fileID,
|
||||
model: deserialize(input.embedding.model),
|
||||
embedding: decodedEmbedding,
|
||||
updationTime: input.embedding.updatedAt,
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user