Program Listing for File DimRedTools.hpp

Return to documentation for file (include\DimRedTools\DimRedTools.hpp)

#ifndef DIMREDTOOLS_INCLUDE_DIMREDTOOLS_DIMREDTOOLS_HPP_
#define DIMREDTOOLS_INCLUDE_DIMREDTOOLS_DIMREDTOOLS_HPP_

#include <string>
#include <queue>
#include <functional>
#include <cmath>
#include <stdexcept>
#include <algorithm>
#include <Eigen/Dense>

namespace dim_red {

using Matrix = Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using Vector = Eigen::RowVectorXd;
using IntVector = Eigen::RowVectorXi;
using Array = Eigen::Array<double, 1, Eigen::Dynamic, Eigen::RowMajor>;
using Metric =
    std::function<double(const Eigen::Ref<const Vector> &, const Eigen::Ref<const Vector> &)>;

template <typename T>
class NeighborsHeap {
public:
    NeighborsHeap(size_t limit) : limit_(limit) {
        if (limit <= 0) {
            throw std::invalid_argument("Invalid limit: " + std::to_string(limit));
        }
    }

    void add(T value) {
        queue_.push(value);
        if (queue_.size() == limit_ + 1) {
            queue_.pop();
        }
    }

    const T &peek() const {
        return queue_.top();
    }

    std::vector<T> extract() {
        std::vector<T> result;
        result.reserve(queue_.size());
        for (; !queue_.empty(); queue_.pop()) {
            result.push_back(queue_.top());
        }
        return result;
    }

private:
    size_t limit_;
    std::priority_queue<T> queue_;
};

class NearestNeighbors {
public:
    virtual ~NearestNeighbors() {
    }

    virtual std::pair<Vector, IntVector> query(const Eigen::Ref<const Vector> &point, int k,
                                               bool sort_results = true) const = 0;

    virtual std::pair<Vector, IntVector> queryRadius(const Eigen::Ref<const Vector> &point,
                                                     double radius,
                                                     bool sort_results = false) const = 0;

protected:
    void validate(int data_size, int k, double radius, bool k_nearest) const;

    std::pair<Vector, IntVector> processNeighbors(
        int k, bool sort_results, std::vector<std::pair<double, int>> *neighbors,
        std::vector<std::pair<double, int>> *bound_neighbors) const;
};

class Bruteforce : public NearestNeighbors {
public:
    Bruteforce(const Eigen::Ref<const Matrix> &x, const std::string &metric = "euclidean");

    std::pair<Vector, IntVector> query(const Eigen::Ref<const Vector> &point, int k,
                                       bool sort_results = true) const override;

    std::pair<Vector, IntVector> queryRadius(const Eigen::Ref<const Vector> &point, double radius,
                                             bool sort_results = false) const override;

private:
    const Eigen::Ref<const Matrix> data_;
    Metric distance_;
};

Metric getMetricByName(const std::string &name);

}  // namespace dim_red

#endif