Deep Learning On Demand with Spark, Akka, and GraphQL
Machine learning can be an opaque undertaking. As algorithms grow more and more complex, we need specialized tools to answer questions like, "Why did the computer think this was spam?" or "Why did your service recommend this movie to me?" In my last post, I wrote about a model that was elegantly straightforward: the information content of an item is the positive logarithm of its frequency. Newer models, however, are far less transparent: most famously, the output of Google's DeepDream pattern-recognition software produces phantasmagoric images that are at once completely fascinating and entirely unfathomable.
In this article, I'll start by creating a Spark job to train a neural net for some straightforward text analysis. But to explore and analyse the model, we'll need dynamic and interactive query abilities. For that purpose, we're going to use GraphQL to build a powerful, self-documenting query language for our data, and put it up on an Akka HTTP server. GraphQL is pretty far out on the cutting-edge, but as you'll see, the largely in-memory nature of our data makes things a lot more streamlined that a database-backed system might be. And in a follow-up post, we'll look at how we can use GraphQL, React, and Relay to build powerful UIs and visualizations based on the architecture described here. But first, we need to go back to the data, and the algorithm.
Oh, and if you want the code, it's available in a GitHub repo, of course.
What is a word embedding?
A word embedding is a numerical representation of the meaning of a word in a low-dimensional vector space, usually in the range of 50 to 300 dimensions. This sort of "distributional representation" has been a staple of machine learning for decades, but there's been a renaissance since the publication of Tomas Mikolov's Word2Vec algorithm in 2013.
Whereas older models like latent semantic analysis (LSA) used "bag of words" representations of word frequency at a document level, Word2Vec uses a neural network to learn a representation of words based on their usage and context at the sentence level. These vector representations serve as the fundamental building blocks of algorithms for sentiment analysis, entity recognition, machine translation, and even image classification.
Implementation details aside, Word2Vec is a remarkable algorithm. Unlike the algorithms used for other deep learning tasks such as image recognition, Word2Vec can be computed efficiently on mainstream hardware, rather than specialized GPUs or Google's custom tensor processing hardware. And it produces surprising results on relatively small datasets.
Training our model
For this blog post, I'm going to use the celebrated Brown Corpus, consisting of about 1 million word of English, mostly from newspaper articles in the 1960's. Since it's already split up into a single sentence per line, it's perfect for our model. You can download it here if you're following along with the code.
Since we're using Spark ML's high-level interface, the actual training is straightforward. Given a
DataFrame with a column of sentences as
Array[String], it's as easy as:
And if you look at the SparkJob class, you'll see that most of the work is in reading and tokenizing the text, or in extracting and manipulating the vectors of the model afterwards.
Building our schema
Now that we have a model, we can build a query schema to interact with it. GraphQL uses the same schema language to define queries and resultant data types, which turns out to be very, very powerful, since it enables sophisticated recursive and asynchronous programming techniques that support complex nested queries. The big win for developers is that the query parsing and execution is common infrastructure; I'll be using Sangria, which is an excellent and well-documented Scala implementation of the specification.
In Sangria, each
ObjectType declaration has a
fields defintion, which is parameterized with
Ctx (context) and
Value type parameters.
Value is easy--it's whatever the concrete value you're returning--but designing the global
Ctx is more subtle. The context is going to need all the global resources you need to evaluate any query, which means that remote services, data access, and other complexities all get encapsulated within the
Ctx. For our purposes, I've designed a simple case class named
ModelRepo, that looks like this:
words is the 100-dimension vector for each word in the model.
dimensions is a list of all 100 dimensions, and for each one, we have a ranked list of all the words and their weights in that dimension. Finally, we've attached the model itself and the
SparkSession, which we'll use to fire off new computations for things like finding neighbors and clusters. With that in place, implementing our schema is mostly straightforward: we just define all the fields of our type, with
resolve callbacks for each one. Some are trivial attributes of the
Value, whereas others are computed.
As you can see, the
vector fields are just returning the raw data, and
vectorSize is just returning the length of the
synonyms is different: it takes an argument of the number of synonyms to fetch, then it actually launches a task in the
SparkSession to find the synonyms of that word, and then return the results to the user.
I won't go through the whole SchemaDef line-by-line, but we define types and documentation for all of our types, including references to other types, and eventually build up a whole query language for retrieving word vectors and dimensions either by name or by number or in bulk. Once we've defined the schema, we're done! Sangria takes care of parsing queries for us, and executing them with the schema and resolvers we've provided. Compared to a typical CRUD-oriented REST API, this can seem highly formalized, but it really pays off when you want to specify complex nested queries and the like--rather than have to write new request handlers for every query you might ever want, the framework handles the query structure for you, in a way that feels almost magical.
Putting it all together
Once we have our Schema, we're ready to put it on the network. For this, I'm going to use Akka HTTP, which is my preferred tool for building embedded webservers in my Scala programs. Although Spark no longer uses Akka internally, it still is very straightforward to embed in a Spark job driver process, which makes it trivial to launch a web server after completing some distributed computation. In our
Main.scala class, we do just that, but we also have entrypoints that will compute a model and write it to a file, or read a model file directly into a
SparkSession without the parsing and fitting, for convenience. The Server class itself is adapted from Oleg Ilyenko's sangria-akka-http-example project, with an attempt at abstracting over the
Once our server is up and running, we're good to go. We've actually included the wonderful GraphiQL tool along with the program, so you can navigate to
http://localhost:8080/graphiql.html and start kicking the tires if you're following along. GraphiQL uses GraphQL's schema introspection capabilities to build an intelligent, general-purpose query UI with extremely nice syntax assistance and autocompletion. It's a joy to use after a decade or so of curl and Postman. We're going to do a query for the 10 closest synonyms of "government" according to our model, like so:
And see our results:loading
Anecdotally, most machine learning researchers who I've spoken to report a common experience: initially skeptical of deep learning as an over-hyped novelty, they run their data into it and are immediately surprised at the quality of the output. For me, this was that kind of moment--these are unusually good results for a data set this small. Anecdotally, I've noticed that the kind of big, abstract nouns that are common in newspaper articles are well represented: "education", "business", etc., whereas more specific or concrete terms get more mixed results.
But to really explore the model in detail, I'll wait until the next post in this series, when we'll build out an interactive front-end to explore our model with React and Relay, and talk about what a full-stack application looks like with this architecture.