Updates
This commit is contained in:
@@ -3,16 +3,26 @@ import os
|
||||
import joblib
|
||||
|
||||
MODEL_FILE = "model.joblib"
|
||||
WEIGHTS_DIR = "weights"
|
||||
|
||||
|
||||
def _resolve(name):
|
||||
if name.startswith(WEIGHTS_DIR + os.sep) or name.startswith(WEIGHTS_DIR + "/"):
|
||||
return name
|
||||
return os.path.join(WEIGHTS_DIR, name)
|
||||
|
||||
|
||||
def is_saved_model(path):
|
||||
return os.path.isdir(path) and os.path.exists(os.path.join(path, MODEL_FILE))
|
||||
directory = _resolve(path)
|
||||
return os.path.isdir(directory) and os.path.exists(os.path.join(directory, MODEL_FILE))
|
||||
|
||||
|
||||
def save_model(clf, directory):
|
||||
def save_model(clf, name):
|
||||
directory = _resolve(name)
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
joblib.dump(clf, os.path.join(directory, MODEL_FILE))
|
||||
|
||||
|
||||
def load_model(directory):
|
||||
def load_model(path):
|
||||
directory = _resolve(path)
|
||||
return joblib.load(os.path.join(directory, MODEL_FILE))
|
||||
|
||||
Reference in New Issue
Block a user