How do you transform a quadrilateral area of a BufferedImage into a rectangular BufferedImage in Java?

I am trying to reverse the perspective shift from a rectangle seen in 3D such that it appears as a quadrilateral. Here is an example image that I would like to process:

enter image description here

I know the coordinates of the 4 corners of the quadrilateral in the image.

I have been playing around with AffineTransform, specifically the shear method. However I can not find any good information on how to properly determine the shx and shy values for an arbitrary quadrilateral.

The final image also needs to be a rectangle that does not include any of the black background, just the internal image. So I need some way of selecting only the quadrilateral for the transformation. I tried using java.awt Shapes like Polygon and Area to describe the quadrilateral, however it only seemed to account for the outline and not the pixels contained in the Shape.

Add Comment
1 Answer(s)

I was able to solve this with projective transformations. It doesn’t run as fast I would have liked but still works. It takes about 24 seconds to perform 1000 iterations, on my computer; I was aiming for 60 fps at least. I thought maybe Java would have a built-in way of dealing with these image transformations.

Here is the output image:

enter image description here

Here is my code:

/*  * File:    ImageUtility.java  * Package: utility  * Author:  Zachary Gill  */  package utility;  import java.awt.Color; import java.awt.Graphics2D; import java.awt.Point; import java.awt.Polygon; import java.awt.Shape; import java.awt.image.BufferedImage; import java.awt.image.DataBufferInt; import java.io.File; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.Stack; import javax.imageio.ImageIO;  import math.matrix.Matrix3; import math.vector.Vector;  /**  * Handles image operations.  */ public class ImageUtility {          public static void main(String[] args) throws Exception {         File image = new File("test2.jpg");         BufferedImage src = loadImage(image);                  List<Vector> srcBounds = new ArrayList<>();         srcBounds.add(new Vector(439, 42));         srcBounds.add(new Vector(841, 3));         srcBounds.add(new Vector(816, 574));         srcBounds.add(new Vector(472, 683));                  int width = (int) ((Math.abs(srcBounds.get(1).getX() - srcBounds.get(0).getX()) + Math.abs(srcBounds.get(3).getX() - srcBounds.get(2).getX())) / 2);         int height = (int) ((Math.abs(srcBounds.get(3).getY() - srcBounds.get(0).getY()) + Math.abs(srcBounds.get(2).getY() - srcBounds.get(1).getY())) / 2);         BufferedImage dest = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB);         List<Vector> destBounds = getBoundsForImage(dest);                  transformImage(src, srcBounds, dest, destBounds);         ImageIO.write(dest, "jpg", new File("result.jpg"));     }          /**      * Performs a quad to quad image transformation.      *      * @param src        The source image.      * @param srcBounds  The bounds from the source image of the quad to transform.      * @param dest       The destination image.      * @param destBounds The bounds from the destination image of the quad to place the result of the transformation.      */     public static void transformImage(BufferedImage src, List<Vector> srcBounds, BufferedImage dest, List<Vector> destBounds) {         Graphics2D destGraphics = dest.createGraphics();         transformImage(src, srcBounds, destGraphics, dest.getWidth(), dest.getHeight(), destBounds);         destGraphics.dispose();     }          /**      * Performs a quad to quad image transformation.      *       * @param src        The source image.      * @param srcBounds  The bounds from the source image of the quad to transform.      * @param dest       The destination graphics.      * @param destWidth  The width of the destination graphics.      * @param destHeight The height of the destination graphics.      * @param destBounds The bounds from the destination graphics of the quad to place the result of the transformation.      */     @SuppressWarnings("IntegerDivisionInFloatingPointContext")     public static void transformImage(BufferedImage src, List<Vector> srcBounds, Graphics2D dest, int destWidth, int destHeight, List<Vector> destBounds) {         if ((src == null) || (srcBounds == null) || (dest == null) || (destBounds == null) ||                 (srcBounds.size() != 4) || (destBounds.size() != 4)) {             return;         }                  Matrix3 projectiveMatrix = calculateProjectiveMatrix(srcBounds, destBounds);         if (projectiveMatrix == null) {             return;         }                  final int filterColor = new Color(0, 255, 0).getRGB();                  BufferedImage maskImage = new BufferedImage(destWidth, destHeight, BufferedImage.TYPE_INT_RGB);         Graphics2D maskGraphics = maskImage.createGraphics();         maskGraphics.setColor(new Color(filterColor));         maskGraphics.fillRect(0, 0, maskImage.getWidth(), maskImage.getHeight());         Polygon mask = new Polygon(                 destBounds.stream().map(e -> (int) e.getX()).mapToInt(Integer::valueOf).toArray(),                 destBounds.stream().map(e -> (int) e.getY()).mapToInt(Integer::valueOf).toArray(),                 4         );         Vector maskCenter = Vector.averageVector(destBounds);         maskGraphics.setColor(new Color(0, 0, 0));         maskGraphics.fillPolygon(mask);         maskGraphics.dispose();                  int srcWidth = src.getWidth();         int srcHeight = src.getHeight();         int maskWidth = maskImage.getWidth();         int maskHeight = maskImage.getHeight();                  int[] srcData = ((DataBufferInt) src.getRaster().getDataBuffer()).getData();         int[] maskData = ((DataBufferInt) maskImage.getRaster().getDataBuffer()).getData();                  Set<Integer> visited = new HashSet<>();         Stack<Point> stack = new Stack<>();         stack.push(new Point((int) maskCenter.getX(), (int) maskCenter.getY()));         while (!stack.isEmpty()) {             Point p = stack.pop();             int x = (int) p.getX();             int y = (int) p.getY();             int index = (y * maskImage.getWidth()) + x;                          if ((x < 0) || (x >= maskWidth) || (y < 0) || (y >= maskHeight) ||                     visited.contains(index) || (maskData[y * maskWidth + x] == filterColor)) {                 continue;             }             visited.add(index);                          stack.push(new Point(x + 1, y));             stack.push(new Point(x - 1, y));             stack.push(new Point(x, y + 1));             stack.push(new Point(x, y - 1));         }                  visited.parallelStream().forEach(p -> {             Vector homogeneousSourcePoint = projectiveMatrix.multiply(new Vector(p % maskWidth, p / maskWidth, 1.0));             int sX = BoundUtility.truncateNum(homogeneousSourcePoint.getX() / homogeneousSourcePoint.getZ(), 0, srcWidth - 1).intValue();             int sY = BoundUtility.truncateNum(homogeneousSourcePoint.getY() / homogeneousSourcePoint.getZ(), 0, srcHeight - 1).intValue();             maskData[p] = srcData[sY * srcWidth + sX];         });         visited.clear();                  Shape saveClip = dest.getClip();         dest.setClip(mask);         dest.drawImage(maskImage, 0, 0, maskWidth, maskHeight, null);         dest.setClip(saveClip);     }          /**      * Calculates the projective matrix for a quad to quad image transformation.      *       * @param src  The bounds of the quad in the source.      * @param dest The bounds of the quad in the destination.      * @return The projective matrix.      */     private static Matrix3 calculateProjectiveMatrix(List<Vector> src, List<Vector> dest) {         Matrix3 projectiveMatrixSrc = new Matrix3(new double[] {                 src.get(0).getX(), src.get(1).getX(), src.get(3).getX(),                 src.get(0).getY(), src.get(1).getY(), src.get(3).getY(),                 1.0, 1.0, 1.0});         Vector solutionSrc = new Vector(src.get(2).getX(), src.get(2).getY(), 1.0);         Vector coordinateSystemSrc = projectiveMatrixSrc.solveSystem(solutionSrc);         Matrix3 coordinateMatrixSrc = new Matrix3(new double[] {                 coordinateSystemSrc.getX(), coordinateSystemSrc.getY(), coordinateSystemSrc.getZ(),                 coordinateSystemSrc.getX(), coordinateSystemSrc.getY(), coordinateSystemSrc.getZ(),                 coordinateSystemSrc.getX(), coordinateSystemSrc.getY(), coordinateSystemSrc.getZ()         });         projectiveMatrixSrc = projectiveMatrixSrc.scale(coordinateMatrixSrc);                  Matrix3 projectiveMatrixDest = new Matrix3(new double[] {                 dest.get(0).getX(), dest.get(1).getX(), dest.get(3).getX(),                 dest.get(0).getY(), dest.get(1).getY(), dest.get(3).getY(),                 1.0, 1.0, 1.0});         Vector solutionDest = new Vector(dest.get(2).getX(), dest.get(2).getY(), 1.0);         Vector coordinateSystemDest = projectiveMatrixDest.solveSystem(solutionDest);         Matrix3 coordinateMatrixDest = new Matrix3(new double[] {                 coordinateSystemDest.getX(), coordinateSystemDest.getY(), coordinateSystemDest.getZ(),                 coordinateSystemDest.getX(), coordinateSystemDest.getY(), coordinateSystemDest.getZ(),                 coordinateSystemDest.getX(), coordinateSystemDest.getY(), coordinateSystemDest.getZ()         });         projectiveMatrixDest = projectiveMatrixDest.scale(coordinateMatrixDest);                  try {             projectiveMatrixDest = projectiveMatrixDest.inverse();         } catch (ArithmeticException ignored) {             return null;         }         return projectiveMatrixSrc.multiply(projectiveMatrixDest);     }          /**      * Loads an image.      *       * @param file The image file.      * @return The BufferedImage loaded from the file, or null if there was an error.      */     public static BufferedImage loadImage(File file) {         try {             BufferedImage tmpImage = ImageIO.read(file);             BufferedImage image = new BufferedImage(tmpImage.getWidth(), tmpImage.getHeight(), BufferedImage.TYPE_INT_RGB);             Graphics2D imageGraphics = image.createGraphics();             imageGraphics.drawImage(tmpImage, 0, 0, tmpImage.getWidth(), tmpImage.getHeight(), null);             imageGraphics.dispose();             return image;         } catch (Exception ignored) {             return null;         }     }          /**      * Creates the default bounds for an image.      *       * @param image The image.      * @return The default bounds for the image.      */     public static List<Vector> getBoundsForImage(BufferedImage image) {         List<Vector> bounds = new ArrayList<>();         bounds.add(new Vector(0, 0));         bounds.add(new Vector(image.getWidth() - 1, 0));         bounds.add(new Vector(image.getWidth() - 1, image.getHeight() - 1));         bounds.add(new Vector(0, image.getHeight() - 1));         return bounds;     }      } 

If you would like to run this yourself, the Matrix3 and Vector operations can be found here: https://github.com/ZGorlock/Graphy/blob/master/src/math/matrix/Matrix3.java https://github.com/ZGorlock/Graphy/blob/master/src/math/vector/Vector.java

Also, here is some good reference material for projective transformations:

http://graphics.cs.cmu.edu/courses/15-463/2006_fall/www/Papers/proj.pdf

https://mc.ai/part-ii-projective-transformations-in-2d/

Answered on July 16, 2020.
Add Comment

Your Answer

By posting your answer, you agree to the privacy policy and terms of service.