Featured

What Is the scikit-learn Compatible Estimator?

When developing machine learning models in Python, ensuring they integrate smoothly with established tools can save significant time and effort.

 

The scikit-learn API has become a de facto standard for structuring estimators, predictors, and related components, making it easier to reuse code, test new ideas, and plug into existing workflows. In this post, we’ll explore what it means to create a scikit-learn–compatible estimator, break down its key building blocks, and walk through a concrete implementation using a perceptron example.

 

With regard to reuse, it makes sense to store the developed code in an implementation that is compatible with scikit-learn. As a matter of fact, there are templates that can be used and that we’ve included here. The following objects are the basic components of the implementation, such as the estimator object for learning and the predictor object for performing the classification:

 

Estimator

An estimator is an object that learns from data, for example, using a classification, regression, or clustering algorithm. The base object sklearn.base.BaseEstimator implements the fit method to learn from the data:

 

estimator = Perceptron.fit(data,targets)

 

Perceptron is the name of an object we’ll instantiate using the PerceptronEstimator estimator. When the fit method is called, the reference to the object itself gets returned, so the object reference estimator can be assigned the returned object reference.

 

Predictor

For example, the method for classification for supervised learning is implemented in the predictor object:

 

prediction = Perceptron.predict(data)

 

Transformer

We could implement data filtering or data changes in the transformer object, but we don’t do that.

 

new_data = obj.transform(data)

 

Model

In the model object, we can implement the quality according to the goodness of fit measure, which we also omit.

 

score = obj.score(data)

 

We’ll now break down the explanation for the implementation of our estimator into separate parts, although you should think of the code in one piece without all the explanations. In addition, you should be able to understand the code easily because you’ve already developed it. We only distribute the coding passages appropriately in the methods.

 

Create another code cell in your Jupyter Notebook and transfer the different coding passages one below the other into the cell, starting from here. Let’s start with the declaration:

  • The new estimator we’re building inherits from the BaseEstimator and Classifier- Mixin classes, as shown in this listing.

# Numpy helps us with the arrays

import numpy as np

# Graphical display

import matplotlib.pyplot as plt

# These are our basic classes

from sklearn.base import BaseEstimator, ClassifierMixin

# Check routines for the consistency of the data, etc.

from sklearn.utils.validation import check_X_y, check_is_fitted, check_

random_state

# Buffering the different target values

from sklearn.utils.multiclass import unique_labels

# Very important, otherwise the plot will not be displayed

%matplotlib inline

 

# Our estimator, appropriately labeled, and the base classes

class PerceptronEstimator(BaseEstimator, ClassifierMixin): 

  • The classifier has three methods: initialize (__init__), learn (fit), and analyze (predict).
  • The __init__ method shouldn’t accept any training data; it should rather be passed to the fit method.
  • It should also be possible to instantiate the estimator without parameters. This means that all parameters of the __init__ method require a default value.
  • All parameters of the __init__ method must be stored in object attributes with the same name, as shown in the listing below.

# Initialization

def __init__(self, n_iterations=20, random_state=None):

   """ Initialization of the objects

   n_iterations: Number of iterations for learning

   random_state_seed: In order to guarantee repeatability, a

         numpy.random.RandomState object can be constructed,

         which was initialized via random_state_seed-Seed

   """

   # The number of iterations

   self.n_iterations = n_iterations

   # The seed for the random generator

   self.random_state = random_state

   # Buffer the errors in the learning process for the plot

   self.errors = [] 

  • We write the step function as before, but we write it as a method, not a lambda (see listing below).

# A step function named after the mathematician and physicist

# Oliver Heaviside

def heaviside(self, x):

       """ A step function

 

       x: The value for which the step function is analyzed

 

       """

       if x < 0:

         result = 0

      else:

         result = 1

   return result

  • If a random number generator (RNG) is used in the code, which is the case in our example, then numpy.random.random() shouldn’t be used; instead, use numpy.random. RandomState. For reasons of repeatability, especially if the fit method is called multiple times, the RNG should be generated in the fit method, and this is how it works: The __init__ method requires a parameter called random_state and should default this to None. In addition, the method should save the random_state unchanged in an attribute.
  • The fit method uses check_random_state to generate an RNG and saves it in an attribute called random_state_, as shown in this listing.

# Learn

def fit(self, X=None, y=None ):

   """ Train

   X: Array-like structure with [N,D], where

      N = rows = number of learning examples and

      D = columns = number of features

   y: Array with [N], with N as above

   """

   # Generation of the random number generator (RNG)

   random_state = check_random_state(self.random_state)

   # Initialization of the weights

   # np.size(.,1) = number of columns

   self.w = random_state.random_sample(np.size(X,1))

   # Check whether X and y have the correct shape: X.shape[0] = y.shape[0]

   X, y = check_X_y(X, y)

   # Save the unique target values

   self.classes_ = unique_labels(y)

   # Save learning data for later testing in predict method

   self.X_ = X

   self.y_ = y

   # Learn

   for i in range(self.n_iterations):

      # random mix, for batch size = 1

      # np.size(.,0) = number of rows

      rand_index = random_state.randint(0,np.size(X,0))

      # A random input vector

      x_ = X[rand_index]

      # A matching output

      y_ = y[rand_index]

      # Determine the calculated output:

      # Weighted sum with downstream step function

      y_hat = self.heaviside(np.dot(self.w, x_))

      # Calculate the error as the difference between the desired and

      # current output

      error = y_ - y_hat

      # Collect errors for output

      self.errors.append(error)

      # Weight adjustment = learning

      self.w += error * x_

      # Return of the estimator for linked calls

   return self

 

We should say a few more words about the fit method. This method receives the training data in the two-dimensional array X and the desired values in the array y. A new RNG with the initialization value (seed) random_state is then generated. The weights are initialized depending on the size of the input vector x. The check_X_y method checks whether the vectors match. The learning data is stored in object attributes, and the familiar learning iterations take place. That’s it for the learning method.

 

Now follows the last method, predict, as shown in the next listing. The analysis can then take place. Here, too, scikit-learn provides support methods, such as checking whether learning has already taken place.

 

# Analyze

def predict(self, x):

   """ Analyzing a vector

 

   x: A test input vector

 

   """

   # Check whether fit has already been called

   # The data was set in the fit method

   check_is_fitted(self, ['X_', 'y_'])

   # Analyze, forward path

   y_hat = self.heaviside(np.dot(self.w,x))

 

   return y_hat

 

Of course, we can also implement methods in the estimator class that can be useful in addition to the standard methods, for example, a method for outputting the error curve. We’ve called this method plot (see below).

 

# Plot

def plot(self):

   """ Output of the error

 

   Output the errors stored in the error array as a graphic

 

   """

   # Figure numbers start

   fignr = 1

   # Print size in inches

   plt.figure(fignr,figsize=(10,10))

   # Output error as plot

   plt.plot(self.errors)

   # Grid

   plt.style.use('seaborn-v0_8-whitegrid')

   # Labels

   plt.xlabel('Iteration')

   plt.ylabel(r"$(y - \hat y)$")

 

This completes the estimator class, and we’re now ready to use it. First, the training data is set up in the main() function, which we only create to better structure the script; then our perceptron estimator gets instantiated. We’ve set the number of iterations to 30 and selected 10 as the initial value for the RNG. The fit method receives the training data for learning, and our estimator can then be used for evaluation, as shown in this listing.

 

def main():

   # Training data

   X = np.array([[1,0,0], [1,0,1], [1,1,0],[1,1,1]])

   y = np.array([0,1,1,1])

   # Learn

   Perceptron = PerceptronEstimator(30,10)

   Perceptron.fit(X,y)

   # Test data

   x = np.array([1,0,0])

   # Analysis after training

   print("Analysis at the end of the training:")

   for index, x in enumerate(X):

      p = Perceptron.predict(x)

      print("{}: {} -> {}".format(x, y[index],p))

   # Output graph

   Perceptron.plot()

 

# Main program

main()

 

The predict method analyzes the training data correctly, and the plot also shows what it should show (see below): Success!

 

# Output:

Analysis at the end of the training:

[1 0 0]: 0 -> 0

[1 0 1]: 1 -> 1

[1 1 0]: 1 -> 1

[1 1 1]: 1 -> 1

 

The figure below shows the graphical output.

 

Differences in the Output of the Perceptron Estimator

 

Conclusion

By following the scikit-learn estimator guidelines—structuring your code into clear, consistent methods and inheriting from the right base classes—you gain compatibility with scikit-learn’s ecosystem of tools, from cross-validation and pipelines to model persistence. The perceptron example here demonstrates how to implement __init__, fit, predict, and supporting methods in a reusable, standards-compliant way. Once you understand this pattern, you can adapt it to a wide variety of algorithms, making your machine learning code more modular, maintainable, and easy to integrate into larger projects.

 

Editor’s note: This post has been adapted from a section of the book Programming Neural Networks with Python by Joachim Steinwendner and Roland Schwaiger. Dr. Steinwendner is a scientific project leader specializing in data science, machine learning, recommendation systems, and deep learning. Dr. Schwaiger is a software developer, freelance trainer, and consultant. He has a PhD in mathematics and he has spent many years working as a researcher in the development of artificial neural networks, applying them in the field of image recognition.

 

This post was originally published 8/2025.

Recommendation

Programming Neural Networks with Python
Programming Neural Networks with Python

Neural networks are at the heart of AI—so ensure you’re on the cutting edge with this guide! For true beginners, get a crash course in Python and the mathematical concepts you’ll need to understand and create neural networks. Or jump right into programming your first neural network, from implementing the scikit-learn library to using the perceptron learning algorithm. Learn how to train your neural network, measure errors, make use of transfer learning, implement the CRISP-DM model, and more. Whether you’re interested in machine learning, gen AI, LLMs, deep learning, or all of the above, this is the AI book you need!

Learn More
Rheinwerk Computing
by Rheinwerk Computing

Rheinwerk Computing is an imprint of Rheinwerk Publishing and publishes books by leading experts in the fields of programming, administration, security, analytics, and more.

Comments