diff --git a/libnd4j/include/helpers/mman.h b/libnd4j/include/helpers/mman.h index 484832beb..618ee23c3 100644 --- a/libnd4j/include/helpers/mman.h +++ b/libnd4j/include/helpers/mman.h @@ -138,13 +138,6 @@ void _mmap(Nd4jLong* result, size_t length, const char *fileName) { OffsetType off = 0; int prot = PROT_READ | PROT_WRITE; - // we need to convert long path (probably) to short pat (actually) - // it's Windows API, in the middle of 2018! - auto sz = GetShortPathName(fileName, nullptr, 0); - - auto shortName = new TCHAR[sz]; - GetShortPathName(fileName, shortName, sz); - #ifdef _MSC_VER #pragma warning(push) #pragma warning(disable: 4293) @@ -170,8 +163,6 @@ void _mmap(Nd4jLong* result, size_t length, const char *fileName) { h = CreateFileA(fileName, GENERIC_READ | GENERIC_WRITE, FILE_SHARE_WRITE | FILE_SHARE_READ, nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr); - delete[] shortName; - if (h == INVALID_HANDLE_VALUE) { errno = __map_mman_error(GetLastError(), EPERM); nd4j_printf("Error code: %i\n", (int) errno); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/abstracts/Nd4jWorkspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/abstracts/Nd4jWorkspace.java index 9fa23bf45..23aa3533a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/abstracts/Nd4jWorkspace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/abstracts/Nd4jWorkspace.java @@ -297,8 +297,12 @@ public abstract class Nd4jWorkspace implements MemoryWorkspace { } protected void init() { + // in case of MMAP we don't want any learning applied + if (workspaceConfiguration.getPolicyLocation() == LocationPolicy.MMAP && workspaceConfiguration.getPolicyLearning() != LearningPolicy.NONE) + throw new IllegalArgumentException("Workspace backed by memory-mapped file can't have LearningPolicy defined"); + // we don't want overallocation in case of MMAP - if (currentSize.get() > 0 && workspaceConfiguration.getPolicyLocation() != LocationPolicy.MMAP) { + if (currentSize.get() > 0 && workspaceConfiguration.getPolicyLocation() != LocationPolicy.MMAP) { if (!isOver.get()) { if (workspaceConfiguration.getPolicyAllocation() == AllocationPolicy.OVERALLOCATE && workspaceConfiguration.getOverallocationLimit() > 0) { @@ -309,7 +313,6 @@ public abstract class Nd4jWorkspace implements MemoryWorkspace { if (workspaceConfiguration.getMaxSize() > 0 && currentSize.get() > workspaceConfiguration.getMaxSize()) currentSize.set(workspaceConfiguration.getMaxSize()); - } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/provider/BasicWorkspaceManager.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/provider/BasicWorkspaceManager.java index be054a273..1baea4bc9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/provider/BasicWorkspaceManager.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/provider/BasicWorkspaceManager.java @@ -159,7 +159,7 @@ public abstract class BasicWorkspaceManager implements MemoryWorkspaceManager { if (workspace == null || workspace instanceof DummyWorkspace) return; - //workspace.destroyWorkspace(); + workspace.destroyWorkspace(true); backingMap.get().remove(workspace.getId()); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java index ae21a82cb..05c09ae85 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspace.java @@ -138,6 +138,13 @@ public class CpuWorkspace extends Nd4jWorkspace implements Deallocatable { } } + protected long mappedFileSize() { + if (workspaceConfiguration.getPolicyLocation() != LocationPolicy.MMAP) + return 0; + + return tempFile.length(); + } + @Override protected void clearExternalAllocations() { if (isDebug.get()) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspaceDeallocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspaceDeallocator.java index 40553448a..436244ff8 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspaceDeallocator.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspaceDeallocator.java @@ -18,11 +18,15 @@ package org.nd4j.linalg.cpu.nativecpu.workspace; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; +import org.bytedeco.javacpp.LongPointer; +import org.bytedeco.javacpp.Pointer; import org.nd4j.linalg.api.memory.Deallocator; import org.nd4j.linalg.api.memory.enums.LocationPolicy; import org.nd4j.linalg.api.memory.enums.MemoryKind; import org.nd4j.linalg.api.memory.pointers.PointersPair; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.primitives.Pair; +import org.nd4j.nativeblas.NativeOpsHolder; import java.util.List; import java.util.Queue; @@ -37,12 +41,16 @@ public class CpuWorkspaceDeallocator implements Deallocator { private Queue pinnedPointers; private List externalPointers; private LocationPolicy location; + private Pair mmapInfo; public CpuWorkspaceDeallocator(@NonNull CpuWorkspace workspace) { this.pointersPair = workspace.workspace(); this.pinnedPointers = workspace.pinnedPointers(); this.externalPointers = workspace.externalPointers(); this.location = workspace.getWorkspaceConfiguration().getPolicyLocation(); + + if (workspace.mappedFileSize() > 0) + this.mmapInfo = Pair.makePair(workspace.mmap, workspace.mappedFileSize()); } @Override @@ -50,7 +58,7 @@ public class CpuWorkspaceDeallocator implements Deallocator { log.trace("Deallocating CPU workspace"); // purging workspace planes - if (pointersPair != null) { + if (pointersPair != null && (pointersPair.getDevicePointer() != null || pointersPair.getHostPointer() != null)) { if (pointersPair.getDevicePointer() != null) { Nd4j.getMemoryManager().release(pointersPair.getDevicePointer(), MemoryKind.DEVICE); } @@ -58,6 +66,8 @@ public class CpuWorkspaceDeallocator implements Deallocator { if (pointersPair.getHostPointer() != null) { if (location != LocationPolicy.MMAP) Nd4j.getMemoryManager().release(pointersPair.getHostPointer(), MemoryKind.HOST); + else + NativeOpsHolder.getInstance().getDeviceNativeOps().munmapFile(null, mmapInfo.getFirst(), mmapInfo.getSecond()); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java index 0c43ff9ca..5b357d9a4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java @@ -930,6 +930,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest { WorkspaceConfiguration mmap = WorkspaceConfiguration.builder() .initialSize(1000000) .policyLocation(LocationPolicy.MMAP) + .policyLearning(LearningPolicy.NONE) .build(); MemoryWorkspace ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2"); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java index 8613d60a2..aefbafe53 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/SpecialWorkspaceTests.java @@ -344,6 +344,7 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { .initialSize(200 * 1024L * 1024L) // 200mbs .tempFilePath(tmpFile.toAbsolutePath().toString()) .policyLocation(LocationPolicy.MMAP) + .policyLearning(LearningPolicy.NONE) .build(); try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) { @@ -373,6 +374,7 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { .initialSize(200 * 1024L * 1024L) // 200mbs .tempFilePath(tmpFile.toAbsolutePath().toString()) .policyLocation(LocationPolicy.MMAP) + .policyLearning(LearningPolicy.NONE) .build(); try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) { @@ -380,6 +382,49 @@ public class SpecialWorkspaceTests extends BaseNd4jTest { } } + @Test + public void testDeleteMappedFile_1() throws Exception { + if (!Nd4j.getEnvironment().isCPU()) + return; + + val tmpFile = Files.createTempFile("some", "file"); + val mmap = WorkspaceConfiguration.builder() + .initialSize(200 * 1024L * 1024L) // 200mbs + .tempFilePath(tmpFile.toAbsolutePath().toString()) + .policyLocation(LocationPolicy.MMAP) + .policyLearning(LearningPolicy.NONE) + .build(); + + try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) { + val x = Nd4j.rand(DataType.FLOAT, 1024); + } + + Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); + + Files.delete(tmpFile); + } + + @Test(expected = IllegalArgumentException.class) + public void testDeleteMappedFile_2() throws Exception { + if (!Nd4j.getEnvironment().isCPU()) + throw new IllegalArgumentException("Don't try to run on CUDA"); + + val tmpFile = Files.createTempFile("some", "file"); + val mmap = WorkspaceConfiguration.builder() + .initialSize(200 * 1024L * 1024L) // 200mbs + .tempFilePath(tmpFile.toAbsolutePath().toString()) + .policyLocation(LocationPolicy.MMAP) + .build(); + + try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(mmap, "M2")) { + val x = Nd4j.rand(DataType.FLOAT, 1024); + } + + Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); + + Files.delete(tmpFile); + } + @Override public char ordering() { return 'c';