64,647
社区成员
发帖
与我相关
我的任务
分享
class PostagModel {
public:
/**
* @brief 构造函数
**/
PostagModel() {
m_pImpt = NULL;
m_bOpen = false;
};
/**
* @brief 析构函数
*
**/
~PostagModel() {
printf("~postagmodel\n");
fflush(stdout);
}
/**
* @brief 根据模型文件的目录和语言id,创建模型
* @param[in] model_file_dir : const char* 模型文件目录
* @param[in] lang_id : const unsigned 语言ID: 1=english 0=chinese
* @return true=成功 false=失败
**/
bool create(const char *model_file_dir, const unsigned lang_id);
/**
* @brief 标注函数
* @param[in] words : const vector<string> & 句子向量,词条组成基本单位
* @param[out] tags : vector<string> 返回的标注tags序列
* @param[in] hmm_order : const unsigned 解码阶数
* @return true=成功 false=失败
**/
bool postag(const vector<string> &words, vector<string> &tags, const unsigned hmm_order);
/**
* @brief 释放资源
*
**/
void destroy();
private:
PostagModelImplement *m_pImpt; //模型实现接口
bool m_bOpen; //是否已经打开模型
};
#include <string>
#include <iostream>
#include <sstream>
#include <map>
#include <vector>
#include <cstdio>
using namespace std;
#include "HmmModel.h"
//基类
class PostagModelImplement {
public:
/**
* @brief 构造函数
*
**/
PostagModelImplement() {};
virtual ~PostagModelImplement() {};
/**
* @brief 初始化postag_model模型,postag内部初始化调用接口
**/
bool init_postag_model(const char *model_file_dir);
/**
* @brief 词性标注
* @param[in] words : const vector<string>& 词向量(明文)
* @param[out] tags : vector<string>& 词性序列
* @param[in] hmm_order : 指定几阶解码1/2
**/
bool postag(const vector<string> &words, vector<string> &tags, const unsigned hmm_order);
/**
* @brief 归一化oov,并利用归一化的oov获取oov的映射id
**/
virtual int normalizeOov_to_oovId(const char *oov) = 0;
/**
* @brief 判断是否为数字
**/
virtual bool is_CD_string(const char * str) = 0;
/**
* @brief 销毁postag_model
**/
void destroy();
protected:
map<string, int> word2Id; //word到id映射的存储结构
map<string, int> low_oov2Id;
map<string, int> up_oov2Id;
map<int, string> tagId2tag; //tagId到tag的映射,采取N维字符串形式存储,方便直接索引
private:
HmmModel m_hmm_model; //HMM解码接口
/**
* @brief 加载word2id、lowoov2id、upoov2id、tagid2tag文件的行数
**/
bool PostagModelImplement::load_linenum(const char* linenum_file);
/**
* @brief 加载word到wordId的映射关系
**/
bool load_word2Id(const char* word2Id_file);
/**
* @brief 加载oov到oovId的映射关系
**/
bool load_oov2Id(const char* oov2Id_file);
/**
* @brief 加载tagId到tag的映射关系
**/
bool load_tagId2tag(const char* tagId2tag_file);
/**
* @brief 将词序列映射为词id序列
**/
bool trans_word_to_wordId(vector<string>, int*);
};
/**
* 英文语言扩展类
**/
class EnglishPostagModelImpt : public PostagModelImplement {
public:
/**
* @brief EnglishPostagModel构造函数
**/
EnglishPostagModelImpt() {};
/**
* @brief 析构函数
**/
~EnglishPostagModelImpt() {};
/**
* @brief 英文,获取oov的id
**/
virtual int normalizeOov_to_oovId(const char *_oov);
/**
* @brief 是否是数字
**/
virtual bool is_CD_string(const char * str) ;
};
/**
* 中文语言扩展类
**/
class ChinesePostagModelImpt : public PostagModelImplement {
private:
map<string, int> *ch_oov2Id;
public:
ChinesePostagModelImpt() {};
~ChinesePostagModelImpt() {};
/**
* @brief 中文,获取oov的id
**/
virtual int normalizeOov_to_oovId(const char *_oov);
/**
* @brief 是否为数字串
**/
virtual bool is_CD_string(const char *str);
};
void PostagModel::destroy() {
m_pImpt->destroy();
delete m_pImpt;//报错报在这
m_pImpt = NULL;
m_bOpen = false;
}
bool PostagModel::create(const char *model_file_dir, const unsigned lang_id) {
if(model_file_dir == NULL || lang_id > 1) {
return false;
}
if(m_pImpt != NULL) {
delete m_pImpt;
m_pImpt = NULL;
}
if(lang_id == 0) { //中文
m_pImpt = new ChinesePostagModelImpt();
if(m_pImpt == NULL) {
goto failed;
}
} else if(lang_id == 1) { //英文
m_pImpt = new EnglishPostagModelImpt();
if(m_pImpt == NULL) {
goto failed;
}
} else { //其他
goto failed;
}
if (m_pImpt == NULL)
{
goto failed;
}
if(m_pImpt->init_postag_model(model_file_dir) == false) {
goto failed;
}
m_bOpen = true;
failed:
if (!m_bOpen && m_pImpt)
{
delete m_pImpt;
m_pImpt = NULL;
}
return m_bOpen;
}