May 14, 2024

TorchCat v0.1

tl;dr

New web application for visualizing PyTorch and TensorFlow operations.

https://torchcat.com

A little PyTorch quiz

Consider the following example:

  • Let 'a' be a vector of length 8, comprised of integers from 0 to 7.
  • 'a_wide' is a 2D matrix of dimensions 2x4, transformed from 'a' using torch.reshape().
  • 'a_tall' is a 4x2 reshape of 'a'.

Alternatively,

a = torch.arange(8) # [0,1,2,3,4,5,6,7]
a_wide = a.reshape(2,4)
a_tall = a.reshape(4,2)

Which of the following operations would make 'a_prime' equal to 'a_wide', so that the following statement is True?

torch.equal(a_wide, a_prime)
(a) a_prime = a_tall.transpose(1,0)
(b) a_prime = a_tall.rot90() # rotate 90 degs counterclockwise
(c) a_prime = a_tall.rot90(-1) # rotate 90 degs clockwise

If you've answered (d) none of the above, you are correct.

The key difference lies in how torch.reshape(2,4) and torch.reshape(4,2) split and arrange the elements.

  • reshape(2,4) performs a single cut down the middle, resulting in [[0,1,2,3],[4,5,6,7]].
  • reshape(4,2) interleaves the elements like a zipper: [[0,1],[2,3],[4,5],[6,7]].

Transposition or rotation alone cannot make a_tall equal to a_wide, due to their fundamental difference in arrangement.

Why working with tensors can be tricky

This nuance of torch.reshape() may be trivial to experienced practitioners, but it's caused (and still causes) me a lot of frustration. The documentations are oftentimes vague and some interactions can only be learned through trial and error. The above example is particularly nefarious because a_tall.t() will not be caught by mismatch errors, potentially resulting in downstream issues at runtime.

There is a gap in how we conceptualize matrices and how we implement them in our code. As learners we are introduced to matrices visually, yet the underlying data structure of n-dimensional arrays like [[1,2],[3,4]] does not readily map to our mental model. This problem is exacerbated when working across multiple libraries, as one discovers that even methods with the same name may work differently under the hood.1

Introducing TorchCat

That's why I made TorchCat. It's a web-based tool for visualizing and animating various tensor operations in real time. Using TorchCat, the user can observe the movement of individual elements that is abstracted away behind each function call. Some operations are just mesmerizing to gaze at, like the changing formations of migratory birds.

TorchCat's tech is simple. It's a single page React application built on Vite. I used tensorflow.js to implement the tensor data structure that runs on the browser. Framer Motion handles the animation of the svg elements. As PyTorch does not have a dedicated JavaScript backend, I referenced nobuco to implement its functions in tensorflow.js.

The app being client-side only comes with several advantages. First, it costs me practically nothing, since I'm just serving a static page. It allows me to keep TorchCat free and ad-free, even in the odd chance of high traffic. The second benefit is the smooth user experience. The tensors glide about almost instantaneously, making the app a convenient resource to have on a second screen or a phone.

On the other hand, the current implementation cannot provide a 100% guarantee that the output from TorchCat is equivalent to that from a corresponding operation in a Python runtime. The app works well in most scenarios, but I encourage the user to double check the output, should it be for an important project. I identify several points of failure:

  • Implementation inconsistencies between TensorFlow Python and TensorFlow.js.
  • Possible errors stemming from differing data types and low-level execution of Python and JavaScript environments.
  • Human errors when porting functions from other libraries (e.g., PyTorch) to TensorFlow.
  • Potential differences in how the web UI presents the data compared to the actual underlying data structure.

TO-DO

Output consistency is only one of the many issues I'd like to address in the future. Below is the list of features that I plan to implement:

  1. Visualization option for 4+ dimensions. I had an interesting discussion with a colleague on using a slider for the 4th dimension, allowing the user to iterate through each 3D tensor over the axis. This falls in line with the intuition of conceptualizing "time" as the 4th dimension. It's also appropriate for the most common use case of 4D tensors in image tasks, reserving the 4th dimension for batches. For dimensions of 5 or more, I will probably use a dropdown.

  2. Server-side computation. A Django server that runs the same function (or sequence of functions) in Python and compares it to the displayed output for validation. It may also be used to handle features and larger computations that are not feasible client-side.

  3. Code conversion. Server-side validation should be preceded by a module to convert the function calls in the browser to Python code. The inverse may also be useful as a UI option, allowing users to insert a code snippet for TorchCat to run.

  4. Functions with two or more tensor inputs. The site's namesake torch.cat() has not yet been implemented. Work should primarily be on the UI, since the app can track multiple tensors already. Ideally, I'd have the concatenation or multiplication output to maintain its provenance, so that each element maps to its origin in the input tensors. I believe this is doable for 2 tensors. Design would depend on actual use cases.

  5. Support for more functions, more libraries, and more versions.

  6. Open-sourcing. Given that TorchCat is a hobby project, updates are contingent on my spare time. If there's sufficient interest and demand for more frequent updates, it'd be good to clean up the code and make the repo public.

This post will be kept up to date with new developments. If you're reading this post, thank you!

I welcome any feedback or ideas for improvement. Please contact me if you have any suggestions.

Footnotes

  1. This stackoverflow post on torch.gather() is a good example. Additionally, even functions that serve a similar purpose like torch.repeat() and tf.tile() vary slightly in how they assume default values to parameters.