forked from dmlc/xgboost
-
Notifications
You must be signed in to change notification settings - Fork 0
/
file_iterator.h
114 lines (99 loc) · 3.49 KB
/
file_iterator.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
/*!
* Copyright 2021 XGBoost contributors
*/
#ifndef XGBOOST_DATA_FILE_ITERATOR_H_
#define XGBOOST_DATA_FILE_ITERATOR_H_
#include <string>
#include <memory>
#include <vector>
#include <utility>
#include "dmlc/data.h"
#include "xgboost/c_api.h"
#include "xgboost/json.h"
#include "xgboost/linalg.h"
#include "array_interface.h"
namespace xgboost {
namespace data {
/**
* An iterator for implementing external memory support with file inputs. Users of
* external memory are encouraged to define their own file parsers/loaders so this one is
* just here for compatibility with old versions of XGBoost and CLI interface.
*/
class FileIterator {
// uri of input file, encodes parameters about whether it's 1-based index etc. dmlc
// parser will decode these information.
std::string uri_;
// Equals to rank_id in distributed training, used to split file into parts for each
// worker.
uint32_t part_idx_;
// Equals to total number of workers.
uint32_t n_parts_;
// Format of the input file, like "libsvm".
std::string type_;
DMatrixHandle proxy_;
std::unique_ptr<dmlc::Parser<uint32_t>> parser_;
// Temporary reference to stage the data.
dmlc::RowBlock<uint32_t, float> row_block_;
// Storage for the array interface strings.
std::string indptr_;
std::string values_;
std::string indices_;
public:
FileIterator(std::string uri, unsigned part_index, unsigned num_parts,
std::string type)
: uri_{std::move(uri)}, part_idx_{part_index}, n_parts_{num_parts},
type_{std::move(type)} {
XGProxyDMatrixCreate(&proxy_);
}
~FileIterator() {
XGDMatrixFree(proxy_);
}
int Next() {
CHECK(parser_);
if (parser_->Next()) {
row_block_ = parser_->Value();
using linalg::MakeVec;
indptr_ = ArrayInterfaceStr(MakeVec(row_block_.offset, row_block_.size + 1));
values_ = ArrayInterfaceStr(MakeVec(row_block_.value, row_block_.offset[row_block_.size]));
indices_ = ArrayInterfaceStr(MakeVec(row_block_.index, row_block_.offset[row_block_.size]));
size_t n_columns = *std::max_element(row_block_.index,
row_block_.index + row_block_.offset[row_block_.size]);
// dmlc parser converts 1-based indexing back to 0-based indexing so we can ignore
// this condition and just add 1 to n_columns
n_columns += 1;
XGProxyDMatrixSetDataCSR(proxy_, indptr_.c_str(), indices_.c_str(),
values_.c_str(), n_columns);
if (row_block_.label) {
XGDMatrixSetDenseInfo(proxy_, "label", row_block_.label, row_block_.size, 1);
}
if (row_block_.qid) {
XGDMatrixSetDenseInfo(proxy_, "qid", row_block_.qid, row_block_.size, 1);
}
if (row_block_.weight) {
XGDMatrixSetDenseInfo(proxy_, "weight", row_block_.weight, row_block_.size, 1);
}
// Continue iteration
return true;
} else {
// Stop iteration
return false;
}
}
auto Proxy() -> decltype(proxy_) { return proxy_; }
void Reset() {
CHECK(!type_.empty());
parser_.reset(dmlc::Parser<uint32_t>::Create(uri_.c_str(), part_idx_,
n_parts_, type_.c_str()));
}
};
namespace fileiter {
inline void Reset(DataIterHandle self) {
static_cast<FileIterator*>(self)->Reset();
}
inline int Next(DataIterHandle self) {
return static_cast<FileIterator*>(self)->Next();
}
} // namespace fileiter
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_FILE_ITERATOR_H_