package quic

import (
	"bytes"
	"crypto"
	"errors"
	"fmt"
	"io"
	"sort"

	"github.com/quic-go/quic-go/quicvarint"
	"golang.org/x/crypto/hkdf"
)

func ReadCryptoPayload(packet []byte) ([]byte, error) {
	hdr, offset, err := ParseInitialHeader(packet)
	if err != nil {
		return nil, err
	}
	// Some sanity checks
	if hdr.Version != V1 && hdr.Version != V2 {
		return nil, fmt.Errorf("unsupported version: %x", hdr.Version)
	}
	if offset == 0 || hdr.Length == 0 {
		return nil, errors.New("invalid packet")
	}

	initialSecret := hkdf.Extract(crypto.SHA256.New, hdr.DestConnectionID, getSalt(hdr.Version))
	clientSecret := hkdfExpandLabel(crypto.SHA256.New, initialSecret, "client in", []byte{}, crypto.SHA256.Size())
	key, err := NewInitialProtectionKey(clientSecret, hdr.Version)
	if err != nil {
		return nil, fmt.Errorf("NewInitialProtectionKey: %w", err)
	}
	pp := NewPacketProtector(key)
	// https://datatracker.ietf.org/doc/html/draft-ietf-quic-tls-32#name-client-initial
	//
	// "The unprotected header includes the connection ID and a 4-byte packet number encoding for a packet number of 2"
	if int64(len(packet)) < offset+hdr.Length {
		return nil, fmt.Errorf("packet is too short: %d < %d", len(packet), offset+hdr.Length)
	}
	unProtectedPayload, err := pp.UnProtect(packet[:offset+hdr.Length], offset, 2)
	if err != nil {
		return nil, err
	}
	frs, err := extractCryptoFrames(bytes.NewReader(unProtectedPayload))
	if err != nil {
		return nil, err
	}
	data := assembleCryptoFrames(frs)
	if data == nil {
		return nil, errors.New("unable to assemble crypto frames")
	}
	return data, nil
}

const (
	paddingFrameType = 0x00
	pingFrameType    = 0x01
	cryptoFrameType  = 0x06
)

type cryptoFrame struct {
	Offset int64
	Data   []byte
}

func extractCryptoFrames(r *bytes.Reader) ([]cryptoFrame, error) {
	var frames []cryptoFrame
	for r.Len() > 0 {
		typ, err := quicvarint.Read(r)
		if err != nil {
			return nil, err
		}
		if typ == paddingFrameType || typ == pingFrameType {
			continue
		}
		if typ != cryptoFrameType {
			return nil, fmt.Errorf("encountered unexpected frame type: %d", typ)
		}
		var frame cryptoFrame
		offset, err := quicvarint.Read(r)
		if err != nil {
			return nil, err
		}
		frame.Offset = int64(offset)
		dataLen, err := quicvarint.Read(r)
		if err != nil {
			return nil, err
		}
		frame.Data = make([]byte, dataLen)
		if _, err := io.ReadFull(r, frame.Data); err != nil {
			return nil, err
		}
		frames = append(frames, frame)
	}
	return frames, nil
}

// assembleCryptoFrames assembles multiple crypto frames into a single slice (if possible).
// It returns an error if the frames cannot be assembled. This can happen if the frames are not contiguous.
func assembleCryptoFrames(frames []cryptoFrame) []byte {
	if len(frames) == 0 {
		return nil
	}
	if len(frames) == 1 {
		return frames[0].Data
	}
	// sort the frames by offset
	sort.Slice(frames, func(i, j int) bool { return frames[i].Offset < frames[j].Offset })
	// check if the frames are contiguous
	for i := 1; i < len(frames); i++ {
		if frames[i].Offset != frames[i-1].Offset+int64(len(frames[i-1].Data)) {
			return nil
		}
	}
	// concatenate the frames
	data := make([]byte, frames[len(frames)-1].Offset+int64(len(frames[len(frames)-1].Data)))
	for _, frame := range frames {
		copy(data[frame.Offset:], frame.Data)
	}
	return data
}