1 module dnslib.parser;
2 
3 import dnslib.defs;
4 import dnslib.aux;
5 
6 import std.stdio;
7 import std.conv: to;
8 
9 
10 // ---------------------------------------------------------------------
11 
12 enum dnsParserResult
13 {
14 	success,
15 	tooLarge,
16 	dataMissing,
17 	specificationError,
18 	illegalCharacters
19 }
20 
21 class dnsParserException : Exception
22 {
23 	dnsParserResult parserResult;
24 	
25     this(dnsParserResult parserResult, string file = __FILE__, size_t line = __LINE__)
26     {
27 		this.parserResult = parserResult;
28 		import std.conv: to;
29         super("dnslib.parser exception " ~ to!string(parserResult), file, line);
30     }
31 }
32 
33 // ---------------------------------------------------------------------
34 
35 dnsParserResult dnsParse(const ref ubyte[] input, ref dnsMessage myDnsMessage, bool printEnabled = false)
36 {
37   dnsParserResult parserResult = dnsParserResult.success;
38   try
39   {
40 	dnsHeader		myDnsHeader;
41 
42 	querySection	myQuerySection;
43 	responseSection	myResponseSection;
44 
45 	ushort inputPtr = 0;
46 	
47 	// ----------
48 
49 	void parseHeader()
50 	{
51 		if (input.length < dnsHeader.sizeof) { throw new dnsParserException(dnsParserResult.dataMissing); }
52 
53 		myDnsHeader = cast(dnsHeader)cast(ubyte[dnsHeader.sizeof])(input[0 .. dnsHeader.sizeof]);
54 		inputPtr += dnsHeader.sizeof;
55 	}  // parseHeader
56 	
57 	// ----------
58 
59 	uint readUINT32(const ref ubyte[] input)
60 	{
61 		if (input.length - inputPtr < 4) { throw new dnsParserException(dnsParserResult.dataMissing); }	
62 		
63 		union U
64 		{
65 			uint i;
66 			ubyte[4] a;
67 		}
68 		
69 		U x;
70 		x.a = input[inputPtr .. inputPtr+4];
71 		import core.bitop: bswap;
72 		x.i = bswap(x.i);
73 		inputPtr += 4;
74 		return x.i;
75 	}
76 	
77 	// ----------
78 
79 	uint readUINT16(const ref ubyte[] input)
80 	{
81 		if (input.length - inputPtr < 2) { throw new dnsParserException(dnsParserResult.dataMissing); }	
82 		
83 		union U
84 		{
85 			uint i;
86 			ubyte[2] a;
87 		}
88 		
89 		U x;
90 		x.a = input[inputPtr .. inputPtr+2];
91 		inputPtr += 2;
92 		return x.a[0] * 256 + x.a[1];
93 	}
94 	// ----------
95 
96 	string readCharacterString(const ref ubyte[] input)
97 	{
98 		if (input.length - inputPtr < 1) { throw new dnsParserException(dnsParserResult.dataMissing); }	
99 
100 		ubyte length = input[inputPtr];
101 		inputPtr += 1;
102 		
103 		if (printEnabled) writefln("Read character string %d %d %d", input.length, inputPtr, length);
104 
105 		if (input.length - inputPtr < length) { throw new dnsParserException(dnsParserResult.dataMissing); }	
106 		
107 		string element = cast(string)(input[inputPtr .. inputPtr+length]);
108 		if (printEnabled) writefln("Element: %s", element);
109 
110 		inputPtr += length;
111 		return element;
112 	}
113 
114 	// ----------
115 
116 	string readDomainName(const ref ubyte[] input)
117 	{
118 		string[] elements;
119 
120 		void readSingleElement(ref ushort myInputPtr)
121 		{
122 			if (input.length - myInputPtr < 1) { throw new dnsParserException(dnsParserResult.dataMissing); }	
123 
124 			ubyte length = input[myInputPtr];
125 			myInputPtr += 1;
126 
127 			if (length > 63) { throw new dnsParserException(dnsParserResult.specificationError); }	
128 			
129 			if (printEnabled) writefln("%d %d %d", input.length, myInputPtr, length);
130 
131 			if (input.length - myInputPtr < length) { throw new dnsParserException(dnsParserResult.dataMissing); }	
132 			
133 			string element = cast(string)(input[myInputPtr .. myInputPtr+length]);
134 			if (printEnabled) writefln("Element: %s", element);
135 
136 			if (!noIllegalCharacters(element)) { throw new dnsParserException(dnsParserResult.illegalCharacters); }	
137 			myInputPtr += length;
138 
139 			elements ~= element;
140 		}  // readSingleElement
141 
142 		void readMultipleElements(ref ushort myInputPtr)
143 		{
144 			if (input.length - myInputPtr < 1) { throw new dnsParserException(dnsParserResult.dataMissing); }	
145 			
146 			while(input.length > myInputPtr && input[myInputPtr] > 0 && (input[myInputPtr] & 0b11000000) != 0b11000000)
147 			{
148 				if (printEnabled) writefln("Input1[%s]: %d", myInputPtr, input[myInputPtr]);
149 				readSingleElement(myInputPtr);
150 			}
151 
152 			if (input.length - myInputPtr < 1) { throw new dnsParserException(dnsParserResult.dataMissing); }	
153 
154 			if(input.length > myInputPtr && (input[myInputPtr] & 0b11000000) == 0b11000000) // Compression pointer
155 			{
156 				if (printEnabled) writefln("Input2[%s]: %d", myInputPtr, input[myInputPtr]);
157 				
158 				if (input.length - myInputPtr < 2) { throw new dnsParserException(dnsParserResult.dataMissing); }
159 				
160 				ushort pointer = (input[myInputPtr] & 0b00111111) * 256 + input[myInputPtr+1];
161 				myInputPtr += 2;
162 				
163 				if (printEnabled) writefln("Pointer %d", pointer);
164 
165 				auto oldInputPtr = myInputPtr;
166 				auto compressionInputPtr = pointer;
167 
168 				readMultipleElements(compressionInputPtr);
169 				
170 				myInputPtr = oldInputPtr;
171 
172 				if (printEnabled) writeln(elements);
173 			}
174 			else  // No compression pointer; Domain name is just a list of labels
175 			{
176 				if (input.length - myInputPtr < 1) { throw new dnsParserException(dnsParserResult.dataMissing); }	
177 				if (input[myInputPtr] != 0)        { throw new dnsParserException(dnsParserResult.specificationError); }	
178 				myInputPtr ++;
179 			}
180 
181 		}  // readMultipleElements
182 
183 		readMultipleElements(inputPtr);
184 		if (printEnabled) writefln("Input length: %d / %d", inputPtr, input.length);
185 
186 		import std..string: join;
187 		return elements.join(".");		
188 	}  // readDomainName
189 
190 	// ----------
191 
192 	void parseQuery()
193 	{
194 		string parsedDomainName = readDomainName(input);
195 		myQuerySection.domainName = parsedDomainName;
196 		
197 		if (input.length - inputPtr < 4) { throw new dnsParserException(dnsParserResult.dataMissing); }	
198 
199 		myQuerySection.queryType	= cast(ushort)(input[inputPtr+0]*256+input[inputPtr+1]);
200 		myQuerySection.queryClass	= cast(ushort)(input[inputPtr+2]*256+input[inputPtr+3]);
201 		if (printEnabled) writefln("Query type / class: %s / %s", myQuerySection.queryType, myQuerySection.queryClass);
202 		
203 		inputPtr += 4;	
204 	}  // parseQuery
205 
206 	// ----------
207 	
208 	void parseResponse()
209 	{
210 		string parsedDomainName2 = readDomainName(input);
211 		if (printEnabled) writefln("Response domain name:  %s", parsedDomainName2);
212 		myResponseSection.domainName		= parsedDomainName2;
213 
214 		if (input.length - inputPtr < 10) { throw new dnsParserException(dnsParserResult.dataMissing); }	
215 
216 		myResponseSection.responseType		= cast(ushort)(input[inputPtr+0]*256+input[inputPtr+1]);
217 		myResponseSection.responseClass		= cast(ushort)(input[inputPtr+2]*256+input[inputPtr+3]);
218 		myResponseSection.TTL				= cast(uint)(input[inputPtr+4]*(256^3)+input[inputPtr+5]*(256^2)+input[inputPtr+6]*256+input[inputPtr+7]);
219 		myResponseSection.responseDataLength = cast(ushort)(input[inputPtr+8]*256+input[inputPtr+9]);
220 
221 		if (printEnabled) writefln("Response type / class / TTL / datalength: %s / %s / %s / %s", myResponseSection.responseType, myResponseSection.responseClass, myResponseSection.TTL, myResponseSection.responseDataLength);
222 		
223 		inputPtr += 10;
224 
225 		if (input.length - myResponseSection.responseDataLength < 1) { throw new dnsParserException(dnsParserResult.dataMissing); }		
226 		myResponseSection.responseData = input[inputPtr .. inputPtr + myResponseSection.responseDataLength].dup;
227 
228 		ushort oldInputPtr = inputPtr;
229 
230 		if (myResponseSection.responseType == dnsType.MX)
231 		{
232 			if (input.length - inputPtr < 2) { throw new dnsParserException(dnsParserResult.dataMissing); }	
233 			ushort priority = input[inputPtr]*256 + input[inputPtr+1];
234 			inputPtr += 2;
235 			import std.conv;
236 			myResponseSection.responseElements ~= to!string(priority);
237 
238 			string parsedDomainName = readDomainName(input);
239 			myResponseSection.responseElements ~= parsedDomainName;
240 
241 			myResponseSection.responseString = to!string(priority) ~ " " ~ parsedDomainName;
242 		}
243 
244 		else if (myResponseSection.responseType == dnsType.TXT)
245 		{
246 			string characterString = readCharacterString(input);
247 			myResponseSection.responseElements ~= characterString;
248 			myResponseSection.responseString = characterString;
249 		}
250 		
251 		else if (myResponseSection.responseType == dnsType.CNAME 
252 			  || myResponseSection.responseType == dnsType.NS 
253 			  || myResponseSection.responseType == dnsType.DNAME 
254 			  || myResponseSection.responseType == dnsType.PTR 
255 		)
256 		{
257 			string parsedDomainName = readDomainName(input);
258 			myResponseSection.responseElements ~= parsedDomainName;
259 			myResponseSection.responseString = parsedDomainName;			
260 		}
261 
262 		else if (myResponseSection.responseType == dnsType.A
263 			  || myResponseSection.responseType == dnsType.AAAA
264 		)
265 		{
266 			import std.algorithm.iteration;
267 			import std.conv;
268 			auto x = map!(to!string)(myResponseSection.responseData);
269 			import std.range;
270 			string[] xx = x.array();
271 			
272 			import std..string: join;
273 			string IPstring = xx.join(".");
274 			myResponseSection.responseElements ~= IPstring;
275 
276 			myResponseSection.responseString = IPstring;
277 			
278 			if (myResponseSection.responseType == dnsType.A)
279 			{
280 				if (myResponseSection.responseDataLength != 4) { throw new dnsParserException(dnsParserResult.specificationError); }	
281 				
282 				inputPtr += 4;
283 			}
284 			else if (myResponseSection.responseType == dnsType.AAAA)
285 			{
286 				if (myResponseSection.responseDataLength != 8) { throw new dnsParserException(dnsParserResult.specificationError); }	
287 				inputPtr += 8;
288 			}
289 			else
290 			{
291 				assert(false);
292 			}
293 		}
294 
295 		else if (myResponseSection.responseType == dnsType.SOA)
296 		{
297 			string parsedMNAME = readDomainName(input);		
298 			string parsedRNAME = readDomainName(input);
299 			uint serial		= readUINT32(input);
300 			uint refresh	= readUINT32(input);
301 			uint retry		= readUINT32(input);
302 			uint expire		= readUINT32(input);
303 			uint minimum	= readUINT32(input);
304 			
305 			myResponseSection.responseElements ~= [parsedMNAME, parsedRNAME, to!string(serial), to!string(refresh), to!string(retry), to!string(expire), to!string(minimum)];
306 			
307 			import std.format;
308 			myResponseSection.responseString = format!"%s %s %d %d %d %d %d"(parsedMNAME, parsedRNAME, serial, refresh, retry, expire, minimum);
309 		}
310 
311 		else if (myResponseSection.responseType == dnsType.SRV)
312 		{
313 			uint priority	= readUINT16(input);
314 			uint weight		= readUINT16(input);
315 			uint port		= readUINT16(input);
316 
317 			string parsedTarget = readDomainName(input);
318 			//string parsedtarget = readCharacterString(input);
319 
320 			myResponseSection.responseElements ~= [to!string(priority), to!string(weight), to!string(port), parsedTarget];
321 			
322 			import std.format;
323 			myResponseSection.responseString = format!"%d %d %d %s"(priority, weight, port, parsedTarget);
324 		}
325 
326 		else
327 		{
328 			assert(false);
329 		}
330 
331 		ushort tempInputPtr = inputPtr;
332 		inputPtr = oldInputPtr;
333 		inputPtr += myResponseSection.responseDataLength;
334 		
335 		assert(tempInputPtr == inputPtr);
336 
337 		import std..string: join;
338 		assert(myResponseSection.responseString == myResponseSection.responseElements.join(" "));
339 
340 	}  // parseResponse
341 
342 	// ----------
343 
344 	if (printEnabled) writefln("Parsing header section: %d / %d", inputPtr, input.length);
345 
346 	parseHeader();
347 
348 	myDnsMessage.header = myDnsHeader;
349 	
350 	foreach (i; 0 .. myDnsMessage.header.QDCOUNT)
351 	{
352 		myQuerySection = querySection.init;
353 		if (printEnabled) writefln("Parsing query section %d: %d / %d", i, inputPtr, input.length);
354 		parseQuery();
355 		myDnsMessage.query ~= myQuerySection;
356 	}
357 	
358 	foreach (i; 0 .. myDnsMessage.header.ANCOUNT)
359 	{
360 		myResponseSection = responseSection.init;
361 		if (printEnabled) writefln("Parsing response section %d: %d / %d", i, inputPtr, input.length);
362 		parseResponse();
363 		myDnsMessage.answer ~= myResponseSection;
364 	}
365 
366 	foreach (i; 0 .. myDnsMessage.header.NSCOUNT)
367 	{
368 		myResponseSection = responseSection.init;
369 		if (printEnabled) writefln("Parsing response section %d: %d / %d", i, inputPtr, input.length);
370 		parseResponse();
371 		myDnsMessage.authority ~= myResponseSection;
372 	}
373 
374 	foreach (i; 0 .. myDnsMessage.header.ARCOUNT)
375 	{
376 		myResponseSection = responseSection.init;
377 		if (printEnabled) writefln("Parsing response section %d: %d / %d", i, inputPtr, input.length);
378 		parseResponse();
379 		myDnsMessage.additional ~= myResponseSection;
380 	}		
381 
382 	if (printEnabled) writefln("Parsing done: %d / %d", inputPtr, input.length);
383 	if (input.length != inputPtr) return dnsParserResult.tooLarge;
384   }  // try
385   catch (dnsParserException e)
386   {
387 	 parserResult = e.parserResult;
388   }
389 
390   return parserResult;
391 }  // dnsParse
392