Index: ntoskrnl/include/internal/mm.h =================================================================== --- ntoskrnl/include/internal/mm.h (revision 36186) +++ ntoskrnl/include/internal/mm.h (working copy) @@ -1549,6 +1549,7 @@ NTAPI MmCallDllInitialize( IN PLDR_DATA_TABLE_ENTRY LdrEntry, + IN PUNICODE_STRING ImportName, IN PLIST_ENTRY ListHead ); Index: ntoskrnl/mm/sysldr.c =================================================================== --- ntoskrnl/mm/sysldr.c (revision 36186) +++ ntoskrnl/mm/sysldr.c (working copy) @@ -366,17 +366,45 @@ NTSTATUS NTAPI MmCallDllInitialize(IN PLDR_DATA_TABLE_ENTRY LdrEntry, + IN PUNICODE_STRING ImportName, IN PLIST_ENTRY ListHead) { + UNICODE_STRING ServicesKeyName = RTL_CONSTANT_STRING(L"\\Registry\\Machine\\System\\CurrentControlSet\\Services\\"); PMM_DLL_INITIALIZE DllInit; + UNICODE_STRING RegPath; + NTSTATUS Status; /* Try to see if the image exports a DllInitialize routine */ DllInit = (PMM_DLL_INITIALIZE)MiLocateExportName(LdrEntry->DllBase, "DllInitialize"); if (!DllInit) return STATUS_SUCCESS; - /* FIXME: TODO */ - DPRINT1("DllInitialize not called!\n"); + /* Obtain the path to this dll's service in the registry */ + RegPath.MaximumLength = ServicesKeyName.Length + ImportName->Length + sizeof(UNICODE_NULL); + RegPath.Buffer = ExAllocatePoolWithTag(NonPagedPool, RegPath.MaximumLength, TAG_LDR_WSTR); + + if (RegPath.Buffer) + { + /* Setup the base length and copy it */ + RegPath.Length = ServicesKeyName.Length; + RtlCopyMemory(RegPath.Buffer, ServicesKeyName.Buffer, ServicesKeyName.Length); + + /* Build and append the service name itself */ + USHORT NameLength = ImportName->Length; + ImportName->Length = (wcschr(ImportName->Buffer, L'.') - ImportName->Buffer)*sizeof(WCHAR); + RtlAppendUnicodeStringToString (&RegPath, ImportName); + ImportName->Length = NameLength; + + /* Now call the DllInit func */ + DPRINT1("Calling DllInit(%wZ)\n", &RegPath); + Status = DllInit(&RegPath); + + /* Clean up */ + ExFreePool (RegPath.Buffer); + + return Status; + } + return STATUS_UNSUCCESSFUL; } @@ -517,7 +545,7 @@ MAXIMUM_FILENAME_LENGTH - 1); /* Setup name tables */ - DPRINT("Import name: %s\n", NameImport->Name); + DPRINT1("Import name: %s\n", NameImport->Name); NameTable = (PULONG)((ULONG_PTR)DllBase + ExportDirectory->AddressOfNames); OrdinalTable = (PUSHORT)((ULONG_PTR)DllBase + @@ -855,6 +883,19 @@ return STATUS_PROCEDURE_NOT_FOUND; } + if ((!_strnicmp(ImportName, "ntdll", sizeof("ntdll") - 1)) || + (!_strnicmp(ImportName, "winsrv", sizeof("winsrv") - 1)) || + (!_strnicmp(ImportName, "advapi32", sizeof("advapi32") - 1)) || + (!_strnicmp(ImportName, "kernel32", sizeof("kernel32") - 1)) || + (!_strnicmp(ImportName, "user32", sizeof("user32") - 1)) || + (!_strnicmp(ImportName, "gdi32", sizeof("gdi32") - 1)) ) + { + /* It's importing stuff it shouldn't be! */ + MiDereferenceImports(LoadedImports); + if (LoadedImports) ExFreePool(LoadedImports); + return STATUS_PROCEDURE_NOT_FOUND; + } + /* Check if this is a "core" import, which doesn't get referenced */ if (!(_strnicmp(ImportName, "ntoskrnl", sizeof("ntoskrnl") - 1)) || !(_strnicmp(ImportName, "win32k", sizeof("win32k") - 1)) || @@ -946,6 +987,7 @@ DllName.Buffer[(DllName.MaximumLength - 1) / 2] = UNICODE_NULL; /* Load the image */ + DPRINT1("Loading imports through image %wZ\n",&DllName); Status = MmLoadSystemImage(&DllName, NamePrefix, NULL, @@ -959,7 +1001,9 @@ } else { - /* Fill out the information for the error */ + DPRINT1("Failed to load %wZ\n", &DllName); + + /* Fill out the information for the error */ *MissingDriver = DllName.Buffer; *(PULONG)MissingDriver |= 1; *MissingApi = NULL; @@ -981,7 +1025,7 @@ ASSERT(DllBase = DllEntry->DllBase); /* Call the initialization routines */ - Status = MmCallDllInitialize(DllEntry, &PsLoadedModuleList); + Status = MmCallDllInitialize(DllEntry, &NameString, &PsLoadedModuleList); if (!NT_SUCCESS(Status)) { /* We failed, unload the image */ @@ -1075,6 +1119,8 @@ /* Check if we have an import list */ if (LoadedImports) { + DPRINT1("Loaded imports for %wZ\n", &DllName); + /* Reset the count again, and loop entries*/ ImportCount = 0; for (i = 0; i < LoadedImports->Count; i++)