When developing machine learning models in Python, ensuring they integrate smoothly with established tools can save significant time and effort.
The scikit-learn API has become a de facto standard for structuring estimators, predictors, and related components, making it easier to reuse code, test new ideas, and plug into existing workflows. In this post, we’ll explore what it means to create a scikit-learn–compatible estimator, break down its key building blocks, and walk through a concrete implementation using a perceptron example.
With regard to reuse, it makes sense to store the developed code in an implementation that is compatible with scikit-learn. As a matter of fact, there are templates that can be used and that we’ve included here. The following objects are the basic components of the implementation, such as the estimator object for learning and the predictor object for performing the classification:
Estimator
An estimator is an object that learns from data, for example, using a classification, regression, or clustering algorithm. The base object sklearn.base.BaseEstimator implements the fit method to learn from the data:
estimator = Perceptron.fit(data,targets)
Perceptron is the name of an object we’ll instantiate using the PerceptronEstimator estimator. When the fit method is called, the reference to the object itself gets returned, so the object reference estimator can be assigned the returned object reference.
Predictor
For example, the method for classification for supervised learning is implemented in the predictor object:
prediction = Perceptron.predict(data)
Transformer
We could implement data filtering or data changes in the transformer object, but we don’t do that.
new_data = obj.transform(data)
Model
In the model object, we can implement the quality according to the goodness of fit measure, which we also omit.
score = obj.score(data)
We’ll now break down the explanation for the implementation of our estimator into separate parts, although you should think of the code in one piece without all the explanations. In addition, you should be able to understand the code easily because you’ve already developed it. We only distribute the coding passages appropriately in the methods.
Create another code cell in your Jupyter Notebook and transfer the different coding passages one below the other into the cell, starting from here. Let’s start with the declaration:
- The new estimator we’re building inherits from the BaseEstimator and Classifier- Mixin classes, as shown in this listing.
# Numpy helps us with the arrays
import numpy as np
# Graphical display
import matplotlib.pyplot as plt
# These are our basic classes
from sklearn.base import BaseEstimator, ClassifierMixin
# Check routines for the consistency of the data, etc.
from sklearn.utils.validation import check_X_y, check_is_fitted, check_
random_state
# Buffering the different target values
from sklearn.utils.multiclass import unique_labels
# Very important, otherwise the plot will not be displayed
%matplotlib inline
# Our estimator, appropriately labeled, and the base classes
class PerceptronEstimator(BaseEstimator, ClassifierMixin):
- The classifier has three methods: initialize (__init__), learn (fit), and analyze (predict).
- The __init__ method shouldn’t accept any training data; it should rather be passed to the fit method.
- It should also be possible to instantiate the estimator without parameters. This means that all parameters of the __init__ method require a default value.
- All parameters of the __init__ method must be stored in object attributes with the same name, as shown in the listing below.
# Initialization
def __init__(self, n_iterations=20, random_state=None):
""" Initialization of the objects
n_iterations: Number of iterations for learning
random_state_seed: In order to guarantee repeatability, a
numpy.random.RandomState object can be constructed,
which was initialized via random_state_seed-Seed
"""
# The number of iterations
self.n_iterations = n_iterations
# The seed for the random generator
self.random_state = random_state
# Buffer the errors in the learning process for the plot
self.errors = []
- We write the step function as before, but we write it as a method, not a lambda (see listing below).
# A step function named after the mathematician and physicist
# Oliver Heaviside
def heaviside(self, x):
""" A step function
x: The value for which the step function is analyzed
"""
if x < 0:
result = 0
else:
result = 1
return result
- If a random number generator (RNG) is used in the code, which is the case in our example, then numpy.random.random() shouldn’t be used; instead, use numpy.random. RandomState. For reasons of repeatability, especially if the fit method is called multiple times, the RNG should be generated in the fit method, and this is how it works: The __init__ method requires a parameter called random_state and should default this to None. In addition, the method should save the random_state unchanged in an attribute.
- The fit method uses check_random_state to generate an RNG and saves it in an attribute called random_state_, as shown in this listing.
# Learn
def fit(self, X=None, y=None ):
""" Train
X: Array-like structure with [N,D], where
N = rows = number of learning examples and
D = columns = number of features
y: Array with [N], with N as above
"""
# Generation of the random number generator (RNG)
random_state = check_random_state(self.random_state)
# Initialization of the weights
# np.size(.,1) = number of columns
self.w = random_state.random_sample(np.size(X,1))
# Check whether X and y have the correct shape: X.shape[0] = y.shape[0]
X, y = check_X_y(X, y)
# Save the unique target values
self.classes_ = unique_labels(y)
# Save learning data for later testing in predict method
self.X_ = X
self.y_ = y
# Learn
for i in range(self.n_iterations):
# random mix, for batch size = 1
# np.size(.,0) = number of rows
rand_index = random_state.randint(0,np.size(X,0))
# A random input vector
x_ = X[rand_index]
# A matching output
y_ = y[rand_index]
# Determine the calculated output:
# Weighted sum with downstream step function
y_hat = self.heaviside(np.dot(self.w, x_))
# Calculate the error as the difference between the desired and
# current output
error = y_ - y_hat
# Collect errors for output
self.errors.append(error)
# Weight adjustment = learning
self.w += error * x_
# Return of the estimator for linked calls
return self
We should say a few more words about the fit method. This method receives the training data in the two-dimensional array X and the desired values in the array y. A new RNG with the initialization value (seed) random_state is then generated. The weights are initialized depending on the size of the input vector x. The check_X_y method checks whether the vectors match. The learning data is stored in object attributes, and the familiar learning iterations take place. That’s it for the learning method.
Now follows the last method, predict, as shown in the next listing. The analysis can then take place. Here, too, scikit-learn provides support methods, such as checking whether learning has already taken place.
# Analyze
def predict(self, x):
""" Analyzing a vector
x: A test input vector
"""
# Check whether fit has already been called
# The data was set in the fit method
check_is_fitted(self, ['X_', 'y_'])
# Analyze, forward path
y_hat = self.heaviside(np.dot(self.w,x))
return y_hat
Of course, we can also implement methods in the estimator class that can be useful in addition to the standard methods, for example, a method for outputting the error curve. We’ve called this method plot (see below).
# Plot
def plot(self):
""" Output of the error
Output the errors stored in the error array as a graphic
"""
# Figure numbers start
fignr = 1
# Print size in inches
plt.figure(fignr,figsize=(10,10))
# Output error as plot
plt.plot(self.errors)
# Grid
plt.style.use('seaborn-v0_8-whitegrid')
# Labels
plt.xlabel('Iteration')
plt.ylabel(r"$(y - \hat y)$")
This completes the estimator class, and we’re now ready to use it. First, the training data is set up in the main() function, which we only create to better structure the script; then our perceptron estimator gets instantiated. We’ve set the number of iterations to 30 and selected 10 as the initial value for the RNG. The fit method receives the training data for learning, and our estimator can then be used for evaluation, as shown in this listing.
def main():
# Training data
X = np.array([[1,0,0], [1,0,1], [1,1,0],[1,1,1]])
y = np.array([0,1,1,1])
# Learn
Perceptron = PerceptronEstimator(30,10)
Perceptron.fit(X,y)
# Test data
x = np.array([1,0,0])
# Analysis after training
print("Analysis at the end of the training:")
for index, x in enumerate(X):
p = Perceptron.predict(x)
print("{}: {} -> {}".format(x, y[index],p))
# Output graph
Perceptron.plot()
# Main program
main()
The predict method analyzes the training data correctly, and the plot also shows what it should show (see below): Success!
# Output:
Analysis at the end of the training:
[1 0 0]: 0 -> 0
[1 0 1]: 1 -> 1
[1 1 0]: 1 -> 1
[1 1 1]: 1 -> 1
The figure below shows the graphical output.
Conclusion
By following the scikit-learn estimator guidelines—structuring your code into clear, consistent methods and inheriting from the right base classes—you gain compatibility with scikit-learn’s ecosystem of tools, from cross-validation and pipelines to model persistence. The perceptron example here demonstrates how to implement __init__, fit, predict, and supporting methods in a reusable, standards-compliant way. Once you understand this pattern, you can adapt it to a wide variety of algorithms, making your machine learning code more modular, maintainable, and easy to integrate into larger projects.
Editor’s note: This post has been adapted from a section of the book Programming Neural Networks with Python by Joachim Steinwendner and Roland Schwaiger. Dr. Steinwendner is a scientific project leader specializing in data science, machine learning, recommendation systems, and deep learning. Dr. Schwaiger is a software developer, freelance trainer, and consultant. He has a PhD in mathematics and he has spent many years working as a researcher in the development of artificial neural networks, applying them in the field of image recognition.
This post was originally published 8/2025.
Comments