1 module dnslib.parser; 2 3 import dnslib.defs; 4 import dnslib.aux; 5 6 import std.stdio; 7 import std.conv: to; 8 9 10 // --------------------------------------------------------------------- 11 12 enum dnsParserResult 13 { 14 success, 15 tooLarge, 16 dataMissing, 17 specificationError, 18 illegalCharacters 19 } 20 21 class dnsParserException : Exception 22 { 23 dnsParserResult parserResult; 24 25 this(dnsParserResult parserResult, string file = __FILE__, size_t line = __LINE__) 26 { 27 this.parserResult = parserResult; 28 import std.conv: to; 29 super("dnslib.parser exception " ~ to!string(parserResult), file, line); 30 } 31 } 32 33 // --------------------------------------------------------------------- 34 35 dnsParserResult dnsParse(const ref ubyte[] input, ref dnsMessage myDnsMessage, bool printEnabled = false) 36 { 37 dnsParserResult parserResult = dnsParserResult.success; 38 try 39 { 40 dnsHeader myDnsHeader; 41 42 querySection myQuerySection; 43 responseSection myResponseSection; 44 45 ushort inputPtr = 0; 46 47 // ---------- 48 49 void parseHeader() 50 { 51 if (input.length < dnsHeader.sizeof) { throw new dnsParserException(dnsParserResult.dataMissing); } 52 53 myDnsHeader = cast(dnsHeader)cast(ubyte[dnsHeader.sizeof])(input[0 .. dnsHeader.sizeof]); 54 inputPtr += dnsHeader.sizeof; 55 } // parseHeader 56 57 // ---------- 58 59 uint readUINT32(const ref ubyte[] input) 60 { 61 if (input.length - inputPtr < 4) { throw new dnsParserException(dnsParserResult.dataMissing); } 62 63 union U 64 { 65 uint i; 66 ubyte[4] a; 67 } 68 69 U x; 70 x.a = input[inputPtr .. inputPtr+4]; 71 import core.bitop: bswap; 72 x.i = bswap(x.i); 73 inputPtr += 4; 74 return x.i; 75 } 76 77 // ---------- 78 79 uint readUINT16(const ref ubyte[] input) 80 { 81 if (input.length - inputPtr < 2) { throw new dnsParserException(dnsParserResult.dataMissing); } 82 83 union U 84 { 85 uint i; 86 ubyte[2] a; 87 } 88 89 U x; 90 x.a = input[inputPtr .. inputPtr+2]; 91 inputPtr += 2; 92 return x.a[0] * 256 + x.a[1]; 93 } 94 // ---------- 95 96 string readCharacterString(const ref ubyte[] input) 97 { 98 if (input.length - inputPtr < 1) { throw new dnsParserException(dnsParserResult.dataMissing); } 99 100 ubyte length = input[inputPtr]; 101 inputPtr += 1; 102 103 if (printEnabled) writefln("Read character string %d %d %d", input.length, inputPtr, length); 104 105 if (input.length - inputPtr < length) { throw new dnsParserException(dnsParserResult.dataMissing); } 106 107 string element = cast(string)(input[inputPtr .. inputPtr+length]); 108 if (printEnabled) writefln("Element: %s", element); 109 110 inputPtr += length; 111 return element; 112 } 113 114 // ---------- 115 116 string readDomainName(const ref ubyte[] input) 117 { 118 string[] elements; 119 120 void readSingleElement(ref ushort myInputPtr) 121 { 122 if (input.length - myInputPtr < 1) { throw new dnsParserException(dnsParserResult.dataMissing); } 123 124 ubyte length = input[myInputPtr]; 125 myInputPtr += 1; 126 127 if (length > 63) { throw new dnsParserException(dnsParserResult.specificationError); } 128 129 if (printEnabled) writefln("%d %d %d", input.length, myInputPtr, length); 130 131 if (input.length - myInputPtr < length) { throw new dnsParserException(dnsParserResult.dataMissing); } 132 133 string element = cast(string)(input[myInputPtr .. myInputPtr+length]); 134 if (printEnabled) writefln("Element: %s", element); 135 136 if (!noIllegalCharacters(element)) { throw new dnsParserException(dnsParserResult.illegalCharacters); } 137 myInputPtr += length; 138 139 elements ~= element; 140 } // readSingleElement 141 142 void readMultipleElements(ref ushort myInputPtr) 143 { 144 if (input.length - myInputPtr < 1) { throw new dnsParserException(dnsParserResult.dataMissing); } 145 146 while(input.length > myInputPtr && input[myInputPtr] > 0 && (input[myInputPtr] & 0b11000000) != 0b11000000) 147 { 148 if (printEnabled) writefln("Input1[%s]: %d", myInputPtr, input[myInputPtr]); 149 readSingleElement(myInputPtr); 150 } 151 152 if (input.length - myInputPtr < 1) { throw new dnsParserException(dnsParserResult.dataMissing); } 153 154 if(input.length > myInputPtr && (input[myInputPtr] & 0b11000000) == 0b11000000) // Compression pointer 155 { 156 if (printEnabled) writefln("Input2[%s]: %d", myInputPtr, input[myInputPtr]); 157 158 if (input.length - myInputPtr < 2) { throw new dnsParserException(dnsParserResult.dataMissing); } 159 160 ushort pointer = (input[myInputPtr] & 0b00111111) * 256 + input[myInputPtr+1]; 161 myInputPtr += 2; 162 163 if (printEnabled) writefln("Pointer %d", pointer); 164 165 auto oldInputPtr = myInputPtr; 166 auto compressionInputPtr = pointer; 167 168 readMultipleElements(compressionInputPtr); 169 170 myInputPtr = oldInputPtr; 171 172 if (printEnabled) writeln(elements); 173 } 174 else // No compression pointer; Domain name is just a list of labels 175 { 176 if (input.length - myInputPtr < 1) { throw new dnsParserException(dnsParserResult.dataMissing); } 177 if (input[myInputPtr] != 0) { throw new dnsParserException(dnsParserResult.specificationError); } 178 myInputPtr ++; 179 } 180 181 } // readMultipleElements 182 183 readMultipleElements(inputPtr); 184 if (printEnabled) writefln("Input length: %d / %d", inputPtr, input.length); 185 186 import std..string: join; 187 return elements.join("."); 188 } // readDomainName 189 190 // ---------- 191 192 void parseQuery() 193 { 194 string parsedDomainName = readDomainName(input); 195 myQuerySection.domainName = parsedDomainName; 196 197 if (input.length - inputPtr < 4) { throw new dnsParserException(dnsParserResult.dataMissing); } 198 199 myQuerySection.queryType = cast(ushort)(input[inputPtr+0]*256+input[inputPtr+1]); 200 myQuerySection.queryClass = cast(ushort)(input[inputPtr+2]*256+input[inputPtr+3]); 201 if (printEnabled) writefln("Query type / class: %s / %s", myQuerySection.queryType, myQuerySection.queryClass); 202 203 inputPtr += 4; 204 } // parseQuery 205 206 // ---------- 207 208 void parseResponse() 209 { 210 string parsedDomainName2 = readDomainName(input); 211 if (printEnabled) writefln("Response domain name: %s", parsedDomainName2); 212 myResponseSection.domainName = parsedDomainName2; 213 214 if (input.length - inputPtr < 10) { throw new dnsParserException(dnsParserResult.dataMissing); } 215 216 myResponseSection.responseType = cast(ushort)(input[inputPtr+0]*256+input[inputPtr+1]); 217 myResponseSection.responseClass = cast(ushort)(input[inputPtr+2]*256+input[inputPtr+3]); 218 myResponseSection.TTL = cast(uint)(input[inputPtr+4]*(256^3)+input[inputPtr+5]*(256^2)+input[inputPtr+6]*256+input[inputPtr+7]); 219 myResponseSection.responseDataLength = cast(ushort)(input[inputPtr+8]*256+input[inputPtr+9]); 220 221 if (printEnabled) writefln("Response type / class / TTL / datalength: %s / %s / %s / %s", myResponseSection.responseType, myResponseSection.responseClass, myResponseSection.TTL, myResponseSection.responseDataLength); 222 223 inputPtr += 10; 224 225 if (input.length - myResponseSection.responseDataLength < 1) { throw new dnsParserException(dnsParserResult.dataMissing); } 226 myResponseSection.responseData = input[inputPtr .. inputPtr + myResponseSection.responseDataLength].dup; 227 228 ushort oldInputPtr = inputPtr; 229 230 if (myResponseSection.responseType == dnsType.MX) 231 { 232 if (input.length - inputPtr < 2) { throw new dnsParserException(dnsParserResult.dataMissing); } 233 ushort priority = input[inputPtr]*256 + input[inputPtr+1]; 234 inputPtr += 2; 235 import std.conv; 236 myResponseSection.responseElements ~= to!string(priority); 237 238 string parsedDomainName = readDomainName(input); 239 myResponseSection.responseElements ~= parsedDomainName; 240 241 myResponseSection.responseString = to!string(priority) ~ " " ~ parsedDomainName; 242 } 243 244 else if (myResponseSection.responseType == dnsType.TXT) 245 { 246 string characterString = readCharacterString(input); 247 myResponseSection.responseElements ~= characterString; 248 myResponseSection.responseString = characterString; 249 } 250 251 else if (myResponseSection.responseType == dnsType.CNAME 252 || myResponseSection.responseType == dnsType.NS 253 || myResponseSection.responseType == dnsType.DNAME 254 || myResponseSection.responseType == dnsType.PTR 255 ) 256 { 257 string parsedDomainName = readDomainName(input); 258 myResponseSection.responseElements ~= parsedDomainName; 259 myResponseSection.responseString = parsedDomainName; 260 } 261 262 else if (myResponseSection.responseType == dnsType.A 263 || myResponseSection.responseType == dnsType.AAAA 264 ) 265 { 266 import std.algorithm.iteration; 267 import std.conv; 268 auto x = map!(to!string)(myResponseSection.responseData); 269 import std.range; 270 string[] xx = x.array(); 271 272 import std..string: join; 273 string IPstring = xx.join("."); 274 myResponseSection.responseElements ~= IPstring; 275 276 myResponseSection.responseString = IPstring; 277 278 if (myResponseSection.responseType == dnsType.A) 279 { 280 if (myResponseSection.responseDataLength != 4) { throw new dnsParserException(dnsParserResult.specificationError); } 281 282 inputPtr += 4; 283 } 284 else if (myResponseSection.responseType == dnsType.AAAA) 285 { 286 if (myResponseSection.responseDataLength != 8) { throw new dnsParserException(dnsParserResult.specificationError); } 287 inputPtr += 8; 288 } 289 else 290 { 291 assert(false); 292 } 293 } 294 295 else if (myResponseSection.responseType == dnsType.SOA) 296 { 297 string parsedMNAME = readDomainName(input); 298 string parsedRNAME = readDomainName(input); 299 uint serial = readUINT32(input); 300 uint refresh = readUINT32(input); 301 uint retry = readUINT32(input); 302 uint expire = readUINT32(input); 303 uint minimum = readUINT32(input); 304 305 myResponseSection.responseElements ~= [parsedMNAME, parsedRNAME, to!string(serial), to!string(refresh), to!string(retry), to!string(expire), to!string(minimum)]; 306 307 import std.format; 308 myResponseSection.responseString = format!"%s %s %d %d %d %d %d"(parsedMNAME, parsedRNAME, serial, refresh, retry, expire, minimum); 309 } 310 311 else if (myResponseSection.responseType == dnsType.SRV) 312 { 313 uint priority = readUINT16(input); 314 uint weight = readUINT16(input); 315 uint port = readUINT16(input); 316 317 string parsedTarget = readDomainName(input); 318 //string parsedtarget = readCharacterString(input); 319 320 myResponseSection.responseElements ~= [to!string(priority), to!string(weight), to!string(port), parsedTarget]; 321 322 import std.format; 323 myResponseSection.responseString = format!"%d %d %d %s"(priority, weight, port, parsedTarget); 324 } 325 326 else 327 { 328 assert(false); 329 } 330 331 ushort tempInputPtr = inputPtr; 332 inputPtr = oldInputPtr; 333 inputPtr += myResponseSection.responseDataLength; 334 335 assert(tempInputPtr == inputPtr); 336 337 import std..string: join; 338 assert(myResponseSection.responseString == myResponseSection.responseElements.join(" ")); 339 340 } // parseResponse 341 342 // ---------- 343 344 if (printEnabled) writefln("Parsing header section: %d / %d", inputPtr, input.length); 345 346 parseHeader(); 347 348 myDnsMessage.header = myDnsHeader; 349 350 foreach (i; 0 .. myDnsMessage.header.QDCOUNT) 351 { 352 myQuerySection = querySection.init; 353 if (printEnabled) writefln("Parsing query section %d: %d / %d", i, inputPtr, input.length); 354 parseQuery(); 355 myDnsMessage.query ~= myQuerySection; 356 } 357 358 foreach (i; 0 .. myDnsMessage.header.ANCOUNT) 359 { 360 myResponseSection = responseSection.init; 361 if (printEnabled) writefln("Parsing response section %d: %d / %d", i, inputPtr, input.length); 362 parseResponse(); 363 myDnsMessage.answer ~= myResponseSection; 364 } 365 366 foreach (i; 0 .. myDnsMessage.header.NSCOUNT) 367 { 368 myResponseSection = responseSection.init; 369 if (printEnabled) writefln("Parsing response section %d: %d / %d", i, inputPtr, input.length); 370 parseResponse(); 371 myDnsMessage.authority ~= myResponseSection; 372 } 373 374 foreach (i; 0 .. myDnsMessage.header.ARCOUNT) 375 { 376 myResponseSection = responseSection.init; 377 if (printEnabled) writefln("Parsing response section %d: %d / %d", i, inputPtr, input.length); 378 parseResponse(); 379 myDnsMessage.additional ~= myResponseSection; 380 } 381 382 if (printEnabled) writefln("Parsing done: %d / %d", inputPtr, input.length); 383 if (input.length != inputPtr) return dnsParserResult.tooLarge; 384 } // try 385 catch (dnsParserException e) 386 { 387 parserResult = e.parserResult; 388 } 389 390 return parserResult; 391 } // dnsParse 392