Rax is a library for composable Learning-to-Rank (LTR) written entirely in JAX. The goal of Rax is to facilitate easy prototyping of LTR systems by leveraging the flexibility and simplicity of JAX. Rax provides a diverse set of popular ranking metrics and losses that integrate well with the rest of the JAX ecosystem. Furthermore, Rax implements a system of ranking-specific function transformations which allows fine-grained customization of ranking losses and metrics. Most notably Rax provides approx-t12n: a function transformation (t12n) that can transform any of our ranking metrics into an approximate and differentiable form that can be optimized. This provides a systematic way to directly optimize neural ranking models for ranking metrics that are not easily optimizable in other libraries. We empirically demonstrate the effectiveness of Rax by benchmarking neural models implemented using Flax and trained using Rax on two popular LTR benchmarks: WEB30K and Istella. Furthermore, we show that integrating ranking losses with T5, a large language model, can improve overall ranking performance on the MS MARCO passage ranking task. We are sharing the Rax library with the open source community as part of the larger JAX ecosystem at https://github.com/google/rax.
CITATION STYLE
Jagerman, R., Wang, X., Zhuang, H., Qin, Z., Bendersky, M., & Najork, M. (2022). Rax: Composable Learning-to-Rank Using JAX. In Proceedings of the ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (pp. 3051–3060). Association for Computing Machinery. https://doi.org/10.1145/3534678.3539065
Mendeley helps you to discover research relevant for your work.