1 module dnslib.defs;
2 
3 import std.stdio;
4 import std.conv;
5 import std.bitmanip;
6 
7 version(BigEndian)
8 {
9 	static assert(false);
10 }
11 
12 // ---------------------------------------------------------------------
13 
14 struct DnsOptions
15 {
16 	string		name					= "";
17 	dnsType		type					= dnsType.A;
18 	bool		recursionDesired		= true;
19 	bool		reverse					= false;
20 
21 	bool		hexStdin				= false;
22 	
23 	Protocol	protocol				= Protocol.udptcp;
24 	string		server					= "127.0.0.1";
25 	string		serverName				= "";
26 	
27 	bool		trusted					= true;
28 	string		trustedCertificateFile	= "/etc/ssl/certs/ca-certificates.crt";
29 	
30 	ushort		udpTcpPort				=  53;	
31 	ushort		tlsPort					= 853;	
32 	
33 	bool		printData				= false;
34 	bool		printParsing			= false;
35 	
36 	bool		verbose					= false;
37 	bool		quiet					= false;
38 	
39 	string toString()
40 	{
41 		import std..string: leftJustify;
42 		static import std.ascii;
43 		import std.array:  join;
44 
45 		string[] resultArray = [];
46 		import std.conv: to;
47 		import std.stdio;
48 		import std.traits;
49 		//const auto b = [ __traits(allMembers, DnsOptions) ];
50 		const auto b = FieldNameTuple!DnsOptions;
51 		
52 		static foreach(element; b)
53 		{
54 			mixin("resultArray ~= leftJustify(element, 23, ' ') ~ \": \"  ~ this." ~ element ~ ".text() ~ std.ascii.newline; ");
55 		}
56 			
57 		string s = resultArray.join();
58 		return s;	
59 	}
60 }
61 
62 // ---------------------------------------------------------------------
63 
64 version(ENABLE_TLS)
65 {
66 	enum Protocol : ubyte
67 	{
68 		none		= 0,
69 		udp			= 1,
70 		tcp			= 2,
71 		udptcp		= 3,
72 		tls			= 4,
73 		tlstcp		= 5,
74 		//https		= 6,
75 	}	
76 }
77 else
78 {
79 	enum Protocol : ubyte
80 	{
81 		none		= 0,
82 		udp			= 1,
83 		tcp			= 2,
84 		udptcp		= 3,
85 		//tls		= 4,
86 		//tlstcp	= 5,
87 		//https		= 6,
88 	}	
89 }
90 
91 // ---------------------------------------------------------------------
92 
93 enum dnsType : ushort
94 {
95 	A		=    1,
96 	NS		=    2,
97 	CNAME	=    5,
98 	SOA		=	 6,
99 	PTR		=   12,
100 	MX		=   15,
101 	TXT		=   16,
102 	AAAA	=   28,
103 	SRV		=   33,
104 	DNAME	=   39,
105 //	SSHFP	=	44,
106 //	TLSA	=   52,
107 //	CAA		=  257,
108 }
109 
110 bool typeValidate(ushort type)
111 {
112 	import std.traits;
113 	const auto b = [ __traits(allMembers, dnsType) ];
114 	static foreach(element; b)
115 	{
116 		mixin("if (type == dnsType."~element~") return true;");
117 	}
118 	
119 	return false;
120 }
121 
122 // ---------------------------------------------------------------------
123 
124 enum dnsClass : ushort
125 {
126 	INET	=	1,
127 	CH		=	3,
128 	HS		=	4,
129 }
130 
131 bool classValidate(ushort myClass)
132 {
133 	return (myClass == 1) || (myClass == 3) || (myClass == 4);
134 }
135 
136 string queryClassToString(ushort queryClass)
137 {
138 	string res = "Other";
139 	switch(queryClass)
140 	{
141 		case 1:	res = "INET";	break;
142 		case 3: res = "CHAOS";	break;
143 		case 4: res = "HESIOD";	break;
144 		default: break;
145 	}
146 	
147 	return res;
148 }
149 
150 // ---------------------------------------------------------------------
151 
152 enum dnsOpCode : ubyte
153 {
154 	QUERY		= 0,
155 	IQUERY		= 1,
156 	STATUS		= 2,
157 	NOTIFY		= 4,
158 	UPDATE		= 5,
159 	DSO			= 6,
160 }
161 
162 bool opCodeValidate(ubyte opCode)
163 {
164 	return (opCode <= 6 && opCode != 3);
165 }
166 
167 string opCodeToString(ubyte	opCode)
168 {
169 	string res = "Unassigned";
170 	switch(opCode)
171 	{
172 		case 0: res = "Query"; break;
173 		case 1: res = "IQuery (OBSOLETE)"; break;
174 		case 2: res = "Status"; break;
175 		case 4: res = "Notify"; break;
176 		case 5: res = "Update"; break;
177 		case 6: res = "DNS Stateful Operations (DSO)"; break;
178 		default: break;
179 	}
180 	return res;
181 }
182 
183 // ---------------------------------------------------------------------
184 
185 enum dnsResponseCode : ubyte
186 {
187 	NOERROR		= 0,
188 	FORMERROR	= 1,
189 	SERVERFAIL	= 2,
190 	NXDOMAIN	= 3,
191 	NOTIMP		= 4,
192 	REFUSED		= 5,
193 }
194 
195 bool responseCodeValidate(ubyte responseCode)
196 {
197 	return (responseCode <= 5);
198 }
199 
200 string responseCodeToString(ubyte responseCode)
201 {
202 	string res = "Other";
203 	switch(responseCode)
204 	{
205 		case  0: res = "NoError";	break;
206 		case  1: res = "FormErr";	break;
207 		case  2: res = "ServFail";	break;
208 		case  3: res = "NXDomain";	break;
209 		case  4: res = "NotImp";	break;
210 		case  5: res = "Refused";	break;
211 		
212 		default: break;
213 	}
214 	
215 	return res;
216 }  // responseCodeToString
217 
218 // ---------------------------------------------------------------------
219 
220 struct dnsHeaderFlags
221 {
222 	union
223 	{
224 		ushort a;
225 
226 		// ToDo: Check for bitdirection for opCode, responseCode and reserved
227 		mixin(bitfields!(
228 			bool,  "recursionDesired",		1,
229 			bool,  "truncation",			1,
230 			bool,  "authoritativeAnswer",	1,
231 			ubyte, "opCode",				4,
232 			bool,  "queryResponse",			1,
233 
234 			// byte border
235 
236 			ubyte, "responseCode",			4,
237 
238 			// See https://tools.ietf.org/html/rfc2535#page-15
239 			ubyte, "reserved",				3,
240 			//bool,  "checkingDisabled",	1,
241 			//bool,  "authenticData",		1,
242 			//bool,  "recursionAvailable",	1,
243 
244 			bool,  "recursionAvailable",	1,
245 
246 			));
247 	}  // union
248 	
249 	string toString()
250 	{
251 		import std.format;
252 		return format!"%s - %s - %s - %d %d %d %d - %d"(dnsHeaderFlagQueryResponseToString(queryResponse), opCodeToString(opCode), responseCodeToString(responseCode), recursionDesired, recursionAvailable, truncation, authoritativeAnswer, reserved);
253 	}
254 	
255 }  // struct dnsHeaderFlags
256 
257 static assert(dnsHeaderFlags.sizeof == 2); // 16 bits
258 
259 unittest
260 {
261 	dnsHeaderFlags x;
262 	x.a = 0;
263 	assert(x.recursionDesired == false);
264 	assert(x.a == 0);
265 	
266 	x.recursionDesired = true;
267 	assert(x.recursionDesired == true);
268 	assert(x.a != 0);
269 }
270 
271 // ---------------------------------------------------------------------
272 
273 enum dnsHeaderFlagQueryResponse : bool
274 {
275 	query		= false,
276 	response	= true
277 }
278 
279 string dnsHeaderFlagQueryResponseToString(bool x)
280 {
281 	if (x) { return "response"; } else { return "query"; }
282 }
283 
284 struct dnsHeader{
285 	private ushort _ID;
286 	dnsHeaderFlags flags;	
287 	private ushort _QDCOUNT, _ANCOUNT, _NSCOUNT, _ARCOUNT;
288 	
289 	import std.bitmanip: swapEndian;
290 	
291 	@property ushort ID() { return _ID.swapEndian(); }
292 	@property ushort QDCOUNT() { return _QDCOUNT.swapEndian(); }
293 	@property ushort ANCOUNT() { return _ANCOUNT.swapEndian(); }
294 	@property ushort NSCOUNT() { return _NSCOUNT.swapEndian(); }
295 	@property ushort ARCOUNT() { return _ARCOUNT.swapEndian(); }
296 	
297 	void print()
298 	{
299 		writeln("HEADER SECTION:");
300 		writefln("ID:           %s", ID);
301 		writefln("QR:           %s", to!dnsHeaderFlagQueryResponse(flags.queryResponse));
302 		writefln("Opcode:       %s", flags.opCode.opCodeToString);
303 		writefln("Responsecode: %s", flags.responseCode.responseCodeToString);
304 		writefln("Reserved:     %s", flags.reserved);
305 		write(   "Flags:        ");
306 		if (flags.authoritativeAnswer)	write("authoritativeAnswer ");
307 		if (flags.truncation)			write("truncation ");
308 		if (flags.recursionDesired)		write("recursionDesired ");
309 		if (flags.recursionAvailable)	write("recursionAvailable ");
310 		writeln();
311 		//writefln("flagbits:     %b", flags.a);
312 		
313 		writefln("QDCOUNT: %d", QDCOUNT);					
314 		writefln("ANCOUNT: %d", ANCOUNT);					
315 		writefln("NSCOUNT: %d", NSCOUNT);					
316 		writefln("ARCOUNT: %d", ARCOUNT);					
317 	}
318 	
319 	string toString()
320 	{
321 		import std.format;
322 		return format!"%s ; %s ; %s %s %s %s"(ID().text(), flags.toString(), QDCOUNT().text(), ANCOUNT().text(), NSCOUNT().text(), ARCOUNT().text());
323 		
324 	}
325 }  // struct dnsHeader
326 
327 static assert(dnsHeader.sizeof == 12);
328 
329 
330 struct querySection
331 {
332 	string domainName;
333 	ushort queryType;
334 	ushort queryClass;
335 	
336 	void print()
337 	{
338 		writeln("QUERY SECTION:");
339 		writefln("query name:  %s", domainName);
340 		writefln("query type:  %s", to!dnsType(queryType));
341 		writefln("query class: %s", queryClass.queryClassToString);			
342 	}
343 }
344 
345 struct responseSection
346 {
347 	string	domainName;
348 	ushort	responseType;
349 	ushort	responseClass;
350 	uint	TTL;
351 	ushort	responseDataLength;
352 	ubyte[]	responseData;
353 	string[] responseElements;
354 	string	responseString;
355 	
356 	// Do not call function getMXPriority before responseType has been checked (type == MX) and message has been validated (and is valid)
357 	ushort	getMXPriority()
358 	{
359 		assert(responseDataLength >= 2);
360 		return responseData[0]*256 + responseData[1];
361 	}
362 	
363 	void print()
364 	{
365 		writeln("RESPONSE SECTION:");
366 		writefln("response name:  %s", domainName);
367 		writefln("response type:  %s", to!dnsType(responseType));
368 		writefln("response class: %s", responseClass.queryClassToString);			
369 		writefln("TTL: %d", TTL);			
370 		writefln("response data length: %d", responseDataLength);			
371 		writefln("response data:        %s", responseData);
372 		writefln("response elements:    %s", responseElements);
373 		writefln("response string:      %s", responseString);
374 	}
375 }
376 
377 // ---------------------------------------------------------------------
378 
379 enum dnsMessageValidateCode : uint
380 {
381 	success								= 0,
382 
383 	// HEADER 
384 	header_wrong_query_response			= 2 ^^  1,
385 	invalid_op_code						= 2 ^^  2,
386 	query_op_code_error					= 2 ^^  3,
387 	query_flag_error					= 2 ^^  4,
388 	header_reserved_flag_use			= 2 ^^  5,
389 	invalid_response_code				= 2 ^^  6,
390 	query_response_code_error			= 2 ^^  7,
391 	section_count_mismatch				= 2 ^^  8,
392 	query_section_count_error			= 2 ^^  9,
393 	query_other_sections_count_error	= 2 ^^ 10,
394 	
395 	// SECTION
396 	msg_class_not_inet					= 2 ^^ 11,
397 
398 	invalid_class						= 2 ^^ 12,
399 	invalid_type						= 2 ^^ 13,
400 
401 	inconsistent_class					= 2 ^^ 14,
402 	inconsistent_type					= 2 ^^ 15,
403 
404 	response_data_length_mismatch		= 2 ^^ 16,
405 	response_data_length_error			= 2 ^^ 17,
406 
407 	authority_section_type_error		= 2 ^^ 18,
408 	
409 	// MISC
410 	unknown_error						= 2 ^^ 31,	
411 }
412 
413 string messageValidateCodeToString(uint code)
414 {
415 	string[] resultArray = [];
416 	import std.conv: to;
417 	import std.stdio;
418 	import std.traits;
419 	const auto b = [ __traits(allMembers, dnsMessageValidateCode) ];
420 	static foreach(element; b)
421 	{
422 		mixin("if ((code & dnsMessageValidateCode."~element~") != 0) { resultArray ~= \""~element~"\"; }");
423 	}
424 	
425 	string result = "success";
426 	import std..string: join;
427 	if (resultArray.length > 0) result = resultArray.join(", ");
428 	
429 	return result;
430 }
431 
432 // ---------------------------------------------------------------------
433 
434 struct dnsMessage
435 {
436 	dnsHeader			header;
437 	querySection[]		query;
438 	responseSection[]	answer;
439 	responseSection[]	authority;
440 	responseSection[]	additional;
441 
442 	uint validate(dnsHeaderFlagQueryResponse flagQueryResponse, bool strict = true)
443 	{
444 		immutable OPCODE_QUERY			= 0;
445 		immutable RESPONSECODE_NOERROR	= 0;
446 
447 		uint result = dnsMessageValidateCode.success;
448 
449 		// HEADER
450 		if (header.flags.queryResponse != flagQueryResponse)
451 			result |= dnsMessageValidateCode.header_wrong_query_response;
452 
453 		if (!opCodeValidate(header.flags.opCode))
454 			result |= dnsMessageValidateCode.invalid_op_code;
455 
456 		if (header.flags.queryResponse == dnsHeaderFlagQueryResponse.query && header.flags.opCode != OPCODE_QUERY) 
457 			result |= dnsMessageValidateCode.query_op_code_error;
458 
459 		if (header.flags.queryResponse == dnsHeaderFlagQueryResponse.query && (header.flags.authoritativeAnswer || header.flags.truncation  || header.flags.recursionAvailable))
460 			result |= dnsMessageValidateCode.query_flag_error;
461 
462 		if (header.flags.reserved != 0)
463 			result |= dnsMessageValidateCode.header_reserved_flag_use;
464 
465 		if (!responseCodeValidate(header.flags.responseCode))
466 			result |= dnsMessageValidateCode.invalid_response_code;
467 
468 		if (header.flags.queryResponse == dnsHeaderFlagQueryResponse.query && header.flags.responseCode != RESPONSECODE_NOERROR) 
469 			result |= dnsMessageValidateCode.query_response_code_error;
470 
471 		if (header.QDCOUNT != query.length || header.ANCOUNT != answer.length || header.NSCOUNT != authority.length ||  header.ARCOUNT != additional.length) 
472 			result |= dnsMessageValidateCode.section_count_mismatch;
473 
474 		if (header.QDCOUNT == 0)
475 			result |= dnsMessageValidateCode.query_section_count_error;
476 
477 		if (header.flags.queryResponse == dnsHeaderFlagQueryResponse.query && header.QDCOUNT != 1 && strict)
478 			result |= dnsMessageValidateCode.query_section_count_error;
479 
480 		if (header.flags.queryResponse == dnsHeaderFlagQueryResponse.query && (header.ANCOUNT != 0 || header.NSCOUNT != 0 || header.ARCOUNT != 0))
481 			result |= dnsMessageValidateCode.query_other_sections_count_error;
482 
483 		// SECTIONS
484 		ushort  msgType		= ushort.max;	// DUMMY VALUE
485 		ushort  msgClass	= ushort.max;	// DUMMY VALUE
486 
487 		if (header.QDCOUNT != 0)
488 		{
489 			msgType		= query[0].queryType;
490 			msgClass	= query[0].queryClass;
491 		}
492 
493 		if (msgClass != dnsClass.INET && strict)
494 			result |= dnsMessageValidateCode.msg_class_not_inet;
495 
496 		foreach(section; query)
497 		{
498 			if (!typeValidate(section.queryType))
499 				result |= dnsMessageValidateCode.invalid_type;
500 
501 			if (!classValidate(section.queryClass))
502 				result |= dnsMessageValidateCode.invalid_class;
503 
504 			if (section.queryType  != msgType)
505 				result |= dnsMessageValidateCode.inconsistent_type;
506 				
507 			if (section.queryClass != msgClass)
508 				result |= dnsMessageValidateCode.inconsistent_class;
509 		}
510 
511 		foreach(section; answer)
512 		{
513 			if (!typeValidate(section.responseType))
514 				result |= dnsMessageValidateCode.invalid_type;
515 
516 			if (!classValidate(section.responseClass))
517 				result |= dnsMessageValidateCode.invalid_class;
518 
519 			if (section.responseType  != msgType && section.responseType != dnsType.CNAME)
520 				result |= dnsMessageValidateCode.inconsistent_type;
521 				
522 			if (section.responseClass != msgClass)
523 				result |= dnsMessageValidateCode.inconsistent_class;
524 		
525 			if (section.responseDataLength != section.responseData.length)
526 				result |= dnsMessageValidateCode.response_data_length_mismatch;
527 
528 			// dnsType.A
529 			if (section.responseType == dnsType.A && section.responseDataLength != 4) 
530 				result |= dnsMessageValidateCode.response_data_length_error;
531 				
532 			// dnsType.MX
533 			if (section.responseType == dnsType.MX && section.responseDataLength <= 2) 
534 				result |= dnsMessageValidateCode.response_data_length_error;
535 
536 			// ... ToDo for other responseTypes
537 		}
538 
539 		foreach(section; authority)
540 		{
541 			if (!typeValidate(section.responseType))
542 				result |= dnsMessageValidateCode.invalid_type;
543 
544 			if (!classValidate(section.responseClass))
545 				result |= dnsMessageValidateCode.invalid_class;
546 
547 			if (section.responseType  != dnsType.NS
548 			||  section.responseType  != dnsType.SOA		// ToDo: Check RFC specification
549 			)
550 				result |= dnsMessageValidateCode.authority_section_type_error;
551 				
552 			if (section.responseClass != msgClass)
553 				result |= dnsMessageValidateCode.inconsistent_class;
554 				
555 			if (section.responseDataLength != section.responseData.length)
556 				result |= dnsMessageValidateCode.response_data_length_mismatch;
557 		}
558 
559 		foreach(section; additional)
560 		{
561 			if (!typeValidate(section.responseType))
562 				result |= dnsMessageValidateCode.invalid_type;
563 
564 			if (!classValidate(section.responseClass))
565 				result |= dnsMessageValidateCode.invalid_class;
566 
567 			if (section.responseType  != msgType)
568 				result |= dnsMessageValidateCode.inconsistent_type;
569 				
570 			if (section.responseClass != msgClass)
571 				result |= dnsMessageValidateCode.inconsistent_class;
572 		
573 			if (section.responseDataLength != section.responseData.length)
574 				result |= dnsMessageValidateCode.response_data_length_mismatch;
575 		}
576 
577 		return result;
578 	}  // validate
579 
580 	void print()
581 	{
582 		header.print();
583 
584 		if (query.length == 0)
585 		{
586 			writeln("NO QUERY SECTION");
587 		}
588 		else
589 		{
590 			writeln("QUERY SECTIONS");
591 			foreach(e; query)
592 			e.print;
593 			
594 		}
595 
596 		if (answer.length == 0)
597 		{
598 			writeln("NO ANSWER SECTION");
599 		}
600 		else
601 		{
602 			writeln("ANSWER SECTIONS");
603 			foreach(e; answer)
604 			e.print;
605 		}
606 
607 		if (authority.length == 0)
608 		{
609 			writeln("NO AUTHORITY SECTION");
610 		}
611 		else
612 		{
613 			writeln("AUTHORITY SECTIONS");
614 			foreach(e; authority)
615 			e.print;
616 		}
617 
618 		if (additional.length == 0)
619 		{
620 			writeln("NO ADDITIONAL SECTION");
621 		}
622 		else
623 		{
624 			writeln("ADDITIONAL SECTIONS");
625 			foreach(e; additional)
626 			e.print;
627 		}
628 	}  // print
629 
630 	string[] getShortResult(dnsType type = dnsType.A)
631 	{
632 		string[] result;
633 		foreach(element; answer)
634 		{
635 			if (element.responseType != type) continue;
636 			result ~= element.responseString;
637 		}
638 		
639 		return result;
640 	}  // getShortResult
641 	
642 	void printShort(dnsType type = dnsType.A)
643 	{		
644 		foreach(result; getShortResult(type))
645 		{
646 			writeln(result);
647 		}
648 	}  // printShort
649 	
650 }