[mob][photos] Remove any model differentiation between onnx and ggml

This commit is contained in:
laurenspriem
2024-07-03 11:20:53 +05:30
parent 2d0cadc8c9
commit 523d3fa201
3 changed files with 16 additions and 88 deletions

View File

@@ -16,7 +16,6 @@ class EmbeddingsDB {
static const databaseName = "ente.embeddings.db";
static const tableName = "embeddings";
static const columnFileID = "file_id";
static const columnModel = "model";
static const columnEmbedding = "embedding";
static const columnUpdationTime = "updation_time";
@@ -42,7 +41,7 @@ class EmbeddingsDB {
1,
(tx) async {
await tx.execute(
'CREATE TABLE $tableName ($columnFileID INTEGER NOT NULL, $columnModel INTEGER NOT NULL, $columnEmbedding BLOB NOT NULL, $columnUpdationTime INTEGER, UNIQUE ($columnFileID, $columnModel))',
'CREATE TABLE $tableName ($columnFileID INTEGER NOT NULL, $columnEmbedding BLOB NOT NULL, $columnUpdationTime INTEGER, UNIQUE ($columnFileID))',
);
},
),
@@ -57,19 +56,16 @@ class EmbeddingsDB {
await db.execute('DELETE FROM $tableName');
}
Future<List<Embedding>> getAll(Model model) async {
Future<List<Embedding>> getAll() async {
final db = await _database;
final results = await db.getAll('SELECT * FROM $tableName');
return _convertToEmbeddings(results);
}
// Get FileIDs for a specific model
Future<Set<int>> getFileIDs(Model model) async {
Future<Set<int>> getFileIDs() async {
final db = await _database;
final results = await db.getAll(
'SELECT $columnFileID FROM $tableName WHERE $columnModel = ?',
[modelToInt(model)!],
);
final results = await db.getAll('SELECT $columnFileID FROM $tableName');
if (results.isEmpty) {
return <int>{};
}
@@ -79,7 +75,7 @@ class EmbeddingsDB {
Future<void> put(Embedding embedding) async {
final db = await _database;
await db.execute(
'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnModel, $columnEmbedding, $columnUpdationTime) VALUES (?, ?, ?, ?)',
'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnUpdationTime) VALUES (?, ?, ?, ?)',
_getRowFromEmbedding(embedding),
);
Bus.instance.fire(EmbeddingUpdatedEvent());
@@ -89,7 +85,7 @@ class EmbeddingsDB {
final db = await _database;
final inputs = embeddings.map((e) => _getRowFromEmbedding(e)).toList();
await db.executeBatch(
'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnModel, $columnEmbedding, $columnUpdationTime) values(?, ?, ?, ?)',
'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnEmbedding, $columnUpdationTime) values(?, ?, ?, ?)',
inputs,
);
Bus.instance.fire(EmbeddingUpdatedEvent());
@@ -111,12 +107,9 @@ class EmbeddingsDB {
Bus.instance.fire(EmbeddingUpdatedEvent());
}
Future<void> deleteAllForModel(Model model) async {
Future<void> deleteAll() async {
final db = await _database;
await db.execute(
'DELETE FROM $tableName WHERE $columnModel = ?',
[modelToInt(model)!],
);
await db.execute('DELETE FROM $tableName');
Bus.instance.fire(EmbeddingUpdatedEvent());
}
@@ -132,16 +125,14 @@ class EmbeddingsDB {
Embedding _getEmbeddingFromRow(Map<String, dynamic> row) {
final fileID = row[columnFileID];
final model = intToModel(row[columnModel])!;
final bytes = row[columnEmbedding] as Uint8List;
final list = Float32List.view(bytes.buffer);
return Embedding(fileID: fileID, model: model, embedding: list);
return Embedding(fileID: fileID, embedding: list);
}
List<Object?> _getRowFromEmbedding(Embedding embedding) {
return [
embedding.fileID,
modelToInt(embedding.model)!,
Float32List.fromList(embedding.embedding).buffer.asUint8List(),
embedding.updationTime,
];
@@ -157,26 +148,4 @@ class EmbeddingsDB {
await deprecatedIsar.delete();
}
}
int? modelToInt(Model model) {
switch (model) {
case Model.onnxClip:
return 1;
case Model.ggmlClip:
return 2;
default:
return null;
}
}
Model? intToModel(int model) {
switch (model) {
case 1:
return Model.onnxClip;
case 2:
return Model.ggmlClip;
default:
return null;
}
}
}

View File

@@ -2,7 +2,6 @@ import "dart:convert";
class Embedding {
final int fileID;
final Model model;
final List<double> embedding;
int? updationTime;
@@ -10,15 +9,13 @@ class Embedding {
Embedding({
required this.fileID,
required this.model,
required this.embedding,
this.updationTime,
});
factory Embedding.empty(int fileID, Model model) {
factory Embedding.empty(int fileID) {
return Embedding(
fileID: fileID,
model: model,
embedding: <double>[],
);
}
@@ -31,34 +28,3 @@ class Embedding {
return jsonEncode(embedding);
}
}
enum Model {
onnxClip,
ggmlClip,
}
extension ModelExtension on Model {
String get name => serialize(this);
}
String serialize(Model model) {
switch (model) {
case Model.onnxClip:
return 'onnx-clip';
case Model.ggmlClip:
return 'ggml-clip';
default:
throw Exception('$model is not a valid Model');
}
}
Model deserialize(String model) {
switch (model) {
case 'onnx-clip':
return Model.onnxClip;
case 'ggml-clip':
return Model.ggmlClip;
default:
throw Exception('$model is not a valid Model');
}
}

View File

@@ -33,17 +33,16 @@ class EmbeddingStore {
_preferences = await SharedPreferences.getInstance();
}
Future<bool> pullEmbeddings(Model model) async {
return true; // TODO: remove this
Future<bool> pullEmbeddings() async {
if (_remoteSyncStatus != null) {
return _remoteSyncStatus!.future;
}
_remoteSyncStatus = Completer();
try {
var remoteEmbeddings = await _getRemoteEmbeddings(model);
var remoteEmbeddings = await _getRemoteEmbeddings();
await _storeRemoteEmbeddings(remoteEmbeddings.embeddings);
while (remoteEmbeddings.hasMore) {
remoteEmbeddings = await _getRemoteEmbeddings(model);
remoteEmbeddings = await _getRemoteEmbeddings();
await _storeRemoteEmbeddings(remoteEmbeddings.embeddings);
}
_remoteSyncStatus!.complete(true);
@@ -85,8 +84,8 @@ class EmbeddingStore {
unawaited(_pushEmbedding(file, embedding));
}
Future<void> clearEmbeddings(Model model) async {
await EmbeddingsDB.instance.deleteAllForModel(model);
Future<void> clearEmbeddings() async {
await EmbeddingsDB.instance.deleteAll();
await _preferences.remove(kEmbeddingsSyncTimeKey);
}
@@ -106,7 +105,6 @@ class EmbeddingStore {
"/embeddings",
data: {
"fileID": embedding.fileID,
"model": embedding.model.name,
"encryptedEmbedding": encryptedData,
"decryptionHeader": header,
},
@@ -119,10 +117,7 @@ class EmbeddingStore {
}
}
Future<RemoteEmbeddings> _getRemoteEmbeddings(
Model model, {
int limit = 200,
}) async {
Future<RemoteEmbeddings> _getRemoteEmbeddings({int limit = 200}) async {
final remoteEmbeddings = <RemoteEmbedding>[];
try {
final sinceTime = _preferences.getInt(kEmbeddingsSyncTimeKey) ?? 0;
@@ -130,7 +125,6 @@ class EmbeddingStore {
final response = await _dio.get(
"/embeddings/diff",
queryParameters: {
"model": model.name,
"sinceTime": sinceTime,
"limit": limit,
},
@@ -212,7 +206,6 @@ Future<List<Embedding>> _decodeEmbeddings(Map<String, dynamic> args) async {
embeddings.add(
Embedding(
fileID: input.embedding.fileID,
model: deserialize(input.embedding.model),
embedding: decodedEmbedding,
updationTime: input.embedding.updatedAt,
),