mirror of
				https://github.com/ossrs/srs.git
				synced 2025-03-09 15:49:59 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			316 lines
		
	
	
	
		
			6.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			316 lines
		
	
	
	
		
			6.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package mdns
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"math/big"
 | |
| 	"net"
 | |
| 	"sync"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/pion/logging"
 | |
| 	"golang.org/x/net/dns/dnsmessage"
 | |
| 	"golang.org/x/net/ipv4"
 | |
| )
 | |
| 
 | |
| // Conn represents a mDNS Server
 | |
| type Conn struct {
 | |
| 	mu  sync.RWMutex
 | |
| 	log logging.LeveledLogger
 | |
| 
 | |
| 	socket  *ipv4.PacketConn
 | |
| 	dstAddr *net.UDPAddr
 | |
| 
 | |
| 	queryInterval time.Duration
 | |
| 	localNames    []string
 | |
| 	queries       []query
 | |
| 
 | |
| 	closed chan interface{}
 | |
| }
 | |
| 
 | |
| type query struct {
 | |
| 	nameWithSuffix  string
 | |
| 	queryResultChan chan queryResult
 | |
| }
 | |
| 
 | |
| type queryResult struct {
 | |
| 	answer dnsmessage.ResourceHeader
 | |
| 	addr   net.Addr
 | |
| }
 | |
| 
 | |
| const (
 | |
| 	inboundBufferSize    = 512
 | |
| 	defaultQueryInterval = time.Second
 | |
| 	destinationAddress   = "224.0.0.251:5353"
 | |
| 	maxMessageRecords    = 3
 | |
| 	responseTTL          = 120
 | |
| )
 | |
| 
 | |
| // Server establishes a mDNS connection over an existing conn
 | |
| func Server(conn *ipv4.PacketConn, config *Config) (*Conn, error) {
 | |
| 	if config == nil {
 | |
| 		return nil, errNilConfig
 | |
| 	}
 | |
| 
 | |
| 	ifaces, err := net.Interfaces()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	joinErrCount := 0
 | |
| 	for i := range ifaces {
 | |
| 		if err = conn.JoinGroup(&ifaces[i], &net.UDPAddr{IP: net.IPv4(224, 0, 0, 251)}); err != nil {
 | |
| 			joinErrCount++
 | |
| 		}
 | |
| 	}
 | |
| 	if joinErrCount >= len(ifaces) {
 | |
| 		return nil, errJoiningMulticastGroup
 | |
| 	}
 | |
| 
 | |
| 	dstAddr, err := net.ResolveUDPAddr("udp", destinationAddress)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 
 | |
| 	}
 | |
| 
 | |
| 	loggerFactory := config.LoggerFactory
 | |
| 	if loggerFactory == nil {
 | |
| 		loggerFactory = logging.NewDefaultLoggerFactory()
 | |
| 	}
 | |
| 
 | |
| 	localNames := []string{}
 | |
| 	for _, l := range config.LocalNames {
 | |
| 		localNames = append(localNames, l+".")
 | |
| 	}
 | |
| 
 | |
| 	c := &Conn{
 | |
| 		queryInterval: defaultQueryInterval,
 | |
| 		queries:       []query{},
 | |
| 		socket:        conn,
 | |
| 		dstAddr:       dstAddr,
 | |
| 		localNames:    localNames,
 | |
| 		log:           loggerFactory.NewLogger("mdns"),
 | |
| 		closed:        make(chan interface{}),
 | |
| 	}
 | |
| 	if config.QueryInterval != 0 {
 | |
| 		c.queryInterval = config.QueryInterval
 | |
| 	}
 | |
| 
 | |
| 	go c.start()
 | |
| 	return c, nil
 | |
| }
 | |
| 
 | |
| // Close closes the mDNS Conn
 | |
| func (c *Conn) Close() error {
 | |
| 	select {
 | |
| 	case <-c.closed:
 | |
| 		return nil
 | |
| 	default:
 | |
| 	}
 | |
| 
 | |
| 	if err := c.socket.Close(); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	<-c.closed
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // Query sends mDNS Queries for the following name until
 | |
| // either the Context is canceled/expires or we get a result
 | |
| func (c *Conn) Query(ctx context.Context, name string) (dnsmessage.ResourceHeader, net.Addr, error) {
 | |
| 	select {
 | |
| 	case <-c.closed:
 | |
| 		return dnsmessage.ResourceHeader{}, nil, errConnectionClosed
 | |
| 	default:
 | |
| 	}
 | |
| 
 | |
| 	nameWithSuffix := name + "."
 | |
| 
 | |
| 	queryChan := make(chan queryResult, 1)
 | |
| 	c.mu.Lock()
 | |
| 	c.queries = append(c.queries, query{nameWithSuffix, queryChan})
 | |
| 	ticker := time.NewTicker(c.queryInterval)
 | |
| 	c.mu.Unlock()
 | |
| 
 | |
| 	c.sendQuestion(nameWithSuffix)
 | |
| 	for {
 | |
| 		select {
 | |
| 		case <-ticker.C:
 | |
| 			c.sendQuestion(nameWithSuffix)
 | |
| 		case <-c.closed:
 | |
| 			return dnsmessage.ResourceHeader{}, nil, errConnectionClosed
 | |
| 		case res := <-queryChan:
 | |
| 			return res.answer, res.addr, nil
 | |
| 		case <-ctx.Done():
 | |
| 			return dnsmessage.ResourceHeader{}, nil, errContextElapsed
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func ipToBytes(ip net.IP) (out [4]byte) {
 | |
| 	rawIP := ip.To4()
 | |
| 	if rawIP == nil {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	ipInt := big.NewInt(0)
 | |
| 	ipInt.SetBytes(rawIP)
 | |
| 	copy(out[:], ipInt.Bytes())
 | |
| 	return
 | |
| }
 | |
| 
 | |
| func interfaceForRemote(remote string) (net.IP, error) {
 | |
| 	conn, err := net.Dial("udp", remote)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	localAddr := conn.LocalAddr().(*net.UDPAddr)
 | |
| 	if err := conn.Close(); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return localAddr.IP, nil
 | |
| }
 | |
| 
 | |
| func (c *Conn) sendQuestion(name string) {
 | |
| 	packedName, err := dnsmessage.NewName(name)
 | |
| 	if err != nil {
 | |
| 		c.log.Warnf("Failed to construct mDNS packet %v", err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	msg := dnsmessage.Message{
 | |
| 		Header: dnsmessage.Header{},
 | |
| 		Questions: []dnsmessage.Question{
 | |
| 			{
 | |
| 				Type:  dnsmessage.TypeA,
 | |
| 				Class: dnsmessage.ClassINET,
 | |
| 				Name:  packedName,
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	rawQuery, err := msg.Pack()
 | |
| 	if err != nil {
 | |
| 		c.log.Warnf("Failed to construct mDNS packet %v", err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	if _, err := c.socket.WriteTo(rawQuery, nil, c.dstAddr); err != nil {
 | |
| 		c.log.Warnf("Failed to send mDNS packet %v", err)
 | |
| 		return
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (c *Conn) sendAnswer(name string, dst net.IP) {
 | |
| 	packedName, err := dnsmessage.NewName(name)
 | |
| 	if err != nil {
 | |
| 		c.log.Warnf("Failed to construct mDNS packet %v", err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	msg := dnsmessage.Message{
 | |
| 		Header: dnsmessage.Header{
 | |
| 			Response:      true,
 | |
| 			Authoritative: true,
 | |
| 		},
 | |
| 		Answers: []dnsmessage.Resource{
 | |
| 			{
 | |
| 				Header: dnsmessage.ResourceHeader{
 | |
| 					Type:  dnsmessage.TypeA,
 | |
| 					Class: dnsmessage.ClassINET,
 | |
| 					Name:  packedName,
 | |
| 					TTL:   responseTTL,
 | |
| 				},
 | |
| 				Body: &dnsmessage.AResource{
 | |
| 					A: ipToBytes(dst),
 | |
| 				},
 | |
| 			},
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	rawAnswer, err := msg.Pack()
 | |
| 	if err != nil {
 | |
| 		c.log.Warnf("Failed to construct mDNS packet %v", err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	if _, err := c.socket.WriteTo(rawAnswer, nil, c.dstAddr); err != nil {
 | |
| 		c.log.Warnf("Failed to send mDNS packet %v", err)
 | |
| 		return
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (c *Conn) start() {
 | |
| 	defer func() {
 | |
| 		c.mu.Lock()
 | |
| 		defer c.mu.Unlock()
 | |
| 		close(c.closed)
 | |
| 	}()
 | |
| 
 | |
| 	b := make([]byte, inboundBufferSize)
 | |
| 	p := dnsmessage.Parser{}
 | |
| 
 | |
| 	for {
 | |
| 		n, _, src, err := c.socket.ReadFrom(b)
 | |
| 		if err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		func() {
 | |
| 			c.mu.RLock()
 | |
| 			defer c.mu.RUnlock()
 | |
| 
 | |
| 			if _, err := p.Start(b[:n]); err != nil {
 | |
| 				c.log.Warnf("Failed to parse mDNS packet %v", err)
 | |
| 				return
 | |
| 			}
 | |
| 
 | |
| 			for i := 0; i <= maxMessageRecords; i++ {
 | |
| 				q, err := p.Question()
 | |
| 				if err == dnsmessage.ErrSectionDone {
 | |
| 					break
 | |
| 				} else if err != nil {
 | |
| 					c.log.Warnf("Failed to parse mDNS packet %v", err)
 | |
| 					return
 | |
| 				}
 | |
| 
 | |
| 				for _, localName := range c.localNames {
 | |
| 					if localName == q.Name.String() {
 | |
| 
 | |
| 						localAddress, err := interfaceForRemote(src.String())
 | |
| 						if err != nil {
 | |
| 							c.log.Warnf("Failed to get local interface to communicate with %s: %v", src.String(), err)
 | |
| 							continue
 | |
| 						}
 | |
| 
 | |
| 						c.sendAnswer(q.Name.String(), localAddress)
 | |
| 					}
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			for i := 0; i <= maxMessageRecords; i++ {
 | |
| 				a, err := p.AnswerHeader()
 | |
| 				if err == dnsmessage.ErrSectionDone {
 | |
| 					return
 | |
| 				}
 | |
| 				if err != nil {
 | |
| 					c.log.Warnf("Failed to parse mDNS packet %v", err)
 | |
| 					return
 | |
| 				}
 | |
| 
 | |
| 				if a.Type != dnsmessage.TypeA && a.Type != dnsmessage.TypeAAAA {
 | |
| 					continue
 | |
| 				}
 | |
| 
 | |
| 				for i := len(c.queries) - 1; i >= 0; i-- {
 | |
| 					if c.queries[i].nameWithSuffix == a.Name.String() {
 | |
| 						c.queries[i].queryResultChan <- queryResult{a, src}
 | |
| 						c.queries = append(c.queries[:i], c.queries[i+1:]...)
 | |
| 					}
 | |
| 				}
 | |
| 			}
 | |
| 		}()
 | |
| 	}
 | |
| }
 |