package udp import ( "errors" "net" "github.com/apernet/OpenGFW/modifier" "github.com/google/gopacket" "github.com/google/gopacket/layers" ) var _ modifier.Modifier = (*DNSModifier)(nil) var ( errInvalidIP = errors.New("invalid ip") errNotValidDNSResponse = errors.New("not a valid dns response") errEmptyDNSQuestion = errors.New("empty dns question") ) type DNSModifier struct{} func (m *DNSModifier) Name() string { return "dns" } func (m *DNSModifier) New(args map[string]interface{}) (modifier.Instance, error) { i := &dnsModifierInstance{} aStr, ok := args["a"].(string) if ok { a := net.ParseIP(aStr).To4() if a == nil { return nil, &modifier.ErrInvalidArgs{Err: errInvalidIP} } i.A = a } aaaaStr, ok := args["aaaa"].(string) if ok { aaaa := net.ParseIP(aaaaStr).To16() if aaaa == nil { return nil, &modifier.ErrInvalidArgs{Err: errInvalidIP} } i.AAAA = aaaa } return i, nil } var _ modifier.UDPModifierInstance = (*dnsModifierInstance)(nil) type dnsModifierInstance struct { A net.IP AAAA net.IP } func (i *dnsModifierInstance) Process(data []byte) ([]byte, error) { dns := &layers.DNS{} err := dns.DecodeFromBytes(data, gopacket.NilDecodeFeedback) if err != nil { return nil, &modifier.ErrInvalidPacket{Err: err} } if !dns.QR || dns.ResponseCode != layers.DNSResponseCodeNoErr { return nil, &modifier.ErrInvalidPacket{Err: errNotValidDNSResponse} } if len(dns.Questions) == 0 { return nil, &modifier.ErrInvalidPacket{Err: errEmptyDNSQuestion} } // In practice, most if not all DNS clients only send one question // per packet, so we don't care about the rest for now. q := dns.Questions[0] switch q.Type { case layers.DNSTypeA: if i.A != nil { dns.Answers = []layers.DNSResourceRecord{{ Name: q.Name, Type: layers.DNSTypeA, Class: layers.DNSClassIN, IP: i.A, }} } case layers.DNSTypeAAAA: if i.AAAA != nil { dns.Answers = []layers.DNSResourceRecord{{ Name: q.Name, Type: layers.DNSTypeAAAA, Class: layers.DNSClassIN, IP: i.AAAA, }} } } buf := gopacket.NewSerializeBuffer() // Modifiers must be safe for concurrent use, so we can't reuse the buffer err = gopacket.SerializeLayers(buf, gopacket.SerializeOptions{ FixLengths: true, ComputeChecksums: true, }, dns) return buf.Bytes(), err }