Learning Bug Context for PyTorch-to-JAX Translation with LLMs

Abstract

Large language models (LLMs) have shown strong performance on code translation between widely used programming languages. However, translation becomes much less reliable for domain-specific code, where correctness depends on framework-specific APIs and execution semantics. One example is translating deep-learning code from PyTorch to JAX, where LLM outputs often contain subtle bugs or non-idiomatic usage that prevents execution or changes behavior. Prior work suggests that curated bug-fix data from LLM-generated code can help improve code generation quality, but such resources are still limited for PyTorch-to-JAX translation. In this work, we introduce T2J, a benchmark of LLM translation bugs paired with developer-written fixes for PyTorch-to-JAX code. We start from 20 kernels in the TorchLeet dataset, translate them to JAX using the weak LLM gpt-4o-mini, and hire software developers to debug and repair the generated JAX implementations. We then use T2J to improve PyTorch-to-JAX translation for the weak LLM gpt-4o-mini via in-context learning. Our evaluation shows that using T2J yields up to 20% improvement of our proposed metric T2J-CodeTrans-Score.

0

Turn this paper into a full lesson

ArcXiv compiles a staged curriculum from this paper: 8-12 lessons across beginner → advanced, synthesised section guides, visuals, flashcards, a quiz, exercises, and on-demand deep dives per section. Grounded in the abstract, never invented.

Discussion (0)

Sign in to join the discussion.

Loading comments…