/*
 *  Copyright 2001-2005 Internet2
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/* SAMLArtifact.cpp - SAML artifact implementations

   Scott Cantor
   2/15/05

   $History:$
*/

#include "internal.h"

#include <xercesc/util/Base64.hpp>
#include <xsec/enc/XSECCryptoProvider.hpp>

using namespace saml;
using namespace std;

SAMLArtifact::SAMLArtifactFactoryMap SAMLArtifact::m_map;

extern "C" SAMLArtifact* SAMLArtifactType0001Factory(const char* s)
{
    return new SAMLArtifactType0001(s);
}

extern "C" SAMLArtifact* SAMLArtifactType0002Factory(const char* s)
{
    return new SAMLArtifactType0002(s);
}

string SAMLArtifact::toHex(const string& s)
{
    static char DIGITS[] = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};
    int len = s.length();
    string ret;
    
    // two characters form the hex value.
    for (int i=0; i < len; i++) {
        ret+=(DIGITS[((unsigned char)(0xF0 & s[i])) >> 4 ]);
        ret+=(DIGITS[0x0F & s[i]]);
    }
    return ret;
}

SAMLArtifact* SAMLArtifact::parse(const XMLCh* s)
{
    auto_ptr_char temp(s);
    return parse(temp.get());
}

SAMLArtifact* SAMLArtifact::parse(const char* s)
{
    // Decode and extract the type code first.
    unsigned int len=0;
    XMLByte* decoded=Base64::decode(reinterpret_cast<const XMLByte*>(s),&len);
    if (!decoded)
        throw MalformedException("SAMLArtifact::parse() unable to decode base64 artifact");
    
    string type;
    type+= decoded[0];
    type+= decoded[1];
    XMLString::release(&decoded);
    
    SAMLArtifactFactoryMap::const_iterator i=m_map.find(type);
    if (i==m_map.end())
        throw UnsupportedExtensionException(
            string("SAMLArtifact::parse() unable to parse unknown artifact typecode (0x") + toHex(type) + ")"
            );
    return i->second(s);
}

SAMLArtifact::SAMLArtifact(const char* s)
{
    unsigned int len=0;
    XMLByte* decoded=Base64::decode(reinterpret_cast<const XMLByte*>(s),&len);
    if (!decoded)
        throw MalformedException("SAMLArtifact() unable to decode base64 artifact");
    XMLByte* ptr=decoded;
    while (len--)
        m_raw+= *ptr++;
    XMLString::release(&decoded);
}

string SAMLArtifact::encode() const
{
    unsigned int len=0;
    XMLByte* out=Base64::encode(reinterpret_cast<const XMLByte*>(m_raw.data()),m_raw.size(),&len);
    if (out) {
        string ret(reinterpret_cast<char*>(out),len);
        XMLString::release(&out);
        return ret;
    }
    return string();
}

const unsigned int SAMLArtifactType0001::SOURCEID_LENGTH = 20;
const unsigned int SAMLArtifactType0001::HANDLE_LENGTH = 20;

string SAMLArtifactType0001::generateSourceId(const char* s)
{
    auto_ptr<XSECCryptoHash> hasher(XSECPlatformUtils::g_cryptoProvider->hashSHA1());
    if (hasher.get()) {
        char* dup = strdup(s);
        unsigned char buf[SOURCEID_LENGTH+1];
        hasher->hash(reinterpret_cast<unsigned char*>(dup),strlen(dup));
        if (hasher->finish(buf,SOURCEID_LENGTH)==SOURCEID_LENGTH) {
            free(dup);
            string ret;
            for (unsigned int i=0; i<SOURCEID_LENGTH; i++)
                ret+=buf[i];
            return ret;
        }
        free(dup);
    }
    throw InvalidCryptoException("SAMLArtifactType0001::generateSourceId() unable to generate SHA-1 hash");
}

SAMLArtifactType0001::SAMLArtifactType0001(const char* s) : SAMLArtifact(s)
{
    // The base class does the work, we just do the checking.
    if (m_raw.size() != 2 + SOURCEID_LENGTH + HANDLE_LENGTH)
        throw MalformedException("SAMLArtifactType0001() given artifact of incorrect length");
    else if (m_raw[0] != 0x0 || m_raw[1] != 0x1)
        throw MalformedException(
            string("SAMLArtifactType0001() given artifact of invalid type (") + toHex(getTypeCode()) + ")"
            );
}

SAMLArtifactType0001::SAMLArtifactType0001(const std::string& sourceid)
{
    if (sourceid.size()!=SOURCEID_LENGTH)
        throw MalformedException("SAMLArtifactType0001() given sourceid of incorrect length");
    m_raw+=(char)0x0;
    m_raw+=(char)0x1;
    m_raw.append(sourceid,0,20);
    char buf[20];
    SAMLIdentifier::generateRandomBytes(buf,20);
    for (int i=0; i<20; i++)
        m_raw+=buf[i];
}

SAMLArtifactType0001::SAMLArtifactType0001(const std::string& sourceid, const std::string& handle)
{
    if (sourceid.size()!=SOURCEID_LENGTH)
        throw MalformedException("SAMLArtifactType0001() given sourceid of incorrect length");
    if (handle.size()!=HANDLE_LENGTH)
        throw MalformedException("SAMLArtifactType0001() given handle of incorrect length");
    m_raw+=(char)0x0;
    m_raw+=(char)0x1;
    m_raw.append(sourceid,0,20);
    m_raw.append(handle,0,20);
}

SAMLArtifact* SAMLArtifactType0001::clone() const
{
    return new SAMLArtifactType0001(*this);
}


const unsigned int SAMLArtifactType0002::HANDLE_LENGTH = 20;

SAMLArtifactType0002::SAMLArtifactType0002(const char* s) : SAMLArtifact(s)
{
    // The base class does the work, we just do the checking.
    if (m_raw.size() <= 2 + HANDLE_LENGTH)
        throw MalformedException("SAMLArtifactType0001() given artifact of incorrect length");
    else if (m_raw[0] != 0x0 || m_raw[1] != 0x2)
        throw MalformedException(
            string("SAMLArtifactType0002() given artifact of invalid type (") + toHex(getTypeCode()) + ")"
            );
}

SAMLArtifactType0002::SAMLArtifactType0002(const std::string& sourceLocation)
{
    if (sourceLocation.empty())
        throw MalformedException("SAMLArtifactType0002() given empty source location");
    m_raw+=(char)0x0;
    m_raw+=(char)0x2;
    char buf[20];
    SAMLIdentifier::generateRandomBytes(buf,20);
    for (int i=0; i<20; i++)
        m_raw+=buf[i];
    m_raw+=sourceLocation;
}

SAMLArtifactType0002::SAMLArtifactType0002(const std::string& sourceLocation, const std::string& handle)
{
    if (sourceLocation.empty())
        throw MalformedException("SAMLArtifactType0002() given empty source location");
    if (handle.size()!=HANDLE_LENGTH)
        throw MalformedException("SAMLArtifactType0002() given handle of incorrect length");
    m_raw+=(char)0x0;
    m_raw+=(char)0x2;
    m_raw.append(handle,0,20);
    m_raw+=sourceLocation;
}

SAMLArtifact* SAMLArtifactType0002::clone() const
{
    return new SAMLArtifactType0002(*this);
}
