# include <iostream>	// for std::cout
# include <algorithm>	// for std::for_each, ...

# include <comstl_coll_sequence.h>	// for comstl::collection_sequence
# include <comstl_enum_sequence.h>	// for comstl::enumerator_sequence
# include <comstl_value_policies.h>	// for interface_policy
# include <comstl_interface_cast.h>	// for interface casts

# include "dia2.h"

struct com_initializer
{
	com_initializer()
	{ 
		::CoInitialize(NULL); 
	}

	~com_initializer()
	{ 
		::CoUninitialize(); 
	}
} com_initializer;

void throwError() 
{
  // get the error
  DWORD lastError = GetLastError();
  if (lastError == S_OK)
	  throw std::runtime_error("Unknown error");
  
  // Get the error text
  LPVOID lpMsgBuf;
  FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | 
                FORMAT_MESSAGE_FROM_SYSTEM | 
                FORMAT_MESSAGE_IGNORE_INSERTS,
                NULL,
                lastError,
                MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), 
                (LPTSTR) &lpMsgBuf,
                0, NULL);
  std::string errStr = (char*) lpMsgBuf;
  LocalFree( lpMsgBuf );

  // throw it!
  throw std::runtime_error(errStr);
}

void openPdbFile(std::wstring filename, 
                 IDiaDataSource** diaDataSource, IDiaSession** diaSession, 
                 IDiaSymbol** globalScope) 
{
	if (FAILED(CoCreateInstance( CLSID_DiaSource, NULL, CLSCTX_INPROC_SERVER, 
								__uuidof( IDiaDataSource ), (void **) diaDataSource))) 
		 throwError();
	
  // Load the PDB file
  if (FAILED((*diaDataSource)->loadDataFromPdb(filename.c_str())))
    throwError();

  // Open a session
  if (FAILED((*diaDataSource)->openSession(diaSession)))
    throwError();

  // Get the global scope
  if (FAILED((*diaSession)->get_globalScope(globalScope)))
    throwError();
};

std::wstring _getName(IDiaSymbol* symbol) {
  // get the name
  BSTR nameB = NULL;
  if (FAILED(symbol->get_name(&nameB)) || !nameB)
    throwError();

  std::wstring name = nameB;
  LocalFree(nameB);

  // if the name is "__unnamed", make up a name
  if (name.find(L"__unnamed") != std::wstring::npos) {
    wchar_t buf[512];

    // get unqiue type code
    DWORD typeId = 0;
    if (FAILED(symbol->get_symIndexId(&typeId)))
      throwError();

    // make up name
    wsprintfW(buf, L"%ws_%x_", name.c_str(), typeId);
    name = buf;
  }

  // return the name
  return name;
}
#if 0
typedef comstl::collection_sequence<IDiaEnumSymbols, IDiaEnumSymbols, IDiaSymbol*, 
			comstl::interface_policy<IDiaSymbol>, IDiaSymbol*> dia_symbol_traversal_type;

void dump_symbol(IDiaSymbol *com_ptr)
{
	std::wcout << _getName(com_ptr).c_str() << std::endl;
};
#else /* ? 0 */

//# define USE_COMSTL_COLLECTION_SEQUENCE
# define USE_COMSTL_ENUMERATOR_SEQUENCE

# ifdef USE_COMSTL_COLLECTION_SEQUENCE
typedef comstl::collection_sequence<IDiaEnumSymbols, IEnumVARIANT, VARIANT, 
			comstl::VARIANT_policy> dia_symbol_traversal_type;

void dump_symbol(VARIANT const &var)
{
	std::wcout << _getName(stlsoft::get_ptr(comstl::interface_cast_noaddref<IDiaSymbol*>(var.punkVal))).c_str() << std::endl;
};

# elif defined(USE_COMSTL_ENUMERATOR_SEQUENCE)

typedef comstl::enumerator_sequence<IDiaEnumSymbols, IDiaSymbol*, 
			comstl::interface_policy<IDiaSymbol> > dia_symbol_traversal_type;

void dump_symbol(IDiaSymbol *com_ptr)
{
	std::wcout << _getName(com_ptr).c_str() << std::endl;
};

# else /* ? sequence type */
#  error Need to define one of USE_COMSTL_COLLECTION_SEQUENCE or USE_COMSTL_ENUMERATOR_SEQUENCE
# endif /* sequence type */

#endif /* 0 */




int main (void)
{
	// please, set the pdb file name.
	std::wstring pdbFilename(L"MSDIA71.PDB");
	IDiaDataSource *diaDataSource = 0;
    IDiaSession *diaSession = 0;
    IDiaSymbol* globalScope=0;
    
	openPdbFile(pdbFilename, &diaDataSource, &diaSession, &globalScope);


	IDiaEnumSymbols *itemsEnum = 0;
	if (FAILED(globalScope->findChildren(SymTagUDT, NULL, nsNone, &itemsEnum))) {
		std::cout << "couldn't get children" << std::endl;
	}

	dia_symbol_traversal_type symbol_traversal(itemsEnum, true);

	// there is a problem in the iterator when we try to get begin() because the QueryInterface
	// return E_NOINTERFACE
	std::for_each(symbol_traversal.begin(), symbol_traversal.end(), dump_symbol);

	return 0;
}