#include "opencv2/opencv.hpp"

#include <map>
#include <vector>
#include <string>
#include <iostream>

const std::map<std::string, int> str2backend{
    {"opencv", cv::dnn::DNN_BACKEND_OPENCV}, {"cuda", cv::dnn::DNN_BACKEND_CUDA},
    {"timvx",  cv::dnn::DNN_BACKEND_TIMVX},  {"cann", cv::dnn::DNN_BACKEND_CANN}
};
const std::map<std::string, int> str2target{
    {"cpu", cv::dnn::DNN_TARGET_CPU}, {"cuda", cv::dnn::DNN_TARGET_CUDA},
    {"npu", cv::dnn::DNN_TARGET_NPU}, {"cuda_fp16", cv::dnn::DNN_TARGET_CUDA_FP16}
};

class YuNet
{
public:
    YuNet(const std::string& model_path,
          const cv::Size& input_size = cv::Size(320, 320),
          float conf_threshold = 0.6f,
          float nms_threshold = 0.3f,
          int top_k = 5000,
          int backend_id = 0,
          int target_id = 0)
        : model_path_(model_path), input_size_(input_size),
          conf_threshold_(conf_threshold), nms_threshold_(nms_threshold),
          top_k_(top_k), backend_id_(backend_id), target_id_(target_id)
    {
        model = cv::FaceDetectorYN::create(model_path_, "", input_size_, conf_threshold_, nms_threshold_, top_k_, backend_id_, target_id_);
    }

    /* Overwrite the input size when creating the model. Size format: [Width, Height].
    */
    void setInputSize(const cv::Size& input_size)
    {
        input_size_ = input_size;
        model->setInputSize(input_size_);
    }

    cv::Mat infer(const cv::Mat image)
    {
        cv::Mat res;
        model->detect(image, res);
        return res;
    }

private:
    cv::Ptr<cv::FaceDetectorYN> model;

    std::string model_path_;
    cv::Size input_size_;
    float conf_threshold_;
    float nms_threshold_;
    int top_k_;
    int backend_id_;
    int target_id_;
};

cv::Mat visualize(const cv::Mat& image, const cv::Mat& faces, float fps = -1.f)
{
    static cv::Scalar box_color{0, 255, 0};
    static std::vector<cv::Scalar> landmark_color{
        cv::Scalar(255,   0,   0), // right eye
        cv::Scalar(  0,   0, 255), // left eye
        cv::Scalar(  0, 255,   0), // nose tip
        cv::Scalar(255,   0, 255), // right mouth corner
        cv::Scalar(  0, 255, 255)  // left mouth corner
    };
    static cv::Scalar text_color{0, 255, 0};

    auto output_image = image.clone();

    if (fps >= 0)
    {
        std::string text = cv::format("FPS: %0.2f", fps);
        int baseLine = 0;
        cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.6, 1.5, &baseLine);
        cv::rectangle(output_image, cv::Rect(cv::Point(0, 0), cv::Size(label_size.width + 20, label_size.height + 15)), cv::Scalar(0, 0, 0), -1);
        putText(output_image, text, cv::Point(10, 20), cv::FONT_HERSHEY_SIMPLEX, 0.6, cv::Scalar(0, 255, 0), 1.5);
    }

    for (int i = 0; i < faces.rows; ++i)
    {
        // Draw bounding boxes
        int x1 = static_cast<int>(faces.at<float>(i, 0));
        int y1 = static_cast<int>(faces.at<float>(i, 1));
        int w = static_cast<int>(faces.at<float>(i, 2));
        int h = static_cast<int>(faces.at<float>(i, 3));
        cv::rectangle(output_image, cv::Rect(x1, y1, w, h), box_color, 2);

        // Confidence as text
        float conf = faces.at<float>(i, 14);
        cv::putText(output_image, cv::format("%.4f", conf), cv::Point(x1, y1+12), cv::FONT_HERSHEY_DUPLEX, 0.5, text_color);

        // Draw landmarks
        for (int j = 0; j < landmark_color.size(); ++j)
        {
            int x = static_cast<int>(faces.at<float>(i, 2*j+4)), y = static_cast<int>(faces.at<float>(i, 2*j+5));
            cv::circle(output_image, cv::Point(x, y), 2, landmark_color[j], 2);
        }
    }

    std::string persons = cv::format("Persons: %d", faces.rows);
    int baseLine = 0;
    cv::Size label_size = cv::getTextSize(persons, cv::FONT_HERSHEY_SIMPLEX, 0.6, 1.5, &baseLine);
    int height = label_size.height + 15 + 5;
    cv::rectangle(output_image, cv::Rect(cv::Point(0, height), cv::Size(label_size.width + 20, label_size.height + 15)), cv::Scalar(0, 0, 0), -1);
    putText(output_image, persons, cv::Point(10, height + 20), cv::FONT_HERSHEY_SIMPLEX, 0.6, cv::Scalar(255, 255, 0), 1.5);

    return output_image;
}

int main(int argc, char** argv)
{
    cv::CommandLineParser parser(argc, argv,
        "{help  h           |                                   | Print this message}"
        "{input i           |                                   | Set input to a certain image, omit if using camera}"
        "{model m           | face_detection_yunet_2023mar.onnx  | Set path to the model}"
        "{backend b         | opencv                            | Set DNN backend}"
        "{target t          | cpu                               | Set DNN target}"
        "{save s            | false                             | Whether to save result image or not}"
        "{vis v             | false                             | Whether to visualize result image or not}"
        "{device_id d       | 0                                 | Device ID to use}"
        "{height H          | 0                                 | Input video height}"
        /* model params below*/
        "{conf_threshold    | 0.9                               | Set the minimum confidence for the model to identify a face. Filter out faces of conf < conf_threshold}"
        "{nms_threshold     | 0.3                               | Set the threshold to suppress overlapped boxes. Suppress boxes if IoU(box1, box2) >= nms_threshold, the one of higher score is kept.}"
        "{top_k             | 5000                              | Keep top_k bounding boxes before NMS. Set a lower value may help speed up postprocessing.}"
    );
    if (parser.has("help"))
    {
        parser.printMessage();
        return 0;
    }

    std::string input_path = parser.get<std::string>("input");
    std::string model_path = parser.get<std::string>("model");
    std::string backend = parser.get<std::string>("backend");
    std::string target = parser.get<std::string>("target");
    bool save_flag = parser.get<bool>("save");
    bool vis_flag = parser.get<bool>("vis");

    // model params
    float conf_threshold = parser.get<float>("conf_threshold");
    float nms_threshold = parser.get<float>("nms_threshold");
    int top_k = parser.get<int>("top_k");
    const int backend_id = str2backend.at(backend);
    const int target_id = str2target.at(target);

    int device_id = parser.get<int>("device_id");
    auto cap = cv::VideoCapture(device_id);
    cap.setExceptionMode(true);
    int w = static_cast<int>(cap.get(cv::CAP_PROP_FRAME_WIDTH));
    int h = static_cast<int>(cap.get(cv::CAP_PROP_FRAME_HEIGHT));

    // Установка новой высоты изображения
    int new_height = parser.get<int>("height");
    if (new_height != 0) {
        // Вычисление соотношения сторон исходного изображения
        double aspect_ratio = static_cast<double>(w) / h;
        // Вычисление новой ширины изображения, сохраняя соотношение сторон
        w = static_cast<int>(new_height * aspect_ratio);
        h = new_height;
    }
    std::cout << "New width: " << w << ", New height: " << h << std::endl;

    // Instantiate YuNet
    YuNet model(model_path, cv::Size(w, h), conf_threshold, nms_threshold, top_k, backend_id, target_id);

    model.setInputSize(cv::Size(w, h));

    auto tick_meter = cv::TickMeter();
    cv::Mat frame;
    while (cv::waitKey(1) < 0)
    {
        bool has_frame = cap.read(frame);
        if (!has_frame)
        {
            std::cout << "No frames grabbed! Exiting ...\n";
            break;
        }
        if (frame.empty()) {
            continue;
        }

        // Изменение размера входного изображения
        cv::Mat resized_frame;
        cv::resize(frame, resized_frame, cv::Size(w, h));

        // Inference
        tick_meter.start();
        cv::Mat faces = model.infer(resized_frame);
        tick_meter.stop();

        // Вычисление коэффициентов масштабирования
        float scale_x = static_cast<float>(frame.cols) / resized_frame.cols;
        float scale_y = static_cast<float>(frame.rows) / resized_frame.rows;

        // Масштабирование результатов обнаружения обратно к исходному размеру изображения
        for (int i = 0; i < faces.rows; ++i)
        {
            faces.at<float>(i, 0) *= scale_x; // x1
            faces.at<float>(i, 1) *= scale_y; // y1
            faces.at<float>(i, 2) *= scale_x; // w
            faces.at<float>(i, 3) *= scale_y; // h
            for (int j = 0; j < 5; ++j)
            {
                faces.at<float>(i, 2*j+4) *= scale_x; // x для точек на лице
                faces.at<float>(i, 2*j+5) *= scale_y; // y для точек на лице
            }
        }

        // Draw results on the input image
        auto res_image = visualize(frame, faces, (float)tick_meter.getFPS());
        // Visualize in a new window
        if (frame.empty()) {
            continue;
        }
        cv::namedWindow("image", cv::WINDOW_NORMAL);
        cv::imshow("image", res_image);

        tick_meter.reset();
    }

    return 0;
}