Revisions
This commit is contained in:
@@ -66,9 +66,13 @@ class MLPClassifier:
|
||||
self._model = model
|
||||
return self
|
||||
|
||||
def predict(self, X):
|
||||
def predict_proba(self, X):
|
||||
X_te = torch.tensor(X, dtype=torch.float32)
|
||||
self._model.eval()
|
||||
with torch.no_grad():
|
||||
predictions = self._model(X_te.to(self._device)).argmax(dim=1).cpu().numpy()
|
||||
return predictions
|
||||
logits = self._model(X_te.to(self._device))
|
||||
probabilities = torch.softmax(logits, dim=1).cpu().numpy()
|
||||
return probabilities
|
||||
|
||||
def predict(self, X):
|
||||
return self.predict_proba(X).argmax(axis=1)
|
||||
|
||||
Reference in New Issue
Block a user