I started making a ray tracer in Haskell for fun so I’m going to blog about it. A ray tracer is a piece of software that generates an image of a 3D scene by simulating the way light moves around the scene and enters a camera. Since light travels in straight lines until it interacts with matter, this is done by tracing straight light rays around the scene.
This isn’t going to be a tutorial, but I’ll go over some of the quirks of writing a ray-tracer in a pure functional language where you don’t have access to mutable state.
I used two books for reference, Ray Tracing from the Ground Up by Kevin Suffern, and Physically Based Rendering by Matt Pharr and Greg Humphreys.
Why write a ray tracer in Haskell?
There were a lot of things I didn’t know how to do in Haskell before starting this project, like for example how do you load a mesh from a file into a Haskell-style Algebraic Data Type (ADT). Writing a ray tracer is a great way to learn.
Also, I thought that, in principle, a ray tracer should be easier to express in Haskell because all of the operations in a ray tracer have well defined functions because they’re derived from the rendering equation. For example, when shading a surface point, we take a few inputs, including the incoming light direction, the surface and light color, and the surface normal and produce an outgoing light ray direction. There’s no need to modify the existing state, or even store the old state of the ray for later use, the previous ray can be discarded.
The basics: ray intersections
I started out writing the data types to represent shapes and rays. Like so:
data Shape = Plane (Point V3 Double) (V3 Double)
| Sphere (Point V3 Double) Double
| AABB (M44 Double) (V3 Double) (V3 Double)
| Triangle (Point V3 Double) (Point V3 Double) (Point V3 Double) (V3 Double)
| Disk (Point V3 Double) (V3 Double) Double
| Rectangle (Point V3 Double) (V3 Double) (V3 Double) (V3 Double)
deriving (Show, Eq)
data Ray = Ray { rayOrigin :: Point V3 Double
, rayDirection :: V3 Double
}
deriving (Show, Eq)
In order to store the location where a light ray hits an object, I created an intersection data structure:
data Intersection = Intersection { intersectionPoint :: Point V3 Double
, intersectionNormal :: V3 Double
, tMin :: Double
}
deriving (Show, Eq)
Then performing a ray-shape intersection is as simple as it would be in any other language, just take a ray and a shape and return an intersection. In this case, we’ll return a Maybe Intersection, because it’s possible that the ray misses the object entirely, in which case you get a result of Nothing. Here’s an example of how that works with a ray and a flat plane:
rayIntersection :: Ray -> Shape -> Maybe Intersection
rayIntersection (Ray {rayOrigin = ro, rayDirection = rd}) (Plane planePoint planeNormal) =
let denominator = (rd `dot` planeNormal)
in if (denominator > -rayEpsilon) && (denominator < rayEpsilon)
then Nothing
else let t = (planePoint .-. ro) `dot` (planeNormal ^/ denominator)
in if t <= rayEpsilon
then Nothing
else Just (Intersection {intersectionPoint = ro .+^ (rd ^* t), intersectionNormal = planeNormal, tMin = t})
Lights, cameras, materials, shaders
There are a few other things you need to represent a scene, including lights:
data Light = EnvironmentLight (Color Double)
| PointLight (Point V3 Double) (Color Double)
| DirectionalLight (V3 Double) (Color Double)
| DiskLight (Point V3 Double) (V3 Double) Double (Color Double) -- Point, normal, and radius
| SphereLight (Point V3 Double) Double (Color Double) -- Point and radius
| RectangleLight (Point V3 Double) (V3 Double) (V3 Double) (Color Double) -- Point and radius
deriving (Show, Eq)
cameras:
data Camera = Camera (Point V3 Double) (V3 Double) (V3 Double) -- Origin, look, and up
deriving (Show, Eq)
and materials:
data Material = ColorMaterial (Color Double) -- Color (No shading)
| MatteMaterial (Color Double) Double -- Diffuse, kD
| PlasticMaterial (Color Double) Double (Color Double) Double Double -- Diffuse, kD, Specular, kS, kExp
deriving (Show, Eq)
In order to tie all of these systems together, I created a class to represent an object, which has a shape, a material, and a function, called a shader, from a ShadePoint to a Color.
data Object = Object Shape Material (ShadePoint -> Color Double)
A ShadePoint contains everything you need to shade a surface point of an object: the material of the object, the surface normal, the incoming light ray direction, and the outgoing light ray direction:
data ShadePoint = ShadePoint Material (V3 Double) (V3 Double) (V3 Double)
deriving (Show, Eq)
Here’s an example of a diffuse shader (diffuseF is a helper function):
diffuseF :: Color Double
-> Double
-> Color Double
diffuseF diffuse kD =
let invPi = 1.0 / pi
in diffuse ^* (kD * invPi)
lambertShader :: ShadePoint -> Color Double
lambertShader (ShadePoint (ColorMaterial color) normal wIn wOut) = color
lambertShader (ShadePoint (MatteMaterial diffuse kD) normal wIn wOut) = diffuseF diffuse kD
lambertShader (ShadePoint (PlasticMaterial diffuse kD _ _ _) normal wIn wOut) = diffuseF diffuse kD
Tracing
If you’re going to trace rays, there’s no point tracing against a single triangle. To make a meaningful render, you need to trace a scene like this:
data Scene = ListScene [Object]
| KDScene KDTree
deriving (Show)
I started out with the most basic kind of scene, just a list of objects and made a scene with a KD-tree accelerator later. I’ll talk about that in a later post, but for now let’s look at how to trace a list scene.
The approach to tracing a list scene is to simply compare every ray for intersections against every object. This produces a very slow, but correct render.
Here’s the entire code for the tracer (Don’t worry about the LowDiscrepancySequence for now, I’ll describe that in a later post):
traceRays :: (LowDiscrepancySequence s)
=> Scene
-> Color Double
-> Ray
-> s
-> ((TraceResult, Ray), s)
traceRays (ListScene objects) bgColor ray gen =
((foldl' (\traceResult@(TraceResult (Intersection {tMin = traceTMin}) material shader) (Object shape objectMaterial objectShader) ->
case rayIntersection ray shape of
Nothing -> traceResult
Just objectIntersection@(Intersection {tMin = tm}) ->
if tm < traceTMin
then TraceResult objectIntersection objectMaterial objectShader
else traceResult) (emptyTraceResult bgColor) objects, ray), gen)
That’s it, a single foldl’. You take the minimum intersection point and that’s the point you shade. The shadow rays are cast in a separate function which is pretty much the same.
Lessons learned
I was surprised by a few things when I started developing the ray tracer, and I think that this is relevant to any projects which involve large amounts of data and computations:
- The tracer is very slow compared to C++. The tracer renders at approximately 36000 rays per second on a Core i5-8250U on 4 cores tracing a mesh with 16300 faces. This is a lot slower than I was hoping for, but still faster than it would be if written in a scripting language like Python. I didn’t spend that much time on optimization, but it’s definitely not performant without some more time spent on optimization. Switch from ADTs to other types and using more strict evaluation would probably help a lot but those are not obvious to me as a Haskell beginner.
- The code is really small. 1052 lines of code for a ray tracer with random sampling, a KD-tree accelerator, and mesh loading is incredibly concise. I like this because it makes it feasible for me to develop a large project like this on my own in my spare time. It might also be important if your work involves fast prototyping too.
- Using generic types makes computations even slower. My original code used the generic Floating typeclass. Removing the typeclass and specializing the code to use Doubles resulted in a 10x speedup!
- If you don’t use strict folds when working with lots of data, your program will crash. If you use foldl or foldr over enough data, your program will run out of memory from all of the thunks it allocates. Use foldl’ to avoid this.
- Dealing with I/O was actually pretty easy. There are only a couple of files which actually deal with I/O in the project and the rest of the code is 100% pure.
- Adding multithreading was annoying, but ultimately required very few changes. It took a lot of digging to figure out how to get the tracer to run on multiple cores, but in the end I just used a parListChunk and it worked.
- Modifying the code is super simple. Because everything is pure, adding a feature like random numbers to the program takes a fraction of the time it normally would. Iterating on a pure functional program is super fast.
- Lambdas are the most natural way to express shaders. In a ray tracer, when a ray hits an object, you need to determine how the light will reflect off of the surface and what color the resulting ray will be. This is done using a function called a shader. In most languages, shaders are bound to a surface using an ID and accessed using function pointers. In Haskell, you can simply store the shader as a lambda inside the object!
You can find the source code for this project under an MIT license at https://github.com/WhatTheFunctional/HaskellTracer.