18,356
社区成员
发帖
与我相关
我的任务
分享
#include<cstdio>
#include<thread>
#include<vector>
#include<queue>
#include<mutex>
#include<map>
#include<string>
#include<condition_variable>
#include<WinSock2.h>
#pragma comment(lib, "wsock32.lib")
typedef struct _dns_struct {
unsigned short id;//ID:长度为16位,是一个用户发送查询的时候定义的随机数
unsigned short flag; //标志: QR(1),Opcode(4),AA(1),TC(1),RD(1),RA(1),Z(3),RCode(4)
unsigned short ques; //QDCount:长度16位,报文请求段中的问题记录数。
unsigned short answ; //ANCount:长度16位,报文回答段中的回答记录数。
unsigned short auth; //NSCOUNT :长度16位,报文授权段中的授权记录数。
unsigned short addrrs; //ARCOUNT :长度16位,报文附加段中的附加记录数。
} dns_struct;
typedef struct _dns_query_type {
unsigned short type;
unsigned short classtype;
} dns_query_type;
typedef struct _dns_response_type {
unsigned short name;//C0 0C 域名指针
unsigned short type;//查询类型
unsigned short classtype;//分类
} dns_response_type;
typedef struct DNSSERVERSTRUCT{
int cType; //连接类型. 0 UDP; 1 TCP; 2 socks4; 3 4socks5
//SOCKET socket;
SOCKADDR_IN addr;
//rude::Socket *rudesocket;
} DNSSERVER;
typedef struct InfoUDPStruct
{
SOCKET clientSocket; //客户的socket
SOCKADDR_IN clientSockAddr; //客户地址信息
char buff[1000]; //接收到的第一条消息
int dataLength;
} PACKAGEINFO;
struct TRANSPACKAGE {
DNSSERVER *DNSServer;
unsigned short ID;
};
struct SavedDNS {
unsigned TTL; //过期时间
IN_ADDR IP;
};
struct IDRECORD {
int count;
PACKAGEINFO *info;
};
struct PRIORITYQUENODE {
std::string domain;
SavedDNS *savedDNS;
friend bool operator<(PRIORITYQUENODE &t1, PRIORITYQUENODE &t2) {
return t1.savedDNS->TTL < t2.savedDNS->TTL;
}
};
class DNS {
public:
int startListen();
int listenPort = 53;
char listenIP[20] = "";
int maxRecvThread = 1;
int maxTransThread = 1;
private:
void getDomainName(char *, std::string &);
int parse_dns_response_packet(char *buff, int length, char *dstbuff, unsigned int&);
void recvProcess();
void transProcess();
void listenUDP();
void listenTCP();
void addForbiddenIP(std::string);
void addServer(std::string);
void answerQuery(PACKAGEINFO*, std::string&, SavedDNS&);
std::queue<PACKAGEINFO*> recvQueue;
std::queue<TRANSPACKAGE*> transQueue;
std::vector<std::thread> threadVector;
SOCKET listenSocketUDP,listenSocketTCP;
std::vector<DNSSERVER> DNSServer;
std::vector<IN_ADDR> forbiddenIP;
sockaddr_in listenAddr;
int sizeofAddr = sizeof(listenAddr);
std::mutex recvLock;
std::mutex transLock;
std::mutex recvEmptyMutex;
std::mutex transEmptyMutex;
std::condition_variable recvEmpty;
std::condition_variable transEmpty;
std::map<std::string, SavedDNS> DNSCache;
std::map<unsigned short, IDRECORD> IDRecord; //记录请求的信息
std::priority_queue<PRIORITYQUENODE, std::vector<PRIORITYQUENODE>> TTLQueue;
};
int DNS::startListen() {
if (WSAStartup(MAKEWORD(2, 0), new WSADATA) != 0) {
return 0;
}
//addServer("udp 127.0.0.1:54");
//addServer("tcp 202.120.224.6");
//addServer("tcp 114.114.114.114");
//addServer("udp 114.114.115.115");
addServer("udp 223.5.5.5");
//addServer("udp 223.6.6.6");
//addServer("tcp 8.8.8.8:53");
//addServer("tcp 8.8.4.4:53");
//addServer("SOCKS5 127.0.0.1:1080 8.8.8.8:53");
addForbiddenIP("243.185.187.39");
addForbiddenIP("46.82.174.68");
addForbiddenIP("37.61.54.158");
addForbiddenIP("93.46.8.89");
addForbiddenIP("59.24.3.173");
addForbiddenIP("203.98.7.65");
addForbiddenIP("8.7.198.45");
addForbiddenIP("78.16.49.15");
addForbiddenIP("159.106.121.75");
addForbiddenIP("10.6.0.126");
addForbiddenIP("10.6.0.127");
listenSocketUDP = socket(AF_INET, SOCK_DGRAM, 0);
listenSocketTCP = socket(AF_INET, SOCK_STREAM, 0);
listenAddr.sin_family = AF_INET;
listenAddr.sin_port = htons(listenPort);
if (strcmp(listenIP, "") == 0 || strcmp(listenIP, "0.0.0.0") == 0) {
listenAddr.sin_addr.S_un.S_addr = INADDR_ANY;
}
else {
//InetPtonA(AF_INET, listenIP, &listenAddr);
listenAddr.sin_addr.S_un.S_addr = inet_addr(listenIP);
}
threadVector.reserve(200);
//threadVector.push_back(std::thread(std::bind(&DNS::listenTCP, this)));
for (int i = 0; i < maxRecvThread; i++) {
threadVector.push_back(std::thread(std::bind(&DNS::recvProcess, this)));
}
for (int i = 0; i < maxTransThread; i++) {
threadVector.push_back(std::thread(std::bind(&DNS::transProcess, this)));
}
listenUDP();
return 1;
}
void DNS::addForbiddenIP(std::string IP) {
IN_ADDR tmpaddr;
//InetPtonA(AF_INET, IP.c_str(), &tmpaddr);
tmpaddr.S_un.S_addr = inet_addr(listenIP);
forbiddenIP.push_back(tmpaddr);
}
void DNS::addServer(std::string server) {
DNSSERVER tmpserver;
unsigned short port1=1080, port2 = 53;
int m1, m2;
tmpserver.addr.sin_family = AF_INET;
for (auto &i : server) i = toupper(i);
m1 = server.find_first_of(' ');
m2 = server.find_last_of(' ');
std::string s1 = server.substr(m1 + 1, m2 - m1 - 1); //代理服务器
std::string s2 = server.substr(m2 + 1); //DNS服务器
if ((int)s1.find(":") > 0) {
port1 = atoi(s1.substr(s1.find(":") + 1).c_str());
s1 = s1.substr(0, s1.find(":"));
}
if ((int)s2.find(":") > 0) {
port2 = atoi(s2.substr(s2.find(":") + 1).c_str());
s2 = s2.substr(0, s2.find(":"));
}
tmpserver.addr.sin_port = htons(port2);
//InetPtonA(AF_INET, s2.c_str(), &tmpserver.addr.sin_addr);
tmpserver.addr.sin_addr.S_un.S_addr = inet_addr(s2.c_str());
if ((int)server.find("UDP") >= 0) {
tmpserver.cType = 0;
//connect(tmpserver.socket, (sockaddr*)&tmpserver.addr, sizeofAddr);
}
else if ((int)server.find("TCP") >= 0) {
tmpserver.cType = 1;
}
else if ((int)server.find("SOCKS4") >= 0) {
tmpserver.cType = 2;
//tmpserver.rudesocket = new rude::Socket();
//tmpserver.rudesocket->insertSocks4(s2.c_str(), port2, "");
//tmpserver.rudesocket->connect(s1.c_str(), port1);
}
else if ((int)server.find("SOCKS5") >= 0) {
tmpserver.cType = 3;
//tmpserver.rudesocket = new rude::Socket();
//tmpserver.rudesocket->insertSocks5(s2.c_str(), port2, "", "");
//tmpserver.rudesocket->connect(s1.c_str(), port1);
}
timeval tv_out = { 0, 300 };
setsockopt(listenSocketUDP, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv_out, sizeof(timeval));
DNSServer.push_back(tmpserver);
}
void DNS::listenUDP() {
char buff[3000] = {};
int recvLen = -1;
SOCKADDR_IN clientAddr; //接入用户址地
bind(listenSocketUDP, (sockaddr*)&listenAddr, sizeofAddr);
timeval tv_out = { 0, 300 };
setsockopt(listenSocketUDP, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv_out, sizeof(timeval));
while (true)
{
//当成功接到新来的用户时,发出处理请求
recvLen = recvfrom(listenSocketUDP, buff, sizeof(buff), 0, (sockaddr*)&clientAddr, &sizeofAddr);
if (recvLen != -1)
{
//队列已满就等待
PACKAGEINFO *info = new PACKAGEINFO;
info->clientSocket = NULL;
info->clientSockAddr = clientAddr;
memcpy(info->buff, buff, recvLen);
info->dataLength = recvLen;
recvLock.lock();
recvQueue.push(info);
recvLock.unlock();
recvEmpty.notify_one();
}
}
}
void DNS::listenTCP() {
char buff[3000] = {};
int recvLen = -1;
SOCKADDR_IN clientAddr; //接入用户址地
SOCKET clientSocket;
bind(listenSocketTCP, (sockaddr*)&listenAddr, sizeofAddr);
listen(listenSocketTCP, 5);
timeval tv_out = { 0, 300 };
setsockopt(listenSocketUDP, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv_out, sizeof(timeval));
while (true)
{
//当成功接到新来的用户时,发出处理请求
clientSocket = accept(listenSocketTCP, (sockaddr*)&clientAddr, &sizeofAddr);
recvLen = recv(clientSocket, buff, sizeof(buff), 0);
if (recvLen != -1)
{
PACKAGEINFO *info = new PACKAGEINFO;
info->clientSocket = clientSocket;
info->clientSockAddr = clientAddr;
memcpy(info->buff, buff, recvLen);
info->dataLength = recvLen;
recvLock.lock();
recvQueue.push(info);
recvLock.unlock();
recvEmpty.notify_one();
}
}
}
void DNS::recvProcess() {
std::unique_lock<std::mutex> recvEmptyUL(recvEmptyMutex);
PACKAGEINFO *info;
unsigned short ID;
std::string domainName = "";
while (true) {
while (recvQueue.empty()) {
recvEmpty.wait(recvEmptyUL); //等特新的处理项
}
recvLock.lock();
info = recvQueue.front();
recvQueue.pop();
recvLock.unlock();
ID = (info->buff[0] << 8) + info->buff[1];
getDomainName(info->buff, domainName);
if (DNSCache.count(domainName) <= 0 ) {
//if (DNSCache.count(domainName) <= 0 || time(NULL) >= DNSCache[domainName].TTL ) {
//无缓存,转发DNS查询
if (IDRecord.count(ID) > 0) {
continue;
}
IDRecord[ID].count = DNSServer.size();
IDRecord[ID].info = info;
for (unsigned int i = 0; i < DNSServer.size(); i++) {
TRANSPACKAGE *newPackage = new TRANSPACKAGE{ &DNSServer[i], ID };
transLock.lock();
transQueue.push(newPackage);
transLock.unlock();
transEmpty.notify_one();
}
}
else {
//有缓存,查缓存并回答
answerQuery(info, domainName, DNSCache[domainName]);
delete info;
}
}
}
void DNS::transProcess() {
std::unique_lock<std::mutex> transEmptyUL(transEmptyMutex);
unsigned short ID;
PACKAGEINFO *info;
char buff[3000];
int recvLen;
int retSend;
char dstIP[10];
bool isForbidden = false;
unsigned int ipOffset;
IN_ADDR newIP;
std::string domainName = "";
unsigned int newTTL = 0;
SavedDNS newDNS;
TRANSPACKAGE *transPackage;
DNSSERVER *thisServer;
while (true) {
while (transQueue.empty()) {
transEmpty.wait(transEmptyUL);
}
transLock.lock();
printf("CacheSize=%04i IDRecordSize=%04i recvQueueSize=%04i transQueueSize=%04i\n", DNSCache.size(), IDRecord.size(), recvQueue.size(), transQueue.size());
transPackage = transQueue.front();
transQueue.pop();
transLock.unlock();
ID = transPackage->ID;
thisServer = transPackage->DNSServer;
if (IDRecord.count(ID) > 0 && IDRecord[ID].count > 0) {
info = IDRecord[ID].info;
if (thisServer->cType == 0) {
SOCKET sSocket = socket(AF_INET, SOCK_DGRAM, 0);
timeval tv_out = { 0, 300 };
setsockopt(sSocket, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv_out, sizeof(timeval));
retSend = sendto(sSocket, info->buff, info->dataLength, 0, (sockaddr*)&thisServer->addr, sizeofAddr);
recvLen = recvfrom(sSocket, buff, sizeof(buff), 0, (sockaddr*)&thisServer->addr, &sizeofAddr);
closesocket(sSocket);
}
else if (thisServer->cType == 1) {
SOCKET sSocket = socket(AF_INET, SOCK_STREAM, 0);
timeval tv_out = { 0, 500 };
setsockopt(sSocket, SOL_SOCKET, SO_RCVTIMEO, (char *)&tv_out, sizeof(timeval));
connect(sSocket, (sockaddr*)&thisServer->addr, sizeofAddr);
retSend = send(sSocket, info->buff, info->dataLength, 0);
recvLen = recv(sSocket, buff, sizeof(buff), 0);
closesocket(sSocket);
}
else {
//retSend = thisServer->rudesocket->send(info->buff, info->dataLength);
//recvLen = thisServer->rudesocket->read(buff, sizeof(buff));
}
if (recvLen > 0) {
getDomainName(info->buff, domainName);
ipOffset = parse_dns_response_packet(buff, recvLen, dstIP, newTTL);
newIP.S_un.S_un_b.s_b1 = dstIP[0];
newIP.S_un.S_un_b.s_b2 = dstIP[1];
newIP.S_un.S_un_b.s_b3 = dstIP[2];
newIP.S_un.S_un_b.s_b4 = dstIP[3];
for (unsigned int j = 0; j < forbiddenIP.size(); j++) {
if (newIP.S_un.S_addr == forbiddenIP[j].S_un.S_addr) {
isForbidden = true;
break;
}
}
if (!isForbidden) {
newDNS.TTL = time(NULL) + newTTL;
newDNS.IP = newIP;
DNSCache[domainName] = newDNS;
answerQuery(IDRecord[ID].info, domainName, newDNS);
}
if (!isForbidden || --IDRecord[ID].count == 0) {
delete info;
IDRecord.erase(ID);
}
}
}
delete transPackage;
}
}
void DNS::answerQuery(PACKAGEINFO* info, std::string& domainName, SavedDNS& newDNS) {
char buff[3000]; //数据发送缓冲区
char name_ptr[2]; //资源记录中,指向查询名的指针
char answer_type[2]; //响应类型
char answer_class[2]; //响应类
char answer_ttl[4]; //响应TTL
char answer_data_length[2]; //资源数据长度
int retSend;
unsigned ttl = newDNS.TTL - time(NULL);
int package_len; //记录当前发送的数据的长度
int len = sizeof(SOCKADDR);
memcpy(buff, info->buff, info->dataLength);
package_len = info->dataLength;
buff[2] = buff[2] | 128; //QA为1,响应报文
//当域名为禁止的域名
buff[2] = buff[2] | 132;
buff[3] = buff[3] | 128;
buff[7] = buff[7] | 1;
/*buff[9] = buff[9] | 1;*/
//设置查询名,此为压缩格式,指前DNS开头12字节处
name_ptr[0] = 0xc0;
name_ptr[1] = 0x0c;
memcpy(&buff[package_len], name_ptr, 2);
package_len += 2;
//设置响应类型,此处应为类型1
answer_type[0] = 0x0;
answer_type[1] = 0x01;
memcpy(&buff[package_len], answer_type, 2);
package_len += 2;
//设置响应类,此处应为1
answer_class[0] = 0x0;
answer_class[1] = 0x01;
memcpy(&buff[package_len], answer_class, 2);
package_len += 2;
//设置响应TTL,此处设置为两天
answer_ttl[0] = ttl >> 24;
answer_ttl[1] = ttl << 8 >> 24;
answer_ttl[2] = ttl << 16 >> 24;
answer_ttl[3] = ttl << 24 >> 24;
memcpy(&buff[package_len], answer_ttl, 4);
package_len += 4;
//设置数据长度
answer_data_length[0] = 0;
answer_data_length[1] = 0x4;
memcpy(&buff[package_len], answer_data_length, 2);
package_len += 2;
//设置资源数据
memcpy(&buff[package_len], &newDNS.IP, 4);
package_len += 4;
if (info->clientSocket == NULL) {
//UDP连接
retSend = sendto(listenSocketUDP, buff, package_len, 0, (sockaddr*)&info->clientSockAddr, sizeofAddr);
}
else {
//TCP
retSend = send(info->clientSocket, buff, package_len, 0);
}
}
void DNS::getDomainName(char * buff, std::string & domainName)
{
domainName = "";
int i = 12; //请求包的前12个字节与域名无关
int j = 0;
int count = 0;
int flag = 0;
while (true)
{
//若计数字变成0,则取下一个计数字
if (count == 0)
{
count = buff[i];
i++;
//如果新取的计数字为0,则代表计数结束
if (count == 0)
{
break;
}
if (flag == 1)
{
domainName.append(1, '.');
}
flag = 1;
}
//否则存入string中
else
{
domainName.append(1, buff[i]);
i++;
count--;
}
}
}
int DNS::parse_dns_response_packet(char *buff, int length, char *dstbuff, unsigned int &ttl)
{
//解析DNS包
int dstoffset = 0;
char *ptr;
unsigned int addr = 0;
unsigned int count = 0;
unsigned int answ_number = 0;
unsigned int maxttl = 0;
dns_query_type *t = NULL;
dns_struct *dns_pkt = (dns_struct *)buff;
answ_number = htons(dns_pkt->answ);//回复的IP个数
ptr = buff + 12; //跳过头
while (ptr && *ptr)
{
ptr++;
}
ptr++;//跳过Question.Qname
t = (dns_query_type *)ptr;//查询类型
ptr += 4;//跳过Question
while (count < answ_number)
{
dns_response_type *res = NULL;
res = (dns_response_type *)ptr; //Name+Type+Class
ptr += 6;
memcpy(&ttl, ptr, 4);//生存时间。4字节
ttl = htonl(ttl);
ptr += 4;
unsigned short dlen = 0;
memcpy(&dlen, ptr, 2);//资源的长度
ptr += 2;
if (res->type == 0x0100)
{
if (ttl > maxttl) {
maxttl = ttl;
memcpy(dstbuff, ptr, 4);//IPv4 地址
dstoffset = 4;
}
}
ptr += htons(dlen);
count++;
}
ttl = maxttl;
return dstoffset;
}
int main(int argc, char *argv[])
{
DNS dns;
dns.startListen();
return 0;
}