JAX-FEM | Advancing Development Towards Large-Scale Problem Solving

"The integration of machine learning and physical modeling is revolutionizing the paradigm of scientific research. People aiming to push the boundaries of science and solve challenging problems through computational modeling are coming together in unprecedented ways." Recently, the DeepModeling open-source community has welcomed a new member in the field of macro-scale computation. To further advance the development of the JAX-FEM project, a differentiable finite element method library, JAX-FEM will join the DeepModeling community. Together with developers and users in the community, it aims to expand the frontiers of finite element methods in the AI4Science era.

Community project homepage:
https://github.com/deepmodeling/jax-fem

Finite Element Method (FEM):

FEM is a representative physical modeling technique widely applied in engineering manufacturing and industrial design. It divides the entire problem domain into subdomains, defines basis functions over each subdomain, and combines the finite elements to approximate the solution of the original equations. FEM is effectively used in structural analysis, fluid dynamics, heat transfer, and electromagnetic problems for solving partial differential equations numerically.

Challenges in Conventional FEM:

Despite its widespread application, FEM faces challenges as problem scales grow. These challenges include increased computational complexity and memory requirements, leading to prolonged computational times and excessive resource consumption. Additionally, the user-friendliness of interfaces (e.g., explicit derivation of linear incremental forms for solving nonlinear problems) and the extensibility of code interfaces (e.g., integration with data-driven models) in many commercial and open-source FEM software are limited.

What is JAX-FEM?

JAX-FEM is designed to overcome these limitations. By utilizing Google’s JAX library—a high-performance numerical computing and machine learning toolkit—the developers have created an extensible FEM solution written entirely in Python. This enables efficient solving of nonlinear finite element problems.


Key Features and Advantages of JAX-FEM

  1. Interactivity and Extensibility:
    Unlike most existing FEM libraries implemented in low-level languages like C/C++ or Fortran, JAX-FEM is developed in pure Python. This significantly enhances development flexibility and ease of use, providing a streamlined experience for both users and developers.

  2. GPU Acceleration:
    JAX-FEM can leverage GPUs for accelerated computation, which is particularly useful for problems with a large number of degrees of freedom (DOF). Supported by JAX’s XLA backend for linear algebra acceleration, JAX-FEM achieves competitive performance, especially when GPUs are available. For example, in a 3D tensile loading problem with 7.7 million DOF, JAX-FEM’s GPU-based computation is approximately 10 times faster than commercial FEM software.

  3. Automatic Differentiation (AD):
    JAX-FEM incorporates AD technology, eliminating the need for manual derivation of sensitivity or Jacobian matrices. This simplifies solving nonlinear and inverse problems, automating complex mathematical operations. It is particularly useful in areas like engineering optimization and parameter identification.

  1. Integration with Machine Learning:
    Beyond being a finite element solver, JAX-FEM effectively integrates machine learning capabilities. Users can seamlessly train and deploy models on a single platform, enabling data-driven computational mechanics research.

  2. Open Source and Community Support:
    As an open-source project, JAX-FEM encourages contributions and collaboration. Users can freely access the source code, customize it according to their needs, and share improvements and applications with others.


Joining and Using JAX-FEM

JAX-FEM has joined the DeepModeling open-source community. The project is hosted at:
https://github.com/deepmodeling/jax-fem


Future Development Plans

The upcoming development of JAX-FEM will revolve around solving ultra-large-scale problems using CPU/GPU hybrid parallelization and integrating third-party linear solvers to maximize computational efficiency. The project will also explore machine-learning-based multiscale computational problems, organically combining FEM as a macro-scale tool with other micro-scale tools.

In summary, JAX-FEM combines modern computing techniques and automatic differentiation, offering a powerful and efficient tool for automating finite element analysis workflows, particularly in scenarios involving complex physical phenomena and material behaviors.

If you are interested in contributing to the development of differentiable FEM algorithms and frameworks, the development team warmly invites you to connect with them.


Reference:
Xue, T., Liao, S., Gan, Z., Park, C., Xie, X., Liu, W. K., & Cao, J. (2023). JAX-FEM: A differentiable GPU-accelerated 3D finite element solver for automatic inverse design and mechanistic data science. Computer Physics Communications, 108802.