-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
metric.h
143 lines (122 loc) · 4 KB
/
metric.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
/*!
* Copyright (c) 2016 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifndef LIGHTGBM_METRIC_H_
#define LIGHTGBM_METRIC_H_
#include <LightGBM/config.h>
#include <LightGBM/dataset.h>
#include <LightGBM/meta.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/common.h>
#include <string>
#include <vector>
namespace LightGBM {
/*!
* \brief The interface of metric.
* Metric is used to calculate metric result
*/
class Metric {
public:
/*! \brief virtual destructor */
virtual ~Metric() {}
/*!
* \brief Initialize
* \param test_name Specific name for this metric, will output on log
* \param metadata Label data
* \param num_data Number of data
*/
virtual void Init(const Metadata& metadata, data_size_t num_data) = 0;
virtual const std::vector<std::string>& GetName() const = 0;
virtual double factor_to_bigger_better() const = 0;
/*!
* \brief Calculating and printing metric result
* \param score Current prediction score
*/
virtual std::vector<double> Eval(const double* score, const ObjectiveFunction* objective) const = 0;
Metric() = default;
/*! \brief Disable copy */
Metric& operator=(const Metric&) = delete;
/*! \brief Disable copy */
Metric(const Metric&) = delete;
/*!
* \brief Create object of metrics
* \param type Specific type of metric
* \param config Config for metric
*/
LIGHTGBM_EXPORT static Metric* CreateMetric(const std::string& type, const Config& config);
};
/*!
* \brief Static class, used to calculate DCG score
*/
class DCGCalculator {
public:
static void DefaultEvalAt(std::vector<int>* eval_at);
static void DefaultLabelGain(std::vector<double>* label_gain);
/*!
* \brief Initial logic
* \param label_gain Gain for labels, default is 2^i - 1
*/
static void Init(const std::vector<double>& label_gain);
/*!
* \brief Calculate the DCG score at position k
* \param k The position to evaluate
* \param label Pointer of label
* \param score Pointer of score
* \param num_data Number of data
* \return The DCG score
*/
static double CalDCGAtK(data_size_t k, const label_t* label,
const double* score, data_size_t num_data);
/*!
* \brief Calculate the DCG score at multi position
* \param ks The positions to evaluate
* \param label Pointer of label
* \param score Pointer of score
* \param num_data Number of data
* \param out Output result
*/
static void CalDCG(const std::vector<data_size_t>& ks,
const label_t* label, const double* score,
data_size_t num_data, std::vector<double>* out);
/*!
* \brief Calculate the Max DCG score at position k
* \param k The position want to eval at
* \param label Pointer of label
* \param num_data Number of data
* \return The max DCG score
*/
static double CalMaxDCGAtK(data_size_t k,
const label_t* label, data_size_t num_data);
/*!
* \brief Check the label range for NDCG and lambdarank
* \param label Pointer of label
* \param num_data Number of data
*/
static void CheckLabel(const label_t* label, data_size_t num_data);
/*!
* \brief Calculate the Max DCG score at multi position
* \param ks The positions want to eval at
* \param label Pointer of label
* \param num_data Number of data
* \param out Output result
*/
static void CalMaxDCG(const std::vector<data_size_t>& ks,
const label_t* label, data_size_t num_data, std::vector<double>* out);
/*!
* \brief Get discount score of position k
* \param k The position
* \return The discount of this position
*/
inline static double GetDiscount(data_size_t k) { return discount_[k]; }
private:
/*! \brief store gains for different label */
static std::vector<double> label_gain_;
/*! \brief store discount score for different position */
static std::vector<double> discount_;
/*! \brief max position for eval */
static const data_size_t kMaxPosition;
};
} // namespace LightGBM
#endif // LightGBM_METRIC_H_