64,266
社区成员
发帖
与我相关
我的任务
分享
#include <string>
#include <vector>
#include <sstream>
using namespace std;
struct TreeNode
{
int fid;
double value;
int index;
int left;
int right;
int isleaf;
};
class TreeLoadHelper
{
public:
TreeLoadHelper()
{
};
~TreeLoadHelper()
{
};
double model_predict(const vector<double>& feature_list);
bool load(const string& path);
private:
vector<vector<TreeNode> > alltrees_noleaf;
vector<vector<TreeNode> > alltrees_leaf;
vector<string> split_vet;
vector<string> thred_vet;
vector<string> left_vet;
vector<string> right_vet;
vector<string> leaf_vet;
vector<string> parse_line(string line, string col_name);
double predict_onetree(vector<TreeNode>& no_leaf_node_vet,vector<TreeNode>& leaf_node_vet,vector<double> featlst);
template<typename Out>
void inner_split(const std::string &s, char delim, Out result) {
std::stringstream ss(s);
std::string item;
while (std::getline(ss, item, delim)) {
*(result++) = item;
}
}
std::vector<std::string> str_split(const std::string &s, char delim) {
std::vector<std::string> elems;
inner_split(s, delim, std::back_inserter(elems));
return elems;
}
};
#include "TreeModel.h"
#include <fstream>
#include <cstdlib>
bool TreeLoadHelper::load(const string& path)
{
ifstream fin(path.c_str());
string line;
if (!fin) {
LOG_ERROR("load tree model file error:%s", path.c_str());
return false;
}
while(getline(fin, line)) {
if (line.find("Tree=")!=string::npos) {
if (!split_vet.empty() && !left_vet.empty() &&
!right_vet.empty() &&!leaf_vet.empty()) {
vector<TreeNode> no_leaf_node_vet;
vector<TreeNode> leaf_node_vet;
no_leaf_node_vet.clear();
leaf_node_vet.clear();
for (size_t i=0;i<split_vet.size();i++) {
no_leaf_node_vet.push_back(TreeNode());
}
for (size_t i=0;i<leaf_vet.size();i++) {
TreeNode leaf_node_ = TreeNode();
leaf_node_.isleaf = 1;
leaf_node_.index = -i-1;
//leaf_node_.value = stod(leaf_vet[i]);
leaf_node_.value = atof(leaf_vet[i].c_str());
leaf_node_vet.push_back(leaf_node_);
}
for (size_t i=0;i<split_vet.size();i++) {
TreeNode& index_node = no_leaf_node_vet[i];
index_node.index = i;
index_node.fid = atoi(split_vet[i].c_str());
index_node.value = atof(thred_vet[i].c_str());
int left_index = atoi(left_vet[i].c_str());
int right_index = atoi(right_vet[i].c_str());
index_node.left = left_index;
index_node.right = right_index;
}
alltrees_noleaf.push_back(no_leaf_node_vet);
alltrees_leaf.push_back(leaf_node_vet);
}
split_vet.clear();
thred_vet.clear();
left_vet.clear();
right_vet.clear();
leaf_vet.clear();
}
else {
if (split_vet.empty()) {
split_vet = parse_line(line, string("split_feature="));
}
if (thred_vet.empty()) {
thred_vet = parse_line(line, string("threshold="));
}
if (left_vet.empty()) {
left_vet = parse_line(line, string("left_child="));
}
if (right_vet.empty()) {
right_vet = parse_line(line, string("right_child="));
}
if (leaf_vet.empty()) {
leaf_vet = parse_line(line, string("leaf_value="));
}
}
}
return true;
}
double TreeLoadHelper::predict_onetree(vector<TreeNode>& no_leaf_node_vet,
vector<TreeNode>& leaf_node_vet,
vector<double> featlst) {
int num = 0;
int index = 0;
while(num<1000) {
num+=1;
TreeNode* root = NULL;
if (index>=0) {
root = &no_leaf_node_vet[index];
}
else {
root = &leaf_node_vet[-index-1];
}
if (root == NULL) {
return 0.0;
}
if (root->isleaf==1) {
return root->value;
}
else {
if (featlst[root->fid-1]<=root->value) {
index = root->left;
} else {
index = root->right;
}
}
}
return 0.0;
}
vector<string> TreeLoadHelper::parse_line(string line, string col_name)
{
vector<string> feat_vet;
if(line.find(col_name)!=string::npos)
{
string feat_str = line.substr(col_name.length(),string::npos);
feat_vet = str_split(feat_str , ' ');
}
return feat_vet;
}
double TreeLoadHelper::model_predict(const vector<double>& feature_list)
{
double pred = 0.0;
for (size_t i=0;i<alltrees_noleaf.size();i++) {
double score = predict_onetree(alltrees_noleaf[i], alltrees_leaf[i], feature_list);
pred+=score;
}
return pred;
}