package v5

import (
	"fmt"

	"github.com/facebookincubator/nvdtools/wfn"

	cpeUtil "github.com/anchore/grype/grype/cpe"
	"github.com/anchore/grype/grype/db/v5/namespace"
	"github.com/anchore/grype/grype/distro"
	"github.com/anchore/grype/grype/pkg"
	"github.com/anchore/grype/grype/vulnerability"
	"github.com/anchore/grype/internal/log"
	"github.com/anchore/syft/syft/cpe"
	syftPkg "github.com/anchore/syft/syft/pkg"
)

type vulnerabilityProvider struct {
	namespaceIndex *namespace.Index
	reader         VulnerabilityStoreReader
}

func NewVulnerabilityProvider(reader VulnerabilityStoreReader) (VulnerabilityProvider, error) {
	namespaces, err := reader.GetVulnerabilityNamespaces()
	if err != nil {
		return nil, fmt.Errorf("unable to get namespaces from store: %w", err)
	}

	namespaceIndex, err := namespace.FromStrings(namespaces)
	if err != nil {
		return nil, fmt.Errorf("unable to parse namespaces from store: %w", err)
	}

	return &vulnerabilityProvider{
		namespaceIndex: namespaceIndex,
		reader:         reader,
	}, nil
}

func (pr *vulnerabilityProvider) Get(id, namespace string) ([]vulnerability.Vulnerability, error) {
	// note: getting a vulnerability record by id doesn't necessarily return a single record
	// since records are duplicated by the set of fixes they have.
	vulns, err := pr.reader.GetVulnerability(namespace, id)
	if err != nil {
		return nil, fmt.Errorf("provider failed to fetch namespace=%q pkg=%q: %w", namespace, id, err)
	}

	var results []vulnerability.Vulnerability
	for _, vuln := range vulns {
		vulnObj, err := NewVulnerability(vuln)
		if err != nil {
			log.WithFields("namespace", vuln.Namespace, "id", vuln.ID).Errorf("failed to inflate vulnerability record: %v", err)
			continue
		}

		results = append(results, *vulnObj)
	}
	return results, nil
}

func (pr *vulnerabilityProvider) GetByDistro(d *distro.Distro, p pkg.Package) ([]vulnerability.Vulnerability, error) {
	if d == nil {
		return nil, nil
	}

	var vulnerabilities []vulnerability.Vulnerability
	namespaces := pr.namespaceIndex.NamespacesForDistro(d)

	if len(namespaces) == 0 {
		log.Debugf("no vulnerability namespaces found in grype database for distro=%s package=%s", d.String(), p.Name)
		return vulnerabilities, nil
	}

	vulnerabilities = make([]vulnerability.Vulnerability, 0)

	for _, n := range namespaces {
		for _, packageName := range n.Resolver().Resolve(p) {
			nsStr := n.String()
			allPkgVulns, err := pr.reader.SearchForVulnerabilities(nsStr, packageName)

			if err != nil {
				return nil, fmt.Errorf("provider failed to search for vulnerabilities (namespace=%q pkg=%q): %w", nsStr, packageName, err)
			}

			for _, vuln := range allPkgVulns {
				vulnObj, err := NewVulnerability(vuln)
				if err != nil {
					log.WithFields("namespace", vuln.Namespace, "id", vuln.ID).Errorf("failed to inflate vulnerability record (by distro): %v", err)
					continue
				}

				vulnerabilities = append(vulnerabilities, *vulnObj)
			}
		}
	}

	return vulnerabilities, nil
}

func (pr *vulnerabilityProvider) GetByLanguage(l syftPkg.Language, p pkg.Package) ([]vulnerability.Vulnerability, error) {
	var vulnerabilities []vulnerability.Vulnerability
	namespaces := pr.namespaceIndex.NamespacesForLanguage(l)

	if len(namespaces) == 0 {
		log.Debugf("no vulnerability namespaces found in grype database for language=%s package=%s", l, p.Name)
		return vulnerabilities, nil
	}

	vulnerabilities = make([]vulnerability.Vulnerability, 0)

	for _, n := range namespaces {
		for _, packageName := range n.Resolver().Resolve(p) {
			nsStr := n.String()
			allPkgVulns, err := pr.reader.SearchForVulnerabilities(nsStr, packageName)

			if err != nil {
				return nil, fmt.Errorf("provider failed to fetch namespace=%q pkg=%q: %w", nsStr, packageName, err)
			}

			for _, vuln := range allPkgVulns {
				vulnObj, err := NewVulnerability(vuln)
				if err != nil {
					log.WithFields("namespace", vuln.Namespace, "id", vuln.ID).Errorf("failed to inflate vulnerability record (by language): %v", err)
					continue
				}

				vulnerabilities = append(vulnerabilities, *vulnObj)
			}
		}
	}

	return vulnerabilities, nil
}

func (pr *vulnerabilityProvider) GetByCPE(requestCPE cpe.CPE) ([]vulnerability.Vulnerability, error) {
	vulns := make([]vulnerability.Vulnerability, 0)
	namespaces := pr.namespaceIndex.CPENamespaces()

	if len(namespaces) == 0 {
		log.Debugf("no vulnerability namespaces found for arbitrary CPEs in grype database")
		return nil, nil
	}

	if requestCPE.Attributes.Product == wfn.Any || requestCPE.Attributes.Product == wfn.NA {
		return nil, fmt.Errorf("product name is required")
	}

	for _, ns := range namespaces {
		allPkgVulns, err := pr.reader.SearchForVulnerabilities(ns.String(), ns.Resolver().Normalize(requestCPE.Attributes.Product))
		if err != nil {
			return nil, fmt.Errorf("provider failed to fetch namespace=%q product=%q: %w", ns, requestCPE.Attributes.Product, err)
		}

		normalizedRequestCPE, err := cpe.New(ns.Resolver().Normalize(requestCPE.Attributes.BindToFmtString()), requestCPE.Source)

		if err != nil {
			normalizedRequestCPE = requestCPE
		}

		for _, vuln := range allPkgVulns {
			vulnCPEs, err := cpeUtil.NewSlice(vuln.CPEs...)
			if err != nil {
				return nil, err
			}

			// compare the request CPE to the potential matches (excluding version, which is handled downstream)
			candidateMatchCpes := cpeUtil.MatchWithoutVersion(normalizedRequestCPE, vulnCPEs)

			if len(candidateMatchCpes) > 0 {
				vulnObj, err := NewVulnerability(vuln)
				if err != nil {
					log.WithFields("namespace", vuln.Namespace, "id", vuln.ID, "cpe", requestCPE.Attributes.BindToFmtString()).Errorf("failed to inflate vulnerability record (by CPE): %v", err)
					continue
				}

				vulnObj.CPEs = candidateMatchCpes

				vulns = append(vulns, *vulnObj)
			}
		}
	}

	return vulns, nil
}
