/*
 * This file is part of LibEuFin.
 * Copyright (C) 2020-2025 Taler Systems S.A.
 *
 * LibEuFin is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation; either version 3, or
 * (at your option) any later version.
 *
 * LibEuFin is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public
 * License along with LibEuFin; see the file COPYING.  If not, see
 * <http://www.gnu.org/licenses/>
 */

package tech.libeufin.ebics

import tech.libeufin.common.decodeBase64
import org.w3c.dom.Document
import org.w3c.dom.Node
import org.w3c.dom.NodeList
import org.w3c.dom.Element
import org.xml.sax.InputSource
import java.io.InputStream
import java.io.StringWriter
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.util.UUID
import java.time.Instant
import java.time.ZoneId
import java.time.LocalDate
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter
import java.security.PrivateKey
import java.security.PublicKey
import javax.xml.XMLConstants
import javax.xml.crypto.*
import javax.xml.crypto.dom.DOMURIReference
import javax.xml.crypto.dsig.*
import javax.xml.crypto.dsig.dom.DOMSignContext
import javax.xml.crypto.dsig.dom.DOMValidateContext
import javax.xml.crypto.dsig.spec.C14NMethodParameterSpec
import javax.xml.crypto.dsig.spec.TransformParameterSpec
import javax.xml.parsers.DocumentBuilderFactory
import javax.xml.transform.OutputKeys
import javax.xml.transform.TransformerFactory
import javax.xml.transform.dom.DOMSource
import javax.xml.transform.stream.StreamResult
import javax.xml.stream.XMLOutputFactory
import javax.xml.stream.XMLStreamWriter
import javax.xml.xpath.XPath
import javax.xml.xpath.XPathConstants
import javax.xml.xpath.XPathFactory

fun Instant.xmlDate(): String = 
    DateTimeFormatter.ISO_DATE.withZone(ZoneId.of("UTC")).format(this)
fun Instant.xmlDateTime(): String = 
    DateTimeFormatter.ISO_OFFSET_DATE_TIME.withZone(ZoneId.of("UTC")).format(this)

interface XmlBuilder {
    fun el(path: String, lambda: XmlBuilder.() -> Unit = {})
    fun el(path: String, content: String) {
        el(path) {
            text(content)
        }
    }
    fun attr(namespace: String, name: String, value: String)
    fun attr(name: String, value: String)
    fun text(content: String)

    companion object {
        fun toBytes(root: String, f: XmlBuilder.() -> Unit): ByteArray {
            val factory = XMLOutputFactory.newFactory()
            val stream = StringWriter()
            val writer = factory.createXMLStreamWriter(stream)
            /**
             * NOTE: commenting out because it wasn't obvious how to output the
             * "standalone = 'yes' directive".  Manual forge was therefore preferred.
             */
            stream.write("<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"yes\"?>")
            XmlStreamBuilder(writer).el(root) {
                this.f()
            }
            writer.writeEndDocument()
            return stream.buffer.toString().toByteArray()
        }

        fun toDom(root: String, schema: String?, f: XmlBuilder.() -> Unit): Document {
            val factory = DocumentBuilderFactory.newInstance()
            factory.isNamespaceAware = true
            val builder = factory.newDocumentBuilder()
            val doc = builder.newDocument()
            doc.xmlVersion = "1.0"
            doc.xmlStandalone = true
            val root = doc.createElementNS(schema, root)
            doc.appendChild(root)
            XmlDOMBuilder(doc, schema, root).f()
            doc.normalize()
            return doc
        }
    }
}

private class XmlStreamBuilder(private val w: XMLStreamWriter): XmlBuilder {
    override fun el(path: String, lambda: XmlBuilder.() -> Unit) {
        path.splitToSequence('/').forEach { 
            w.writeStartElement(it)
        }
        lambda()
        path.splitToSequence('/').forEach {
            w.writeEndElement()
        }
    }

    override fun attr(namespace: String, name: String, value: String) {
        w.writeAttribute(namespace, name, value)
    }

    override fun attr(name: String, value: String) {
        w.writeAttribute(name, value)
    }

    override fun text(content: String) {
        w.writeCharacters(content)
    }
}

private class XmlDOMBuilder(private val doc: Document, private val schema: String?, private var node: Element): XmlBuilder {
    override fun el(path: String, lambda: XmlBuilder.() -> Unit) {
        val current = node
        path.splitToSequence('/').forEach {
            val new = doc.createElementNS(schema, it)
            node.appendChild(new)
            node = new
        }
        lambda()
        node = current
    }

    override fun attr(namespace: String, name: String, value: String) {
        node.setAttributeNS(namespace, name, value)
    }

    override fun attr(name: String, value: String) {
        node.setAttribute(name, value)
    }

    override fun text(content: String) {
        node.appendChild(doc.createTextNode(content))
    }
}

private fun Element.childrenByTag(tag: String, signed: Boolean): Sequence<Element> = sequence {
    for (i in 0..childNodes.length) {
        val el = childNodes.item(i)
        if (el is Element 
            && el.localName == tag 
            && (!signed || el.getAttribute("authenticate") == "true")) {
            yield(el)
        }
    }
}

class XmlDestructor internal constructor(private val el: Element) {
    fun each(path: String, signed: Boolean = false, f: XmlDestructor.() -> Unit) {
        el.childrenByTag(path, signed).forEach {
            f(XmlDestructor(it))
        }
    }

    fun <T> map(path: String, signed: Boolean = false, f: XmlDestructor.() -> T): List<T> {
        return el.childrenByTag(path, signed).map {
            f(XmlDestructor(it))
        }.toList()
    }

    fun one(tag: String, signed: Boolean = false): XmlDestructor {
        val children = el.childrenByTag(tag, signed).iterator()
        if (!children.hasNext()) {
            throw Exception("expected unique '${el.tagName}.$tag', got none")
        }
        val child = children.next()
        if (children.hasNext()) {
            throw Exception("expected unique '${el.tagName}.$tag', got ${children.asSequence().count() + 1}")
        }
        return XmlDestructor(child)
    }
    fun opt(tag: String, signed: Boolean = false): XmlDestructor? {
        val children = el.childrenByTag(tag, signed).iterator()
        if (!children.hasNext()) {
            return null
        }
        val child = children.next()
        if (children.hasNext()) {
            throw Exception("expected optional '${el.tagName}.$tag', got ${children.asSequence().count() + 1}")
        }
        return XmlDestructor(child)
    }

    fun <T> one(path: String, signed: Boolean = false, f: XmlDestructor.() -> T): T = f(one(path, signed))
    fun <T> opt(path: String, signed: Boolean = false, f: XmlDestructor.() -> T): T? = opt(path, signed)?.run(f)

    fun uuid(): UUID = UUID.fromString(text())
    fun text(): String = el.textContent
    fun base64(): ByteArray = el.textContent.decodeBase64()
    fun bool(): Boolean = el.textContent.toBoolean()
    fun float(): Float = el.textContent.toFloat()
    fun date(): LocalDate = LocalDate.parse(text(), DateTimeFormatter.ISO_DATE)
    fun dateTime(): LocalDateTime = LocalDateTime.parse(text(), DateTimeFormatter.ISO_DATE_TIME)
    inline fun <reified T : Enum<T>> enum(): T = java.lang.Enum.valueOf(T::class.java, text())

    fun optAttr(index: String): String? {
        val attr = el.getAttribute(index)
        if (attr == "") {
            return null
        } else {
            return attr
        }
    }
    fun attr(index: String): String {
        val attr = optAttr(index)
        if (attr == null) {
            throw Exception("missing attribute '$index' at '${el.tagName}'")
        }
        return attr
    }


    companion object {
        fun <T> parse(xml: String, root: String, f: XmlDestructor.() -> T): T {
            val inputStream = ByteArrayInputStream(xml.toByteArray())
            return parse(inputStream, root, f)
        }

        fun <T> parse(xml: InputStream, root: String, f: XmlDestructor.() -> T): T {
            val doc = XMLUtil.parseIntoDom(xml)
            return parse(doc, root, f)
        }

        fun <T> parse(doc: Document, root: String, f: XmlDestructor.() -> T): T {
            if (doc.documentElement.localName != root) {
                throw Exception("expected root '$root' got '${doc.documentElement.localName}'")
            }
            val destr = XmlDestructor(doc.documentElement)
            return f(destr)
        }
    }
}


/**
 * This URI dereferencer allows handling the resource reference used for
 * XML signatures in EBICS.
 */
private class EbicsSigUriDereferencer : URIDereferencer {
    override fun dereference(myRef: URIReference?, myCtx: XMLCryptoContext?): Data {
        if (myRef !is DOMURIReference)
            throw Exception("invalid type")
        if (myRef.uri != "#xpointer(//*[@authenticate='true'])")
            throw Exception("invalid EBICS XML signature URI: '${myRef.uri}'")
        val xp: XPath = XPathFactory.newInstance().newXPath()
        val nodeSet = xp.compile("//*[@authenticate='true']/descendant-or-self::node()").evaluate(
            myRef.here.ownerDocument, XPathConstants.NODESET
        )
        if (nodeSet !is NodeList)
            throw Exception("invalid type")
        if (nodeSet.length <= 0) {
            throw Exception("no nodes to sign")
        }
        val nodeList = ArrayList<Node>()
        for (i in 0 until nodeSet.length) {
            val node = nodeSet.item(i)
            nodeList.add(node)
        }
        return NodeSetData { nodeList.iterator() }
    }
}

/**
 * Helpers for dealing with XML in EBICS.
 */
object XMLUtil {
    fun convertDomToBytes(document: Document): ByteArray {
        val w = ByteArrayOutputStream()
        val transformer = TransformerFactory.newInstance().newTransformer()
        transformer.setOutputProperty(OutputKeys.STANDALONE, "yes")
        transformer.transform(DOMSource(document), StreamResult(w))
        return w.toByteArray()
    }

    /** Parse [xml] into a XML DOM */
    fun parseIntoDom(xml: InputStream): Document {
        val factory = DocumentBuilderFactory.newInstance().apply {
            // Enable secure processing
            setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, true)
            // Disable all external access
            setAttribute(XMLConstants.ACCESS_EXTERNAL_DTD, "")
            setAttribute(XMLConstants.ACCESS_EXTERNAL_SCHEMA, "")
            isNamespaceAware = true
        }
        val builder = factory.newDocumentBuilder()
        return xml.use { 
            builder.parse(InputSource(it))
        }
    }

    /** Sign an EBICS document with the authentication and identity signature */
    fun signEbicsDocument(
        doc: Document,
        signingPriv: PrivateKey
    ) {
        val authSigNode = XPathFactory.newInstance().newXPath()
            .evaluate("/*[1]/*[local-name()='AuthSignature']", doc, XPathConstants.NODE)
        if (authSigNode !is Node)
            throw java.lang.Exception("sign: no AuthSignature")
        val fac = XMLSignatureFactory.getInstance("DOM")
        val c14n = fac.newTransform(CanonicalizationMethod.INCLUSIVE, null as TransformParameterSpec?)
        val ref: Reference =
            fac.newReference(
                "#xpointer(//*[@authenticate='true'])",
                fac.newDigestMethod(DigestMethod.SHA256, null),
                listOf(c14n),
                null,
                null
            )
        val canon: CanonicalizationMethod =
            fac.newCanonicalizationMethod(CanonicalizationMethod.INCLUSIVE, null as C14NMethodParameterSpec?)
        val signatureMethod = fac.newSignatureMethod("http://www.w3.org/2001/04/xmldsig-more#rsa-sha256", null)
        val si: SignedInfo = fac.newSignedInfo(canon, signatureMethod, listOf(ref))
        val sig: XMLSignature = fac.newXMLSignature(si, null)
        val dsc = DOMSignContext(signingPriv, authSigNode)
        dsc.defaultNamespacePrefix = "ds"
        dsc.uriDereferencer = EbicsSigUriDereferencer()
        dsc.setProperty("javax.xml.crypto.dsig.cacheReference", true)
        sig.sign(dsc)
        val innerSig = authSigNode.firstChild
        while (innerSig.hasChildNodes()) {
            authSigNode.appendChild(innerSig.firstChild)
        }
        authSigNode.removeChild(innerSig)
    }

    /** Check an EBICS document signature */
    fun verifyEbicsDocument(
        doc: Document,
        signingPub: PublicKey
    ) {
        // Find SignedInfo
        val sigInfos = doc.getElementsByTagNameNS(XMLSignature.XMLNS, "SignedInfo");
        if (sigInfos.length == 0) {
            throw Exception("missing SignedInfo")
        } else if (sigInfos.length != 1) {
            throw Exception("many SignedInfo")
        }
        val sigInfo = sigInfos.item(0)

        // Rename AuthSignature
        val authSig = sigInfo.parentNode
        doc.renameNode(authSig, XMLSignature.XMLNS, "${sigInfo.prefix}:Signature")

        // Check signature
        val fac = XMLSignatureFactory.getInstance("DOM")
        val dvc = DOMValidateContext(signingPub, authSig)
        dvc.setProperty("javax.xml.crypto.dsig.cacheReference", true)
        dvc.uriDereferencer = EbicsSigUriDereferencer()
        val sig = fac.unmarshalXMLSignature(dvc)
        if (!sig.validate(dvc)) {
            throw Exception("bank signature did not verify")
        }
    }
}