package csidh

import (
	"bytes"
	"encoding/hex"
	"encoding/json"
	"fmt"
	"os"
	"testing"

	crand "crypto/rand"

	"github.com/henrydcase/nobs/drbg"
)

// Possible values for "Status"
const (
	Valid               = iota // Indicates that shared secret must be agreed correctly
	ValidPublicKey2            // Public key 2 must succeed validation
	InvalidSharedSecret        // Calculated shared secret must be different than test vector
	InvalidPublicKey1          // Public key 1 generated from private key must be different than test vector
	InvalidPublicKey2          // Public key 2 must fail validation
)

var StatusValues = map[int]string{
	Valid:               "valid",
	ValidPublicKey2:     "valid_public_key2",
	InvalidSharedSecret: "invalid_shared_secret",
	InvalidPublicKey1:   "invalid_public_key1",
	InvalidPublicKey2:   "invalid_public_key2",
}

type TestVector struct {
	ID     int    `json:"Id"`
	Pk1    string `json:"Pk1"`
	Pr1    string `json:"Pr1"`
	Pk2    string `json:"Pk2"`
	Ss     string `json:"Ss"`
	Status string `json:"status"`
}

type TestVectors struct {
	Vectors []TestVector `json:"Vectors"`
}

var rng *drbg.CtrDrbg

func init() {
	var tmp [32]byte

	// Init drbg
	rng = drbg.NewCtrDrbg()
	crand.Read(tmp[:])
	if !rng.Init(tmp[:], nil) {
		panic("Can't initialize DRBG")
	}
}

func TestCompare64(t *testing.T) {
	const s uint64 = 0xFFFFFFFFFFFFFFFF
	var val1 = fp{0, 2, 3, 4, 5, 6, 7, 8}
	var val2 = fp{s, s, s, s, s, s, s, s}
	var fp fp

	if !fp.isZero() {
		t.Errorf("isZero returned true, where it should be false")
	}
	if val1.isZero() {
		t.Errorf("isZero returned false, where it should be true")
	}
	if val2.isZero() {
		t.Errorf("isZero returned false, where it should be true")
	}
}

func TestEphemeralKeyExchange(t *testing.T) {
	var ss1, ss2 [64]byte
	var prv1, prv2 PrivateKey
	var pub1, pub2 PublicKey

	prvBytes1 := []byte{0xaa, 0x54, 0xe4, 0xd4, 0xd0, 0xbd, 0xee, 0xcb, 0xf4, 0xd0, 0xc2, 0xbc, 0x52, 0x44, 0x11, 0xee, 0xe1, 0x14, 0xd2, 0x24, 0xe5, 0x0, 0xcc, 0xf5, 0xc0, 0xe1, 0x1e, 0xb3, 0x43, 0x52, 0x45, 0xbe, 0xfb, 0x54, 0xc0, 0x55, 0xb2}
	prv1.Import(prvBytes1)
	GeneratePublicKey(&pub1, &prv1, rng)

	GeneratePrivateKey(&prv2, rng)
	GeneratePublicKey(&pub2, &prv2, rng)

	if !DeriveSecret(&ss1, &pub1, &prv2, rng) {
		t.Errorf("Derivation failed\n")
	}

	if !DeriveSecret(&ss2, &pub2, &prv1, rng) {
		t.Errorf("Derivation failed\n")
	}

	if !bytes.Equal(ss1[:], ss2[:]) {
		fmt.Printf("%X\n", ss1)
		fmt.Printf("%X\n", ss2)
		t.Error("ss1 != ss2")
	}
}

func TestPrivateKeyExportImport(t *testing.T) {
	var buf [37]byte
	for i := 0; i < numIter; i++ {
		var prv1, prv2 PrivateKey
		GeneratePrivateKey(&prv1, rng)
		prv1.Export(buf[:])
		prv2.Import(buf[:])

		for i := 0; i < len(prv1.e); i++ {
			if prv1.e[i] != prv2.e[i] {
				t.Error("Error occurred when public key export/import")
			}
		}
	}
}

func TestValidateNegative(t *testing.T) {
	pk := PublicKey{a: p}
	pk.a[0]++
	if Validate(&pk, rng) {
		t.Error("Public key > p has been validated")
	}

	pk = PublicKey{a: p}
	if Validate(&pk, rng) {
		t.Error("Public key == p has been validated")
	}

	pk = PublicKey{a: two}
	if Validate(&pk, rng) {
		t.Error("Public key == 2 has been validated")
	}

	pk = PublicKey{a: twoNeg}
	if Validate(&pk, rng) {
		t.Error("Public key == -2 has been validated")
	}
}

func TestPublicKeyExportImport(t *testing.T) {
	var buf [64]byte
	eq64 := func(x, y []uint64) bool {
		for i := range x {
			if x[i] != y[i] {
				return false
			}
		}
		return true
	}

	for i := 0; i < numIter; i++ {
		var prv PrivateKey
		var pub1, pub2 PublicKey
		GeneratePrivateKey(&prv, rng)
		GeneratePublicKey(&pub1, &prv, rng)

		pub1.Export(buf[:])
		pub2.Import(buf[:])

		if !eq64(pub1.a[:], pub2.a[:]) {
			t.Error("Error occurred when public key export/import")
		}
	}
}

// Test vectors generated by reference implementation
func TestKAT(t *testing.T) {
	var tests TestVectors
	var testVectorFile string

	// Helper checks if e==true and reports an error if not.
	checkExpr := func(e bool, vec *TestVector, t *testing.T, msg string) {
		t.Helper()
		if !e {
			t.Errorf("[Test ID=%d] "+msg, vec.ID)
		}
	}

	if hasADXandBMI2 {
		testVectorFile = "testdata/csidh_testvectors.dat"
	} else {
		testVectorFile = "testdata/csidh_testvectors_small.dat"
	}

	// checkSharedSecret implements nominal case - imports asymmetric keys for
	// both parties, derives secret key and compares it to value in test vector.
	// Comparison must succeed in case status is "Valid" in any other case
	// it must fail.
	checkSharedSecret := func(vec *TestVector, t *testing.T, status int) {
		var prv1 PrivateKey
		var pub1, pub2 PublicKey
		var ss [SharedSecretSize]byte

		prBuf, err := hex.DecodeString(vec.Pr1)
		if err != nil {
			t.Fatal(err)
		}
		checkExpr(
			prv1.Import(prBuf[:]),
			vec, t, "PrivateKey wrong")

		pkBuf, err := hex.DecodeString(vec.Pk1)
		if err != nil {
			t.Fatal(err)
		}
		checkExpr(
			pub1.Import(pkBuf[:]),
			vec, t, "PublicKey 1 wrong")

		pkBuf, err = hex.DecodeString(vec.Pk2)
		if err != nil {
			t.Fatal(err)
		}
		checkExpr(
			pub2.Import(pkBuf[:]),
			vec, t, "PublicKey 2 wrong")

		checkExpr(
			DeriveSecret(&ss, &pub2, &prv1, rng),
			vec, t, "Error when deriving key")

		ssExp, err := hex.DecodeString(vec.Ss)
		if err != nil {
			t.Fatal(err)
		}
		checkExpr(
			bytes.Equal(ss[:], ssExp) == (status == Valid),
			vec, t, "Unexpected value of shared secret")
	}

	// checkPublicKey1 imports public and private key for one party A
	// and tries to generate public key for a private key. After that
	// it compares generated key to a key from test vector. Comparison
	// must fail.
	checkPublicKey1 := func(vec *TestVector, t *testing.T) {
		var prv PrivateKey
		var pub PublicKey
		var pubBytesGot [PublicKeySize]byte

		prBuf, err := hex.DecodeString(vec.Pr1)
		if err != nil {
			t.Fatal(err)
		}

		pubBytesExp, err := hex.DecodeString(vec.Pk1)
		if err != nil {
			t.Fatal(err)
		}

		checkExpr(
			prv.Import(prBuf[:]),
			vec, t, "PrivateKey wrong")

		// Generate public key
		GeneratePrivateKey(&prv, rng)
		pub.Export(pubBytesGot[:])

		// pubBytesGot must be different than pubBytesExp
		checkExpr(
			!bytes.Equal(pubBytesGot[:], pubBytesExp),
			vec, t, "Public key generated is the same as public key from the test vector")
	}

	// checkPublicKey2 the goal is to test key validation. Test tries to
	// import public key for B and ensure that import succeeds in case
	// status is "Valid" and fails otherwise.
	checkPublicKey2 := func(vec *TestVector, t *testing.T, status int) {
		var pub PublicKey

		pubBytesExp, err := hex.DecodeString(vec.Pk2)
		if err != nil {
			t.Fatal(err)
		}

		// Import validates an input, so it must fail
		pub.Import(pubBytesExp[:])
		checkExpr(
			Validate(&pub, rng) == (status == Valid || status == ValidPublicKey2),
			vec, t, "PublicKey has been validated correctly")
	}

	// Load test data
	file, err := os.Open(testVectorFile)
	if err != nil {
		t.Fatal(err.Error())
	}
	err = json.NewDecoder(file).Decode(&tests)
	if err != nil {
		t.Fatal(err.Error())
	}

	// Loop over all test cases
	for _, test := range tests.Vectors {
		switch test.Status {
		case StatusValues[Valid]:
			checkSharedSecret(&test, t, Valid)
			checkPublicKey2(&test, t, Valid)
		case StatusValues[InvalidSharedSecret]:
			checkSharedSecret(&test, t, InvalidSharedSecret)
		case StatusValues[InvalidPublicKey1]:
			checkPublicKey1(&test, t)
		case StatusValues[InvalidPublicKey2]:
			checkPublicKey2(&test, t, InvalidPublicKey2)
		case StatusValues[InvalidPublicKey2]:
			checkPublicKey2(&test, t, InvalidPublicKey2)
		case StatusValues[ValidPublicKey2]:
			checkPublicKey2(&test, t, ValidPublicKey2)
		}
	}
}

var prv1, prv2 PrivateKey
var pub1, pub2 PublicKey

// Private key generation
func BenchmarkGeneratePrivate(b *testing.B) {
	for n := 0; n < b.N; n++ {
		GeneratePrivateKey(&prv1, rng)
	}
}

// Public key generation from private (group action on empty key)
func BenchmarkGenerateKeyPair(b *testing.B) {
	for n := 0; n < b.N; n++ {
		var pub PublicKey
		GeneratePrivateKey(&prv1, rng)
		GeneratePublicKey(&pub, &prv1, rng)
	}
}

// Benchmark validation on same key multiple times
func BenchmarkValidate(b *testing.B) {
	prvBytes := []byte{0xaa, 0x54, 0xe4, 0xd4, 0xd0, 0xbd, 0xee, 0xcb, 0xf4, 0xd0, 0xc2, 0xbc, 0x52, 0x44, 0x11, 0xee, 0xe1, 0x14, 0xd2, 0x24, 0xe5, 0x0, 0xcc, 0xf5, 0xc0, 0xe1, 0x1e, 0xb3, 0x43, 0x52, 0x45, 0xbe, 0xfb, 0x54, 0xc0, 0x55, 0xb2}
	prv1.Import(prvBytes)

	var pub PublicKey
	GeneratePublicKey(&pub, &prv1, rng)

	for n := 0; n < b.N; n++ {
		Validate(&pub, rng)
	}
}

// Benchmark validation on random (most probably wrong) key
func BenchmarkValidateRandom(b *testing.B) {
	var tmp [64]byte
	var pub PublicKey

	// Initialize seed
	for n := 0; n < b.N; n++ {
		if _, err := rng.Read(tmp[:]); err != nil {
			b.FailNow()
		}
		pub.Import(tmp[:])
	}
}

// Benchmark validation on different keys
func BenchmarkValidateGenerated(b *testing.B) {
	for n := 0; n < b.N; n++ {
		GeneratePrivateKey(&prv1, rng)
		GeneratePublicKey(&pub1, &prv1, rng)
		Validate(&pub1, rng)
	}
}

// Generate some keys and benchmark derive
func BenchmarkDerive(b *testing.B) {
	var ss [64]byte

	GeneratePrivateKey(&prv1, rng)
	GeneratePublicKey(&pub1, &prv1, rng)

	GeneratePrivateKey(&prv2, rng)
	GeneratePublicKey(&pub2, &prv2, rng)

	for n := 0; n < b.N; n++ {
		DeriveSecret(&ss, &pub2, &prv1, rng)
	}
}

// Benchmarks both - key generation and derivation
func BenchmarkDeriveGenerated(b *testing.B) {
	var ss [64]byte

	for n := 0; n < b.N; n++ {
		GeneratePrivateKey(&prv1, rng)
		GeneratePublicKey(&pub1, &prv1, rng)

		GeneratePrivateKey(&prv2, rng)
		GeneratePublicKey(&pub2, &prv2, rng)

		DeriveSecret(&ss, &pub2, &prv1, rng)
	}
}
