/* /////////////////////////////////////////////////////////////////////////////
 * File:        loader.d (originally from synsoft.win32.loader)
 *
 * Purpose:     Win32 exception classes
 *
 * Created      18th October 2003
 * Updated:     11th May 2004
 *
 * Author:      Matthew Wilson
 *
 * License:     (Licensed under the Synesis Software Standard Source License)
 *
 *              Copyright (C) 2003-2004, Synesis Software Pty Ltd.
 *
 *              All rights reserved.
 *
 *              www:        http://www.synesis.com.au/software
 *                          http://www.synsoft.org/
 *
 *              email:      submissions@synsoft.org  for submissions
 *                          admin@synsoft.org        for other enquiries
 *
 *              Redistribution and use in source and binary forms, with or
 *              without modification, are permitted provided that the following
 *              conditions are met:
 *
 *              (i) Redistributions of source code must retain the above
 *              copyright notice and contact information, this list of
 *              conditions and the following disclaimer.
 *
 *              (ii) Any derived versions of this software (howsoever modified)
 *              remain the sole property of Synesis Software.
 *
 *              (iii) Any derived versions of this software (howsoever modified)
 *              remain subject to all these conditions.
 *
 *              (iv) Neither the name of Synesis Software nor the names of any
 *              subdivisions, employees or agents of Synesis Software, nor the
 *              names of any other contributors to this software may be used to
 *              endorse or promote products derived from this software without
 *              specific prior written permission.
 *
 *              This source code is provided by Synesis Software "as is" and any
 *              warranties, whether expressed or implied, including, but not
 *              limited to, the implied warranties of merchantability and
 *              fitness for a particular purpose are disclaimed. In no event
 *              shall the Synesis Software be liable for any direct, indirect,
 *              incidental, special, exemplary, or consequential damages
 *              (including, but not limited to, procurement of substitute goods
 *              or services; loss of use, data, or profits; or business
 *              interruption) however caused and on any theory of liability,
 *              whether in contract, strict liability, or tort (including
 *              negligence or otherwise) arising in any way out of the use of
 *              this software, even if advised of the possibility of such
 *              damage.
 *
 * ////////////////////////////////////////////////////////////////////////// */



/** \file D/std/loader.d This file contains the \c D standard library 
 * executable module loader library, and the ExeModule class.
 */

/* ////////////////////////////////////////////////////////////////////////// */

module std.loader;

/* /////////////////////////////////////////////////////////////////////////////
 * Imports
 */

private import std.string;
private import std.c.stdlib;
private import std.c.stdio;
private import std.syserror;
//private import std.windows.exceptions;
version(Windows)
{
    import std.c.windows.windows;
}

//import synsoft.types;
/+ + These are borrowed from synsoft.types, until such time as something similar is in Phobos ++
 +/
public alias int                    boolean;

/* /////////////////////////////////////////////////////////////////////////////
 * External function declarations
 */

version(Windows)
{
    extern(Windows)
    {
        alias HMODULE   HModule_;

// These four will be rolled into a platform-independent TSS API soon
        DWORD   TlsAlloc();
        BOOL    TlsFree(DWORD key);
        LPVOID  TlsGetValue(DWORD key);
        BOOL    TlsSetValue(DWORD key, LPVOID value);
    }
}
else version(linux)
{
    extern(C)
    {
        const int RTLD_NOW  =   0x00002; /* Correct for Red Hat 8 */

        typedef void    *HModule_;

        HModule_    dlopen(char *path, int mode);
        int         dlclose(HModule_ handle);
        void        *dlsym(HModule_ handle, char *symbolName);
        char        *dlerror();
    }
}
else
{
    const int platform_not_discriminated = 0;

    static assert(platform_not_discriminated);
}

/** The platform-independent module handle. Note that this has to be
 * separate from the platform-dependent handle because different module names
 * can result in the same module being loaded, which cannot be detected in
 * some operating systems
 */
typedef void    *HXModule;

/* /////////////////////////////////////////////////////////////////////////////
 * ExeModule library Initialisation
 *
 */
static this()
{
    ExeModule_Init_();
}

static ~this()
{
    ExeModule_Uninit_();
}

/* /////////////////////////////////////////////////////////////////////////////
 * ExeModule functions
 */

/* These are "forward declared" here because I don't like the way D forces me
 * to provide my declaration and implementation together, and mixed in with all
 * the other implementation gunk.
 */

/// \name ExeModule functions
/// @{

/** 
 *
 * \note The value of the handle returned may not be a valid handle for your operating
 * system, and you <b>must not</b> attempt to use it with any other operating system
 * or other APIs. It is only valid for use with the ExeModule library.
 */
public HXModule ExeModule_Load(in char[] moduleName)
{
    return ExeModule_Load_(moduleName, false);
}

public HXModule ExeModule_AddRef(in HXModule hModule)
{
    return ExeModule_AddRef_(hModule, false);
}

/**
 *
 * \param hModule The module handler. It must not be null.
 */
public void ExeModule_Release(inout HXModule hModule)
{
    ExeModule_Release_(hModule, false);
}

public void *ExeModule_GetSymbol(inout HXModule hModule, in char[] symbolName)
{
    return ExeModule_GetSymbol_(hModule, symbolName, false);
}

public char[] ExeModule_Error()
{
    return ExeModule_Error_();
}

public char[] ExeModule_GetPath(HXModule hModule)
{
    return ExeModule_GetPath_(hModule, false);
}

/// @}

/* /////////////////////////////////////////////////////////////////////////////
 * TEMP STUFF HERE TO MAKE COMPILE IN ONE UNIT
 */

extern (C)
{
    int wsprintfA(char *dest, char *fmt, ...);
}

extern (Windows)
{
    uint    FormatMessageA( in uint         dwFlags
                        ,   in Reserved 
                        ,   in uint         dwMessageId
                        ,   in uint         dwLanguageId
                        ,   out char        *lpBuffer
                        ,   in uint         nSize
                        ,   in Reserved
                        );
    uint    FormatMessageA( in uint         dwFlags
                        ,   in HMODULE      hModule
                        ,   in uint         dwMessageId
                        ,   in uint         dwLanguageId
                        ,   out char        *lpBuffer
                        ,   in uint         nSize
                        ,   in Reserved
                        );
    void    *LocalFree(in void *);
}


private typedef uint    Reserved;

private const uint      FORMAT_MESSAGE_ALLOCATE_BUFFER      =   0x00000100;
private const uint      FORMAT_MESSAGE_FROM_HMODULE         =   0x00000800;
private const uint      FORMAT_MESSAGE_FROM_SYSTEM          =   0x00001000;
private const uint      FORMAT_MESSAGE_MAX_WIDTH_MASK       =   0x000000FF;

private const ushort    LANG_NEUTRAL                        =   0x00;
private const ushort    SUBLANG_DEFAULT                     =   0x01;

private const Reserved  RESERVED                            =   cast(Reserved)0;

ushort MAKELANGID(ushort p, ushort s)
{
    return (((cast(ushort)(s)) << 10) | cast(ushort)(p));
}

/// \name Error code translation
/// @{

/// Translates the given Win32 error code to a string, using the system message
/// table
///
/// \note This does <b>not</b> use any dynamic library loading
public char[] FormatMessage(uint error)
{
    return FormatMessage_(error);
}

/// Translates the given Win32 error code to a string, using the message table
/// in the given module
///
/// \note This uses the std.ExeModule to load the named module
public char[] FormatMessage(uint error, char[] moduleName)
{
    return FormatMessage_(error, moduleName);
}

/// @}

private char[] TidyMessage_(in char *rawMessage, in uint cch)
{
    char    *end    =   rawMessage + cch;

    for(; rawMessage < end; --end)
    {
        if( *end != '\0' &&
            *end != ' ' &&
            *end != '\t' &&
            *end != '.')
        {
            break;
        }
    }

    return rawMessage[0 .. 1 + (end - rawMessage)].dup;
}

private char[] FormatMessage_(uint error)
{
    char    *rawMessage;
    uint    cch =   FormatMessageA( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_MAX_WIDTH_MASK
                                ,   RESERVED
                                ,   error
                                ,   MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT)
                                ,   rawMessage
                                ,   0
                                ,   RESERVED);

    if(0 == cch)
    {
        return null;
    }
    else
    {
        char[]  message = TidyMessage_(rawMessage, cch);

        LocalFree(rawMessage);

        return message;
    }
}

private char[] FormatMessage_(uint error, char[] moduleName)
{
    try
    {
        auto ExeModule  exemod = new ExeModule(moduleName);

        char    *rawMessage;
        uint    cch =   FormatMessageA( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_HMODULE | FORMAT_MESSAGE_MAX_WIDTH_MASK
                                    ,   cast(HMODULE)exemod.handle
                                    ,   error
                                    ,   MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT)
                                    ,   rawMessage
                                    ,   0
                                    ,   RESERVED);

        if(cch == 0)
        {
            // Try the default system library
            return FormatMessage_(error);
        }
        else
        {
            char[]  message = TidyMessage_(rawMessage, cch);

            LocalFree(rawMessage);

            return message;
        }
    }
    catch(ExeModuleException x)
    {
        // Try the default system library
        return FormatMessage_(error);
    }
}

/// This class is the root exception class for Win32, and provides mechanisms
/// for representing Win32 error codes and extracting error translation messages.
class Win32Exception
    : Exception
{
/// \name Construction
//@{
public:
    /// \brief Creates an instance of the exception
    ///
    /// \param message The message associated with the exception
    this(char[] message)
    {
        this(message, GetLastError());
    }
    /// \brief Creates an instance of the exception, with the given 
    ///
    /// \param message The message associated with the exception
    /// \param error The Win32 error number associated with the exception
    this(char[] message, int error)
    {
        char    sz[24]; // Enough for the three " ()" characters and a 64-bit integer value
        int     cch = wsprintfA(sz, " (%d)", error);

        m_message = message;
        m_error   = error;

        super(message ~ sz[0 .. cch]);
    }
//@}

/// \name Attributes
//@{
public:
    /// Returns the message string associated with the exception
    char[] message()
    {
        return m_message;
    }

    /// Returns the Win32 error code associated with the exception
    int error()
    {
        return m_error;
    }

    /// Converts the error code into a string, searching the default system message libraries
    char[] lookupError()
    {
        return FormatMessage(m_error);
    }

    /// Converts the error code into a string, searching the given message module
    ///
    /// \note Not yet implemented
    char[] lookupError(char[] moduleName)
    {
        return FormatMessage(m_error, moduleName);
    }
//@}

/// \name Members
//@{
private:
    char[]  m_message;
    int     m_error;
//@}
}

unittest
{
    // (i) Test that we can throw and catch one by its own type
    try
    {
        char[]  message =   "Test 1";
        int     code    =   3;
        char[]  string  =   "Test 1 (3)";

        try
        {
            throw new Win32Exception(message, code);
        }
        catch(Win32Exception x)
        {
            assert(x.error == code);
            if(message != x.message)
            {
                printf( "UnitTest failure for Win32Exception:\n"
                        "  x.message [%d;\"%.*s\"] does not equal [%d;\"%.*s\"]\n"
                    ,   x.message.length, x.message
                    ,   message.length, message);
            }
            assert(message == x.message);
        }
    }
    catch(Exception /* x */)
    {
        int code_flow_should_never_reach_here = 0;
        assert(code_flow_should_never_reach_here);
    }

    // (ii) Test that can throw and be caught by Exception
    {
        char[]  message =   "Test 2";
        int     code    =   3;
        char[]  string  =   "Test 2 (3)";

        try
        {
            throw new Win32Exception(message, code);
        }
        catch(Exception x)
        {
            if(string != x.toString())
            {
                printf( "UnitTest failure for Win32Exception:\n"
                        "  x.toString() [%d;\"%.*s\"] does not equal [%d;\"%.*s\"]\n"
                    ,   x.toString().length, x.toString()
                    ,   string.length, string);
            }
            assert(string == x.toString());
        }
    }
}



/* /////////////////////////////////////////////////////////////////////////////
 * Implementation
 */

version(Windows)
{
    private int     s_init;
    private DWORD   s_key;

    private void record_error_()
    in
    {
        assert(0 != s_key);
    }
    body
    {
        TlsSetValue(s_key, cast(LPVOID)(GetLastError()));
    }

    private DWORD get_error_()
    in
    {
        assert(0 != s_key);
    }
    body
    {
        return cast(DWORD)(TlsGetValue(s_key));
    }

    private void ExeModule_Init_()
    {
        if(0 == InterlockedExchangeAdd(&s_init, 1))
        {
            assert(0 == s_key);

            s_key = TlsAlloc();

            if(0 == s_key)
            {
                throw new Win32Exception("Failed to allocate TSS slot", GetLastError());
            }
        }
    }

    private void ExeModule_Uninit_()
    {
        if(0 == InterlockedDecrement(&s_init))
        {
            assert(0 != s_key);

            TlsFree(s_key);
        }
    }

    private HXModule ExeModule_Load_(in char[] moduleName, boolean bThrowOnFailure)
    in
    {
        assert(null !== moduleName);
    }
    body
    {
        HXModule hmod = cast(HXModule)LoadLibraryA(toStringz(moduleName));

        if(null === hmod)
        {
            record_error_();

            if(bThrowOnFailure)
            {
                throw new ExeModuleException("Failed to load module \"" ~ moduleName ~ "\": ", ExeModule_Error());
            }
        }

        return hmod;
    }

    private HXModule ExeModule_AddRef_(in HXModule hModule, boolean bThrowOnFailure)
    in
    {
        assert(null !== hModule);
    }
    body
    {
        return ExeModule_Load_(ExeModule_GetPath_(hModule, bThrowOnFailure), bThrowOnFailure);
    }

    private boolean ExeModule_Release_(inout HXModule hModule, boolean bThrowOnFailure)
    in
    {
        assert(null !== hModule);
    }
    body
    {
        if(!FreeLibrary(cast(HModule_)hModule))
        {
            record_error_();

            if(bThrowOnFailure)
            {
                throw new ExeModuleException(ExeModule_Error());
            }

            return cast(boolean)(false);
        }

        hModule = null;

        return cast(boolean)(true);
    }

    private void *ExeModule_GetSymbol_(inout HXModule hModule, in char[] symbolName, boolean bThrowOnFailure)
    in
    {
        assert(null !== hModule);
    }
    body
    {
        void    *symbol = GetProcAddress(cast(HModule_)hModule, toStringz(symbolName));

        if(null === symbol)
        {
            record_error_();

            if(bThrowOnFailure)
            {
                throw new ExeModuleException("Failed to locate symbol \"" ~ symbolName ~ "\": ", ExeModule_Error());
            }
        }

        return symbol;
    }

    private char[] ExeModule_Error_()
    {
        return FormatMessage(get_error_());
    }

    private char[] ExeModule_GetPath_(HXModule hModule, boolean bThrowOnFailure)
    {
        char    szFileName[260]; // Need to use a constant here
        // http://msdn.microsoft.com/library/default.asp?url=/library/en-us/dllproc/base/getmodulefilename.asp
        uint    cch = GetModuleFileNameA(cast(HModule_)hModule, szFileName, szFileName.length);

        if(0 == cch)
        {
            record_error_();

            if(bThrowOnFailure)
            {
                throw new ExeModuleException(ExeModule_Error());
            }
        }

        return szFileName[0 .. cch].dup;
    }
}
else version(linux)
{
    private class ExeModuleInfo
    {
    public:
        int         m_cRefs;
        HModule_    m_hmod;
        char[]      m_name;

        this(HModule_ hmod, char[] name)
        {
            m_cRefs =   1;
            m_hmod  =   hmod;
            m_name  =   name;
        }
    };

    private int                     s_init;
    private ExeModuleInfo [char[]]  s_modules;
    private char[]                  s_lastError;    // This is NOT thread-specific

    private void record_error_()
    {
        char *err = dlerror();
        s_lastError = (null === err) ? "" : err[0 .. std.string.strlen(err)];
    }

    private int ExeModule_Init_()
    {
        if(1 == ++s_init)
        {
            return 0;
        }

        return 1;
    }

    private void ExeModule_Uninit_()
    {
        if(0 == --s_init)
        {
            // TODO ...
        }
    }

    private HXModule ExeModule_Load_(in char[] moduleName, boolean bThrowOnFailure)
    in
    {
        assert(null !== moduleName);
    }
    body
    {
        ExeModuleInfo   mi = s_modules[moduleName];

        if(null !== mi)
        {
            return (++mi.m_cRefs, cast(HXModule)mi);
        }
        else
        {
            HModule_    hmod = dlopen(toStringz(moduleName), RTLD_NOW);

            if(null === hmod)
            {
                record_error_();

                if(bThrowOnFailure)
                {
                    throw new ExeModuleException(ExeModule_Error());
                }

                return null;
            }
            else
            {
                ExeModuleInfo   mi  =   new ExeModuleInfo(hmod, moduleName);

                s_modules[moduleName]   =   mi;

                return cast(HXModule)mi;
            }
        }
    }

    private HXModule ExeModule_AddRef_(in HXModule hModule, boolean bThrowOnFailure)
    in
    {
        assert(null !== hModule);

        ExeModuleInfo   mi = cast(ExeModuleInfo)hModule;

        assert(0 < mi.m_cRefs);
        assert(null !== mi.m_hmod);
        assert(null !== mi.m_name);
        assert(null !== s_modules[mi.m_name]);
        assert(mi === s_modules[mi.m_name]);
    }
    body
    {
        ExeModuleInfo   mi = cast(ExeModuleInfo)hModule;

        if(null !== mi)
        {
            return (++mi.m_cRefs, hModule);
        }
        else
        {
            if(bThrowOnFailure)
            {
                throw new ExeModuleException(ExeModule_Error());
            }

            return null;
        }
    }

    private boolean ExeModule_Release_(inout HXModule hModule, boolean bThrowOnFailure)
    in
    {
        assert(null !== hModule);

        ExeModuleInfo   mi = cast(ExeModuleInfo)hModule;

        assert(0 < mi.m_cRefs);
        assert(null !== mi.m_hmod);
        assert(null !== mi.m_name);
        assert(null !== s_modules[mi.m_name]);
        assert(mi === s_modules[mi.m_name]);
    }
    body
    {
        boolean         bRet    =   cast(boolean)(true);
        ExeModuleInfo   mi      =   cast(ExeModuleInfo)hModule;

        if(0 == --mi.m_cRefs)
        {
            char[]  name = mi.m_name;

            if(dlclose(mi.m_hmod))
            {
                record_error_();

                if(bThrowOnFailure)
                {
                    throw new ExeModuleException(ExeModule_Error());
                }

                bRet = cast(boolean)(false);
            }

            delete s_modules[name];
            delete mi;
        }

        hModule = null;

        return bRet;
    }

    private void *ExeModule_GetSymbol_(inout HXModule hModule, in char[] symbolName)
    in
    {
        assert(null !== hModule);

        ExeModuleInfo   mi = cast(ExeModuleInfo)hModule;

        assert(0 < mi.m_cRefs);
        assert(null !== mi.m_hmod);
        assert(null !== mi.m_name);
        assert(null !== s_modules[mi.m_name]);
        assert(mi === s_modules[mi.m_name]);
    }
    body
    {
        ExeModuleInfo   mi      =   cast(ExeModuleInfo)hModule;
        void            *symbol =   dlsym(mi.m_hmod, toStringz(symbolName));

        if(null === symbol)
        {
            record_error_();
        }

        return symbol;
    }

    private char[] ExeModule_Error_()
    {
        return s_lastError;
    }

    private char[] ExeModule_GetPath_(HXModule hModule, boolean bThrowOnException)
    in
    {
        assert(null !== hModule);

        ExeModuleInfo   mi = cast(ExeModuleInfo)hModule;

        assert(0 < mi.m_cRefs);
        assert(null !== mi.m_hmod);
        assert(null !== mi.m_name);
        assert(null !== s_modules[mi.m_name]);
        assert(mi === s_modules[mi.m_name]);
    }
    body
    {
        ExeModuleInfo   mi = cast(ExeModuleInfo)hModule;

        return mi.m_name;
    }
}
else
{
    const int platform_not_discriminated = 0;

    static assert(platform_not_discriminated);
}

/* /////////////////////////////////////////////////////////////////////////////
 * Classes
 */

/// Exception thrown by the ExeModule API and the ExeModule class
public class ExeModuleException
    : Exception
{
public:
    this(char[] message)
    {
        super(message);
    }
    this(char[] message0, char[] message1)
    {
        super(message0 ~ message1);
    }
/+
    this(char[] message, uint errcode)
    {
        super(message ~ " (" ~ std.string.toString(errcode) ~ ")");
    }
+/
}

/// This class represents an executable image
public auto class ExeModule
{
/// \name Construction
/// @{
public:
    /// Constructs an instance which manipulates an existing image handle
    ///
    /// \param hModule The module handle (created by ExeModule_Load()). Must not be NULL
    /// \param bTakeOwnership If true, the instance takes ownership of \c hModule. If false, it increases the reference count of \c hModule
    /// \note Throws ExeModuleException if not taking ownership and the module handle cannot be incremented
    this(in HXModule hModule, boolean bTakeOwnership)
    in
    {
        assert(null !== hModule);
    }
    body
    {
        if(bTakeOwnership)
        {
            m_hModule = hModule;
        }
        else
        {
            m_hModule = ExeModule_AddRef_(hModule, true);
        }
    }

    /// Constructs an instance which loads the given module by name
    ///
    /// \param moduleName The name of the module to load
    /// \note Throws ExeModuleException if the given module name is invalid, or the module cannot be loaded
    this(char[] moduleName)
    in
    {
        assert(null !== moduleName);
    }
    body
    {
        m_hModule = ExeModule_Load_(moduleName, true);
    }

    /// Destructor.
    ///
    /// \note This closes the module handle, if it was not already closed by an explict call to \c close()
    ~this()
    {
        close();
    }
/// @}

/// \name Operations
/// @{
public:
    /// Closes the library
    ///
    /// \note This is available to close the module at any time. Repeated
    /// calls do not result in an error, and are simply ignored.
    void close()
    {
        if(null !== m_hModule)
        {
            ExeModule_Release_(m_hModule, true);
        }
    }
/// @}

/// \name Accessors
/// @{
public:
    /// Retrieves the named symbol.
    ///
    /// \param symbolName The name of the symbol to load
    /// \return A pointer to the symbol. There is no null return - failure to retrieve the symbol
    /// results in an ExeModuleException exception being thrown.
    void *getSymbol(in char[] symbolName)
    {
        return ExeModule_GetSymbol_(m_hModule, symbolName, true);
    }

    /// Retrieves the named symbol.
    ///
    /// \param symbolName The name of the symbol to load
    /// \return A pointer to the symbol, or null if it does not exist. An exception is not thrown.
    void *findSymbol(in char[] symbolName)
    {
        return ExeModule_GetSymbol_(m_hModule, symbolName, false);
    }

/// @}

/// \name Properties
/// @{
public:
    /// The handle of the module
    ///
    /// \note Will be \c null if the module load in the constructor failed
    HXModule handle()
    {
        return m_hModule;
    }
    /// The handle of the module
    ///
    /// \note Will be \c null if the module load in the constructor failed
    char[] path()
    {
        assert(null != m_hModule);

        return ExeModule_GetPath_(m_hModule, true);
    }
/// @}

private:
    HXModule m_hModule;
};

/* ////////////////////////////////////////////////////////////////////////// */

version(TestMain)
{
    int main(char[][] args)
    {
        if(args.length < 3)
        {
            printf("USAGE: <moduleName> <symbolName>\n");
        }
        else
        {
            char[]  moduleName  =   args[1];
            char[]  symbolName  =   args[2];

            try
            {
                auto ExeModule xmod =   new ExeModule(moduleName);

                printf("\"%.*s\" is loaded\n", moduleName);

                void    *symbol =   xmod.getSymbol(symbolName);

                if(null == symbol)
                {
                    throw new ExeModuleException(ExeModule_Error());
                }
                else
                {
                    printf("\"%.*s\" is acquired\n", symbolName);
                }
            }
            catch(ExeModuleException x)
            {
                x.print();
            }
        }

        return 0;
    }
}

/* ////////////////////////////////////////////////////////////////////////// */
