Revisions
This commit is contained in:
@@ -72,9 +72,13 @@ class CNNClassifier:
|
||||
self._model = model
|
||||
return self
|
||||
|
||||
def predict(self, X):
|
||||
def predict_proba(self, X):
|
||||
X_te = torch.tensor(X, dtype=torch.float32)
|
||||
self._model.eval()
|
||||
with torch.no_grad():
|
||||
predictions = self._model(X_te.to(self._device)).argmax(dim=1).cpu().numpy()
|
||||
return predictions
|
||||
logits = self._model(X_te.to(self._device))
|
||||
probabilities = torch.softmax(logits, dim=1).cpu().numpy()
|
||||
return probabilities
|
||||
|
||||
def predict(self, X):
|
||||
return self.predict_proba(X).argmax(axis=1)
|
||||
|
||||
@@ -34,7 +34,7 @@ class FeatureExtractor:
|
||||
}
|
||||
|
||||
|
||||
class HandPickedClassifier:
|
||||
class FeatureClassifier:
|
||||
def fit(self, X, y):
|
||||
self._pipeline = Pipeline([
|
||||
("features", FeatureExtractor()),
|
||||
@@ -46,3 +46,6 @@ class HandPickedClassifier:
|
||||
|
||||
def predict(self, X):
|
||||
return self._pipeline.predict(X)
|
||||
|
||||
def predict_proba(self, X):
|
||||
return self._pipeline.predict_proba(X)
|
||||
@@ -66,9 +66,13 @@ class MLPClassifier:
|
||||
self._model = model
|
||||
return self
|
||||
|
||||
def predict(self, X):
|
||||
def predict_proba(self, X):
|
||||
X_te = torch.tensor(X, dtype=torch.float32)
|
||||
self._model.eval()
|
||||
with torch.no_grad():
|
||||
predictions = self._model(X_te.to(self._device)).argmax(dim=1).cpu().numpy()
|
||||
return predictions
|
||||
logits = self._model(X_te.to(self._device))
|
||||
probabilities = torch.softmax(logits, dim=1).cpu().numpy()
|
||||
return probabilities
|
||||
|
||||
def predict(self, X):
|
||||
return self.predict_proba(X).argmax(axis=1)
|
||||
|
||||
@@ -9,3 +9,6 @@ class PixelClassifier:
|
||||
|
||||
def predict(self, X):
|
||||
return self._classifier.predict(X)
|
||||
|
||||
def predict_proba(self, X):
|
||||
return self._classifier.predict_proba(X)
|
||||
|
||||
Reference in New Issue
Block a user