538 lines
15 KiB
TypeScript
538 lines
15 KiB
TypeScript
import log from "@/next/log";
|
|
import { CACHES } from "@ente/shared/storage/cacheStorage/constants";
|
|
import { cached } from "@ente/shared/storage/cacheStorage/helpers";
|
|
import * as tf from "@tensorflow/tfjs-core";
|
|
import { NormalizedFace } from "blazeface-back";
|
|
import { FILE_TYPE } from "constants/file";
|
|
import { BLAZEFACE_FACE_SIZE } from "constants/mlConfig";
|
|
import { euclidean } from "hdbscan";
|
|
import PQueue from "p-queue";
|
|
import DownloadManager from "services/download";
|
|
import { getLocalFiles } from "services/fileService";
|
|
import { decodeLivePhoto } from "services/livePhotoService";
|
|
import { EnteFile } from "types/file";
|
|
import { Dimensions } from "types/image";
|
|
import {
|
|
AlignedFace,
|
|
DetectedFace,
|
|
Face,
|
|
FaceAlignment,
|
|
FaceImageBlob,
|
|
MlFileData,
|
|
Person,
|
|
Versioned,
|
|
} from "types/machineLearning";
|
|
import { getRenderableImage } from "utils/file";
|
|
import { clamp, imageBitmapToBlob, warpAffineFloat32List } from "utils/image";
|
|
import mlIDbStorage from "utils/storage/mlIDbStorage";
|
|
import { Box, Point } from "../../../thirdparty/face-api/classes";
|
|
import { ibExtractFaceImage, ibExtractFaceImages } from "./faceAlign";
|
|
import { getFaceCropBlobFromStorage } from "./faceCrop";
|
|
|
|
export function f32Average(descriptors: Float32Array[]) {
|
|
if (descriptors.length < 1) {
|
|
throw Error("f32Average: input size 0");
|
|
}
|
|
|
|
if (descriptors.length === 1) {
|
|
return descriptors[0];
|
|
}
|
|
|
|
const f32Size = descriptors[0].length;
|
|
const avg = new Float32Array(f32Size);
|
|
|
|
for (let index = 0; index < f32Size; index++) {
|
|
avg[index] = descriptors[0][index];
|
|
for (let desc = 1; desc < descriptors.length; desc++) {
|
|
avg[index] = avg[index] + descriptors[desc][index];
|
|
}
|
|
avg[index] = avg[index] / descriptors.length;
|
|
}
|
|
|
|
return avg;
|
|
}
|
|
|
|
export function isTensor(tensor: any, dim: number) {
|
|
return tensor instanceof tf.Tensor && tensor.shape.length === dim;
|
|
}
|
|
|
|
export function isTensor1D(tensor: any): tensor is tf.Tensor1D {
|
|
return isTensor(tensor, 1);
|
|
}
|
|
|
|
export function isTensor2D(tensor: any): tensor is tf.Tensor2D {
|
|
return isTensor(tensor, 2);
|
|
}
|
|
|
|
export function isTensor3D(tensor: any): tensor is tf.Tensor3D {
|
|
return isTensor(tensor, 3);
|
|
}
|
|
|
|
export function isTensor4D(tensor: any): tensor is tf.Tensor4D {
|
|
return isTensor(tensor, 4);
|
|
}
|
|
|
|
export function toTensor4D(
|
|
image: tf.Tensor3D | tf.Tensor4D,
|
|
dtype?: tf.DataType,
|
|
) {
|
|
return tf.tidy(() => {
|
|
let reshapedImage: tf.Tensor4D;
|
|
if (isTensor3D(image)) {
|
|
reshapedImage = tf.expandDims(image, 0);
|
|
} else if (isTensor4D(image)) {
|
|
reshapedImage = image;
|
|
} else {
|
|
throw Error("toTensor4D only supports Tensor3D and Tensor4D input");
|
|
}
|
|
if (dtype) {
|
|
reshapedImage = tf.cast(reshapedImage, dtype);
|
|
}
|
|
|
|
return reshapedImage;
|
|
});
|
|
}
|
|
|
|
export function imageBitmapsToTensor4D(imageBitmaps: Array<ImageBitmap>) {
|
|
return tf.tidy(() => {
|
|
const tfImages = imageBitmaps.map((ib) => tf.browser.fromPixels(ib));
|
|
return tf.stack(tfImages) as tf.Tensor4D;
|
|
});
|
|
}
|
|
|
|
export function extractFaces(
|
|
image: tf.Tensor3D | tf.Tensor4D,
|
|
facebBoxes: Array<Box>,
|
|
faceSize: number,
|
|
) {
|
|
return tf.tidy(() => {
|
|
const reshapedImage = toTensor4D(image, "float32");
|
|
|
|
const boxes = facebBoxes.map((box) => {
|
|
const normalized = box.rescale({
|
|
width: 1 / reshapedImage.shape[2],
|
|
height: 1 / reshapedImage.shape[1],
|
|
});
|
|
|
|
return [
|
|
normalized.top,
|
|
normalized.left,
|
|
normalized.bottom,
|
|
normalized.right,
|
|
];
|
|
});
|
|
|
|
// log.info('boxes: ', boxes[0]);
|
|
|
|
const faceImagesTensor = tf.image.cropAndResize(
|
|
reshapedImage,
|
|
boxes,
|
|
tf.fill([boxes.length], 0, "int32"),
|
|
[faceSize, faceSize],
|
|
);
|
|
|
|
return faceImagesTensor;
|
|
});
|
|
}
|
|
|
|
export function newBox(x: number, y: number, width: number, height: number) {
|
|
return new Box({ x, y, width, height });
|
|
}
|
|
|
|
export function newBoxFromPoints(
|
|
left: number,
|
|
top: number,
|
|
right: number,
|
|
bottom: number,
|
|
) {
|
|
return new Box({ left, top, right, bottom });
|
|
}
|
|
|
|
export function normFaceBox(face: NormalizedFace) {
|
|
return newBoxFromPoints(
|
|
face.topLeft[0],
|
|
face.topLeft[1],
|
|
face.bottomRight[0],
|
|
face.bottomRight[1],
|
|
);
|
|
}
|
|
|
|
export function getBoxCenterPt(topLeft: Point, bottomRight: Point): Point {
|
|
return topLeft.add(bottomRight.sub(topLeft).div(new Point(2, 2)));
|
|
}
|
|
|
|
export function getBoxCenter(box: Box): Point {
|
|
return getBoxCenterPt(box.topLeft, box.bottomRight);
|
|
}
|
|
|
|
export function enlargeBox(box: Box, factor: number = 1.5) {
|
|
const center = getBoxCenter(box);
|
|
const size = new Point(box.width, box.height);
|
|
const newHalfSize = new Point((factor * size.x) / 2, (factor * size.y) / 2);
|
|
|
|
return new Box({
|
|
left: center.x - newHalfSize.x,
|
|
top: center.y - newHalfSize.y,
|
|
right: center.x + newHalfSize.x,
|
|
bottom: center.y + newHalfSize.y,
|
|
});
|
|
}
|
|
|
|
export function normalizeRadians(angle: number) {
|
|
return angle - 2 * Math.PI * Math.floor((angle + Math.PI) / (2 * Math.PI));
|
|
}
|
|
|
|
export function computeRotation(point1: Point, point2: Point) {
|
|
const radians =
|
|
Math.PI / 2 - Math.atan2(-(point2.y - point1.y), point2.x - point1.x);
|
|
return normalizeRadians(radians);
|
|
}
|
|
|
|
export function getAllFacesFromMap(allFacesMap: Map<number, Array<Face>>) {
|
|
const allFaces = [...allFacesMap.values()].flat();
|
|
|
|
return allFaces;
|
|
}
|
|
|
|
export async function getLocalFile(fileId: number) {
|
|
const localFiles = await getLocalFiles();
|
|
return localFiles.find((f) => f.id === fileId);
|
|
}
|
|
|
|
export async function getFaceImage(
|
|
face: AlignedFace,
|
|
token: string,
|
|
faceSize: number = BLAZEFACE_FACE_SIZE,
|
|
file?: EnteFile,
|
|
): Promise<FaceImageBlob> {
|
|
if (!file) {
|
|
file = await getLocalFile(face.fileId);
|
|
}
|
|
|
|
const imageBitmap = await getOriginalImageBitmap(file);
|
|
const faceImageBitmap = ibExtractFaceImage(
|
|
imageBitmap,
|
|
face.alignment,
|
|
faceSize,
|
|
);
|
|
const faceImage = imageBitmapToBlob(faceImageBitmap);
|
|
faceImageBitmap.close();
|
|
imageBitmap.close();
|
|
|
|
return faceImage;
|
|
}
|
|
|
|
export async function extractFaceImages(
|
|
faces: Array<AlignedFace>,
|
|
faceSize: number,
|
|
image?: ImageBitmap,
|
|
) {
|
|
// if (faces.length === faces.filter((f) => f.crop).length) {
|
|
// return ibExtractFaceImagesFromCrops(faces, faceSize);
|
|
// } else
|
|
if (image) {
|
|
const faceAlignments = faces.map((f) => f.alignment);
|
|
return ibExtractFaceImages(image, faceAlignments, faceSize);
|
|
} else {
|
|
throw Error(
|
|
"Either face crops or image is required to extract face images",
|
|
);
|
|
}
|
|
}
|
|
|
|
export async function extractFaceImagesToFloat32(
|
|
faceAlignments: Array<FaceAlignment>,
|
|
faceSize: number,
|
|
image: ImageBitmap,
|
|
): Promise<Float32Array> {
|
|
const faceData = new Float32Array(
|
|
faceAlignments.length * faceSize * faceSize * 3,
|
|
);
|
|
for (let i = 0; i < faceAlignments.length; i++) {
|
|
const alignedFace = faceAlignments[i];
|
|
const faceDataOffset = i * faceSize * faceSize * 3;
|
|
warpAffineFloat32List(
|
|
image,
|
|
alignedFace,
|
|
faceSize,
|
|
faceData,
|
|
faceDataOffset,
|
|
);
|
|
}
|
|
return faceData;
|
|
}
|
|
|
|
export function leftFillNum(num: number, length: number, padding: number) {
|
|
return num.toString().padStart(length, padding.toString());
|
|
}
|
|
|
|
export function getFaceId(detectedFace: DetectedFace, imageDims: Dimensions) {
|
|
const xMin = clamp(
|
|
detectedFace.detection.box.x / imageDims.width,
|
|
0.0,
|
|
0.999999,
|
|
)
|
|
.toFixed(5)
|
|
.substring(2);
|
|
const yMin = clamp(
|
|
detectedFace.detection.box.y / imageDims.height,
|
|
0.0,
|
|
0.999999,
|
|
)
|
|
.toFixed(5)
|
|
.substring(2);
|
|
const xMax = clamp(
|
|
(detectedFace.detection.box.x + detectedFace.detection.box.width) /
|
|
imageDims.width,
|
|
0.0,
|
|
0.999999,
|
|
)
|
|
.toFixed(5)
|
|
.substring(2);
|
|
const yMax = clamp(
|
|
(detectedFace.detection.box.y + detectedFace.detection.box.height) /
|
|
imageDims.height,
|
|
0.0,
|
|
0.999999,
|
|
)
|
|
.toFixed(5)
|
|
.substring(2);
|
|
|
|
const rawFaceID = `${xMin}_${yMin}_${xMax}_${yMax}`;
|
|
const faceID = `${detectedFace.fileId}_${rawFaceID}`;
|
|
|
|
return faceID;
|
|
}
|
|
|
|
export async function getTFImage(blob): Promise<tf.Tensor3D> {
|
|
const imageBitmap = await createImageBitmap(blob);
|
|
const tfImage = tf.browser.fromPixels(imageBitmap);
|
|
imageBitmap.close();
|
|
|
|
return tfImage;
|
|
}
|
|
|
|
export async function getImageBlobBitmap(blob: Blob): Promise<ImageBitmap> {
|
|
return await createImageBitmap(blob);
|
|
}
|
|
|
|
// export async function getTFImageUsingJpegJS(blob: Blob): Promise<TFImageBitmap> {
|
|
// const imageData = jpegjs.decode(await blob.arrayBuffer());
|
|
// const tfImage = tf.browser.fromPixels(imageData);
|
|
|
|
// return new TFImageBitmap(undefined, tfImage);
|
|
// }
|
|
|
|
async function getOriginalFile(file: EnteFile, queue?: PQueue) {
|
|
let fileStream;
|
|
if (queue) {
|
|
fileStream = await queue.add(() => DownloadManager.getFile(file));
|
|
} else {
|
|
fileStream = await DownloadManager.getFile(file);
|
|
}
|
|
return new Response(fileStream).blob();
|
|
}
|
|
|
|
async function getOriginalConvertedFile(file: EnteFile, queue?: PQueue) {
|
|
const fileBlob = await getOriginalFile(file, queue);
|
|
if (file.metadata.fileType === FILE_TYPE.IMAGE) {
|
|
return await getRenderableImage(file.metadata.title, fileBlob);
|
|
} else {
|
|
const livePhoto = await decodeLivePhoto(file, fileBlob);
|
|
return await getRenderableImage(
|
|
livePhoto.imageNameTitle,
|
|
new Blob([livePhoto.image]),
|
|
);
|
|
}
|
|
}
|
|
|
|
export async function getOriginalImageBitmap(
|
|
file: EnteFile,
|
|
queue?: PQueue,
|
|
useCache: boolean = false,
|
|
) {
|
|
let fileBlob;
|
|
|
|
if (useCache) {
|
|
fileBlob = await cached(CACHES.FILES, file.id.toString(), () => {
|
|
return getOriginalConvertedFile(file, queue);
|
|
});
|
|
} else {
|
|
fileBlob = await getOriginalConvertedFile(file, queue);
|
|
}
|
|
log.info("[MLService] Got file: ", file.id.toString());
|
|
|
|
return getImageBlobBitmap(fileBlob);
|
|
}
|
|
|
|
export async function getThumbnailImageBitmap(file: EnteFile) {
|
|
const thumb = await DownloadManager.getThumbnail(file);
|
|
log.info("[MLService] Got thumbnail: ", file.id.toString());
|
|
|
|
return getImageBlobBitmap(new Blob([thumb]));
|
|
}
|
|
|
|
export async function getLocalFileImageBitmap(
|
|
enteFile: EnteFile,
|
|
localFile: globalThis.File,
|
|
) {
|
|
let fileBlob = localFile as Blob;
|
|
fileBlob = await getRenderableImage(enteFile.metadata.title, fileBlob);
|
|
return getImageBlobBitmap(fileBlob);
|
|
}
|
|
|
|
export async function getPeopleList(file: EnteFile): Promise<Array<Person>> {
|
|
let startTime = Date.now();
|
|
const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id);
|
|
log.info(
|
|
"getPeopleList:mlFilesStore:getItem",
|
|
Date.now() - startTime,
|
|
"ms",
|
|
);
|
|
if (!mlFileData?.faces || mlFileData.faces.length < 1) {
|
|
return [];
|
|
}
|
|
|
|
const peopleIds = mlFileData.faces
|
|
.filter((f) => f.personId !== null && f.personId !== undefined)
|
|
.map((f) => f.personId);
|
|
if (!peopleIds || peopleIds.length < 1) {
|
|
return [];
|
|
}
|
|
// log.info("peopleIds: ", peopleIds);
|
|
startTime = Date.now();
|
|
const peoplePromises = peopleIds.map(
|
|
(p) => mlIDbStorage.getPerson(p) as Promise<Person>,
|
|
);
|
|
const peopleList = await Promise.all(peoplePromises);
|
|
log.info(
|
|
"getPeopleList:mlPeopleStore:getItems",
|
|
Date.now() - startTime,
|
|
"ms",
|
|
);
|
|
// log.info("peopleList: ", peopleList);
|
|
|
|
return peopleList;
|
|
}
|
|
|
|
export async function getUnidentifiedFaces(
|
|
file: EnteFile,
|
|
): Promise<Array<Face>> {
|
|
const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id);
|
|
|
|
return mlFileData?.faces?.filter(
|
|
(f) => f.personId === null || f.personId === undefined,
|
|
);
|
|
}
|
|
|
|
export async function getFaceCropBlobs(
|
|
faces: Array<Face>,
|
|
): Promise<Array<FaceImageBlob>> {
|
|
const faceCrops = faces
|
|
.map((f) => f.crop)
|
|
.filter((faceCrop) => faceCrop !== null && faceCrop !== undefined);
|
|
|
|
return (
|
|
faceCrops &&
|
|
Promise.all(
|
|
faceCrops.map((faceCrop) => getFaceCropBlobFromStorage(faceCrop)),
|
|
)
|
|
);
|
|
}
|
|
|
|
export async function getAllPeople(limit: number = undefined) {
|
|
let people: Array<Person> = await mlIDbStorage.getAllPeople();
|
|
// await mlPeopleStore.iterate<Person, void>((person) => {
|
|
// people.push(person);
|
|
// });
|
|
people = people ?? [];
|
|
return people
|
|
.sort((p1, p2) => p2.files.length - p1.files.length)
|
|
.slice(0, limit);
|
|
}
|
|
|
|
export function findFirstIfSorted<T>(
|
|
elements: Array<T>,
|
|
comparator: (a: T, b: T) => number,
|
|
) {
|
|
if (!elements || elements.length < 1) {
|
|
return;
|
|
}
|
|
let first = elements[0];
|
|
|
|
for (let i = 1; i < elements.length; i++) {
|
|
const comp = comparator(elements[i], first);
|
|
if (comp < 0) {
|
|
first = elements[i];
|
|
}
|
|
}
|
|
|
|
return first;
|
|
}
|
|
|
|
export function isDifferentOrOld(
|
|
method: Versioned<string>,
|
|
thanMethod: Versioned<string>,
|
|
) {
|
|
return (
|
|
!method ||
|
|
method.value !== thanMethod.value ||
|
|
method.version < thanMethod.version
|
|
);
|
|
}
|
|
|
|
function primitiveArrayEquals(a, b) {
|
|
return (
|
|
Array.isArray(a) &&
|
|
Array.isArray(b) &&
|
|
a.length === b.length &&
|
|
a.every((val, index) => val === b[index])
|
|
);
|
|
}
|
|
|
|
export function areFaceIdsSame(ofFaces: Array<Face>, toFaces: Array<Face>) {
|
|
if (
|
|
(ofFaces === null || ofFaces === undefined) &&
|
|
(toFaces === null || toFaces === undefined)
|
|
) {
|
|
return true;
|
|
}
|
|
return primitiveArrayEquals(
|
|
ofFaces?.map((f) => f.id),
|
|
toFaces?.map((f) => f.id),
|
|
);
|
|
}
|
|
|
|
export function getNearestPointIndex(
|
|
toPoint: Point,
|
|
fromPoints: Array<Point>,
|
|
maxDistance?: number,
|
|
) {
|
|
const dists = fromPoints.map((point, i) => ({
|
|
index: i,
|
|
point: point,
|
|
distance: euclidean([point.x, point.y], [toPoint.x, toPoint.y]),
|
|
}));
|
|
const nearest = findFirstIfSorted(
|
|
dists,
|
|
(a, b) => Math.abs(a.distance) - Math.abs(b.distance),
|
|
);
|
|
|
|
// log.info('Nearest dist: ', nearest.distance, maxDistance);
|
|
if (!maxDistance || nearest.distance <= maxDistance) {
|
|
return nearest.index;
|
|
}
|
|
}
|
|
|
|
export function logQueueStats(queue: PQueue, name: string) {
|
|
queue.on("active", () =>
|
|
log.info(
|
|
`queuestats: ${name}: Active, Size: ${queue.size} Pending: ${queue.pending}`,
|
|
),
|
|
);
|
|
queue.on("idle", () => log.info(`queuestats: ${name}: Idle`));
|
|
queue.on("error", (error) =>
|
|
console.error(`queuestats: ${name}: Error, `, error),
|
|
);
|
|
}
|