From 2911da061b330f255a8c4b85c4ef4d72705286d3 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 5 Mar 2020 14:11:13 +0300 Subject: [PATCH] blas fallback (#291) Signed-off-by: raver119 Co-authored-by: raver119 --- libnd4j/include/helpers/impl/BlasHelper.cpp | 30 +++++++++++++++++++++ libnd4j/include/legacy/impl/Environment.cpp | 9 +++++++ libnd4j/include/system/Environment.h | 4 +++ 3 files changed, 43 insertions(+) diff --git a/libnd4j/include/helpers/impl/BlasHelper.cpp b/libnd4j/include/helpers/impl/BlasHelper.cpp index 0f270a97e..378c8a6f1 100644 --- a/libnd4j/include/helpers/impl/BlasHelper.cpp +++ b/libnd4j/include/helpers/impl/BlasHelper.cpp @@ -74,6 +74,9 @@ namespace sd { template <> bool BlasHelper::hasGEMV() { + if (sd::Environment::getInstance()->blasFallback()) + return false; + #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) return true; #else @@ -83,6 +86,9 @@ namespace sd { template <> bool BlasHelper::hasGEMV() { + if (sd::Environment::getInstance()->blasFallback()) + return false; + #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) return true; #else @@ -132,6 +138,9 @@ namespace sd { bool BlasHelper::hasGEMV(const sd::DataType dtype) { if(dtype == DataType::FLOAT32) { + if (sd::Environment::getInstance()->blasFallback()) + return false; + #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) return true; #else @@ -139,6 +148,9 @@ namespace sd { #endif } if(dtype == DataType::DOUBLE) { + if (sd::Environment::getInstance()->blasFallback()) + return false; + #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) return true; #else @@ -150,6 +162,9 @@ namespace sd { template <> bool BlasHelper::hasGEMM() { + if (sd::Environment::getInstance()->blasFallback()) + return false; + #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) return true; #else @@ -159,6 +174,9 @@ namespace sd { template <> bool BlasHelper::hasGEMM() { + if (sd::Environment::getInstance()->blasFallback()) + return false; + #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) return true; #else @@ -208,6 +226,9 @@ namespace sd { bool BlasHelper:: hasGEMM(const sd::DataType dtype) { if(dtype == DataType::FLOAT32) { + if (sd::Environment::getInstance()->blasFallback()) + return false; + #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) return true; #else @@ -215,6 +236,9 @@ namespace sd { #endif } if(dtype == DataType::DOUBLE) { + if (sd::Environment::getInstance()->blasFallback()) + return false; + #if defined(__EXTERNAL_BLAS__) || defined(HAVE_OPENBLAS) return true; #else @@ -227,11 +251,17 @@ namespace sd { template <> bool BlasHelper::hasBatchedGEMM() { + if (sd::Environment::getInstance()->blasFallback()) + return false; + return _hasSgemmBatch; } template <> bool BlasHelper::hasBatchedGEMM() { + if (sd::Environment::getInstance()->blasFallback()) + return false; + return _hasDgemmBatch; } diff --git a/libnd4j/include/legacy/impl/Environment.cpp b/libnd4j/include/legacy/impl/Environment.cpp index 491d33569..fae3a28dc 100644 --- a/libnd4j/include/legacy/impl/Environment.cpp +++ b/libnd4j/include/legacy/impl/Environment.cpp @@ -162,6 +162,11 @@ namespace sd { // still do nothing } } + + const char* blas_fallback = std::getenv("SD_BLAS_FALLBACK"); + if (blas_fallback != nullptr) { + _blasFallback = true; + } #endif #ifdef __CUDABLAS__ @@ -189,6 +194,10 @@ namespace sd { #endif } + bool sd::Environment::blasFallback() { + return _blasFallback; + } + sd::Environment::~Environment() { // } diff --git a/libnd4j/include/system/Environment.h b/libnd4j/include/system/Environment.h index 9a998d705..392e70871 100644 --- a/libnd4j/include/system/Environment.h +++ b/libnd4j/include/system/Environment.h @@ -51,6 +51,8 @@ namespace sd{ std::atomic _maxTotalSpecialMemory{-1}; std::atomic _maxDeviceMemory{-1}; + bool _blasFallback = false; + #ifdef __ND4J_EXPERIMENTAL__ const bool _experimental = true; #else @@ -85,6 +87,8 @@ namespace sd{ void setLeaksDetector(bool reallyDetect); bool helpersAllowed(); void allowHelpers(bool reallyAllow); + + bool blasFallback(); int tadThreshold(); void setTadThreshold(int threshold);