/* ============================================================
 *
 * This file is a part of digiKam project
 * https://www.digikam.org
 *
 * Date        : 2010-09-03
 * Description : Integrated, multithread face detection / recognition
 *
 * Copyright (C) 2010-2011 by Marcel Wiesweg <marcel dot wiesweg at gmx dot de>
 * Copyright (C) 2012-2020 by Gilles Caulier <caulier dot gilles at gmail dot com>
 *
 * This program is free software; you can redistribute it
 * and/or modify it under the terms of the GNU General
 * Public License as published by the Free Software Foundation;
 * either version 2, or (at your option)
 * any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * ============================================================ */

#include "trainerworker.h"

// KDE includes

#include <ksharedconfig.h>
#include <kconfiggroup.h>

// Local includes

#include "digikam_debug.h"

namespace Digikam
{

class Q_DECL_HIDDEN MapListTrainingDataProvider : public TrainingDataProvider
{
public:

    MapListTrainingDataProvider()
    {
    }

    ImageListProvider* newImages(const Identity& identity)
    {
        if (imagesToTrain.contains(identity.id()))
        {
            QListImageListProvider& provider = imagesToTrain[identity.id()];
            provider.reset();

            return &provider;
        }

        return &empty;
    }

    ImageListProvider* images(const Identity&)
    {
        // Not implemented. Would be needed if we use a backend with a "holistic" approach that needs all images to train.

        return &empty;
    }

public:

    EmptyImageListProvider            empty;
    QMap<int, QListImageListProvider> imagesToTrain;
};

// ----------------------------------------------------------------------------------------

TrainerWorker::TrainerWorker(FacePipeline::Private* const d)
    : imageRetriever(d),
      d(d)
{
    KSharedConfig::Ptr config                    = KSharedConfig::openConfig();
    KConfigGroup group                           = config->group(QLatin1String("Face Management Settings"));

    RecognitionDatabase::RecognizeAlgorithm algo =
            (RecognitionDatabase::RecognizeAlgorithm)group.readEntry(QLatin1String("Recognize Algorithm"),
                                                                     (int)RecognitionDatabase::RecognizeAlgorithm::DNN);
    database.activeFaceRecognizer(algo);
}

TrainerWorker::~TrainerWorker()
{
    wait();    // protect detector
}

void TrainerWorker::process(FacePipelineExtendedPackage::Ptr package)
{
    //qCDebug(DIGIKAM_GENERAL_LOG) << "TrainerWorker: processing one package";

    // Get a list of faces with type FaceForTraining (probably type is ConfirmedFace)

    QList<FaceTagsIface> toTrain;
    QList<int>           identities;
    QList<Identity>      identitySet;
    FaceUtils            utils;

    foreach (const FacePipelineFaceTagsIface& face, package->databaseFaces)
    {
        if (face.roles & FacePipelineFaceTagsIface::ForTraining)
        {
            FaceTagsIface dbFace = face;
            dbFace.setType(FaceTagsIface::FaceForTraining);
            toTrain << dbFace;

            Identity identity    = utils.identityForTag(dbFace.tagId(), database);

            identities  << identity.id();

            if (!identitySet.contains(identity))
            {
                identitySet << identity;
            }
        }
    }

    if (!toTrain.isEmpty())
    {
        QList<QImage> images;

        if (package->image.isNull())
        {
            images = imageRetriever.getThumbnails(package->filePath, toTrain);
        }
        else
        {
            images = imageRetriever.getDetails(package->image, toTrain);
        }

        MapListTrainingDataProvider provider;

        // Group images by identity

        for (int i = 0 ; i < toTrain.size() ; ++i)
        {
            provider.imagesToTrain[identities[i]].list << images[i];
        }

        database.train(identitySet, &provider, QLatin1String("digikam"));
    }

    utils.removeFaces(toTrain);
    package->databaseFaces.replaceRole(FacePipelineFaceTagsIface::ForTraining, FacePipelineFaceTagsIface::Trained);
    package->processFlags |= FacePipelinePackage::ProcessedByTrainer;

    emit processed(package);
}

void TrainerWorker::aboutToDeactivate()
{
    imageRetriever.cancel();
}

} // namespace Digikam
