How to Functional in Java

Stephen Harrison
Rue Gilt Groupe Tech Blog
5 min readMar 9, 2016

We’ll show you how to get started using Java 8’s functional programming features. When you’re done, you’ll be able to read and write Java code that looks like

private static <T> Map<T, Long> occurrences(final Stream<T> in) {
return in.collect(groupingBy(x -> x, counting()));
}

Many of us are using tried-and-trusted imperative C#, Java, Python, and so on. There’s nothing wrong with those languages, of course. Amazing code is developed at Rue every day using them. Until Scala or Haskell are mainstream at your company, can you meet those languages halfway? This blog looks at Java 8’s functional enhancements and suggests places you can start to incorporate them.

Scala took a different approach to Java. So while Scala still runs on the Java VM, it was designed to marry functional and object-oriented at its core, so it’s a new language altogether. Some feel (including the author) this can give Scala code a more polished appearance than the same code in Java, even after functional features have been layered on. We’ll see. Finished Scala is often startlingly beautiful to look at, meeting the obligation of “provably correct by inspection” more often than most languages.

Does Java’s functional upgrade instantly make your code beautiful and easy to read? Of course it doesn’t. Not without care, in any case. So what do we have to do? Where do we start?

As you’re learning how to include Java 8’s new functional features, you’re probably going to start at the method level. Let’s try an example. Suppose we want to count the number of times each element appears in a List<T>. We’ll keep a Map<T, Long> with keys for each unique element and values for the number of times it occurs.

private static Map<String, Long> counts(final List<String> in) {
return new HashMap<String, Long>() {
{
for (final String s : in) {
final Long value = get(s);

put(s, value == null ? 1L : value + 1);
}
}
};
}

If we’re being completely honest, it’s a bit lame. Digging around, we notice that Java 8’s Map interface includes a new method, getOrDefault(), which lets us remove an explicit conditional. We’ll also make it generic because we’re exceptionally good at this kind of thing.

private static <T> Map<T, Long> counts(final List<T> in) {
return new HashMap<T, Long>() {
{
for (final T e : in) {
final Long value = getOrDefault(e, 0L);
put(e, value + 1);
}
}
};
}

It’s still lame. What happened? Well the issue is that although we’re using some new Java 8 features, we didn’t think about our problem in a functional way. So rather than just translate something from the old way, let’s explore another Java 8 feature, Streams. We find java.util.stream.Collectors.groupingBy and it feels like we’re getting warm. A few imports later, and we land on

private static <T> Map<T, Long> counts(final Stream<T> in) {
return in.collect(groupingBy(x -> x, counting()));
}

Much better. It reads like English left to right, which a lot of functional code tends to do. You can say identity for functional idioms like x -> x or even better the thing it’s the identity of, in this case key. Try saying it out loud: “take the stream and collect the counts of the groups of keys.” Here’s the method in a micro test harness. It times counting the frequency of random integers in a Stream.

package com.ruelala;

import static java.util.stream.Collectors.counting;
import static java.util.stream.Collectors.groupingBy;

import java.util.Map;
import java.util.Random;
import java.util.stream.Stream;

public class HowToFunctionalInJava {
private static final Random R = new Random();
private static final int N = 1 << 26;

public static void main(final String[] args) {
final Stream<Integer> randomInts = R.ints(N, 0, 1000).boxed();
final long start = System.currentTimeMillis();
final Map<Integer, Long> counts = counts(randomInts);
final long elapsed = System.currentTimeMillis() - start;

System.out.println(N + " in " + elapsed + "ms");
System.out.println(counts);
}

private static <T> Map<T, Long> counts(final Stream<T> in) {
return in.collect(groupingBy(x -> x, counting()));
}
}

On my Mac laptop, the output is

67108864 in 2445ms
{0=67493, 1=67284, 2=67167, 3=66703, 4=67240, 5=67653, 6=66953, 7=66922, 8=67060...

That’s the frequency of over 67 million integer values counted in less than 2.5s, or about 27 million a second. It seems to be good enough for now. We just dipped our toe in Java 8’s functional features by looking at how to manipulate collections. Future blogs will take this a lot further. Until then, what could you do differently with code you’re working on? Write and let me know.

Functional programming takes off when you write in a language that’s designed that way. So here’s what a sudoku-board validator looks like in Scala. I tried to make it beautiful, but I’m sure there’s room for improvement. Even if you don’t know Scala, I’ll challenge you to read and understand it. The gist, anyway. Small functions with good names help make it flow. We’re basically setting things up for the isValid function. And by the time we get there, it seems almost too easy.

Can you make it clearer, faster, or prettier?

case class Index(r: Int, c: Int)

object SudokuValidator extends App {
val valid = (1 to 9).toSet
val indexes = (0 to 8).toSet

def rowOf(r: Int) = indexes.map(Index(r, _))
def columnOf(c: Int) = indexes.map(Index(_, c))
def blockOf(b: Int) = indexes.map(i => Index((b / 3) * 3 + i / 3, (b % 3) * 3 + i % 3))

val allGroups = indexes.map(rowOf) ++ indexes.map(columnOf) ++ indexes.map(blockOf)

def isValid(board: Seq[Seq[Int]]) = {
def groupOf(group: Set[Index]) = group.map(i => board(i.r)(i.c))

allGroups.map(groupOf).forall(valid.equals)
}

val badBoard = Seq(
Seq(1, 2, 3, 4, 5, 6, 7, 8, 9),
Seq(1, 2, 3, 4, 5, 6, 7, 8, 9),
Seq(1, 2, 3, 4, 5, 6, 7, 8, 9),
Seq(1, 2, 3, 4, 5, 6, 7, 8, 9),
Seq(1, 2, 3, 4, 5, 6, 7, 8, 9),
Seq(1, 2, 3, 4, 5, 6, 7, 8, 9),
Seq(1, 2, 3, 4, 5, 6, 7, 8, 9),
Seq(1, 2, 3, 4, 5, 6, 7, 8, 9),
Seq(1, 2, 3, 4, 5, 6, 7, 8, 9))
val goodBoard = Seq(
Seq(9, 5, 3, 2, 1, 4, 7, 6, 8),
Seq(2, 7, 6, 8, 5, 3, 4, 1, 9),
Seq(8, 1, 4, 6, 7, 9, 2, 3, 5),
Seq(7, 4, 8, 5, 3, 1, 6, 9, 2),
Seq(6, 9, 1, 7, 4, 2, 5, 8, 3),
Seq(5, 3, 2, 9, 6, 8, 1, 7, 4),
Seq(1, 6, 9, 4, 8, 5, 3, 2, 7),
Seq(3, 2, 5, 1, 9, 7, 8, 4, 6),
Seq(4, 8, 7, 3, 2, 6, 9, 5, 1))

val results = Seq(badBoard, goodBoard).map(isValid)

println(results)
}

We’ll talk about a Java version of this in another post. Until next time, have fun and stay functional.

Originally published at ruetech.io on March 9, 2016.

--

--