diff --git a/.gitignore b/.gitignore index 11be381..cac3b3f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,19 +1,13 @@ *.csv *.gif -events.out.* +*.png *.zip -evaluations.npz -**/nuScenes_data/raw/* -**/nuScenes_data/cached_data/* -plots/megvii/* -!plots/megvii/*.py -!plots/megvii/*.ipynb -!plots/megvii/nuScenes-devkit-mods wandb/ -plots/nuScenes/* -!plots/nuScenes/*.ipynb -!plots/nuScenes/*.txt +*.pkl +*.pickle + + # Byte-compiled / optimized / DLL files __pycache__/ @@ -164,4 +158,6 @@ cython_debug/ # Mac OSX .DS_Store -._.DS_Store \ No newline at end of file +._.DS_Store + +settings.json \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json index 8be0f64..78846e8 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -1,72 +1,40 @@ { - // Use IntelliSense to learn about possible attributes. - // Hover to view descriptions of existing attributes. - // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", "configurations": [ { - "name": "Train new DiffStack model from trajdata (recommended)", + "name": "PL train CTT", "type": "python", "request": "launch", - //"module": "torch.distributed.run", - "program": "./diffstack/train.py", + "program": "diffstack/scripts/train_pl.py", "console": "integratedTerminal", - "env": {"PYDEVD_WARN_EVALUATION_TIMEOUT": "15"}, + "justMyCode": true, "args": [ - "--conf=./config/diffstack_default.json", - "--data_loc_dict={\"nusc_mini\": \"~/data/nuscenes_raw_annot\"}", - "--train_data=nusc_mini-mini_val", - "--eval_data=nusc_mini-mini_val", - "--predictor=tpp", - "--plan_cost=corl_default_angle_fix", - "--plan_loss_scaler=100", - "--plan_loss_scaler2=10", - "--device=cuda:0", - "--debug", - ], - }, + "--config_file=${workspaceFolder}/config/templates/CTTPredStack.json", + "--remove_exp_dir", + // "--debug", + "--dataset_path=", + ] + }, { - "name": "Load pretrained model, cached data.", + "name": "PL eval CTT", "type": "python", "request": "launch", - "module": "torch.distributed.run", + "program": "diffstack/scripts/train_pl.py", "console": "integratedTerminal", - "env": {"PYDEVD_WARN_EVALUATION_TIMEOUT": "15"}, + "justMyCode": true, "args": [ - "--nproc_per_node=1", - "./diffstack/train.py", - "--data_source=cache", - "--cached_data_dir=~/data/corl22_public", - "--train_data=nusc_mini-mini_train", - "--eval_data=nusc_mini-mini_val", - "--predictor=tpp_cache", - "--dynamic_edges=yes", - "--plan_cost=corl_default", - "--eval_batch_size=32", - "--load=~/data/corl22_public/diffstack_rl.pt", - "--debug", - ], - }, - - { - "name": "Train model on preprocessed cache data.", - "type": "python", - "request": "launch", - "module": "torch.distributed.run", - "console": "integratedTerminal", - "env": {"PYDEVD_WARN_EVALUATION_TIMEOUT": "15"}, - "args": [ - "--nproc_per_node=1", - "./diffstack/train.py", - "--data_source=cache", - "--cached_data_dir=~/data/corl22_public", - "--train_data=nusc_mini-mini_train", - "--eval_data=nusc_mini-mini_val", - "--predictor=tpp_cache", - "--dynamic_edges=yes", - "--plan_cost=corl_default", - "--device=cuda:0", - ], - }, + "--config_file=", + "--ckpt_path=", + "--test_data=", + "--test_data_root=", + "--evaluate", + "--log_image_frequency=10", + "--eval_output_dir=", + "--test_batch_size=16", + // "--log_all_image", + "--dataset_path=", + ] + }, ] } diff --git a/CITATION.cff b/CITATION.cff deleted file mode 100644 index 150a075..0000000 --- a/CITATION.cff +++ /dev/null @@ -1,19 +0,0 @@ -cff-version: 1.2.0 -message: "If you use this software, please cite it as follows." -authors: -- family-names: "Karkus" - given-names: "Peter" - orcid: "https://orcid.org/0000-0002-1474-9771" -- family-names: "Ivanovic" - given-names: "Boris" - orcid: "https://orcid.org/0000-0002-8698-202X" -- family-names: "Mannor" - given-names: "Shie" - orcid: "https://orcid.org/0000-0003-4439-7647" -- family-names: "Pavone" - given-names: "Marco" - orcid: "https://orcid.org/0000-0002-0206-4337" -title: "DiffStack: A Differentiable and Modular Control Stack for Autonomous Vehicles." -version: 0.0.1 -date-released: 2022-12-02 -url: "https://github.com/NVlabs/diffstack" \ No newline at end of file diff --git a/README.md b/README.md index b0aa89d..0ffeb6e 100644 --- a/README.md +++ b/README.md @@ -1,143 +1,146 @@ -[![NVIDIA Source Code License](https://img.shields.io/badge/license-NSCL-blue.svg)](https://github.com/NVlabs/diffstack/blob/main/LICENSE.txt) -![Python 3.9](https://img.shields.io/badge/python-3.9-green.svg) +# Differentiable Stack -# DiffStack +Impements Categorical Traffic Transformer in the environment of diffstack. -![drawing](diffstack_modules.png) +Paper [pdf](https://arxiv.org/abs/2311.18307) -This repository contains the code for [DiffStack: A Differentiable and Modular Control Stack for Autonomous Vehicles](https://openreview.net/forum?id=teEnA3L4aRe) a CoRL 2022 paper by Peter Karkus, Boris Ivanovic, Shie Mannor, Marco Pavone. +## Setup -DiffStack is comprised of differentiable modules for prediction, planning, and control. -Importantly, this means that gradients can propagate backwards all the way from the final planning -objective, allowing upstream predictions to be optimized with respect to downstream decision making. +Clone the repo with the desired branch. Use `--recurse-submodules` to also clone various submodules -**Disclaimer** this code is for research purpose only. This is only an alpha release, not product quality code. Expect some rough edges and sparse documentation. +For trajdata, we need to use branch `vectorize`, there are two options: -**Credits:** the code is built on [Trajectron++](https://github.com/StanfordASL/Trajectron-plus-plus), [Differentiable MPC](https://github.com/locuslab/mpc.pytorch), [Unified Trajctory Data Loader](https://github.com/NVlabs/trajdata), and we utilize several other standard libraries. +1. clone from NVlabs and then apply a patch -## Setup +``` +git clone --recurse-submodules --branch main git@github.com:NVlabs/trajdata.git; +cd trajdata; +git fetch origin +git reset --hard 748b8b1 +git apply ../patches/trajdata_vectorize.patch +cd .. +``` -### Install +2. clone from a forked repo of trajdata -Create a conda or virtualenv environment and clone the repository -```bash -conda create -n diffstack python==3.9 -conda activate diffstack -git clone https://github.com/NVlabs/diffstack +``` +git clone --recurse-submodules --branch vectorize git@github.com:chenyx09/trajdata.git ``` -Install diffstack (locally) with pip +Then add Pplan -```bash -cd diffstack -pip install -e ./ +``` +git clone --recurse-submodules git@github.com:NVlabs/spline-planner.git ``` -This single step is sufficient to install all dependencies. For active development, you may prefer to clone and install [Trajectron++](https://github.com/StanfordASL/Trajectron-plus-plus), [Differentiable MPC](https://github.com/locuslab/mpc.pytorch), and [Unified Trajctory Data Loader](https://github.com/NVlabs/trajdata) manually. +You can also sync submodules later using +``` +git submodule update --remote +``` +### Install diffstack -### Prepare data +We will install diffstack with a conda env. -Download and setup the NuScenes dataset following [https://github.com/NVlabs/trajdata/blob/main/DATASETS.md](https://github.com/NVlabs/trajdata/blob/main/DATASETS.md) +Create a `conda` environment for `diffstack`: -The path to the dataset can be specified with the `--data_loc_dict` argument, the default is `--data_loc_dict={\"nusc_mini\": \"./data/nuscenes\"}` +``` +conda create -n diffstack python=3.9 +conda activate diffstack +``` -## Usage +Next install torch pytorch compatible to your CUDA setup following [Pytorch website](https://pytorch.org/get-started/locally/) -We currently support training and evaluation from two data sources. -- Cached data source (`--data_source=cache`) corresponds to a custom preprocessed dataset that allows reproducing results in the paper. -- Trajdata source (`--data_source=trajdata`) is an interface to the [Unified Trajectory Data Loader](https://github.com/NVlabs/trajdata) that supports various data sources including nuScenes, Lyft, etc. This is the recommended way to train new models. -We provide a couple of example commands below. For more argument options use `python ./diffstack/train.py --help` or look at [./diffstack/argument_parser.py](diffstack/argument_parser.py). +Install the required python packages for diffstack +``` +pip install -r requirements.txt +``` -### Training a new model +Install submodules manually (use `-e` for developer mode) +``` +pip install -e ./trajdata +pip install -e ./spline-planner +``` -To train a new model we recommend using the `trajdata` data source and the following default arguments. See [Unified Trajectory Data Loader](https://github.com/NVlabs/trajdata) for setting up different datasets. -An example for data-parallel training on the nuScenes mini dataset using 1 gpu. Remember to change the path to the dataset. Use the full training set instead of minival for meaningful results. -```bash -python -m torch.distributed.run --nproc_per_node=1 ./diffstack/train.py \ - --data_loc_dict={\"nusc_mini\": \"./data/nuscenes\"} \ - --train_data=nusc_mini-mini_train \ - --eval_data=nusc_mini-mini_val \ - --predictor=tpp \ - --plan_cost=corl_default_angle_fix \ - --plan_loss_scaler=100 \ - --plan_loss_scaler2=10 \ - --device=cuda:0 +These additional steps might be necessary ``` +# need to reinstall pathos, gets replaced by multiprocessing install +pip uninstall pathos -y +pip install pathos==0.2.9 -### Reproducing published results +# Fix opencv compatibility issue https://github.com/opencv/opencv-python/issues/591 +pip uninstall opencv-python opencv-python-headless -y +pip install "opencv-python-headless==4.2.0.34" +# pip install "opencv-python-headless==4.7.0.72" # for python 3.9 -To reporduce results you will need to download our preprocessed dataset ([nuScenes-minival](https://drive.google.com/drive/folders/1GTEZdywSautvUbjY8eBzXEoT6VkII3E1?usp=sharing), [nuScenes-full](https://drive.google.com/drive/folders/1Ho2d6L57q4H9UHP53AyhD46zY-Vf2xxN?usp=sharing)) and optionally the pretrained models ([checkpoints](https://drive.google.com/drive/folders/1beTZwXqs0mb9Z1Sb346F2xHx9oeueoFm?usp=sharing)). Please only use the dataset and models if you have agreed to the terms for non-commercial use on [https://www.nuscenes.org/nuscenes](https://www.nuscenes.org/nuscenes). The preprocessed dataset and pretrained models are under the [CC BY-NC-SA 4.0 licence](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +# Sometimes you need to reinstall matplotlib with the correct version +pip install matplotlib==3.3.4 -Evaluate a pretrained DiffStack model with jointly trained prediction-planning-control modules on the nuscenes minival dataset. Remember to update the path to the downloaded data and models. Use the full validation set to reproduce results in the paper. - -```bash -python ./diffstack/train.py \ - --data_source=cache \ - --cached_data_dir=./data/cached_data/ \ - --data_loc_dict={\"nusc_mini\": \"./data/nuscenes\"} \ - --train_data=nusc_mini-mini_train \ - --eval_data=nusc_mini-mini_val \ - --predictor=tpp_cache \ - --dynamic_edges=yes \ - --plan_cost=corl_default \ - --train_epochs=0 \ - --load=./data/pretrained_models/diffstack_model.pt ``` -Retrain a DiffStack model (on cpu) using our preprocessed dataset. Remember to update the path to the downloaded data. When training a new model you should expect semantically similar results as in the paper, but exact reproduction is not possible due to different random seeding of the public code base. +## Key files and code structure -```bash -python ./diffstack/train.py \ - --data_source=cache \ - --cached_data_dir=./data/cached_data/ \ - --data_loc_dict={\"nusc_mini\": \"./data/nuscenes\"} \ - --train_data=nusc_mini-mini_train \ - --eval_data=nusc_mini-mini_val \ - --predictor=tpp_cache \ - --dynamic_edges=yes \ - --plan_cost=corl_default \ - --plan_loss_scaler=100 -``` +Diffstack uses a similar config system as [TBSIM](https://github.com/NVlabs/traffic-behavior-simulation), where the config templates are first defined in python inside the [diffstack/configs](/diffstack/configs/) folder. We separate the configs for [data](/diffstack/configs/trajdata_config.py), [training](/diffstack/configs/base.py), and [models](/diffstack/configs/algo_config.py). -To use 1 gpu and distributed data parallel pipeline use: +The training and evaluation process takes in a JSON file as config, and one can call the [generate_config_templates.py](/diffstack/scripts/generate_config_templates.py) to generate all the template JSON configs, stored in [config/templates](/config/templates/) folder, by taking the default values from the python config files. -```bash -python -m torch.distributed.run --nproc_per_node=1 ./diffstack/train.py \ - --data_source=cache \ - --cached_data_dir=./data/cached_data/ \ - --data_loc_dict={\"nusc_mini\": \"./data/nuscenes\"} \ - --train_data=nusc_mini-mini_train \ - --eval_data=nusc_mini-mini_val \ - --predictor=tpp_cache \ - --dynamic_edges=yes \ - --plan_cost=corl_default \ - --plan_loss_scaler=100 \ - --device=cuda:0 +The models are separetely defined in the [models](/diffstack/models/) folder and [modules](/diffstack/modules/) folder where the former defines the model architecture, the latter wraps the torch model in a unified format called module, defined in [diffstack/modules/module.py](/diffstack/modules/module.py). + +Modules can be chained together to form [stacks](/diffstack/stacks/), which can be trained/evalulated as a whole. For this codebase, we only include CTT, thus the only type of stack is a prediction stack. + +A stack is wrapped as a Pytorch-lightning model for training and evaluation, see [train_pl.py](/diffstack/scripts/train_pl.py) for details. + +The main files of CTT to look for is the [model file](/diffstack/models/CTT.py), and the [module file](/diffstack/modules/predictors/CTT.py). + +We also included a rich collection of [utils functions](/diffstack/utils/), among which many are not used by CTT, but we believe they contribute to creating a convenient code base. + +## Data + +CTT uses [trajdata](https://github.com/NVlabs/trajdata) as the dataloader, technically, you can train with any dataset supported by trajdata. Considering the vectorized map support, we have tested CTT with WOMD, nuScenes, and nuPlan. + + +## Training and eval + +The following examples use nuScenes trainval as dataset, you'll need to prepare the nuScenes dataset following instructions in [trajdata](https://github.com/NVlabs/trajdata). + +Training script: + +``` +python diffstack/scripts/train_pl.py +--config_file=/config/templates/CTTPredStack.json +--remove_exp_dir +--dataset_path= ``` -Train a standard stack that only trains for a prediction objective by setting `--plan_loss_scaler=0`: +Eval script: -```bash -python -m torch.distributed.run --nproc_per_node=1 ./diffstack/train.py \ - --data_source=cache \ - --cached_data_dir=./data/cached_data/ \ - --data_loc_dict={\"nusc_mini\": \"./data/nuscenes\"} \ - --train_data=nusc_mini-mini_train \ - --eval_data=nusc_mini-mini_val \ - --predictor=tpp_cache \ - --dynamic_edges=yes \ - --plan_cost=corl_default \ - --plan_loss_scaler=0 - --device=cuda:0 ``` +python diffstack/scripts/train_pl.py +--evaluate +--config_file= +--ckpt_path= +--test_data= +--test_data_root= +--log_image_frequency=10 +--eval_output_dir= +--test_batch_size=16 +--dataset_path= +``` + +Training and eval example commands are also included in the `.vscode/launch.json` file. + +## Trained models + -## Licence -The source code is released under the [NSCL licence](https://github.com/NVlabs/diffstack/blob/main/LICENSE.txt). The preprocessed dataset and pretrained models are under the [CC BY-NC-SA 4.0 licence](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +| Training dataset | Step time | History horizon | Future horizon | config | checkpoint | +|------------------|-------|------|----|--------|------------| +| nuScenes | 0.25s | 1.5s | 3s | [config](https://drive.google.com/file/d/1fnPX0o2qPVGszFxbYX_LDSYJ221IUFk7/view?usp=drive_link) | [ckpt](https://drive.google.com/file/d/1KvTdJQIEtk50cwiUzMFdl-ZtxpKg52kK/view?usp=drive_link) | +| nuPlan | 0.25s | 1.5s | 3s |[config](https://drive.google.com/file/d/1huNKKlTeT_i3oMOgPL6L1iUT2hKouKtt/view?usp=drive_link) |[ckpt](https://drive.google.com/file/d/1w66sf6sTaoLI-Rl6MFHpuegu0y-i5R0y/view?usp=drive_link) | +| WOMD | 0.2s | 1s | 8s | [config](https://drive.google.com/file/d/1QgsHm3UhY74245YbhsQ4GlTyxyOBpE5y/view?usp=drive_link) | [ckpt](https://drive.google.com/file/d/1qClV16V8jlSMMuPoAavFQeF7qJlb71GV/view?usp=drive_link) | diff --git a/config/diffstack_default.json b/config/diffstack_default.json deleted file mode 100644 index d367ec4..0000000 --- a/config/diffstack_default.json +++ /dev/null @@ -1,154 +0,0 @@ -{ - "learning_rate_style": "exp", - "learning_decay_rate": 0.9999, - - "map_encoder": { - "VEHICLE": { - "patch_size": 100, - "map_channels": 3, - "hidden_channels": [10, 20, 10, 1], - "output_size": 32, - "masks": [5, 5, 5, 3], - "strides": [2, 2, 1, 1], - "dropout": 0.5 - }, - "BICYCLE": { - "patch_size": 100, - "map_channels": 3, - "hidden_channels": [10, 20, 10, 1], - "output_size": 32, - "masks": [5, 5, 5, 3], - "strides": [2, 2, 1, 1], - "dropout": 0.5 - }, - "MOTORCYCLE": { - "patch_size": 100, - "map_channels": 3, - "hidden_channels": [10, 20, 10, 1], - "output_size": 32, - "masks": [5, 5, 5, 3], - "strides": [2, 2, 1, 1], - "dropout": 0.5 - }, - "PEDESTRIAN": { - "patch_size": 100, - "map_channels": 3, - "hidden_channels": [10, 20, 10, 1], - "output_size": 32, - "masks": [5, 5, 5, 5], - "strides": [1, 1, 1, 1], - "dropout": 0.5 - } - }, - - "kl_min": 0.07, - "kl_weight": 100.0, - "kl_weight_start": 0, - "kl_decay_rate": 0.99995, - "kl_crossover": 400, - "kl_sigmoid_divisor": 4, - - "rnn_kwargs": { - "dropout_keep_prob": 0.75 - }, - "MLP_dropout_keep_prob": 0.9, - "enc_rnn_dim_edge": 32, - "enc_rnn_dim_edge_influence": 32, - "enc_rnn_dim_history": 32, - "enc_rnn_dim_future": 32, - "dec_rnn_dim": 128, - - "q_z_xy_MLP_dims": 0, - "p_z_x_MLP_dims": 32, - "GMM_components": 1, - - "log_p_yt_xz_max": 6, - - "K": 25, - "k": 1, - "N": 1, - - "plan_agent_choice": "most_relevant", - "plan_lqr_max_iters": 5, - "plan_lqr_max_linesearch_iters": 5, - "dt": 0.5, - - "filter_plan_valid": true, - "filter_pred_not_parked": true, - "filter_pred_near_ego": false, - "filter_plan_converged": false, - "filter_plan_relevant": false, - "filter_lane_near": false, - - "tau_init": 2.0, - "tau_final": 0.05, - "tau_decay_rate": 0.997, - - "use_z_logit_clipping": true, - "z_logit_clip_start": 0.05, - "z_logit_clip_final": 5.0, - "z_logit_clip_crossover": 300, - "z_logit_clip_divisor": 5, - - "dynamic": { - "PEDESTRIAN": { - "name": "SingleIntegrator", - "distribution": true, - "limits": {} - }, - "VEHICLE": { - "name": "Unicycle", - "distribution": true, - "limits": { - "max_a": 4, - "min_a": -5, - "max_heading_change": 0.7, - "min_heading_change": -0.7 - } - }, - "BICYCLE": { - "name": "Unicycle", - "distribution": true, - "limits": { - "max_a": 4, - "min_a": -5, - "max_heading_change": 0.7, - "min_heading_change": -0.7 - } - }, - "MOTORCYCLE": { - "name": "Unicycle", - "distribution": true, - "limits": { - "max_a": 4, - "min_a": -5, - "max_heading_change": 0.7, - "min_heading_change": -0.7 - } - } - }, - - "state": { - "PEDESTRIAN": { - "position": ["x", "y"], - "velocity": ["x", "y"], - "acceleration": ["x", "y"], - "augment": ["ego_indicator"] - }, - "VEHICLE": { - "position": ["x", "y"], - "velocity": ["x", "y"], - "acceleration": ["x", "y"], - "heading": ["°", "d°"], - "augment": ["ego_indicator"] - } - }, - - "pred_state": { - "VEHICLE": { - "position": ["x", "y"] - } - }, - - "plan_node_types": ["VEHICLE"] -} \ No newline at end of file diff --git a/config/templates/CTTPredStack.json b/config/templates/CTTPredStack.json new file mode 100644 index 0000000..7d1ea83 --- /dev/null +++ b/config/templates/CTTPredStack.json @@ -0,0 +1,230 @@ +{ + "registered_name": "CTTPredStack", + "train": { + "logging": { + "terminal_output_to_txt": true, + "log_tb": false, + "log_wandb": true, + "wandb_project_name": "diffstack", + "log_every_n_steps": 10, + "flush_every_n_steps": 100 + }, + "save": { + "enabled": true, + "every_n_steps": 1000, + "best_k": 10, + "save_best_rollout": false, + "save_best_validation": true + }, + "rollout": { + "save_video": true, + "enabled": false, + "every_n_steps": 5000, + "warm_start_n_steps": 1 + }, + "training": { + "batch_size": 2, + "num_steps": 200000, + "num_data_workers": 8 + }, + "validation": { + "enabled": true, + "batch_size": 2, + "num_data_workers": 6, + "every_n_steps": 500, + "num_steps_per_epoch": 50 + }, + "parallel_strategy": "ddp", + "rebuild_cache": false, + "on_ngc": false, + "amp": false, + "auto_batch_size": false, + "max_batch_size": 1000, + "gradient_clip_val": 0.5, + "trajdata_source_train": "train", + "trajdata_source_valid": "val", + "trajdata_source_test": null, + "trajdata_source_root": "nusc_trainval", + "trajdata_val_source_root": null, + "trajdata_test_source_root": null, + "dataset_path": "SET-THIS-THROUGH-TRAIN-SCRIPT-ARGS", + "datamodule_class": "UnifiedDataModule", + "ego_only": true, + "test": { + "enabled": false, + "batch_size": 32, + "num_data_workers": 6, + "every_n_steps": 500, + "num_steps_per_epoch": 50 + } + }, + "env": { + "name": "nusc_trainval", + "rasterizer": { + "raster_size": 224, + "pixel_size": 0.5, + "ego_center": [ + -0.75, + 0.0 + ] + }, + "data_generation_params": { + "other_agents_num": 10, + "max_agents_distance": 30, + "yaw_correction_speed": 1.0 + }, + "simulation": { + "distance_th_close": 30, + "num_simulation_steps": 50, + "start_frame_index": 0, + "vectorize_lane": "ego" + }, + "incl_neighbor_map": false, + "incl_vector_map": true, + "incl_raster_map": false, + "remove_single_successor": false, + "calc_lane_graph": true, + "max_num_lanes": 15, + "num_lane_pts": 30, + "remove_parked": false + }, + "stack": { + "predictor": { + "name": "CTT", + "step_time": 0.25, + "history_num_frames": 4, + "future_num_frames": 12, + "n_embd": 256, + "n_head": 8, + "use_rpe_net": true, + "PE_mode": "RPE", + "enc_nblock": 3, + "dec_nblock": 3, + "encoder": { + "attn_pdrop": 0.05, + "resid_pdrop": 0.05, + "pooling": "attn", + "edge_scale": 20, + "edge_clip": [ + -4, + 4 + ], + "mode_embed_dim": 64, + "jm_GNN_nblock": 2, + "num_joint_samples": 30, + "num_joint_factors": 6, + "null_lane_mode": true + }, + "edge_dim": { + "a2a": 14, + "a2l": 12, + "l2a": 12, + "l2l": 16 + }, + "a2l_edge_type": "attn", + "a2l_n_embd": 64, + "attn_ntype": { + "a2a": 2, + "a2l": 1, + "l2l": 2 + }, + "lane_GNN_num_layers": 4, + "homotpy_GNN_num_layers": 4, + "hist_lane_relation": "LaneRelation", + "fut_lane_relation": "SimpleLaneRelation", + "classify_a2l_4all_lanes": false, + "closed_loop": false, + "CL_Tf_mode": 6, + "CL_step": 2, + "decoder": { + "arch": "mlp", + "lstm_hidden_size": 128, + "mlp_hidden_dims": [ + 128, + 256 + ], + "traj_dim": 4, + "num_layers": 2, + "dyn": { + "vehicle": "unicycle", + "pedestrian": "DI_unicycle" + }, + "attn_pdrop": 0.05, + "resid_pdrop": 0.05, + "pooling": "attn", + "decode_num_modes": 5, + "AR_step_size": 1, + "GNN_enabled": false, + "AR_update_mode": null, + "dec_rounds": 5 + }, + "num_lane_pts": 30, + "loss_weights": { + "marginal_lm_loss": 5.0, + "marginal_homo_loss": 5.0, + "joint_prob_loss": 5.0, + "xy_loss": 4.0, + "yaw_loss": 1.0, + "lm_consistency_loss": 5.0, + "homotopy_consistency_loss": 5.0, + "yaw_dev_loss": 30.0, + "y_dev_loss": 5.0, + "l2_reg": 0.1, + "coll_loss": 0.1, + "acce_reg_loss": 0.04, + "steering_reg_loss": 0.1, + "input_violation_loss": 20.0, + "jerk_loss": 0.05 + }, + "loss": { + "lm_margin_offset": 0.2 + }, + "weighted_consistency_loss": false, + "LR_sample_hack": true, + "scene_centric": true, + "max_joint_cardinality": 5, + "optim_params": { + "policy": { + "learning_rate": { + "initial": 0.0001, + "decay_factor": 0.1, + "epoch_schedule": [] + }, + "regularization": { + "L2": 0.0 + } + } + } + }, + "stack_type": "pred" + }, + "eval": { + "name": null, + "env": "nusc", + "dataset_path": null, + "eval_class": "", + "seed": 0, + "ckpt_root_dir": "checkpoints/", + "ckpt": { + "dir": null, + "ngc_job_id": null, + "ckpt_dir": null, + "ckpt_key": null + }, + "eval": { + "batch_size": 100, + "num_steps": null, + "num_data_workers": 8 + }, + "log_image_frequency": null, + "log_all_image": false, + "trajdata_source_root": "nusc_trainval", + "trajdata_source_eval": "val" + }, + "name": "test", + "root_dir": "predictor_CTT_trained_models/", + "seed": 1, + "devices": { + "num_gpus": 1 + } +} \ No newline at end of file diff --git a/diffstack/argument_parser.py b/diffstack/argument_parser.py deleted file mode 100644 index 7bc8ec4..0000000 --- a/diffstack/argument_parser.py +++ /dev/null @@ -1,351 +0,0 @@ -import argparse -import json -import torch -import os -import sys - - -parser = argparse.ArgumentParser() -parser.add_argument("--conf", - help="path to json config file for hyperparameters", - default="./config/diffstack_default.json", - type=str) - -parser.add_argument("--experiment", - help="name of experiment for wandb", - type=str, - default='diffstack-def') - -parser.add_argument("--debug", - help="disable all disk writing processes.", - action='store_true') - -parser.add_argument("--preprocess_workers", - help="number of processes to spawn for preprocessing", - type=int, - default=0) - -parser.add_argument('--seed', - help='manual seed to use, default is 123', - type=int, - default=123) - -parser.add_argument('--device', - help='what device to perform training on', - type=str, - default='cuda:0') - -# Data Parameters -parser.add_argument("--data_source", - help="Specifies the source of data [trajdata, trajdata-scene, cache]", - type=str, - default='trajdata') - -parser.add_argument("--data_loc_dict", - help="JSON dict of dataset locations", - type=str, - default='{"nusc_mini": "./data/nuscenes"}') - -parser.add_argument("--cached_data_dir", - help="What dir to look in for cached data when data_source==cache", - type=str, - default='./data/cached_data') - -parser.add_argument('--train_data', - help='name of data split to use for training', - type=str, - default="nusc_mini-mini_train") - -parser.add_argument("--eval_data", - help="name of data split to use for evaluation", - type=str, - default='nusc_mini-mini_val') - -parser.add_argument("--log_dir", - help="what dir to save training information (i.e., saved models, logs, etc)", - type=str, - default='./experiments/') - -parser.add_argument("--trajdata_cache_dir", - help="location of the unified dataloader cache", - type=str, - default='./cache/unified_data_cache') - -parser.add_argument('--rebuild_cache', - help="Rebuild trajdata cache files.", - action='store_true') - -# Training and eval parameters -parser.add_argument('--load', - help='Load pretrained model from checkpoint file if not empty.', - type=str, - default="") - -parser.add_argument('--batch_size', - help='training batch size', - type=int, - default=256) - -parser.add_argument('--eval_batch_size', - help='evaluation batch size', - type=int, - default=256) - -parser.add_argument("--learning_rate", - help="initial learning rate", - type=float, - default=0.003) - -parser.add_argument("--map_enc_learning_rate", - help="map encoder learning rate for trajectron", - type=float, - default=0.00003) - -parser.add_argument("--lr_step", - help="number of epochs after which to step down the LR by 0.1, the default (0) is no step downs", - type=int, - default=0) - -parser.add_argument("--grad_clip", - help="the maximum magnitude of gradients (enforced by clipping)", - type=float, - default=1.0) - -parser.add_argument("--train_epochs", - help="number of epochs to train for", - type=int, - default=20) - -parser.add_argument('--eval_every', - help='how often to evaluate during training, never if None', - type=int, - default=1) - -parser.add_argument('--vis_every', - help='how often to visualize during training, never if None', - type=int, - default=0) - -parser.add_argument('--save_every', - help='how often to save during training, never if None', - type=int, - default=1) - -parser.add_argument('--log_histograms', - help="Log historgrams during training.", - action='store_true') - -parser.add_argument('--cl_trajlen', - help='Length of trajectories for closed loop evaluation in discrete timesteps. No closed loop evaluation if negative.', - type=int, - default=-1) - -# Stack parameters -parser.add_argument('--predictor', - help='Choose prediction model [tpp, constvel]', - type=str, - default="tpp_origin") - -parser.add_argument('--planner', - help='Choice of planner [none, mpc, fan, fan_mpc, cvae]', - type=str, - default="fan_mpc") - -parser.add_argument("--history_sec", - help="required agent history (in seconds)", - type=float, - default=4.0) - -parser.add_argument("--prediction_sec", - help="prediction horizon (in seconds)", - default=3.0, - type=float) - -parser.add_argument('--plan_dt', - help='Custom delta t for planning. Will be the same as prediction dt if not defined or nonpositive.', - type=float, - default=0.) - -parser.add_argument('--plan_cost', - help='Replaces the cost function for planning if specified [[empty], corl_default, corl_default_angle_fix]', - type=str, - default="corl_default_angle_fix") - -parser.add_argument('--plan_init', - help='Options for initializing the planner: fitted, gtplan, zero.', - type=str, - default="zero") - -parser.add_argument('--plan_lqr_eps', - help='Maximum change in norm of control vector for iLQR to be considered converged.', - type=float, - default=0.01) - -parser.add_argument('--plan_loss', - help='Options for the planning loss [mse, joint_hcost2, class_hcost2, hcost]', - type=str, - default="joint_hcost2") - -parser.add_argument('--no_plan_train', - help="Disable planning during training.", - action='store_true') - -parser.add_argument('--no_train_pred', - help="Disable updating the prediction model during training.", - action='store_true') - -parser.add_argument('--train_plan_cost', - help="Train the planners cost function.", - action='store_true') - -parser.add_argument('--bias_predictions', - help="Add a bias to prediction targets if True.", - action='store_true') - -parser.add_argument('--pred_loss_scaler', - help='Scaler for prediction loss.', - type=float, - default=1.0) - -parser.add_argument('--plan_loss_scaler', - help='Scaler for planning loss when added to prediction loss.', - type=float, - default=100.0) - -parser.add_argument('--plan_loss_scaler2', - help='Scaler for the second planning loss term (when used).', - type=float, - default=10.0) - -parser.add_argument('--pred_loss_weights', - help='Custom scheme to weight prediction losses across samples in a batch [none, dist, grad]', - type=str, - default="none") - -parser.add_argument('--pred_loss_temp', - help='Temperature parameter for prediction loss weighting scheme.', - type=float, - default=1.0) - -parser.add_argument('--cost_grad_scaler', - help='Scaler for the gradient of cost weights when training the planning cost.', - type=float, - default=0.001) - -parser.add_argument('--plan_loss_scale_start', - help='Epoch where to start (linearly) increasing planning loss scaler.', - type=int, - default=-1) - -parser.add_argument('--plan_loss_scale_end', - help='Epoch where to end (linearly) increasing planning loss scaler.', - type=int, - default=-1) - -# Trajectron++ -parser.add_argument('--K', - help='how many CVAE discrete latent modes to have in the model', - type=int, - default=25) - -parser.add_argument('--k_eval', - help='how many samples to take during evaluation', - type=int, - default=25) - -parser.add_argument('--pred_ego_indicator', - help="The type of ego indicator to use in input to predictor [none, most_relevant].", - type=str, - default='most_relevant') - -parser.add_argument("--dynamic_edges", - help="whether to use dynamic edges or not, options are 'no' and 'yes'", - type=str, - default='no') - -parser.add_argument("--edge_state_combine_method", - help="the method to use for combining edges of the same type", - type=str, - default='sum') - -parser.add_argument("--edge_influence_combine_method", - help="the method to use for combining edge influences", - type=str, - default='attention') - -parser.add_argument('--edge_addition_filter', - nargs='+', - help="what scaling to use for edges as they're created", - type=float, - default=[0.25, 0.5, 0.75, 1.0]) # We don't automatically pad left with 0.0, if you want a sharp - # and short edge addition, then you need to have a 0.0 at the - # beginning, e.g. [0.0, 1.0]. - -parser.add_argument('--edge_removal_filter', - nargs='+', - help="what scaling to use for edges as they're removed", - type=float, - default=[1.0, 0.0]) # We don't automatically pad right with 0.0, if you want a sharp drop off like - # the default, then you need to have a 0.0 at the end. - -parser.add_argument('--incl_robot_node', - help="whether to include a robot node in the graph or simply model all agents", - action='store_true') - -parser.add_argument('--map_encoding', - help="Whether to use map encoding or not", - action='store_true') - -parser.add_argument('--augment_input_noise', - help="Standard deviation of Gaussian noise to add the inputs during training, not performed if 0.0", - type=float, - default=0.0) - -parser.add_argument('--no_edge_encoding', - help="Whether to use neighbors edge encoding", - action='store_true') - - -args = parser.parse_args() - - -def get_hyperparams(args): - # Load hyperparameters from json - if not os.path.exists(args.conf): - print('Config json not found!') - with open(args.conf, 'r', encoding='utf-8') as conf_json: - hyperparams = json.load(conf_json) - - # Add all arguments as hyperparams - hyperparams.update({k: v for k, v in vars(args).items() if v is not None}) - if torch.distributed.is_initialized(): - hyperparams['world_size'] = torch.distributed.get_world_size() # number of torch workers - else: - hyperparams['world_size'] = 1 - - # Special arguments - hyperparams['plan_dt'] = (args.plan_dt if args.plan_dt > 0. else hyperparams['dt']) - hyperparams['edge_encoding'] = not args.no_edge_encoding - hyperparams['plan_train'] = not args.no_plan_train - hyperparams['train_pred'] = not args.no_train_pred - - # Distributed LR Scaling - hyperparams['learning_rate'] *= hyperparams['world_size'] - - return hyperparams - - -def print_hyperparams_summary(hyperparams): - print('-----------------------') - print('| PARAMETERS |') - print('-----------------------') - print('| Experiment: %s' % hyperparams["experiment"]) - print('| Batch Size: %d' % hyperparams["batch_size"]) - print('| Eval Batch Size: %d' % hyperparams["eval_batch_size"]) - print('| Device: %s %d' % (hyperparams["device"], hyperparams["world_size"])) - print('| Learning Rate: %s' % hyperparams['learning_rate']) - print('| Learning Rate Step Every: %s' % hyperparams["lr_step"]) - print('| Max History: %ss' % hyperparams['history_sec']) - print('| Max Future: %ss' % hyperparams['prediction_sec']) - print('| Args: %s' % " ".join(sys.argv[1:])) - print('-----------------------') diff --git a/diffstack/closed_loop_eval.py b/diffstack/closed_loop_eval.py deleted file mode 100644 index da097d6..0000000 --- a/diffstack/closed_loop_eval.py +++ /dev/null @@ -1,225 +0,0 @@ -import torch -import numpy as np - -from collections import defaultdict -from tqdm import tqdm - -# Dataset related -from trajectron.trajectron.model.dataset import collate # collate needs to be imported like this for cached files to recognize the class. If changed caching has to be redone. - -# Model related - -from trajectron.trajectron.model.dataset.preprocessing import get_node_timestep_data, get_node_closest_to_robot, pred_state_to_plan_state, plan_state_to_pred_state # need to import like this to match class instance comparison -from trajectron.trajectron.model.dataset import restore # import from here for legacy reasons, if changed isinstance will fail -from mpc.util import get_traj - - -def simulate_scenarios_in_scene(trajectron_module, nusc_maps, env, scene, node_type, hyperparams, replan_every_ns=1, all_scenarios_in_scene=False): - max_hl = hyperparams['maximum_history_length'] - # Unroll once but calculate metrics at different unroll lengths defined by scene_eval_trajlens - scene_eval_trajlens = [hyperparams["cl_trajlen"], hyperparams["prediction_horizon"]] - # scene_eval_trajlens = [hyperparams["cl_trajlen"]] - - eval_loss = defaultdict(list) - - # Get valid timesteps and nodes for the scene - timestep = np.arange(scene.timesteps) - nodes_per_ts = scene.present_nodes(timestep, - type=node_type, - min_history_timesteps=max_hl, - min_future_timesteps=1, - return_robot=not hyperparams['incl_robot_node']) - valid_timesteps = np.array(sorted(nodes_per_ts.keys())) - - # Construct a set of closed-loop scenarios defined by the start timestep - if all_scenarios_in_scene: - valid_start_timesteps = valid_timesteps[:-(hyperparams["cl_trajlen"])] - else: - valid_start_timesteps = valid_timesteps[:1] - - # Iterate over scenarios in scene - for t_sim_start in valid_start_timesteps: - - # iterate over replan frequencies - for replan_every in replan_every_ns: - scenario_metrics, sim_hist, ego_sim_hist = simulate_scenario(trajectron_module, nusc_maps, env, scene, node_type, nodes_per_ts, t_sim_start, hyperparams, replan_every=replan_every) - if not scenario_metrics: - print ("Skip scene, no valid segment for planning") - continue - - # # # Visualize - # visualize_closed_loop(sim_hist, scenario_metrics, scene, nusc_maps) - - # Done with unroll, average metrics through time - for metric, values_through_time in scenario_metrics.items(): - # Mean through time and dummy batch dimension - for scene_eval_trajlen in scene_eval_trajlens: - eval_loss[f'cl{scene_eval_trajlen}r{replan_every}_' + metric].append( - np.array(values_through_time[:scene_eval_trajlen]).mean(axis=0)[None]) - eval_loss[f"clr{replan_every}_unroll_len"].append(np.array([ego_sim_hist.shape[0]], np.float32)) - eval_loss[f"clr{replan_every}_active_len"].append(np.array([len(scenario_metrics['mse_t0'])], np.float32)) - - return eval_loss - - -def simulate_scenario(trajectron_module, nusc_maps, env, scene, node_type, nodes_per_ts, t_sim_start, hyperparams, replan_every=1): - max_hl = hyperparams['maximum_history_length'] - - ego_sim_hist = None # "position": ["x", "y"], "velocity": ["x", "y"], "acceleration": ["x", "y"], "heading": ["°", "d°"], - all_sim_hist = defaultdict(lambda: defaultdict(list)) # all_sim_hist[t][node_type] --> [N][state_dim] - plan_hist = defaultdict(list) # all_sim_hist[var] --> [t] - scenario_metrics = defaultdict(list) - - steps_since_plan = replan_every - last_plan = None - - # Return if too short - if t_sim_start + hyperparams["cl_trajlen"] not in nodes_per_ts: - return scenario_metrics, all_sim_hist, ego_sim_hist - - # Unroll time - for t_sim in range(t_sim_start, t_sim_start + hyperparams["cl_trajlen"]): - # In closed-loop eval we first pick ego, then the predicted agent, which is the opposite of open-loop. - node = get_node_closest_to_robot(scene, t_sim, node_type, nodes=nodes_per_ts[t_sim]) - # TODO support closest nodes without complete history. Currently present_nodes() filters out these to avoid assert in predict_and_evaluate_batch(). - if node is None: - print ("No closest node found!") - return scenario_metrics, all_sim_hist, ego_sim_hist - - # Run preprocessing steps to construct a dummy batch - sample, (neighbors_data_not_st, logged_robot_data, robot_idx) = get_node_timestep_data( - env, scene, t_sim, node, trajectron_module.state, trajectron_module.pred_obj.pred_state, - env.get_edge_types(), max_hl, hyperparams["prediction_horizon"], hyperparams, nusc_maps=nusc_maps, - is_closed_loop=True, - closed_loop_ego_hist=(None if ego_sim_hist is None else ego_sim_hist[-(max_hl+1):])) - sample = trajectron_module.planner_obj.augment_sample_with_dummy_plan_info(sample) - batch = collate([sample]) - if logged_robot_data is not None: - logged_robot_data = logged_robot_data.to(trajectron_module.device) - batch_i = 0 - - # Log - for log_node_type in env.NodeType: - if (node_type, log_node_type) in neighbors_data_not_st: - for node_idx, node_hist in enumerate(neighbors_data_not_st[(node_type, log_node_type)]): - if node_idx == robot_idx: - all_sim_hist[t_sim]['ROBOT_SIM'].append(node_hist[-1].cpu().numpy()) - else: - all_sim_hist[t_sim][log_node_type].append(node_hist[-1].cpu().numpy()) - if logged_robot_data is not None: - all_sim_hist[t_sim]['ROBOT_LOG'].append(logged_robot_data[-1].cpu().numpy()) - x_t = sample[1] - all_sim_hist[t_sim]['PRED'].append(x_t[-1].cpu().numpy()) - - if robot_idx < 0: - # When history is incomplete we cannot plan, so we will just follow the log - continue - - ego_log_hist = neighbors_data_not_st[(node_type, 'VEHICLE')][robot_idx] - if ego_sim_hist is None: - ego_sim_hist = ego_log_hist.cpu().numpy() - - # Make sure past steps match - assert np.isclose(ego_log_hist.cpu().numpy(), ego_sim_hist[-ego_log_hist.shape[0]:]).all() - # # Track deviation - # print (np.isclose(logged_robot_data.cpu().numpy(), ego_sim_hist[-ego_log_hist.shape[0]:]).all(axis=1)) - # print (np.linalg.norm(logged_robot_data.cpu().numpy()[..., :2] - ego_sim_hist[-ego_log_hist.shape[0]:, ..., :2], axis=-1)) - plan_x0 = pred_state_to_plan_state(ego_sim_hist[-1])[..., :4] - plan_x0 = torch.tensor(plan_x0, dtype=torch.float, device=trajectron_module.device) # x, y, yaw, vel - - # Inference. Choose between replan or follow last plan. - if last_plan is None or steps_since_plan >= replan_every: - # Replan - eval_loss_node_type, plot_data = trajectron_module.predict_and_evaluate_batch(batch, node_type, max_hl, return_plot_data=True) - # Recover plan inputs - plan_metrics, plan_iters = plot_data['plan'] - # plan_batch_filter = plan_iters['plan_batch_filter'] - plan_x = plan_iters['x'][-1][:, batch_i] - plan_u = plan_iters['u'][-1][:, batch_i] - assert torch.isclose(plan_x0, plan_x[0]).all() - last_plan = (plan_x, plan_u) - steps_since_plan = 1 - - # Metrics for planning - scenario_metrics['plan_hcost'].append(eval_loss_node_type["plan_hcost"][0].cpu().numpy()) - scenario_metrics['ade_unbiased'].append(eval_loss_node_type["ade_unbiased"][0].cpu().numpy()) - scenario_metrics['nll_mean_unbiased'].append(eval_loss_node_type["nll_mean_unbiased"][0].cpu().numpy()) - # It's not meaningful to log class_hcost_valid, because whether plan is valid will be different depending on ego state. - # if 'class_hcost_valid' in eval_loss_node_type and eval_loss_node_type["class_hcost_valid"].shape[0] > 0: - # scenario_metrics['class_hcost_valid'].append(eval_loss_node_type["class_hcost_valid"][0].cpu().numpy()) - - # Log for final cost - plan_hist['plan_all_gt_neighbors_batch'].append(plan_iters['all_gt_neighbors'][:, :replan_every].squeeze(2)) - plan_hist['lanes'].append(plan_iters['lanes'][:replan_every].squeeze(1)) - plan_hist['plan_x'].append(plan_x[:replan_every]) - plan_hist['plan_u'].append(plan_u[:replan_every]) - - else: - # Use the last plan - plan_x, plan_u = last_plan - plan_x = plan_x[steps_since_plan:] - plan_u = plan_u[steps_since_plan:] - steps_since_plan += 1 - - # Unroll unnormalized state with first control step - plan_x_unroll = get_traj(2, plan_u[:2].unsqueeze(1), plan_x0.unsqueeze(0), trajectron_module.planner_obj.dyn_obj).squeeze(1) - new_ego_state = torch.concat([plan_x_unroll[1], plan_u[0]], dim=-1).cpu().numpy() - - # Convert from planner's state to predictor's state: - # (x, y, h, v), (acc, dh) --> x, y, vx, vy, ax, ay, heading, delta_heading - # TODO not sure if prediction state represent acceleration as v[t]-v[t-1] or v[t+1]-v[t] \ - # but it shouldnt matter much - new_ego_state = plan_state_to_pred_state(new_ego_state) - # Update sim state - ego_sim_hist = np.append(ego_sim_hist, new_ego_state[None], axis=0) - plan_hist['logged_x'].append(logged_robot_data[-1]) - - # Metrics - t_metric = 1 - hcost_t = plan_iters['hcost_components'][t_metric, batch_i] - for i in range(hcost_t.shape[0]): - scenario_metrics[f'hcost_t{t_metric}_comp{i}'].append(hcost_t[i].cpu().numpy()) - scenario_metrics[f'hcost_t{t_metric}'].append(hcost_t[:].sum(dim=0).cpu().numpy()) - - icost_t = plan_iters['icost_components'][t_metric, batch_i] - for i in range(icost_t.shape[0]): - scenario_metrics[f'icost_t{t_metric}_comp{i}'].append(icost_t[i].cpu().numpy()) - - mse_t0 = torch.linalg.norm(plan_x0[:2] - logged_robot_data[-1][:2], dim=-1) - scenario_metrics['mse_t0'].append(mse_t0.cpu().numpy()) - - # Eval cost for full trajectory - if len(plan_hist['plan_x']) >= 2: - plan_xu = torch.concat([torch.concat(plan_hist['plan_x'], dim=0), torch.concat(plan_hist['plan_u'], dim=0)], dim=-1).unsqueeze(1) # T+1, b - plan_all_gt_neighbors_batch = torch.concat(plan_hist['plan_all_gt_neighbors_batch'], dim=1).unsqueeze(2) # N, T+1, b - plan_all_gt_neighbors_batch = plan_all_gt_neighbors_batch[:, 1:] # normall we dont have t0, N, T, b - goal_batch = plan_hist['logged_x'][-1][..., :2].unsqueeze(0) # b - lanes = torch.concat(plan_hist['lanes'], dim=0).unsqueeze(1) # T+1, b - empty_mus_batch = torch.zeros((0, plan_all_gt_neighbors_batch.shape[1], 1, 1, 2), dtype=torch.float, device=trajectron_module.device) - empty_logp_batch = torch.zeros((0, 1, 1), dtype=torch.float, device=trajectron_module.device) - lane_points = None - - hcost_components = trajectron_module.planner_obj.cost_obj( - plan_xu, cost_inputs=(plan_all_gt_neighbors_batch, empty_mus_batch, empty_logp_batch, goal_batch, lanes, lane_points), - keep_components=True) # (T, b, c) - hcost_components = hcost_components.squeeze(1) - - icost_components = trajectron_module.planner_obj.interpretable_cost_obj( - plan_xu, cost_inputs=(plan_all_gt_neighbors_batch, empty_mus_batch, empty_logp_batch, goal_batch, lanes, lane_points), - keep_components=True) # (T, b, c) - icost_components = icost_components.squeeze(1) - - for i in range(hcost_components.shape[1]): - scenario_metrics[f'hcost_traj_comp{i}'] = hcost_components[:, i].cpu().numpy() - scenario_metrics[f'hcost_traj'] = hcost_components.sum(dim=1).cpu().numpy() - - for i in range(icost_components.shape[1]): - scenario_metrics[f'icost_traj_comp{i}'] = icost_components[:, i].cpu().numpy() - - sim_hist = {k: [el.cpu().numpy() for el in v] for k, v in plan_hist.items()} - - return scenario_metrics, sim_hist, ego_sim_hist - - - -# def run_closed_loop_eval(eval_data_loader, trajectron_module, nusc_maps, log_writer, env, hyperparams, epoch, rank, all_scenarios_in_scene=False): diff --git a/diffstack/configs/__init__.py b/diffstack/configs/__init__.py new file mode 100644 index 0000000..2c33122 --- /dev/null +++ b/diffstack/configs/__init__.py @@ -0,0 +1,2 @@ +from diffstack.configs.base import ExperimentConfig + diff --git a/diffstack/configs/algo_config.py b/diffstack/configs/algo_config.py new file mode 100644 index 0000000..cf9007e --- /dev/null +++ b/diffstack/configs/algo_config.py @@ -0,0 +1,150 @@ +import math + +from diffstack.configs.base import AlgoConfig + + +class BehaviorCloningConfig(AlgoConfig): + def __init__(self): + super(BehaviorCloningConfig, self).__init__() + self.eval_class = "BC" + + self.name = "bc" + self.model_architecture = "resnet18" + self.map_feature_dim = 256 + self.history_num_frames = 8 + self.history_num_frames_ego = 8 + self.history_num_frames_agents = 8 + self.future_num_frames = 6 + self.step_time = 0.5 + self.render_ego_history = False + + self.decoder.layer_dims = () + self.decoder.state_as_input = True + + self.dynamics.type = "Unicycle" + self.dynamics.max_steer = 0.5 + self.dynamics.max_yawvel = math.pi * 2.0 + self.dynamics.acce_bound = (-10, 8) + self.dynamics.ddh_bound = (-math.pi * 2.0, math.pi * 2.0) + self.dynamics.max_speed = 40.0 # roughly 90mph + + self.spatial_softmax.enabled = False + self.spatial_softmax.kwargs.num_kp = 32 + self.spatial_softmax.kwargs.temperature = 1.0 + self.spatial_softmax.kwargs.learnable_temperature = False + + self.loss_weights.prediction_loss = 1.0 + self.loss_weights.goal_loss = 0.0 + self.loss_weights.collision_loss = 0.0 + self.loss_weights.yaw_reg_loss = 0.001 + + self.optim_params.policy.learning_rate.initial = 1e-3 # policy learning rate + self.optim_params.policy.learning_rate.decay_factor = ( + 0.1 # factor to decay LR by (if epoch schedule non-empty) + ) + self.optim_params.policy.learning_rate.epoch_schedule = ( + [] + ) # epochs where LR decay occurs + self.optim_params.policy.regularization.L2 = 0.00 # L2 regularization strength + self.checkpoint.enabled = False + self.checkpoint.path = None + + +class CTTConfig(AlgoConfig): + def __init__(self): + super(CTTConfig, self).__init__() + self.name = "CTT" + self.step_time = 0.25 + self.history_num_frames = 4 + self.future_num_frames = 12 + + self.n_embd = 256 + self.n_head = 8 + self.use_rpe_net = True + self.PE_mode = "RPE" # "RPE" or "PE" + + self.enc_nblock = 3 + self.dec_nblock = 3 + + self.encoder.attn_pdrop = 0.05 + self.encoder.resid_pdrop = 0.05 + self.encoder.pooling = "attn" + self.encoder.edge_scale = 20 + self.encoder.edge_clip = [-4, 4] + self.encoder.mode_embed_dim = 64 + self.encoder.jm_GNN_nblock = 2 + self.encoder.num_joint_samples = 30 + self.encoder.num_joint_factors = 6 + self.encoder.null_lane_mode = True + + self.edge_dim.a2a = 14 + self.edge_dim.a2l = 12 + self.edge_dim.l2a = 12 + self.edge_dim.l2l = 16 + self.a2l_edge_type = "attn" + self.a2l_n_embd = 64 + + self.attn_ntype.a2a = 2 + self.attn_ntype.a2l = 1 + self.attn_ntype.l2l = 2 + + self.lane_GNN_num_layers = 4 + self.homotpy_GNN_num_layers = 4 + + self.hist_lane_relation = "LaneRelation" + self.fut_lane_relation = "SimpleLaneRelation" + self.classify_a2l_4all_lanes = ( + False # Alternative is 4all modes, find resulting lane + ) + + self.decoder.arch = "mlp" + self.decoder.lstm_hidden_size = 128 + self.decoder.mlp_hidden_dims = [128, 256] + self.decoder.traj_dim = 4 # x,y,v,yaw + self.decoder.num_layers = 2 + self.decoder.dyn.vehicle = "unicycle" + self.decoder.dyn.pedestrian = "DI_unicycle" + self.decoder.attn_pdrop = 0.05 + self.decoder.resid_pdrop = 0.05 + self.decoder.pooling = "attn" + self.decoder.decode_num_modes = 5 + self.decoder.AR_step_size = 1 + self.decoder.GNN_enabled = False + self.decoder.AR_update_mode = None + self.decoder.dec_rounds = 5 + + self.num_lane_pts = 30 + + self.loss_weights.marginal_lm_loss = 5.0 + self.loss_weights.marginal_homo_loss = 5.0 + self.loss_weights.joint_prob_loss = 5.0 + self.loss_weights.xy_loss = 4.0 + self.loss_weights.yaw_loss = 1.0 + self.loss_weights.lm_consistency_loss = 5.0 + self.loss_weights.homotopy_consistency_loss = 5.0 + self.loss_weights.yaw_dev_loss = 30.0 + self.loss_weights.y_dev_loss = 5.0 + self.loss_weights.l2_reg = 0.1 + self.loss_weights.coll_loss = 0.1 + self.loss_weights.acce_reg_loss = 0.04 + self.loss_weights.steering_reg_loss = 0.1 + self.loss_weights.input_violation_loss = 20.0 + self.loss_weights.jerk_loss = 0.05 + + self.loss.lm_margin_offset = 0.2 + + self.weighted_consistency_loss = False + self.LR_sample_hack = True + + self.scene_centric = True + + self.max_joint_cardinality = 5 + + self.optim_params.policy.learning_rate.initial = 1e-4 # policy learning rate + self.optim_params.policy.learning_rate.decay_factor = ( + 0.1 # factor to decay LR by (if epoch schedule non-empty) + ) + self.optim_params.policy.learning_rate.epoch_schedule = ( + [] + ) # epochs where LR decay occurs + self.optim_params.policy.regularization.L2 = 0.00 # L2 regularization strength diff --git a/diffstack/configs/base.py b/diffstack/configs/base.py new file mode 100644 index 0000000..31edf6f --- /dev/null +++ b/diffstack/configs/base.py @@ -0,0 +1,139 @@ +from diffstack.configs.config import Dict +from copy import deepcopy +from diffstack.configs.eval_config import EvaluationConfig + + +class TrainConfig(Dict): + def __init__(self): + super(TrainConfig, self).__init__() + self.logging.terminal_output_to_txt = True # whether to log stdout to txt file + self.logging.log_tb = False # enable tensorboard logging + self.logging.log_wandb = True # enable wandb logging + self.logging.wandb_project_name = "diffstack" + self.logging.log_every_n_steps = 10 + self.logging.flush_every_n_steps = 100 + + ## save config - if and when to save model checkpoints ## + self.save.enabled = True # whether model saving should be enabled or disabled + self.save.every_n_steps = 100 # save model every n epochs + self.save.best_k = 5 + self.save.save_best_rollout = False + self.save.save_best_validation = True + + ## evaluation rollout config ## + self.rollout.save_video = True + self.rollout.enabled = False # enable evaluation rollouts + self.rollout.every_n_steps = 1000 # do rollouts every @rate epochs + self.rollout.warm_start_n_steps = ( + 1 # number of steps to wait before starting rollouts + ) + + ## training config + self.training.batch_size = 100 + self.training.num_steps = 200000 + self.training.num_data_workers = 0 + + ## validation config + self.validation.enabled = True + self.validation.batch_size = 100 + self.validation.num_data_workers = 0 + self.validation.every_n_steps = 1000 + self.validation.num_steps_per_epoch = 100 + + ## Training parallelism (e.g., multi-GPU) + self.parallel_strategy = "ddp" + + self.rebuild_cache = False + + self.on_ngc = False + + # AMP + self.amp = False + + # auto batch size + self.auto_batch_size = False + self.max_batch_size = 1000 + # graidient clipping + self.gradient_clip_val = 0.5 + + +class EnvConfig(Dict): + def __init__(self): + super(EnvConfig, self).__init__() + self.name = "my_env" + + +class AlgoConfig(Dict): + def __init__(self): + super(AlgoConfig, self).__init__() + self.name = "my_algo" + + +class ExperimentConfig(Dict): + def __init__( + self, + train_config: TrainConfig, + env_config: EnvConfig, + module_configs: dict, + eval_config: EvaluationConfig = None, + registered_name: str = None, + stack_type: str = None, + name: str = None, + root_dir=None, + seed=None, + devices=None, + ): + """ + + Args: + train_config (TrainConfig): training config + env_config (EnvConfig): environment config + module_configs dict(AlgoConfig): algorithm configs for all modules + registered_name (str): name of the experiment config object in the global config registry + """ + super(ExperimentConfig, self).__init__() + self.registered_name = registered_name + + self.train = train_config + self.env = env_config + self.stack = module_configs + self.stack.stack_type = stack_type + self.eval = EvaluationConfig() if eval_config is None else eval_config + + # Write all results to this directory. A new folder with the timestamp will be created + # in this directory, and it will contain three subfolders - "log", "models", and "videos". + # The "log" directory will contain tensorboard and stdout txt logs. The "models" directory + # will contain saved model checkpoints. The "videos" directory contains evaluation rollout + # videos. + self.name = ( + "test" # name of the experiment (creates a subdirectory under root_dir) + ) + stack_name = "" + for key, config in self.stack.items(): + if isinstance(config, dict): + stack_name += key + "_" + config["name"] + + self.root_dir = ( + "{}_trained_models/".format(stack_name) if root_dir is None else root_dir + ) + self.seed = ( + 1 if seed is None else seed + ) # seed for everything (for reproducibility) + + self.devices = ( + Dict(num_gpus=1) if devices is None else devices + ) # Set to 0 to use CPU + + def clone(self): + return self.__class__( + train_config=deepcopy(self.train), + env_config=deepcopy(self.env), + module_configs=deepcopy(self.stack), + eval_config=deepcopy(self.eval), + registered_name=self.registered_name, + stack_type=self.stack.stack_type, + name=self.name, + root_dir=self.root_dir, + seed=self.seed, + devices=self.devices, + ) diff --git a/diffstack/configs/config.py b/diffstack/configs/config.py new file mode 100644 index 0000000..9dec7a4 --- /dev/null +++ b/diffstack/configs/config.py @@ -0,0 +1,190 @@ +""" +Basic config class - provides a convenient way to work with nested +dictionaries (by exposing keys as attributes) and to save / load from jsons. + +Based on addict: https://github.com/mewwts/addict +""" + +import json +import copy +import contextlib +from copy import deepcopy + + +class Dict(dict): + + def __init__(__self, *args, **kwargs): + object.__setattr__(__self, '__parent', kwargs.pop('__parent', None)) + object.__setattr__(__self, '__key', kwargs.pop('__key', None)) + object.__setattr__(__self, '__frozen', False) + for arg in args: + if not arg: + continue + elif isinstance(arg, dict): + for key, val in arg.items(): + __self[key] = __self._hook(val) + elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)): + __self[arg[0]] = __self._hook(arg[1]) + else: + for key, val in iter(arg): + __self[key] = __self._hook(val) + + for key, val in kwargs.items(): + __self[key] = __self._hook(val) + + def __setattr__(self, name, value): + if hasattr(self.__class__, name): + raise AttributeError("'Dict' object attribute " + "'{0}' is read-only".format(name)) + else: + self[name] = value + + def __setitem__(self, name, value): + isFrozen = (hasattr(self, '__frozen') and + object.__getattribute__(self, '__frozen')) + if isFrozen and name not in super(Dict, self).keys(): + raise KeyError(name) + super(Dict, self).__setitem__(name, value) + try: + p = object.__getattribute__(self, '__parent') + key = object.__getattribute__(self, '__key') + except AttributeError: + p = None + key = None + if p is not None: + p[key] = self + object.__delattr__(self, '__parent') + object.__delattr__(self, '__key') + + def __add__(self, other): + if not self.keys(): + return other + else: + self_type = type(self).__name__ + other_type = type(other).__name__ + msg = "unsupported operand type(s) for +: '{}' and '{}'" + raise TypeError(msg.format(self_type, other_type)) + + @classmethod + def _hook(cls, item): + if isinstance(item, dict): + return cls(item) + elif isinstance(item, (list, tuple)): + return type(item)(cls._hook(elem) for elem in item) + return item + + def __getattr__(self, item): + return self.__getitem__(item) + + def __missing__(self, name): + if object.__getattribute__(self, '__frozen'): + raise KeyError(name) + return Dict(__parent=self, __key=name) + + def __delattr__(self, name): + del self[name] + + def __repr__(self): + json_string = json.dumps(self.to_dict(), indent=4) + return json_string + + def to_dict(self): + base = {} + for key, value in self.items(): + if isinstance(value, type(self)): + base[key] = value.to_dict() + elif isinstance(value, (list, tuple)): + base[key] = type(value)( + item.to_dict() if isinstance(item, type(self)) else + item for item in value) + else: + base[key] = value + return base + + def copy(self): + return copy.copy(self) + + def deepcopy(self): + return copy.deepcopy(self) + + def __deepcopy__(self, memo): + other = self.__class__() + memo[id(self)] = other + for key, value in self.items(): + other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo) + return other + + def update(self, *args, **kwargs): + other = {} + if args: + if len(args) > 1: + raise TypeError() + other.update(args[0]) + other.update(kwargs) + for k, v in other.items(): + if ((k not in self) or + (not isinstance(self[k], dict)) or + (not isinstance(v, dict))): + self[k] = v + else: + self[k].update(v) + + def __getnewargs__(self): + return tuple(self.items()) + + def __getstate__(self): + return self + + def __setstate__(self, state): + self.update(state) + + def __or__(self, other): + if not isinstance(other, (Dict, dict)): + return NotImplemented + new = Dict(self) + new.update(other) + return new + + def __ror__(self, other): + if not isinstance(other, (Dict, dict)): + return NotImplemented + new = Dict(other) + new.update(self) + return new + + def __ior__(self, other): + self.update(other) + return self + + def setdefault(self, key, default=None): + if key in self: + return self[key] + else: + self[key] = default + return default + + def lock(self, should_lock=True): + object.__setattr__(self, '__frozen', should_lock) + for key, val in self.items(): + if isinstance(val, Dict): + val.lock(should_lock) + + def dump(self, filename = None): + json_string = json.dumps(self.to_dict(), indent=4) + if filename is not None: + f = open(filename, "w") + f.write(json_string) + f.close() + return json_string + + def unlock(self): + self.lock(False) + + @contextlib.contextmanager + def unlocked(self): + self.unlock() + yield + self.lock() + + def clone(self): + return deepcopy(self) \ No newline at end of file diff --git a/diffstack/configs/eval_config.py b/diffstack/configs/eval_config.py new file mode 100644 index 0000000..cb479c4 --- /dev/null +++ b/diffstack/configs/eval_config.py @@ -0,0 +1,126 @@ +import numpy as np +from copy import deepcopy + +from diffstack.configs.config import Dict + + +class SimEvaluationConfig(Dict): + def __init__(self): + super(SimEvaluationConfig, self).__init__() + self.name = None + self.env = "nusc" + self.dataset_path = None + self.eval_class = "" + self.seed = 0 + self.num_scenes_per_batch = 4 + self.num_scenes_to_evaluate = 100 + + self.num_episode_repeats = 3 + self.start_frame_index_each_episode = None # if specified, should be the same length as num_episode_repeats + self.seed_each_episode = None # if specified, should be the same length as num_episode_repeats + + self.ego_only = False + self.agent_eval_class = None + + self.ckpt_root_dir = "checkpoints/" + self.experience_hdf5_path = None + self.results_dir = "results/" + + self.ckpt.policy.ngc_job_id = None + self.ckpt.policy.ckpt_dir = None + self.ckpt.policy.ckpt_key = None + + self.ckpt.planner.ngc_job_id = None + self.ckpt.planner.ckpt_dir = None + self.ckpt.planner.ckpt_key = None + + self.ckpt.predictor.ngc_job_id = None + self.ckpt.predictor.ckpt_dir = None + self.ckpt.predictor.ckpt_key = None + + self.ckpt.cvae_metric.ngc_job_id = None + self.ckpt.cvae_metric.ckpt_dir = None + self.ckpt.cvae_metric.ckpt_key = None + + self.ckpt.occupancy_metric.ngc_job_id = None + self.ckpt.occupancy_metric.ckpt_dir = None + self.ckpt.occupancy_metric.ckpt_key = None + + self.policy.mask_drivable = True + self.policy.num_plan_samples = 50 + self.policy.num_action_samples = 10 + self.policy.pos_to_yaw = True + self.policy.yaw_correction_speed = 1.0 + self.policy.diversification_clearance = None + self.policy.sample = True + + + self.policy.cost_weights.collision_weight = 10.0 + self.policy.cost_weights.lane_weight = 1.0 + self.policy.cost_weights.likelihood_weight = 0.0 # 0.1 + self.policy.cost_weights.progress_weight = 0.0 # 0.005 + + self.metrics.compute_analytical_metrics = True + self.metrics.compute_learned_metrics = False + + self.perturb.enabled = False + self.perturb.OU.theta = 0.8 + self.perturb.OU.sigma = [0.0, 0.1,0.2,0.5,1.0,2.0,4.0] + self.perturb.OU.scale = [1.0,1.0,0.2] + + self.rolling_perturb.enabled = False + self.rolling_perturb.OU.theta = 0.8 + self.rolling_perturb.OU.sigma = 0.5 + self.rolling_perturb.OU.scale = [1.0,1.0,0.2] + + self.occupancy.rolling = True + self.occupancy.rolling_horizon = [5,10,20] + + self.cvae.rolling = True + self.cvae.rolling_horizon = [5,10,20] + + self.nusc.eval_scenes = np.arange(100).tolist() + self.nusc.n_step_action = 5 + self.nusc.num_simulation_steps = 200 + self.nusc.skip_first_n = 0 + + + self.adjustment.random_init_plan=True + self.adjustment.remove_existing_neighbors = False + self.adjustment.initial_num_neighbors = 4 + self.adjustment.num_frame_per_new_agent = 20 + + def clone(self): + return deepcopy(self) + +class EvaluationConfig(Dict): + def __init__(self): + super(EvaluationConfig, self).__init__() + self.name = None + self.env = "nusc" + self.dataset_path = None + self.eval_class = "" + self.seed = 0 + self.ckpt_root_dir = "checkpoints/" + self.ckpt.dir = None + self.ckpt.ngc_job_id = None + self.ckpt.ckpt_dir = None + self.ckpt.ckpt_key = None + + self.eval.batch_size = 100 + self.eval.num_steps = None + self.eval.num_data_workers = 8 + self.log_image_frequency=None + self.log_all_image = False + + self.trajdata_source_root = "nusc_trainval" + self.trajdata_source_eval = "val" + + +class TrainTimeEvaluationConfig(SimEvaluationConfig): + def __init__(self): + super(TrainTimeEvaluationConfig, self).__init__() + + self.num_scenes_per_batch = 4 + self.nusc.eval_scenes = np.arange(0, 100, 10).tolist() + self.policy.sample = False diff --git a/diffstack/configs/registry.py b/diffstack/configs/registry.py new file mode 100644 index 0000000..18672b1 --- /dev/null +++ b/diffstack/configs/registry.py @@ -0,0 +1,32 @@ +"""A global registry for looking up named experiment configs""" +from diffstack.configs.base import ExperimentConfig +from diffstack.configs.config import Dict + + +from diffstack.configs.trajdata_config import TrajdataTrainConfig, TrajdataEnvConfig + +from diffstack.configs.algo_config import ( + CTTConfig, +) + + +EXP_CONFIG_REGISTRY = dict() + + +EXP_CONFIG_REGISTRY["CTTPredStack"] = ExperimentConfig( + train_config=TrajdataTrainConfig(), + env_config=TrajdataEnvConfig(), + module_configs=Dict(predictor=CTTConfig()), + registered_name="CTTPredStack", + stack_type="pred", +) + + +def get_registered_experiment_config(registered_name): + if registered_name not in EXP_CONFIG_REGISTRY.keys(): + raise KeyError( + "'{}' is not a registered experiment config please choose from {}".format( + registered_name, list(EXP_CONFIG_REGISTRY.keys()) + ) + ) + return EXP_CONFIG_REGISTRY[registered_name].clone() diff --git a/diffstack/configs/trajdata_config.py b/diffstack/configs/trajdata_config.py new file mode 100644 index 0000000..def7b9c --- /dev/null +++ b/diffstack/configs/trajdata_config.py @@ -0,0 +1,99 @@ +from diffstack.configs.base import TrainConfig, EnvConfig + +MAX_POINTS_LANE = 5 + + +class TrajdataTrainConfig(TrainConfig): + def __init__(self): + super(TrajdataTrainConfig, self).__init__() + + self.trajdata_source_train = "train" + self.trajdata_source_valid = "val" + self.trajdata_source_test = None + self.trajdata_source_root = "nusc_trainval" + self.trajdata_val_source_root = None + self.trajdata_test_source_root = None + # self.trajdata_source_train = "mini_train" + # self.trajdata_source_valid = "mini_val" + # self.trajdata_source_root = "nusc_mini" + + self.dataset_path = "SET-THIS-THROUGH-TRAIN-SCRIPT-ARGS" + self.datamodule_class = "UnifiedDataModule" + self.ego_only = True + + self.rollout.enabled = False + self.rollout.save_video = True + self.rollout.every_n_steps = 5000 + + # training config + self.training.batch_size = 100 + self.training.num_steps = 200000 + self.training.num_data_workers = 8 + + self.save.every_n_steps = 1000 + self.save.best_k = 10 + + # validation config + self.validation.enabled = True + self.validation.batch_size = 32 + self.validation.num_data_workers = 6 + self.validation.every_n_steps = 500 + self.validation.num_steps_per_epoch = 50 + + self.test.enabled = False + self.test.batch_size = 32 + self.test.num_data_workers = 6 + self.test.every_n_steps = 500 + self.test.num_steps_per_epoch = 50 + + +class TrajdataEnvConfig(EnvConfig): + def __init__(self): + super(TrajdataEnvConfig, self).__init__() + + self.name = "nusc_trainval" + + # raster image size [pixels] + self.rasterizer.raster_size = 224 + + # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. + self.rasterizer.pixel_size = 0.5 + + # where the agent is on the map, (0.0, 0.0) is the center + # WARNING: this should not be changed before resolving + self.rasterizer.ego_center = (-0.75, 0.0) + + # maximum number of agents to consider during training + self.data_generation_params.other_agents_num = 10 + + self.data_generation_params.max_agents_distance = 30 + + # correct for yaw (zero-out delta yaw) when speed is lower than this threshold + self.data_generation_params.yaw_correction_speed = 1.0 + + self.simulation.distance_th_close = 30 + + # maximum number of simulation steps to run (0.1sec / step) + self.simulation.num_simulation_steps = 50 + + # which frame to start an simulation episode with + self.simulation.start_frame_index = 0 + + # whether to get lane information + self.simulation.vectorize_lane = "ego" + + # whether include neighbor map patches + self.incl_neighbor_map = False + # whether include vectorized map + self.incl_vector_map = True + # whether include rasterized map + self.incl_raster_map = False + # whether to combine lanes with a single successor + self.remove_single_successor = False # be very careful, don't turn this on unless you know what you are doing + # whether to prepare lane graphs + self.calc_lane_graph = True + # maximum number of lanes to consider + self.max_num_lanes = 15 + # number of lane points to include in the polyline + self.num_lane_pts = 30 + self.remove_parked = False diff --git a/diffstack/data/cached_nusc_as_trajdata.py b/diffstack/data/cached_nusc_as_trajdata.py deleted file mode 100644 index 0f55fc5..0000000 --- a/diffstack/data/cached_nusc_as_trajdata.py +++ /dev/null @@ -1,684 +0,0 @@ -import os -import time -import torch -import numpy as np -import dill -import json - -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from tqdm import tqdm -from typing import List, Dict - -from trajdata.data_structures.batch import AgentBatch -from trajdata.data_structures.agent import AgentType -from trajdata.utils.arr_utils import batch_nd_transform_points_angles_pt, transform_matrices, batch_nd_transform_points_pt -from trajdata.utils.arr_utils import PadDirection, transform_matrices - -from diffstack.data.trajdata_lanes import LanesList -from diffstack.utils import visualization as plan_vis -from diffstack.utils.utils import batch_derivative_of, angle_wrap, collate - -from diffstack.modules.predictors.trajectron_utils.environment import Environment - - -standardization = { - 'PEDESTRIAN': { - 'position': { - 'x': {'mean': 0, 'std': 1}, - 'y': {'mean': 0, 'std': 1} - }, - 'velocity': { - 'x': {'mean': 0, 'std': 2}, - 'y': {'mean': 0, 'std': 2} - }, - 'acceleration': { - 'x': {'mean': 0, 'std': 1}, - 'y': {'mean': 0, 'std': 1} - } - }, - 'VEHICLE': { - 'position': { - 'x': {'mean': 0, 'std': 80}, - 'y': {'mean': 0, 'std': 80} - }, - 'velocity': { - 'x': {'mean': 0, 'std': 15}, - 'y': {'mean': 0, 'std': 15}, - 'norm': {'mean': 0, 'std': 15} - }, - 'acceleration': { - 'x': {'mean': 0, 'std': 4}, - 'y': {'mean': 0, 'std': 4}, - 'norm': {'mean': 0, 'std': 4} - }, - 'heading': { - 'x': {'mean': 0, 'std': 1}, - 'y': {'mean': 0, 'std': 1}, - '°': {'mean': 0, 'std': np.pi}, - 'd°': {'mean': 0, 'std': 1} - } - } -} - - -def prepare_cache_to_avdata(rank, hyperparams, args, diffstack): - """ - Loads cache file used for Corl22 paper and converts to AgentBatch. - - The data is identical apart from numerical differences which was verified by manually - comparing inputs/outputs for - - `corl22_unified_reproduce` branch commit 3703d7adb94f98a749465fc7dec7da4f4a2bb2e3 - - `clean` branch commit 4edf81c4ea742a862e9600558864d193ff9346c8 - I used the local `./cache/nuScenes_mini_val.pkl.6.mpc1.20.20.v6.cached.data.pkl` file. - Randomly picked inputs were the same, and planning metrics were very close over the - validation data when using the `Unified Train GTpred Fan-MPC manualdata compare` launch - config, - --predictor=gt - --planner=fan_mpc - --history_sec=4.0 - --prediction_sec=3.0 - --augment_input_noise=0.0 - - Fan related metrics: - fan_converged 100: 0.8839 | plan.fan_valid 100: 0.8839 - plan_class_hcost 100: 0.8021 | plan.class_hcost 100: 0.8023 - - Mpc related metrics: difference is larger but this could be due to small - numerical differences that have a large influence on mpc. - plan_hcost 100: 0.7412 | plan.hcost 100: 0.7417 - lan_converged 100: 0.6161 | plan.converged 100: 0.5893 - - It is hard to reproduce results with T++ training because the random seeds will differ. - """ - _, train_split = hyperparams["train_data"].split('-') - _, eval_split = hyperparams["eval_data"].split('-') - cached_train_data_path = os.path.join( - os.path.expanduser(hyperparams["cached_data_dir"]), - f"cached_nuScenes_{train_split}.pkl") - cached_eval_data_path = os.path.join( - os.path.expanduser(hyperparams["cached_data_dir"]), - f"cached_nuScenes_{eval_split}.pkl") - - if not os.path.exists(cached_eval_data_path): - raise ValueError(f"No file: {cached_eval_data_path}") - if not os.path.exists(cached_train_data_path): - raise ValueError(f"No file: {cached_train_data_path}") - - # # Load training data. - with open(cached_train_data_path, 'rb') as f: - train_dataset = dill.load(f) - train_dataset = list(train_dataset)[0] - - # Load eval data - with open(cached_eval_data_path, 'rb') as f: - eval_dataset = dill.load(f) - eval_dataset = list(eval_dataset)[0] - - # Filter data - filter_fn = get_filter_func( - node_type=train_dataset.node_type, - plan_node_types=hyperparams["plan_node_types"], - get_neighbor_idx_for_planning_fn=get_neighbor_idx_for_planning, - plan_valid=hyperparams["filter_plan_valid"], - plan_converged=hyperparams["filter_plan_converged"], plan_relevant=hyperparams["filter_plan_relevant"], - lane_near=hyperparams["filter_lane_near"]) - if filter_fn is not None: - train_dataset.filter(filter_fn, verbose=(rank == 0)) - - filter_fn = get_filter_func( - node_type=eval_dataset.node_type, - plan_node_types=hyperparams["plan_node_types"], - get_neighbor_idx_for_planning_fn=get_neighbor_idx_for_planning, - legacy_valid_set=True) - eval_dataset.filter(filter_fn, verbose=(rank == 0)) - - train_sampler = DistributedSampler( - train_dataset, - num_replicas=hyperparams["world_size"], - rank=rank - ) - eval_sampler = DistributedSampler( - eval_dataset, - num_replicas=hyperparams["world_size"], - rank=rank - ) - - def my_collate(*args): - manual_batch = collate(*args) - agent_batch = convert_manual_batch_to_agentbatch(manual_batch, hyperparams=hyperparams) - return agent_batch - - train_dataloader = DataLoader(train_dataset, - collate_fn=my_collate, - pin_memory=False if hyperparams['device'] == 'cpu' else True, - batch_size=hyperparams['batch_size'], - shuffle=False, - num_workers=hyperparams['preprocess_workers'], - sampler=train_sampler) - - eval_dataloader = DataLoader(eval_dataset, - collate_fn=my_collate, - pin_memory=False if hyperparams['device'] == 'cpu' else True, - batch_size=hyperparams['eval_batch_size'], - shuffle=False, - num_workers=hyperparams['preprocess_workers'], - sampler=eval_sampler) - - input_wrapper = lambda batch: {"batch": batch} - - return train_dataloader, train_sampler, train_dataset, eval_dataloader, eval_sampler, eval_dataset, input_wrapper - - -def get_filter_func(node_type, plan_node_types, get_neighbor_idx_for_planning_fn, plan_valid=False, plan_converged=False, plan_relevant=False, min_history_len=None, lane_near=False, legacy_valid_set=False): - """ - sample: single non-batched input from dataset - return: True if sample should be kept in dataset - """ - - # shortcut no filtering - if not plan_valid and not plan_converged and not plan_relevant and min_history_len is None and not legacy_valid_set: - return None - - def fn(sample): - (first_history_index, - x_t, y_t, x_st_t, y_st_t, - neighbors_data_st, # dict of lists. edge_type -> [batch][neighbor]: Tensor(time, statedim). Represetns - neighbors_edge_value, - robot_traj_st_t, - map, neighbors_future_data, plan_data) = sample - - # Prediction reletad filters - if legacy_valid_set and x_t.shape[1] - first_history_index < 8: - # TODO this is only kept to reproduce validation set in the paper. x_t.shape is used incorrectly - # x_t.shape[1]=8 because its state dimensions, min_history_len=8, so in effect this requires first_history_index==0 - return False - - if min_history_len is not None and x_t.shape[0] - first_history_index < min_history_len: - return False - - # Planning related filters - if node_type not in plan_node_types: - # Don't filter if we don't plan for this node type. - return True - if plan_valid: - ni = get_neighbor_idx_for_planning_fn(plan_data) - if ni < 0: - return False - # TODO temp fix for preprocessing not filtering invalid futures - vehicle_future_f = neighbors_future_data[(str(node_type), 'VEHICLE')] - if vehicle_future_f[int(ni)].shape[1]>=9 and torch.isnan(vehicle_future_f[int(ni)][:, :(4+2+3)]).any(): - return False - if plan_converged: - # _, _, _, _, _, _, _, _, gtplan_x, gtplan_u, gtplan_converged = self.decode_plan_inputs(plan_data=plan_data.unsqueeze(0)) # dummy batch - gtplan_converged = plan_data['gt_plan_converged'] - if gtplan_converged < 0.5: - return False - if plan_relevant: - regret = plan_data["nopred_plan_hcost"] - plan_data["gt_plan_hcost"] - if regret <= 0.: # 0.001: - return False - if lane_near: - raise NotImplementedError() - return True - return fn - - -def get_neighbor_idx_for_planning(plan_data, plan_agent_choice="most_relevant"): - # Chose plan_node, context_nodes. - if plan_agent_choice == "most_relevant": - plan_neigbors = plan_data["most_relevant_idx"] - # plan_neigbors = plan_data[..., 0].int() - elif plan_agent_choice == "ego": - plan_neigbors = plan_data["robot_idx"] - # plan_neigbors = plan_data[..., 1].int() - else: - raise ValueError("Unknown plan_agent_choice: {}".format(plan_agent_choice)) - return plan_neigbors - - -def standardized_manual_state(x, x_origin, agent_type: str, dt: float, only2d=False): - attention_radius = dict() - attention_radius["PEDESTRIAN"] = 20.0 - attention_radius["VEHICLE"] = 30.0 - state_dict = { - "PEDESTRIAN": { - "position": ["x", "y"], - "velocity": ["x", "y"], - "acceleration": ["x", "y"], - # "augment": ["ego_indicator"] - }, - "VEHICLE": { - "position": ["x", "y"], - "velocity": ["x", "y"], - "acceleration": ["x", "y"], - "heading": ["°", "d°"], - # "augment": ["ego_indicator"] - } - } - - env = Environment(node_type_list=['VEHICLE', 'PEDESTRIAN'], standardization=standardization, dt=dt) - std_dict = {} - _, std_dict["PEDESTRIAN"] = env.get_standardize_params(state_dict["PEDESTRIAN"], "PEDESTRIAN") - _, std_dict["VEHICLE"] = env.get_standardize_params(state_dict["VEHICLE"], "VEHICLE") - - std = np.array(std_dict[agent_type], dtype=np.float32) - std[0:2] = attention_radius[agent_type] - rel_state = np.zeros_like(x, dtype=np.float32) - if only2d: - rel_state_dims = 2 - else: - rel_state_dims = np.min((x.shape[-1], x_origin.shape[-1])) - - rel_state[:, :rel_state_dims] = x_origin[:rel_state_dims] - x_standardized = (x - rel_state) / std[None, :x.shape[-1]] - return x_standardized - - -def unstandardized_manual_state(x_standardized, x_origin, agent_type: str, dt: float, only2d=False): - attention_radius = dict() - attention_radius["PEDESTRIAN"] = 20.0 - attention_radius["VEHICLE"] = 30.0 - state_dict = { - "PEDESTRIAN": { - "position": ["x", "y"], - "velocity": ["x", "y"], - "acceleration": ["x", "y"], - # "augment": ["ego_indicator"] - }, - "VEHICLE": { - "position": ["x", "y"], - "velocity": ["x", "y"], - "acceleration": ["x", "y"], - "heading": ["°", "d°"], - # "augment": ["ego_indicator"] - } - } - - env = Environment(node_type_list=['VEHICLE', 'PEDESTRIAN'], standardization=standardization, dt=dt) - std_dict = {} - _, std_dict["PEDESTRIAN"] = env.get_standardize_params(state_dict["PEDESTRIAN"], "PEDESTRIAN") - _, std_dict["VEHICLE"] = env.get_standardize_params(state_dict["VEHICLE"], "VEHICLE") - - std = np.array(std_dict[agent_type], dtype=np.float32) - std[0:2] = attention_radius[agent_type] - rel_state = np.zeros_like(x_standardized, dtype=np.float32) - if only2d: - rel_state_dims = 2 - else: - rel_state_dims = np.min((x_standardized.shape[-1], x_origin.shape[-1])) - - rel_state[:, :rel_state_dims] = x_origin[:rel_state_dims] - x_global = x_standardized * std[None] + rel_state - return x_global - - -def check_consistent(states, agent_type:str, dt): - if agent_type == "VEHICLE": - x, y, vx, vy, ax, ay, h, dh = torch.unbind(states[..., :8], dim=-1) - h_from_v = torch.atan2(vy, vx) - h_err = ((h-h_from_v + np.pi) % (2*np.pi) - np.pi).abs() - else: - x, y, vx, vy, ax, ay = torch.unbind(states[..., :6], dim=-1) - h_err = torch.zeros((1, )) - - delta_x = (x[1:] - x[:-1]) / dt - delta_y = (y[1:] - y[:-1]) / dt - vx_err = (vx[1:]-delta_x) - vy_err = (vy[1:]-delta_y) - - print ("h", h_err, h_err.max()) - print ("vx", vx_err, vx_err.max()) - print ("vy", vy_err, vy_err.max()) - return h_err, vx_err, vy_err - - -def convert_manual_hist_to_trajdata_hist(traj, agent_type: AgentType): - if agent_type == AgentType.VEHICLE: - # "VEHICLE": { "position": ["x", "y"], "velocity": ["x", "y"], "acceleration": ["x", "y"], "heading": ["°", "d°"], "augment": ["ego_indicator"] - # History - x, y, vx, vy, ax, ay, h, dh = torch.unbind(traj[..., :8], dim=-1) - trajdata_hist = torch.stack((x, y, vx, vy, ax, ay, torch.sin(h), torch.cos(h)), dim=-1) - elif agent_type == AgentType.PEDESTRIAN: - # PEDESTRIAN: "position": ["x", "y"], "velocity": ["x", "y"], "acceleration": ["x", "y"], "augment": ["ego_indicator"] - # History. There is no heading so recover it from vx/vy - x, y, vx, vy, ax, ay = torch.unbind(traj[..., :6], dim=-1) - h = torch.atan2(vy, vx) - trajdata_hist = torch.stack((x, y, vx, vy, ax, ay, torch.sin(h), torch.cos(h)), dim=-1) - else: - assert False - - return trajdata_hist - - -def convert_trajdata_hist_to_manual_hist(traj, agent_type: AgentType, dt: float): - if agent_type == AgentType.VEHICLE: - # "VEHICLE": { "position": ["x", "y"], "velocity": ["x", "y"], "acceleration": ["x", "y"], "heading": ["°", "d°"], "augment": ["ego_indicator"] - # History - x, y, vx, vy, ax, ay, sinh, cosh = torch.unbind(traj, dim=-1) - h = angle_wrap(torch.atan2(sinh, cosh)) - dh = angle_wrap(batch_derivative_of(h[..., None], dt)).squeeze(-1) - # Set dh to zero for steps that are nan - dh[torch.logical_not(h.isnan())] = torch.nan_to_num(dh, 0.)[torch.logical_not(h.isnan())] - manual_hist = torch.stack((x, y, vx, vy, ax, ay, h ,dh), dim=-1) - elif agent_type == AgentType.PEDESTRIAN: - # PEDESTRIAN: "position": ["x", "y"], "velocity": ["x", "y"], "acceleration": ["x", "y"], "augment": ["ego_indicator"] - # History. There is no heading so recover it from vx/vy - x, y, vx, vy, ax, ay, sinh, cosh = torch.unbind(traj, dim=-1) - manual_hist = torch.stack((x, y, vx, vy, ax, ay), dim=-1) - else: - assert False - - return manual_hist - - -def convert_manual_fut_to_trajdata_fut(traj, agent_type: AgentType, ph:int, dt:float, num_lane_points=16): - """ - traj: [..., T, future_state_dim] - """ - if agent_type == AgentType.VEHICLE: - x_fut, u_gt, lanes, x_proj, u_t_fitted_dh, u_t_fitted_a, lane_points = torch.split( - traj, (4, 2, 3, 2, ph+1, ph+1, num_lane_points*3), dim=-1,) - x, y, h, v = torch.unbind(x_fut, dim=-1) - vx = v * torch.cos(h) - vy = v * torch.sin(h) - ax = batch_derivative_of(vx[:, None], dt).squeeze(1) - ay = batch_derivative_of(vy[:, None], dt).squeeze(1) - trajdata_fut = torch.stack((x, y, vx, vy, ax, ay, torch.sin(h), torch.cos(h)), dim=-1) - elif agent_type == AgentType.PEDESTRIAN: - # There is no heading, nor velocity, so recover these. - x, y = torch.unbind(traj, dim=-1) - vx = batch_derivative_of(x[..., None], dt).squeeze(-1) - vy = batch_derivative_of(y[..., None], dt).squeeze(-1) - h = torch.atan2(vy, vx) - ax = batch_derivative_of(vx[..., None], dt).squeeze(-1) - ay = batch_derivative_of(vy[..., None], dt).squeeze(-1) - trajdata_fut = torch.stack((x, y, vx, vy, ax, ay, torch.sin(h), torch.cos(h)), dim=-1) - lanes = None - lane_points = None - else: - assert False - - return trajdata_fut, lanes, lane_points - - -class ManualInputsList(list): - def __to__(self, device, non_blocking=False): - # keep on cpu - return self - - -def origin_to_tf(origin_xyh): - translate_mat = transform_matrices( - angles=torch.zeros_like(origin_xyh[..., 2]), - translations=-origin_xyh[..., :2]) - - rot_mat = transform_matrices( - angles=-origin_xyh[..., 2], - translations=None) - - trans_rot_mat = torch.bmm(rot_mat, translate_mat) # first translate then rotate - - return trans_rot_mat, rot_mat - - -def transform_states_xyvvaahh(traj_xyvvaahh: torch.Tensor, origin_xyh: torch.Tensor) -> torch.Tensor: - """ - traj_xyvvaahh: [..., state_dim] where state_dim = [x, y, vx, vy, ax, ay, sinh, cosh] - """ - tf_mat, rot_mat = origin_to_tf(origin_xyh) - - xy, vv, aa, hh = torch.split(traj_xyvvaahh, (2, 2, 2, 2), dim=-1) - xy = batch_nd_transform_points_pt(xy, tf_mat) - vv = batch_nd_transform_points_pt(vv, rot_mat) - aa = batch_nd_transform_points_pt(aa, rot_mat) - # hh: sinh, cosh instead of cosh, sinh, so we use flip - hh = batch_nd_transform_points_pt(hh.flip(-1), rot_mat).flip(-1) - - return torch.concat((xy, vv, aa, hh), dim=-1) - -def transform_agentbatch_coordinate_frame(agent_batch: AgentBatch, origin_xyh: torch.Tensor, extras_transform_fn: Dict) -> AgentBatch: - """ - Args: - origin_xyh: desired origin (x, y, heading) defined in the current coordinate frame. - """ - if agent_batch.maps is not None: - raise NotImplementedError - - tf_mat, _ = origin_to_tf(origin_xyh) - - return AgentBatch( - data_idx=agent_batch.data_idx, - scene_ts=agent_batch.scene_ts, - dt=agent_batch.dt, - agent_name=agent_batch.agent_name, - agent_type=agent_batch.agent_type, - curr_agent_state=agent_batch.curr_agent_state, # this is always defined in the `global` coordinate frame for some reason - agent_hist=transform_states_xyvvaahh(agent_batch.agent_hist, origin_xyh), - agent_hist_extent=agent_batch.agent_hist_extent, - agent_hist_len=agent_batch.agent_hist_len, - agent_fut=transform_states_xyvvaahh(agent_batch.agent_fut, origin_xyh), - agent_fut_extent=agent_batch.agent_fut_extent, - agent_fut_len=agent_batch.agent_fut_len, - num_neigh=agent_batch.num_neigh, - neigh_types=agent_batch.neigh_types, - neigh_hist=transform_states_xyvvaahh(agent_batch.neigh_hist, origin_xyh), - neigh_hist_extents=agent_batch.neigh_hist_extents, - neigh_hist_len=agent_batch.neigh_hist_len, - neigh_fut=transform_states_xyvvaahh(agent_batch.neigh_fut, origin_xyh), - neigh_fut_extents=agent_batch.neigh_fut_extents, - neigh_fut_len=agent_batch.neigh_fut_len, - robot_fut=transform_states_xyvvaahh(agent_batch.robot_fut, origin_xyh) - if agent_batch.robot_fut is not None - else None, - robot_fut_len=agent_batch.robot_fut_len, - maps=agent_batch.maps, # TODO - maps_resolution=agent_batch.maps_resolution, # TODO - rasters_from_world_tf=agent_batch.rasters_from_world_tf, # TODO - agents_from_world_tf=torch.bmm(tf_mat, agent_batch.agents_from_world_tf), # TODO test - scene_ids=agent_batch.scene_ids, - history_pad_dir=agent_batch.history_pad_dir, - extras={ - key: extras_transform_fn[key](val) - for key, val in agent_batch.extras.items()}, - ) - -def convert_manual_batch_to_agentbatch(manual_data_batch: List, hyperparams: Dict, agent_centric_frame=False, pad_direction=PadDirection.AFTER) -> AgentBatch: - - (first_history_index, - x_t, y_t, x_st_t, y_st_t, - neighbors_data_st, # dict of lists. edge_type -> [batch][neighbor]: Tensor(time, statedim). Represetns - neighbors_edge_value, - robot_traj_st_t, - map, neighbors_future_data, plan_data) = manual_data_batch - - batch_size = x_st_t.shape[0] - ph = hyperparams["prediction_horizon"] - dt = hyperparams["dt"] - num_lane_points = 16 - agent_hist_len = x_st_t.shape[1] - first_history_index - assert list(neighbors_data_st.keys())[0][0] == "VEHICLE" - agent_type = torch.tensor([AgentType.VEHICLE] * batch_size) - - # Combine different edge neighbors into a single list - # Deep copy - neigh_data_st = [list(temp) for temp in neighbors_data_st[("VEHICLE", "VEHICLE")]] - neigh_fut_data = [list(temp) for temp in neighbors_future_data[("VEHICLE", "VEHICLE")]] - neigh_types_list = [[AgentType.VEHICLE] * len(neigh_data_st[b_i]) for b_i in range(batch_size)] - for b_i in range(batch_size): - neigh_data_st[b_i].extend(neighbors_data_st[("VEHICLE", "PEDESTRIAN")][b_i]) - neigh_fut_data[b_i].extend(neighbors_future_data[("VEHICLE", "PEDESTRIAN")][b_i]) - neigh_types_list[b_i].extend([AgentType.PEDESTRIAN] * len(neighbors_data_st[("VEHICLE", "PEDESTRIAN")][b_i])) - num_neigh = torch.tensor([len(neigh) for neigh in neigh_data_st]) - max_num_neigh = torch.max(num_neigh).item() - # We dont need to orgainze edge values into a single list because it is already combined for vehicles and pedestrians - neigh_edge_list = list(neighbors_edge_value[("VEHICLE", "VEHICLE")]) - - # Unstandardize neigh_data - neigh_data = [[] for _ in range(batch_size)] - x_origin_batch = x_t[:, -1, :].cpu().numpy() - for b_i in range(batch_size): - # # For agent history - x_t_conv = unstandardized_manual_state(x_st_t[b_i], x_origin_batch[b_i], AgentType.VEHICLE.name, dt, only2d=True) - assert torch.logical_or(torch.isclose(x_t_conv, x_t[b_i]), torch.isnan(x_t_conv)).all() - # check_consistent(x_t_conv, "VEHICLE", dt) - - for n_i in range(len(neigh_data_st[b_i])): - x_conv = unstandardized_manual_state(neigh_data_st[b_i][n_i], x_origin_batch[b_i], AgentType(neigh_types_list[b_i][n_i]).name, dt, only2d=False) - # check_consistent(x_conv, AgentType(neigh_types[b_i][n_i]).name, dt) - # Check current pose is the same in history and future - assert torch.isclose(x_conv[-1, :2], neigh_fut_data[b_i][n_i][0, :2]).all() - neigh_data[b_i].append(x_conv) - - # Move plan agent to first neighbor - robot_ind = torch.full((batch_size, ), -1, dtype=torch.int) - for b_i in range(batch_size): - plan_i = plan_data['most_relevant_idx'][b_i] - if plan_i >= 0: - neigh_data[b_i] = [neigh_data[b_i][plan_i]] + [ - neigh_data[b_i][n_i] - for n_i in range(len(neigh_data[b_i])) if n_i != plan_i] - neigh_fut_data[b_i] = [neigh_fut_data[b_i][plan_i]] + [ - neigh_fut_data[b_i][n_i] - for n_i in range(len(neigh_fut_data[b_i])) if n_i != plan_i] - robot_ind[b_i] = 0 - neigh_edge_list[b_i] = [neigh_edge_list[b_i][plan_i]] + [ - neigh_edge_list[b_i][n_i] - for n_i in range(len(neigh_edge_list[b_i])) if n_i != plan_i] - - # Convert to agentbatch tensors - neigh_hist = torch.full((batch_size, max_num_neigh, hyperparams["maximum_history_length"]+1, 8), torch.nan) - neigh_fut = torch.full((batch_size, max_num_neigh, ph, 8), torch.nan) - neigh_hist_len = torch.full((batch_size, max_num_neigh), 0, dtype=torch.int64) - neigh_fut_len = torch.full((batch_size, max_num_neigh), 0, dtype=torch.int64) - neigh_types = torch.full((batch_size, max_num_neigh), -1, dtype=torch.int) - neigh_edge = torch.full((batch_size, max_num_neigh), torch.nan) - - lanes_batch = torch.full((batch_size, ph+1, 3), torch.nan) - lane_points_batch = torch.full((batch_size, ph+1, num_lane_points, 3), torch.nan) - - for b_i in range(len(neigh_data)): - for n_i in range(len(neigh_data[b_i])): - # Convert state representation - TODO move out function, apply to agent states too. - trajdata_hist = convert_manual_hist_to_trajdata_hist(neigh_data[b_i][n_i], neigh_types_list[b_i][n_i]) - trajdata_fut, lanes, lane_points = convert_manual_fut_to_trajdata_fut( - neigh_fut_data[b_i][n_i], neigh_types_list[b_i][n_i], ph, dt) - - # Find actual history/future length. Invalid states are represented by zeros. - invalid_t_mask = torch.logical_or( - torch.isclose(trajdata_hist[:, :2], torch.zeros(()), atol=1e-4).all(dim=1), - torch.isnan(trajdata_hist[:, 0]) - ) - invalid_t_count = invalid_t_mask.sum() - trajdata_hist = trajdata_hist[invalid_t_count:] - invalid_t_mask = torch.logical_or( - torch.isclose(trajdata_fut[:, :2], torch.zeros(()), atol=1e-4).all(dim=1), - torch.isnan(trajdata_fut[:, 0]) - ) - invalid_t_count = invalid_t_mask.sum() - trajdata_fut = trajdata_fut[:trajdata_fut.shape[0]-invalid_t_count] - - # Use padding_direction.AFTER - if pad_direction == pad_direction.AFTER: - neigh_hist[b_i, n_i, :trajdata_hist.shape[0]] = trajdata_hist - else: - neigh_hist[b_i, n_i, neigh_hist.shape[2]-trajdata_hist.shape[0]:] = trajdata_hist - neigh_fut[b_i, n_i, :trajdata_fut[1:].shape[0]] = trajdata_fut[1:] - neigh_hist_len[b_i, n_i] = trajdata_hist.shape[0] - neigh_fut_len[b_i, n_i] = trajdata_fut.shape[0] - neigh_types[b_i, n_i] = neigh_types_list[b_i][n_i] - neigh_edge[b_i, n_i] = neigh_edge_list[b_i][n_i] - - if n_i == robot_ind[b_i]: - lanes_batch[b_i] = lanes - - lane_points = lane_points.reshape(list(lane_points.shape[:-1]) + [num_lane_points, 3]) - lane_points_batch[b_i] = lane_points - - agent_hist = convert_manual_hist_to_trajdata_hist(x_t, AgentType.VEHICLE) - # Agenet future is only xy, same as the history for a pedestrian, so we - # purposefully use the history converter function with PEDESTRIAN the future of a VEHICLE agent. - agent_fut, _, _ = convert_manual_fut_to_trajdata_fut(y_t, AgentType.PEDESTRIAN, ph, dt) - - # Convert to PadDirection.AFTER - if pad_direction == PadDirection.AFTER: - for b_i in range(batch_size): - new_hist = torch.full_like(agent_hist[b_i], torch.nan) - new_hist[:agent_hist_len[b_i]] = agent_hist[b_i, (agent_hist.shape[1]-agent_hist_len[b_i]):] - agent_hist[b_i] = new_hist - - curr_agent_state_world = agent_hist[torch.arange(batch_size), agent_hist_len-1] - curr_agent_state_world[:, :2] += plan_data["scene_offset"] - - agents_from_world_tf = transform_matrices( - angles=torch.zeros((batch_size, )), - translations=-plan_data["scene_offset"], - ) - - lanes_near_goal = plan_data["most_relevant_nearby_lanes"] # shallow copy - - extras = { - "robot_ind": robot_ind, - "goal": neigh_fut[:, 0, -1, :], # xyvvaahh - "lane_projection_points": lanes_batch, # xyh - "lanes_near_goal": LanesList(lanes_near_goal), # list of xyh - "neigh_edge_weight": neigh_edge, - "manual_inputs": ManualInputsList(manual_data_batch), - } - - agent_batch = AgentBatch( - data_idx=torch.full((batch_size,), np.nan), - scene_ts=torch.tensor([0] * batch_size), - dt=torch.tensor([hyperparams["dt"]] * batch_size), - agent_name=["dummy"] * batch_size, - agent_type=agent_type, - curr_agent_state=curr_agent_state_world, - agent_hist=agent_hist, - agent_hist_len=agent_hist_len, - agent_fut=agent_fut, - agent_fut_len=torch.tensor([hyperparams["prediction_horizon"]] * batch_size, dtype=torch.int), - agent_hist_extent=torch.full((batch_size,), np.nan), - agent_fut_extent=torch.full((batch_size,), np.nan), - num_neigh=num_neigh, - neigh_types=neigh_types, - neigh_hist=neigh_hist, - neigh_hist_len=neigh_hist_len, - neigh_fut=neigh_fut, - neigh_fut_len=neigh_fut_len, - neigh_hist_extents=torch.full((batch_size,), np.nan), - neigh_fut_extents=torch.full((batch_size,), np.nan), - robot_fut=None, - robot_fut_len=None, - maps=None, - map_names=None, - vector_maps=None, - maps_resolution=None, - rasters_from_world_tf=torch.full((batch_size,), np.nan), - agents_from_world_tf=agents_from_world_tf, - scene_ids=[None for _ in range(batch_size)], - history_pad_dir=pad_direction, - extras=extras, - ) - - # # Convert everything to agent centric. - if agent_centric_frame: - agent_state_t = agent_hist[torch.arange(batch_size), agent_hist_len-1] - origin_xyh = torch.concat(( - agent_state_t[:, :2], - torch.atan2(agent_state_t[:, -2], agent_state_t[:, -1]).unsqueeze(-1) # h = atan2(sinh, cosh) - ), dim=-1) - tf_mat, rot_mat = origin_to_tf(origin_xyh) - - extras_tf_fn = { - "robot_ind": lambda x: x, - "goal": lambda x: transform_states_xyvvaahh(x, origin_xyh), - "lane_projection_points": lambda x: - batch_nd_transform_points_angles_pt(x, tf_mat), - "lanes_near_goal": lambda xlistbatch: LanesList( - [[batch_nd_transform_points_angles_pt(x, tf_mat[b_i]) for x in xlistbatch[b_i]] for b_i in range(batch_size)]) - } - - agent_batch = transform_agentbatch_coordinate_frame(agent_batch, origin_xyh, extras_tf_fn) - - return agent_batch - diff --git a/diffstack/data/scene_batch_extras.py b/diffstack/data/scene_batch_extras.py index 88e8507..efa2077 100644 --- a/diffstack/data/scene_batch_extras.py +++ b/diffstack/data/scene_batch_extras.py @@ -12,42 +12,46 @@ def role_selector(element: SceneBatchElement, pred_agent_types: List[AgentType] = (AgentType.VEHICLE, )): # Find ego agent_names = [agent.name for agent in element.agents] - ego_i = next(i for i, name in enumerate(agent_names) if name == "ego") - - # Find pred agent that is closest to ego - dists = [] - inds = [] - for n_i in range(len(element.agent_futures)): - if n_i == ego_i: - continue - - # Filter what agents we want to predict - if element.agent_types_np[n_i] not in pred_agent_types: - continue - - # Filter incomplete future - if element.agent_future_lens_np[n_i] < element.agent_future_lens_np[ego_i]: - continue - - # Filter parked vehicles - if element.agent_meta_dicts[n_i]['is_stationary']: - continue - - # Distance from predicted agent - dist = np.square(element.agent_futures[ego_i][:, :2] - element.agent_futures[n_i][:, :2]) - dist = np.min(dist.sum(axis=-1), axis=-1) # sum over states, min over time - inds.append(n_i) - dists.append(dist) - - if dists: - pred_i = inds[np.argmin(np.array(dists))] # neighbor that gets closest to current node + if "ego" in agent_names: + ego_i = next(i for i, name in enumerate(agent_names) if name == "ego") else: - # No neighbors or all futures are incomplete - ego_i = -1 - pred_i = -1 + ego_i = element.agents.index(element.centered_agent) + + + # # Find pred agent that is closest to ego + # dists = [] + # inds = [] + # for n_i in range(len(element.agent_futures)): + # if n_i == ego_i: + # continue + + # # Filter what agents we want to predict + # if element.agent_types_np[n_i] not in pred_agent_types: + # continue + + # # Filter incomplete future + # if element.agent_future_lens_np[n_i] < element.agent_future_lens_np[ego_i]: + # continue + + # # Filter parked vehicles + # if element.agent_meta_dicts[n_i]['is_stationary']: + # continue + + # # Distance from predicted agent + # dist = np.square(element.agent_futures[ego_i][:, :2] - element.agent_futures[n_i][:, :2]) + # dist = np.min(dist.sum(axis=-1), axis=-1) # sum over states, min over time + # inds.append(n_i) + # dists.append(dist) + + # if dists: + # pred_i = inds[np.argmin(np.array(dists))] # neighbor that gets closest to current node + # else: + # # No neighbors or all futures are incomplete + # ego_i = -1 + # pred_i = -1 element.extras['robot_ind'] = ego_i - element.extras['pred_agent_ind'] = pred_i + # element.extras['pred_agent_ind'] = pred_i return element diff --git a/diffstack/data/trajdata_datamodules.py b/diffstack/data/trajdata_datamodules.py new file mode 100644 index 0000000..009583f --- /dev/null +++ b/diffstack/data/trajdata_datamodules.py @@ -0,0 +1,229 @@ +import os +import numpy as np +from collections import defaultdict +from torch.utils.data import Dataset +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from diffstack.configs.base import TrainConfig +from diffstack.data import scene_batch_extras, agent_batch_extras + +from trajdata import AgentBatch, AgentType, UnifiedDataset + + +class UnifiedDataModule(pl.LightningDataModule): + def __init__(self, data_config, train_config: TrainConfig): + super(UnifiedDataModule, self).__init__() + self._data_config = data_config + self._train_config = train_config + self.train_dataset = None + self.valid_dataset = None + self.test_dataset = None + self.train_batch_size = self._train_config.training.batch_size + self.val_batch_size = self._train_config.validation.batch_size + try: + self.test_batch_size = self._train_config.test["batch_size"] + self.test_num_workers = self._train_config.validation.num_data_workers + except: + self.test_batch_size = self.val_batch_size + self.test_num_workers = self._train_config.validation.num_data_workers + self._is_setup = False + + def setup(self, stage=None): + if self._is_setup: + return + data_cfg = self._data_config + future_sec = ( + data_cfg.future_num_frames * data_cfg.step_time + ) # python float precision bug + history_sec = ( + data_cfg.history_num_frames * data_cfg.step_time + ) # python float precision bug + neighbor_distance = data_cfg.max_agents_distance + attention_radius = defaultdict( + lambda: 20.0 + ) # Default range is 20m unless otherwise specified. + attention_radius[(AgentType.PEDESTRIAN, AgentType.PEDESTRIAN)] = 10.0 + attention_radius[(AgentType.PEDESTRIAN, AgentType.VEHICLE)] = 20.0 + attention_radius[(AgentType.VEHICLE, AgentType.PEDESTRIAN)] = 20.0 + attention_radius[(AgentType.VEHICLE, AgentType.VEHICLE)] = 30.0 + + if self._data_config.centric == "scene": + # Use actual ego as our planning agent, pick the closest other agent to predict + # for agent-centric prediction models (like T++). + # We only consider vehicles for prediction agent for now. + # Scene-centric prediction models can do prediction for all agents. + pre_filter_transforms = ( + # lambda el: scene_batch_extras.remove_parked(el, keep_agent_ind=0), + lambda el: scene_batch_extras.role_selector( + el, pred_agent_types=[AgentType.VEHICLE] + ), + ) + if self._data_config.get("remove_parked", False): + pre_filter_transforms = pre_filter_transforms + ( + lambda el: scene_batch_extras.remove_parked(el, keep_agent_ind=0), + ) + transforms = pre_filter_transforms + ( + lambda el: scene_batch_extras.make_robot_the_first(el), + # lambda el: scene_batch_extras.augment_with_point_goal(el), + # lambda el: scene_batch_extras.augment_with_lanes(el, make_missing_lane_invalid=True), + # lambda el: scene_batch_extras.augment_with_goal_lanes(el), + ) + get_filter_func = scene_batch_extras.get_filter_func + else: + pre_filter_transforms = ( + lambda el: agent_batch_extras.remove_parked(el), + lambda el: agent_batch_extras.robot_selector(el), + ) + transforms = pre_filter_transforms + ( + lambda el: agent_batch_extras.make_robot_the_first(el), + lambda el: agent_batch_extras.augment_with_point_goal(el), + lambda el: agent_batch_extras.augment_with_lanes( + el, make_missing_lane_invalid=True + ), + lambda el: agent_batch_extras.augment_with_goal_lanes(el), + ) + get_filter_func = agent_batch_extras.get_filter_func + try: + cache_dir = data_cfg.get("cache_dir", os.environ["TRAJDATA_CACHE_DIR"]) + except: + cache_dir = "~/.unified_avdata_cache" + kwargs = dict( + centric=data_cfg.centric, + desired_data=[data_cfg.trajdata_source_train], + desired_dt=data_cfg.step_time, + future_sec=(future_sec, future_sec), + history_sec=(history_sec, history_sec), + data_dirs={ + data_cfg.trajdata_source_root: data_cfg.dataset_path, + }, + # only_types=[AgentType.VEHICLE,AgentType.PEDESTRIAN], + agent_interaction_distances=defaultdict(lambda: neighbor_distance), + incl_raster_map=data_cfg.get("incl_raster_map", False), + raster_map_params={ + "px_per_m": int(1 / data_cfg.pixel_size), + "map_size_px": data_cfg.raster_size, + "return_rgb": False, + "offset_frac_xy": data_cfg.raster_center, + "original_format": True, + }, + state_format="x,y,xd,yd,xdd,ydd,s,c", + obs_format="x,y,xd,yd,xdd,ydd,s,c", + incl_vector_map=data_cfg.get("incl_vector_map", False), + verbose=True, + max_agent_num=1 + data_cfg.other_agents_num, + # max_neighbor_num = data_cfg.other_agents_num, + num_workers=min(os.cpu_count(), 64), + # num_workers = 0, + ego_only=self._train_config.ego_only, + transforms=transforms, + rebuild_cache=self._train_config.rebuild_cache, + cache_location=cache_dir, + save_index=False, + ) + if kwargs["incl_vector_map"]: + kwargs["vector_map_params"] = { + "incl_road_lanes": True, + "incl_road_areas": False, + "incl_ped_crosswalks": False, + "incl_ped_walkways": False, + "max_num_lanes": data_cfg.max_num_lanes, + "num_lane_pts": data_cfg.num_lane_pts, + "remove_single_successor": data_cfg.remove_single_successor, + # Collation can be quite slow if vector maps are included, + # so we do not unless the user requests it. + "collate": True, + "calc_lane_graph": data_cfg.get("calc_lane_graph", False), + } + if "waymo" in data_cfg.trajdata_source_root: + kwargs["vector_map_params"]["keep_in_memory"] = False + kwargs["vector_map_params"]["radius"] = 300 + print(kwargs) + self.train_dataset = UnifiedDataset(**kwargs) + + # prepare validation dataset + kwargs["desired_data"] = [data_cfg.trajdata_source_valid] + if data_cfg.trajdata_val_source_root is not None: + kwargs["data_dirs"] = { + data_cfg.trajdata_val_source_root: data_cfg.dataset_path, + } + self.valid_dataset = UnifiedDataset(**kwargs) + + # prepare test dataset if specified + if data_cfg.trajdata_source_test is not None: + kwargs["desired_data"] = [data_cfg.trajdata_source_test] + if data_cfg.trajdata_test_source_root is not None: + kwargs["data_dirs"] = { + data_cfg.trajdata_test_source_root: data_cfg.dataset_path, + } + self.test_dataset = UnifiedDataset(**kwargs) + self._is_setup = True + + def train_dataloader(self): + batch_name = ( + "scene_batch" if self._data_config.centric == "scene" else "agent_batch" + ) + collate_fn = lambda *args, **kwargs: { + batch_name: self.train_dataset.get_collate_fn(return_dict=False)( + *args, **kwargs + ) + } + + return DataLoader( + dataset=self.train_dataset, + shuffle=True, + batch_size=self.train_batch_size, + num_workers=self._train_config.training.num_data_workers, + drop_last=True, + collate_fn=collate_fn, + persistent_workers=True + if self._train_config.training.num_data_workers + else False, + ) + + def val_dataloader(self): + batch_name = ( + "scene_batch" if self._data_config.centric == "scene" else "agent_batch" + ) + collate_fn = lambda *args, **kwargs: { + batch_name: self.valid_dataset.get_collate_fn(return_dict=False)( + *args, **kwargs + ) + } + + return DataLoader( + dataset=self.valid_dataset, + shuffle=False, + batch_size=self.val_batch_size, + num_workers=self._train_config.validation.num_data_workers, + drop_last=True, + collate_fn=collate_fn, + persistent_workers=True + if self._train_config.validation.num_data_workers > 0 + else False, + ) + + def test_dataloader(self): + batch_name = ( + "scene_batch" if self._data_config.centric == "scene" else "agent_batch" + ) + collate_fn = lambda *args, **kwargs: { + batch_name: self.test_dataset.get_collate_fn(return_dict=False)( + *args, **kwargs + ) + } + return ( + self.val_dataloader() + if self.test_dataset is None + else DataLoader( + dataset=self.test_dataset, + shuffle=False, + batch_size=self.test_batch_size, + num_workers=self.test_num_workers, + drop_last=True, + collate_fn=collate_fn, + persistent_workers=True if self.test_num_workers > 0 else False, + ) + ) + + def predict_dataloader(self): + return self.val_dataloader diff --git a/diffstack/data/trajdata_interface.py b/diffstack/data/trajdata_interface.py deleted file mode 100644 index 88ae288..0000000 --- a/diffstack/data/trajdata_interface.py +++ /dev/null @@ -1,204 +0,0 @@ -import torch -import numpy as np -import os -import dill -import json - -from tqdm import tqdm -from collections import defaultdict, OrderedDict -from typing import Dict, Iterable, Union -from time import time - -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from trajdata import AgentBatch, UnifiedDataset, AgentType -from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement -from trajdata.data_structures.agent import AgentType -from trajdata.augmentation import NoiseHistories - -from diffstack.data import agent_batch_extras, scene_batch_extras -from diffstack.utils.utils import all_gather - - -def prepare_avdata(rank, hyperparams, scene_centric=True, use_cache=False): - - # Load cached data or process and cache data if cache file does not exists. - cache_params = ".".join([str(hyperparams[k]) for k in - ["prediction_sec", "history_sec"]]) + ".v6" - cached_train_data_path = os.path.join(hyperparams['trajdata_cache_dir'], f"{hyperparams['train_data']}.{cache_params}.cached.trajdata.pkl") - cached_eval_data_path = os.path.join(hyperparams['trajdata_cache_dir'], f"{hyperparams['eval_data']}.{cache_params}.cached.trajdata.pkl") - - # Load training and evaluation environments and scenes - attention_radius = defaultdict(lambda: 20.0) # Default range is 20m unless otherwise specified. - attention_radius[(AgentType.PEDESTRIAN, AgentType.PEDESTRIAN)] = 10.0 - attention_radius[(AgentType.PEDESTRIAN, AgentType.VEHICLE)] = 20.0 - attention_radius[(AgentType.VEHICLE, AgentType.PEDESTRIAN)] = 20.0 - attention_radius[(AgentType.VEHICLE, AgentType.VEHICLE)] = 30.0 - - data_dirs: Dict[str, str] = json.loads(hyperparams['data_loc_dict']) - - augmentations = list() - if hyperparams['augment_input_noise'] > 0.0: - augmentations.append(NoiseHistories(stddev=hyperparams['augment_input_noise'])) - - map_params = {"px_per_m": 2, "map_size_px": 100, "offset_frac_xy": (-0.75, 0.0)} - - if hyperparams["plan_agent_choice"] == "most_relevant": - pass - elif hyperparams["plan_agent_choice"] == "ego": - raise NotImplementedError - else: - raise ValueError("Unknown plan_agent_choice: {}".format(hyperparams["plan_agent_choice"])) - - # # Debug with NuScenes map API - # nusc_maps = {data_name: - # {map_name: NuScenesMap(os.path.expanduser(data_path), map_name) for map_name in nusc_map_names} - # for data_name, data_path in data_dirs.items()} - # transforms = ( - # lambda el: remove_parked(el), - # lambda el: robot_selector(el), - # lambda el: make_robot_the_first(el), - # lambda el: augment_with_goal(el), - # # lambda el: augment_with_lanes_nusc(el, nusc_maps), - # # lambda el: augment_with_lanes_compare(el, nusc_maps), - # lambda el: augment_with_lanes(el), - # # lambda el: augment_with_goal_lanes_nusc(el, nusc_maps), - # lambda el: augment_with_goal_lanes(el), - # ) - - if scene_centric: - # Use actual ego as our planning agent, pick the closest other agent to predict - # for agent-centric prediction models (like T++). - # We only consider vehicles for prediction agent for now. - # Scene-centric prediction models can do prediction for all agents. - pre_filter_transforms = ( - lambda el: scene_batch_extras.remove_parked(el), - lambda el: scene_batch_extras.role_selector(el, pred_agent_types=[AgentType.VEHICLE]), - ) - transforms = pre_filter_transforms + ( - lambda el: scene_batch_extras.make_robot_the_first(el), - lambda el: scene_batch_extras.augment_with_point_goal(el), - lambda el: scene_batch_extras.augment_with_lanes(el, make_missing_lane_invalid=True), - lambda el: scene_batch_extras.augment_with_goal_lanes(el), - ) - get_filter_func = scene_batch_extras.get_filter_func - else: - pre_filter_transforms = ( - lambda el: agent_batch_extras.remove_parked(el), - lambda el: agent_batch_extras.robot_selector(el), - ) - transforms = pre_filter_transforms + ( - lambda el: agent_batch_extras.make_robot_the_first(el), - lambda el: agent_batch_extras.augment_with_point_goal(el), - lambda el: agent_batch_extras.augment_with_lanes(el, make_missing_lane_invalid=True), - lambda el: agent_batch_extras.augment_with_goal_lanes(el), - ) - get_filter_func = agent_batch_extras.get_filter_func - - eval_dataset = UnifiedDataset(desired_data=[hyperparams['eval_data']], - desired_dt=hyperparams['dt'], - centric="scene" if scene_centric else "agent", - history_sec=(hyperparams['history_sec'], hyperparams['history_sec']), - future_sec=(hyperparams['prediction_sec'], hyperparams['prediction_sec']), - agent_interaction_distances=attention_radius, - incl_robot_future=hyperparams['incl_robot_node'], - incl_raster_map=hyperparams['map_encoding'], - incl_vector_map=True, - raster_map_params=map_params, - only_predict=[node_type for node_type in AgentType if node_type.name in hyperparams['pred_state']], - no_types=[AgentType.UNKNOWN, AgentType.BICYCLE, AgentType.MOTORCYCLE], - num_workers=hyperparams['preprocess_workers'], - cache_location=hyperparams['trajdata_cache_dir'], - data_dirs=data_dirs, - # extras=OrderedDict(robot_ind=robot_selector), - transforms=pre_filter_transforms, - verbose=True, - rank=rank, - rebuild_cache=hyperparams['rebuild_cache']) - - train_dataset = UnifiedDataset(desired_data=[hyperparams['train_data']], - desired_dt=hyperparams['dt'], - centric="scene" if scene_centric else "agent", - history_sec=(0.1, hyperparams['history_sec']), - # future_sec=(0.1, hyperparams['prediction_sec']), # TODO support planning with partial predictions - future_sec=(hyperparams['prediction_sec'], hyperparams['prediction_sec']), - agent_interaction_distances=attention_radius, - incl_robot_future=hyperparams['incl_robot_node'], - incl_raster_map=hyperparams['map_encoding'], - incl_vector_map=True, - raster_map_params=map_params, - only_predict=[node_type for node_type in AgentType if node_type.name in hyperparams['pred_state']], - no_types=[AgentType.UNKNOWN, AgentType.BICYCLE, AgentType.MOTORCYCLE], - augmentations=augmentations if len(augmentations) > 0 else None, - num_workers=hyperparams['preprocess_workers'], - cache_location=hyperparams['trajdata_cache_dir'], - data_dirs=data_dirs, - # extras=OrderedDict(robot_ind=robot_selector), - transforms=pre_filter_transforms, - verbose=True, - rank=rank, - rebuild_cache=hyperparams['rebuild_cache']) - - - # Filter / cache dataset. - # TODO(pkarkus) Filtering is quite slow currently and inefficient for multi-GPU setup - # because each process has to scan through the entire dataset. - filter_fn_train = get_filter_func( - ego_valid=hyperparams['filter_plan_valid'], - pred_not_parked=hyperparams['filter_pred_not_parked'], - pred_near_ego=hyperparams['filter_pred_near_ego'], - ) - filter_fn_eval = filter_fn_train - - # # TODO(pkarkus) Temporarily disable filtering for agentcetric input to make debugging faster. - # if not scene_centric: - # filter_fn_train = None - # filter_fn_eval = None - - if use_cache: - eval_dataset.load_or_create_cache(cached_eval_data_path, num_workers=hyperparams['preprocess_workers'], - filter_fn=filter_fn_eval) - train_dataset.load_or_create_cache(cached_train_data_path, num_workers=hyperparams['preprocess_workers'], - filter_fn=filter_fn_train) - else: - eval_dataset.apply_filter(filter_fn=filter_fn_eval, num_workers=hyperparams['preprocess_workers'], max_count=64000, all_gather=all_gather) - train_dataset.apply_filter(filter_fn=filter_fn_train, num_workers=hyperparams['preprocess_workers'], max_count=512000, all_gather=all_gather) - - eval_dataset.transforms = transforms - train_dataset.transforms = transforms - - - # Create samplers - eval_sampler = DistributedSampler( - eval_dataset, - num_replicas=hyperparams["world_size"], - rank=rank - ) - train_sampler = DistributedSampler( - train_dataset, - num_replicas=hyperparams["world_size"], - rank=rank - ) - - - # Wrap in a dataloader that samples datapoints and constructs batches - eval_dataloader = DataLoader(eval_dataset, - collate_fn=eval_dataset.get_collate_fn(pad_format="right"), - pin_memory=False if hyperparams['device'] == 'cpu' else True, - batch_size=hyperparams['eval_batch_size'], - shuffle=False, - num_workers=hyperparams['preprocess_workers'], - sampler=eval_sampler) - train_dataloader = DataLoader(train_dataset, - collate_fn=train_dataset.get_collate_fn(pad_format="right"), - pin_memory=False if hyperparams['device'] == 'cpu' else True, - batch_size=hyperparams['batch_size'], - shuffle=False, - num_workers=hyperparams['preprocess_workers'], - sampler=train_sampler) - - input_wrapper = lambda batch: {"batch": batch} - - return train_dataloader, train_sampler, train_dataset, eval_dataloader, eval_sampler, eval_dataset, input_wrapper - diff --git a/diffstack/data/trajdata_lanes.py b/diffstack/data/trajdata_lanes.py index 75cf9eb..01bbdbb 100644 --- a/diffstack/data/trajdata_lanes.py +++ b/diffstack/data/trajdata_lanes.py @@ -2,34 +2,57 @@ import numpy as np from typing import Dict, Iterable +from time import time +from trajdata.caching.scene_cache import SceneCache from trajdata.utils.arr_utils import angle_wrap +from trajdata.utils.map_utils import get_polyline_headings from trajdata.data_structures.collation import CustomCollateData -from trajdata.maps.vec_map import VectorMap -from trajdata.utils.arr_utils import batch_nd_transform_points_angles_np, angle_wrap, batch_proj +from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement +from trajdata.maps.vec_map import VectorMap, RoadLane, Polyline + +# from trajdata.maps.map_kdtree import LaneCenterKDTree +from trajdata.utils.arr_utils import ( + batch_nd_transform_points_np, + batch_nd_transform_angles_np, + batch_nd_transform_points_angles_np, + angle_wrap, +) +from diffstack.utils.geometry_utils import batch_proj from diffstack.utils.utils import convert_state_pred2plan, lane_frenet_features_simple +# # For nuscenes-based implementation +# from nuscenes.map_expansion.map_api import NuScenesMap, locations as nusc_map_names +# from diffstack.data.manual_preprocess_plan_data import get_lane_reference_points + + class LanesList(list, CustomCollateData): @staticmethod def __collate__(elements: list) -> any: - return LanesList([[torch.as_tensor(pts) for pts in lanelist] for lanelist in elements]) + return LanesList( + [[torch.as_tensor(pts) for pts in lanelist] for lanelist in elements] + ) def __to__(self, device, non_blocking=False): # Always keep on cpu del device del non_blocking - return self - + return self -def get_lane_reference_points_from_polylines(global_xyh: np.ndarray, lane_polylines_xyh: Iterable[np.ndarray], max_num_points: int = 16, resolution_meters: float = 1.0): +def get_lane_reference_points_from_polylines( + global_xyh: np.ndarray, + lane_polylines_xyh: Iterable[np.ndarray], + max_num_points: int = 16, + resolution_meters: float = 1.0, +): # Previosly we did an interpolation of 1m. This was important for cost functions based on N ref points, # but its irrelevant for cost functions using projection of the expert state onto the lane segment. # The projection will remain the same regardless of interpolated points. - # The only difference would be in headings, which are not interpolated this way. One solution is to do + # The only difference would be in headings, which are not interpolated this way. One solution is to do # heading interpolation after projection. - + if not lane_polylines_xyh: return np.zeros((global_xyh.shape[0], 0, 3), dtype=np.float32) @@ -41,9 +64,10 @@ def get_lane_reference_points_from_polylines(global_xyh: np.ndarray, lane_polyli heading_diff_mat = pts_xyh[None, :, 2] - global_xyh[:, None, 2] heading_diff_mat = np.minimum( np.abs(angle_wrap(heading_diff_mat)), # same direction - np.abs(angle_wrap(heading_diff_mat + np.pi))) # opposite direction + np.abs(angle_wrap(heading_diff_mat + np.pi)), + ) # opposite direction # Add a large constant if heading differs by over 45 degrees - d2mat += (heading_diff_mat > np.pi/4).astype(d2mat.dtype) * 1e6 + d2mat += (heading_diff_mat > np.pi / 4).astype(d2mat.dtype) * 1e6 closest_ind = np.argsort(d2mat, axis=-1)[:, :max_num_points] pts_xyh = pts_xyh[closest_ind] # (traj, max_num_points, 4) @@ -51,74 +75,91 @@ def get_lane_reference_points_from_polylines(global_xyh: np.ndarray, lane_polyli def get_lane_projection_points( - vec_map: VectorMap, - ego_histories: np.ndarray, - ego_futures: np.ndarray, - agent_from_world_tf: np.ndarray, - ) -> np.ndarray: + vec_map: VectorMap, + ego_histories: np.ndarray, + ego_futures: np.ndarray, + agent_from_world_tf: np.ndarray, +) -> np.ndarray: """Points along a lane that are closest to GT future ego states.""" - + + if vec_map is None: + return None # Get present and future for the ego agent. ego_pres_future = np.concatenate([ego_histories[-1:], ego_futures], axis=0) ego_xyhv = convert_state_pred2plan(ego_pres_future) world_from_agent_tf = np.linalg.inv(agent_from_world_tf) - ego_xyh_world = batch_nd_transform_points_angles_np(ego_xyhv[:, :3], world_from_agent_tf) + ego_xyh_world = batch_nd_transform_points_angles_np( + ego_xyhv[:, :3], world_from_agent_tf + ) # trajlen = ego_xyh_world.shape[0] # if trajlen == 0: # lane_points_xyh_world = [] # else: - ego_xyz_world = np.concatenate((ego_xyh_world[:, :2], np.zeros_like(ego_xyh_world[:, :1])), axis=-1) + ego_xyz_world = np.concatenate( + (ego_xyh_world[:, :2], np.zeros_like(ego_xyh_world[:, :1])), axis=-1 + ) closest_lanes = vec_map.get_closest_unique_lanes(ego_xyz_world) lane_polylines_xyh = [lane.center.points[:, (0, 1, 3)] for lane in closest_lanes] - lane_points_xyh_world = get_lane_reference_points_from_polylines(ego_xyh_world, lane_polylines_xyh, max_num_points=1) + lane_points_xyh_world = get_lane_reference_points_from_polylines( + ego_xyh_world, lane_polylines_xyh, max_num_points=1 + ) if lane_points_xyh_world.shape[1] == 0: # No lanes. return None else: # World to agent coordinates - lane_points_xyh = batch_nd_transform_points_angles_np(lane_points_xyh_world[..., :3], agent_from_world_tf[None]) + lane_points_xyh = batch_nd_transform_points_angles_np( + lane_points_xyh_world[..., :3], agent_from_world_tf[None] + ) - # Project state to lane - # TODO this is based on a single lane point and its heading. Instead we should use lane segment. - # Get the two neighboring segments, project to both and choose shorter. - # TODO in kdtree we whould interpolate polylines to a max_distance of e.g. 1m - lane_projection_points = lane_frenet_features_simple(ego_xyhv[..., :3], lane_points_xyh[:, :]) + lane_projection_points = lane_frenet_features_simple( + ego_xyhv[..., :3], lane_points_xyh[:, :] + ) assert lane_projection_points.shape == (ego_futures.shape[0] + 1, 3) - + return lane_projection_points.astype(np.float32) def get_goal_lanes( - vec_map: VectorMap, - goal_xyvvaahh: np.ndarray, - agent_from_world_tf: np.ndarray, - goal_to_lane_range: float = 20., - max_lateral_dist: float = 4.5, - max_heading_delta: float = np.pi/4, - ): + vec_map: VectorMap, + goal_xyvvaahh: np.ndarray, + agent_from_world_tf: np.ndarray, + goal_to_lane_range: float = 20.0, + max_lateral_dist: float = 4.5, + max_heading_delta: float = np.pi / 4, +): assert goal_xyvvaahh.shape[-1] == 8 # xyvvaahh goal_xyhv = convert_state_pred2plan(goal_xyvvaahh) world_from_agent_tf = np.linalg.inv(agent_from_world_tf) - goal_xyh_world = batch_nd_transform_points_angles_np(goal_xyhv[:3], world_from_agent_tf) + goal_xyh_world = batch_nd_transform_points_angles_np( + goal_xyhv[:3], world_from_agent_tf + ) # Find lanes in range - goal_xyz_world = np.concatenate((goal_xyh_world[:2], np.zeros_like(goal_xyh_world[:1])), axis=-1) + goal_xyz_world = np.concatenate( + (goal_xyh_world[:2], np.zeros_like(goal_xyh_world[:1])), axis=-1 + ) near_lanes = vec_map.get_lanes_within(goal_xyz_world, dist=goal_to_lane_range) near_lanes_xyh_world = [lane.center.points[:, (0, 1, 3)] for lane in near_lanes] - # Filter + # Filter lanes_xyh_world = [] for lane_xyh in near_lanes_xyh_world: delta_x, delta_y, dpsi = batch_proj(goal_xyh_world, lane_xyh) - if abs(dpsi[0]) < max_heading_delta and np.min(np.abs(delta_y)) < max_lateral_dist: + if ( + abs(dpsi[0]) < max_heading_delta + and np.min(np.abs(delta_y)) < max_lateral_dist + ): lanes_xyh_world.append(lane_xyh) # World to agent coordinates - goal_lanes = [batch_nd_transform_points_angles_np( - lane_xyh, agent_from_world_tf[None]) for lane_xyh in lanes_xyh_world] + goal_lanes = [ + batch_nd_transform_points_angles_np(lane_xyh, agent_from_world_tf[None]) + for lane_xyh in lanes_xyh_world + ] - return LanesList(goal_lanes) \ No newline at end of file + return LanesList(goal_lanes) diff --git a/diffstack/dynamics/__init__.py b/diffstack/dynamics/__init__.py new file mode 100644 index 0000000..b0288a3 --- /dev/null +++ b/diffstack/dynamics/__init__.py @@ -0,0 +1,21 @@ +from typing import Union + +from diffstack.dynamics.single_integrator import SingleIntegrator +from diffstack.dynamics.unicycle import Unicycle +from diffstack.dynamics.bicycle import Bicycle +from diffstack.dynamics.double_integrator import DoubleIntegrator +from diffstack.dynamics.base import Dynamics, DynType +from diffstack.dynamics.unicycle import Unicycle + + +def get_dynamics_model(dyn_type: Union[str, DynType]): + if dyn_type in ["Unicycle", DynType.UNICYCLE]: + return Unicycle + elif dyn_type == ["SingleIntegrator", DynType.SI]: + return SingleIntegrator + elif dyn_type == ["DoubleIntegrator", DynType.DI]: + return DoubleIntegrator + else: + raise NotImplementedError( + "Dynamics model {} is not implemented".format(dyn_type) + ) diff --git a/diffstack/dynamics/base.py b/diffstack/dynamics/base.py new file mode 100644 index 0000000..b43e32e --- /dev/null +++ b/diffstack/dynamics/base.py @@ -0,0 +1,87 @@ +import torch +import numpy as np +import math, copy, time +import abc +from copy import deepcopy + + +class DynType: + """ + Holds environment types - one per environment class. + These act as identifiers for different environments. + """ + + UNICYCLE = 1 + SI = 2 + DI = 3 + BICYCLE = 4 + DDI = 5 + + +class Dynamics(abc.ABC): + @abc.abstractmethod + def __init__(self, dt, name, **kwargs): + self.dt = dt + self._name = name + self.xdim = 4 + self.udim = 2 + + @abc.abstractmethod + def __call__(self, x, u): + return + + @abc.abstractmethod + def step(self, x, u, dt, bound=True): + return + + def name(self): + return self._name + + @abc.abstractmethod + def type(self): + return + + @abc.abstractmethod + def ubound(self, x): + return + + @staticmethod + def state2pos(x): + return + + @staticmethod + def state2yaw(x): + return + + @staticmethod + def get_state(pos,yaw,dt,mask): + return + + def get_input_violation(self,x,u): + lb, ub = self.ubound(x) + return torch.maximum((lb-u).clip(min=0.0), (u-ub).clip(min=0.0)) + + + def forward_dynamics(self,x0: torch.Tensor,u: torch.Tensor, include_step0: bool = False): + """ + Integrate the state forward with initial state x0, action u + Args: + initial_states (Torch.tensor): state tensor of size [B, (A), 4] + u (Torch.tensor): action tensor of size [B, (A), T, 2] + include_step0 (bool): the output trajectory will include the current state if true. + Returns: + state tensor of size [B, (A), T, 4] + """ + num_steps = u.shape[-2] + x = [x0] + [None] * num_steps + for t in range(num_steps): + x[t + 1] = self.step(x[t], u[..., t, :]) + + if include_step0: + x = torch.stack(x, dim=-2) + else: + x = torch.stack(x[1:], dim=-2) + return x + + + diff --git a/diffstack/dynamics/bicycle.py b/diffstack/dynamics/bicycle.py new file mode 100644 index 0000000..3590f37 --- /dev/null +++ b/diffstack/dynamics/bicycle.py @@ -0,0 +1,153 @@ +import torch +import math + +from diffstack.dynamics.base import Dynamics, DynType + + +def bicycle_model(state, acc, ddh, vehicle_length, dt, max_hdot=math.pi * 2.0, max_s=50.0): + """ + Simple differentiable bicycle model that does not allow reverse + Args: + state (torch.Tensor): a batch of current kinematic state [B, ..., 5] (x, y, yaw, speed, hdot) + acc (torch.Tensor): a batch of acceleration profile [B, ...] (acc) + ddh (torch.Tensor): a batch of heading acceleration profile [B, ...] (heading) + vehicle_length (torch.Tensor): a batch of vehicle length [B, ...] (length) + dt (float): time between steps + max_hdot (float): maximum change of heading (rad/s) + max_s (float): maximum speed (m/s) + + Returns: + New kinematic state (torch.Tensor) + """ + # state: (x, y, h, speed, hdot) + assert state.shape[-1] == 5 + newhdot = (state[..., 4] + ddh * dt).clamp(-max_hdot, max_hdot) + newh = state[..., 2] + dt * state[..., 3].abs() / vehicle_length * newhdot + news = (state[..., 3] + acc * dt).clamp(0.0, max_s) # no reverse + newy = state[..., 1] + news * newh.sin() * dt + newx = state[..., 0] + news * newh.cos() * dt + + newstate = torch.empty_like(state) + newstate[..., 0] = newx + newstate[..., 1] = newy + newstate[..., 2] = newh + newstate[..., 3] = news + newstate[..., 4] = newhdot + + return newstate + + +class Bicycle(Dynamics): + + def __init__( + self, + dt, + acc_bound=(-10, 8), + ddh_bound=(-math.pi * 2.0, math.pi * 2.0), + max_speed=50.0, + max_hdot=math.pi * 2.0 + ): + """ + A simple bicycle dynamics model + Args: + acc_bound (tuple): acceleration bound (m/s^2) + ddh_bound (tuple): angular acceleration bound (rad/s^2) + max_speed (float): maximum speed, must be positive + max_hdot (float): maximum turning speed, must be positive + """ + super(Bicycle, self).__init__(name="bicycle") + self.xdim = 6 + self.udim = 2 + self.dt = dt + assert max_speed >= 0 + assert max_hdot >= 0 + self.acc_bound = acc_bound + self.ddh_bound = ddh_bound + self.max_speed = max_speed + self.max_hdot = max_hdot + + def get_normalized_controls(self, u): + u = torch.sigmoid(u) # normalize to [0, 1] + acc = self.acc_bound[0] + (self.acc_bound[1] - self.acc_bound[0]) * u[..., 0] + ddh = self.ddh_bound[0] + (self.ddh_bound[1] - self.ddh_bound[0]) * u[..., 1] + return acc, ddh + + def get_clipped_controls(self, u): + acc = torch.clip(u[..., 0], self.acc_bound[0], self.acc_bound[1]) + ddh = torch.clip(u[..., 1], self.ddh_bound[0], self.ddh_bound[1]) + return acc, ddh + + def step(self, x, u, normalize=True): + """ + Take a step with the dynamics model + Args: + x (torch.Tensor): current state [B, ..., 6] (x, y, h, speed, dh, veh_length) + u (torch.Tensor): (un-normalized) actions [B, ..., 2] (acc, ddh) + dt (float): time between steps + normalize (bool): whether to normalize the actions + + Returns: + next_x (torch.Tensor): next state after taking the action + """ + assert x.shape[-1] == self.xdim + assert u.shape[:-1] == x.shape[:-1] + assert u.shape[-1] == self.udim + if normalize: + acc, ddh = self.get_normalized_controls(u) + else: + acc, ddh = self.get_clipped_controls(u) + next_x = x.clone() # keep the extent the same + next_x[..., :5] = bicycle_model( + state=x[..., :5], + acc=acc, + ddh=ddh, + vehicle_length=x[..., 5], + dt=self.dt, + max_hdot=self.max_hdot, + max_s=self.max_speed + ) + return next_x + + def calculate_vel(self,pos, yaw, mask): + + vel = (pos[..., 1:, 0:1] - pos[..., :-1, 0:1]) / self.dt * torch.cos( + yaw[..., 1:, :] + ) + (pos[..., 1:, 1:2] - pos[..., :-1, 1:2]) / self.dt * torch.sin( + yaw[..., 1:, :] + ) + # right finite difference velocity + vel_r = torch.cat((vel[..., 0:1, :], vel), dim=-2) + # left finite difference velocity + vel_l = torch.cat((vel, vel[..., -1:, :]), dim=-2) + mask_r = torch.roll(mask, 1, dims=-1) + mask_r[..., 0] = False + mask_r = mask_r & mask + + mask_l = torch.roll(mask, -1, dims=-1) + mask_l[..., -1] = False + mask_l = mask_l & mask + vel = ( + (mask_l & mask_r).unsqueeze(-1) * (vel_r + vel_l) / 2 + + (mask_l & (~mask_r)).unsqueeze(-1) * vel_l + + (mask_r & (~mask_l)).unsqueeze(-1) * vel_r + ) + + return vel + + def type(self): + return DynType.BICYCLE + + def state2pos(self, x): + return x[..., :2] + + def state2yaw(self, x): + return x[..., 2:3] + + def __call__(self, x, u): + pass + + def ubound(self, x): + pass + + def name(self): + return self._name diff --git a/diffstack/dynamics/double_integrator.py b/diffstack/dynamics/double_integrator.py new file mode 100644 index 0000000..3e13630 --- /dev/null +++ b/diffstack/dynamics/double_integrator.py @@ -0,0 +1,202 @@ +from diffstack.dynamics.base import DynType, Dynamics +from diffstack.utils.math_utils import soft_sat +import torch +import numpy as np + + +class DoubleIntegrator(Dynamics): + def __init__(self, dt, name="DI_model", abound=None, vbound=None): + self._name = name + self._type = DynType.DI + self.xdim = 4 + self.udim = 2 + self.dt = dt + self.cyclic_state = list() + self.vbound = vbound + self.abound = abound + + def __call__(self, x, u): + assert x.shape[:-1] == u.shape[:, -1] + if isinstance(x, np.ndarray): + return np.hstack((x[..., 2:], u)) + elif isinstance(x, torch.Tensor): + return torch.cat((x[..., 2:], u), dim=-1) + else: + raise NotImplementedError + + def step(self, x, u, bound=True, return_jacobian=False): + if self.abound is None: + bound = False + + if isinstance(x, np.ndarray): + if bound: + lb, ub = self.ubound(x) + u = np.clip(u, lb, ub) + xn = np.hstack( + ( + (x[..., 2:4] + 0.5 * u * self.dt) * self.dt + x[..., 0:2], + x[..., 2:4] + u * self.dt, + ) + ) + if return_jacobian: + raise NotImplementedError + else: + return xn + elif isinstance(x, torch.Tensor): + if bound: + lb, ub = self.ubound(x) + u = torch.clip(u, min=lb, max=ub) + xn = torch.clone(x) + xn[..., 0:2] += (x[..., 2:4] + 0.5 * u * self.dt) * self.dt + xn[..., 2:4] += u * self.dt + if return_jacobian: + jacx = torch.cat( + [ + torch.zeros([4, 2]), + torch.cat([torch.eye(2) * self.dt, torch.zeros([2, 2])], 0), + ], + -1, + ) + torch.eye(4) + jacx = torch.tile(jacx, (*x.shape[:-1], 1, 1)).to(x.device) + + jacu = torch.cat( + [torch.eye(2) * self.dt**2 * 2, torch.eye(2) * self.dt], 0 + ) + jacu = torch.tile(jacu, (*x.shape[:-1], 1, 1)).to(x.device) + return xn, jacx, jacu + else: + return xn + else: + raise NotImplementedError + + def get_x_Gaussian_from_u(self, x, mu_u, var_u): + mu_x, _, jacu = self.step(x, mu_u, bound=False, return_jacobian=True) + + var_u_mat = torch.diag_embed(var_u) + var_x = torch.matmul(torch.matmul(jacu, var_u_mat), jacu.transpose(-1, -2)) + return mu_x, var_x + + def name(self): + return self._name + + def type(self): + return self._type + + def ubound(self, x): + if self.vbound is None: + if isinstance(x, np.ndarray): + lb = np.ones_like(x[..., 2:]) * self.abound[0] + ub = np.ones_like(x[..., 2:]) * self.abound[1] + + elif isinstance(x, torch.Tensor): + lb = torch.ones_like(x[..., 2:]) * torch.from_numpy( + self.abound[:, 0] + ).to(x.device) + ub = torch.ones_like(x[..., 2:]) * torch.from_numpy( + self.abound[:, 1] + ).to(x.device) + + else: + raise NotImplementedError + else: + if isinstance(x, np.ndarray): + lb = (x[..., 2:] > self.vbound[0]) * self.abound[0] + ub = (x[..., 2:] < self.vbound[1]) * self.abound[1] + + elif isinstance(x, torch.Tensor): + lb = ( + x[..., 2:] > torch.from_numpy(self.vbound[0]).to(x.device) + ) * torch.from_numpy(self.abound[0]).to(x.device) + ub = ( + x[..., 2:] < torch.from_numpy(self.vbound[1]).to(x.device) + ) * torch.from_numpy(self.abound[1]).to(x.device) + else: + raise NotImplementedError + return lb, ub + + @staticmethod + def state2pos(x): + return x[..., 0:2] + + @staticmethod + def state2yaw(x): + # return torch.atan2(x[..., 3:], x[..., 2:3]) + return torch.zeros_like(x[..., 0:1]) + + def inverse_dyn(self, x, xp): + return (xp[..., 2:] - x[..., 2:]) / self.dt + + def calculate_vel(self, pos, yaw, mask): + vel = (pos[..., 1:, :] - pos[..., :-1, :]) / self.dt + if isinstance(pos, torch.Tensor): + # right finite difference velocity + vel_r = torch.cat((vel[..., 0:1, :], vel), dim=-2) + # left finite difference velocity + vel_l = torch.cat((vel, vel[..., -1:, :]), dim=-2) + mask_r = torch.roll(mask, 1, dims=-1) + mask_r[..., 0] = False + mask_r = mask_r & mask + + mask_l = torch.roll(mask, -1, dims=-1) + mask_l[..., -1] = False + mask_l = mask_l & mask + vel = ( + (mask_l & mask_r).unsqueeze(-1) * (vel_r + vel_l) / 2 + + (mask_l & (~mask_r)).unsqueeze(-1) * vel_l + + (mask_r & (~mask_l)).unsqueeze(-1) * vel_r + ) + elif isinstance(pos, np.ndarray): + # right finite difference velocity + vel_r = np.concatenate((vel[..., 0:1, :], vel), axis=-2) + # left finite difference velocity + vel_l = np.concatenate((vel, vel[..., -1:, :]), axis=-2) + mask_r = np.roll(mask, 1, axis=-1) + mask_r[..., 0] = False + mask_r = mask_r & mask + mask_l = np.roll(mask, -1, axis=-1) + mask_l[..., -1] = False + mask_l = mask_l & mask + vel = ( + np.expand_dims(mask_l & mask_r, -1) * (vel_r + vel_l) / 2 + + np.expand_dims(mask_l & (~mask_r), -1) * vel_l + + np.expand_dims(mask_r & (~mask_l), -1) * vel_r + ) + else: + raise NotImplementedError + return vel + + def get_state(self, pos, yaw, dt, mask): + vel = self.calculate_vel(pos, yaw, mask) + if isinstance(vel, np.ndarray): + return np.concatenate((pos, vel), -1) + elif isinstance(vel, torch.Tensor): + return torch.cat((pos, vel), -1) + + def forward_dynamics( + self, + x0: torch.Tensor, + u: torch.Tensor, + include_step0: bool = False, + ): + if include_step0: + raise NotImplementedError + + if isinstance(u, np.ndarray): + u = np.clip(u, self.abound[0], self.abound[1]) + delta_v = np.cumsum(u * self.dt, -2) + vel = x0[..., np.newaxis, 2:] + delta_v + vel = np.clip(vel, self.vbound[0], self.vbound[1]) + delta_xy = np.cumsum(vel * self.dt, -2) + xy = x0[..., np.newaxis, :2] + delta_xy + + traj = np.concatenate((xy, vel), -1) + elif isinstance(u, torch.Tensor): + u = soft_sat(u, self.abound[0], self.abound[1]) + delta_v = torch.cumsum(u * self.dt, -2) + vel = x0[..., 2:].unsqueeze(-2) + delta_v + vel = soft_sat(vel, self.vbound[0], self.vbound[1]) + delta_xy = torch.cumsum(vel * self.dt, -2) + xy = x0[..., :2].unsqueeze(-2) + delta_xy + + traj = torch.cat((xy, vel), -1) + return traj diff --git a/diffstack/dynamics/single_integrator.py b/diffstack/dynamics/single_integrator.py new file mode 100644 index 0000000..ad2b2ea --- /dev/null +++ b/diffstack/dynamics/single_integrator.py @@ -0,0 +1,52 @@ +from diffstack.dynamics.base import DynType, Dynamics +import torch +import numpy as np + + +class SingleIntegrator(Dynamics): + def __init__(self, dt, name, vbound): + self._name = name + self.dt = dt + self._type = DynType.SI + self.xdim = vbound.shape[0] + self.udim = vbound.shape[0] + self.cyclic_state = list() + self.vbound = np.array(vbound) + + def __call__(self, x, u): + assert x.shape[:-1] == u.shape[:, -1] + + return u + + def step(self, x, u, bound=True): + assert x.shape[:-1] == u.shape[:, -1] + if bound: + lb, ub = self.ubound(x) + if isinstance(x, np.ndarray): + u = np.clip(u, lb, ub) + elif isinstance(x, torch.Tensor): + u = torch.clip(u, min=lb, max=ub) + + return x + u * self.dt + + def name(self): + return self._name + + def type(self): + return self._type + + def ubound(self, x): + if isinstance(x, np.ndarray): + lb = np.ones_like(x) * self.vbound[:, 0] + ub = np.ones_like(x) * self.vbound[:, 1] + return lb, ub + elif isinstance(x, torch.Tensor): + lb = torch.ones_like(x) * torch.from_numpy(self.vbound[:, 0]) + ub = torch.ones_like(x) * torch.from_numpy(self.vbound[:, 1]) + return lb, ub + else: + raise NotImplementedError + + @staticmethod + def state2pos(x): + return x[..., 0:2] diff --git a/diffstack/dynamics/unicycle.py b/diffstack/dynamics/unicycle.py new file mode 100644 index 0000000..9bc5e36 --- /dev/null +++ b/diffstack/dynamics/unicycle.py @@ -0,0 +1,834 @@ +from diffstack.dynamics.base import DynType, Dynamics +from diffstack.utils.math_utils import soft_sat +import torch +import torch.nn as nn +import numpy as np +from copy import deepcopy +from torch.autograd.functional import jacobian +import diffstack.utils.geometry_utils as GeoUtils + + +class Unicycle(Dynamics): + def __init__( + self, + dt, + name=None, + max_steer=0.5, + max_yawvel=8, + acce_bound=[-6, 4], + vbound=[-10, 30], + ): + self.dt = dt + self._name = name + self._type = DynType.UNICYCLE + self.xdim = 4 + self.udim = 2 + self.cyclic_state = [3] + self.acce_bound = acce_bound + self.vbound = vbound + self.max_steer = max_steer + self.max_yawvel = max_yawvel + + def __call__(self, x, u): + assert x.shape[:-1] == u.shape[:, -1] + if isinstance(x, np.ndarray): + assert isinstance(u, np.ndarray) + theta = x[..., 3:4] + dxdt = np.hstack( + (np.cos(theta) * x[..., 2:3], np.sin(theta) * x[..., 2:3], u) + ) + elif isinstance(x, torch.Tensor): + assert isinstance(u, torch.Tensor) + theta = x[..., 3:4] + dxdt = torch.cat( + (torch.cos(theta) * x[..., 2:3], torch.sin(theta) * x[..., 2:3], u), + dim=-1, + ) + else: + raise NotImplementedError + return dxdt + + def step(self, x, u, bound=True, return_jacobian=False): + assert x.shape[:-1] == u.shape[:-1] + if isinstance(x, np.ndarray): + assert isinstance(u, np.ndarray) + if bound: + lb, ub = self.ubound(x) + u = np.clip(u, lb, ub) + + theta = x[..., 3:4] + cos_theta_p = np.cos(theta) - 0.5 * np.sin(theta) * u[..., 1:] * self.dt + sin_theta_p = np.sin(theta) + 0.5 * np.cos(theta) * u[..., 1:] * self.dt + vel_p = x[..., 2:3] + u[..., 0:1] * self.dt * 0.5 + + dxdt = np.hstack( + ( + cos_theta_p * vel_p, + sin_theta_p * vel_p, + u, + ) + ) + xp = x + dxdt * self.dt + if return_jacobian: + d_cos_theta_p_d_theta = ( + -np.sin(theta) - 0.5 * np.cos(theta) * u[..., 1:] * self.dt + ) + d_sin_theta_p_d_theta = ( + np.cos(theta) - 0.5 * np.sin(theta) * u[..., 1:] * self.dt + ) + d_vel_p_d_a = 0.5 * self.dt + d_cos_theta_p_d_yaw = -0.5 * np.sin(theta) * self.dt + d_sin_theta_p_d_yaw = 0.5 * np.cos(theta) * self.dt + jacx = np.tile(np.eye(4), (*x.shape[:-1], 1, 1)) + jacx[..., 0, 2:3] = cos_theta_p * self.dt + jacx[..., 0, 3:4] = vel_p * self.dt * d_cos_theta_p_d_theta + jacx[..., 1, 2:3] = sin_theta_p * self.dt + jacx[..., 1, 3:4] = vel_p * self.dt * d_sin_theta_p_d_theta + + jacu = np.zeros((*x.shape[:-1], 4, 2)) + jacu[..., 0, 0:1] = cos_theta_p * self.dt * d_vel_p_d_a + jacu[..., 0, 1:2] = vel_p * self.dt * d_cos_theta_p_d_yaw + jacu[..., 1, 0:1] = sin_theta_p * self.dt * d_vel_p_d_a + jacu[..., 1, 1:2] = vel_p * self.dt * d_sin_theta_p_d_yaw + jacu[..., 2, 0:1] = self.dt + jacu[..., 3, 1:2] = self.dt + + return xp, jacx, jacu + else: + return xp + elif isinstance(x, torch.Tensor): + assert isinstance(u, torch.Tensor) + if bound: + lb, ub = self.ubound(x) + # s = (u - lb) / torch.clip(ub - lb, min=1e-3) + # u = lb + (ub - lb) * torch.sigmoid(s) + u = torch.clip(u, lb, ub) + + theta = x[..., 3:4] + cos_theta_p = ( + torch.cos(theta) - 0.5 * torch.sin(theta) * u[..., 1:] * self.dt + ) + sin_theta_p = ( + torch.sin(theta) + 0.5 * torch.cos(theta) * u[..., 1:] * self.dt + ) + vel_p = x[..., 2:3] + u[..., 0:1] * self.dt * 0.5 + dxdt = torch.cat( + ( + cos_theta_p * vel_p, + sin_theta_p * vel_p, + u, + ), + dim=-1, + ) + xp = x + dxdt * self.dt + if return_jacobian: + d_cos_theta_p_d_theta = ( + -torch.sin(theta) - 0.5 * torch.cos(theta) * u[..., 1:] * self.dt + ) + d_sin_theta_p_d_theta = ( + torch.cos(theta) - 0.5 * torch.sin(theta) * u[..., 1:] * self.dt + ) + d_vel_p_d_a = 0.5 * self.dt + d_cos_theta_p_d_yaw = -0.5 * torch.sin(theta) * self.dt + d_sin_theta_p_d_yaw = 0.5 * torch.cos(theta) * self.dt + eye4 = torch.tile(torch.eye(4, device=x.device), (*x.shape[:-1], 1, 1)) + jacxy = torch.zeros((*x.shape[:-1], 4, 2), device=x.device) + zeros21 = torch.zeros((*x.shape[:-1], 2, 1), device=x.device) + jacv = ( + torch.cat( + [cos_theta_p.unsqueeze(-2), sin_theta_p.unsqueeze(-2), zeros21], + -2, + ) + * self.dt + ) + jactheta = ( + torch.cat( + [ + (vel_p * d_cos_theta_p_d_theta).unsqueeze(-2), + (vel_p * d_sin_theta_p_d_theta).unsqueeze(-2), + zeros21, + ], + -2, + ) + * self.dt + ) + jacx = torch.cat([jacxy, jacv, jactheta], -1) + eye4 + # jacx = torch.tile(torch.eye(4,device=x.device), (*x.shape[:-1], 1, 1)) + # jacx[...,0,2:3] = cos_theta_p*self.dt + # jacx[...,0,3:4] = vel_p*self.dt*d_cos_theta_p_d_theta + # jacx[...,1,2:3] = sin_theta_p*self.dt + # jacx[...,1,3:4] = vel_p*self.dt*d_sin_theta_p_d_theta + + jacxy_a = ( + torch.cat( + [cos_theta_p.unsqueeze(-2), sin_theta_p.unsqueeze(-2)], -2 + ) + * self.dt + * d_vel_p_d_a + ) + jacxy_yaw = ( + torch.cat( + [ + (vel_p * d_cos_theta_p_d_yaw).unsqueeze(-2), + (vel_p * d_sin_theta_p_d_yaw).unsqueeze(-2), + ], + -2, + ) + * self.dt + ) + eye2 = torch.tile(torch.eye(2, device=x.device), (*x.shape[:-1], 1, 1)) + jacu = torch.cat( + [torch.cat([jacxy_a, jacxy_yaw], -1), eye2 * self.dt], -2 + ) + # jacu = torch.zeros((*x.shape[:-1], 4, 2),device=x.device) + # jacu[...,0,0:1] = cos_theta_p*self.dt*d_vel_p_d_a + # jacu[...,0,1:2] = vel_p*self.dt*d_cos_theta_p_d_yaw + # jacu[...,1,0:1] = sin_theta_p*self.dt*d_vel_p_d_a + # jacu[...,1,1:2] = vel_p*self.dt*d_sin_theta_p_d_yaw + # jacu[...,2,0:1] = self.dt + # jacu[...,3,1:2] = self.dt + return xp, jacx, jacu + else: + return xp + else: + raise NotImplementedError + + def get_x_Gaussian_from_u(self, x, mu_u, var_u): + mu_x, _, jacu = self.step(x, mu_u, bound=False, return_jacobian=True) + + var_u_mat = torch.diag_embed(var_u) + var_x = torch.matmul(torch.matmul(jacu, var_u_mat), jacu.transpose(-1, -2)) + return mu_x, var_x + + def name(self): + return self._name + + def type(self): + return self._type + + def ubound(self, x): + if isinstance(x, np.ndarray): + v = x[..., 2:3] + vclip = np.clip(np.abs(v), a_min=0.1, a_max=None) + + yawbound = np.minimum( + self.max_steer * vclip, + self.max_yawvel / vclip, + ) + acce_lb = np.clip( + np.clip(self.vbound[0] - v, None, self.acce_bound[1]), + self.acce_bound[0], + None, + ) + acce_ub = np.clip( + np.clip(self.vbound[1] - v, self.acce_bound[0], None), + None, + self.acce_bound[1], + ) + lb = np.concatenate((acce_lb, -yawbound), -1) + ub = np.concatenate((acce_ub, yawbound), -1) + return lb, ub + elif isinstance(x, torch.Tensor): + v = x[..., 2:3] + vclip = torch.clip(torch.abs(v), min=0.1) + yawbound = torch.minimum( + self.max_steer * vclip, + self.max_yawvel / vclip, + ) + yawbound = torch.clip(yawbound, min=0.1) + acce_lb = torch.clip( + torch.clip(self.vbound[0] - v, max=self.acce_bound[1]), + min=self.acce_bound[0], + ) + acce_ub = torch.clip( + torch.clip(self.vbound[1] - v, min=self.acce_bound[0]), + max=self.acce_bound[1], + ) + lb = torch.cat((acce_lb, -yawbound), dim=-1) + ub = torch.cat((acce_ub, yawbound), dim=-1) + return lb, ub + + else: + raise NotImplementedError + + def uniform_sample_xp(self, x, num_sample): + if isinstance(x, torch.Tensor): + u_lb, u_ub = self.ubound(x) + u_sample = torch.rand( + *x.shape[:-1], num_sample, self.udim, device=x.device + ) * (u_ub - u_lb).unsqueeze(-2) + u_lb.unsqueeze(-2) + xp = self.step( + x.unsqueeze(-2).repeat_interleave(num_sample, -2), u_sample, bound=False + ) + elif isinstance(x, np.ndarray): + u_lb, u_ub = self.ubound(x) + u_sample = np.random.uniform( + u_lb[..., None, :], + u_ub[..., None, :], + (*x.shape[:-1], num_sample, self.udim), + ) + xp = self.step( + x[..., None, :].repeat(num_sample, -2), u_sample, bound=False + ) + else: + raise NotImplementedError + return xp + + @staticmethod + def state2pos(x): + return x[..., 0:2] + + @staticmethod + def state2yaw(x): + return x[..., 3:] + + @staticmethod + def state2vel(x): + return x[..., 2:3] + + @staticmethod + def state2xyvsc(x): + return torch.cat([x[..., :3], torch.sin(x[..., 3:]), torch.cos(x[..., 3:])], -1) + + @staticmethod + def combine_to_state(xy, vel, yaw): + if isinstance(xy, torch.Tensor): + return torch.cat((xy, vel, yaw), -1) + elif isinstance(xy, np.ndarray): + return np.concatenate((xy, vel, yaw), -1) + + def calculate_vel(self, pos, yaw, mask, dt=None): + if dt is None: + dt = self.dt + if isinstance(pos, torch.Tensor): + vel = (pos[..., 1:, 0:1] - pos[..., :-1, 0:1]) / dt * torch.cos( + yaw[..., 1:, :] + ) + (pos[..., 1:, 1:2] - pos[..., :-1, 1:2]) / dt * torch.sin( + yaw[..., 1:, :] + ) + # right finite difference velocity + vel_r = torch.cat((vel[..., 0:1, :], vel), dim=-2) + # left finite difference velocity + vel_l = torch.cat((vel, vel[..., -1:, :]), dim=-2) + mask_r = torch.roll(mask, 1, dims=-1) + mask_r[..., 0] = False + mask_r = mask_r & mask + + mask_l = torch.roll(mask, -1, dims=-1) + mask_l[..., -1] = False + mask_l = mask_l & mask + vel = ( + (mask_l & mask_r).unsqueeze(-1) * (vel_r + vel_l) / 2 + + (mask_l & (~mask_r)).unsqueeze(-1) * vel_l + + (mask_r & (~mask_l)).unsqueeze(-1) * vel_r + ) + elif isinstance(pos, np.ndarray): + vel = (pos[..., 1:, 0:1] - pos[..., :-1, 0:1]) / dt * np.cos( + yaw[..., 1:, :] + ) + (pos[..., 1:, 1:2] - pos[..., :-1, 1:2]) / dt * np.sin(yaw[..., 1:, :]) + # right finite difference velocity + vel_r = np.concatenate((vel[..., 0:1, :], vel), axis=-2) + # left finite difference velocity + vel_l = np.concatenate((vel, vel[..., -1:, :]), axis=-2) + mask_r = np.roll(mask, 1, axis=-1) + mask_r[..., 0] = False + mask_r = mask_r & mask + mask_l = np.roll(mask, -1, axis=-1) + mask_l[..., -1] = False + mask_l = mask_l & mask + vel = ( + np.expand_dims(mask_l & mask_r, -1) * (vel_r + vel_l) / 2 + + np.expand_dims(mask_l & (~mask_r), -1) * vel_l + + np.expand_dims(mask_r & (~mask_l), -1) * vel_r + ) + else: + raise NotImplementedError + return vel + + def inverse_dyn(self, x, xp, dt=None, mask=None): + if dt is None: + dt = self.dt + dx = torch.cat( + [xp[..., 2:3] - x[..., 2:3], GeoUtils.round_2pi(xp[..., 3:] - x[..., 3:])], + -1, + ) + u = dx / dt + if mask is not None: + u = u * mask[..., None] + return u + + def get_state(self, pos, yaw, mask, dt=None): + if dt is None: + dt = self.dt + vel = self.calculate_vel(pos, yaw, mask, dt) + if isinstance(vel, np.ndarray): + return np.concatenate((pos, vel, yaw), -1) + elif isinstance(vel, torch.Tensor): + return torch.cat((pos, vel, yaw), -1) + + @staticmethod + def get_axay(x, u): + yaw = Unicycle.state2yaw(x) + vel = Unicycle.state2vel(x) + acce = u[..., 0:1] + r = u[..., 1:] + stack_fun = torch.stack if isinstance(x, torch.Tensor) else np.stack + sin = torch.sin if isinstance(x, torch.Tensor) else np.sin + cos = torch.cos if isinstance(x, torch.Tensor) else np.cos + return stack_fun( + [ + acce * cos(yaw) - vel * r * sin(yaw), + acce * sin(yaw) + vel * r * cos(yaw), + ], + -1, + ) + + def forward_dynamics( + self, + x0: torch.Tensor, + u: torch.Tensor, + mode="parallel", + bound=True, + ): + """ + Integrate the state forward with initial state x0, action u + Args: + initial_states (Torch.tensor): state tensor of size [B, (A), 4] + actions (Torch.tensor): action tensor of size [B, (A), T, 2] + Returns: + state tensor of size [B, (A), T, 4] + """ + if mode == "chain": + num_steps = u.shape[-2] + x = [x0] + [None] * num_steps + for t in range(num_steps): + x[t + 1] = self.step(x[t], u[..., t, :], bound=bound) + + return torch.stack(x[1:], dim=-2) + + else: + assert mode in ["parallel", "partial_parallel"] + with torch.no_grad(): + num_steps = u.shape[-2] + b = x0.shape[0] + device = x0.device + + mat = torch.ones(num_steps + 1, num_steps + 1, device=device) + mat = torch.tril(mat) + mat = mat.repeat(b, 1, 1) + + mat2 = torch.ones(num_steps, num_steps + 1, device=device) + mat2_h = torch.tril(mat2, diagonal=1) + mat2_l = torch.tril(mat2, diagonal=-1) + mat2 = torch.logical_xor(mat2_h, mat2_l).float() * 0.5 + mat2 = mat2.repeat(b, 1, 1) + if x0.ndim == 3: + mat = mat.unsqueeze(1) + mat2 = mat2.unsqueeze(1) + + acc = u[..., :1] + yaw = u[..., 1:] + if bound: + acc_clipped = soft_sat(acc, self.acce_bound[0], self.acce_bound[1]) + else: + acc_clipped = acc + + if mode == "parallel": + acc_paded = torch.cat( + (x0[..., -2:-1].unsqueeze(-2), acc_clipped * self.dt), dim=-2 + ) + + v_raw = torch.matmul(mat, acc_paded) + v_clipped = soft_sat(v_raw, self.vbound[0], self.vbound[1]) + else: + v_clipped = [x0[..., 2:3]] + [None] * num_steps + for t in range(num_steps): + vt = v_clipped[t] + acc_clipped_t = soft_sat( + acc_clipped[:, t], self.vbound[0] - vt, self.vbound[1] - vt + ) + v_clipped[t + 1] = vt + acc_clipped_t * self.dt + v_clipped = torch.stack(v_clipped, dim=-2) + + v_avg = torch.matmul(mat2, v_clipped) + + v = v_clipped[..., 1:, :] + if bound: + with torch.no_grad(): + v_earlier = v_clipped[..., :-1, :] + + yawbound = torch.minimum( + self.max_steer * torch.abs(v_earlier), + self.max_yawvel / torch.clip(torch.abs(v_earlier), min=0.1), + ) + yawbound_clipped = torch.clip(yawbound, min=0.1) + + yaw_clipped = soft_sat(yaw, -yawbound_clipped, yawbound_clipped) + else: + yaw_clipped = yaw + yawvel_paded = torch.cat( + (x0[..., -1:].unsqueeze(-2), yaw_clipped * self.dt), dim=-2 + ) + yaw_full = torch.matmul(mat, yawvel_paded) + yaw = yaw_full[..., 1:, :] + + # print('before clip', torch.cat((acc[0], yawvel[0]), dim=-1)) + # print('after clip', torch.cat((acc_clipped[0], yawvel_clipped[0]), dim=-1)) + + yaw_earlier = yaw_full[..., :-1, :] + vx = v_avg * torch.cos(yaw_earlier) + vy = v_avg * torch.sin(yaw_earlier) + v_all = torch.cat((vx, vy), dim=-1) + + # print('initial_states[0, -2:]', initial_states[0, -2:]) + # print('vx[0, :5]', vx[0, :5]) + + v_all_paded = torch.cat( + (x0[..., :2].unsqueeze(-2), v_all * self.dt), dim=-2 + ) + x_and_y = torch.matmul(mat, v_all_paded) + x_and_y = x_and_y[..., 1:, :] + + x_all = torch.cat((x_and_y, v, yaw), dim=-1) + return x_all + + # def propagate_and_linearize(self,x0,u,dt=None): + # if dt is None: + # dt = self.dt + # xp,_,_ = self.forward_dynamics(x0,u,dt,mode="chain") + # xl = torch.cat([x0.unsqueeze(1),xp[:,:-1]],1) + # A,B = jacobian(lambda x,u: self.step(x,u,dt),(xl,u)) + # A = A.diagonal(dim1=0,dim2=3).diagonal(dim1=0,dim2=2).permute(2,3,0,1) + # B = B.diagonal(dim1=0,dim2=3).diagonal(dim1=0,dim2=2).permute(2,3,0,1) + # C = xp - (A@xl.unsqueeze(-1)+B@u.unsqueeze(-1)).squeeze(-1) + # return xp,A,B,C + + +class Unicycle_xyvsc(Dynamics): + def __init__( + self, dt,name = None, max_steer=0.5, max_yawvel=8, acce_bound=[-6, 4], vbound=[-10, 30] + ): + self.dt = dt + self._name = name + self._type = DynType.UNICYCLE + self.xdim = 5 + self.udim = 2 + self.acce_bound = acce_bound + self.vbound = vbound + self.max_steer = max_steer + self.max_yawvel = max_yawvel + + def step(self, x, u, bound=True, return_jacobian=False): + assert x.shape[:-1] == u.shape[:-1] + if isinstance(x, np.ndarray): + assert isinstance(u, np.ndarray) + if bound: + lb, ub = self.ubound(x) + u = np.clip(u, lb, ub) + + s = x[..., 3:4] + c = x[..., 4:5] + c_step = c-0.5*s*u[...,1:]*self.dt + s_step = s+0.5*c*u[...,1:]*self.dt + vel_step = x[..., 2:3] + u[..., 0:1] * self.dt * 0.5 + yaw = u[..., 1:2] + cp = c*np.cos(yaw*self.dt)-s*np.sin(yaw*self.dt) + sp = s*np.cos(yaw*self.dt)+c*np.sin(yaw*self.dt) + dx = np.hstack( + ( + c_step * vel_step*self.dt, + s_step * vel_step*self.dt, + u[...,]*self.dt, + sp-s, + cp-c, + ) + ) + xp = x + dx + if return_jacobian: + raise NotImplementedError + # d_cos_theta_p_d_theta = -np.sin(theta)-0.5*np.cos(theta)*u[...,1:]*self.dt + # d_sin_theta_p_d_theta = np.cos(theta)-0.5*np.sin(theta)*u[...,1:]*self.dt + # d_vel_p_d_a = 0.5*self.dt + # d_cos_theta_p_d_yaw = -0.5*np.sin(theta)*self.dt + # d_sin_theta_p_d_yaw = 0.5*np.cos(theta)*self.dt + # jacx = np.tile(np.eye(4), (*x.shape[:-1], 1, 1)) + # jacx[...,0,2:3] = cos_theta_p*self.dt + # jacx[...,0,3:4] = vel_p*self.dt*d_cos_theta_p_d_theta + # jacx[...,1,2:3] = sin_theta_p*self.dt + # jacx[...,1,3:4] = vel_p*self.dt*d_sin_theta_p_d_theta + + # jacu = np.zeros((*x.shape[:-1], 4, 2)) + # jacu[...,0,0:1] = cos_theta_p*self.dt*d_vel_p_d_a + # jacu[...,0,1:2] = vel_p*self.dt*d_cos_theta_p_d_yaw + # jacu[...,1,0:1] = sin_theta_p*self.dt*d_vel_p_d_a + # jacu[...,1,1:2] = vel_p*self.dt*d_sin_theta_p_d_yaw + # jacu[...,2,0:1] = self.dt + # jacu[...,3,1:2] = self.dt + + # return xp, jacx, jacu + else: + return xp + elif isinstance(x, torch.Tensor): + assert isinstance(u, torch.Tensor) + if bound: + lb, ub = self.ubound(x) + # s = (u - lb) / torch.clip(ub - lb, min=1e-3) + # u = lb + (ub - lb) * torch.sigmoid(s) + u = torch.clip(u, lb, ub) + + s = x[..., 3:4] + c = x[..., 4:5] + c_step = c-0.5*s*u[...,1:]*self.dt + s_step = s+0.5*c*u[...,1:]*self.dt + vel_step = x[..., 2:3] + u[..., 0:1] * self.dt * 0.5 + yaw = u[..., 1:2] + cp = c*torch.cos(yaw*self.dt)-s*torch.sin(yaw*self.dt) + sp = s*torch.cos(yaw*self.dt)+c*torch.sin(yaw*self.dt) + dx = torch.cat( + ( + c_step * vel_step*self.dt, + s_step * vel_step*self.dt, + u[...,:1]*self.dt, + sp-s, + cp-c, + ),-1 + ) + xp = x + dx + if return_jacobian: + raise NotImplementedError + # d_cos_theta_p_d_theta = -torch.sin(theta)-0.5*torch.cos(theta)*u[...,1:]*self.dt + # d_sin_theta_p_d_theta = torch.cos(theta)-0.5*torch.sin(theta)*u[...,1:]*self.dt + # d_vel_p_d_a = 0.5*self.dt + # d_cos_theta_p_d_yaw = -0.5*torch.sin(theta)*self.dt + # d_sin_theta_p_d_yaw = 0.5*torch.cos(theta)*self.dt + # eye4 = torch.tile(torch.eye(4,device=x.device), (*x.shape[:-1], 1, 1)) + # jacxy = torch.zeros((*x.shape[:-1], 4, 2),device=x.device) + # zeros21 = torch.zeros((*x.shape[:-1], 2, 1),device=x.device) + # jacv = torch.cat([cos_theta_p.unsqueeze(-2),sin_theta_p.unsqueeze(-2),zeros21],-2)*self.dt + # jactheta = torch.cat([(vel_p*d_cos_theta_p_d_theta).unsqueeze(-2),(vel_p*d_sin_theta_p_d_theta).unsqueeze(-2),zeros21],-2)*self.dt + # jacx = torch.cat([jacxy,jacv,jactheta],-1)+eye4 + # # jacx = torch.tile(torch.eye(4,device=x.device), (*x.shape[:-1], 1, 1)) + # # jacx[...,0,2:3] = cos_theta_p*self.dt + # # jacx[...,0,3:4] = vel_p*self.dt*d_cos_theta_p_d_theta + # # jacx[...,1,2:3] = sin_theta_p*self.dt + # # jacx[...,1,3:4] = vel_p*self.dt*d_sin_theta_p_d_theta + + + + # jacxy_a = torch.cat([cos_theta_p.unsqueeze(-2),sin_theta_p.unsqueeze(-2)],-2)*self.dt*d_vel_p_d_a + # jacxy_yaw = torch.cat([(vel_p*d_cos_theta_p_d_yaw).unsqueeze(-2),(vel_p*d_sin_theta_p_d_yaw).unsqueeze(-2)],-2)*self.dt + # eye2 = torch.tile(torch.eye(2,device=x.device), (*x.shape[:-1], 1, 1)) + # jacu = torch.cat([torch.cat([jacxy_a,jacxy_yaw],-1),eye2*self.dt],-2) + # # jacu = torch.zeros((*x.shape[:-1], 4, 2),device=x.device) + # # jacu[...,0,0:1] = cos_theta_p*self.dt*d_vel_p_d_a + # # jacu[...,0,1:2] = vel_p*self.dt*d_cos_theta_p_d_yaw + # # jacu[...,1,0:1] = sin_theta_p*self.dt*d_vel_p_d_a + # # jacu[...,1,1:2] = vel_p*self.dt*d_sin_theta_p_d_yaw + # # jacu[...,2,0:1] = self.dt + # # jacu[...,3,1:2] = self.dt + return xp, jacx, jacu + else: + return xp + else: + raise NotImplementedError + + # def get_x_Gaussian_from_u(self,x,mu_u,var_u): + # mu_x,_,jacu = self.step(x, mu_u, bound=False, return_jacobian=True) + + # var_u_mat = torch.diag_embed(var_u) + # var_x = torch.matmul(torch.matmul(jacu,var_u_mat),jacu.transpose(-1,-2)) + # return mu_x,var_x + + def __call__(self, x, u): + return self.step(x, u) + + def name(self): + return self._name + + def type(self): + return self._type + + def ubound(self, x): + if isinstance(x, np.ndarray): + v = x[..., 2:3] + vclip = np.clip(np.abs(v), a_min=0.1, a_max=None) + yawbound = np.minimum( + self.max_steer * vclip, + self.max_yawvel / vclip, + ) + acce_lb = np.clip( + np.clip(self.vbound[0] - v, None, self.acce_bound[1]), + self.acce_bound[0], + None, + ) + acce_ub = np.clip( + np.clip(self.vbound[1] - v, self.acce_bound[0], None), + None, + self.acce_bound[1], + ) + lb = np.concatenate((acce_lb, -yawbound),-1) + ub = np.concatenate((acce_ub, yawbound),-1) + return lb, ub + elif isinstance(x, torch.Tensor): + v = x[..., 2:3] + vclip = torch.clip(torch.abs(v),min=0.1) + yawbound = torch.minimum( + self.max_steer * vclip, + self.max_yawvel / vclip, + ) + yawbound = torch.clip(yawbound, min=0.1) + acce_lb = torch.clip( + torch.clip(self.vbound[0] - v, max=self.acce_bound[1]), + min=self.acce_bound[0], + ) + acce_ub = torch.clip( + torch.clip(self.vbound[1] - v, min=self.acce_bound[0]), + max=self.acce_bound[1], + ) + lb = torch.cat((acce_lb, -yawbound), dim=-1) + ub = torch.cat((acce_ub, yawbound), dim=-1) + return lb, ub + + else: + raise NotImplementedError + def uniform_sample_xp(self,x,num_sample): + if isinstance(x,torch.Tensor): + u_lb,u_ub = self.ubound(x) + u_sample = torch.rand(*x.shape[:-1],num_sample,self.udim).to(x.device)*(u_ub-u_lb).unsqueeze(-2)+u_lb.unsqueeze(-2) + xp = self.step(x.unsqueeze(-2).repeat_interleave(num_sample,-2),u_sample,bound=False) + elif isinstance(x,np.ndarray): + u_lb,u_ub = self.ubound(x) + u_sample = np.random.uniform(u_lb[...,None,:],u_ub[...,None,:],(*x.shape[:-1],num_sample,self.udim)) + xp = self.step(x[...,None,:].repeat(num_sample,-2),u_sample,bound=False) + else: + raise NotImplementedError + return xp + + @staticmethod + def state2pos(x): + return x[..., 0:2] + + @staticmethod + def state2sc(x): + return x[..., 3:] + + @staticmethod + def state2yaw(x): + arctanfun = GeoUtils.ratan2 if isinstance(x,torch.Tensor) else np.arctan2 + return arctanfun(x[...,3:4],x[...,4:5]) + + @staticmethod + def state2vel(x): + return x[..., 2:3] + + @staticmethod + def combine_to_state(xy,vel,yaw): + if isinstance(xy,torch.Tensor): + return torch.cat((xy,vel,torch.sin(yaw),torch.cos(yaw)),-1) + elif isinstance(xy,np.ndarray): + return np.concatenate((xy,vel,np.sin(yaw),np.cos(yaw)),-1) + + def calculate_vel(self, pos, yaw, mask, dt=None): + if dt is None: + dt = self.dt + if isinstance(pos, torch.Tensor): + vel = (pos[..., 1:, 0:1] - pos[..., :-1, 0:1]) / dt * torch.cos( + yaw[..., 1:, :] + ) + (pos[..., 1:, 1:2] - pos[..., :-1, 1:2]) / dt * torch.sin( + yaw[..., 1:, :] + ) + # right finite difference velocity + vel_r = torch.cat((vel[..., 0:1, :], vel), dim=-2) + # left finite difference velocity + vel_l = torch.cat((vel, vel[..., -1:, :]), dim=-2) + mask_r = torch.roll(mask, 1, dims=-1) + mask_r[..., 0] = False + mask_r = mask_r & mask + + mask_l = torch.roll(mask, -1, dims=-1) + mask_l[..., -1] = False + mask_l = mask_l & mask + vel = ( + (mask_l & mask_r).unsqueeze(-1) * (vel_r + vel_l) / 2 + + (mask_l & (~mask_r)).unsqueeze(-1) * vel_l + + (mask_r & (~mask_l)).unsqueeze(-1) * vel_r + ) + elif isinstance(pos, np.ndarray): + vel = (pos[..., 1:, 0:1] - pos[..., :-1, 0:1]) / dt * np.cos( + yaw[..., 1:, :] + ) + (pos[..., 1:, 1:2] - pos[..., :-1, 1:2]) / dt * np.sin(yaw[..., 1:, :]) + # right finite difference velocity + vel_r = np.concatenate((vel[..., 0:1, :], vel), axis=-2) + # left finite difference velocity + vel_l = np.concatenate((vel, vel[..., -1:, :]), axis=-2) + mask_r = np.roll(mask, 1, axis=-1) + mask_r[..., 0] = False + mask_r = mask_r & mask + mask_l = np.roll(mask, -1, axis=-1) + mask_l[..., -1] = False + mask_l = mask_l & mask + vel = ( + np.expand_dims(mask_l & mask_r,-1) * (vel_r + vel_l) / 2 + + np.expand_dims(mask_l & (~mask_r),-1) * vel_l + + np.expand_dims(mask_r & (~mask_l),-1) * vel_r + ) + else: + raise NotImplementedError + return vel + + def inverse_dyn(self,x,xp,dt=None): + if dt is None: + dt = self.dt + acce = (xp[2:3]-x[2:3])/dt + arctanfun = GeoUtils.ratan2 if isinstance(x,torch.Tensor) else np.arctan2 + catfun = torch.cat if isinstance(x,torch.Tensor) else np.concatenate + yawrate = (arctanfun(xp[3:4],xp[4:5])-arctanfun(x[3:4],x[4:5]))/dt + return catfun([acce,yawrate],-1) + + + def get_state(self,pos,yaw,mask,dt=None): + if dt is None: + dt = self.dt + vel = self.calculate_vel(pos, yaw, mask,dt) + return self.combine_to_state(pos,vel,yaw) + + def forward_dynamics(self, + x0: torch.Tensor, + u: torch.Tensor, + mode="chain", + bound = True, + ): + + """ + Integrate the state forward with initial state x0, action u + Args: + initial_states (Torch.tensor): state tensor of size [B, (A), 4] + actions (Torch.tensor): action tensor of size [B, (A), T, 2] + Returns: + state tensor of size [B, (A), T, 4] + """ + if mode=="chain": + num_steps = u.shape[-2] + x = [x0] + [None] * num_steps + for t in range(num_steps): + x[t + 1] = self.step(x[t], u[..., t, :],bound=bound) + + return torch.stack(x[1:], dim=-2) + + else: + raise NotImplementedError + + + + + + +def test(): + model = Unicycle(0.1) + x0 = torch.tensor([[1, 2, 3, 4]]).repeat_interleave(3, 0) + u = torch.tensor([[1, 2]]).repeat_interleave(3, 0) + x, jacx, jacu = model.step(x0, u, return_jacobian=True) + + +if __name__ == "__main__": + test() diff --git a/diffstack/models/CTT.py b/diffstack/models/CTT.py new file mode 100644 index 0000000..d6fc3b2 --- /dev/null +++ b/diffstack/models/CTT.py @@ -0,0 +1,2758 @@ +import enum +from collections import OrderedDict, defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from trajdata.data_structures import AgentType + +import diffstack.utils.geometry_utils as GeoUtils +import diffstack.utils.lane_utils as LaneUtils +import diffstack.utils.model_utils as ModelUtils +import diffstack.utils.tensor_utils as TensorUtils +from diffstack.dynamics.unicycle import Unicycle, Unicycle_xyvsc +from diffstack.models.base_models import MLP +from diffstack.models.RPE_simple import sAuxRPEAttention, sAuxRPECrossAttention +from diffstack.models.TypeTransformer import * +from diffstack.utils.diffusion_utils import zero_module +from diffstack.utils.dist_utils import categorical_psample_wor +from diffstack.utils.geometry_utils import ratan2 +from diffstack.utils.homotopy import HOMOTOPY_THRESHOLD, HomotopyType + + +class FeatureAxes(enum.Enum): + """ + Axes of the features + """ + + B = enum.auto() # batch + A = enum.auto() # agent + T = enum.auto() # time + L = enum.auto() # lanes + F = enum.auto() # features + + +class TFvars(enum.Enum): + """ + Variables of the transformer + """ + + Agent_hist = enum.auto() # Agent history trajectories + Agent_future = enum.auto() # Agent future trajectories + Lane = enum.auto() # Lanes + + +class GNNedges(enum.Enum): + """edges of the GNN""" + + Agenthist2Lane = enum.auto() + Agentfuture2Lane = enum.auto() + Agenthist2Agenthist = enum.auto() + Agentfuture2Agentfuture = enum.auto() + Lane2Lane = enum.auto() + + +class FactorizedAttentionBlock(nn.Module): + def __init__( + self, + n_embd: int, + n_head: int, + PE_mode: str, + use_rpe_net: bool, + attn_attributes: OrderedDict, + var_axes: dict, + attn_pdrop: float = 0, + resid_pdrop: float = 0, + nominal_axes_order: list = None, + MAX_T=50, + ): + """a factorized attention block + + Args: + n_embd (int): embedding dimension + n_head (int): number of heads + PE_mode (str): Positional embedding mode, "RPE" or "PE" + use_rpe_net (bool): For RPE, whether use RPE (True) or ProdPENet (False) + attn_attributes (OrderedDict): recipe for attention blocks + var_axes (dict): variables and their axes + attn_pdrop (float, optional): dropout rate for attention. Defaults to 0. + resid_pdrop (float, optional): dropout rate for residue connection. Defaults to 0. + nominal_axes_order (list, optional): axes' order when arranging tensors. Defaults to None. + MAX_T (int, optional): maximum teporal axis length for PE. Defaults to 50. + + """ + super().__init__() + + self.vars = list(var_axes.keys()) + self.var_axes = dict() + for var in self.vars: + self.var_axes[var] = [ + FeatureAxes[var_axis] for var_axis in var_axes[var].split(",") + ] + if nominal_axes_order is None: + nominal_axes_order = list([var for var in FeatureAxes]) + self.nominal_axes_order = nominal_axes_order + self.attn_attributes = attn_attributes + self.attn = nn.ModuleDict() + self.mlp = nn.ModuleDict() + self.bn_1 = nn.ModuleDict() + self.bn_2 = nn.ModuleDict() + # key of attn_attributes are the two variables and the attention axis, + # value of attn_attributes are the edge dimension, edge function (if defined), ntype, and whether to normalize the embedding + for (var1, var2, axis), attributes in attn_attributes.items(): + edge_dim, edge_func, ntype, normalization = attributes + + if var1 == var2: + # self attention + attn_name = var1.name + "_" + var2.name + "_" + axis.name + if axis in [FeatureAxes.A, FeatureAxes.L]: + # agent/lane axis attention + attn_net = TypeSelfAttention( + ntype=ntype, + n_embd=n_embd, + n_head=n_head, + edge_dim=edge_dim, + aux_edge_func=edge_func, + attn_pdrop=attn_pdrop, + resid_pdrop=resid_pdrop, + ) + elif axis == FeatureAxes.T: + # time axis attention + if edge_func is None: + if PE_mode == "RPE": + attn_net = sAuxRPEAttention( + n_embd=n_embd, + num_heads=n_head, + aux_vardim=edge_dim, + use_checkpoint=False, + use_rpe_net=use_rpe_net, + ) + elif PE_mode == "PE": + attn_net = AuxSelfAttention( + n_embd=n_embd, + n_head=n_head, + edge_dim=edge_dim, + attn_pdrop=attn_pdrop, + resid_pdrop=resid_pdrop, + PE_len=MAX_T, + ) + + else: + if PE_mode == "RPE": + attn_net = sAuxRPECrossAttention( + n_embd, + n_head, + edge_dim, + edge_func, + use_checkpoint=False, + use_rpe_net=use_rpe_net, + ) + elif PE_mode == "PE": + attn_net = AuxCrossAttention( + n_embd, + n_head, + edge_dim, + edge_func, + attn_pdrop, + resid_pdrop, + PE_len=MAX_T, + ) + + else: + raise NotImplementedError + if normalization: + self.bn_1[attn_name + "_" + var1.name] = nn.BatchNorm1d(n_embd) + self.bn_2[attn_name] = nn.BatchNorm1d(n_embd) + else: + # cross attention + attn_name = ( + var1.name + + "_" + + var2.name + + "_" + + axis[0].name + + "->" + + axis[1].name + ) + if axis == (FeatureAxes.A, FeatureAxes.L): + # cross attention between agent and lane + attn_net = TypeCrossAttention( + ntype=ntype, + n_embd=n_embd, + n_head=n_head, + edge_dim=edge_dim, + aux_edge_func=edge_func, + attn_pdrop=attn_pdrop, + resid_pdrop=resid_pdrop, + ) + elif var1 == TFvars.Agent_future and var2 == TFvars.Agent_hist: + assert axis == (FeatureAxes.T, FeatureAxes.T) + if PE_mode == "RPE": + attn_net = sAuxRPECrossAttention( + n_embd, + n_head, + edge_dim, + edge_func, + use_checkpoint=False, + use_rpe_net=use_rpe_net, + ) + elif PE_mode == "PE": + attn_net = AuxCrossAttention( + n_embd, + n_head, + edge_dim, + edge_func, + attn_pdrop, + resid_pdrop, + PE_len=MAX_T, + ) + else: + raise NotImplementedError + if normalization: + self.bn_1[attn_name + "_" + var1.name] = nn.BatchNorm1d(n_embd) + self.bn_1[attn_name + "_" + var2.name] = nn.BatchNorm1d(n_embd) + self.bn_2[attn_name] = nn.BatchNorm1d(n_embd) + + self.attn[attn_name] = attn_net + self.mlp[attn_name] = tfmlp(n_embd, 4 * n_embd, resid_pdrop) + + def generate_agent_attn_mask(self, agent_mask, ntype): + # return 2 attention masks for self, and neightbors, respectively, or only return neighbor mask + # agent_mask: [B, N] + assert ntype in [1, 2] + B, N = agent_mask.shape + cross_mask = agent_mask[:, :, None] * agent_mask[:, None] # [B,N,N] + # mask to attend to self + I = torch.eye(N).to(agent_mask.device)[None] + notI = torch.logical_not(I) + + self_mask = cross_mask * I + # mask to attend to others + neighbor_mask = cross_mask * notI + if ntype == 1: + return neighbor_mask + else: + return self_mask, neighbor_mask + + def get_cross_mask( + self, var1, var2, axis, x1, x2, var_masks, cross_masks, ignore_var1_mask=True + ): + """get the attention mask for cross attention""" + if (var1, var2, axis) in cross_masks: + mask = cross_masks[(var1, var2, axis)] + else: + a1 = [a.name for a in self.var_axes[var1] if a != FeatureAxes.F] + a2 = [a.name for a in self.var_axes[var2] if a != FeatureAxes.F] + ag = [ + a.name + for a in self.nominal_axes_order + if a != FeatureAxes.F and ((a.name in a1) or (a.name in a2)) + ] + + # swap the attention axes to the last two + ag = [a for a in ag if a != axis[0].name and a != axis[1].name] + [ + axis[0].name, + axis[1].name, + ] + if axis[0] == axis[1]: + # deal with repeated axes + assert axis[1].name.swapcase() not in a1 + assert axis[0].name.swapcase() not in a2 + a2[a2.index(axis[1].name)] = axis[1].name.swapcase() + ag[-1] = ag[-1].swapcase() + cmd = "".join(a1) + "," + "".join(a2) + "->" + "".join(ag) + if var1 in var_masks and not ignore_var1_mask: + mask1 = var_masks[var1] + else: + mask1 = torch.ones_like(x1[..., 0]) + if var2 in var_masks: + mask2 = var_masks[var2] + else: + mask2 = torch.ones_like(x2[..., 0]) + mask = torch.einsum(cmd, mask1, mask2) + mask = mask.reshape(-1, *mask.shape[-2:]) + return mask + + def forward( + self, + vars, + aux_xs=None, + var_masks=dict(), + cross_masks=dict(), + frame_indices=dict(), + edges=dict(), + ): + if TFvars.Agent_hist in vars: + B, N, Th, D = vars[TFvars.Agent_hist].shape + if TFvars.Lane in vars: + B, L = vars[TFvars.Lane].shape[:2] + if TFvars.Agent_future in vars: + Tf = vars[TFvars.Agent_future].shape[2] + if TFvars.Agent_hist in vars: + assert N == vars[TFvars.Agent_future].shape[1] + + for (var1, var2, axis), attributes in self.attn_attributes.items(): + edge = edges[(var1, var2, axis)] if (var1, var2, axis) in edges else None + edge_dim, edge_func, ntype, normalization = attributes + if var1 == var2: + attn_name = var1.name + "_" + var2.name + "_" + axis.name + mlp = self.mlp[attn_name] + # self attention + + attn_axis_idx = self.var_axes[var1].index(axis) + x = vars[var1] + aux_x = aux_xs[var1] if var1 in aux_xs else None + if (var1, var2, axis) in cross_masks: + mask = cross_masks[(var1, var2, axis)] + elif var1 in var_masks: + # provide the feature mask instead of attention mask + mask = var_masks[var1] + mask = mask.transpose(attn_axis_idx, -1) + mask = mask.reshape(-1, mask.size(-1)) + + if axis in [FeatureAxes.A, FeatureAxes.L]: + mask = self.generate_agent_attn_mask(mask, ntype) + else: + mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) + else: + mask = None + if attn_axis_idx != len(self.var_axes[var1]) - 2: + # permute the attention axis to second last + x = x.transpose(attn_axis_idx, -2) + aux_x = ( + aux_x.transpose(attn_axis_idx, -2) + if aux_x is not None + else None + ) + orig_shape = x.shape + if axis in [FeatureAxes.A, FeatureAxes.L]: + x = x.reshape(-1, orig_shape[-2], orig_shape[-1]) + aux_x = ( + aux_x.reshape(-1, aux_x.shape[-2], aux_x.shape[-1]) + if aux_x is not None + else None + ) + + if normalization: + bn_1 = self.bn_1[attn_name + "_" + var1.name] + bn_2 = self.bn_2[attn_name] + xn = bn_1(x.view(-1, x.shape[-1])).view(*x.shape) + resid = self.attn[attn_name](xn, aux_x, mask, edge=edge) + x = x + mlp( + bn_2(resid.view(-1, resid.shape[-1])).view(*resid.shape) + ) + else: + x = x + mlp(self.attn[attn_name](x, aux_x, mask, edge=edge)) + elif axis == FeatureAxes.T: + T = x.shape[-2] + x = x.reshape(-1, T, D) + aux_x = ( + aux_x.reshape(-1, T, aux_x.shape[-1]) + if aux_x is not None + else None + ) + frame_index = frame_indices[var1] + frame_index = frame_index.reshape(-1, T) + if isinstance(self.attn[attn_name], sAuxRPEAttention) or isinstance( + self.attn[attn_name], AuxSelfAttention + ): + if normalization: + bn_1 = self.bn_1[attn_name + "_" + var1.name] + bn_2 = self.bn_2[attn_name] + xn = bn_1(x.view(-1, x.shape[-1])).view(*x.shape) + resid = self.attn[attn_name]( + xn, aux_x, mask, frame_indices=frame_index, edge=edge + ) + x = x + mlp( + bn_2(resid.view(-1, resid.shape[-1])).view(*resid.shape) + ) + else: + x = x + mlp( + self.attn[attn_name]( + x, aux_x, mask, frame_indices=frame_index, edge=edge + ) + ) + + elif isinstance( + self.attn[attn_name], sAuxRPECrossAttention + ) or isinstance(self.attn[attn_name], AuxCrossAttention): + if normalization: + bn_1 = self.bn_1[attn_name + "_" + var1.name] + bn_2 = self.bn_2[attn_name] + xn = bn_1(x.view(-1, x.shape[-1])).view(*x.shape) + resid = self.attn[attn_name]( + xn, + xn, + mask, + aux_x, + frame_index, + aux_x, + frame_index, + edge=edge, + ) + x = x + mlp( + bn_2(resid.view(-1, resid.shape[-1])).view(*resid.shape) + ) + else: + x = x + mlp( + self.attn[attn_name]( + x, + x, + mask, + aux_x, + aux_x, + frame_indices1=frame_index, + frame_indices2=frame_index, + edge=edge, + ) + ) + else: + raise NotImplementedError + else: + raise NotImplementedError + x = x.reshape(*orig_shape) + if attn_axis_idx != len(self.var_axes[var1]) - 2: + # permute the attention axis back + x = x.transpose(attn_axis_idx, -2) + vars[var1] = x + else: + # cross attention + attn_name = ( + var1.name + + "_" + + var2.name + + "_" + + axis[0].name + + "->" + + axis[1].name + ) + mlp = self.mlp[attn_name] + x1 = vars[var1] + x2 = vars[var2] + aux_x1 = aux_xs[var1] if var1 in aux_xs else None + aux_x2 = aux_xs[var2] if var2 in aux_xs else None + mask = self.get_cross_mask( + var1, var2, axis, x1, x2, var_masks, cross_masks + ) + + if ( + var1 in [TFvars.Agent_hist, TFvars.Agent_future] + and var2 == TFvars.Lane + ): + # cross attention between agent and lane + assert x1.ndim == 4 # B,N,T,D + assert x2.ndim == 3 # B,L,D + T = x1.shape[2] + x1 = x1.transpose(1, 2).reshape(-1, N, D) # BT,N,D + x2 = x2.repeat_interleave(T, 0) # BT,L,D + aux_x1 = aux_x1.transpose(1, 2).reshape(B * T, N, -1) + aux_x2 = aux_x2.repeat_interleave(T, 0) + if normalization: + bn_11 = self.bn_1[attn_name + "_" + var1.name] + bn_12 = self.bn_1[attn_name + "_" + var2.name] + bn_2 = self.bn_2[attn_name] + xn1 = bn_11(x1.view(-1, x1.shape[-1])).view(*x1.shape) + xn2 = bn_12(x2.view(-1, x2.shape[-1])).view(*x2.shape) + resid = self.attn[attn_name]( + xn1, xn2, mask, aux_x1, aux_x2, edge=edge + ) + x1 = x1 + mlp( + bn_2(resid.view(-1, resid.shape[-1])).view(*resid.shape) + ) + else: + x1 = x1 + mlp( + self.attn[attn_name]( + x1, x2, mask, aux_x1, aux_x2, edge=edge + ) + ) + x1 = x1.reshape(B, T, N, D).transpose(1, 2) # B,N,T,D + elif var1 == TFvars.Lane and var2 in [ + TFvars.Agent_hist, + TFvars.Agent_future, + ]: + # cross attention between agent and lane + assert x2.ndim == 4 # B,N,T,D + assert x1.ndim == 3 # B,L,D + T = x2.shape[2] + L = x1.shape[1] + x2 = x2.transpose(1, 2).reshape(-1, N, D) # BT,N,D + x1 = x1.repeat_interleave(T, 0) # BT,L,D + aux_x2 = aux_x2.transpose(1, 2).reshape(-1, N, D) + aux_x1 = aux_x1.repeat_interleave(T, 0) + if normalization: + bn_11 = self.bn_1[attn_name + "_" + var1.name] + bn_12 = self.bn_1[attn_name + "_" + var2.name] + bn_2 = self.bn_2[attn_name] + xn1 = bn_11(x1.view(-1, x1.shape[-1])).view(*x1.shape) + xn2 = bn_12(x2.view(-1, x2.shape[-1])).view(*x2.shape) + resid = self.attn[attn_name]( + xn1, xn2, mask, aux_x1, aux_x2, edge=edge + ) + x1 = x1 + mlp( + bn_2(resid.view(-1, resid.shape[-1])).view(*resid.shape) + ) + else: + x1 = x1 + mlp( + self.attn[attn_name]( + x1, x2, mask, aux_x1, aux_x2, edge=edge + ) + ) + x1 = x1.reshape(B, T, L, D).max(1)[0] # B,L,D + elif var1 == TFvars.Agent_future and var2 == TFvars.Agent_hist: + assert axis == (FeatureAxes.T, FeatureAxes.T) + x2 = vars[TFvars.Agent_hist].reshape(B * N, Th, D) + x1 = vars[TFvars.Agent_future].reshape(B * N, Tf, D) + aux_x2 = aux_xs[TFvars.Agent_hist].reshape(B * N, Th, -1) + aux_x1 = aux_xs[TFvars.Agent_future].reshape(B * N, Tf, -1) + frame_indices1 = frame_indices[TFvars.Agent_future].reshape( + B * N, Tf + ) + frame_indices2 = frame_indices[TFvars.Agent_hist].reshape(B * N, Th) + if normalization: + bn_11 = self.bn_1[attn_name + "_" + var1.name] + bn_12 = self.bn_1[attn_name + "_" + var2.name] + bn_2 = self.bn_2[attn_name] + xn1 = bn_11(x1.view(-1, x1.shape[-1])).view(*x1.shape) + xn2 = bn_12(x2.view(-1, x2.shape[-1])).view(*x2.shape) + resid = self.attn[attn_name]( + xn1, + xn2, + mask, + aux_x1, + aux_x2, + frame_indices1, + frame_indices2, + edge=edge, + ).reshape(B, N, Tf, D) + x1 = x1 + mlp( + bn_2(resid.view(-1, resid.shape[-1])).view(*resid.shape) + ) + else: + x1 = mlp( + self.attn[attn_name]( + x1, + x2, + mask, + aux_x1, + aux_x2, + frame_indices1, + frame_indices2, + edge=edge, + ) + ).reshape(B, N, Tf, D) + + else: + raise NotImplementedError + vars[var1] = x1 + return vars + + +class FactorizedGNN(nn.Module): + def __init__( + self, + var_axes: dict, + edge_var: dict, + GNN_attributes: OrderedDict, + node_n_embd: int, + edge_n_embd: int, + nominal_axes_order: list = None, + ): + """a factorized GNN + + Args: + var_axes (dict): variables and their axes + edge_var (dict): edges -> the variables it connects + GNN_attributes (OrderedDict): recipe for GNN + node_n_embd (int): embedding dimension for node variables + edge_n_embd (int): embedding dimension for edge variables + nominal_axes_order (list, optional): the order of axes when arranging tensors. Defaults to None. + """ + super().__init__() + self.vars = list(var_axes.keys()) + self.var_axes = dict() + for var in self.vars: + self.var_axes[var] = [ + FeatureAxes[var_axis] for var_axis in var_axes[var].split(",") + ] + self.edge_var = edge_var + + if nominal_axes_order is None: + nominal_axes_order = list([var for var in FeatureAxes]) + self.nominal_axes_order = nominal_axes_order + self.node_n_embd = node_n_embd + self.edge_n_embd = edge_n_embd + + self.GNN_attributes = GNN_attributes + self.GNN_nets = nn.ModuleDict() + self.node_attn = nn.ParameterDict() + self.pe_net = nn.ParameterDict() + self.batch_norm = nn.ModuleDict() + + for (edge, gtype, var), attributes in self.GNN_attributes.items(): + hidden_dim, activation, pooling_method, Nmax = attributes + + if gtype == "edge": + net_name = f"edge_{edge.name}" + var1, var2 = self.edge_var[edge] + nd1, nd2 = self.node_n_embd[var1], self.node_n_embd[var2] + ed = self.edge_n_embd[edge] + # message udpate of edge + + self.GNN_nets[net_name] = MLP( + input_dim=nd1 + nd2 + ed, + output_dim=ed, + layer_dims=hidden_dim, + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=activation if activation is not None else nn.ReLU, + dropouts=None, + normalization=False, + ) + elif gtype == "node": + # message update of node + net_name = f"node_{edge.name}_{var.name}" + ed = self.edge_n_embd[edge] + nd = self.node_n_embd[var] + # message udpate of edge + + self.GNN_nets[net_name] = MLP( + input_dim=nd + ed, + output_dim=nd, + layer_dims=hidden_dim, + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=activation if activation is not None else nn.ReLU, + dropouts=None, + normalization=False, + ) + + if pooling_method == "attn": + self.node_attn[net_name] = CrossAttention(nd, n_head=4) + self.pe_net[net_name] = ( + nn.Parameter(torch.randn(Nmax, nd)) + if Nmax is not None + else None + ) + + def pooling(self, node_feat, node_feat_new, mask, net_name, pooling_method): + """pooling the node features when aggregating the messages for node varaibles""" + # node_feat: [B,N,node_dim] + # node_feat_new: [B,N,N,node_dim] + if pooling_method == "max": + node_feat_new.masked_fill_( + mask[..., None] == 0, node_feat_new.min().detach() - 1 + ) + node_feat_new = node_feat_new.max(2)[0] + elif pooling_method == "mean": + node_feat_new = (node_feat_new * mask[..., None]).sum(2) / mask.sum(2)[ + ..., None + ].clip(min=1e-3) + elif pooling_method == "attn": + N, D = node_feat_new.shape[-2:] + if self.pe_net[net_name] is not None: + node_feat_new = node_feat_new + self.pe_net[net_name][None, None, :N] + node_feat_new = self.node_attn[net_name]( + node_feat.reshape(-1, 1, D), + node_feat_new.reshape(-1, N, D), + mask.reshape(-1, 1, N), + ).reshape(*node_feat.shape) + + return node_feat_new + + def get_cross_mask( + self, var1, var2, edge, x1, x2, var_masks, cross_masks, axis=None + ): + """get the mask for message passing""" + if edge in cross_masks: + mask = cross_masks[edge] + else: + a1 = [a.name for a in self.var_axes[var1] if a != FeatureAxes.F] + a2 = [a.name for a in self.var_axes[var2] if a != FeatureAxes.F] + ag = [ + a.name + for a in self.nominal_axes_order + if a != FeatureAxes.F and ((a.name in a1) or (a.name in a2)) + ] + if var1 == var2: + # self mask + # need to provide the axis on which the mask is applied + assert axis is not None + assert axis.name.swapcase() not in a1 + assert axis.name.swapcase() not in a2 + a2[a2.index(axis.name)] = axis.name.swapcase() + ag.insert(ag.index(axis.name) + 1, axis.name.swapcase()) + + cmd = "".join(a1) + "," + "".join(a2) + "->" + "".join(ag) + if var1 in var_masks: + mask1 = var_masks[var1] + else: + mask1 = torch.ones_like(x1[..., 0]) + if var2 in var_masks: + mask2 = var_masks[var2] + else: + mask2 = torch.ones_like(x2[..., 0]) + mask = torch.einsum(cmd, mask1, mask2) + return mask + + def forward(self, vars, var_masks, cross_masks): + nodes = {var: x for var, x in vars.items() if var in TFvars} + edges = {edge: x for edge, x in vars.items() if edge in GNNedges} + for (edge, gtype, var), attributes in self.GNN_attributes.items(): + _, _, pooling_method, _ = attributes + if gtype == "edge": + net_name = f"edge_{edge.name}" + var1, var2 = self.edge_var[edge] + nx1, nx2 = nodes[var1], nodes[var2] + ex = edges[edge] + if edge in [GNNedges.Agenthist2Lane, GNNedges.Agentfuture2Lane]: + B, N, T = nx1.shape[:3] + M = nx2.size(1) + aggr_feature = torch.cat( + [ + nx1.unsqueeze(3).expand(B, N, T, M, -1), + nx2.view(B, 1, 1, M, -1).expand(B, N, T, M, -1), + ex, + ], + dim=-1, + ) + edges[edge] = edges[edge] + self.GNN_nets[net_name](aggr_feature) + elif edge in [ + GNNedges.Agenthist2Agenthist, + GNNedges.Agentfuture2Agentfuture, + ]: + B, N, T = nx1.shape[:3] + aggr_feature = torch.cat( + [ + nx1.unsqueeze(2).expand(B, N, N, T, -1), + nx1.unsqueeze(1).expand(B, N, N, T, -1), + ex, + ], + dim=-1, + ) + edges[edge] = edges[edge] + self.GNN_nets[net_name](aggr_feature) + elif edge == GNNedges.Lane2Lane: + B, M = nx1.shape[:2] + aggr_feature = torch.cat( + [ + nx1.unsqueeze(2).expand(B, M, M, -1), + nx1.unsqueeze(3).expand(B, M, M, -1), + ex, + ], + dim=-1, + ) + edges[edge] = edges[edge] + self.GNN_nets[net_name](aggr_feature) + else: + raise NotImplementedError + elif gtype == "node": + net_name = f"node_{edge.name}_{var.name}" + ex = edges[edge] + var1, var2 = self.edge_var[edge] + nx1, nx2 = nodes[var1], nodes[var2] + nx = nodes[var] + + if edge in [GNNedges.Agenthist2Lane, GNNedges.Agentfuture2Lane]: + mask = self.get_cross_mask( + var1, var2, edge, nx1, nx2, var_masks, cross_masks + ) + if var in [TFvars.Agent_future, TFvars.Agent_hist]: + B, N, T = nx.shape[:3] + aggr_feature = torch.cat( + [nx.unsqueeze(3).expand(B, N, T, M, -1), ex], dim=-1 + ) + + new_nx = self.GNN_nets[net_name](aggr_feature) + + nodes[var] = nodes[var] + self.pooling( + nx.reshape(B, N * T, -1), + new_nx.view(B, N * T, M, -1), + mask.view(B, N * T, M), + net_name, + pooling_method=pooling_method, + ).view(B, N, T, -1) + elif var == TFvars.Lane: + B, M = nx.shape[:2] + aggr_feature = torch.cat( + [nx[:, None, None].expand(B, N, T, M, -1), ex], dim=-1 + ) + + new_nx = self.GNN_nets[net_name](aggr_feature) # [B,N,T,M,D] + nodes[var] = nodes[var] + self.pooling( + nx, + new_nx.view(B, N * T, M, -1).transpose(1, 2), + mask.view(B, N * T, M).transpose(1, 2), + net_name, + pooling_method=pooling_method, + ) + + elif edge in [ + GNNedges.Agenthist2Agenthist, + GNNedges.Agentfuture2Agentfuture, + ]: + mask = self.get_cross_mask( + var1, + var2, + edge, + nx1, + nx2, + var_masks, + cross_masks, + axis=FeatureAxes.A, + ) + B, N, T = nx.shape[:3] + aggr_feature = torch.cat( + [nx.unsqueeze(2).expand(B, N, N, T, -1), ex], dim=-1 + ) + new_nx = self.GNN_nets[net_name](aggr_feature) # [B,N,N,T,D] + nodes[var] = nodes[var] + self.pooling( + nx.transpose(1, 2).reshape(B * T, N, -1), + new_nx.permute(0, 3, 1, 2, 4).reshape(B * T, N, N, -1), + mask.permute(0, 3, 1, 2).reshape(B * T, N, N), + net_name, + pooling_method=pooling_method, + ).view(B, T, N, -1).transpose(1, 2) + + elif edge == GNNedges.Lane2Lane: + raise NotImplementedError + + return {**nodes, **edges} + + +class tfmlp(nn.Module): + def __init__(self, n_embd, ff_dim, dropout): + """An MLP with GELU activation and dropout for transformer residual connection + + Args: + n_embd (int): embedding dimension + ff_dim (int): feed forward dimension + dropout (float): dropout rate + """ + super().__init__() + self.c_fc = nn.Linear(n_embd, ff_dim) + self.c_proj = zero_module(nn.Linear(ff_dim, n_embd)) + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.dropout(self.c_proj(self.act(self.c_fc(x)))) + + +class CTTBlock(nn.Module): + def __init__(self, TF_kwargs, GNN_kwargs): + """a CTT block with a transformer block and a GNN block""" + super().__init__() + self.TFblock = ( + FactorizedAttentionBlock(**TF_kwargs) if TF_kwargs is not None else None + ) + self.GNNblock = FactorizedGNN(**GNN_kwargs) if GNN_kwargs is not None else None + + def forward( + self, vars, aux_xs, var_masks, cross_masks=None, frame_indices=None, edges={} + ): + if self.TFblock is not None: + vars = self.TFblock( + vars, aux_xs, var_masks, cross_masks, frame_indices, edges + ) + if self.GNNblock is not None: + vars = self.GNNblock(vars, var_masks, cross_masks) + return vars + + +class CTTEncoder(nn.Module): + def __init__(self, nblock, TF_kwargs, GNN_kwargs): + super().__init__() + self.blocks = nn.ModuleList( + [CTTBlock(TF_kwargs, GNN_kwargs) for _ in range(nblock)] + ) + + def forward(self, vars, enc_kwargs): + for block in self.blocks: + vars = block(vars, **enc_kwargs) + return vars + + +class CTT(nn.Module): + def __init__( + self, + n_embd: int, + embed_funcs: dict, + enc_nblock: int, + dec_nblock: int, + enc_transformer_kwargs: dict, + enc_GNN_kwargs: dict, + dec_transformer_kwargs: dict, + dec_GNN_kwargs: dict, + enc_output_params: dict, + dec_output_params: dict, + hist_lane_relation, + fut_lane_relation, + max_joint_cardinality: int, + classify_a2l_4all_lanes: bool = False, + edge_func=dict(), + ): + """Categorical Traffic Transformer + + Args: + n_embd (int): embedding dimensions + embed_funcs (dict): embedding functions for variables + enc_nblock (int): number of blocks in encoder + dec_nblock (int): number of blocks in decoder + enc_transformer_kwargs (dict): recipe for transformer in CTT encoder + enc_GNN_kwargs (dict): recipe for GNN in CTT encoder + dec_transformer_kwargs (dict): recipe for transformer in CTT decoder + dec_GNN_kwargs (dict): recipe for GNN in CTT decoder + enc_output_params (dict): parameters for mode prediction of CTT encoder + dec_output_params (dict): parameters for trajectory prediction of CTT decoder + hist_lane_relation: class of lane relation between agent history and lane segments + fut_lane_relation: class of lane relation between agent future and lane segments + max_joint_cardinality (int): maximum number of cardinality for each factor during importance sampling for joint scene mode + classify_a2l_4all_lanes (bool, optional): whether to predict lane mode for every agent-lane pair. Defaults to False. + edge_func (dict, optional): edge function between node variables. Defaults to dict(). + """ + super().__init__() + # build transformer encoder and decoder + self.embed_funcs = nn.ModuleDict({k.name: v for k, v in embed_funcs.items()}) + self.edge_var = {} + if enc_GNN_kwargs is not None: + self.edge_var.update(enc_GNN_kwargs["edge_var"]) + if dec_GNN_kwargs is not None: + self.edge_var.update(dec_GNN_kwargs["edge_var"]) + self.encoder = CTTEncoder(enc_nblock, enc_transformer_kwargs, enc_GNN_kwargs) + self.decoder = CTTEncoder(dec_nblock, dec_transformer_kwargs, dec_GNN_kwargs) + self.n_embd = n_embd + self.max_joint_cardinality = max_joint_cardinality + self.hist_lane_relation = hist_lane_relation + self.fut_lane_relation = fut_lane_relation + self.classify_a2l_4all_lanes = classify_a2l_4all_lanes + self.build_enc_pred_net(enc_output_params) + self.build_dec_output_net(dec_output_params) + self.use_hist_mode = False + self.enc_output_params = enc_output_params + self.dec_output_params = dec_output_params + self.enc_vars = ( + [var for var in enc_transformer_kwargs["var_axes"].keys()] + if enc_transformer_kwargs is not None + else [] + ) + self.enc_edges = ( + [edge for edge in enc_GNN_kwargs["edge_var"].keys()] + if enc_GNN_kwargs is not None + else [] + ) + self.dec_vars = ( + [var for var in dec_transformer_kwargs["var_axes"].keys()] + if dec_transformer_kwargs is not None + else [] + ) + self.dec_edges = ( + [edge for edge in dec_GNN_kwargs["edge_var"].keys()] + if dec_GNN_kwargs is not None + else [] + ) + self.edge_func = edge_func + self.dec_T = self.Tf + + def build_enc_pred_net(self, enc_output_params, use_hist_mode=False): + # Nlr = len(self.fut_lane_relation) + # Nhm = len(HomotopyType) + self.Th = enc_output_params["Th"] + self.pooling_T = enc_output_params["pooling_T"] + self.null_lane_mode = enc_output_params["null_lane_mode"] + mode_embed_dim = enc_output_params.get("mode_embed_dim", 64) + self.lm_embed = nn.Embedding(len(self.fut_lane_relation), mode_embed_dim) + self.homotopy_embed = nn.Embedding(len(HomotopyType), mode_embed_dim) + hidden_dim = ( + enc_output_params["lane_mode"]["hidden_dim"] + if "hidden_dim" in enc_output_params["lane_mode"] + else [self.n_embd * 4, self.n_embd * 2] + ) + # marginal probability of lane relation + self.lane_mode_net = MLP( + input_dim=self.n_embd + mode_embed_dim if use_hist_mode else self.n_embd, + output_dim=len(self.fut_lane_relation) + - ( + 1 - int(self.classify_a2l_4all_lanes) + ), # If loss over lanes, no NOTON relation + layer_dims=hidden_dim, + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=nn.ReLU, + dropouts=None, + normalization=False, + output_activation=None, + ) + + if self.pooling_T == "attn": + n_head = enc_output_params["lane_mode"]["n_head"] + if enc_output_params["PE_mode"] == "RPE": + self.lane_mode_attn_T = sAuxRPECrossAttention( + self.n_embd + mode_embed_dim if use_hist_mode else self.n_embd, + n_head, + edge_dim=0, + aux_edge_func=None, + use_checkpoint=False, + use_rpe_net=False, + ) + + self.homotopy_attn_T = sAuxRPECrossAttention( + self.n_embd + mode_embed_dim if use_hist_mode else self.n_embd, + n_head, + edge_dim=0, + aux_edge_func=None, + use_checkpoint=False, + use_rpe_net=False, + ) + + self.agent_hist_attn_T = sAuxRPECrossAttention( + self.n_embd, + n_head, + edge_dim=0, + aux_edge_func=None, + use_checkpoint=False, + use_rpe_net=False, + ) + elif enc_output_params["PE_mode"] == "PE": + self.lane_mode_attn_T = AuxCrossAttention( + self.n_embd + mode_embed_dim if use_hist_mode else self.n_embd, + n_head, + edge_dim=0, + aux_edge_func=None, + PE_len=self.Th + 1, + ) + self.homotopy_attn_T = AuxCrossAttention( + self.n_embd + mode_embed_dim if use_hist_mode else self.n_embd, + n_head, + edge_dim=0, + aux_edge_func=None, + PE_len=self.Th + 1, + ) + self.agent_hist_attn_T = AuxCrossAttention( + self.n_embd, + n_head, + edge_dim=0, + aux_edge_func=None, + PE_len=self.Th + 1, + ) + + hidden_dim = ( + enc_output_params["homotopy"]["hidden_dim"] + if "hidden_dim" in enc_output_params["homotopy"] + else [self.n_embd * 4, self.n_embd * 2] + ) + + # marginal probability of homotopy + self.homotopy_net = MLP( + input_dim=self.n_embd + mode_embed_dim if use_hist_mode else self.n_embd, + output_dim=len(HomotopyType), + layer_dims=hidden_dim, + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=nn.ReLU, + dropouts=None, + normalization=False, + output_activation=None, + ) + + GNN_kwargs = enc_output_params["joint_mode"]["GNN_kwargs"] + jm_GNN_nblock = enc_output_params["joint_mode"].get("jm_GNN_nblock", 2) + + self.JM_GNN = nn.ModuleList( + [FactorizedGNN(**GNN_kwargs) for i in range(jm_GNN_nblock)] + ) + hidden_dim = enc_output_params["joint_mode"].get("hidden_dim", [256, 128]) + self.JM_lane_mode_factor = MLP( + input_dim=self.n_embd + mode_embed_dim * 2 + if use_hist_mode + else self.n_embd + mode_embed_dim, + output_dim=1, + layer_dims=hidden_dim, + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=nn.ReLU, + dropouts=None, + normalization=False, + output_activation=None, + ) + self.JM_homotopy_factor = MLP( + input_dim=self.n_embd + mode_embed_dim * 2 + if use_hist_mode + else self.n_embd + mode_embed_dim, + output_dim=1, + layer_dims=hidden_dim, + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=nn.ReLU, + dropouts=None, + normalization=False, + output_activation=None, + ) + self.num_joint_samples = enc_output_params["joint_mode"].get( + "num_joint_samples", 30 + ) + self.num_joint_factors = enc_output_params["joint_mode"].get( + "num_joint_factors", 6 + ) + + def build_dec_output_net(self, dec_output_params): + arch = dec_output_params["arch"] + self.Tf = dec_output_params["Tf"] + self.arch = arch + dyn = dec_output_params.get("dyn", None) + self.dyn = dyn + self.dt = dec_output_params.get("dt", 0.1) + traj_dim = dec_output_params["traj_dim"] + self.decode_num_modes = dec_output_params.get("decode_num_modes", 1) + self.AR_step_size = dec_output_params.get("AR_step_size", 1) + self.AR_update_mode = dec_output_params.get("AR_update_mode", "step") + self.LR_sample_hack = dec_output_params.get("LR_sample_hack", False) + self.dec_rounds = dec_output_params.get("dec_rounds", 3) + assert self.Tf % self.AR_step_size == 0 + if arch == "lstm": + self.output_rnn = nn.ModuleDict() + num_layers = dec_output_params.get("num_layers", 1) + hidden_size = dec_output_params.get("lstm_hidden_size", self.n_embd) + for k, v in self.dyn.items(): + if v is not None: + proj_size = v.udim + else: + proj_size = traj_dim + self.output_rnn[k.name] = nn.LSTM( + self.n_embd, + hidden_size, + batch_first=True, + num_layers=num_layers, + proj_size=proj_size, + ) + elif arch == "mlp": + self.output_mlp = nn.ModuleDict() + hidden_dim = dec_output_params.get( + "mlp_hidden_dims", [self.n_embd * 2, self.n_embd * 4] + ) + for k, v in self.dyn.items(): + self.output_mlp[k.name] = MLP( + input_dim=self.n_embd, + output_dim=traj_dim if v is None else v.udim, + layer_dims=hidden_dim, + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=nn.ReLU, + dropouts=None, + normalization=False, + output_activation=None, + ) + + def embed_raw(self, raw_vars, aux_xs, aux_edges={}): + vars = dict() + for kname, v in self.embed_funcs.items(): + if hasattr(TFvars, kname): + # embed a TFvar + k = getattr(TFvars, kname) + if k in raw_vars: + vars[k] = self.embed_funcs[k.name](raw_vars[k]) + + elif hasattr(GNNedges, kname): + k = getattr(GNNedges, kname) + var1, var2 = self.edge_var[k] + if var1 in aux_xs and var2 in aux_xs: + aux_edge = aux_edges.get(k, None) + vars[k] = self.embed_funcs[k.name]( + aux_xs[var1], aux_xs[var2], aux_edge + ) + return vars + + def predict_from_context( + self, + context_vars, + aux_xs, + frame_indices, + var_masks, + cross_masks, + enc_edges, + GT_lane_mode=None, + GT_homotopy=None, + prev_lm_pred_dict=defaultdict(lambda: None), + prev_homo_pred_dict=defaultdict(lambda: None), + num_samples=None, + GT_batch_mask=None, + ): + # GT_batch_mask: identify which indices (among batch dimension) are given the GT modes + + device = context_vars[TFvars.Agent_hist].device + if num_samples is None: + num_samples = self.num_joint_samples + # predict lane mode + agent_feat = context_vars[TFvars.Agent_hist] + lane_feat = context_vars[TFvars.Lane] + a2l_edge = context_vars[GNNedges.Agenthist2Lane] + a2a_edge = context_vars[GNNedges.Agenthist2Agenthist] + B, N, Th = agent_feat.shape[:3] + if GT_batch_mask is None and GT_lane_mode is not None: + GT_batch_mask = torch.ones(B, device=device, dtype=bool) + M = lane_feat.shape[1] + if self.use_hist_mode: + prev_lm_pred = [ + torch.zeros(B, N, M, dtype=torch.int, device=device) for t in range(Th) + ] + prev_homo_pred = [ + torch.zeros(B, N, N, dtype=torch.int, device=device) for t in range(Th) + ] + + for t, v in prev_lm_pred_dict.items(): + prev_lm_pred[t] = v + for t, v in prev_homo_pred_dict.items(): + if v is not None: + prev_homo_pred[t] = v + prev_lm_embed = self.lm_embed(torch.stack(prev_lm_pred, 2)) + prev_homo_embed = self.homotopy_embed(torch.stack(prev_homo_pred, 3)) + a2l_edge = torch.cat([a2l_edge, prev_lm_embed], -1) + a2a_edge = torch.cat([a2a_edge, prev_homo_embed], -1) + if self.pooling_T == "max": + a2l_edge_pool = a2l_edge.max(2)[0] + a2a_edge_pool = a2a_edge.max(3)[0] + agent_feat_pool = agent_feat.max(-2)[0] + + elif self.pooling_T == "attn": + query = a2l_edge[:, :, -1].reshape(B * N * M, 1, -1) + frame_indices2 = ( + frame_indices[TFvars.Agent_hist] + .reshape(B * N, -1) + .repeat_interleave(M, 0) + ) + frame_indices1 = frame_indices2[:, -1:] + hist_mask = var_masks[TFvars.Agent_hist] # B,N,T + mask = ( + (hist_mask[..., :1, None] * hist_mask.unsqueeze(2)) + .repeat_interleave(M, 1) + .reshape(B * N * M, 1, Th) + ) + a2l_edge_pool = self.lane_mode_attn_T( + query, + a2l_edge.permute(0, 1, 3, 2, 4).reshape(B * N * M, Th, -1), + mask, + None, + None, + frame_indices1=frame_indices1, + frame_indices2=frame_indices2, + ).reshape(B, N, M, -1) + + query = a2a_edge[:, :, :, -1].reshape(B * N * N, 1, -1) + frame_indices2 = ( + frame_indices[TFvars.Agent_hist] + .reshape(B * N, -1) + .repeat_interleave(N, 0) + ) + frame_indices1 = frame_indices2[:, -1:] + mask = ( + (hist_mask[..., :1, None] * hist_mask.unsqueeze(2)) + .repeat_interleave(N, 1) + .reshape(B * N * N, 1, Th) + ) + a2a_edge_pool = self.homotopy_attn_T( + query, + a2a_edge.reshape(B * N * N, Th, -1), + mask, + None, + None, + frame_indices1, + frame_indices2, + ).reshape(B, N, N, -1) + + query = agent_feat[:, :, -1].reshape(B * N, 1, -1) + frame_indices2 = frame_indices[TFvars.Agent_hist].reshape(B * N, -1) + frame_indices1 = frame_indices2[:, -1:] + mask = (hist_mask[..., :1, None] * hist_mask.unsqueeze(2)).reshape( + B * N, 1, Th + ) + agent_feat_pool = self.agent_hist_attn_T( + query, + agent_feat.reshape(B * N, Th, -1), + mask, + None, + None, + frame_indices1, + frame_indices2, + ).view(B, N, -1) + + lane_mode_pred = self.lane_mode_net(a2l_edge_pool) + + if not self.classify_a2l_4all_lanes: # if self.per_mode + lane_mode_pred = lane_mode_pred.swapaxes( + -1, -2 + ) # [..., M, nbr_modes] -> [..., nbr_modes, M] + if self.null_lane_mode: + # add a null lane mode + lane_mode_pred = torch.cat( + [lane_mode_pred, torch.zeros_like(lane_mode_pred[..., 0:1])], -1 + ) + + lane_mode_prob = F.softmax(lane_mode_pred, dim=-1) + + homotopy_pred = self.homotopy_net(a2a_edge_pool) + homotopy_asym = homotopy_pred.transpose(1, 2) + homotopy_pred = (homotopy_pred + homotopy_asym) / 2 + + homotopy_prob = F.softmax(homotopy_pred, dim=-1) + + ########### Compute joint mode probability ############ + D_lane = lane_mode_pred.shape[ + -1 + ] # Either M (M+1) or len(self.fut_lane_relation) + D_homo = len(HomotopyType) + D = min([max([D_lane, D_homo]), self.max_joint_cardinality]) + M_lm = lane_mode_pred.shape[ + -2 + ] # number of possible modes for each lane,agent pair + + # factor masks for joint mode prediction + agent_mask = var_masks[TFvars.Agent_hist].any(-1).float() + lane_mask = var_masks[TFvars.Lane] + lm_factor_mask = agent_mask.unsqueeze(2) * lane_mask.unsqueeze(1) # (B,N,M) + if self.null_lane_mode and not self.classify_a2l_4all_lanes: + lane_mask_ext = torch.cat( + [lane_mask, torch.ones(*lane_mask.shape[:-1], 1, device=device)], -1 + ) + lm_factor_mask = torch.cat([lm_factor_mask, agent_mask.unsqueeze(-1)], -1) + else: + lane_mask_ext = lane_mask + + homo_factor_mask = agent_mask.unsqueeze(2) * agent_mask.unsqueeze(1) # (B,N,N) + if not self.classify_a2l_4all_lanes: + mode_mask = torch.ones( + [agent_mask.shape[0], M_lm], device=agent_mask.device + ) # All modes active + lm_factor_mask_per_mode = agent_mask.unsqueeze(2) * mode_mask.unsqueeze(1) + combined_factor_mask = torch.cat( + [ + lm_factor_mask_per_mode.view(B, N * M_lm), + homo_factor_mask.view(B, N * N), + ], + -1, + ) + else: + combined_factor_mask = torch.cat( + [lm_factor_mask.view(B, N * M_lm), homo_factor_mask.view(B, N * N)], -1 + ) + + if D_homo < D: + homotopy_pred_padded = torch.cat( + [ + homotopy_pred, + torch.ones(*homotopy_pred.shape[:-1], D - D_homo, device=device) + * -torch.inf, + ], + -1, + ) + elif D_homo > D: + raise NotImplementedError("We can only consider all homotopies") + else: + homotopy_pred_padded = homotopy_pred + + if ( + self.LR_sample_hack and GT_lane_mode is not None + ): # hack to favor sampling left and right lane of each agent + if ( + self.hist_lane_relation == LaneUtils.LaneRelation + and self.fut_lane_relation == LaneUtils.SimpleLaneRelation + ): + lane_mode_pred_m = lane_mode_pred.clone().detach() + try: + hist_lane_flag = enc_edges[ + (TFvars.Agent_hist, TFvars.Lane, (FeatureAxes.A, FeatureAxes.L)) + ][..., -len(LaneUtils.LaneRelation) :] + hist_lane_flag = hist_lane_flag.reshape(B, Th, N, M, -1)[ + :, -1 + ].type(torch.bool) + # pick out left and right lane and modify the probability + lm_logpi_max = lane_mode_pred_m.max(-1)[0] + left_lane_mask = hist_lane_flag[ + ..., LaneUtils.LaneRelation.LEFTOF + ].unsqueeze(2) + right_lane_mask = hist_lane_flag[ + ..., LaneUtils.LaneRelation.RIGHTOF + ].unsqueeze(2) + on_lane_mask = hist_lane_flag[ + ..., LaneUtils.LaneRelation.ON + ].unsqueeze(2) + lane_mode_pred_m[..., :M] = ( + lane_mode_pred_m[..., :M] * (~left_lane_mask) + + lm_logpi_max.unsqueeze(-1) * left_lane_mask + ) + lane_mode_pred_m[..., :M] = ( + lane_mode_pred_m[..., :M] * (~right_lane_mask) + + lm_logpi_max.unsqueeze(-1) * right_lane_mask + ) + lane_mode_pred_m[..., :M] = ( + lane_mode_pred_m[..., :M] * (~on_lane_mask) + + lm_logpi_max.unsqueeze(-1) * on_lane_mask + ) + except: + pass + + else: + lane_mode_pred_m = None + else: + lane_mode_pred_m = None + + if D_lane < D: + lane_mode_pred_padded = torch.cat( + [ + lane_mode_pred, + torch.ones(*lane_mode_pred.shape[:-1], D - D_lane, device=device) + * -torch.inf, + ], + -1, + ) + indices = torch.arange(D_lane, device=device).expand(B, N, 1, D_lane) + elif D_lane > D: + # For each agent & mode, find the D likeliest lanes + assert ( + not self.classify_a2l_4all_lanes + ), "Currently only implemented for per_mode setting" + + lane_mode_pred_masked = lane_mode_pred.masked_fill( + torch.logical_not(lane_mask_ext[:, None, None, :]), -torch.inf + ) + + sorted, indices = lane_mode_pred_masked.topk( + D, dim=-1 + ) # Only consider most likely lanes! + lane_mode_pred_padded = torch.gather( + lane_mode_pred, -1, indices + ) # (...,M) -> (...,D) + else: + lane_mode_pred_padded = lane_mode_pred + indices = torch.arange(D_lane, device=device).expand(B, N, 1, D_lane) + + combined_logpi = torch.cat( + [ + lane_mode_pred_padded.view(B, N * M_lm, -1), + homotopy_pred_padded.view(B, N * N, -1), + ], + 1, + ) + # for all invalid factors, turn the logpi to [0,-inf,-inf,-inf...] + combined_logpi[..., 1:].masked_fill_( + ~(combined_factor_mask.bool().unsqueeze(-1)), -torch.inf + ) + + if lane_mode_pred_m is not None: + if D_lane < D: + lane_mode_pred_padded_m = torch.cat( + [ + lane_mode_pred_m, + torch.ones( + *lane_mode_pred_m.shape[:-1], D - D_lane, device=device + ) + * -torch.inf, + ], + -1, + ) + indices_m = torch.arange(D_lane, device=device).expand(B, N, 1, D_lane) + elif D_lane > D: + # For each agent & mode, find the D likeliest lanes + assert ( + not self.classify_a2l_4all_lanes + ), "Currently only implemented for per_mode setting" + lane_mode_pred_masked_m = lane_mode_pred_m.masked_fill( + torch.logical_not(lane_mask_ext[:, None, None, :]), -torch.inf + ) + sorted, indices_m = lane_mode_pred_masked_m.topk( + D, dim=-1 + ) # Only consider most likely lanes! + lane_mode_pred_padded_m = torch.gather( + lane_mode_pred_m, -1, indices_m + ) # (...,M) -> (...,D) + + else: + lane_mode_pred_padded_m = lane_mode_pred_m + indices_m = torch.arange(D_lane, device=device).expand(B, N, 1, D_lane) + + combined_logpi_m = torch.cat( + [ + lane_mode_pred_padded_m.view(B, N * M_lm, -1), + homotopy_pred_padded.clone().detach().view(B, N * N, -1), + ], + 1, + ) + # for all invalid factors, turn the logpi to [0,-inf,-inf,-inf...] + combined_logpi_m[..., 1:].masked_fill_( + ~(combined_factor_mask.bool().unsqueeze(-1)), -torch.inf + ) + else: + combined_logpi_m = None + + # relevance score + # lane mode relevance score + # first, factors with one dominant mode is less relevant (or one dominant lane) + lm_dominance_score = -( + lane_mode_prob.max(dim=-1)[0] - 1 / lane_mode_prob.shape[-1] + ) # B,N,M_lm + homo_dominance_score = -( + homotopy_prob.max(dim=-1)[0] - 1 / homotopy_prob.shape[-1] + ) # B,N,N + # second, lanes that are far away from the agents are less important + a2ledge_raw = ModelUtils.agent2lane_edge_proj( + aux_xs[TFvars.Agent_hist][:, :, -1], aux_xs[TFvars.Lane] + ) + + if not self.classify_a2l_4all_lanes: + a2l_dis_score = 0 # For per mode, we don't consider lanes specifically so no a2l distance score + + else: + a2l_dis = a2ledge_raw[..., :2].norm(dim=-1) # B,N,M + a2l_dis_score = 1 / (a2l_dis + 1) + # lanes that are closer to ego is more important + + # third, agents far from ego is less relevant + a2eedge_raw = ModelUtils.agent2agent_edge( + aux_xs[TFvars.Agent_hist][:, 0:1, -1], aux_xs[TFvars.Agent_hist][:, :, -1] + ).squeeze(1) + a2e_dis = a2eedge_raw[..., :2].norm(dim=-1) + a2e_dis.masked_fill_(torch.logical_not(agent_mask.bool()), torch.inf) + a2e_dis_score = 1 / (a2e_dis.clip(min=0.5)) + a2e_dis_score_homo = a2e_dis_score.unsqueeze(1) * a2e_dis_score.unsqueeze(2) + a2e_dis_score_homo.masked_fill_( + torch.eye(N, device=a2e_dis_score.device, dtype=torch.bool).unsqueeze(0), 0 + ) + + lm_factor_score = ( + lm_dominance_score * 0.1 + + a2l_dis_score + + a2e_dis_score.unsqueeze(-1).repeat_interleave(M_lm, -1) * 2 + ) + homo_factor_score = homo_dominance_score + a2e_dis_score_homo + + # mask out half of homo_factor due to symmetry + sym_mask = torch.tril(torch.ones(B, N, N, dtype=torch.bool, device=device)) + homo_factor_score.masked_fill_(sym_mask, homo_factor_score.min().detach() - 5) + combined_factor_mask[:, -N * N :].masked_fill_(sym_mask.view(B, -1), 0) + if not self.classify_a2l_4all_lanes: + # Set the relevance score for ego agent (0) and the on lane index such that this is always a factor + max_score = torch.max(lm_factor_score.max(), homo_factor_score.max()) + on_lane_idx = ( + self.fut_lane_relation.ON - 1 + if self.fut_lane_relation.ON > self.fut_lane_relation.NOTON + else self.fut_lane_relation.ON + ) # We removed lane_pred + lm_factor_score[:, 0, on_lane_idx] = ( + 1.1 * max_score + ) # Temp, maybe make customizable + + combined_factor_score = torch.cat( + [lm_factor_score.view(B, N * M_lm), homo_factor_score.view(B, N * N)], -1 + ) + + num_factor = min(self.num_joint_factors, N * (M_lm + N)) + # modify the loglikelihood so that more important factors get more samples + temperature = 2 + modfied_combined_logpi = combined_logpi / torch.exp( + temperature + * (combined_factor_score - combined_factor_score.min()) + / (combined_factor_score.max() - combined_factor_score.min()) + ).unsqueeze(-1) + # Do importance sampling (sampling only the factors) the remaining marginals are chosen to be their argmax value + joint_sample, factor_idx = categorical_psample_wor( + modfied_combined_logpi, + num_samples, + num_factor, + factor_mask=combined_factor_mask, + relevance_score=combined_factor_score, + ) + + if combined_logpi_m is not None: + modfied_combined_logpi_m = combined_logpi_m / torch.exp( + temperature + * (combined_factor_score - combined_factor_score.min()) + / (combined_factor_score.max() - combined_factor_score.min()) + ).unsqueeze(-1) + joint_sample_m, factor_idx_m = categorical_psample_wor( + modfied_combined_logpi_m, + self.decode_num_modes, + num_factor, + factor_mask=combined_factor_mask, + relevance_score=combined_factor_score, + ) + else: + joint_sample_m = None + + # Turn indices of factors of importance back to combined_logpi size, then sum over num factors to get actual one hot (assuming nonrepetitive factors) + + ( + lm_sample, + homo_sample, + factor_mask, + ) = self.restore_lm_homotopy_from_joint_sample( + joint_sample, + indices, + factor_idx, + lm_factor_mask, + N, + lane_mask_ext.shape[-1], + M_lm, + ) + if joint_sample_m is not None: + ( + lm_sample_m, + homo_sample_m, + factor_mask_m, + ) = self.restore_lm_homotopy_from_joint_sample( + joint_sample_m, + indices_m, + factor_idx_m, + lm_factor_mask, + N, + lane_mask_ext.shape[-1], + M_lm, + ) + lm_sample = torch.cat([lm_sample, lm_sample_m], 1) + homo_sample = torch.cat([homo_sample, homo_sample_m], 1) + num_samples += lm_sample_m.size(1) + factor_mask = factor_mask | factor_mask_m + + if self.null_lane_mode: + # remove the factor_mask for the null lane + factor_mask = factor_mask.view(B, N, -1) + factor_mask = torch.cat( + [factor_mask[..., :M], factor_mask[..., M + 1 :]], -1 + ).reshape(B, -1) + + if ( + GT_lane_mode is not None and GT_homotopy is not None + ): # If we have GT, add it as sample for training + # for all the irrelevant entries, make them exactly the same as GT + if self.classify_a2l_4all_lanes: + lm_sample = lm_sample * lm_factor_mask.bool().unsqueeze(1) + ( + GT_lane_mode * torch.logical_not(lm_factor_mask.bool()) + ).unsqueeze(1) + + homo_sample = homo_sample * homo_factor_mask.bool().unsqueeze(1) + ( + GT_homotopy * torch.logical_not(homo_factor_mask.bool()) + ).unsqueeze(1) + # mask modes that are the same as GT + + lm_GT_flag = (lm_sample == GT_lane_mode.unsqueeze(1)).all(dim=3) + lm_GT_flag = lm_GT_flag.masked_fill( + torch.logical_not(agent_mask[:, None]), 1 + ).all(dim=2) + homo_GT_flag = homo_sample == GT_homotopy.unsqueeze(1) + homo_GT_flag = ( + homo_GT_flag.masked_fill( + torch.logical_not(homo_factor_mask)[:, None], 1 + ) + .all(-1) + .all(-1) + ) + GT_flag = (lm_GT_flag & homo_GT_flag) * GT_batch_mask[:, None] + # make sure that at least one entry of GT_flag is true for all scene, we can then remove 1 sample from every scene + num_samples = num_samples - GT_flag.sum(-1).max().item() + + lm_sample = torch.stack( + [ + lm_sample[i][torch.where(~GT_flag[i])[0]][:num_samples] + for i in range(B) + ], + 0, + ) + homo_sample = torch.stack( + [ + homo_sample[i][torch.where(~GT_flag[i])[0]][:num_samples] + for i in range(B) + ], + 0, + ) + if GT_batch_mask.all(): + lm_sample = torch.cat([GT_lane_mode.unsqueeze(1), lm_sample], 1) + homo_sample = torch.cat([GT_homotopy.unsqueeze(1), homo_sample], 1) + else: + lm_sample_GT = torch.cat([GT_lane_mode.unsqueeze(1), lm_sample], 1) + homo_sample_GT = torch.cat([GT_homotopy.unsqueeze(1), homo_sample], 1) + # for instances where GT is not provided, simply pad the last index + lm_sample_nonGT = torch.cat([lm_sample, lm_sample[:, :1]], 1) + homo_sample_nonGT = torch.cat([homo_sample, homo_sample[:, :1]], 1) + lm_sample = lm_sample_GT * GT_batch_mask[ + :, None, None, None + ] + lm_sample_nonGT * torch.logical_not( + GT_batch_mask[:, None, None, None] + ) + homo_sample = homo_sample_GT * GT_batch_mask[ + :, None, None, None + ] + homo_sample_nonGT * torch.logical_not( + GT_batch_mask[:, None, None, None] + ) + num_samples += 1 + + lm_embedding = self.lm_embed(lm_sample[..., :M]) + + homo_embedding = self.homotopy_embed(homo_sample) + a2l_edge_with_mode = torch.cat( + [ + a2l_edge_pool.repeat_interleave(num_samples, 0), + lm_embedding.view(B * num_samples, N, M, -1), + ], + -1, + ) + a2a_edge_with_mode = torch.cat( + [ + a2a_edge_pool.repeat_interleave(num_samples, 0), + homo_embedding.view(B * num_samples, N, N, -1), + ], + -1, + ) + + JM_GNN_vars = { + TFvars.Agent_hist: agent_feat_pool.repeat_interleave( + num_samples, 0 + ).unsqueeze(-2), + TFvars.Lane: context_vars[TFvars.Lane].repeat_interleave(num_samples, 0), + GNNedges.Agenthist2Lane: a2l_edge_with_mode.unsqueeze(2), + GNNedges.Agenthist2Agenthist: a2a_edge_with_mode.unsqueeze(3), + } + JM_GNN_var_masks = { + TFvars.Agent_hist: var_masks[TFvars.Agent_hist] + .any(2, keepdim=True) + .repeat_interleave(num_samples, 0), + TFvars.Lane: var_masks[TFvars.Lane].repeat_interleave(num_samples, 0), + } + for i in range(len(self.JM_GNN)): + JM_GNN_vars = self.JM_GNN[i]( + JM_GNN_vars, JM_GNN_var_masks, cross_masks=dict() + ) + + lm_factor = self.JM_lane_mode_factor(JM_GNN_vars[GNNedges.Agenthist2Lane]).view( + B, num_samples, N, M + ) + + homotopy_factor = self.JM_homotopy_factor( + JM_GNN_vars[GNNedges.Agenthist2Agenthist] + ).view(B, num_samples, N, N) + + total_factor = torch.cat([lm_factor, homotopy_factor], -1).reshape( + B, num_samples, N * (M + N) + ) + total_factor_mask = ( + torch.cat([lm_factor_mask[..., :M], homo_factor_mask], -1).reshape(B, -1) + * factor_mask + ) + joint_logpi = ( + total_factor * (total_factor_mask * factor_mask).unsqueeze(1) + ).sum(-1) + + if joint_sample_m is not None: + num_sample_m = joint_sample_m.size(1) + ml_idx = joint_logpi.argmax(dim=1, keepdim=True) + # decode the GT, the most likely mode other than GT, and the modes sampled with modified probability + if GT_lane_mode is not None: + dec_idx = torch.cat( + [ + torch.zeros(B, 1, device=device), + ml_idx, + torch.arange( + num_samples - num_sample_m, num_samples, device=device + )[None].repeat_interleave(B, 0), + ], + dim=1, + ).type(torch.int64) + else: + dec_idx = torch.cat( + [ + ml_idx, + torch.arange( + num_samples - num_sample_m, num_samples, device=device + )[None].repeat_interleave(B, 0), + ], + dim=1, + ).type(torch.int64) + else: + dec_idx = joint_logpi.topk( + min(self.decode_num_modes, joint_logpi.shape[1]), dim=1 + ).indices + if GT_lane_mode is not None: + dec_idx = torch.cat( + [torch.zeros(B, 1, device=device), dec_idx], 1 + ).type(torch.int64) + if self.null_lane_mode: + dec_cond_lm = torch.gather( + lm_sample, 1, dec_idx.view(B, -1, 1, 1).repeat(1, 1, N, M + 1) + ) + else: + dec_cond_lm = torch.gather( + lm_sample, 1, dec_idx.view(B, -1, 1, 1).repeat(1, 1, N, M) + ) + dec_cond_homotopy = torch.gather( + homo_sample, 1, dec_idx.view(B, -1, 1, 1).repeat(1, 1, N, N) + ) + joint_mode_prob = torch.softmax(joint_logpi, -1) + dec_cond_prob = torch.gather(joint_mode_prob, 1, dec_idx) + # homotopy result test + # dec_cond_homotopy[0,1,0,2]=2 + # dec_cond_homotopy[0,1,2,0]=2 + return dict( + lane_mode_pred=lane_mode_pred, + homotopy_pred=homotopy_pred, + joint_pred=joint_logpi, + lane_mode_sample=lm_sample, + homotopy_sample=homo_sample, + dec_cond_lm=dec_cond_lm, + dec_cond_homotopy=dec_cond_homotopy, + dec_cond_prob=dec_cond_prob, + ) + + def decode_trajectory_AR( + self, + enc_vars, + raw_vars, + aux_xs, + var_masks, + cross_masks, + frame_indices, + agent_mask, + x0, + u0, + lane_mode, + homotopy, + center_from_agents, + ): + ARS = self.AR_step_size + B, N, Th = enc_vars[TFvars.Agent_hist].shape[:3] + M = aux_xs[TFvars.Lane].shape[1] + device = enc_vars[TFvars.Agent_hist].device + raw_agent_hist = raw_vars[TFvars.Agent_hist] # [B,N,Th,F] + raw_agent_fut = list( + raw_agent_hist[:, :, -1:] + .repeat_interleave(1 + self.dec_T, 2) + .split(1, dim=2) + ) + state_out = {k: list() for k in self.dyn} + input_out = {k: list() for k in self.dyn} + Tf = self.dec_T + curr_yaw = ratan2( + aux_xs[TFvars.Agent_future][:, :, 0, 3:4], + aux_xs[TFvars.Agent_future][:, :, 0, 4:5], + ) + curr_yaw = curr_yaw * var_masks[TFvars.Agent_future][:, :, 0:1] + ar_mask = torch.zeros([1, 1, Tf + 1], dtype=bool, device=device) + xt = x0 + dec_kwargs_ar = dict( + aux_xs=aux_xs, + cross_masks=cross_masks, + frame_indices=frame_indices, + ) + lane_xysc = aux_xs[TFvars.Lane].reshape(B, M, -1, 4) + future_xyvsc = list(torch.split(aux_xs[TFvars.Agent_future][..., :5], 1, dim=2)) + for i in range(Tf // ARS): + dec_raw_vars = { + TFvars.Agent_future: torch.cat(raw_agent_fut, 2), + TFvars.Lane: raw_vars[TFvars.Lane], + } + + ar_mask[:, :, : i * ARS + 1] = 1 + x_mask = var_masks[TFvars.Agent_future] * ar_mask + + cross_masks = dict() + Agent_fut_cross_mask = var_masks[TFvars.Agent_future].unsqueeze( + -1 + ) * var_masks[TFvars.Agent_future].unsqueeze(-2) + Agent_fut_cross_mask = torch.tril(Agent_fut_cross_mask) * ar_mask.unsqueeze( + -2 + ) + + cross_masks[ + TFvars.Agent_future, TFvars.Agent_future, FeatureAxes.T + ] = Agent_fut_cross_mask + + dec_kwargs_ar["var_masks"] = var_masks + dec_kwargs_ar["cross_masks"] = cross_masks + future_xyvsc_tensor = torch.cat(future_xyvsc, dim=2) + future_aux = torch.cat( + [future_xyvsc_tensor, aux_xs[TFvars.Agent_future][..., 5:]], -1 + ) + aux_xs[TFvars.Agent_future] = future_aux + future_xysc = future_xyvsc_tensor[..., [0, 1, 3, 4]] + lane_margins = ( + self.fut_lane_relation.get_all_margins( + TensorUtils.join_dimensions(future_xysc, 0, 2), + lane_xysc.repeat_interleave(N, 0), + ) + .reshape(B, N, M, 1 + Tf, -1) + .clip(-20, 20) + ) + lane_margins.masked_fill_( + ~( + x_mask[:, :, None, :, None] + * var_masks[TFvars.Lane][:, None, :, None, None] + ).bool(), + 0, + ) + # mask out all unselected lanes + lane_margins.masked_fill_( + lane_mode[..., self.fut_lane_relation.NOTON, None, None].bool(), 0 + ) + agent_edge = future_xysc[..., :2].unsqueeze(2) - future_xysc[ + ..., :2 + ].unsqueeze(1) + rel_angle = ratan2(agent_edge[..., [1]], agent_edge[..., [0]]) + rel_angle_offset = torch.cat( + [rel_angle[..., :1, :], rel_angle[..., :-1, :]], -2 + ) + + angle_diff = (rel_angle - rel_angle_offset).cumsum(-2) + homotopy_margins = torch.cat( + [ + HOMOTOPY_THRESHOLD - angle_diff.abs(), + -angle_diff - HOMOTOPY_THRESHOLD, + angle_diff - HOMOTOPY_THRESHOLD, + ], + -1, + ) + homotopy_margins.masked_fill_( + torch.logical_not(x_mask.unsqueeze(1) * x_mask.unsqueeze(2)).unsqueeze( + -1 + ), + 0, + ) + aux_dec_edges = { + GNNedges.Agentfuture2Agentfuture: torch.cat( + [ + homotopy.unsqueeze(3).repeat_interleave(1 + self.dec_T, 3), + homotopy_margins, + ], + -1, + ), + GNNedges.Agentfuture2Lane: torch.cat( + [ + lane_mode.unsqueeze(2).repeat_interleave(1 + self.dec_T, 2), + lane_margins.transpose(2, 3), + ], + -1, + ), + } + dec_vars = self.embed_raw( + dec_raw_vars, + {TFvars.Agent_future: future_aux, TFvars.Lane: aux_xs[TFvars.Lane]}, + aux_dec_edges, + ) + vars = {**enc_vars, **dec_vars} + # update edges due to change of aux_xs + dec_agent_aux = future_aux.permute(0, 2, 1, 3).reshape(B * (Tf + 1), N, -1) + + a2a_edge = torch.cat( + [ + self.edge_func["a2a"](dec_agent_aux, dec_agent_aux), + homotopy.repeat_interleave(Tf + 1, 0), + TensorUtils.join_dimensions( + homotopy_margins.permute(0, 3, 1, 2, 4), 0, 2 + ), + ], + -1, + ) + a2l_edge = torch.cat( + [ + self.edge_func["a2l"]( + dec_agent_aux, aux_xs[TFvars.Lane].repeat_interleave(Tf + 1, 0) + ), + lane_mode.repeat_interleave(Tf + 1, 0), + TensorUtils.join_dimensions( + lane_margins.permute(0, 3, 1, 2, 4), 0, 2 + ), + ], + -1, + ) + edges = { + (TFvars.Agent_future, TFvars.Agent_future, FeatureAxes.A): a2a_edge, + (TFvars.Agent_future, TFvars.Lane, (FeatureAxes.A, FeatureAxes.L)): [ + a2l_edge + ], + } + dec_kwargs_ar["edges"] = edges + vars = self.decoder(vars, dec_kwargs_ar) + out = dict() + + if self.AR_update_mode == "step": + if self.arch == "lstm": + raise ValueError("LSTM arch is not working!") + + elif self.arch == "mlp": + for k in self.dyn: + out[k] = self.output_mlp[k.name]( + vars[TFvars.Agent_future].reshape(B * N, Tf + 1, -1)[ + :, i * ARS : (i + 1) * ARS + ] + ).view(B, N, ARS, -1) + else: + raise NotImplementedError + + elif self.AR_update_mode == "all": + if self.arch == "lstm": + raise ValueError("LSTM arch is not working!") + elif self.arch == "mlp": + for k in self.dyn: + out[k] = self.output_mlp[k.name]( + vars[TFvars.Agent_future].reshape(B * N, Tf + 1, -1)[ + :, : (i + 1) * ARS + ] + ).view(B, N, ARS * (i + 1), -1) + else: + raise NotImplementedError + + agent_xyvsc_by_type = dict() + agent_acce = dict() + agent_r = dict() + for k, dyn in self.dyn.items(): + if dyn is None: + # output assumed to take form [x,y,v,h] + if center_from_agents is None: + agent_xyvsc_by_type[k] = torch.cat( + [ + out[k][..., :3], + torch.sin(out[k][..., 3:]), + torch.cos(out[k][..., 3:]), + ], + -1, + ) + + else: + # transform to global coordinate + xy_local = out[k][..., :2] + agent_v = out[k][..., 2:3] + h_local = out[k][..., 3:4] + xy_global = GeoUtils.batch_nd_transform_points( + xy_local, center_from_agents.unsqueeze(2) + ) + h_global = h_local + curr_yaw.unsqueeze(2) + agent_xyvsc_by_type[k] = torch.cat( + [ + xy_global, + agent_v, + torch.sin(h_global), + torch.cos(h_global), + ], + -1, + ) + + agent_v = out[k][..., 2:3] + if self.AR_update_mode == "step": + agent_v_pre = torch.cat( + future_xyvsc[i * ARS : (i + 1) * ARS], 2 + )[..., 2:3] + # here we assume input is always 2 dimensional + input_out[k].append(torch.zeros_like(out[k][..., :2])) + elif self.AR_update_mode == "all": + agent_v_pre = torch.cat(future_xyvsc[: (i + 1) * ARS], 2)[ + ..., 2:3 + ] + + agent_acce[k] = (agent_v - agent_v_pre) / self.dt + agent_r[k] = torch.zeros_like(agent_v) + + else: + if isinstance(dyn, Unicycle): + xseq = dyn.forward_dynamics( + xt[k], out[k], mode="chain", bound=False + ) + + if self.AR_update_mode == "step": + agent_xyvsc_by_type[k] = dyn.state2xyvsc(xseq) + state_out[k].append(xseq) + xt[k] = xseq[:, :, -1] + elif self.AR_update_mode == "all": + agent_xyvsc_by_type[k] = dyn.state2xyvsc(xseq) + + agent_acce[k] = out[k][..., :1] + agent_r[k] = out[k][..., 1:2] + + elif isinstance(dyn, Unicycle_xyvsc): + xseq = dyn.forward_dynamics(xt[k], out[k], bound=False) + + agent_xyvsc_by_type[k] = xseq + if self.AR_update_mode == "step": + xt[k] = xseq[:, :, -1] + state_out[k].append(xseq) + agent_acce[k] = out[k][..., :1] + agent_r[k] = out[k][..., 1:2] + else: + raise NotImplementedError + input_out[k].append(out[k]) + agent_xyvsc_combined = sum( + [ + agent_mask[k][..., None, None] * agent_xyvsc_by_type[k] + for k in agent_xyvsc_by_type + ] + ) + + acce_combined = sum( + [agent_mask[k][..., None, None] * agent_acce[k] for k in agent_acce] + ) + r_combined = sum( + [agent_mask[k][..., None, None] * agent_r[k] for k in agent_r] + ) + agent_v = agent_xyvsc_combined[..., 2:3] + # update the agent future aux + + # update the agent future raw [v,acce,r,static_features] + + agent_vu = torch.cat([agent_v, acce_combined, r_combined], -1) + if self.AR_update_mode == "step": + future_xyvsc[i * ARS + 1 : (i + 1) * ARS + 1] = list( + torch.split(agent_xyvsc_combined, 1, dim=2) + ) + for j in range(ARS): + raw_agent_fut[i * ARS + 1 + j] = torch.cat( + [ + agent_vu[:, :, j : j + 1], + raw_agent_fut[i * ARS + 1 + j][..., 3:], + ], + -1, + ) + elif self.AR_update_mode == "all": + future_xyvsc[: (i + 1) * ARS + 1] = list( + torch.split(agent_xyvsc_combined, 1, dim=2) + ) + for j in range(ARS * (i + 1)): + raw_agent_fut[j] = torch.cat( + [agent_vu[:, :, j : j + 1], raw_agent_fut[j][..., 3:]], -1 + ) + + if self.AR_update_mode == "step": + future_xyvsc = torch.cat(future_xyvsc, dim=2) + input_out = { + k: torch.cat(input_type, dim=2) if len(input_type) > 0 else None + for k, input_type in input_out.items() + } + state_out = { + k: torch.cat(state_type, dim=2) + if len(state_type) > 0 + else future_xyvsc[..., 1:, :] + for k, state_type in state_out.items() + } + elif self.AR_update_mode == "all": + future_xyvsc = agent_xyvsc_combined + for k, v in out.items(): + if v is not None: + input_out[k] = v + else: + input_out[k] = torch.zeros_like(out[k][..., :2]) + + trajectories = torch.cat( + [ + future_xyvsc[:, :, 1:, :2], + ratan2(future_xyvsc[:, :, 1:, 3:4], future_xyvsc[:, :, 1:, 4:5]), + ], + -1, + ) + + input_violation = dict() + jerk = dict() + for k, dyn in self.dyn.items(): + if dyn is not None: + xl = torch.cat([x0[k].unsqueeze(-2), state_out[k][..., :-1, :]], -2) + input_violation[k] = ( + dyn.get_input_violation(xl, input_out[k]) + * agent_mask[k][..., None, None] + ) + inputs_extended = torch.cat([u0[k].unsqueeze(-2), input_out[k]], -2) + jerk[k] = ( + (inputs_extended[..., 1:, :] - inputs_extended[..., :-1, :]) + / self.dt + ) * agent_mask[k][..., None, None] + return trajectories, input_out, state_out, input_violation, jerk + + def decode_trajectory_OS( + self, + enc_vars, + raw_vars, + aux_xs, + var_masks, + cross_masks, + frame_indices, + agent_mask, + x0, + u0, + lane_mode, + homotopy, + center_from_agents, + ): + B, N, Th = enc_vars[TFvars.Agent_hist].shape[:3] + M = aux_xs[TFvars.Lane].shape[1] + device = enc_vars[TFvars.Agent_hist].device + raw_agent_hist = raw_vars[TFvars.Agent_hist] # [B,N,Th,F] + raw_agent_fut = raw_agent_hist[:, :, -1:].repeat_interleave(1 + self.dec_T, 2) + Tf = self.dec_T + curr_yaw = ratan2( + aux_xs[TFvars.Agent_future][:, :, 0, 3:4], + aux_xs[TFvars.Agent_future][:, :, 0, 4:5], + ) + curr_yaw = curr_yaw * var_masks[TFvars.Agent_future][:, :, 0:1] + dec_kwargs = dict( + aux_xs=aux_xs, + cross_masks=cross_masks, + frame_indices=frame_indices, + ) + lane_xysc = aux_xs[TFvars.Lane].reshape(B, M, -1, 4) + future_xyvsc = aux_xs[TFvars.Agent_future][..., :5] + input_out = dict() + state_out = dict() + for i in range(self.dec_rounds): + dec_raw_vars = { + TFvars.Agent_future: raw_agent_fut, + TFvars.Lane: raw_vars[TFvars.Lane], + } + + x_mask = var_masks[TFvars.Agent_future] + + cross_masks = dict() + Agent_fut_cross_mask = var_masks[TFvars.Agent_future].unsqueeze( + -1 + ) * var_masks[TFvars.Agent_future].unsqueeze(-2) + + cross_masks[ + TFvars.Agent_future, TFvars.Agent_future, FeatureAxes.T + ] = torch.tril(Agent_fut_cross_mask) + + dec_kwargs["var_masks"] = var_masks + dec_kwargs["cross_masks"] = cross_masks + future_aux = torch.cat( + [future_xyvsc, aux_xs[TFvars.Agent_future][..., 5:]], -1 + ) + aux_xs[TFvars.Agent_future] = future_aux + future_xysc = future_xyvsc[..., [0, 1, 3, 4]] + lane_margins = ( + self.fut_lane_relation.get_all_margins( + TensorUtils.join_dimensions(future_xysc, 0, 2), + lane_xysc.repeat_interleave(N, 0), + ) + .reshape(B, N, M, 1 + Tf, -1) + .clip(-20, 20) + ) + lane_margins.masked_fill_( + ~( + x_mask[:, :, None, :, None] + * var_masks[TFvars.Lane][:, None, :, None, None] + ).bool(), + 0, + ) + # mask out all unselected lanes + lane_margins.masked_fill_( + lane_mode[..., self.fut_lane_relation.NOTON, None, None].bool(), 0 + ) + agent_edge = future_xysc[..., :2].unsqueeze(2) - future_xysc[ + ..., :2 + ].unsqueeze(1) + rel_angle = ratan2(agent_edge[..., [1]], agent_edge[..., [0]]) + rel_angle_offset = torch.cat( + [rel_angle[..., :1, :], rel_angle[..., :-1, :]], -2 + ) + + angle_diff = (rel_angle - rel_angle_offset).cumsum(-2) + homotopy_margins = torch.cat( + [ + HOMOTOPY_THRESHOLD - angle_diff.abs(), + -angle_diff - HOMOTOPY_THRESHOLD, + angle_diff - HOMOTOPY_THRESHOLD, + ], + -1, + ) + homotopy_margins.masked_fill_( + torch.logical_not(x_mask.unsqueeze(1) * x_mask.unsqueeze(2)).unsqueeze( + -1 + ), + 0, + ) + aux_dec_edges = { + GNNedges.Agentfuture2Agentfuture: torch.cat( + [ + homotopy.unsqueeze(3).repeat_interleave(1 + self.dec_T, 3), + homotopy_margins, + ], + -1, + ), + GNNedges.Agentfuture2Lane: torch.cat( + [ + lane_mode.unsqueeze(2).repeat_interleave(1 + self.dec_T, 2), + lane_margins.transpose(2, 3), + ], + -1, + ), + } + dec_vars = self.embed_raw( + dec_raw_vars, + {TFvars.Agent_future: future_aux, TFvars.Lane: aux_xs[TFvars.Lane]}, + aux_dec_edges, + ) + vars = {**enc_vars, **dec_vars} + # update edges due to change of aux_xs + dec_agent_aux = future_aux.permute(0, 2, 1, 3).reshape(B * (Tf + 1), N, -1) + + a2a_edge = torch.cat( + [ + self.edge_func["a2a"](dec_agent_aux, dec_agent_aux), + homotopy.repeat_interleave(Tf + 1, 0), + TensorUtils.join_dimensions( + homotopy_margins.permute(0, 3, 1, 2, 4), 0, 2 + ), + ], + -1, + ) + a2l_edge = torch.cat( + [ + self.edge_func["a2l"]( + dec_agent_aux, aux_xs[TFvars.Lane].repeat_interleave(Tf + 1, 0) + ), + lane_mode.repeat_interleave(Tf + 1, 0), + TensorUtils.join_dimensions( + lane_margins.permute(0, 3, 1, 2, 4), 0, 2 + ), + ], + -1, + ) + edges = { + (TFvars.Agent_future, TFvars.Agent_future, FeatureAxes.A): a2a_edge, + (TFvars.Agent_future, TFvars.Lane, (FeatureAxes.A, FeatureAxes.L)): [ + a2l_edge + ], + } + dec_kwargs["edges"] = edges + vars = self.decoder(vars, dec_kwargs) + out = dict() + + if self.arch == "mlp": + for k in self.dyn: + out[k] = self.output_mlp[k.name]( + vars[TFvars.Agent_future].reshape(B, N, Tf + 1, -1) + ) + else: + raise NotImplementedError + + agent_xyvsc_by_type = dict() + agent_acce = dict() + agent_r = dict() + for k, dyn in self.dyn.items(): + if dyn is None: + # output assumed to take form [x,y,v,h] + if center_from_agents is None: + agent_xyvsc_by_type[k] = torch.cat( + [ + out[k][..., :3], + torch.sin(out[k][..., 3:]), + torch.cos(out[k][..., 3:]), + ], + -1, + )[..., 1:, :] + + else: + # transform to global coordinate + xy_local = out[k][..., :2] + agent_v = out[k][..., 2:3] + h_local = out[k][..., 3:4] + xy_global = GeoUtils.batch_nd_transform_points( + xy_local, center_from_agents.unsqueeze(2) + ) + h_global = h_local + curr_yaw.unsqueeze(2) + agent_xyvsc_by_type[k] = torch.cat( + [ + xy_global, + agent_v, + torch.sin(h_global), + torch.cos(h_global), + ], + -1, + )[..., 1:, :] + + agent_v = out[k][..., 2:3] + + agent_acce[k] = ( + out[k][..., 1:, 2:3] - out[k][..., :Tf, 2:3] + ) / self.dt + agent_acce[k] = torch.cat( + [agent_acce[k], agent_acce[k][..., -1:, :]], -2 + ) + agent_r[k] = torch.zeros_like(agent_v) + + else: + if isinstance(dyn, Unicycle): + xseq = dyn.forward_dynamics( + x0[k], out[k][..., :Tf, :], mode="chain", bound=False + ) + + agent_xyvsc_by_type[k] = dyn.state2xyvsc(xseq) + + agent_acce[k] = out[k][..., :1] + agent_r[k] = out[k][..., 1:2] + + elif isinstance(dyn, Unicycle_xyvsc): + xseq = dyn.forward_dynamics( + x0[k], out[k][..., :Tf, :], bound=False + ) + + agent_xyvsc_by_type[k] = xseq + agent_acce[k] = out[k][..., :1] + agent_r[k] = out[k][..., 1:2] + else: + raise NotImplementedError + state_out[k] = xseq + future_xyvsc[..., 1:, :] = sum( + [ + agent_mask[k][..., None, None] * agent_xyvsc_by_type[k] + for k in agent_xyvsc_by_type + ] + ) + + acce_combined = sum( + [agent_mask[k][..., None, None] * agent_acce[k] for k in agent_acce] + ) + r_combined = sum( + [agent_mask[k][..., None, None] * agent_r[k] for k in agent_r] + ) + agent_v = future_xyvsc[..., 2:3] + # update the agent future aux + + # update the agent future raw [v,acce,r,static_features] + + agent_vu = torch.cat([agent_v, acce_combined, r_combined], -1) + raw_agent_fut = torch.cat([agent_vu, raw_agent_fut[..., 3:]], -1) + for k, v in out.items(): + if v is not None: + input_out[k] = v[..., :Tf, :] + else: + input_out[k] = torch.zeros_like(out[k][..., :Tf, :2]) + + trajectories = torch.cat( + [ + future_xyvsc[:, :, 1:, :2], + ratan2(future_xyvsc[:, :, 1:, 3:4], future_xyvsc[:, :, 1:, 4:5]), + ], + -1, + ) + + input_violation = dict() + jerk = dict() + for k, dyn in self.dyn.items(): + if dyn is not None: + xl = torch.cat([x0[k].unsqueeze(-2), state_out[k][..., :-1, :]], -2) + input_violation[k] = ( + dyn.get_input_violation(xl, input_out[k][..., :Tf, :]) + * agent_mask[k][..., None, None] + ) + inputs_extended = torch.cat([u0[k].unsqueeze(-2), input_out[k]], -2) + jerk[k] = ( + (inputs_extended[..., 1:, :] - inputs_extended[..., :-1, :]) + / self.dt + ) * agent_mask[k][..., None, None] + return trajectories, input_out, state_out, input_violation, jerk + + def decode_trajectory( + self, + enc_vars, + raw_vars, + aux_xs, + var_masks, + cross_masks, + frame_indices, + agent_type, + lane_mode, + homotopy, + x0, + u0, + center_from_agents, + num_samples=None, + ): + device = enc_vars[TFvars.Agent_hist].device + B, N, Th = enc_vars[TFvars.Agent_hist].shape[:3] + assert AgentType.VEHICLE in self.dyn + agent_mask = dict() + for k, v in self.dyn.items(): + if k == AgentType.VEHICLE: + # this is the default mode + other_types = [k.value for k in self.dyn if k != AgentType.VEHICLE] + + agent_mask[k] = torch.logical_not( + agent_type[..., other_types].any(-1) + ) * agent_type.any(-1) + else: + agent_mask[k] = agent_type[..., k.value] + if lane_mode.ndim == 4: + # sample mode + sample_mode = True + num_samples = homotopy.shape[1] + lane_mode, homotopy = TensorUtils.join_dimensions( + (lane_mode, homotopy), 0, 2 + ) + B0 = B + B = B0 * num_samples + enc_vars = { + k: v.repeat_interleave(num_samples, 0) for k, v in enc_vars.items() + } + raw_vars = { + k: v.repeat_interleave(num_samples, 0) for k, v in raw_vars.items() + } + aux_xs = {k: v.repeat_interleave(num_samples, 0) for k, v in aux_xs.items()} + var_masks = { + k: v.repeat_interleave(num_samples, 0) for k, v in var_masks.items() + } + cross_masks = { + k: v.repeat_interleave(num_samples, 0) for k, v in cross_masks.items() + } + frame_indices = { + k: v.repeat_interleave(num_samples, 0) for k, v in frame_indices.items() + } + for k, v in x0.items(): + if v is not None: + x0[k] = v.repeat_interleave(num_samples, 0) + for k, v in u0.items(): + if v is not None: + u0[k] = v.repeat_interleave(num_samples, 0) + center_from_agents = center_from_agents.repeat_interleave(num_samples, 0) + for k, v in agent_mask.items(): + agent_mask[k] = v.repeat_interleave(num_samples, 0) + + else: + sample_mode = False + + lane_mode = F.one_hot(lane_mode, len(self.fut_lane_relation)).float() + if self.null_lane_mode: + lane_mode = lane_mode[..., :-1, :] + homotopy = F.one_hot(homotopy, len(HomotopyType)).float() + + Tf = self.dec_T + # create agent_future from agent_hist for auto-regressive decoding + + var_masks[TFvars.Agent_future] = ( + var_masks[TFvars.Agent_hist] + .any(2, keepdim=True) + .repeat_interleave(1 + Tf, 2) + ) + aux_xs[TFvars.Agent_future] = aux_xs[TFvars.Agent_hist][ + :, :, -1: + ].repeat_interleave(1 + Tf, 2) + + frame_indices[TFvars.Agent_future] = Th + torch.arange( + 1 + self.dec_T, device=device + )[None, None, :].repeat(B, N, 1) + + future_xyvsc = list(torch.split(aux_xs[TFvars.Agent_future][..., :5], 1, dim=2)) + + if self.AR_update_mode is not None: + ( + trajectories, + input_out, + state_out, + input_violation, + jerk, + ) = self.decode_trajectory_AR( + enc_vars, + raw_vars, + aux_xs, + var_masks, + cross_masks, + frame_indices, + agent_mask, + x0, + u0, + lane_mode, + homotopy, + center_from_agents, + ) + else: + ( + trajectories, + input_out, + state_out, + input_violation, + jerk, + ) = self.decode_trajectory_OS( + enc_vars, + raw_vars, + aux_xs, + var_masks, + cross_masks, + frame_indices, + agent_mask, + x0, + u0, + lane_mode, + homotopy, + center_from_agents, + ) + + if sample_mode: + trajectories = trajectories.view(B0, num_samples, N, Tf, -1) + state_out = { + k: v.view(B0, num_samples, N, Tf, -1) if v is not None else None + for k, v in state_out.items() + } + input_out = { + k: v.view(B0, num_samples, N, Tf, -1) if v is not None else None + for k, v in input_out.items() + } + agent_mask = {k: v.view(B0, num_samples, N) for k, v in agent_mask.items()} + input_violation = { + k: v.view(B0, num_samples, N, Tf, -1) + for k, v in input_violation.items() + } + jerk = {k: v.view(B0, num_samples, N, Tf, -1) for k, v in jerk.items()} + + return dict( + trajectories=trajectories, + states=state_out, + inputs=input_out, + input_violation=input_violation, + jerk=jerk, + type_mask=agent_mask, + future_xyvsc=future_xyvsc, + ) + + def restore_lm_homotopy_from_joint_sample( + self, joint_sample, indices, factor_idx, lm_factor_mask, N, M, M_lm + ): + device = joint_sample.device + B, num_samples = joint_sample.shape[:2] + factor_mask = F.one_hot(factor_idx, N * (N + M_lm)).sum( + -2 + ) # (B, |combined_log_pi|) + if not self.classify_a2l_4all_lanes: + factor_mask = torch.cat( + [ + torch.ones(B, N * M, device=device, dtype=torch.int), + factor_mask[:, N * M_lm :], + ], + -1, + ) + candidate_flag = torch.zeros(B, N, M, dtype=torch.bool, device=device) + for i in range(M_lm): + candidate_flag.scatter_(-1, indices[:, :, i], 1) # Back to lanes + lm_factor_mask = lm_factor_mask * candidate_flag + + if not self.classify_a2l_4all_lanes: + lm_sample_per_mode = joint_sample[..., : N * M_lm].view( + B, -1, N, M_lm + ) # First N*M_lm are mode contributions + lm_sample = torch.zeros( + [B, joint_sample.size(1), N, M], dtype=torch.long, device=device + ) + + i = 0 + lm_sample_per_mode = lm_sample_per_mode.clip(0, indices.shape[-1] - 1) + for mode in self.fut_lane_relation: + if mode == self.fut_lane_relation.NOTON: + continue + # restore from the important sampling of of lane segments + restored_lm_samples_i = torch.gather( + indices[:, :, i].unsqueeze(1).repeat_interleave(num_samples, 1), + -1, + lm_sample_per_mode[..., i : i + 1], + ).squeeze(-1) + # value of the lane mode + mode_enum = i + 1 if mode > self.fut_lane_relation.NOTON else i + lm_sample = lm_sample + (lm_sample == 0) * ( + torch.scatter( + lm_sample, -1, restored_lm_samples_i.unsqueeze(-1), mode_enum + ) + ) + i += 1 + else: + lm_sample = joint_sample[..., : N * M].view( + B, -1, N, M + ) # .clip(0,len(self.fut_lane_relation)-1) + lm_sample.masked_fill_(~lm_factor_mask.bool().unsqueeze(1), 0) + homo_sample = joint_sample[..., N * M_lm :].view(B, -1, N, N) + homo_sample = HomotopyType.enforce_symmetry(homo_sample) + return lm_sample, homo_sample, factor_mask + + def forward( + self, + raw_vars, + aux_xs, + var_masks, + cross_masks, + frame_indices, + agent_type, + enc_edges=dict(), + dec_edges=dict(), + GT_lane_mode=None, + GT_homotopy=None, + center_from_agents=None, + num_samples=None, + ): + enc_raw_vars = { + k: v + for k, v in raw_vars.items() + if k in self.enc_vars or k in self.enc_edges + } + enc_vars = self.embed_raw(enc_raw_vars, aux_xs) + + x0 = dict() + u0 = dict() + for k, dyn in self.dyn.items(): + if dyn is not None: + # get x0 and u0 from aux_xs + if isinstance(dyn, Unicycle_xyvsc): + x0[k] = aux_xs[TFvars.Agent_hist][:, :, -1, :5] + xprev = aux_xs[TFvars.Agent_hist][:, :, -2, :5] + u0[k] = dyn.inverse_dyn(x0[k], xprev, self.dt) + elif isinstance(dyn, Unicycle): + xyvsc = aux_xs[TFvars.Agent_hist][:, :, -2:, :5] + h = ratan2(xyvsc[..., 3:4], xyvsc[..., 4:5]) + x0[k] = torch.cat([xyvsc[..., -1, :3], h[..., -1, :]], -1) + xprev = torch.cat([xyvsc[..., -2, :3], h[..., -2, :]], -1) + u0[k] = dyn.inverse_dyn(x0[k], xprev, self.dt) + else: + raise NotImplementedError + else: + x0[k] = None + u0[k] = None + + enc_kwargs = dict( + aux_xs=aux_xs, + var_masks=var_masks, + cross_masks=cross_masks, + frame_indices=frame_indices, + edges=enc_edges, + ) + enc_vars = self.encoder(enc_vars, enc_kwargs) + mode_pred = self.predict_from_context( + enc_vars, + aux_xs, + frame_indices, + var_masks, + cross_masks, + enc_edges, + GT_lane_mode, + GT_homotopy, + num_samples=num_samples, + ) + + # decode trajectory predictions + # prepare modes + + if GT_lane_mode is not None and GT_homotopy is not None: + # train mode + if self.decode_num_modes == 1: + lane_mode = GT_lane_mode + homotopy = GT_homotopy + mode_pred["dec_cond_lm"] = GT_lane_mode.unsqueeze(1) + mode_pred["dec_cond_homotopy"] = GT_homotopy.unsqueeze(1) + mode_pred["dec_cond_prob"] = torch.ones( + [GT_lane_mode.shape[0], 1], device=GT_lane_mode.device + ) + else: + lane_mode = mode_pred["dec_cond_lm"] + homotopy = mode_pred["dec_cond_homotopy"] + + else: + # infer mode + lane_mode = mode_pred["lane_mode_sample"] + homotopy = mode_pred["homotopy_sample"] + + dec_vars = self.decode_trajectory( + enc_vars, + raw_vars, + aux_xs, + var_masks, + cross_masks, + frame_indices, + agent_type, + lane_mode, + homotopy, + x0=x0, + u0=u0, + center_from_agents=center_from_agents, + ) + return dec_vars, mode_pred diff --git a/diffstack/models/RPE_simple.py b/diffstack/models/RPE_simple.py new file mode 100644 index 0000000..0bd885e --- /dev/null +++ b/diffstack/models/RPE_simple.py @@ -0,0 +1,441 @@ +import math +import torch +import torch.nn as nn +from torch.nn import functional as F +from diffstack.configs.config import Dict +from diffstack.models.TypeTransformer import NewGELU, zero_module +from diffstack.utils.model_utils import * + +""" +modified from RPE.py with the following changes: +1. removed diffusion time embedding +2. removed the D dimension +3. reversed the order of T and C +4. changed attn_mask to be [B,T,T] instead of [B,T] +5. removed the residual connection, instead, output the residual itself +6. removed the batchnorm +""" +class sRPENet(nn.Module): + def __init__(self, n_embd, num_heads): + super().__init__() + self.embed_distances = nn.Linear(3, n_embd) + + self.gelu = NewGELU() + self.out = nn.Linear(n_embd, n_embd) + self.out.weight.data *= 0. + self.out.bias.data *= 0. + self.n_embd = n_embd + self.num_heads = num_heads + + def forward(self, relative_distances): + distance_embs = torch.stack( + [torch.log(1+(relative_distances).clamp(min=0)), + torch.log(1+(-relative_distances).clamp(min=0)), + (relative_distances == 0).float()], + dim=-1 + ) # BxTxTx3 + if self.embed_distances.weight.dtype==torch.float16: + distance_embs = distance_embs.half() + emb = self.embed_distances(distance_embs) + return self.out(self.gelu(emb)).view(*relative_distances.shape, self.num_heads, self.n_embd//self.num_heads) + +class ProdPENet(nn.Module): + # embed the tuple of two positions [P1,P2] into positional embedding + def __init__(self, n_embd, num_heads): + super().__init__() + self.embed_distances = nn.Linear(5, n_embd) + + self.gelu = NewGELU() + self.out = nn.Linear(n_embd, n_embd) + self.out.weight.data *= 0. + self.out.bias.data *= 0. + self.n_embd = n_embd + self.num_heads = num_heads + + def forward(self, relative_distances,P1,P2): + distance_embs = torch.stack( + [torch.log(1+(relative_distances).clamp(min=0)), + torch.log(1+(-relative_distances).clamp(min=0)), + (relative_distances == 0).float(), + torch.log(1+P1), + torch.log(1+P2)], + dim=-1 + ) # BxTxTx3 + if self.embed_distances.weight.dtype==torch.float16: + distance_embs = distance_embs.half() + emb = self.embed_distances(distance_embs) + return self.out(self.gelu(emb)).view(*relative_distances.shape, self.num_heads, self.n_embd//self.num_heads) + +class sRPE(nn.Module): + # Based on https://github.com/microsoft/Cream/blob/6fb89a2f93d6d97d2c7df51d600fe8be37ff0db4/iRPE/DeiT-with-iRPE/rpe_vision_transformer.py + def __init__(self, n_embd, num_heads, use_rpe_net=False): + """ This module handles the relative positional encoding. + Args: + channels (int): Number of input channels. + num_heads (int): Number of attention heads. + """ + super().__init__() + self.num_heads = num_heads + self.head_dim = n_embd // self.num_heads + self.use_rpe_net = use_rpe_net + if use_rpe_net: + self.rpe_net = sRPENet(n_embd, num_heads) + else: + self.prod_pe_net = ProdPENet(n_embd, num_heads) + # raise NotImplementedError + # self.lookup_table_weight = nn.Parameter( + # torch.zeros(2 * self.beta + 1, + # self.num_heads, + # self.head_dim)) + + def get_R(self, pairwise_distances,P1=None,P2=None): + if self.use_rpe_net: + return self.rpe_net(pairwise_distances) + else: + return self.prod_pe_net(pairwise_distances,P1,P2) + # return self.lookup_table_weight[pairwise_distances] # BxTxTxHx(C/H) + + def forward(self, x, pairwise_distances, mode,P1=None,P2=None): + if mode == "qk": + return self.forward_qk(x, pairwise_distances,P1=P1,P2=P2) + elif mode == "v": + return self.forward_v(x, pairwise_distances,P1=P1,P2=P2) + else: + raise ValueError(f"Unexpected RPE attention mode: {mode}") + + def forward_qk(self, qk, pairwise_distances,P1=None,P2=None): + # qv is either of q or k and has shape BxHxTx(C/H) + # Output shape should be # BxHxTxT + R = self.get_R(pairwise_distances,P1=P1,P2=P2) + if qk.ndim==4: + return torch.einsum( # See Eq. 16 in https://arxiv.org/pdf/2107.14222.pdf + "bhtf,btshf->bhts", qk, R # BxHxTxT + ) + elif qk.ndim==5: + return torch.einsum( # See Eq. 16 in https://arxiv.org/pdf/2107.14222.pdf + "bhtsf,btshf->bhts", qk, R # BxHxTxT + ) + + def forward_v(self, attn, pairwise_distances,P1=None,P2=None): + # attn has shape BxHxT1xT2 + # Output shape should be # BxHxT1x(C/H) + R = self.get_R(pairwise_distances,P1=P1,P2=P2) + return torch.einsum( # See Eq. 16ish in https://arxiv.org/pdf/2107.14222.pdf + "bhts,btshf->bhtf", attn, R # BxHxTxT + ) + + def forward_safe_qk(self, x, pairwise_distances): + R = self.get_R(pairwise_distances) + B, T, _, H, F = R.shape + res = x.new_zeros(B, H, T, T) # attn shape + for b in range(B): + for h in range(H): + for i in range(T): + for j in range(T): + res[b, h, i, j] = x[b, h, i].dot(R[b, i, j, h]) + return res + +class sRPEAttention(nn.Module): + # Based on https://github.com/microsoft/Cream/blob/6fb89a2f93d6d97d2c7df51d600fe8be37ff0db4/iRPE/DeiT-with-iRPE/rpe_vision_transformer.py#L42 + def __init__(self, n_embd, num_heads, use_checkpoint=False,use_rpe_net=None, + use_rpe_q=True, use_rpe_k=True, use_rpe_v=True, + ): + super().__init__() + self.num_heads = num_heads + head_dim = n_embd // num_heads + self.scale = head_dim ** -0.5 + self.use_checkpoint = use_checkpoint + + self.qkv = nn.Linear(n_embd, n_embd * 3) + self.proj_out = zero_module(nn.Linear(n_embd, n_embd)) + + if use_rpe_q or use_rpe_k or use_rpe_v: + assert use_rpe_net is not None + def make_rpe_func(): + return sRPE( + n_embd=n_embd, num_heads=num_heads, use_rpe_net=use_rpe_net, + ) + self.rpe_q = make_rpe_func() if use_rpe_q else None + self.rpe_k = make_rpe_func() if use_rpe_k else None + self.rpe_v = make_rpe_func() if use_rpe_v else None + + def forward(self, x, attn_mask, frame_indices, attn_weights_list=None): + out, attn = checkpoint(self._forward, (x, attn_mask, frame_indices), self.parameters(), self.use_checkpoint) + if attn_weights_list is not None: + B, T, C = x.shape + attn_weights_list.append(attn.detach().view(B, -1, T, T).mean(dim=1).abs()) # logging attn weights + return out + + def _forward(self, x, attn_mask, frame_indices): + B, T, C = x.shape + qkv = self.qkv(x).reshape(B, T, 3, self.num_heads, C // self.num_heads) + qkv = torch.einsum("BTtHF -> tBHTF", qkv) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + # q, k, v shapes: BxHxTx(C/H) + q *= self.scale + attn = (q @ k.transpose(-2, -1)) # BxDxHxTxT + if self.rpe_q is not None or self.rpe_k is not None or self.rpe_v is not None: + pairwise_distances = (frame_indices.unsqueeze(-1) - frame_indices.unsqueeze(-2)) # BxTxT + # relative position on keys + if self.rpe_k is not None: + attn += self.rpe_k(q, pairwise_distances, mode="qk") + # relative position on queries + if self.rpe_q is not None: + attn += self.rpe_q(k * self.scale, pairwise_distances, mode="qk") + + # softmax where all elements with mask==0 can attend to eachother and all with mask==1 + # can attend to eachother (but elements with mask==0 can't attend to elements with mask==1) + def softmax(w, attn_mask): + if attn_mask is not None: + allowed_interactions = attn_mask + inf_mask = (1-allowed_interactions) + inf_mask[inf_mask == 1] = torch.inf + w = w - inf_mask.view(B, 1, 1, T, T) # BxDxHxTxT + w.masked_fill_((attn_mask==0).all(-1).view(B,1,1,T,1),0) + finfo = torch.finfo(w.dtype) + w = w.nan_to_num(nan=0.0, posinf=finfo.max, neginf=finfo.min) + return torch.softmax(w.float(), dim=-1).type(w.dtype) + if attn_mask is None: + attn_mask = torch.ones_like(attn[:,0]) + attn = softmax(attn, attn_mask.type(attn.dtype)) + out = attn @ v + # relative position on values + if self.rpe_v is not None: + out += self.rpe_v(attn, pairwise_distances, mode="v") + out = torch.einsum("BHTF -> BTHF", out).reshape(B, T, C) + out = self.proj_out(out) + return out, attn + +class AuxPEAttention(nn.Module): + def __init__(self, n_embd, num_heads, max_len = 100,attn_pdrop=0.0,resid_pdrop = 0.0,aux_vardim=0): + super().__init__() + + self.num_heads = num_heads + self.PE_q = nn.Embedding(max_len, n_embd) + self.PE_k = nn.Embedding(max_len, n_embd) + self.attn_pdrop = attn_pdrop + self.resid_dropout = nn.Dropout(resid_pdrop) + + self.aux_vardim = aux_vardim + self.qnet = nn.Linear(n_embd, n_embd) + self.knet = nn.Linear(n_embd+aux_vardim, n_embd) + self.vnet = nn.Linear(n_embd+aux_vardim, n_embd) + self.proj_out = zero_module(nn.Linear(n_embd, n_embd)) + + def forward(self, x, aux_x, attn_mask, frame_indices): + + B, T, C = x.shape + if aux_x is not None: + x_aug = torch.cat([x, aux_x], dim=2) + else: + x_aug = x + q,k,v = self.qnet(x).reshape(B,T,self.num_heads,C//self.num_heads).transpose(1,2),\ + self.knet(x_aug).reshape(B,T,self.num_heads,C//self.num_heads).transpose(1,2),\ + self.vnet(x_aug).reshape(B,T,self.num_heads,C//self.num_heads).transpose(1,2) + + # q, k, v shapes: BxHxTx(C/H) + q = q + self.PE_q(frame_indices).view(B,T,self.num_heads,C//self.num_heads) + k = k + self.PE_k(frame_indices).view(B,T,self.num_heads,C//self.num_heads) + + out = F.scaled_dot_product_attention(q,k,v,attn_mask[:,None].repeat_interleave(self.num_heads,1),dropout=0.0) + out = out.nan_to_num(0.0) + out = out.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + out = self.resid_dropout(self.proj_out(out)) + return out + + + +class sAuxRPEAttention(nn.Module): + # Based on https://github.com/microsoft/Cream/blob/6fb89a2f93d6d97d2c7df51d600fe8be37ff0db4/iRPE/DeiT-with-iRPE/rpe_vision_transformer.py#L42 + def __init__(self, n_embd, num_heads, aux_vardim, use_checkpoint=False, + use_rpe_net=None, use_rpe_q=True, use_rpe_k=True, use_rpe_v=True + ): + super().__init__() + assert aux_vardim>0 + self.aux_vardim = aux_vardim + self.num_heads = num_heads + head_dim = n_embd // num_heads + self.scale = head_dim ** -0.5 + self.use_checkpoint = use_checkpoint + + self.qnet = nn.Linear(n_embd, n_embd) + self.knet = nn.Linear(n_embd+aux_vardim, n_embd) + self.vnet = nn.Linear(n_embd+aux_vardim, n_embd) + self.proj_out = zero_module(nn.Linear(n_embd, n_embd)) + + if use_rpe_q or use_rpe_k or use_rpe_v: + assert use_rpe_net is not None + def make_rpe_func(): + return sRPE( + n_embd=n_embd, num_heads=num_heads, use_rpe_net=use_rpe_net, + ) + self.rpe_q = make_rpe_func() if use_rpe_q else None + self.rpe_k = make_rpe_func() if use_rpe_k else None + self.rpe_v = make_rpe_func() if use_rpe_v else None + + def forward(self, x, aux_x, attn_mask, frame_indices, attn_weights_list=None,**kwargs): + out, attn = checkpoint(self._forward, (x, aux_x, attn_mask, frame_indices), self.parameters(), self.use_checkpoint) + if attn_weights_list is not None: + B, T, C = x.shape + attn_weights_list.append(attn.detach().view(B, -1, T, T).mean(dim=1).abs()) # logging attn weights + return out + + def _forward(self, x, aux_x, attn_mask, frame_indices): + B, T, C = x.shape + x_aug = torch.cat([x, aux_x], dim=2) + + q = self.qnet(x).reshape(B,T,self.num_heads,C//self.num_heads).transpose(1,2) + k = self.knet(x_aug).reshape(B,T,self.num_heads,C//self.num_heads).transpose(1,2) + v = self.vnet(x_aug).reshape(B,T,self.num_heads,C//self.num_heads).transpose(1,2) + # q, k, v shapes: BxHxTx(C/H) + q *= self.scale + attn = (q @ k.transpose(-2, -1)) # BxHxTxT + if self.rpe_q is not None or self.rpe_k is not None or self.rpe_v is not None: + pairwise_distances = (frame_indices.unsqueeze(-1) - frame_indices.unsqueeze(-2)) # BxTxT + # relative position on keys + if self.rpe_k is not None: + attn += self.rpe_k(q, pairwise_distances, mode="qk") + # relative position on queries + if self.rpe_q is not None: + attn += self.rpe_q(k * self.scale, pairwise_distances, mode="qk") + + # softmax where all elements with mask==0 can attend to eachother and all with mask==1 + # can attend to eachother (but elements with mask==0 can't attend to elements with mask==1) + def softmax(w, attn_mask): + if attn_mask is not None: + allowed_interactions = attn_mask + inf_mask = (1-allowed_interactions) + inf_mask[inf_mask == 1] = torch.inf + w = w - inf_mask.view(B, 1, T, T) # BxHxTxT + w.masked_fill_((attn_mask[:,None]==0).all(-1).unsqueeze(-1),0) + finfo = torch.finfo(w.dtype) + w = w.nan_to_num(nan=0.0, posinf=finfo.max, neginf=finfo.min) + return torch.softmax(w.float(), dim=-1).type(w.dtype) + if attn_mask is None: + attn_mask = torch.ones_like(attn[:,0]) + attn = softmax(attn, attn_mask.type(attn.dtype)) + out = attn @ v + # relative position on values + if self.rpe_v is not None: + out += self.rpe_v(attn, pairwise_distances, mode="v") + out = torch.einsum("BHTF -> BTHF", out).reshape(B, T, C) + out = self.proj_out(out) + return out, attn + +class sAuxRPECrossAttention(nn.Module): + """ RPE cross attention with auxillary variable + + """ + def __init__(self, n_embd, num_heads, edge_dim,aux_edge_func=None, use_checkpoint=False, + use_rpe_net=None, use_rpe_k=True, use_rpe_v=True + ): + super().__init__() + # assert edge_dim% num_heads==0 + self.edge_dim = edge_dim + self.num_heads = num_heads + head_dim = n_embd // num_heads + self.scale = head_dim ** -0.5 + self.use_checkpoint = use_checkpoint + + self.qnet = nn.Linear(n_embd, n_embd) + self.knet = nn.Linear(n_embd+edge_dim, n_embd) + self.vnet = nn.Linear(n_embd+edge_dim, n_embd) + self.proj_out = zero_module(nn.Linear(n_embd, n_embd)) + self.aux_edge_func = aux_edge_func + + if use_rpe_k or use_rpe_v: + assert use_rpe_net is not None + def make_rpe_func(): + return sRPE( + n_embd=n_embd, num_heads=num_heads, use_rpe_net=use_rpe_net, + ) + self.rpe_k = make_rpe_func() if use_rpe_k else None + self.rpe_v = make_rpe_func() if use_rpe_v else None + + def forward(self, x1, x2, attn_mask, aux_x1,aux_x2, frame_indices1, frame_indices2, edge = None,attn_weights_list=None): + out, attn = checkpoint(self._forward, (x1, x2, attn_mask, aux_x1, aux_x2, frame_indices1, frame_indices2,edge), self.parameters(), self.use_checkpoint) + if attn_weights_list is not None: + B, T1, C = x1.shape + T2 = x2.shape[1] + attn_weights_list.append(attn.detach().view(B, -1, T1, T2).mean(dim=1).abs()) # logging attn weights + return out + + def _forward(self, x1, x2, attn_mask, aux_x1, aux_x2, frame_indices1, frame_indices2,edge=None): + B, T1, C = x1.shape + T2 = x2.shape[1] + q = self.qnet(x1) + if self.edge_dim>0: + if edge is None: + if self.aux_edge_func is None: + # default to concatenation + edge = torch.cat([aux_x1.unsqueeze(2).repeat_interleave(T2,2), + aux_x2.unsqueeze(1).repeat_interleave(T1,1)],-1) # (B,T1,T2,auxdim1+auxdim2) + else: + edge = self.aux_edge_func(aux_x1,aux_x2) # (B,T1,T2,aux_vardim) + if self.knet.weight.dtype==torch.float16: + edge = edge.half() + aug_x2 = torch.cat([x2.unsqueeze(1).repeat_interleave(T1,1), edge], dim=-1) + else: + aug_x2 = x2.unsqueeze(1).repeat_interleave(T1,1) + + k = self.knet(aug_x2).reshape(B,T1, T2, self.num_heads,C//self.num_heads).permute(0, 3, 1, 2,4) #B,nh,T1,T2,C//nh + v = self.vnet(aug_x2).reshape(B,T1,T2,self.num_heads,C//self.num_heads).permute(0, 3, 1, 2,4) #B,nh,T1,T2,C//nh + + q = (q*self.scale).view(B, T1, self.num_heads, 1, C // self.num_heads).transpose(1, 2).repeat_interleave(T2,3) # (B, nh, T1, T2, hs) + attn = (q * k).sum(-1) * (1.0 / math.sqrt(k.size(-1))) + if self.rpe_k is not None or self.rpe_v is not None: + pairwise_distances = (frame_indices1.unsqueeze(-1) - frame_indices2.unsqueeze(-2)) # BxT1xT2 + # relative position on keys + if self.rpe_k is not None: + attn += self.rpe_k(q, pairwise_distances, mode="qk",P1=frame_indices1.unsqueeze(-1).expand(*pairwise_distances.shape),\ + P2=frame_indices2.unsqueeze(-2).expand(*pairwise_distances.shape)) + + # softmax where all elements with mask==0 can attend to eachother and all with mask==1 + # can attend to eachother (but elements with mask==0 can't attend to elements with mask==1) + def softmax(w, attn_mask): + if attn_mask is not None: + allowed_interactions = attn_mask + inf_mask = (1-allowed_interactions) + inf_mask[inf_mask == 1] = torch.inf + w = w - inf_mask.view(B, 1, T1, T2) # BxHxTxT + w.masked_fill_((attn_mask[:,None]==0).all(-1).unsqueeze(-1),0) + finfo = torch.finfo(w.dtype) + w = w.nan_to_num(nan=0.0, posinf=finfo.max, neginf=finfo.min) + return torch.softmax(w.float(), dim=-1).type(w.dtype) + if attn_mask is None: + attn_mask = torch.ones_like(attn[:,0]) + attn = softmax(attn, attn_mask.type(attn.dtype)) + out = attn @ v if v.ndim==4 else (attn.unsqueeze(-1)*v).sum(-2) + # relative position on values + if self.rpe_v is not None: + out += self.rpe_v(attn, pairwise_distances, mode="v",P1=frame_indices1.unsqueeze(-1).expand(*pairwise_distances.shape),\ + P2=frame_indices2.unsqueeze(-2).expand(*pairwise_distances.shape)) + out = torch.einsum("BHTF -> BTHF", out).reshape(B, T1, C) + out = self.proj_out(out) + return out, attn + +def testAux(): + + net = sAuxRPEAttention(n_embd=32, num_heads=4, aux_vardim=4,use_rpe_net=True) + B,C,T = 3,32,10 + x = torch.randn(B,T,C) + aux_x = torch.randn(B,T,4) + frame_indices = torch.arange(T).unsqueeze(0).repeat(B,1) + out = net(x, aux_x, None, frame_indices) + +def testAuxCross(): + net = sAuxRPECrossAttention(n_embd=32, num_heads=4, edge_dim=4,use_rpe_net=True) + B,C,T1,T2 = 3,32,10,5 + x1 = torch.randn(B,T1,C) + x2 = torch.randn(B,T2,C) + aux_x1 = torch.randn(B,T1,4) + aux_x2 = torch.randn(B,T2,4) + frame_indices1 = torch.arange(T1).unsqueeze(0).repeat(B,1) + frame_indices2 = torch.arange(T2).unsqueeze(0).repeat(B,1)+T1 + out = net(x1,x2, None, aux_x1, aux_x2, frame_indices1, frame_indices2) + print("123") + +if __name__ == "__main__": + testAuxCross() \ No newline at end of file diff --git a/diffstack/models/Transformer.py b/diffstack/models/Transformer.py new file mode 100644 index 0000000..faffc5b --- /dev/null +++ b/diffstack/models/Transformer.py @@ -0,0 +1,861 @@ +from logging import raiseExceptions +import numpy as np + +import torch +import math, copy +from typing import Dict +from collections import OrderedDict + +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import diffstack.utils.tensor_utils as TensorUtils + + +def clones(module, n): + "Produce N identical layers." + return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) + + +class FactorizedEncoderDecoder(nn.Module): + """ + A encoder-decoder transformer model with Factorized encoder and decoder + """ + + def __init__(self, encoder, decoder, src_embed, tgt_embed, generator, src2posfun): + """ + Args: + encoder: FactorizedEncoder + decoder: FactorizedDecoder + src_embed: source embedding network + tgt_embed: target embedding network + generator: network used to generate output from target + src2posfun: extract positional info from the src + """ + super(FactorizedEncoderDecoder, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.src_embed = src_embed + self.tgt_embed = tgt_embed + self.generator = generator + self.src2posfun = src2posfun + + def src2pos(self, src, dyn_type): + "extract positional info from src for all datatypes, e.g., for vehicles, the first two dimensions are x and y" + + pos = torch.zeros([*src.shape[:-1], 2]).to(src.device) + for dt, fun in self.src2posfun.items(): + pos += fun(src) * (dyn_type == dt).view([*(dyn_type.shape), 1, 1]) + + return pos + + def forward( + self, + src, + tgt, + src_mask, + tgt_mask, + tgt_mask_agent, + dyn_type, + map_emb=None, + ): + "Take in and process masked src and target sequences." + src_pos = self.src2pos(src, dyn_type) + "for decoders, we only use position at the last time step of the src" + return self.decode( + self.encode(src, src_mask, src_pos, map_emb), + src_mask, + tgt, + tgt_mask, + tgt_mask_agent, + src_pos[:, :, -1:], + ) + + def encode(self, src, src_mask, src_pos, map_emb): + return self.encoder(self.src_embed(src), src_mask, src_pos, map_emb) + + def decode(self, memory, src_mask, tgt, tgt_mask, tgt_mask_agent, pos): + + return self.decoder( + self.tgt_embed(tgt), + memory, + src_mask, + tgt_mask, + tgt_mask_agent, + pos, + ) + + +class DynamicGenerator(nn.Module): + "Incorporating dynamics to the generator to generate dynamically feasible output, not used yet" + + def __init__(self, d_model, dt, dyns, state2feature, feature2state): + super(DynamicGenerator, self).__init__() + self.dyns = dyns + self.proj = dict() + self.dt = dt + self.state2feature = state2feature + self.feature2state = feature2state + for dyn in self.dyns: + self.proj[dyn.type()] = nn.Linear(d_model, dyn.udim) + + def forward(self, x, tgt, type_index): + Nagent = tgt.shape[0] + tgt_next = [None] * Nagent + for dyn in self.dyns: + index = type_index[dyn.type()] + state = self.feature2state[dyn.name](tgt[index]) + input = self.proj[dyn.type()](x) + state_next = dyn.step(state, input, self.dt) + x_next_raw = self.state2feature[dyn.name](state_next) + for i in range(len(index)): + tgt_next[index[i]] = x_next_raw[i] + return torch.stack(tgt_next, dim=0) + + +class Encoder(nn.Module): + "Core encoder is a stack of N layers" + + def __init__(self, layer, N): + super(Encoder, self).__init__() + self.layers = clones(layer, N) + self.norm = LayerNorm(layer.size) + + def forward(self, x, mask, mask1=None): + "Pass the input (and mask) through each layer in turn." + for i, layer in enumerate(self.layers): + if i == 0: + x = layer(x, mask) + else: + if mask1 is None: + x = layer(x, mask) + else: + x = layer(x, mask1) + return self.norm(x) + + +class FactorizedEncoder(nn.Module): + def __init__(self, temporal_enc, agent_enc, temporal_pe, XY_pe, N_layer=1): + """ + Factorized encoder and agent axis + Args: + temporal_enc: encoder with attention over temporal axis + agent_enc: encoder with attention over agent axis + temporal_pe: positional encoding over time + XY_pe: positional encoding over XY coordinates + """ + super(FactorizedEncoder, self).__init__() + self.N_layer = N_layer + self.temporal_encs = clones(temporal_enc, N_layer) + self.agent_encs = clones(agent_enc, N_layer) + self.temporal_pe = temporal_pe + self.XY_pe = XY_pe + + def forward(self, x, src_mask, src_pos, map_emb): + """Pass the input (and mask) through each layer in turn. + Args: + x:[B,Num_agent,T,d_model] + src_mask:[B,Num_agent,T] + src_pos:[B,Num_agent,T,2] + map_emb: [B,Num_agent,1,map_emb_dim] output of the CNN ROI map encoder + Returns: + embedding of size [B,Num_agent,T,d_model] + """ + + if map_emb.ndim == 3: + map_emb = map_emb.unsqueeze(2).repeat(1, 1, x.size(2), 1) + x = ( + torch.cat( + ( + x, + self.XY_pe(x, src_pos), + self.temporal_pe(x).repeat(x.size(0), x.size(1), 1, 1), + map_emb, + ), + dim=-1, + ) + * src_mask.unsqueeze(-1) + ) + for i in range(self.N_layer): + x = self.agent_encs[i](x, src_mask) + x = self.temporal_encs[i](x, src_mask) + return x + + +class StaticEncoder(nn.Module): + def __init__(self, agent_enc, XY_pe, N_layer=1): + """ + Factorized encoder and agent axis + Args: + temporal_enc: encoder with attention over temporal axis + agent_enc: encoder with attention over agent axis + temporal_pe: positional encoding over time + XY_pe: positional encoding over XY coordinates + """ + super(StaticEncoder, self).__init__() + self.N_layer = N_layer + self.agent_encs = clones(agent_enc, N_layer) + self.XY_pe = XY_pe + + def forward(self, x, src_mask, src_pos, map_emb=None): + """Pass the input (and mask) through each layer in turn. + Args: + x:[B,Num_agent,T,d_model] + src_mask:[B,Num_agent,T] + src_pos:[B,Num_agent,T,2] + map_emb: [B,Num_agent,1,map_emb_dim] output of the CNN ROI map encoder + Returns: + embedding of size [B,Num_agent,T,d_model] + """ + inputs = [x, self.XY_pe(x, src_pos)] + if map_emb is not None: + inputs.append(map_emb) + + x = ( + torch.cat( + ( + inputs + ), + dim=-1, + ) + * src_mask.unsqueeze(-1) + ) + for i in range(self.N_layer): + x = self.agent_encs[i](x, src_mask) + return x + + +class LayerNorm(nn.Module): + "Construct a layernorm module" + + def __init__(self, features, eps=1e-6): + super(LayerNorm, self).__init__() + self.a_2 = nn.Parameter(torch.ones(features)) + self.b_2 = nn.Parameter(torch.zeros(features)) + self.eps = eps + + def forward(self, x): + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 + + +class SublayerConnection(nn.Module): + """ + A residual connection followed by a layer norm. + Note for code simplicity the norm is first as opposed to last. + """ + + def __init__(self, size, dropout): + super(SublayerConnection, self).__init__() + self.norm = LayerNorm(size) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, sublayer): + "Apply residual connection to any sublayer with the same size." + return x + self.dropout(sublayer(self.norm(x))) + + +class EncoderLayer(nn.Module): + "Encoder is made up of self-attn and feed forward" + + def __init__(self, size, self_attn, feed_forward, dropout): + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(size, dropout), 2) + self.size = size + + def forward(self, x, mask): + "self attention followed by feedforward, residual and batch norm in between layers" + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) + return self.sublayer[1](x, self.feed_forward) + + +class Decoder(nn.Module): + "Generic N layer decoder with masking." + + def __init__(self, layer, N): + super(Decoder, self).__init__() + self.layers = clones(layer, N) + self.norm = LayerNorm(layer.size) + + def forward(self, x, memory, src_mask, tgt_mask): + "cross attention to the embedding generated by the encoder" + for layer in self.layers: + x = layer(x, memory, src_mask, tgt_mask) + return self.norm(x) + + +class SummaryModel(nn.Module): + """ + map the scene information to attributes that summarizes the scene + """ + + def __init__(self, encoder, decoder, src_embed, src2posfun): + super(SummaryModel, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.src_embed = src_embed + self.src2posfun = src2posfun + + def src2pos(self, src, dyn_type): + "extract positional info from src for all datatypes, e.g., for vehicles, the first two dimensions are x and y" + + pos = torch.zeros([*src.shape[:-1], 2]).to(src.device) + for dt, fun in self.src2posfun.items(): + pos += fun(src) * (dyn_type == dt).view([*(dyn_type.shape), 1, 1]) + + return pos + + def forward( + self, + src, + src_mask, + dyn_type, + map_emb, + ): + "Take in and process masked src and target sequences." + src_pos = self.src2pos(src, dyn_type) + return self.decode( + self.encode(src, src_mask, src_pos, map_emb), + src_mask, + ) + + def encode(self, src, src_mask, src_pos, map_emb): + return self.encoder(self.src_embed(src), src_mask, src_pos, map_emb) + + def decode(self, memory, src_mask): + return self.decoder(memory, src_mask) + + +class SummaryDecoder(nn.Module): + """ + Map the encoded tensor to a description of the whole scene, e.g., the likelihood of certain modes + """ + + def __init__( + self, temporal_attn, agent_attn, ff, emb_dim, output_dim, static=False + ): + super(SummaryDecoder, self).__init__() + self.temporal_attn = temporal_attn + self.agent_attn = agent_attn + self.ff = ff + self.output_dim = output_dim + self.static = static + self.MLP = nn.Sequential(nn.Linear(emb_dim, output_dim), nn.Sigmoid()) + + def forward(self, x, mask): + x = self.agent_attn(x, x, x, mask) + x = self.ff(torch.max(x, dim=-3)[0]).unsqueeze(1) + if not self.static: + x = self.temporal_attn(x, x, x) + x = torch.max(x, dim=-2)[0].squeeze(1) + x = self.MLP(x) + return x + + +class FactorizedDecoder(nn.Module): + """ + Args: + temporal_dec: decoder with attention over temporal axis + agent_enc: decoder with attention over agent axis + temporal_pe: positional encoding for time axis + XY_pe: positional encoding for XY axis + """ + + def __init__( + self, + temporal_dec, + agent_enc, + temporal_enc, + temporal_pe, + XY_pe, + N_layer_enc=1, + N_layer_dec=1, + ): + super(FactorizedDecoder, self).__init__() + self.temporal_dec = clones(temporal_dec, N_layer_dec) + self.agent_enc = clones(agent_enc, N_layer_enc) + self.temporal_enc = clones(temporal_enc, N_layer_enc) + self.N_layer_enc = N_layer_enc + self.N_layer_dec = N_layer_dec + self.temporal_pe = temporal_pe + self.XY_pe = XY_pe + + def forward(self, x, memory, src_mask, tgt_mask, tgt_mask_agent, pos): + """ + Pass the input (and mask) through each layer in turn. + Args: + x (torch.tensor)): [batch,Num_agent,T_tgt,d_model] + memory (torch.tensor): [batch,Num_agent,T_src,d_model] + src_mask (torch.tensor): [batch,Num_agent,T_src] + tgt_mask (torch.tensor): [batch,Num_agent,T_tgt] + tgt_mask_agent (torch.tensor): [batch,Num_agent,T_tgt] + pos (torch.tensor): [batch,Num_agent,1,2] + + Returns: + torch.tensor: [batch,Num_agent,T_tgt,d_model] + """ + T = x.size(-2) + tgt_pos = pos.repeat([1, 1, T, 1]) + + x = ( + torch.cat( + ( + x, + self.XY_pe(x, tgt_pos), + self.temporal_pe(x).repeat(x.size(0), x.size(1), 1, 1), + ), + dim=-1, + ) + * tgt_mask_agent.unsqueeze(-1) + ) + + for i in range(self.N_layer_dec): + x = self.temporal_dec[i](x, memory, src_mask, tgt_mask) + prob = torch.ones(x.shape[0]).to(x.device) + return x * tgt_mask_agent.unsqueeze(-1), prob + + +class MultimodalFactorizedDecoder(nn.Module): + """ + Args: + temporal_dec: decoder with attention over temporal axis + agent_enc: decoder with attention over agent axis + temporal_pe: positional encoding for time axis + XY_pe: positional encoding for XY axis + """ + + def __init__( + self, + temporal_dec, + agent_enc, + temporal_enc, + temporal_pe, + XY_pe, + M, + summary_dec, + N_layer_enc=1, + N_layer_dec=1, + ): + super(MultimodalFactorizedDecoder, self).__init__() + self.M = M + self.temporal_dec = clones(temporal_dec, N_layer_dec) + self.agent_enc = clones(agent_enc, N_layer_enc) + self.temporal_enc = clones(temporal_enc, N_layer_enc) + self.N_layer_enc = N_layer_enc + self.N_layer_dec = N_layer_dec + self.temporal_pe = temporal_pe + self.XY_pe = XY_pe + self.summary_dec = summary_dec + + def forward(self, x, memory, src_mask, tgt_mask, tgt_mask_agent, pos): + """ + Pass the input (and mask) through each layer in turn. + Args: + x (torch.tensor)): [batch,Num_agent,T_tgt,d_model] + memory (torch.tensor): [batch,Num_agent,T_src,d_model] + src_mask (torch.tensor): [batch,Num_agent,T_src] + tgt_mask (torch.tensor): [batch,Num_agent,T_tgt] + tgt_mask_agent (torch.tensor): [batch,Num_agent,T_tgt] + pos (torch.tensor): [batch,Num_agent,1,2] + + Returns: + torch.tensor: [batch,Num_agent,T_tgt,d_model] + """ + T = x.size(-2) + tgt_pos = pos.repeat([1, 1, T, 1]) + + x = ( + torch.cat( + ( + x, + self.XY_pe(x, tgt_pos), + self.temporal_pe(x).repeat(x.size(0), x.size(1), 1, 1), + ), + dim=-1, + ) + * tgt_mask_agent.unsqueeze(-1) + ) + + # adding one-hot encoding of the modes + modes_enc = ( + F.one_hot(torch.arange(0, self.M)) + .view(1, self.M, 1, 1, self.M) + .repeat(x.size(0), 1, x.size(1), x.size(2), 1) + ).to(x.device) + + x = torch.cat((x.unsqueeze(1).repeat(1, self.M, 1, 1, 1), modes_enc), dim=-1) + + memory_M = memory.unsqueeze(1).repeat(1, self.M, 1, 1, 1) + src_mask_M = src_mask.unsqueeze(1).repeat(1, self.M, 1, 1) + tgt_mask_M = tgt_mask.unsqueeze(1).repeat(1, self.M, 1, 1, 1) + tgt_mask_agent_M = tgt_mask_agent.unsqueeze(1).repeat(1, self.M, 1, 1) + for i in range(self.N_layer_enc): + x = self.agent_enc[i](x, tgt_mask_agent_M) + x = self.temporal_enc[i](x, tgt_mask_M) + for i in range(self.N_layer_dec): + x = self.temporal_dec[i]( + x, + memory_M, + src_mask_M, + tgt_mask_M, + ) + + prob = self.summary_dec(x, tgt_mask_agent_M).squeeze(-1) + prob = F.softmax(prob, dim=-1) + return x * tgt_mask_agent_M.unsqueeze(-1), prob + + +class DecoderLayer(nn.Module): + "Decoder is made of self-attn, src-attn, and feed forward (defined below)" + + def __init__(self, size, self_attn, src_attn, feed_forward, dropout): + super(DecoderLayer, self).__init__() + self.size = size + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(size, dropout), 3) + + def forward(self, x, memory, src_mask, tgt_mask): + "self attention followed by cross attention with the encoder output" + + m = memory + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) + x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) + return self.sublayer[2](x, self.feed_forward) + + +def subsequent_mask(size): + "Mask out subsequent positions." + attn_shape = (1, size, size) + subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8") + return torch.from_numpy(subsequent_mask) == 0 + + +def attention(query, key, value, mask=None, dropout=None): + "Compute 'Scaled Dot Product Attention'" + d_k = query.size(-1) + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) + + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e9) + + p_attn = F.softmax(scores, dim=-1) + if dropout is not None: + p_attn = dropout(p_attn) + return torch.matmul(p_attn, value), p_attn + + +class MultiHeadedAttention(nn.Module): + def __init__(self, h, d_model, dropout=0.1, pooling_dim=None): + "Take in model size and number of heads." + super(MultiHeadedAttention, self).__init__() + assert d_model % h == 0 + # We assume d_v always equals d_k + self.d_k = d_model // h + self.h = h + self.linears = clones(nn.Linear(d_model, d_model), 4) + self.attn = None + self.pooling_dim = pooling_dim + self.dropout = nn.Dropout(p=dropout) + + def forward(self, query, key, value, mask=None): + "Implements Figure 2" + if self.pooling_dim is None: + pooling_dim = -2 + else: + pooling_dim = self.pooling_dim + + if mask is not None: + # Same mask applied to all h heads. + if mask.ndim == query.ndim - 1: + mask = mask.view([*mask.shape, 1, 1]).transpose(-1, pooling_dim - 1) + elif mask.ndim == query.ndim: + mask = mask.unsqueeze(-2).transpose(-2, pooling_dim - 1) + else: + raise Exception("mask dimension mismatch") + + # 1) Do all the linear projections in batch from d_model => h x d_k + + query, key, value = [ + l(x).view(*x.shape[:-1], self.h, self.d_k) + for l, x in zip(self.linears, (query, key, value)) + ] + + # 2) Apply attention on all the projected vectors in batch. + x, self.attn = attention( + query.transpose(-2, pooling_dim - 1), + key.transpose(-2, pooling_dim - 1), + value.transpose(-2, pooling_dim - 1), + mask, + dropout=self.dropout, + ) + + x = x.transpose(-2, pooling_dim - 1).contiguous() + x = x.view(*x.shape[:-2], self.h * self.d_k) + + # 3) "Concat" using a view and apply a final linear. + return self.linears[-1](x) + + +class PositionwiseFeedForward(nn.Module): + "Implements FFN equation." + + def __init__(self, d_model, d_ff, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Linear(d_model, d_ff) + self.w_2 = nn.Linear(d_ff, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.w_2(self.dropout(F.relu(self.w_1(x)))) + + +class PositionalEncoding(nn.Module): + "Implement the PE function." + + def __init__(self, dim, dropout, max_len=5000, flipped=False): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + self.dim = dim + self.flipped = flipped + + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, dim) + position = torch.arange(0, max_len).unsqueeze(1) + if self.flipped: + position = -position.flip(dims=[0]) + div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x): + pe_shape = [1] * (x.ndim - 2) + list(x.shape[-2:-1]) + [self.dim] + if self.flipped: + return self.dropout( + Variable(self.pe[:, -x.size(-2) :].view(pe_shape), requires_grad=False) + ) + else: + return self.dropout( + Variable(self.pe[:, : x.size(-2)].view(pe_shape), requires_grad=False) + ) + + +class PositionalEncodingNd(nn.Module): + "extension of the PE function, works for N dimensional position input" + + def __init__(self, dim, dropout, step_size=[1]): + """ + step_size: scale of each dimension, pos/step_size = phase for the sinusoidal PE + """ + super(PositionalEncodingNd, self).__init__() + assert dim % 2 == 0 + self.dropout = nn.Dropout(p=dropout) + self.dim = dim + self.step_size = step_size + self.D = len(step_size) + self.pe = list() + + # Compute the positional encodings once in log space. + self.div_term = torch.exp(torch.arange(0, dim, 2) * -(math.log(10000.0) / dim)) + + def forward(self, x, pos): + rep_size = [1] * (x.ndim) + rep_size[-1] = int(self.dim / 2) + pe_shape = [*x.shape[:-1], self.dim] + for i in range(self.D): + pe = torch.zeros(pe_shape).to(x.device) + + pe[..., 0::2] = torch.sin( + pos[..., i : i + 1].repeat(*rep_size) + / self.step_size[i] + * self.div_term.to(x.device) + ) + pe[..., 1::2] = torch.sin( + pos[..., i : i + 1].repeat(*rep_size) + / self.step_size[i] + * self.div_term.to(x.device) + ) + return self.dropout(Variable(pe, requires_grad=False)) + + +def make_transformer_model( + src_dim, + tgt_dim, + out_dim, + dyn_list, + N_t=6, + N_a=3, + d_model=384, + XY_pe_dim=64, + temporal_pe_dim=64, + map_emb_dim=128, + d_ff=2048, + head=8, + dropout=0.1, + step_size=[0.1, 0.1], + N_layer_enc=1, + N_layer_tgt_enc=1, + N_layer_tgt_dec=1, + M=1, + use_GAN=False, + GAN_static=True, + N_layer_enc_discr=1, +): + "first generate the building blocks, attn networks, encoders, decoders, PEs and Feedforward nets" + c = copy.deepcopy + temporal_attn = MultiHeadedAttention(head, d_model) + agent_attn = MultiHeadedAttention(head, d_model, pooling_dim=-3) + ff = PositionwiseFeedForward(d_model, d_ff, dropout) + temporal_pe = PositionalEncoding(temporal_pe_dim, dropout) + temporal_pe_flip = PositionalEncoding(temporal_pe_dim, dropout, flipped=True) + XY_pe = PositionalEncodingNd(XY_pe_dim, dropout, step_size=step_size) + temporal_enc = Encoder(EncoderLayer(d_model, c(temporal_attn), c(ff), dropout), N_t) + agent_enc = Encoder(EncoderLayer(d_model, c(agent_attn), c(ff), dropout), N_a) + + src_emb = nn.Linear(src_dim, d_model - XY_pe_dim - temporal_pe_dim - map_emb_dim) + if M == 1: + tgt_emb = nn.Linear(tgt_dim, d_model - XY_pe_dim - temporal_pe_dim) + else: + tgt_emb = nn.Linear(tgt_dim, d_model - XY_pe_dim - temporal_pe_dim - M) + generator = nn.Linear(d_model, out_dim) + + temporal_dec = Decoder( + DecoderLayer(d_model, c(temporal_attn), c(temporal_attn), c(ff), dropout), N_t + ) + "gather src2posfun from all agent types" + src2posfun = {D.type(): D.state2pos for D in dyn_list} + + Factorized_Encoder = FactorizedEncoder( + c(temporal_enc), c(agent_enc), temporal_pe_flip, XY_pe, N_layer_enc + ) + if M == 1: + Factorized_Decoder = FactorizedDecoder( + c(temporal_dec), + c(agent_enc), + c(temporal_enc), + temporal_pe, + XY_pe, + N_layer_tgt_enc, + N_layer_tgt_dec, + ) + else: + mode_summary_dec = SummaryDecoder( + c(temporal_attn), c(agent_attn), c(ff), d_model, 1 + ) + Factorized_Decoder = MultimodalFactorizedDecoder( + temporal_dec, + agent_enc, + temporal_enc, + temporal_pe, + XY_pe, + M, + mode_summary_dec, + N_layer_enc=1, + N_layer_dec=1, + ) + Factorized_Encoder = FactorizedEncoder( + c(temporal_enc), c(agent_enc), temporal_pe_flip, XY_pe, N_layer_enc + ) + if use_GAN: + if GAN_static: + Summary_Encoder = StaticEncoder( + c(agent_enc), + XY_pe, + N_layer_enc_discr, + ) + Summary_Decoder = SummaryDecoder( + c(temporal_attn), c(agent_attn), c(ff), d_model, 1, static=True + ) + static_src_emb = nn.Linear(src_dim, d_model - XY_pe_dim - map_emb_dim) + Summary_Model = SummaryModel( + Summary_Encoder, + Summary_Decoder, + c(static_src_emb), + src2posfun, + ) + else: + Summary_Encoder = Summary_Encoder = FactorizedEncoder( + c(temporal_enc), + c(agent_enc), + temporal_pe_flip, + XY_pe, + N_layer_enc_discr, + ) + Summary_Decoder = SummaryDecoder( + c(temporal_attn), c(agent_attn), c(ff), d_model, 1, static=True + ) + Summary_Model = SummaryModel( + Summary_Encoder, + Summary_Decoder, + c(src_emb), + src2posfun, + ) + + else: + Summary_Model = None + "use a simple nn.Linear as the generator as our output is continuous" + + Transformer_Model = FactorizedEncoderDecoder( + Factorized_Encoder, + Factorized_Decoder, + c(src_emb), + c(tgt_emb), + c(generator), + src2posfun, + ) + + return Transformer_Model, Summary_Model + + +class SimpleTransformer(nn.Module): + def __init__( + self, + src_dim, + N_a=3, + d_model=384, + XY_pe_dim=64, + d_ff=2048, + head=8, + dropout=0.1, + step_size=[0.1, 0.1], + ): + super(SimpleTransformer, self).__init__() + c = copy.deepcopy + agent_attn = MultiHeadedAttention(head, d_model, pooling_dim=-3) + ff = PositionwiseFeedForward(d_model, d_ff, dropout) + XY_pe = PositionalEncodingNd(XY_pe_dim, dropout, step_size=step_size) + self.agent_enc = StaticEncoder(EncoderLayer(d_model, c(agent_attn), c(ff), dropout), XY_pe, N_a) + self.pre_emb = nn.Linear(src_dim, d_model - XY_pe_dim) + self.post_emb = nn.Linear(d_model, src_dim) + + def forward(self, feats, avails, pos): + x = self.pre_emb(feats) + x = self.agent_enc(x, avails, pos) + return self.post_emb(x) + + +class simplelinear(nn.Module): + def __init__(self, input_dim, output_dim, hidden_dim=[64, 32]): + super(simplelinear, self).__init__() + self.hidden_dim = hidden_dim + self.hidden_layers = len(hidden_dim) + self.fhidden = nn.ModuleList() + + for i in range(1, self.hidden_layers): + self.fhidden.append(nn.Linear(hidden_dim[i - 1], hidden_dim[i])) + + self.f1 = nn.Linear(input_dim, hidden_dim[0]) + self.f2 = nn.Linear(hidden_dim[-1], output_dim) + + def forward(self, x): + hidden = self.f1(x) + for i in range(1, self.hidden_layers): + hidden = self.fhidden[i - 1](F.relu(hidden)) + return self.f2(F.relu(hidden)) diff --git a/diffstack/models/TypeTransformer.py b/diffstack/models/TypeTransformer.py new file mode 100644 index 0000000..d517f4f --- /dev/null +++ b/diffstack/models/TypeTransformer.py @@ -0,0 +1,791 @@ +""" +Transformer that accepts multiple types with separate attention +Adopted from the minGPT project https://github.com/karpathy/minGPT + + +""" + +import math + +import torch +import torch.nn as nn +from torch.nn import functional as F + +# ----------------------------------------------------------------------------- + + +class NewGELU(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). + Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 + """ + + def forward(self, x): + return ( + 0.5 + * x + * ( + 1.0 + + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)) + ) + ) + ) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class SelfAttention(nn.Module): + """ + A vanilla multi-head self-attention layer with a projection at the end. + """ + + def __init__(self, n_embd, n_head, attn_pdrop=0, resid_pdrop=0): + super().__init__() + assert n_embd % n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(n_embd, 3 * n_embd) + # output projection + self.c_proj = nn.Linear(n_embd, n_embd) + # regularization + self.attn_dropout = nn.Dropout(attn_pdrop) + self.attn_pdrop = attn_pdrop + self.resid_dropout = nn.Dropout(resid_pdrop) + + self.n_head = n_head + self.n_embd = n_embd + + def forward(self, x, mask=None): + # mask: (B,T,T) + ( + B, + T, + C, + ) = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + + # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + if mask is not None: + att = att.masked_fill(mask[:, None] == 0, float("-inf")) + att = att.masked_fill((mask == 0).all(-1)[:, None, :, None], 0.0) + finfo = torch.finfo(att.dtype) + att = att.nan_to_num(nan=0.0, posinf=finfo.max, neginf=finfo.min) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = ( + y.transpose(1, 2).contiguous().view(B, T, C) + ) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class AuxSelfAttention(nn.Module): + """ + A vanilla multi-head self-attention layer with a projection at the end. + """ + + def __init__( + self, + n_embd, + n_head, + edge_dim, + aux_edge_func=None, + attn_pdrop=0, + resid_pdrop=0, + PE_len=None, + ): + super().__init__() + assert n_embd % n_head == 0 + # assert edge_dim % n_head == 0 + # key, query, value projections for all heads, but in a batch + self.q_net = nn.Linear(n_embd, n_embd) + self.k_net = nn.Linear(n_embd + edge_dim, n_embd) + self.v_net = nn.Linear(n_embd + edge_dim, n_embd) + self.aux_edge_func = aux_edge_func + + self.PE_len = PE_len + if PE_len is not None: + self.PE_q = nn.Embedding(PE_len, n_embd) + self.PE_k = nn.Embedding(PE_len, n_embd) + # output projection + self.c_proj = nn.Linear(n_embd, n_embd) + # regularization + self.attn_dropout = nn.Dropout(attn_pdrop) + self.attn_pdrop = attn_pdrop + self.resid_dropout = nn.Dropout(resid_pdrop) + + self.n_head = n_head + self.n_embd = n_embd + + def forward(self, x, aux_x, mask=None, edge=None, frame_indices=None): + # mask: (B,T,T) + ( + B, + T, + C, + ) = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q = self.q_net(x) + if edge is not None or self.aux_edge_func is not None: + T = aux_x.size(1) + if edge is None: + edge = self.aux_edge_func(aux_x, aux_x) # (B,T,T,aux_vardim) + if self.k_net.weight.dtype == torch.float16: + edge = edge.half() + aug_x = torch.cat([x.unsqueeze(2).repeat_interleave(T, 2), edge], dim=-1) + k = self.k_net(aug_x) + v = self.v_net(aug_x) + k = k.view(B, T, T, self.n_head, C // self.n_head).permute( + 0, 3, 1, 2, 4 + ) # (B, nh, T, T, hs) + q = ( + q.view(B, T, self.n_head, 1, C // self.n_head) + .transpose(1, 2) + .repeat_interleave(T, 3) + ) # (B, nh, T, T, hs) + v = v.view(B, T, T, self.n_head, C // self.n_head).permute( + 0, 3, 1, 2, 4 + ) # (B, nh, T, T, hs) + if self.PE_len is not None: + q = q + self.PE_q(frame_indices).view( + B, T, self.n_head, 1, C // self.n_head + ).transpose(1, 2).repeat_interleave(T, 3) + k = k + self.PE_k(frame_indices).view( + B, T, self.n_head, C // self.n_head + ).transpose(1, 2).unsqueeze(2).repeat_interleave(T, 2) + # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q * k).sum(-1) * (1.0 / math.sqrt(k.size(-1))) + else: + aug_x = torch.cat([x, aux_x], dim=-1) + k = self.k_net(aug_x) + v = self.v_net(aug_x) + k = k.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + if self.PE_len is not None: + q = q + self.PE_q(frame_indices).view( + B, T, self.n_head, C // self.n_head + ).transpose(1, 2) + k = k + self.PE_k(frame_indices).view( + B, T, self.n_head, C // self.n_head + ).transpose(1, 2) + # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + + if mask is not None: + att = att.masked_fill(mask[:, None] == 0, float("-inf")) + att = att.masked_fill((mask == 0).all(-1)[:, None, :, None], 0.0) + finfo = torch.finfo(att.dtype) + att = att.nan_to_num(nan=0.0, posinf=finfo.max, neginf=finfo.min) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + # if v.ndim==4, (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) elif v.ndim==5, (B, nh, T, T) x (B, nh, T, T, hs) -> (B, nh, T, hs) + y = att @ v if v.ndim == 4 else (att.unsqueeze(-1) * v).sum(-2) + + y = ( + y.transpose(1, 2).contiguous().view(B, T, C) + ) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class AuxCrossAttention(nn.Module): + """ + A vanilla multi-head self-attention layer with a projection at the end. + """ + + def __init__( + self, + n_embd, + n_head, + edge_dim, + aux_edge_func=None, + attn_pdrop=0, + resid_pdrop=0, + PE_len=None, + ): + super().__init__() + assert n_embd % n_head == 0 + # assert edge_dim % n_head == 0 + self.edge_dim = edge_dim + # key, query, value projections for all heads, but in a batch + self.q_net = nn.Linear(n_embd, n_embd) + self.k_net = nn.Linear(n_embd + edge_dim, n_embd) + self.v_net = nn.Linear(n_embd + edge_dim, n_embd) + self.PE_len = PE_len + if PE_len is not None: + self.PE_q = nn.Embedding(PE_len, n_embd) + self.PE_k = nn.Embedding(PE_len, n_embd) + self.aux_edge_func = aux_edge_func + # output projection + self.c_proj = nn.Linear(n_embd, n_embd) + # regularization + self.attn_dropout = nn.Dropout(attn_pdrop) + self.resid_dropout = nn.Dropout(resid_pdrop) + + self.n_head = n_head + self.n_embd = n_embd + + def forward( + self, + x1, + x2, + mask, + aux_x1, + aux_x2, + frame_indices1=None, + frame_indices2=None, + edge=None, + ): + # mask: (B,T,T) + ( + B, + T1, + C, + ) = x1.size() # batch size, sequence length, embedding dimensionality (n_embd) + T2 = x2.size(1) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q = self.q_net(x1) + if self.edge_dim > 0: + if edge is None: + if self.aux_edge_func is None: + # default to concatenation + edge = torch.cat( + [ + aux_x1.unsqueeze(2).repeat_interleave(T2, 2), + aux_x2.unsqueeze(1).repeat_interleave(T1, 1), + ], + -1, + ) # (B,T1,T2,auxdim1+auxdim2) + else: + edge = self.aux_edge_func(aux_x1, aux_x2) # (B,T1,T2,aux_vardim) + if self.k_net.weight.dtype == torch.float16: + edge = edge.half() + aug_x2 = torch.cat([x2.unsqueeze(1).repeat_interleave(T1, 1), edge], dim=-1) + else: + aug_x2 = x2.unsqueeze(1).repeat_interleave(T1, 1) + + k = self.k_net(aug_x2) + v = self.v_net(aug_x2) + k = k.view(B, T1, T2, self.n_head, C // self.n_head).permute( + 0, 3, 1, 2, 4 + ) # (B, nh, T1, T2, hs) + q = ( + q.view(B, T1, self.n_head, 1, C // self.n_head) + .transpose(1, 2) + .repeat_interleave(T2, 3) + ) # (B, nh, T1, T2, hs) + v = v.view(B, T1, T2, self.n_head, C // self.n_head).permute( + 0, 3, 1, 2, 4 + ) # (B, nh, T1, T2, hs) + if self.PE_len is not None: + q = q + self.PE_q(frame_indices1).view( + B, T1, self.n_head, 1, C // self.n_head + ).transpose(1, 2).repeat_interleave(T2, 3) + k = k + self.PE_k(frame_indices2).view( + B, T2, self.n_head, C // self.n_head + ).transpose(1, 2).unsqueeze(2).repeat_interleave(T1, 2) + # # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q * k).sum(-1) * (1.0 / math.sqrt(k.size(-1))) + if mask is not None: + att = att.masked_fill(mask[:, None] == 0, float("-inf")) + att = att.masked_fill((mask == 0).all(-1)[:, None, :, None], 0.0) + finfo = torch.finfo(att.dtype) + att = att.nan_to_num(nan=0.0, posinf=finfo.max, neginf=finfo.min) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + # if v.ndim==4, (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) elif v.ndim==5, (B, nh, T, T) x (B, nh, T, T, hs) -> (B, nh, T, hs) + y = att @ v if v.ndim == 4 else (att.unsqueeze(-1) * v).sum(-2) + y = ( + y.transpose(1, 2).contiguous().view(B, T1, C) + ) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class CrossAttention(nn.Module): + """ + A vanilla multi-head cross-attention layer with a projection at the end. + """ + + def __init__(self, n_embd, n_head, attn_pdrop=0, resid_pdrop=0): + super().__init__() + assert n_embd % n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_q = nn.Linear(n_embd, n_embd) + self.c_k = nn.Linear(n_embd, n_embd) + self.c_v = nn.Linear(n_embd, n_embd) + # output projection + self.c_proj = nn.Linear(n_embd, n_embd) + # regularization + self.attn_pdrop = attn_pdrop + self.attn_dropout = nn.Dropout(attn_pdrop) + self.resid_dropout = nn.Dropout(resid_pdrop) + + self.n_head = n_head + self.n_embd = n_embd + + def forward(self, x, z, mask=None): + # x: (B, T_in, C), query + # z: (B, T_m, C), memory + # mask: (B, T_in, T_m), mask for memory + ( + B, + T_in, + C, + ) = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + T_m = z.size(1) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q = self.c_q(x) + k, v = self.c_k(z), self.c_v(z) + + k = k.view(B, T_m, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T_m, hs) + q = q.view(B, T_in, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T_in, hs) + v = v.view(B, T_m, self.n_head, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T_m, hs) + + # cross-attention; Cross-attend: (B, nh, T_in, hs) x (B, nh, hs, T_m) -> (B, nh, T_in, T_m) + + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + if mask is not None: + att = att.masked_fill(mask[:, None] == 0, float("-inf")) + att = att.masked_fill((mask == 0).all(-1)[:, None, :, None], 0.0) + finfo = torch.finfo(att.dtype) + att = att.nan_to_num(nan=0.0, posinf=finfo.max, neginf=finfo.min) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T_in, T_m) x (B, nh, T_m, hs) -> (B, nh, T_in, hs) + y = ( + y.transpose(1, 2).contiguous().view(B, T_in, C) + ) # re-assemble all head outputs side by side + + y = self.resid_dropout(self.c_proj(y)) + + return y + + +class TypeSelfAttention(nn.Module): + """a composite attention block supporting multiple attention types""" + + def __init__( + self, + ntype, + n_embd, + n_head, + edge_dim=0, + aux_edge_func=None, + attn_pdrop=0, + resid_pdrop=0, + ): + super().__init__() + self.ntype = ntype + self.edge_dim = edge_dim + if edge_dim > 0: + self.attn = nn.ModuleList( + [ + AuxSelfAttention( + n_embd, n_head, edge_dim, aux_edge_func, attn_pdrop, resid_pdrop + ) + for _ in range(self.ntype) + ] + ) + else: + self.attn = nn.ModuleList( + [ + SelfAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + for _ in range(self.ntype) + ] + ) + + def forward(self, x, aux_x, masks, edge=None): + if self.ntype == 1 and isinstance(masks, torch.Tensor): + masks = [masks] + assert len(masks) == self.ntype + if edge is not None: + if isinstance(edge, list): + assert len(edge) == self.ntype + elif isinstance(edge, torch.Tensor): + edge = [edge] * self.ntype + else: + edge = [None] * self.ntype + resid = torch.zeros_like(x) + if self.edge_dim > 0: + assert aux_x is not None + for i, mask in enumerate(masks): + resid = resid + self.attn[i](x, aux_x, mask, edge=edge[i]) + else: + for i, mask in enumerate(masks): + resid = resid + self.attn[i](x, mask, edge=edge[i]) + return resid + + +class TypeCrossAttention(nn.Module): + """a composite attention block supporting multiple attention types""" + + def __init__( + self, + ntype, + n_embd, + n_head, + edge_dim, + aux_edge_func=None, + attn_pdrop=0, + resid_pdrop=0, + ): + super().__init__() + self.ntype = ntype + self.edge_dim = edge_dim + if edge_dim > 0: + self.attn = nn.ModuleList( + [ + AuxCrossAttention( + n_embd, n_head, edge_dim, aux_edge_func, attn_pdrop, resid_pdrop + ) + for _ in range(self.ntype) + ] + ) + else: + self.attn = nn.ModuleList( + [ + CrossAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + for _ in range(self.ntype) + ] + ) + + def forward(self, x1, x2, masks, aux_x1=None, aux_x2=None, edge=None): + if self.ntype == 1 and isinstance(masks, torch.Tensor): + masks = [masks] + assert len(masks) == self.ntype + if edge is not None: + if isinstance(edge, list): + assert len(edge) == self.ntype + elif isinstance(edge, torch.Tensor): + edge = [edge] * self.ntype + else: + edge = [None] * self.ntype + resid = torch.zeros_like(x1) + if self.edge_dim > 0: + for i, mask in enumerate(masks): + resid = resid + self.attn[i](x1, x2, mask, aux_x1, aux_x2, edge=edge[i]) + else: + for i, mask in enumerate(masks): + resid = resid + self.attn[i](x1, x2, mask, edge=edge[i]) + return resid + + +class EncBlock(nn.Module): + """an unassuming Transformer encoder block""" + + def __init__(self, n_embd, n_head, attn_pdrop=0, resid_pdrop=0): + super().__init__() + self.ln_1 = nn.LayerNorm(n_embd) + self.attn = SelfAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + self.ln_2 = nn.LayerNorm(n_embd) + self.mlp = nn.ModuleDict( + dict( + c_fc=nn.Linear(n_embd, 4 * n_embd), + c_proj=nn.Linear(4 * n_embd, n_embd), + act=NewGELU(), + dropout=nn.Dropout(resid_pdrop), + ) + ) + m = self.mlp + self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward + + def forward(self, x, mask=None): + x = x + self.attn(self.ln_1(x), mask) + x = x + self.mlpf(self.ln_2(x)) + return x + + +class DecBlock(nn.Module): + """an unassuming Transformer decoder block""" + + def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop): + super().__init__() + self.ln_1 = nn.LayerNorm(n_embd) + self.attn = CrossAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + self.ln_2 = nn.LayerNorm(n_embd) + self.mlp = nn.ModuleDict( + dict( + c_fc=nn.Linear(n_embd, 4 * n_embd), + c_proj=nn.Linear(4 * n_embd, n_embd), + act=NewGELU(), + dropout=nn.Dropout(resid_pdrop), + ) + ) + m = self.mlp + self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward + + def forward(self, x, z, mask=None): + x = x + self.attn(self.ln_1(x), z, mask) + x = x + self.mlpf(self.ln_2(x)) + return x + + +class TypeEncBlock(nn.Module): + """an unassuming Transformer encoder block supporting multiple attention types""" + + def __init__(self, n_embd, ntype, n_head, attn_pdrop=0, resid_pdrop=0): + super().__init__() + self.ln_1 = nn.LayerNorm(n_embd) + self.ntype = ntype + self.attn = nn.ModuleList( + [ + SelfAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + for _ in range(self.ntype) + ] + ) + self.ln_2 = nn.LayerNorm(n_embd) + self.mlp = nn.ModuleDict( + dict( + c_fc=nn.Linear(n_embd, 4 * n_embd), + c_proj=nn.Linear(4 * n_embd, n_embd), + act=NewGELU(), + dropout=nn.Dropout(resid_pdrop), + ) + ) + m = self.mlp + self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward + + def forward(self, x, masks): + assert len(masks) == self.ntype + for i, mask in enumerate(masks): + x = x + self.attn[i](self.ln_1(x), mask) + x = x + self.mlpf(self.ln_2(x)) + return x + + +class TypeDecBlock(nn.Module): + """an unassuming Transformer decoder block supporting multiple attention types""" + + def __init__(self, n_embd, ntype, n_head, attn_pdrop, resid_pdrop): + super().__init__() + self.ln_1 = nn.LayerNorm(n_embd) + self.ntype = ntype + self.attn = nn.ModuleList( + [ + CrossAttention(n_embd, n_head, attn_pdrop, resid_pdrop) + for _ in range(self.ntype) + ] + ) + self.ln_2 = nn.LayerNorm(n_embd) + self.mlp = nn.ModuleDict( + dict( + c_fc=nn.Linear(n_embd, 4 * n_embd), + c_proj=nn.Linear(4 * n_embd, n_embd), + act=NewGELU(), + dropout=nn.Dropout(resid_pdrop), + ) + ) + m = self.mlp + self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward + + def forward(self, x, z, masks): + assert len(masks) == self.ntype + for i, mask in enumerate(masks): + x = x + self.attn[i](self.ln_1(x), z, mask) + x = x + self.mlpf(self.ln_2(x)) + return x + + +class TransformerEncoder(nn.Module): + def __init__( + self, + n_embd, + ntype, + n_head, + attn_pdrop, + resid_pdrop, + nblock, + indim=None, + outdim=None, + ): + super().__init__() + self.nblock = nblock + self.indim = indim + self.outdim = outdim + self.in_embd = ( + nn.Linear(self.indim, n_embd) + if self.indim is not None + else torch.nn.Identity() + ) + self.out_proj = ( + nn.Linear(n_embd, self.outdim) + if self.outdim is not None + else torch.nn.Identity() + ) + self.ntype = ntype + if self.ntype == 1: + self.blocks = nn.ModuleDict( + { + "block_{}".format(i): EncBlock( + n_embd, n_head, attn_pdrop, resid_pdrop + ) + for i in range(self.nblock) + } + ) + else: + self.blocks = nn.ModuleDict( + { + "block_{}".format(i): TypeEncBlock( + n_embd, ntype, n_head, attn_pdrop, resid_pdrop + ) + for i in range(self.nblock) + } + ) + + def forward(self, x, mask=None): + x = self.in_embd(x) + for i in range(self.nblock): + x = self.blocks["block_{}".format(i)](x, mask) + x = self.out_proj(x) + return x + + +class TransformerDecoder(nn.Module): + def __init__( + self, + n_embd, + ntype, + n_head, + attn_pdrop, + resid_pdrop, + nblock, + indim=None, + outdim=None, + ): + super().__init__() + self.nblock = nblock + self.indim = indim + self.outdim = outdim + self.in_embd = ( + nn.Linear(self.indim, n_embd) + if self.indim is not None + else torch.nn.Identity() + ) + self.out_proj = ( + nn.Linear(n_embd, self.outdim) + if self.outdim is not None + else torch.nn.Identity() + ) + self.ntype = ntype + if self.ntype == 1: + self.blocks = nn.ModuleDict( + { + "block_{}".format(i): DecBlock( + n_embd, n_head, attn_pdrop, resid_pdrop + ) + for i in range(self.nblock) + } + ) + else: + self.blocks = nn.ModuleDict( + { + "block_{}".format(i): TypeDecBlock( + n_embd, ntype, n_head, attn_pdrop, resid_pdrop + ) + for i in range(self.nblock) + } + ) + + def forward(self, x, z, mask=None): + x = self.in_embd(x) + for i in range(self.nblock): + x = self.blocks["block_{}".format(i)](x, z, mask) + x = self.out_proj(x) + return x + + +def test(): + n_embd = 256 + ntype = 3 + n_head = 8 + attn_pdrop = 0.1 + resid_pdrop = 0.1 + enc = TransformerEncoder( + n_embd, ntype, n_head, attn_pdrop, resid_pdrop, nblock=2, indim=15 + ).cuda() + dec = TransformerDecoder( + n_embd, ntype, n_head, attn_pdrop, resid_pdrop, nblock=2, indim=15, outdim=40 + ).cuda() + b = 2 + T_in = 10 + T_m = 20 + xin = torch.randn(b, T_in, 15).cuda() + + type = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 2]).cuda() + type_flag = [(type == i)[None].repeat_interleave(b, 0) for i in range(3)] + enc_masks = [fg[..., None] * fg[..., None, :] for fg in type_flag] + + mask1 = torch.ones(2, T_in, T_m).bool().cuda() + mask2 = torch.tril(mask1) + mask3 = torch.triu(mask1) + dec_masks = [mask1, mask2, mask3] + mem = enc(xin, mask=enc_masks) + mem = torch.repeat_interleave(mem, 2, 1) + x1 = dec(xin, mem, [mask1, mask2, mask3]) + print("done") + + +def test_edge_func(): + n_embd = 256 + n_head = 8 + aux_vardim = 16 + aux_edge_func = lambda x, y: x.unsqueeze(2) * y.unsqueeze(1) + + attn = AuxCrossAttention(n_embd, n_head, aux_vardim, aux_edge_func) + b = 2 + N1 = 3 + N2 = 4 + x1 = torch.randn([b, N1, n_embd]) + aux_x1 = torch.randn([b, N1, aux_vardim]) + + x2 = torch.randn([b, N2, n_embd]) + aux_x2 = torch.randn([b, N2, aux_vardim]) + xx = attn(x1, x2, None, aux_x1, aux_x2) + print("123") + + +if __name__ == "__main__": + test_edge_func() diff --git a/__init__.py b/diffstack/models/__init__.py similarity index 100% rename from __init__.py rename to diffstack/models/__init__.py diff --git a/diffstack/models/base_models.py b/diffstack/models/base_models.py new file mode 100644 index 0000000..ce0b671 --- /dev/null +++ b/diffstack/models/base_models.py @@ -0,0 +1,1688 @@ +import numpy as np +import math +import textwrap +from collections import OrderedDict +from typing import Dict, Union, List +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.models.resnet import resnet18, resnet50 +from torchvision.models.feature_extraction import create_feature_extractor +from torchvision.ops import RoIAlign +from diffstack.models.Transformer import SimpleTransformer + +from diffstack.utils.tensor_utils import reshape_dimensions, flatten +import diffstack.utils.tensor_utils as TensorUtils +import diffstack.dynamics as dynamics + + +class MLP(nn.Module): + """ + Base class for simple Multi-Layer Perceptrons. + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + layer_dims: tuple = (), + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=nn.ReLU, + dropouts=None, + normalization=False, + output_activation=None, + ): + """ + Args: + input_dim (int): dimension of inputs + output_dim (int): dimension of outputs + layer_dims ([int]): sequence of integers for the hidden layers sizes + layer_func: mapping per layer - defaults to Linear + layer_func_kwargs (dict): kwargs for @layer_func + activation: non-linearity per layer - defaults to ReLU + dropouts ([float]): if not None, adds dropout layers with the corresponding probabilities + after every layer. Must be same size as @layer_dims. + normalization (bool): if True, apply layer normalization after each layer + output_activation: if provided, applies the provided non-linearity to the output layer + """ + super(MLP, self).__init__() + layers = [] + dim = input_dim + if layer_func_kwargs is None: + layer_func_kwargs = dict() + if dropouts is not None: + assert len(dropouts) == len(layer_dims) + for i, l in enumerate(layer_dims): + layers.append(layer_func(dim, l, **layer_func_kwargs)) + if normalization: + layers.append(nn.LayerNorm(l)) + layers.append(activation()) + if dropouts is not None and dropouts[i] > 0.0: + layers.append(nn.Dropout(dropouts[i])) + dim = l + layers.append(layer_func(dim, output_dim)) + if output_activation is not None: + if isinstance(output_activation, nn.Module): + layers.append(output_activation) + else: + layers.append(output_activation()) + self._layer_func = layer_func + self.nets = layers + self._model = nn.Sequential(*layers) + + self._layer_dims = layer_dims + self._input_dim = input_dim + self._output_dim = output_dim + self._dropouts = dropouts + self._act = activation + self._output_act = output_activation + + def output_shape(self, input_shape=None): + """ + Function to compute output shape from inputs to this module. + Args: + input_shape (iterable of int): shape of input. Does not include batch dimension. + Some modules may not need this argument, if their output does not depend + on the size of the input, or if they assume fixed size input. + Returns: + out_shape ([int]): list of integers corresponding to output shape + """ + return [self._output_dim] + + def forward(self, inputs): + """ + Forward pass. + """ + return self._model(inputs) + + def __repr__(self): + """Pretty print network.""" + header = str(self.__class__.__name__) + act = None if self._act is None else self._act.__name__ + output_act = None if self._output_act is None else self._output_act.__name__ + + indent = " " * 4 + msg = "input_dim={}\noutput_shape={}\nlayer_dims={}\nlayer_func={}\ndropout={}\nact={}\noutput_act={}".format( + self._input_dim, + self.output_shape(), + self._layer_dims, + self._layer_func.__name__, + self._dropouts, + act, + output_act, + ) + msg = textwrap.indent(msg, indent) + msg = header + "(\n" + msg + "\n)" + return msg + + +class SplitMLP(MLP): + """ + A multi-output MLP network: The model split and reshapes the output layer to the desired output shapes + """ + + def __init__( + self, + input_dim: int, + output_shapes: OrderedDict, + layer_dims: tuple = (), + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=nn.ReLU, + dropouts=None, + normalization=False, + output_activation=None, + ): + """ + Args: + input_dim (int): dimension of inputs + output_shapes (dict): named dictionary of output shapes + layer_dims ([int]): sequence of integers for the hidden layers sizes + layer_func: mapping per layer - defaults to Linear + layer_func_kwargs (dict): kwargs for @layer_func + activation: non-linearity per layer - defaults to ReLU + dropouts ([float]): if not None, adds dropout layers with the corresponding probabilities + after every layer. Must be same size as @layer_dims. + normalization (bool): if True, apply layer normalization after each layer + output_activation: if provided, applies the provided non-linearity to the output layer + """ + + assert isinstance(output_shapes, OrderedDict) + output_dim = 0 + for v in output_shapes.values(): + output_dim += np.prod(v) + self._output_shapes = output_shapes + + super(SplitMLP, self).__init__( + input_dim=input_dim, + output_dim=output_dim, + layer_dims=layer_dims, + layer_func=layer_func, + layer_func_kwargs=layer_func_kwargs, + activation=activation, + dropouts=dropouts, + normalization=normalization, + output_activation=output_activation, + ) + + def output_shape(self, input_shape=None): + return self._output_shapes + + def forward(self, inputs): + outs = super(SplitMLP, self).forward(inputs) + out_dict = dict() + ind = 0 + for k, v in self._output_shapes.items(): + v_dim = int(np.prod(v)) + out_dict[k] = reshape_dimensions( + outs[:, ind : ind + v_dim], begin_axis=1, end_axis=2, target_dims=v + ) + ind += v_dim + return out_dict + + +class MIMOMLP(SplitMLP): + """ + A multi-input, multi-output MLP: The model flattens and concatenate the input before feeding into an MLP + """ + + def __init__( + self, + input_shapes: OrderedDict, + output_shapes: OrderedDict, + layer_dims: tuple = (), + layer_func=nn.Linear, + layer_func_kwargs=None, + activation=nn.ReLU, + dropouts=None, + normalization=False, + output_activation=None, + ): + """ + Args: + input_shapes (OrderedDict): named dictionary of input shapes + output_shapes (OrderedDict): named dictionary of output shapes + layer_dims ([int]): sequence of integers for the hidden layers sizes + layer_func: mapping per layer - defaults to Linear + layer_func_kwargs (dict): kwargs for @layer_func + activation: non-linearity per layer - defaults to ReLU + dropouts ([float]): if not None, adds dropout layers with the corresponding probabilities + after every layer. Must be same size as @layer_dims. + normalization (bool): if True, apply layer normalization after each layer + output_activation: if provided, applies the provided non-linearity to the output layer + """ + assert isinstance(input_shapes, OrderedDict) + input_dim = 0 + for v in input_shapes.values(): + input_dim += np.prod(v) + + self._input_shapes = input_shapes + + super(MIMOMLP, self).__init__( + input_dim=input_dim, + output_shapes=output_shapes, + layer_dims=layer_dims, + layer_func=layer_func, + layer_func_kwargs=layer_func_kwargs, + activation=activation, + dropouts=dropouts, + normalization=normalization, + output_activation=output_activation, + ) + + def forward(self, inputs): + flat_inputs = [] + for k in self._input_shapes.keys(): + flat_inputs.append(flatten(inputs[k])) + return super(MIMOMLP, self).forward(torch.cat(flat_inputs, dim=1)) + + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), + # (Mohit): argh... forgot to remove this batchnorm + nn.BatchNorm2d(mid_channels), + nn.ReLU(inplace=True), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), + # (Mohit): argh... forgot to remove this batchnorm + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) + + def forward(self, x): + return self.double_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv""" + + def __init__(self, in_channels, out_channels, bilinear=True): + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + else: + self.up = nn.ConvTranspose2d( + in_channels, in_channels // 2, kernel_size=2, stride=2 + ) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + """ + U-Net forward by concatenating input feature (x1) with mirroring encoding feature maps channel-wise (x2) + Args: + x1 (torch.Tensor): [B, C1, H1, W1] + x2 (torch.Tensor): [B, C2, H2, W2] + + Returns: + output (torch.Tensor): [B, out_channels, H2, W2] + """ + x1 = self.up(x1) + # input is CHW + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) + # if you have padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class IdentityBlock(nn.Module): + def __init__( + self, in_planes, filters, kernel_size, stride=1, final_relu=True, batchnorm=True + ): + super(IdentityBlock, self).__init__() + self.final_relu = final_relu + self.batchnorm = batchnorm + + filters1, filters2, filters3 = filters + self.conv1 = nn.Conv2d(in_planes, filters1, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(filters1) if self.batchnorm else nn.Identity() + self.conv2 = nn.Conv2d( + filters1, + filters2, + kernel_size=kernel_size, + dilation=1, + stride=stride, + padding=1, + bias=False, + ) + self.bn2 = nn.BatchNorm2d(filters2) if self.batchnorm else nn.Identity() + self.conv3 = nn.Conv2d(filters2, filters3, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(filters3) if self.batchnorm else nn.Identity() + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += x + if self.final_relu: + out = F.relu(out) + return out + + +class ConvBlock(nn.Module): + def __init__( + self, in_planes, filters, kernel_size, stride=1, final_relu=True, batchnorm=True + ): + super(ConvBlock, self).__init__() + self.final_relu = final_relu + self.batchnorm = batchnorm + + filters1, filters2, filters3 = filters + self.conv1 = nn.Conv2d(in_planes, filters1, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(filters1) if self.batchnorm else nn.Identity() + self.conv2 = nn.Conv2d( + filters1, + filters2, + kernel_size=kernel_size, + dilation=1, + stride=stride, + padding=1, + bias=False, + ) + self.bn2 = nn.BatchNorm2d(filters2) if self.batchnorm else nn.Identity() + self.conv3 = nn.Conv2d(filters2, filters3, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(filters3) if self.batchnorm else nn.Identity() + + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, filters3, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(filters3) if self.batchnorm else nn.Identity(), + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + if self.final_relu: + out = F.relu(out) + return out + + +class UNetDecoder(nn.Module): + """UNet part based on https://github.com/milesial/Pytorch-UNet/tree/master/unet""" + + def __init__( + self, + input_channel, + output_channel, + encoder_channels, + up_factor, + bilinear=True, + batchnorm=True, + ): + super(UNetDecoder, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d( + input_channel, 1024, kernel_size=3, stride=1, padding=1, bias=False + ), + nn.ReLU(True), + ) + + self.up1 = Up(1024 + encoder_channels[-1], 512, bilinear=True) + + self.up2 = Up(512 + encoder_channels[-2], 512 // up_factor, bilinear) + + self.up3 = Up(256 + encoder_channels[-3], 256 // up_factor, bilinear) + + self.layer1 = nn.Sequential( + ConvBlock(128, [64, 64, 64], kernel_size=3, stride=1, batchnorm=batchnorm), + IdentityBlock( + 64, [64, 64, 64], kernel_size=3, stride=1, batchnorm=batchnorm + ), + nn.UpsamplingBilinear2d(scale_factor=2), + ) + + self.layer2 = nn.Sequential( + ConvBlock(64, [32, 32, 32], kernel_size=3, stride=1, batchnorm=batchnorm), + IdentityBlock( + 32, [32, 32, 32], kernel_size=3, stride=1, batchnorm=batchnorm + ), + nn.UpsamplingBilinear2d(scale_factor=2), + ) + + self.layer3 = nn.Sequential( + ConvBlock(32, [16, 16, 16], kernel_size=3, stride=1, batchnorm=batchnorm), + IdentityBlock( + 16, [16, 16, 16], kernel_size=3, stride=1, batchnorm=batchnorm + ), + nn.UpsamplingBilinear2d(scale_factor=2), + ) + + self.conv2 = nn.Sequential(nn.Conv2d(16, output_channel, kernel_size=1)) + + def forward( + self, + feat_to_decode: torch.Tensor, + encoder_feats: List[torch.Tensor], + target_hw: tuple, + ): + assert len(encoder_feats) >= 3 + x = self.conv1(feat_to_decode) + x = self.up1(x, encoder_feats[-1]) + x = self.up2(x, encoder_feats[-2]) + x = self.up3(x, encoder_feats[-3]) + + for layer in [self.layer1, self.layer2, self.layer3, self.conv2]: + x = layer(x) + + x = F.interpolate(x, size=(target_hw[0], target_hw[1]), mode="bilinear") + return x + + +class SpatialSoftmax(nn.Module): + """ + Spatial Softmax Layer. + Based on Deep Spatial Autoencoders for Visuomotor Learning by Finn et al. + https://rll.berkeley.edu/dsae/dsae.pdf + """ + + def __init__( + self, + input_shape, + num_kp=None, + temperature=1.0, + learnable_temperature=False, + output_variance=False, + noise_std=0.0, + ): + """ + Args: + input_shape (list, tuple): shape of the input feature (C, H, W) + num_kp (int): number of keypoints (None for not use spatialsoftmax) + temperature (float): temperature term for the softmax. + learnable_temperature (bool): whether to learn the temperature + output_variance (bool): treat attention as a distribution, and compute second-order statistics to return + noise_std (float): add random spatial noise to the predicted keypoints + """ + super(SpatialSoftmax, self).__init__() + assert len(input_shape) == 3 + self._in_c, self._in_h, self._in_w = input_shape # (C, H, W) + + if num_kp is not None: + self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1) + self._num_kp = num_kp + else: + self.nets = None + self._num_kp = self._in_c + self.learnable_temperature = learnable_temperature + self.output_variance = output_variance + self.noise_std = noise_std + + if self.learnable_temperature: + # temperature will be learned + temperature = torch.nn.Parameter( + torch.ones(1) * temperature, requires_grad=True + ) + self.register_parameter("temperature", temperature) + else: + # temperature held constant after initialization + temperature = torch.nn.Parameter( + torch.ones(1) * temperature, requires_grad=False + ) + self.register_buffer("temperature", temperature) + + pos_x, pos_y = np.meshgrid( + np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h) + ) + pos_x = torch.from_numpy(pos_x.reshape(1, self._in_h * self._in_w)).float() + pos_y = torch.from_numpy(pos_y.reshape(1, self._in_h * self._in_w)).float() + self.register_buffer("pos_x", pos_x) + self.register_buffer("pos_y", pos_y) + + self.kps = None + + def __repr__(self): + """Pretty print network.""" + header = format(str(self.__class__.__name__)) + return header + "(num_kp={}, temperature={}, noise={})".format( + self._num_kp, self.temperature.item(), self.noise_std + ) + + def output_shape(self, input_shape): + """ + Function to compute output shape from inputs to this module. + Args: + input_shape (iterable of int): shape of input. Does not include batch dimension. + Some modules may not need this argument, if their output does not depend + on the size of the input, or if they assume fixed size input. + Returns: + out_shape ([int]): list of integers corresponding to output shape + """ + assert len(input_shape) == 3 + assert input_shape[0] == self._in_c + return [self._num_kp, 2] + + def forward(self, feature): + """ + Forward pass through spatial softmax layer. For each keypoint, a 2D spatial + probability distribution is created using a softmax, where the support is the + pixel locations. This distribution is used to compute the expected value of + the pixel location, which becomes a keypoint of dimension 2. K such keypoints + are created. + Returns: + out (torch.Tensor or tuple): mean keypoints of shape [B, K, 2], and possibly + keypoint variance of shape [B, K, 2, 2] corresponding to the covariance + under the 2D spatial softmax distribution + """ + assert feature.shape[1] == self._in_c + assert feature.shape[2] == self._in_h + assert feature.shape[3] == self._in_w + if self.nets is not None: + feature = self.nets(feature) + + # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints + feature = feature.reshape(-1, self._in_h * self._in_w) + # 2d softmax normalization + attention = F.softmax(feature / self.temperature, dim=-1) + # [1, H * W] x [B * K, H * W] -> [B * K, 1] for spatial coordinate mean in x and y dimensions + expected_x = torch.sum(self.pos_x * attention, dim=1, keepdim=True) + expected_y = torch.sum(self.pos_y * attention, dim=1, keepdim=True) + # stack to [B * K, 2] + expected_xy = torch.cat([expected_x, expected_y], 1) + # reshape to [B, K, 2] + feature_keypoints = expected_xy.view(-1, self._num_kp, 2) + + if self.training: + noise = torch.randn_like(feature_keypoints) * self.noise_std + feature_keypoints += noise + + if self.output_variance: + # treat attention as a distribution, and compute second-order statistics to return + expected_xx = torch.sum( + self.pos_x * self.pos_x * attention, dim=1, keepdim=True + ) + expected_yy = torch.sum( + self.pos_y * self.pos_y * attention, dim=1, keepdim=True + ) + expected_xy = torch.sum( + self.pos_x * self.pos_y * attention, dim=1, keepdim=True + ) + var_x = expected_xx - expected_x * expected_x + var_y = expected_yy - expected_y * expected_y + var_xy = expected_xy - expected_x * expected_y + # stack to [B * K, 4] and then reshape to [B, K, 2, 2] where last 2 dims are covariance matrix + feature_covar = torch.cat([var_x, var_xy, var_xy, var_y], 1).reshape( + -1, self._num_kp, 2, 2 + ) + feature_keypoints = (feature_keypoints, feature_covar) + + if isinstance(feature_keypoints, tuple): + self.kps = (feature_keypoints[0].detach(), feature_keypoints[1].detach()) + else: + self.kps = feature_keypoints.detach() + return feature_keypoints + + +class RasterizedMapEncoder(nn.Module): + """A basic image-based rasterized map encoder""" + + def __init__( + self, + model_arch: str, + input_image_shape: tuple = (3, 224, 224), + feature_dim: int = None, + use_spatial_softmax=False, + spatial_softmax_kwargs=None, + output_activation=nn.ReLU, + ) -> None: + super().__init__() + self.model_arch = model_arch + self.num_input_channels = input_image_shape[0] + self._feature_dim = feature_dim + if output_activation is None: + self._output_activation = nn.Identity() + else: + self._output_activation = output_activation() + + # configure conv backbone + if model_arch == "resnet18": + self.map_model = resnet18() + out_h = int(math.ceil(input_image_shape[1] / 32.0)) + out_w = int(math.ceil(input_image_shape[2] / 32.0)) + self.conv_out_shape = (512, out_h, out_w) + elif model_arch == "resnet50": + self.map_model = resnet50() + out_h = int(math.ceil(input_image_shape[1] / 32.0)) + out_w = int(math.ceil(input_image_shape[2] / 32.0)) + self.conv_out_shape = (2048, out_h, out_w) + else: + raise NotImplementedError(f"Model arch {model_arch} unknown") + + # configure spatial reduction pooling layer + if use_spatial_softmax: + pooling = SpatialSoftmax( + input_shape=self.conv_out_shape, **spatial_softmax_kwargs + ) + self.pool_out_dim = int(np.prod(pooling.output_shape(self.conv_out_shape))) + else: + pooling = nn.AdaptiveAvgPool2d((1, 1)) + self.pool_out_dim = self.conv_out_shape[0] + + self.map_model.conv1 = nn.Conv2d( + self.num_input_channels, + 64, + kernel_size=(7, 7), + stride=(2, 2), + padding=(3, 3), + bias=False, + ) + self.map_model.avgpool = pooling + if feature_dim is not None: + self.map_model.fc = nn.Linear( + in_features=self.pool_out_dim, out_features=feature_dim + ) + else: + self.map_model.fc = nn.Identity() + + def output_shape(self, input_shape=None): + if self._feature_dim is not None: + return [self._feature_dim] + else: + return [self.pool_out_dim] + + def feature_channels(self): + if self.model_arch in ["resnet18", "resnet34"]: + channels = OrderedDict( + { + "layer1": 64, + "layer2": 128, + "layer3": 256, + "layer4": 512, + } + ) + else: + channels = OrderedDict( + { + "layer1": 256, + "layer2": 512, + "layer3": 1024, + "layer4": 2048, + } + ) + return channels + + def feature_scales(self): + return OrderedDict( + {"layer1": 1 / 4, "layer2": 1 / 8, "layer3": 1 / 16, "layer4": 1 / 32} + ) + + def forward(self, map_inputs) -> torch.Tensor: + feat = self.map_model(map_inputs) + feat = self._output_activation(feat) + return feat + + +class RotatedROIAlign(nn.Module): + def __init__(self, roi_feature_size, roi_scale): + super(RotatedROIAlign, self).__init__() + from diffstack.models.roi_align import ROI_align + + self.roi_align = lambda feat, rois: ROI_align(feat, rois, roi_feature_size[0]) + self.roi_scale = roi_scale + + def forward(self, feats, list_of_rois): + scaled_rois = [] + for rois in list_of_rois: + sroi = rois.clone() + sroi[..., :-1] = rois[..., :-1] * self.roi_scale + scaled_rois.append(sroi) + list_of_feats = self.roi_align(feats, scaled_rois) + return torch.cat(list_of_feats, dim=0) + + +class RasterizeROIEncoder(nn.Module): + """Use RoI Align to crop map feature for each agent""" + + def __init__( + self, + model_arch: str, + input_image_shape: tuple = (3, 224, 224), + agent_feature_dim: int = None, + global_feature_dim: int = None, + roi_feature_size: tuple = (7, 7), + roi_layer_key: str = "layer4", + output_activation=nn.ReLU, + use_rotated_roi=True, + ) -> None: + super(RasterizeROIEncoder, self).__init__() + encoder = RasterizedMapEncoder( + model_arch=model_arch, + input_image_shape=input_image_shape, + feature_dim=global_feature_dim, + ) + feat_nodes = { + "map_model.layer1": "layer1", + "map_model.layer2": "layer2", + "map_model.layer3": "layer3", + "map_model.layer4": "layer4", + "map_model.fc": "final", + } + self.encoder_heads = create_feature_extractor(encoder, feat_nodes) + + self.roi_layer_key = roi_layer_key + roi_scale = encoder.feature_scales()[roi_layer_key] + roi_channel = encoder.feature_channels()[roi_layer_key] + if use_rotated_roi: + self.roi_align = RotatedROIAlign(roi_feature_size, roi_scale=roi_scale) + else: + self.roi_align = RoIAlign( + output_size=roi_feature_size, + spatial_scale=roi_scale, + sampling_ratio=-1, + aligned=True, + ) + + self.activation = output_activation() + if agent_feature_dim is not None: + self.agent_net = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), # [B, C, 1, 1] + nn.Flatten(start_dim=1), + nn.Linear(roi_channel, agent_feature_dim), + self.activation, + ) + else: + self.agent_net = nn.Identity() + + def forward(self, map_inputs: torch.Tensor, rois: torch.Tensor): + """ + + Args: + map_inputs (torch.Tensor): [B, C, H, W] + rois (torch.Tensor): [num_boxes, 5] + + Returns: + agent_feats (torch.Tensor): [num_boxes, ...] + roi_feats (torch.Tensor): [num_boxes, ...] + global_feats (torch.Tensor): [B, ....] + """ + feats = self.encoder_heads(map_inputs) + pre_roi_feats = feats[self.roi_layer_key] + roi_feats = self.roi_align(pre_roi_feats, rois) # [num_boxes, C, H, W] + agent_feats = self.agent_net(roi_feats) + global_feats = feats["final"] + + return agent_feats, roi_feats, self.activation(global_feats), feats + + +class RasterizedMapKeyPointNet(RasterizedMapEncoder): + """Predict a list of keypoints [x, y] given a rasterized map""" + + def __init__( + self, + model_arch: str, + input_image_shape: tuple = (3, 224, 224), + spatial_softmax_kwargs=None, + ) -> None: + super().__init__( + model_arch=model_arch, + input_image_shape=input_image_shape, + feature_dim=None, + use_spatial_softmax=True, + spatial_softmax_kwargs=spatial_softmax_kwargs, + ) + + def forward(self, map_inputs) -> torch.Tensor: + outputs = super(RasterizedMapKeyPointNet, self).forward(map_inputs) + # reshape back to kp format + return outputs.reshape(outputs.shape[0], -1, 2) + + +class RasterizedMapUNet(nn.Module): + """Predict a spatial map same size as the input rasterized map""" + + def __init__( + self, + model_arch: str, + input_image_shape: tuple = (3, 224, 224), + output_channel=4, + use_spatial_softmax=False, + spatial_softmax_kwargs=None, + ) -> None: + super(RasterizedMapUNet, self).__init__() + encoder = RasterizedMapEncoder( + model_arch=model_arch, + input_image_shape=input_image_shape, + use_spatial_softmax=use_spatial_softmax, + spatial_softmax_kwargs=spatial_softmax_kwargs, + ) + self.input_image_shape = input_image_shape + # build graph for extracting intermediate features + feat_nodes = { + "map_model.layer1": "layer1", + "map_model.layer2": "layer2", + "map_model.layer3": "layer3", + "map_model.layer4": "layer4", + } + self.encoder_heads = create_feature_extractor(encoder, feat_nodes) + encoder_channels = list(encoder.feature_channels().values()) + self.decoder = UNetDecoder( + input_channel=encoder_channels[-1], + encoder_channels=encoder_channels[:-1], + output_channel=output_channel, + up_factor=2, + ) + + def forward(self, map_inputs, encoder_feats=None): + if encoder_feats is None: + encoder_feats = self.encoder_heads(map_inputs) + encoder_feats = [ + encoder_feats[k] for k in ["layer1", "layer2", "layer3", "layer4"] + ] + return self.decoder.forward( + feat_to_decode=encoder_feats[-1], + encoder_feats=encoder_feats[:-1], + target_hw=self.input_image_shape[1:], + ) + + +class RNNTrajectoryEncoder(nn.Module): + def __init__( + self, + trajectory_dim, + rnn_hidden_size, + feature_dim=None, + mlp_layer_dims: tuple = (), + mode="last", + ): + super(RNNTrajectoryEncoder, self).__init__() + self.mode = mode + self.lstm = nn.LSTM( + trajectory_dim, hidden_size=rnn_hidden_size, batch_first=True + ) + if feature_dim is not None: + self.mlp = MLP( + input_dim=rnn_hidden_size, + output_dim=feature_dim, + layer_dims=mlp_layer_dims, + output_activation=nn.ReLU, + ) + self._feature_dim = feature_dim + else: + self.mlp = nn.Identity() + self._feature_dim = rnn_hidden_size + + def output_shape(self, input_shape=None): + return [self._feature_dim] + + def forward(self, input_trajectory): + if self.mode == "last": + traj_feat = self.lstm(input_trajectory)[0][:, -1, :] + elif self.mode == "all": + traj_feat = self.lstm(input_trajectory)[0] + traj_feat = self.mlp(traj_feat) + return traj_feat + + +class RNNFeatureRoller(nn.Module): + def __init__(self, trajectory_dim, feature_dim): + super(RNNFeatureRoller, self).__init__() + self.gru = nn.GRU(trajectory_dim, hidden_size=feature_dim, batch_first=True) + self._feature_dim = feature_dim + + def output_shape(self, input_shape=None): + return [self._feature_dim] + + def forward(self, feature, input_trajectory): + _, hn = self.gru(input_trajectory, feature.unsqueeze(0)) + return hn[0] + feature + + +class PosteriorEncoder(nn.Module): + """Posterior Encoder (x, x_c -> q) for CVAE""" + + def __init__( + self, + condition_dim: int, + trajectory_shape: tuple, # [T, D] + output_shapes: OrderedDict, + mlp_layer_dims: tuple = (128, 128), + rnn_hidden_size: int = 100, + ) -> None: + super(PosteriorEncoder, self).__init__() + self.trajectory_shape = trajectory_shape + + self.traj_encoder = RNNTrajectoryEncoder( + trajectory_dim=trajectory_shape[-1], rnn_hidden_size=rnn_hidden_size + ) + self.mlp = SplitMLP( + input_dim=(rnn_hidden_size + condition_dim), + output_shapes=output_shapes, + layer_dims=mlp_layer_dims, + output_activation=nn.ReLU, + ) + + def forward(self, inputs, condition_features) -> Dict[str, torch.Tensor]: + traj_feat = self.traj_encoder(inputs["trajectories"]) + feat = torch.cat((traj_feat, condition_features), dim=-1) + return self.mlp(feat) + + +class ScenePosteriorEncoder(nn.Module): + """Scene Posterior Encoder (x, x_c -> q) for CVAE""" + + def __init__( + self, + condition_dim: int, + trajectory_shape: tuple, # [T, D] + output_shapes: OrderedDict, + aggregate_func="max", + mlp_layer_dims: tuple = (128, 128), + rnn_hidden_size: int = 100, + ) -> None: + super(ScenePosteriorEncoder, self).__init__() + self.trajectory_shape = trajectory_shape + self.transformer = SimpleTransformer(src_dim=rnn_hidden_size + condition_dim) + self.aggregate_func = aggregate_func + + self.traj_encoder = RNNTrajectoryEncoder( + trajectory_dim=trajectory_shape[-1], rnn_hidden_size=rnn_hidden_size + ) + self.mlp = SplitMLP( + input_dim=(rnn_hidden_size + condition_dim), + output_shapes=output_shapes, + layer_dims=mlp_layer_dims, + output_activation=nn.ReLU, + ) + + def forward(self, inputs, condition_features, mask, pos) -> Dict[str, torch.Tensor]: + bs, Na, T = inputs["trajectories"].shape[:3] + + traj_feat = self.traj_encoder( + TensorUtils.join_dimensions(inputs["trajectories"], 0, 2) + ) + feat = torch.cat( + (traj_feat, TensorUtils.join_dimensions(condition_features, 0, 2)), dim=-1 + ).reshape(bs, Na, -1) + + feat = self.transformer(feat, mask, pos) + feat + if self.aggregate_func == "max": + feat = feat.max(1)[0] + elif self.aggregate_func == "mean": + feat = feat.mean(1) + + return self.mlp(feat) + + +class ConditionEncoder(nn.Module): + """Condition Encoder (x -> c) for CVAE""" + + def __init__( + self, + map_encoder: nn.Module, + trajectory_shape: tuple, # [T, D] + condition_dim: int, + mlp_layer_dims: tuple = (128, 128), + goal_encoder=None, + agent_traj_encoder=None, + ) -> None: + super(ConditionEncoder, self).__init__() + self.map_encoder = map_encoder + self.trajectory_shape = trajectory_shape + self.goal_encoder = goal_encoder + self.agent_traj_encoder = agent_traj_encoder + + visual_feature_size = self.map_encoder.output_shape()[0] + self.mlp = MLP( + input_dim=visual_feature_size, + output_dim=condition_dim, + layer_dims=mlp_layer_dims, + output_activation=nn.ReLU, + ) + + def forward(self, condition_inputs): + map_feat = self.map_encoder(condition_inputs["image"]) + c_feat = self.mlp(map_feat) + if self.goal_encoder is not None: + goal_feat = self.goal_encoder(condition_inputs["goal"]) + c_feat = torch.cat([c_feat, goal_feat], dim=-1) + if self.agent_traj_encoder is not None and "agent_traj" in condition_inputs: + agent_traj_feat = self.agent_traj_encoder(condition_inputs["agent_traj"]) + c_feat = torch.cat([c_feat, agent_traj_feat], dim=-1) + return c_feat + + +class ECEncoder(nn.Module): + """Condition Encoder (x -> c) for CVAE""" + + def __init__( + self, + map_encoder, + trajectory_shape: tuple, # [T, D] + EC_dim: int, + condition_dim: int, + mlp_layer_dims: tuple = (128, 128), + goal_encoder=None, + rnn_hidden_size: int = 100, + ) -> None: + super(ECEncoder, self).__init__() + if isinstance(map_encoder, nn.Module): + self.map_encoder = map_encoder + visual_feature_size = self.map_encoder.output_shape()[0] + elif isinstance(map_encoder, int): + visual_feature_size = map_encoder + self.map_encoder = None + self.trajectory_shape = trajectory_shape + self.goal_encoder = goal_encoder + self.EC_dim = EC_dim + self.traj_encoder = RNNTrajectoryEncoder( + trajectory_shape[-1], rnn_hidden_size, feature_dim=EC_dim + ) + goal_dim = 0 if goal_encoder is None else goal_encoder.output_shape()[0] + self.mlp = MLP( + input_dim=visual_feature_size + goal_dim + EC_dim, + output_dim=condition_dim, + layer_dims=mlp_layer_dims, + output_activation=nn.ReLU, + ) + + def forward(self, condition_inputs): + if self.map_encoder is None: + c_feat = condition_inputs["map_feature"] + else: + c_feat = self.map_encoder(condition_inputs["image"]) + if self.goal_encoder is not None: + goal_feat = self.goal_encoder(condition_inputs["goal"]) + c_feat = torch.cat([c_feat, goal_feat], dim=-1) + if ( + "cond_traj" in condition_inputs + and condition_inputs["cond_traj"] is not None + ): + if condition_inputs["cond_traj"].ndim == 3: + EC_feat = self.traj_encoder(condition_inputs["cond_traj"]).unsqueeze(1) + elif condition_inputs["cond_traj"].ndim == 4: + bs, M, T, D = condition_inputs["cond_traj"].shape + EC_feat = self.traj_encoder( + condition_inputs["cond_traj"].reshape(-1, T, D) + ).reshape(bs, M, -1) + + EC_feat = EC_feat * (condition_inputs["cond_traj"][..., 0] != 0).any( + -1 + ).unsqueeze(-1) + EC_feat = torch.cat( + (EC_feat, torch.zeros([bs, 1, self.EC_dim]).to(EC_feat.device)), 1 + ) + else: + bs = c_feat.shape[0] + EC_feat = torch.zeros([bs, 1, self.EC_dim]).to(c_feat.device) + if c_feat.ndim == 2: + c_feat = c_feat.unsqueeze(1).repeat(1, EC_feat.shape[1], 1) + else: + assert c_feat.ndim == 3 and c_feat.shape[1] == EC_feat.shape[1] + c_feat = torch.cat((c_feat, EC_feat), -1) + c_feat = self.mlp(c_feat) + + return c_feat + + +class AgentTrajEncoder(nn.Module): + """Condition Encoder (x -> c) for CVAE""" + + def __init__( + self, + trajectory_shape: tuple, # [T, D] + feature_dim: int, + mlp_layer_dims: tuple = (128, 128), + rnn_hidden_size: int = 100, + use_transformer=True, + ) -> None: + super(AgentTrajEncoder, self).__init__() + self.trajectory_shape = trajectory_shape + self.feature_dim = feature_dim + self.traj_encoder = RNNTrajectoryEncoder( + trajectory_shape[-1], + rnn_hidden_size, + feature_dim=feature_dim, + mlp_layer_dims=mlp_layer_dims, + ) + if use_transformer: + self.transformer = SimpleTransformer(src_dim=feature_dim) + else: + self.transformer = None + + def forward(self, agent_trajs): + bs, M, T, D = agent_trajs.shape + feat = self.traj_encoder(agent_trajs.reshape(-1, T, D)).reshape(bs, M, -1) + if self.transformer is not None: + agent_pos = agent_trajs[:, :, 0, :2] + avails = (agent_pos != 0).any(-1) + feat = self.transformer(feat, avails, agent_pos) + + return torch.max(feat, 1)[0] + + +class PosteriorNet(nn.Module): + def __init__( + self, + input_shapes: OrderedDict, + condition_dim: int, + param_shapes: OrderedDict, + mlp_layer_dims: tuple = (), + ): + super(PosteriorNet, self).__init__() + all_shapes = deepcopy(input_shapes) + all_shapes["condition_features"] = (condition_dim,) + self.mlp = MIMOMLP( + input_shapes=all_shapes, + output_shapes=param_shapes, + layer_dims=mlp_layer_dims, + output_activation=None, + ) + + def forward(self, inputs: dict, condition_features: torch.Tensor): + all_inputs = dict(inputs) + all_inputs["condition_features"] = condition_features + return self.mlp(all_inputs) + + +class ConditionNet(nn.Module): + def __init__( + self, + condition_input_shapes: OrderedDict, + condition_dim: int, + mlp_layer_dims: tuple = (), + ): + super(ConditionNet, self).__init__() + self.mlp = MIMOMLP( + input_shapes=condition_input_shapes, + output_shapes=OrderedDict(feat=(condition_dim,)), + layer_dims=mlp_layer_dims, + output_activation=nn.ReLU, + ) + + def forward(self, inputs: dict): + return self.mlp(inputs)["feat"] + + +class ConditionDecoder(nn.Module): + """Decoding (z, c) -> x' using a flat MLP""" + + def __init__(self, decoder_model: nn.Module): + super(ConditionDecoder, self).__init__() + self.decoder_model = decoder_model + + def forward(self, latents, condition_features, **decoder_kwargs): + return self.decoder_model( + torch.cat((latents, condition_features), dim=-1), **decoder_kwargs + ) + + +class TrajectoryDecoder(nn.Module): + def __init__( + self, + feature_dim: int, + state_dim: int = 3, + num_steps: int = None, + dynamics_type: Union[str, dynamics.DynType] = None, + dynamics_kwargs: dict = None, + step_time: float = None, + network_kwargs: dict = None, + Gaussian_var=False, + ): + """ + A class that predict future trajectories based on input features + Args: + feature_dim (int): dimension of the input feature + state_dim (int): dimension of the output trajectory at each step + num_steps (int): (optional) number of future state to predict + dynamics_type (str, dynamics.DynType): (optional) if specified, the network predicts action + for the dynamics model instead of future states. The actions are then used to predict + the future trajectories. + step_time (float): time between steps. required for using dynamics models + network_kwargs (dict): keyword args for the decoder networks + Gaussian_var (bool): whether output the variance of the predicted trajectory + """ + super(TrajectoryDecoder, self).__init__() + self.feature_dim = feature_dim + self.state_dim = state_dim + self.num_steps = num_steps + self.step_time = step_time + self._network_kwargs = network_kwargs + self._dynamics_type = dynamics_type + self._dynamics_kwargs = dynamics_kwargs + self.Gaussian_var = Gaussian_var + self._create_dynamics() + self._create_networks() + + def _create_dynamics(self): + if self._dynamics_type in ["Unicycle", dynamics.DynType.UNICYCLE]: + self.dyn = dynamics.Unicycle( + self.step_time, + max_steer=self._dynamics_kwargs["max_steer"], + max_yawvel=self._dynamics_kwargs["max_yawvel"], + acce_bound=self._dynamics_kwargs["acce_bound"], + ) + elif self._dynamics_type in ["Bicycle", dynamics.DynType.BICYCLE]: + self.dyn = dynamics.Bicycle( + acc_bound=self._dynamics_kwargs["acce_bound"], + ddh_bound=self._dynamics_kwargs["ddh_bound"], + max_hdot=self._dynamics_kwargs["max_yawvel"], + max_speed=self._dynamics_kwargs["max_speed"], + ) + else: + self.dyn = None + + def _create_networks(self): + raise NotImplementedError + + def _forward_networks(self, inputs, current_states=None, num_steps=None): + raise NotImplementedError + + def _forward_dynamics(self, current_states, actions): + assert self.dyn is not None + assert current_states.shape[-1] == self.dyn.xdim + assert actions.shape[-1] == self.dyn.udim + assert isinstance(self.step_time, float) and self.step_time > 0 + x = self.dyn.forward_dynamics( + x0=current_states, + u=actions, + bound=False, + ) + pos = self.dyn.state2pos(x) + yaw = self.dyn.state2yaw(x) + traj = torch.cat((pos, yaw), dim=-1) + return traj, x + + def forward(self, inputs, current_states=None, num_steps=None): + preds = self._forward_networks( + inputs, current_states=current_states, num_steps=num_steps + ) + if self.dyn is not None: + preds["controls"] = preds["trajectories"] + preds["trajectories"], x = self._forward_dynamics( + current_states=current_states, actions=preds["trajectories"] + ) + preds["terminal_state"] = x[..., -1, :] + return preds + + +class MLPTrajectoryDecoder(TrajectoryDecoder): + def _create_networks(self): + net_kwargs = ( + dict() if self._network_kwargs is None else dict(self._network_kwargs) + ) + if self._network_kwargs is None: + net_kwargs = dict() + + assert isinstance(self.num_steps, int) + if self.dyn is None: + pred_shapes = OrderedDict(trajectories=(self.num_steps, self.state_dim)) + else: + pred_shapes = OrderedDict(trajectories=(self.num_steps, self.dyn.udim)) + if self.Gaussian_var: + pred_shapes["logvar"] = (self.num_steps, self.state_dim) + + state_as_input = net_kwargs.pop("state_as_input") + if self.dyn is not None: + assert state_as_input # TODO: deprecated, set default to True and remove from configs + + if state_as_input and self.dyn is not None: + feature_dim = self.feature_dim + self.dyn.xdim + else: + feature_dim = self.feature_dim + + self.mlp = SplitMLP( + input_dim=feature_dim, + output_shapes=pred_shapes, + output_activation=None, + **net_kwargs, + ) + + def _forward_networks(self, inputs, current_states=None, num_steps=None): + if self._network_kwargs["state_as_input"] and self.dyn is not None: + inputs = torch.cat((inputs, current_states), dim=-1) + + if inputs.ndim == 2: + # [B, D] + preds = self.mlp(inputs) + elif inputs.ndim == 3: + # [B, A, D] + preds = TensorUtils.time_distributed(inputs, self.mlp) + else: + raise ValueError( + "Expecting inputs to have ndim == 2 or 3, got {}".format(inputs.ndim) + ) + return preds + + +class MLPECTrajectoryDecoder(TrajectoryDecoder): + def __init__( + self, + feature_dim: int, + state_dim: int = 3, + num_steps: int = None, + dynamics_type: Union[str, dynamics.DynType] = None, + dynamics_kwargs: dict = None, + step_time: float = None, + EC_feature_dim=64, + network_kwargs: dict = None, + Gaussian_var=False, + ): + """ + A class that predict future trajectories based on input features + Args: + feature_dim (int): dimension of the input feature + state_dim (int): dimension of the output trajectory at each step + num_steps (int): (optional) number of future state to predict + dynamics_type (str, dynamics.DynType): (optional) if specified, the network predicts action + for the dynamics model instead of future states. The actions are then used to predict + the future trajectories. + step_time (float): time between steps. required for using dynamics models + network_kwargs (dict): keyword args for the decoder networks + Gaussian_var (bool): whether output the variance of the predicted trajectory + """ + super(TrajectoryDecoder, self).__init__() + self.feature_dim = feature_dim + self.state_dim = state_dim + self.num_steps = num_steps + self.step_time = step_time + self.EC_feature_dim = EC_feature_dim + self._network_kwargs = network_kwargs + self._dynamics_type = dynamics_type + self._dynamics_kwargs = dynamics_kwargs + self.Gaussian_var = Gaussian_var + self._create_dynamics() + self._create_networks() + + def _create_networks(self): + net_kwargs = ( + dict() if self._network_kwargs is None else dict(self._network_kwargs) + ) + if self._network_kwargs is None: + net_kwargs = dict() + + assert isinstance(self.num_steps, int) + if self.dyn is None: + pred_shapes = OrderedDict(trajectories=(self.num_steps, self.state_dim)) + else: + pred_shapes = OrderedDict(trajectories=(self.num_steps, self.dyn.udim)) + if self.Gaussian_var: + pred_shapes["logvar"] = (self.num_steps, self.state_dim) + + state_as_input = net_kwargs.pop("state_as_input") + if self.dyn is not None: + assert state_as_input # TODO: deprecated, set default to True and remove from configs + + if state_as_input and self.dyn is not None: + feature_dim = self.feature_dim + self.dyn.xdim + else: + feature_dim = self.feature_dim + + self.mlp = SplitMLP( + input_dim=feature_dim, + output_shapes=pred_shapes, + output_activation=None, + **net_kwargs, + ) + self.offsetmlp = SplitMLP( + input_dim=feature_dim + self.EC_feature_dim, + output_shapes=pred_shapes, + output_activation=None, + **net_kwargs, + ) + + def _forward_networks( + self, inputs, EC_feat=None, current_states=None, num_steps=None + ): + if self._network_kwargs["state_as_input"] and self.dyn is not None: + inputs = torch.cat((inputs, current_states), dim=-1) + if inputs.ndim == 2: + # [B, D] + + preds = self.mlp(inputs) + if EC_feat is not None: + bs, M = EC_feat.shape[:2] + + # EC_feat = self.traj_encoder(cond_traj.reshape(-1,T,D)).reshape(bs,M,-1) + inputs_tile = inputs.unsqueeze(1).tile(1, M, 1) + EC_feat = torch.cat((inputs_tile, EC_feat), dim=-1) + EC_preds = TensorUtils.time_distributed(EC_feat, self.offsetmlp) + EC_preds["trajectories"] = EC_preds["trajectories"] + preds[ + "trajectories" + ].unsqueeze(1) + else: + EC_preds = None + + elif inputs.ndim == 3: + # [B, A, D] + preds = TensorUtils.time_distributed(inputs, self.mlp) + if EC_feat is not None: + assert EC_feat.ndim == 4 + bs, A, M = EC_feat.shape[:3] + # EC_feat = self.traj_encoder(cond_traj.reshape(-1,T,D)).reshape(bs,M,-1) + inputs_tile = inputs.tile(1, M, 1) + EC_feat = torch.cat((inputs_tile, EC_feat), dim=-1) + EC_preds = TensorUtils.time_distributed(EC_feat, self.offsetmlp) + EC_preds = reshape_dimensions(EC_preds, 1, 2, (A, M)) + EC_preds["trajectories"] = EC_preds["trajectories"] + preds[ + "trajectories" + ].unsqueeze(2) + else: + EC_preds = None + else: + raise ValueError( + "Expecting inputs to have ndim == 2 or 3, got {}".format(inputs.ndim) + ) + return preds, EC_preds + + def _forward_dynamics(self, current_states, actions, **kwargs): + assert self.dyn is not None + assert current_states.shape[-1] == self.dyn.xdim + assert actions.shape[-1] == self.dyn.udim + assert isinstance(self.step_time, float) and self.step_time > 0 + + x = self.dyn.forward_dynamics( + x0=current_states, + u=actions, + bound=False, + ) + pos = self.dyn.state2pos(x) + yaw = self.dyn.state2yaw(x) + traj = torch.cat((pos, yaw), dim=-1) + return traj + + def forward(self, inputs, current_states=None, EC_feat=None, num_steps=None): + preds, EC_preds = self._forward_networks( + inputs, EC_feat, current_states=current_states, num_steps=num_steps + ) + if self.dyn is not None: + preds["controls"] = preds["trajectories"] + if EC_preds is None: + preds["trajectories"] = self._forward_dynamics( + current_states=current_states, actions=preds["trajectories"] + ) + else: + total_actions = torch.cat( + (preds["trajectories"].unsqueeze(1), EC_preds["trajectories"]), 1 + ) + bs, A, T, D = total_actions.shape + current_states_tiled = current_states.unsqueeze(1).repeat( + 1, total_actions.size(1), 1 + ) + total_trajectories = self._forward_dynamics( + current_states=current_states_tiled.reshape(bs * A, -1), + actions=total_actions.reshape(bs * A, T, D), + ).reshape(*total_actions.shape[:-1], -1) + preds["trajectories"] = total_trajectories[:, 0] + preds["EC_trajectories"] = total_trajectories[:, 1:] + else: + preds["EC_trajectories"] = EC_preds["trajectories"] + + return preds + + +class clique_gibbs_distr(nn.Module): + def __init__( + self, + state_enc_dim, + edge_encoding_dim, + z_dim, + edge_types, + node_types, + node_hidden_dim=[64, 64], + edge_hidden_dim=[64, 64], + ): + super(clique_gibbs_distr, self).__init__() + self.edge_encoding_dim = edge_encoding_dim + self.z_dim = z_dim + self.state_enc_dim = state_enc_dim + self.node_types = node_types + self.edge_types = edge_types + self.et_name = dict() + for et in self.edge_types: + self.et_name[et] = et[0].name + "->" + et[1].name + + self.node_factor = nn.ModuleDict() + self.edge_factor = nn.ModuleDict() + for node_type in self.node_types: + self.node_factor[node_type.name] = MLP( + state_enc_dim[node_type], z_dim, node_hidden_dim + ) + + for edge_type in self.edge_types: + self.edge_factor[self.et_name[edge_type]] = MLP( + edge_encoding_dim[edge_type], z_dim**2, edge_hidden_dim + ) + + def forward( + self, + node_types, + node_enc, + edge_enc, + clique_index, + clique_node_index, + clique_order_index, + clique_is_robot=None, + ): + device = node_enc.device + bs, Na = node_types.shape[:2] + Nc, max_cs = clique_node_index.shape[1:] + + # clique_edge_mask = (clique_index.unsqueeze(1)==clique_index.unsqueeze(2))*(clique_index>=0).unsqueeze(1) + # clique_edge_mask = clique_edge_mask*(~torch.eye(Na,dtype=torch.bool,device=device).unsqueeze(0)) + node_factor = torch.zeros([bs, Na, self.z_dim]).to(device) + for nt in self.node_types: + node_factor += self.node_factor[nt.name](node_enc) * ( + node_types == nt.value + ).unsqueeze(-1) + + edge_factor = torch.zeros([bs, Na, Na, self.z_dim**2]).to(device) + for et in self.edge_types: + et_flag = (node_types == et[0].value).unsqueeze(1) * ( + node_types == et[1].value + ).unsqueeze(2) + edge_factor += self.edge_factor[self.et_name[et]]( + edge_enc[..., -1, :] + ) * et_flag.unsqueeze(-1) + edge_factor = edge_factor.reshape(bs, Na, Na, self.z_dim, self.z_dim) + + node_factor_extended = torch.cat( + (node_factor, torch.zeros_like(node_factor[:, 0:1])), 1 + ) + edge_factor_extended = torch.cat( + ( + torch.cat( + ( + edge_factor, + torch.zeros([bs, 1, Na, self.z_dim, self.z_dim], device=device), + ), + 1, + ), + torch.zeros([bs, Na + 1, 1, self.z_dim, self.z_dim], device=device), + ), + 2, + ) + edge_factor_extended_flat = edge_factor_extended.reshape( + bs, -1, self.z_dim, self.z_dim + ) + + logpi = torch.zeros([bs, Nc, *(max_cs * [self.z_dim])]).to(device) + z = torch.stack(torch.meshgrid(*(max_cs * [torch.arange(self.z_dim)])), -1).to( + device + ) + + # add the node factor to the Gibbs distribution + idx = ( + clique_node_index.reshape(bs, -1) + .unsqueeze(-1) + .repeat_interleave(self.z_dim, -1) + ) + idx = idx.masked_fill(idx == -1, Na) + node_factor_clique = torch.gather(node_factor_extended, 1, idx).reshape( + bs, Nc, max_cs, self.z_dim + ) + for i in range(max_cs): + logpi += node_factor_clique[..., i, :].reshape( + [bs, Nc, *(i * [1]), self.z_dim, *((max_cs - i - 1) * [1])] + ) + + # add the edge factor to the Gibbs distribution + for i in range(max_cs - 1): + for j in range(i + 1, max_cs): + idx = clique_node_index[:, :, [i, j]] + idx = idx.masked_fill(idx == -1, Na) + idx_flat = idx[..., 0] * (Na + 1) + idx[..., 1] + idx_flat = ( + idx_flat.reshape(bs, -1, 1, 1) + .repeat_interleave(self.z_dim, 2) + .repeat_interleave(self.z_dim, 3) + ) + edge_factor_ij = torch.gather(edge_factor_extended_flat, 1, idx_flat) + logpi += edge_factor_ij.reshape( + bs, + Nc, + *(i * [1]), + self.z_dim, + *((j - i - 1) * [1]), + self.z_dim, + *((max_cs - j - 1) * [1]), + ) + + return logpi, z + + +class AdditiveAttention(nn.Module): + # Implementing the attention module of Bahdanau et al. 2015 where + # score(h_j, s_(i-1)) = v . tanh(W_1 h_j + W_2 s_(i-1)) + def __init__( + self, encoder_hidden_state_dim, decoder_hidden_state_dim, internal_dim=None + ): + super(AdditiveAttention, self).__init__() + self.encoder_hidden_state_dim = encoder_hidden_state_dim + self.decoder_hidden_state_dim = decoder_hidden_state_dim + if internal_dim is None: + internal_dim = int( + (encoder_hidden_state_dim + decoder_hidden_state_dim) / 2 + ) + self.internal_dim = internal_dim + self.w1 = nn.Linear(encoder_hidden_state_dim, internal_dim, bias=False) + self.w2 = nn.Linear(decoder_hidden_state_dim, internal_dim, bias=False) + self.v = nn.Linear(internal_dim, 1, bias=False) + + def score(self, encoder_state, decoder_state): + # encoder_state is of shape (batch, enc_dim) + # decoder_state is of shape (batch, dec_dim) + # return value should be of shape (batch, 1) + return self.v(torch.tanh(self.w1(encoder_state) + self.w2(decoder_state))) + + def forward(self, encoder_states, decoder_state): + # encoder_states is of shape (batch, num_enc_states, enc_dim) + # decoder_state is of shape (batch, dec_dim) + score_vec = torch.cat( + [ + self.score(encoder_states[:, i], decoder_state) + for i in range(encoder_states.shape[1]) + ], + dim=1, + ) + # score_vec is of shape (batch, num_enc_states) + + attention_probs = torch.unsqueeze(F.softmax(score_vec, dim=1), dim=2) + # attention_probs is of shape (batch, num_enc_states, 1) + + final_context_vec = torch.sum(attention_probs * encoder_states, dim=1) + # final_context_vec is of shape (batch, enc_dim) + + return final_context_vec, attention_probs + + +if __name__ == "__main__": + # model = RasterizedMapUNet(model_arch="resnet18", input_image_shape=(15, 224, 224), output_channel=4) + t = torch.randn(2, 15, 224, 224) + centers = torch.randint(low=10, high=214, size=(10, 2)).float() + yaws = torch.zeros(size=(10, 1)) + extents = torch.ones(10, 2) * 2 + from diffstack.utils.geometry_utils import ( + get_box_world_coords, + get_square_upright_box, + ) + + boxes = get_box_world_coords(pos=centers, yaw=yaws, extent=extents) + boxes_aligned = torch.flatten(boxes[:, [0, 2], :], start_dim=1) + box_aligned = get_square_upright_box(centers, 2) + from IPython import embed + + embed() + boxes_indices = torch.zeros(10, 1) + boxes_indices[5:] = 1 + boxes_indexed = torch.cat((boxes_indices, boxes_aligned), dim=1) + + model = RasterizeROIEncoder(model_arch="resnet50", input_image_shape=(15, 224, 224)) + output = model(t, boxes_indexed) + for f in output: + print(f.shape) diff --git a/diffstack/models/cnn_roi_encoder.py b/diffstack/models/cnn_roi_encoder.py new file mode 100644 index 0000000..ea197f9 --- /dev/null +++ b/diffstack/models/cnn_roi_encoder.py @@ -0,0 +1,558 @@ +from logging import raiseExceptions +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffstack.utils.geometry_utils import batch_nd_transform_points + + +class CNNROIMapEncoder(nn.Module): + def __init__( + self, + map_channels, + hidden_channels, + ROI_outdim, + output_size, + kernel_size, + strides, + input_size, + ): + """ + multi-layer CNN with ROI align for the output + Args: + map_channels (int): map channel numbers + ROI (list): list of ROIs + ROI_outdim (int): ROI points along each dim total interpolating points: ROI_outdim x ROI_outdim + output_size (int): output feature size + kernel_size (list): CNN kernel size for each layer + strides (list): CNN strides for each layer + input_size (tuple): map size + + """ + super(CNNROIMapEncoder, self).__init__() + self.convs = nn.ModuleList() + self.bns = nn.ModuleList() + self.num_channel_last = hidden_channels[-1] + self.ROI_outdim = ROI_outdim + x_dummy = torch.ones([map_channels, *input_size]).unsqueeze(0) * torch.tensor( + float("nan") + ) + + for i, hidden_size in enumerate(hidden_channels): + self.convs.append( + nn.Conv2d( + map_channels if i == 0 else hidden_channels[i - 1], + hidden_channels[i], + kernel_size[i], + stride=strides[i], + padding=int((kernel_size[i] - 1) / 2), + ) + ) + self.bns.append(nn.BatchNorm2d(hidden_size)) + x_dummy = self.convs[i](x_dummy) + + "fully connected layer after ROI align" + self.fc = nn.Linear( + ROI_outdim * ROI_outdim * self.num_channel_last, output_size + ) + + def forward(self, x, ROI): + """ + + Args: + x (torch.tensor): image + ROI (list): ROIs + + Returns: + out (list): ROI align result for each ROI + """ + + for conv, bn in zip(self.convs, self.bns): + x0 = x + x = F.leaky_relu(conv(x), 0.2) + x = bn(x) + x = ROI_align(x, ROI, self.ROI_outdim) + out = [None] * len(x) + for i in range(len(x)): + out[i] = self.fc(x[i].flatten(start_dim=-3)) + + return out + + +# def bilinear_interpolate(img, x, y, floattype=torch.float): +# """Return bilinear interpolation of 4 nearest pts w.r.t to x,y from img +# Args: +# img (torch.Tensor): Tensor of size cxwxh. Usually one channel of feature layer +# x (torch.Tensor): Float dtype, x axis location for sampling +# y (torch.Tensor): Float dtype, y axis location for sampling +# batched version + +# Returns: +# torch.Tensor: interpolated value +# """ +# bs = img.size(0) +# x0 = torch.floor(x).type(torch.cuda.LongTensor) +# x1 = x0 + 1 + +# y0 = torch.floor(y).type(torch.cuda.LongTensor) +# y1 = y0 + 1 + +# x0 = torch.clamp(x0, 0, img.shape[-2] - 1) +# x1 = torch.clamp(x1, 0, img.shape[-2] - 1) +# y0 = torch.clamp(y0, 0, img.shape[-1] - 1) +# y1 = torch.clamp(y1, 0, img.shape[-1] - 1) + +# Ia = [None] * bs +# Ib = [None] * bs +# Ic = [None] * bs +# Id = [None] * bs +# for i in range(bs): +# Ia[i] = img[i, ..., y0[i], x0[i]] +# Ib[i] = img[i, ..., y1[i], x0[i]] +# Ic[i] = img[i, ..., y0[i], x1[i]] +# Id[i] = img[i, ..., y1[i], x1[i]] + +# Ia = torch.stack(Ia, dim=0) +# Ib = torch.stack(Ib, dim=0) +# Ic = torch.stack(Ic, dim=0) +# Id = torch.stack(Id, dim=0) + +# step = (x1.type(floattype) - x0.type(floattype)) * ( +# y1.type(floattype) - y0.type(floattype) +# ) +# step = torch.clamp(step, 1e-3, 2) +# norm_const = 1 / step + +# wa = (x1.type(floattype) - x) * (y1.type(floattype) - y) * norm_const +# wb = (x1.type(floattype) - x) * (y - y0.type(floattype)) * norm_const +# wc = (x - x0.type(floattype)) * (y1.type(floattype) - y) * norm_const +# wd = (x - x0.type(floattype)) * (y - y0.type(floattype)) * norm_const +# return ( +# Ia * wa.unsqueeze(1) +# + Ib * wb.unsqueeze(1) +# + Ic * wc.unsqueeze(1) +# + Id * wd.unsqueeze(1) +# ) +def bilinear_interpolate(img, x, y, floattype=torch.float, flip_y=False): + """Return bilinear interpolation of 4 nearest pts w.r.t to x,y from img + Args: + img (torch.Tensor): Tensor of size cxwxh. Usually one channel of feature layer + x (torch.Tensor): Float dtype, x axis location for sampling + y (torch.Tensor): Float dtype, y axis location for sampling + + Returns: + torch.Tensor: interpolated value + """ + if flip_y: + y = img.shape[-2] - 1-y + if img.device.type == "cuda": + x0 = torch.floor(x).type(torch.cuda.LongTensor) + y0 = torch.floor(y).type(torch.cuda.LongTensor) + elif img.device.type == "cpu": + x0 = torch.floor(x).type(torch.LongTensor) + y0 = torch.floor(y).type(torch.LongTensor) + else: + raise ValueError("device not recognized") + x1 = x0 + 1 + y1 = y0 + 1 + + x0 = torch.clamp(x0, 0, img.shape[-1] - 1) + x1 = torch.clamp(x1, 0, img.shape[-1] - 1) + y0 = torch.clamp(y0, 0, img.shape[-2] - 1) + y1 = torch.clamp(y1, 0, img.shape[-2] - 1) + + Ia = img[..., y0, x0] + Ib = img[..., y1, x0] + Ic = img[..., y0, x1] + Id = img[..., y1, x1] + + step = (x1.type(floattype) - x0.type(floattype)) * ( + y1.type(floattype) - y0.type(floattype) + ) + step = torch.clamp(step, 1e-3, 2) + norm_const = 1 / step + + wa = (x1.type(floattype) - x) * (y1.type(floattype) - y) * norm_const + wb = (x1.type(floattype) - x) * (y - y0.type(floattype)) * norm_const + wc = (x - x0.type(floattype)) * (y1.type(floattype) - y) * norm_const + wd = (x - x0.type(floattype)) * (y - y0.type(floattype)) * norm_const + return ( + Ia * wa.unsqueeze(0) + + Ib * wb.unsqueeze(0) + + Ic * wc.unsqueeze(0) + + Id * wd.unsqueeze(0) + ) + + +def ROI_align(features, ROI, outdim): + """Given feature layers and proposals return bilinear interpolated + points in feature layer + + Args: + features (torch.Tensor): Tensor of shape channels x width x height + proposal (list of torch.Tensor): x0,y0,W1,W2,H1,H2,psi + """ + + bs, num_channels, h, w = features.shape + + xg = ( + torch.cat( + ( + torch.arange(0, outdim).view(-1, 1) - (outdim - 1) / 2, + torch.zeros([outdim, 1]), + ), + dim=-1, + ) + / outdim + ) + yg = ( + torch.cat( + ( + torch.zeros([outdim, 1]), + torch.arange(0, outdim).view(-1, 1) - (outdim - 1) / 2, + ), + dim=-1, + ) + / outdim + ) + gg = xg.view(1, -1, 2) + yg.view(-1, 1, 2) + gg = gg.to(features.device) + res = list() + for i in range(bs): + if ROI[i] is not None: + W1 = ROI[i][..., 2:3] + W2 = ROI[i][..., 3:4] + H1 = ROI[i][..., 4:5] + H2 = ROI[i][..., 5:6] + psi = ROI[i][..., 6:] + WH = torch.cat((W1 + W2, H1 + H2), dim=-1) + offset = torch.cat(((W1 - W2) / 2, (H1 - H2) / 2), dim=-1) + s = torch.sin(psi).unsqueeze(-1) + c = torch.cos(psi).unsqueeze(-1) + rotM = torch.cat( + (torch.cat((c, -s), dim=-1), torch.cat((s, c), dim=-1)), dim=-2 + ) + ggi = gg * WH[..., None, None, :] - offset[..., None, None, :] + ggi = ggi @ rotM[..., None, :, :] + ROI[i][..., None, None, 0:2] + + x_sample = ggi[..., 0].flatten() + y_sample = ggi[..., 1].flatten() + res.append( + bilinear_interpolate(features[i], x_sample, y_sample).view( + ggi.shape[0], num_channels, *ggi.shape[1:-1] + ) + ) + else: + res.append(None) + + return res + + +def generate_ROIs_deprecated( + pos, + yaw, + centroid, + scene_yaw, + raster_from_world, + mask, + patch_size, + mode="last", +): + """ + This version generates ROI for all agents only at most recent time step unless specified otherwise + """ + if mode == "all": + bs = pos.shape[0] + yaw = yaw.type(torch.float) + scene_yaw = scene_yaw.type(torch.float) + s = torch.sin(scene_yaw).reshape(-1, 1, 1, 1) + c = torch.cos(scene_yaw).reshape(-1, 1, 1, 1) + rotM = torch.cat( + (torch.cat((c, -s), dim=-1), torch.cat((s, c), dim=-1)), dim=-2 + ) + world_xy = ((pos.unsqueeze(-2)) @ (rotM.transpose(-1, -2))).squeeze(-2) + world_xy += centroid.view(-1, 1, 1, 2).type(torch.float) + + Mat = raster_from_world.view(-1, 1, 1, 3, 3).type(torch.float) + raster_xy = batch_nd_transform_points(world_xy, Mat) + raster_mult = torch.linalg.norm( + raster_from_world[0, 0, 0:2], dim=[-1]).item() + patch_size = patch_size.type(torch.float) + patch_size *= raster_mult + ROI = [None] * bs + index = [None] * bs + for i in range(bs): + ii, jj = torch.where(mask[i]) + index[i] = (ii, jj) + if patch_size.ndim == 1: + patches_size = patch_size.repeat(ii.shape[0], 1) + else: + sizes = patch_size[i, ii] + patches_size = torch.cat( + ( + sizes[:, 0:1] * 0.5, + sizes[:, 0:1] * 0.5, + sizes[:, 1:2] * 0.5, + sizes[:, 1:2] * 0.5, + ), + dim=-1, + ) + ROI[i] = torch.cat( + ( + raster_xy[i, ii, jj], + patches_size, + yaw[i, ii, jj], + ), + dim=-1, + ).to(pos.device) + return ROI, index + elif mode == "last": + num = torch.arange(0, mask.shape[2]).view(1, 1, -1).to(mask.device) + nummask = num * mask + last_idx, _ = torch.max(nummask, dim=2) + bs = pos.shape[0] + scene_yaw = scene_yaw.type(torch.float) + s = torch.sin(scene_yaw).reshape(-1, 1, 1, 1) + c = torch.cos(scene_yaw).reshape(-1, 1, 1, 1) + rotM = torch.cat( + (torch.cat((c, -s), dim=-1), torch.cat((s, c), dim=-1)), dim=-2 + ) + world_xy = ((pos.unsqueeze(-2)) @ (rotM.transpose(-1, -2))).squeeze(-2) + world_xy += centroid.view(-1, 1, 1, 2).type(torch.float) + Mat = raster_from_world.view(-1, 1, 1, 3, 3).type(torch.float) + raster_xy = batch_nd_transform_points(world_xy, Mat) + agent_mask = mask.any(dim=2) + ROI = [None] * bs + index = [None] * bs + for i in range(bs): + ii = torch.where(agent_mask[i])[0] + index[i] = ii + if patch_size.ndim == 1: + patches_size = patch_size.repeat(ii.shape[0], 1) + else: + sizes = patch_size[i, ii] + patches_size = torch.cat( + ( + sizes[:, 0:1] * 0.5, + sizes[:, 0:1] * 0.5, + sizes[:, 1:2] * 0.5, + sizes[:, 1:2] * 0.5, + ), + dim=-1, + ) + ROI[i] = torch.cat( + ( + raster_xy[i, ii, last_idx[i, ii]], + patches_size, + yaw[i, ii, last_idx[i, ii]], + ), + dim=-1, + ) + return ROI, index + else: + raise ValueError("mode must be 'all' or 'last'") + + +def generate_ROIs( + pos, + yaw, + raster_from_agent, + mask, + patch_size, + mode="last", +): + """ + This version generates ROI for all agents only at most recent time step unless specified otherwise + """ + if mode == "all": + bs = pos.shape[0] + yaw = yaw.type(torch.float) + Mat = raster_from_agent.view(-1, 1, 1, 3, 3).type(torch.float) + raster_xy = batch_nd_transform_points(pos, Mat) + raster_mult = torch.linalg.norm( + raster_from_agent[0, 0, 0:2], dim=[-1]).item() + patch_size = patch_size.type(torch.float) + patch_size *= raster_mult + ROI = [None] * bs + index = [None] * bs + for i in range(bs): + ii, jj = torch.where(mask[i]) + index[i] = (ii, jj) + if patch_size.ndim == 1: + patches_size = patch_size.repeat(ii.shape[0], 1) + else: + sizes = patch_size[i, ii] + patches_size = torch.cat( + ( + sizes[:, 0:1] * 0.5, + sizes[:, 0:1] * 0.5, + sizes[:, 1:2] * 0.5, + sizes[:, 1:2] * 0.5, + ), + dim=-1, + ) + ROI[i] = torch.cat( + ( + raster_xy[i, ii, jj], + patches_size, + yaw[i, ii, jj], + ), + dim=-1, + ).to(pos.device) + return ROI, index + elif mode == "last": + num = torch.arange(0, mask.shape[2]).view(1, 1, -1).to(mask.device) + nummask = num * mask + last_idx, _ = torch.max(nummask, dim=2) + bs = pos.shape[0] + Mat = raster_from_agent.view(-1, 1, 1, 3, 3).type(torch.float) + raster_xy = batch_nd_transform_points(pos, Mat) + raster_mult = torch.linalg.norm( + raster_from_agent[0, 0, 0:2], dim=[-1]).item() + patch_size = patch_size.type(torch.float) + patch_size *= raster_mult + agent_mask = mask.any(dim=2) + ROI = [None] * bs + index = [None] * bs + for i in range(bs): + ii = torch.where(agent_mask[i])[0] + index[i] = ii + if patch_size.ndim == 1: + patches_size = patch_size.repeat(ii.shape[0], 1) + else: + sizes = patch_size[i, ii] + patches_size = torch.cat( + ( + sizes[:, 0:1] * 0.5, + sizes[:, 0:1] * 0.5, + sizes[:, 1:2] * 0.5, + sizes[:, 1:2] * 0.5, + ), + dim=-1, + ) + ROI[i] = torch.cat( + ( + raster_xy[i, ii, last_idx[i, ii]], + patches_size, + yaw[i, ii, last_idx[i, ii]], + ), + dim=-1, + ) + return ROI, index + else: + raise ValueError("mode must be 'all' or 'last'") + + +def Indexing_ROI_result(CNN_out, index, emb_size): + """put the lists of ROI align result into embedding tensor with the help of index""" + bs = len(CNN_out) + map_emb = torch.zeros(emb_size).to(CNN_out[0].device) + if map_emb.ndim == 3: + for i in range(bs): + map_emb[i, index[i]] = CNN_out[i] + elif map_emb.ndim == 4: + for i in range(bs): + ii, jj = index[i] + map_emb[i, ii, jj] = CNN_out[i] + else: + raise ValueError("wrong dimension for the map embedding!") + + return map_emb + + +def rasterized_ROI_align( + lane_mask, pos, yaw, raster_from_agent, mask, patch_size, out_dim +): + if pos.ndim == 4: + ROI, index = generate_ROIs( + pos, + yaw, + raster_from_agent, + mask, + patch_size.type(torch.float), + mode="all", + ) + lane_flags = ROI_align(lane_mask.unsqueeze(1), ROI, out_dim) + lane_flags = [x.mean([-2, -1]).view(x.size(0), 1) for x in lane_flags] + lane_flags = Indexing_ROI_result( + lane_flags, index, [*pos.shape[:3], 1]) + elif pos.ndim == 5: + lane_flags = list() + emb_size = (*pos[:, 0].shape[:-1], 1) + for i in range(pos.size(1)): + ROI, index = generate_ROIs( + pos[:, i], + yaw[:, i], + raster_from_agent, + mask, + patch_size.type(torch.float), + mode="all", + ) + lane_flag_i = ROI_align(lane_mask.unsqueeze(1), ROI, out_dim) + lane_flag_i = [x.mean([-2, -1]).view(x.size(0), 1) + for x in lane_flag_i] + lane_flags.append(Indexing_ROI_result( + lane_flag_i, index, emb_size)) + lane_flags = torch.stack(lane_flags, dim=1) + else: + raise ValueError("wrong shape") + return lane_flags + + +def obtain_map_enc( + image, + map_encoder, + pos, + yaw, + raster_from_agent, + mask, + patch_size, + output_size, + mode, +): + ROI, index = generate_ROIs( + pos, + yaw, + raster_from_agent, + mask, + patch_size, + mode, + ) + CNN_out = map_encoder(image, ROI) + if mode == "all": + emb_size = (*pos.shape[:-1], output_size) + elif mode == "last": + emb_size = (*pos.shape[:-2], output_size) + + # put the CNN output in the right location of the embedding + map_emb = Indexing_ROI_result(CNN_out, index, emb_size) + return map_emb + + +if __name__ == "__main__": + import numpy as np + from torchvision.ops.roi_align import RoIAlign + + + device = torch.device("cuda") + torch.set_default_tensor_type(torch.cuda.DoubleTensor) + + # create feature layer, proposals and targets + num_proposals = 10 + + bs = 1 + features = torch.randn(bs, 10, 32, 32) + + xy = torch.rand((bs, 5, 2)) * torch.tensor([32, 32]) + WH = torch.ones((bs, 5, 1)) * torch.tensor([1, 1, 1, 1]).view(1, 1, -1) + psi = torch.zeros(bs, 5, 1) + ROI = torch.cat((xy, WH, psi), dim=-1) + ROI = [ROI[i] for i in range(ROI.shape[0])] + res1 = ROI_align(features, ROI, 6)[0].transpose(0, 1) + + ROI_star = torch.cat( + (xy - WH[..., [0, 2]], xy + WH[..., [1, 3]]), dim=-1)[0] + + roi_align_obj = RoIAlign(6, 1, sampling_ratio=2, aligned=False) + res2 = roi_align_obj(features, [ROI_star]) + + res1 - res2 diff --git a/diffstack/models/layers.py b/diffstack/models/layers.py new file mode 100644 index 0000000..848dff5 --- /dev/null +++ b/diffstack/models/layers.py @@ -0,0 +1,782 @@ +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + +"""Commonly used network layers and functions.""" + +import logging +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +logger = logging.getLogger(__name__) + + +def _pair(x): + if hasattr(x, "__iter__"): + return x + return (x, x) + + +def _triple(x): + if hasattr(x, "__iter__"): + return x + return (x, x, x) + + +def init_last_conv_layer(module, b_prior, w_std=0.01): + """Initializes parameters of a convolutional layer. + + Uses normal distribution for the kernel weights and constant + computed using bias_prior for the prior. + See Focal Loss paper for more details. + """ + for m in module.modules(): + if isinstance(m, (nn.Conv2d, nn.Conv3d)): + nn.init.normal_(m.weight, mean=0, std=w_std) + nn.init.constant_(m.bias, -np.log((1.0 - b_prior) / b_prior)) + break + + +def soft_threshold(x, thresh, steepness=1024): + """Returns a soft-thresholded version of the tensor {x}. + + Elements less than {thresh} are set to zero, + elements appreciably larger than {thresh} are unchanged, + and elements approaching {thresh} from the positive side + steeply ramp to zero (faster for higher {steepness}). + + Behavior is very similar to th.threshold(x, thresh, 0), + but soft_threshold is composed of operations that are supported + by both the PyTorch 1.4 ONNX exporter and the TRT 6 ONNX parser. + + For comparison, th.threshold with nonzero thresh / value + cannot be exported by PyTorch 1.4, and other formulas + (like x * (x > thresh).float()) use operations + which are not supported by the TensorRT 6 importer. + """ + return x * ((x - thresh) * steepness).clamp(0, 1) + + +# noqa pylint: disable=R0911 +def Activation(act="elu"): + """Create activation function.""" + if act == "elu": + return nn.ELU(inplace=True) + if act == "smooth_elu": + return SmoothELU() + if act == "sigmoid": + return nn.Sigmoid() + if act == "tanh": + return nn.Tanh() + if act == "relu": + return nn.ReLU(inplace=True) + if act == "lrelu": + return nn.LeakyReLU(0.1, inplace=True) + if act is not None: + raise ValueError("Activation is not supported: {}.".format(act)) + return nn.Sequential() + + +def Normalization2D(norm, num_features): + """Create 2D layer normalization (4D inputs).""" + if norm == "bn": + return nn.BatchNorm2d(num_features) + if norm is not None: + raise ValueError("Normalization is not supported: {}.".format(norm)) + return nn.Sequential() + + +def Normalization3D(norm, num_features): + """Create 3D layer normalization (5D inputs).""" + if norm == "bn": + return nn.BatchNorm3d(num_features) + if norm is not None: + raise ValueError("Normalization is not supported: {}.".format(norm)) + return nn.Sequential() + + +def Upsample2D( + mode, in_channels, out_channels, kernel_size, stride, padding, output_padding +): + """Create upsampling layer for 4D input (2D convolution). + + Currently 3 interpolation modes are suported: interp, interp2, and deconv. + interp mode uses nearest neighbor interpolation + convolution. + interp2 mode also uses nearest neighbor interpolation + convolution, + but with a custom implementation that performs better in TRT 5 / 6. + deconv mode uses transposed convolution. + """ + + if stride == 1: + # Skip all upsampling + mode = "interp" + + if mode == "interp": + return UpsampleConv2D( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + use_pth_interp=True, + ) + if mode == "interp2": + return UpsampleConv2D( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + use_pth_interp=False, + ) + if mode == "deconv": + layer = nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=stride + 2 * (stride // 2), + stride=stride, + padding=stride // 2, + output_padding=0, + ) + + # fix initialization to not have checkerboarding + with torch.no_grad(): + layer.weight.data[:] = layer.weight.data[:, :, :1, :1] * 0.5 + layer.weight.data[:, :, 1:-1, 1:-1] *= 2 + return layer + raise ValueError("Mode is not supported: {}".format(mode)) + + +def Upsample3D( + mode, in_channels, out_channels, kernel_size, stride, padding, output_padding +): + """Create upsampling layer for 5D input (3D convolution). + + Currently 2 interpolation modes are suported: interp and deconv. + interp mode uses nearest neighbor interpolation + convolution. + deconv mode uses transposed convolution. + """ + + if mode == "interp": + return UpsampleConv3D( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + if mode == "deconv": + return nn.ConvTranspose3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + ) + + raise ValueError("Mode is not supported: {}".format(mode)) + + +class SmoothELU(nn.Module): + """Smooth version of ELU-like activation function. + + ELU derivative is continuous but not smooth. + See Improved Training of WGANs paper for more details. + """ + + def forward(self, x): + """Forward pass.""" + return F.softplus(2.0 * x + 2.0) / 2.0 - 1.0 + + +class UpsampleConv2D(nn.Module): + """Upsampling that uses nearest neighbor + 2D convolution.""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + use_pth_interp=True, + ): + """Creates the upsampler. + + use_pth_interp forces using PyTorch interpolation (torch.nn.functional.interpolate) + when applicable, rather than custom interpolation implementation. + """ + super(UpsampleConv2D, self).__init__() + + stride = _pair(stride) + if len(stride) != 2: + raise ValueError( + "Stride must be either int or 2-tuple but got {}".format(stride) + ) + if stride[0] != stride[1]: + raise ValueError("H and W strides must be equal but got {}".format(stride)) + + self._scale_factor = stride[0] + self._use_pth_interp = use_pth_interp + + self.interp = ( + self.interpolation(self._scale_factor) if self._scale_factor > 1 else None + ) + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, # noqa: disable=E221 + stride=1, + padding=padding, + ) + + def interpolation(self, scale_factor=2, mode="nearest"): + """Returns interpolation module.""" + if self._use_pth_interp: + return nn.Upsample(scale_factor=scale_factor, mode=mode) + + return self._int_upsample + + def _int_upsample(self, x): + """Alternative implementation of nearest-neighbor interpolation. + + The main motivation is suboptimal performance of TRT NN interpolation (as of TRT 5/6) + for certain tensor dimensions. The implementation below is about 50% faster + than the current TRT implementation on V100/FP16. + """ + assert x.dim() == 4, "Expected NCHW tensor but got {}.".format(x.size()) + # in TRT have to limit to a basic case of 2x upsample. + assert ( + self._scale_factor == 2 + ), "Only scale factor == 2 is currently supported" " but got {}.".format( + self._scale_factor + ) + + n, c, h, w = [int(d) for d in x.size()] + # 1. Upsample in W dim. + # x = x.view(-1, 1).repeat(1, self._scale_factor) + x = x.reshape(n, c, -1, 1) + # Note: when using interp2, Alfred TRT 6.2 ONNX parser requires 4D tensors. + x = torch.cat((x, x), dim=-1) + # 2. Upsample in H dim. + x = x.reshape(n, c, h, w * 2) + # y = x.repeat(1, 1, 1, self._scale_factor) + y = torch.cat((x, x), dim=3) + y = y.reshape(n, c, h * 2, w * 2) + return y + + def forward(self, x): + """Forward pass.""" + res = self.interp(x) if self.interp is not None else x + return self.conv(res) + + +class UpsampleConv3D(nn.Module): + """Upsampling that uses nearest neighbor + 3D convolution.""" + + def __init__( + self, in_channels, out_channels, kernel_size, stride, padding, output_padding + ): + """Creates the upsampler.""" + super(UpsampleConv3D, self).__init__() + + stride = _pair(stride) + if len(stride) != 3: + raise ValueError( + "Stride must be either int or 3-tuple but got {}".format(stride) + ) + if stride[0] != 1: + raise ValueError( + "Upsampling in D dimension is not supported ({}).".format(stride) + ) + if stride[1] != stride[2]: + raise ValueError("H and W strides must be equal but got {}".format(stride)) + + self.interp = self.interpolation(stride[1]) if stride[1] > 1 else None + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size=kernel_size, # noqa: disable=E221 + stride=1, + padding=padding, + ) + + def interpolation(self, scale_factor=2, mode="nearest"): + """Returns interpolation module.""" + return nn.Upsample(scale_factor=scale_factor, mode=mode) + + def forward(self, x): + """Forward pass.""" + res = x + if self.interp is not None: + res = self.interp( + x.view(x.size(0), x.size(1) * x.size(2), x.size(3), x.size(4)) + ) + res = res.view( + res.size(0), x.size(1), x.size(2), res.size(-2), res.size(-1) + ) + return self.conv(res) + + +class ASPPBlock(nn.Module): + """Adaptive spatial pyramid pooling block.""" + + def __init__( + self, + in_channels, + out_channels, + k_hw=3, + norm=None, + act="relu", + dilation_rates=(1, 6, 12), + act_before_combine=False, + act_after_combine=True, + combine_type="add", + combine_conv_k_hw=(3,), + ): + """Creates the ASPP block.""" + super().__init__() + + # Layers with varying dilation rate + self._aspp_layers = nn.ModuleList() + for rate in dilation_rates: + self._aspp_layers.append( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=k_hw, + padding=rate * (k_hw // 2), + dilation=(rate, rate), + ), + ) + + # Optional norm / activation before combining + if act_before_combine: + self._before_combine = nn.Sequential( + Normalization2D(norm, out_channels), + Activation(act), + ) + else: + self._before_combine = nn.Sequential() + + self._combine_type = combine_type + + # Layers to apply after combining + post_combine_layers = [] + if combine_type == "concat": + assert len(combine_conv_k_hw) > 0 + post_combine_layers.extend( + [ + nn.Conv2d( + out_channels * len(dilation_rates), + out_channels, + kernel_size=combine_conv_k_hw[0], + padding=combine_conv_k_hw[0] // 2, + ), + Normalization2D(norm, out_channels), + Activation(act), + ] + ) + combine_conv_k_hw = combine_conv_k_hw[1:] + + for c_c_k_hw in combine_conv_k_hw: + post_combine_layers.extend( + [ + nn.Conv2d( + out_channels, + out_channels, + kernel_size=c_c_k_hw, + padding=c_c_k_hw // 2, + ), + Normalization2D(norm, out_channels), + Activation(act), + ] + ) + if not act_after_combine and post_combine_layers: + # Remove the final activation + post_combine_layers = post_combine_layers[:-1] + + self._final_layer = nn.Sequential(*post_combine_layers) + + def _combine_op(self, xs): + if self._combine_type == "concat": + res = torch.cat(xs, -3) + elif self._combine_type == "add": + res = xs[0] + for x in xs[1:]: + res = res + x + else: + raise ValueError( + "Combine type {} is not supported.".format(self._combine_type) + ) + return res + + def forward(self, x): + """Forward pass.""" + aspp_outs = [] + for layer in self._aspp_layers: + aspp_outs.append(self._before_combine(layer(x))) + return self._final_layer(self._combine_op(aspp_outs)) + + +class ResnetBlock2D(nn.Module): + """Residual block with 2D convolutions. + + Supports both basic and bottleneck configurations. + """ + + def __init__( + self, + in_channels, + dim, + out_channels, + k_hw=3, + s_hw=1, + bottleneck=True, + norm="bn", + act="relu", + ): + """Creates the residual block.""" + super(ResnetBlock2D, self).__init__() + + if k_hw < 3 or (k_hw % 2) == 0: + raise ValueError( + "ResnetBlock2D requires kernel size " + "to be odd and >= 3 but got {}".format(k_hw) + ) + + self.act = Activation(act) + layers = [] + p_hw = k_hw // 2 + if bottleneck: + layers += [ + # 1x1 squeeze. + nn.Conv2d(in_channels, dim, kernel_size=1, stride=1, padding=0), + Normalization2D(norm, dim), + Activation(act), + # 3x3 (or kxk) + nn.Conv2d(dim, dim, kernel_size=k_hw, stride=s_hw, padding=p_hw), + Normalization2D(norm, dim), + Activation(act), + # 1x1 expand. + nn.Conv2d(dim, out_channels, kernel_size=1, stride=1, padding=0), + Normalization2D(norm, out_channels), + ] + self.block = nn.Sequential(*layers) + else: + layers += [ + # First 3x3 (or kxk). + nn.Conv2d( + in_channels, dim, kernel_size=k_hw, stride=s_hw, padding=p_hw + ), + Normalization2D(norm, dim), + Activation(act), + # Second 3x3 (or kxk). + nn.Conv2d(dim, out_channels, kernel_size=k_hw, stride=1, padding=p_hw), + Normalization2D(norm, out_channels), + ] + self.block = nn.Sequential(*layers) + + self.shortcut = None + if in_channels != out_channels or s_hw > 1: + self.shortcut = nn.Sequential( + nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=s_hw, padding=0 + ), + Normalization2D(norm, out_channels), + ) + + def forward(self, x): + """Forward pass.""" + res = self.block(x) + res += x if self.shortcut is None else self.shortcut(x) + return self.act(res) + + +class ResnetBlockTran2D(nn.Module): + """Transposed residual block with 2D convolutions. + + Supports both basic and bottleneck configurations. + """ + + def __init__( + self, + in_channels, + dim, + out_channels, + k_hw=3, + s_hw=1, + bottleneck=True, + norm="bn", + act="relu", + upsample="interp", + ): + """Creates the transposed residual block.""" + super(ResnetBlockTran2D, self).__init__() + + if k_hw < 3 or (k_hw % 2) == 0: + raise ValueError( + "ResnetBlockTran2D requires kernel size " + "to be odd and >= 3 but got {}".format(k_hw) + ) + + self.act = Activation(act) + layers = [] + p_hw = k_hw // 2 + o_p_hw = p_hw if s_hw > 1 else 0 + if bottleneck: + layers += [ + # 1x1 squeeze. + nn.Conv2d(in_channels, dim, kernel_size=1, stride=1, padding=0), + Normalization2D(norm, dim), + Activation(act), + # 3x3 (or kxk) + Upsample2D( + upsample, + dim, + dim, + kernel_size=k_hw, + stride=s_hw, + padding=p_hw, + output_padding=o_p_hw, + ), + Normalization2D(norm, dim), + Activation(act), + # 1x1 expand. + nn.Conv2d(dim, out_channels, kernel_size=1, stride=1, padding=0), + Normalization2D(norm, out_channels), + ] + self.block = nn.Sequential(*layers) + else: + layers += [ + # First 3x3 (or kxk). + Upsample2D( + upsample, + in_channels, + dim, + kernel_size=k_hw, + stride=s_hw, + padding=p_hw, + output_padding=o_p_hw, + ), + Normalization2D(norm, dim), + Activation(act), + # Second 3x3 (or kxk). + nn.Conv2d(dim, out_channels, kernel_size=k_hw, stride=1, padding=p_hw), + Normalization2D(norm, out_channels), + ] + self.block = nn.Sequential(*layers) + + self.shortcut = None + if in_channels != out_channels or s_hw > 1: + self.shortcut = nn.Sequential( + Upsample2D( + upsample, + in_channels, + out_channels, + kernel_size=1, + stride=s_hw, + padding=0, + output_padding=o_p_hw, + ), + Normalization2D(norm, out_channels), + ) + + def forward(self, x): + """Forward pass.""" + res = self.block(x) + res += x if self.shortcut is None else self.shortcut(x) + return self.act(res) + + +class ResnetBlock3D(nn.Module): + """Residual block with 3D convolutions. + + Supports both basic and bottleneck configurations. + """ + + def __init__( + self, + in_channels, + dim, + out_channels, + k_hw=3, + k_d=1, + s_hw=1, + s_d=1, + bottleneck=True, + norm="bn", + act="relu", + ): + """Creates the residual block.""" + super(ResnetBlock3D, self).__init__() + + self.act = Activation(act) + layers = [] + p_hw = k_hw // 2 # noqa: disable=E221 + if bottleneck: + layers += [ + # 1x1 squeeze. + nn.Conv3d(in_channels, dim, kernel_size=1, stride=1, padding=0), + Normalization3D(norm, dim), + self.act, + # 3x3 (or kxk) + nn.Conv3d( + dim, + dim, + kernel_size=(k_d, k_hw, k_hw), + stride=(s_d, s_hw, s_hw), + padding=(0, p_hw, p_hw), + ), + Normalization3D(norm, dim), + self.act, + # 1x1 expand. + nn.Conv3d(dim, out_channels, kernel_size=1, stride=1, padding=0), + Normalization3D(norm, out_channels), + ] + self.block = nn.Sequential(*layers) + else: + layers += [ + # First 3x3 (or kxk). + nn.Conv3d( + in_channels, + dim, + kernel_size=(k_d, k_hw, k_hw), + stride=(s_d, s_hw, s_hw), + padding=(0, p_hw, p_hw), + ), + Normalization3D(norm, dim), + self.act, + # Second 3x3 (or kxk). + nn.Conv3d( + dim, + out_channels, + kernel_size=(k_d, k_hw, k_hw), + stride=1, + padding=(0, p_hw, p_hw), + ), + Normalization3D(norm, out_channels), + ] + self.block = nn.Sequential(*layers) + + self.shortcut = None + if in_channels != out_channels or s_hw > 1 or s_d > 1: + self.shortcut = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size=1, + stride=(s_d, s_hw, s_hw), + padding=0, + ), + Normalization3D(norm, out_channels), + ) + + def forward(self, x): + """Forward pass.""" + res = self.block(x) # noqa: disable=E221 + res += x if self.shortcut is None else self.shortcut(x) + return self.act(res) + + +class ResnetBlockTran3D(nn.Module): + """Transposed residual block with 3D convolutions. + + Supports both basic and bottleneck configurations. + """ + + def __init__( + self, + in_channels, + dim, + out_channels, + k_hw=3, + k_d=1, + s_hw=1, + s_d=1, + bottleneck=True, + norm="bn", + act="relu", + upsample="interp", + ): + """Creates the transposed residual block.""" + super(ResnetBlockTran3D, self).__init__() + + self.act = Activation(act) + layers = [] + p_hw = k_hw // 2 # noqa: disable=E221 + o_p_hw = p_hw if s_hw > 1 else 0 + if bottleneck: + layers += [ + # 1x1 squeeze. + nn.Conv3d(in_channels, dim, kernel_size=1, stride=1, padding=0), + Normalization3D(norm, dim), + self.act, + # 3x3 (or kxk) + Upsample3D( + upsample, + dim, + dim, + kernel_size=(k_d, k_hw, k_hw), + stride=(s_d, s_hw, s_hw), + padding=(0, p_hw, p_hw), + output_padding=(0, o_p_hw, o_p_hw), + ), + Normalization3D(norm, dim), + self.act, + # 1x1 expand. + nn.Conv3d(dim, out_channels, kernel_size=1, stride=1, padding=0), + Normalization3D(norm, out_channels), + ] + self.block = nn.Sequential(*layers) + else: + layers += [ + # First 3x3 (or kxk). + Upsample3D( + upsample, + in_channels, + dim, + kernel_size=(k_d, k_hw, k_hw), + stride=(s_d, s_hw, s_hw), + padding=(0, p_hw, p_hw), + output_padding=(0, o_p_hw, o_p_hw), + ), + Normalization3D(norm, dim), + self.act, + # Second 3x3 (or kxk). + nn.Conv3d( + dim, + out_channels, + kernel_size=(k_d, k_hw, k_hw), + stride=1, + padding=(0, p_hw, p_hw), + ), + Normalization3D(norm, out_channels), + ] + self.block = nn.Sequential(*layers) + + self.shortcut = None + if in_channels != out_channels or s_hw > 1 or s_d > 1 or k_d > 1: + self.shortcut = nn.Sequential( + Upsample3D( + upsample, + in_channels, + out_channels, + kernel_size=(k_d, 1, 1), + stride=(s_d, s_hw, s_hw), + padding=(0, 0, 0), + output_padding=(0, o_p_hw, o_p_hw), + ), + Normalization3D(norm, out_channels), + ) + + def forward(self, x): + """Forward pass.""" + res = self.block(x) # noqa: disable=E221 + res += x if self.shortcut is None else self.shortcut(x) + return self.act(res) diff --git a/diffstack/models/learned_metrics.py b/diffstack/models/learned_metrics.py new file mode 100644 index 0000000..079ae89 --- /dev/null +++ b/diffstack/models/learned_metrics.py @@ -0,0 +1,85 @@ +from typing import Dict + +import torch +import torch.nn as nn + +import diffstack.models.base_models as base_models +import diffstack.utils.tensor_utils as TensorUtils + + +class PermuteEBM(nn.Module): + """Raster-based model for planning. + """ + + def __init__( + self, + model_arch: str, + input_image_shape, + map_feature_dim: int, + traj_feature_dim: int, + embedding_dim: int, + embed_layer_dims: tuple + ) -> None: + + super().__init__() + self.map_encoder = base_models.RasterizedMapEncoder( + model_arch=model_arch, + input_image_shape=input_image_shape, + feature_dim=map_feature_dim, + use_spatial_softmax=False, + output_activation=nn.ReLU + ) + self.traj_encoder = base_models.RNNTrajectoryEncoder( + trajectory_dim=3, + rnn_hidden_size=100, + feature_dim=traj_feature_dim + ) + self.embed_net = base_models.MLP( + input_dim=traj_feature_dim + map_feature_dim, + output_dim=embedding_dim, + output_activation=nn.ReLU, + layer_dims=embed_layer_dims + ) + self.score_net = nn.Linear(embedding_dim, 1) + + def forward(self, data_batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + image_batch = data_batch["image"] + trajs = torch.cat((data_batch["target_positions"], data_batch["target_yaws"]), dim=2) + bs = image_batch.shape[0] + + map_feat = self.map_encoder(image_batch) # [B, D_m] + traj_feat = self.traj_encoder(trajs) # [B, D_t] + + # construct contrastive samples + map_feat_rep = TensorUtils.unsqueeze_expand_at(map_feat, size=bs, dim=1) # [B, B, D_m] + traj_feat_rep = TensorUtils.unsqueeze_expand_at(traj_feat, size=bs, dim=0) # [B, B, D_t] + cat_rep = torch.cat((map_feat_rep, traj_feat_rep), dim=-1) # [B, B, D_m + D_t] + ebm_rep = TensorUtils.time_distributed(cat_rep, self.embed_net) # [B, B, D] + + # calculate embeddings and scores for InfoNCE loss + scores = TensorUtils.time_distributed(ebm_rep, self.score_net).squeeze(-1) # [B, B] + out_dict = dict(features=ebm_rep, scores=scores) + + return out_dict + + def get_scores(self, data_batch): + image_batch = data_batch["image"] + trajs = torch.cat((data_batch["target_positions"], data_batch["target_yaws"]), dim=2) + + map_feat = self.map_encoder(image_batch) # [B, D_m] + traj_feat = self.traj_encoder(trajs) # [B, D_t] + cat_rep = torch.cat((map_feat, traj_feat), dim=-1) # [B, D_m + D_t] + ebm_rep = self.embed_net(cat_rep) + scores = self.score_net(ebm_rep) + out_dict = dict(features=ebm_rep, scores=scores) + + return out_dict + + def compute_losses(self, pred_batch, data_batch): + scores = pred_batch["scores"] + bs = scores.shape[0] + labels = torch.arange(bs).to(scores.device) + loss = nn.CrossEntropyLoss()(scores, labels) + losses = dict(infoNCE_loss=loss) + + return losses \ No newline at end of file diff --git a/diffstack/models/unet.py b/diffstack/models/unet.py new file mode 100644 index 0000000..2ef73fc --- /dev/null +++ b/diffstack/models/unet.py @@ -0,0 +1,512 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +# adapted from predictionnet by Nvidia +from diffstack.models.layers import ( + Activation, + Normalization2D, + Upsample2D) +class UResBlock(nn.Module): + """UNet ResBlock. + + Uses dense 3x3 convs only, single resize *after* skip, + and other graph tweaks specific for UNet. + """ + + def __init__(self, in_channels, out_channels, stride, upsample=None, + norm='bn', act='relu', out_act='relu'): + """Construct UNet ResBlock.""" + super().__init__() + self.act = Activation(out_act) + self.block = nn.Sequential( + nn.Conv2d(in_channels, in_channels, 3, padding=1), + Normalization2D(norm, in_channels), + Activation(act), + nn.Conv2d(in_channels, in_channels, 3, padding=1), + Activation(act), + nn.Conv2d(in_channels, in_channels, 3, padding=1), + ) + self.resize = nn.Conv2d(in_channels, out_channels, 3, padding=1) + if stride > 1: + if upsample is not None: + self.resize = Upsample2D( + upsample, + in_channels, + out_channels, + stride * 2 - 1, + stride=stride, + padding=stride // 2, + output_padding=0 + ) + else: + self.resize = nn.Conv2d( + in_channels, + out_channels, + stride * 2 - 1, + stride=stride, + padding=stride // 2 + ) + + def forward(self, x): + """Forward pass.""" + x = self.block(x) + x + return self.act(self.resize(x)) + +class UNet(nn.Module): + """UNet model.""" + + def __init__(self, in_frames, out_frames, in_channels, out_channels, channels, strides, + decoder_strides=None, skip_type='add', upsample='interp', norm='bn', + activation='relu', last_layer_prior=0.0001, last_layer_mvec_bias=0.0, + last_layer_vel_bias=0.0, enc_out_dropout=0.0, dec_in_dropout=0.0, + trajectory_candidates=0, enable_nll=False, desired_size=None, **kwargs): + """Initialize the model. + + Args: + in_frames: number of input frames. + out_frames: number of output frames. + in_channels: list of [static input channels, dynamic input channels]. + out_channels: list of [output 0 channels, ..., output N channels] for raster outputs. + channels: number of channels in convolutional layers specified with a list. + strides: encoder's convolutional strides specified with a list. + decoder_strides: decoder's convolutional strides specified with a list. + channels, strides, and decoder strides are specified with the same sized list. + skip_type: connection between corresponding encoder/decoder blocks. + Can be add (default), none. + upsample: upsampling type in decoder. Passed to mode parameter of layers.Upsample2D. + norm: normalization type. Can be batch norm (bn) or none. + activation: activation function for all layers. See layers.Activation. + last_layer_prior: last layer bias initialization prior to work properly with + focal loss. + last_layer_mvec_bias: last layer motion vector output initial bias. + last_layer_vel_bias: last layer velocity output initial bias. + enc_out_dropout: encoder output dropout. + dec_in_dropout: decoder input dropout. + trajectory_candidates: number of trajectories to regress directly. + enable_nll: regress output for nll loss if true + desired_size: desired size for the input + """ + super(UNet, self).__init__() + self._in_frames = in_frames # noqa: disable=E221 + self._out_frames = out_frames + self._out_channels = out_channels + if isinstance(out_channels, int): + self._out_channels = [out_channels] + self._skip_type = self._get_str_arg(skip_type, 'skip_type', ['add', 'none']) + self._norm = self._get_str_arg(norm, 'norm', ['bn', 'none']) + self._act = self._get_str_arg(activation, 'activation', ['lrelu', 'relu', 'elu', + 'smooth_elu', 'sigmoid', + 'tanh', 'none']) + self._channels = list(channels) + self._strides = list(strides) # noqa: disable=E221 + if decoder_strides is None: + # symmetrical case, use enc. strides + self._dec_strides = list(reversed(self._strides)) + else: + self._dec_strides = list(decoder_strides) + if len(self._channels) != len(self._strides): + raise ValueError('Number of channels {} must be equal to number of ' + 'strides {}'.format(len(self._channels), len(self._strides))) + self._num_resnet_blocks = len(self._channels) - 1 + static_in_channels, dynamic_in_channels = in_channels + self._in_channels = dynamic_in_channels * in_frames + static_in_channels + # Encoder. + self._e0 = nn.Sequential( + nn.Conv2d(self._in_channels, self._channels[0], kernel_size=3, + stride=self._strides[0], padding=1), + Activation(self._act) + ) + for i in range(1, self._num_resnet_blocks + 1): + block = UResBlock( + self._channels[i - 1], + self._channels[i], + stride=self._strides[i], + upsample=None, + norm=self._norm, + act=self._act, + out_act=self._act + ) + setattr(self, '_e{}'.format(i), block) + self._enc_drop = nn.Dropout2d(p=enc_out_dropout) + self._dec_drop = nn.Dropout2d(p=dec_in_dropout) + output_per_frame = 2 # acce, yaw rate + if enable_nll: + output_per_frame += 3 # log_stdx, log_stdy, rho + total_out_channels = ( + sum(self._out_channels) * out_frames + # BEV Image outputs + trajectory_candidates * out_frames * output_per_frame # Trajectory outputs + ) + # Decoder. + min_last_dec_channel = 64 + for i in range(1, self._num_resnet_blocks + 1): + dec_in_channels_i = self._channels[i] + dec_out_channels_i = self._channels[i - 1] + if i == 1: + # Last decoder channel doesn't need to match encoder because there's no skip + # connection after + dec_out_channels_i = max(dec_out_channels_i, min_last_dec_channel) + block = UResBlock( + dec_in_channels_i, + dec_out_channels_i, + stride=self._dec_strides[self._num_resnet_blocks - i], + upsample=upsample, + norm=self._norm, + act=self._act, + out_act=None + ) + setattr(self, '_d{}'.format(i), block) + self._dec_act = Activation(self._act) + self._trajectory_candidates = trajectory_candidates + # Output layers. + # These are combined into one for efficiency; + # each output head just slices the combined result + output_layer = nn.Sequential( + Upsample2D(upsample, max(min_last_dec_channel, self._channels[0]), + total_out_channels, + kernel_size=3, stride=self._dec_strides[-1], + padding=1, output_padding=0) + ) + # Focal loss init for last layer occupancy channel + last_bias_layer = [m for m in output_layer.modules() if hasattr(m, 'bias')][-1] + with torch.no_grad(): + # most channels are zero-initialized + last_bias_layer.bias.data[:] = 0 + # channels corresponding to occupancy get focal loss initialization + occ_bias = -torch.log((torch.ones(1) - last_layer_prior) / last_layer_prior) + last_bias_layer.bias.data[:out_frames] = occ_bias + # Initialize biases for mvecs and velocity outputs. + # If there are 2 or more outputs, mvecs will be in the last 2 channels. + if len(out_channels) > 1: + last_bias_layer.bias.data[-2 * out_frames:] = last_layer_mvec_bias + # If there are 3 or more outputs, velocity will be in the second group of channels. + if len(out_channels) > 2: + last_bias_layer.bias.data[1*out_frames:3 * out_frames] = last_layer_vel_bias + # Final outputs will be slices of _d0_core + setattr(self, '_d0_core', output_layer) + self.desired_size = desired_size + + @property + def out_channels(self): + """Returns output channels configuration.""" + return self._out_channels + + @staticmethod + def _get_str_arg(val, name, allowed_values): + if val not in allowed_values: + raise ValueError('{} has invalid value {}. Supported values: {}'.format( + name, val, ', '.join(allowed_values))) + if val == 'none': + val = None + return val + + def _skip_op(self, enc, dec): + if self._skip_type == 'add': + res = enc + dec + elif self._skip_type is None: + res = dec + else: + raise ValueError('Skip type {} is not supported.'.format(self._skip_type)) + return res + + def _get_skip_key_for_tensor(self, x): + """Get a key for determining valid skip connections. + + Args: + x: torch Tensor of NCHW data to be used in skip connection. + Returns a key k(x) such that, if k(x) == k(y), then + self._skip_op(x, y) is a valid skip connection. + This is useful when the encoder / decoder are not using matching + block counts / filter counts / strides, and symmetric UNet + skip connections based only on the block index won't work. + Example usage for a simple case: + skips = {} + # during encoder(x) + skips[self._get_skip_key_for_tensor(x)] = x + # ... + # during decoder(y) + y_key = self._get_skip_key_for_tensor(y) + if y_key in skips: + y = self._skip_op(skips.pop(y_key), y) # pop() to prevent reuse + """ + assert x.ndim == 4, f'Invalid {x.shape}' + # N is assumed to always match, + # and H is assumed to match if W does + # (also - broadcasting between H / TH is allowed), + # so create a key just based on C and W. + return (int(x.size(1)), int(x.size(-1))) + + def _apply_encoder(self, x, s): + """Run the encoder tower on a batch of inputs. + + Args: + x: N[CT]HW tensor, representing N episodes of T timesteps of + C channels at HW spatial locations (aka dynamic context). + s: NCHW tensor, representing N episodes of C channels + at HW spatial locations (aka static context). + Returns a tuple ( + a N[CT]HW tensor, + containing final output features from the encoder, + a dictionary of {size: NCHW tensor}, + containing intermediate features from each encoder block; + these are only for the first timestep, and can therefore + be safely used in decoder "skip connections" for all timesteps + without leaking any information backwards in time. + ) + """ + # Check sizes are reasonable + assert ( + x.ndim == 4 + ), f'Invalid x {x.size()}, should be N[CT]HW' + x = torch.cat([s, x], 1) + e_in = self._e0(x) + # Run all encoder blocks + skip_out_dict = {} + # import pdb + # pdb.set_trace() + for i in range(self._num_resnet_blocks): + block = getattr(self, '_e{}'.format(i + 1)) + e_out = block(e_in) + e_in = e_out # noqa: disable=E221 + skip_out_key = self._get_skip_key_for_tensor(e_out) + skip_out_dict[skip_out_key] = skip_out_dict.get(skip_out_key, []) + [e_out] + # Apply dropout on encoder output + if self._enc_drop.p > 0: + e_out = self._enc_drop(e_out) + return e_out, skip_out_dict + + def _apply_decoder(self, x, skip_connections): + """Run the decoder tower on a batch of inputs. + + Args: + x: N[CT]HW tensor, representing N episodes of T timesteps of + C channels at HW spatial locations (aka dynamic context). + skip_connections: a dictionary of {size: NCHW tensor}, + representing output context features from each encoder stage. + Returns an N[CT]HW tensor, + containing decoder output. Output "heads" should + be sliced from the channels of this tensor. + """ + d_in = x + # Apply dropout on decoder input + if self._dec_drop.p > 0: + d_in = self._dec_drop(d_in) + # Run all decoder blocks, applying skip connections + dec_skip_keys = [] + for i in range(self._num_resnet_blocks): + # Apply skip connection + skip_key = self._get_skip_key_for_tensor(d_in) + dec_skip_keys.append(skip_key) + if skip_key in skip_connections and len(skip_connections[skip_key]) > 0: + skip = skip_connections[skip_key].pop() + # do not add the same thing to itself! + if d_in is not skip: + d_in = self._dec_act(self._skip_op(skip, d_in)) + else: + print('failed skip for', skip_key, 'in', skip_connections.keys()) + # Apply block + block = getattr(self, '_d{}'.format(self._num_resnet_blocks - i)) + d_out = block(d_in) + d_in = d_out # noqa: disable=E221 + # Apply shared decoder layer + d_out = self._d0_core(self._dec_act(d_out)) + return d_out + + def is_recurrent(self): + """Check if model is recurrent.""" + return False + + def forward(self, x): + """Forward pass.""" + + # The input contains separate static / dynamic tensors + assert len(x) == 2, f'Invalid x of type {type(x)} length {len(x)}' + s, x = x # Separate static and dynamic parts + assert ( + s.shape[-2:] == x.shape[-2:] + ), f'Invalid static / dynamic shapes: ({s.shape}, {x.shape})' + # In general, 5D tensors are consumed during training, while 4D - during inference. + if x.dim() not in [4, 5]: + raise ValueError(f'Expected 4D(N[CT]HW) or 5D(NCTHW) tensor but got {x.dim()}') + if self.desired_size is not None: + assert s.shape[-2]<=self.desired_size[0] and s.shape[-1]<=self.desired_size[1] + w,h = s.shape[-2:] + pw,ph = self.desired_size[0]-w,self.desired_size[1]-h + pad_w = torch.zeros([*s.shape[:-2],pw,h],dtype=s.dtype,device=s.device) + pad_h = torch.zeros([*s.shape[:-2],self.desired_size[0],ph],dtype=s.dtype,device=s.device) + s = torch.cat((torch.cat((s,pad_w),-2),pad_h),-1) + + pad_w = torch.zeros([*x.shape[:-2],pw,h],dtype=x.dtype,device=x.device) + pad_h = torch.zeros([*x.shape[:-2],self.desired_size[0],ph],dtype=x.dtype,device=x.device) + x = torch.cat((torch.cat((x,pad_w),-2),pad_h),-1) + + + # Save input H & W resolution + + dim_h_input, dim_w_input = x.size()[-2:] + dim_n = int(x.size(0)) + input_is_5d = x.dim() == 5 + if input_is_5d: + # Transform from NCTHW to N[CT]HW. + s = s.view(dim_n, -1, dim_h_input, dim_w_input) + x = x.view(dim_n, -1, dim_h_input, dim_w_input) + + # At this point, everything is 4D N[CT]HW + # (regardless of training or export or whatever); + # we'll convert back to NCTHW at the end if input_is_5d is True + assert ( + s.size(1) + x.size(1) == self._in_channels + ), f'Channel counts in input {s.size(), x.size()} do not match model' + + # Apply encoder / decoder + enc_past, enc_past_skip = self._apply_encoder(x, s=s) + res = self._apply_decoder(enc_past, enc_past_skip) + + # During export - don't do any resizing or slicing; + # those steps will be handled by the DW inference code + if torch.onnx.is_in_onnx_export(): + return res + + # During training, we upsample and slice as needed + res = F.interpolate( + res, + size=(dim_h_input, dim_w_input), + mode='bilinear', + align_corners=False + ) + + # Slice outputs for each output head + res_sliced = [] + slice_start = 0 + dim_t_out = self._out_frames + dim_h_out, dim_w_out = (int(d) for d in res.size()[-2:]) + for dim_c_i in self._out_channels: + slice_end = slice_start + dim_c_i * dim_t_out + res_sliced_i = res[:, slice_start:slice_end] + slice_start = slice_end + # Slice is, again, N[CT]HW + assert res_sliced_i.size() == (dim_n, dim_c_i * dim_t_out, dim_h_out, dim_w_out) + if input_is_5d: + # But, we convert the slice back to NCTHW (5d) if needed + res_sliced_i = res_sliced_i.view(dim_n, dim_c_i, dim_t_out, dim_h_out, dim_w_out) + res_sliced.append(res_sliced_i) + + # Optional trajectory regression output + if self._trajectory_candidates > 0: + # fill in Nones for any other outputs + res_sliced.append(res[:, slice_end:]) + return tuple(res_sliced) + + +class DirectRegressionPostProcessor(torch.nn.Module): + """Direct-regression trajectory extractor. + + This provides a differentiable module for converting a dense DNN output + tensor into a trajectories, given initial positions. + """ + + def __init__( + self, + num_in_frames, + image_height_px, + image_width_px, + sampling_time, + trajectory_candidates, + dyn + ): + """Initialize the post processor. + + Args: + num_in_frames: the value of in_frames used for the predictions. + image_height_px: the vertical spatial resolution of tensors to expect + in postprocessor forward() - i.e. output resolution of prediction DNN. + image_width_px: the corresponding horizontal spatial resolution. + frame_duration_s: the time delta (seconds) between successive input / output frames + trajectory_candidates: number of trajectory candidates to extract. + """ + + super().__init__() + self._num_in_frames = num_in_frames + self._image_height_px = image_height_px + self._image_width_px = image_width_px + assert ( + self._image_height_px == self._image_width_px + ), f"Assumed square pixels, but {self._image_height_px} != {self._image_width_px}" + + self._dt_s = sampling_time + assert self._dt_s > 0, f"Invalid frame duration (s): {self._dt_s}" + self._dyn = dyn + self._trajectory_candidates = trajectory_candidates + + def forward(self, obj_trajectory, query_pts, curr_state, actual_size=None): + """Convert a predicted trajectory tensor and initial positions into a list of FrameTrajectories. + + Args: + obj_trajectory: Torch tensor of dimensions [N, C, H, W] representing + predicted trajectories at each location. For K candidates, C = K * (1 + T * 2), + corresponding to 1 overall confidence score and T (ax, ay) acceleration pairs. + query_pts: location of the agents on the image [N,Na,2] + + Returns a length-N list of FrameTrajectories for each episode in the batch. + """ + bs,Na = query_pts.shape[:2] + if actual_size is None: + pos_xy_rel = query_pts.unsqueeze(1)/(obj_trajectory.shape[-1]-1)*2-1 + else: + pos_xy_rel = query_pts.unsqueeze(1)/(actual_size-1)*2-1 + + pred_per_agent = ( + torch.nn.functional.grid_sample( + obj_trajectory, + pos_xy_rel, + mode="bilinear", + align_corners=True, + ) + .squeeze(2) + .transpose(1,2) + ) + + input_pred = pred_per_agent.reshape( + bs*Na, self._trajectory_candidates, -1, 2 + ) # AxTxCandidatesx2 + traj = self._dyn.forward_dynamics(curr_state.reshape(bs*Na,1,-1).repeat_interleave(self._trajectory_candidates,1),input_pred) + pos = self.dyn.state2pos(traj) + yaw = self.dyn.state2yaw(traj) + + return traj,pos,yaw,input_pred + + +def main(): + num_in_frames = 10 + num_out_frames = 10 + model = UNet(in_frames=num_in_frames, out_frames=num_out_frames, in_channels=[2,3], + out_channels=[1], channels=[32, 64, 128, 128, 256, 256, 256], + strides=[2, 2, 2, 2, 2, 2, 2], decoder_strides=[2, 2, 2, 2, 2, 1, 1], + skip_type='add', upsample='interp', norm='bn', + activation='relu', last_layer_prior=0.0001, last_layer_mvec_bias=0.0, + last_layer_vel_bias=0.0, enc_out_dropout=0.0, dec_in_dropout=0.0, + trajectory_candidates=2, enable_nll=False) + + static_input = torch.ones([10,2,255,255]) + dynamic_input = torch.zeros([10,3,10,255,255]) + out = model((static_input,dynamic_input)) + query_pts = torch.rand([10,5,2])*255 + from diffstack.dynamics import Unicycle + dyn = Unicycle(0.1, vbound=[-1, 40.0]) + curr_state=torch.randn([10,5,4]) + + postprocessor = DirectRegressionPostProcessor(num_in_frames, + 255, + 255, + 0.1, + 2, + dyn, + ) + pred_logits, pred_traj = out + traj = postprocessor(pred_traj,query_pts,curr_state) + import pdb + pdb.set_trace() + + + +if __name__=="__main__": + main() \ No newline at end of file diff --git a/diffstack/models/vaes.py b/diffstack/models/vaes.py new file mode 100644 index 0000000..4747a32 --- /dev/null +++ b/diffstack/models/vaes.py @@ -0,0 +1,1348 @@ +"""Variants of Conditional Variational Autoencoder (C-VAE)""" +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffstack.utils.loss_utils import KLD_0_1_loss, KLD_gaussian_loss, KLD_discrete +from diffstack.utils.torch_utils import reparameterize +from diffstack.models.base_models import MLP, SplitMLP +import diffstack.utils.tensor_utils as TensorUtils +from torch.distributions import Categorical, kl_divergence + + +class Prior(nn.Module): + def __init__(self, latent_dim, input_dim=None, device=None): + """ + A generic prior class + Args: + latent_dim: dimension of the latent code (e.g., mu, logvar) + input_dim: (Optional) dimension of the input feature vector, for conditional prior + device: + """ + super(Prior, self).__init__() + self._latent_dim = latent_dim + self._input_dim = input_dim + self._net = None + self._device = device + + def forward(self, inputs: torch.Tensor = None): + """ + Get a batch of prior parameters. + + Args: + inputs (torch.Tensor): (Optional) A feature vector for priors that are input-conditional. + + Returns: + params (dict): A dictionary of prior parameters with the same batch size as the inputs (1 if inputs is None) + """ + raise NotImplementedError + + @staticmethod + def get_mean(prior_params): + """ + Extract the "mean" of a prior distribution (not supported by all priors) + + Args: + prior_params (torch.Tensor): a batch of prior parameters + + Returns: + mean (torch.Tensor): the "mean" of the distribution + """ + raise NotImplementedError + + def sample(self, n: int, inputs: torch.Tensor = None): + """ + Take samples with the prior distribution + + Args: + n (int): number of samples to take + inputs (torch.Tensor): (Optional) A feature vector for priors that are input-conditional. + + Returns: + samples (torch.Tensor): a batch of latent samples with shape [input_batch_size, n, latent_dim] + """ + prior_params = self.forward(inputs=inputs) + return self.sample_with_parameters(prior_params, n=n) + + @staticmethod + def sample_with_parameters(params: dict, n: int): + """ + Take samples using given a batch of distribution parameters, e.g., mean and logvar of a unit Gaussian + + Args: + params (dict): a batch of distribution parameters + n (int): number of samples to take + + Returns: + samples (torch.Tensor): a batch of latent samples with shape [param_batch_size, n, latent_dim] + """ + raise NotImplementedError + + def kl_loss(self, posterior_params: dict, inputs: torch.Tensor = None) -> torch.Tensor: + """ + Compute kl loss between the prior and the posterior distributions. + + Args: + posterior_params (dict): a batch of distribution parameters + inputs (torch.Tensor): (Optional) A feature vector for priors that are input-conditional. + + Returns: + kl_loss (torch.Tensor): kl divergence value + """ + raise NotImplementedError + + @property + def posterior_param_shapes(self) -> dict: + """ + Shape of the posterior parameters + + Returns: + shapes (dict): a dictionary of parameter shapes + """ + raise NotImplementedError + + @property + def latent_dim(self): + """ + Shape of the latent code + + Returns: + latent_dim (int) + """ + return self._latent_dim + + +class FixedGaussianPrior(Prior): + """An unassuming unit Gaussian Prior""" + def __init__(self, latent_dim, input_dim=None, device=None): + super(FixedGaussianPrior, self).__init__( + latent_dim=latent_dim, input_dim=input_dim, device=device) + self._params = nn.ParameterDict({ + "mu": nn.Parameter(data=torch.zeros(self._latent_dim), requires_grad=False), + "logvar": nn.Parameter(data=torch.zeros(self._latent_dim), requires_grad=False) + }) + + @staticmethod + def get_mean(prior_params): + return prior_params["mu"] + + def forward(self, inputs: torch.Tensor = None): + """ + Get a batch of prior parameters. + + Args: + inputs (torch.Tensor): (Optional) A feature vector for priors that are input-conditional. + + Returns: + params (dict): A dictionary of prior parameters with the same batch size as the inputs (1 if inputs is None) + """ + + batch_size = 1 if inputs is None else inputs.shape[0] + params = TensorUtils.unsqueeze_expand_at(self._params, size=batch_size, dim=0) + return params + + @staticmethod + def sample_with_parameters(params, n: int): + """ + Take samples using given a batch of distribution parameters, e.g., mean and logvar of a unit Gaussian + + Args: + params (dict): a batch of distribution parameters + n (int): number of samples to take + + Returns: + samples (torch.Tensor): a batch of latent samples with shape [param_batch_size, n, latent_dim] + """ + + batch_size = params["mu"].shape[0] + params_tiled = TensorUtils.repeat_by_expand_at(params, repeats=n, dim=0) + samples = reparameterize(params_tiled["mu"], params_tiled["logvar"]) + samples = TensorUtils.reshape_dimensions(samples, begin_axis=0, end_axis=1, target_dims=(batch_size, n)) + return samples + + def kl_loss(self, posterior_params, inputs=None): + """ + Compute kl loss between the prior and the posterior distributions. + + Args: + posterior_params (dict): a batch of distribution parameters + inputs (torch.Tensor): (Optional) A feature vector for priors that are input-conditional. + + Returns: + kl_loss (torch.Tensor): kl divergence value + """ + + assert posterior_params["mu"].shape[1] == self._latent_dim + assert posterior_params["logvar"].shape[1] == self._latent_dim + return KLD_0_1_loss( + mu=posterior_params["mu"], + logvar=posterior_params["logvar"] + ) + + @property + def posterior_param_shapes(self) -> OrderedDict: + return OrderedDict(mu=(self._latent_dim,), logvar=(self._latent_dim,)) + + +class ConditionalCategoricalPrior(Prior): + """ + A class that holds functionality for learning categorical priors for use + in VAEs. + """ + def __init__(self, latent_dim, input_dim=None, device=None): + """ + Args: + latent_dim (int): size of latent dimension for the prior + device (torch.Device): where the module should live (i.e. cpu, gpu) + """ + super(ConditionalCategoricalPrior, self).__init__(latent_dim=latent_dim, input_dim=input_dim, device=device) + assert isinstance(input_dim, int) and input_dim > 0 + self.device = device + self._latent_dim = latent_dim + self._prior_net = MLP(input_dim=input_dim, output_dim=latent_dim) + + def sample(self, n: int, inputs: torch.Tensor = None): + """ + Returns a batch of samples from the prior distribution. + Args: + n (int): this argument is used to specify the number + of samples to generate from the prior. + inputs (torch.Tensor): conditioning feature for prior + Returns: + z (torch.Tensor): batch of sampled latent vectors. + """ + + # check consistency between n and obs_dict + if self.learnable: + + # forward to get parameters + out = self.forward(batch_size=n, obs_dict=obs_dict, goal_dict=goal_dict) + prior_logits = out["logit"] + + # sample one-hot latents from categorical distribution + dist = Categorical(logits=prior_logits) + z = TensorUtils.to_one_hot(dist.sample(), num_class=self.categorical_dim) + + else: + # try to include a categorical sample for each class if possible (ensuring rough uniformity) + if (self.latent_dim == 1) and (self.categorical_dim <= n): + # include samples [0, 1, ..., C - 1] and then repeat until batch is filled + dist_samples = torch.arange(n).remainder(self.categorical_dim).unsqueeze(-1).to(self.device) + else: + # sample one-hot latents from uniform categorical distribution for each latent dimension + probs = torch.ones(n, self.latent_dim, self.categorical_dim).float().to(self.device) + dist_samples = Categorical(probs=probs).sample() + z = TensorUtils.to_one_hot(dist_samples, num_class=self.categorical_dim) + + # reshape [B, D, C] to [B, D * C] to be consistent with other priors that return flat latents + z = z.reshape(*z.shape[:-2], -1) + return z + + def kl_loss(self, posterior_params, inputs = None): + """ + Computes KL divergence loss between the Categorical distribution + given by the unnormalized logits @logits and the prior distribution. + Args: + posterior_params (dict): dictionary with key "logits" corresponding + to torch.Tensor batch of unnormalized logits of shape [B, D * C] + that corresponds to the posterior categorical distribution + Returns: + kl_loss (torch.Tensor): KL divergence loss + """ + prior_logits = self._prior_net(inputs) + + prior_dist = Categorical(logits=prior_logits) + posterior_dist = Categorical(logits=posterior_params["logits"]) + + # sum over latent dimensions, but average over batch dimension + kl_loss = kl_divergence(posterior_dist, prior_dist) + assert len(kl_loss.shape) == 2 + return kl_loss.sum(-1).mean() + + def forward(self, inputs: torch.Tensor = None): + """ + Get a batch of prior parameters. + + Args: + inputs (torch.Tensor): (Optional) A feature vector for priors that are input-conditional. + + Returns: + params (dict): A dictionary of prior parameters with the same batch size as the inputs (1 if inputs is None) + """ + prior_logits = self._prior_net(inputs) + return prior_logits + + +class LearnedGaussianPrior(FixedGaussianPrior): + """A Gaussian prior with learnable parameters""" + def __init__(self, latent_dim, input_dim=None, device=None): + super(LearnedGaussianPrior, self).__init__( + latent_dim=latent_dim, input_dim=input_dim, device=device) + self._params = nn.ParameterDict({ + "mu": nn.Parameter(data=torch.zeros(self._latent_dim), requires_grad=True), + "logvar": nn.Parameter(data=torch.zeros(self._latent_dim), requires_grad=True) + }) + + def kl_loss(self, posterior_params, inputs=None): + """ + Compute kl loss between the prior and the posterior distributions. + + Args: + posterior_params (dict): a batch of distribution parameters + inputs (torch.Tensor): (Optional) A feature vector for priors that are input-conditional. + + Returns: + kl_loss (torch.Tensor): kl divergence value + """ + + assert posterior_params["mu"].shape[1] == self._latent_dim + assert posterior_params["logvar"].shape[1] == self._latent_dim + + batch_size = posterior_params["mu"].shape[0] + prior_params = TensorUtils.unsqueeze_expand_at(self._params, size=batch_size, dim=0) + return KLD_gaussian_loss( + mu_1=posterior_params["mu"], + logvar_1=posterior_params["logvar"], + mu_2=prior_params["mu"], + logvar_2=prior_params["logvar"] + ) + + +class CVAE(nn.Module): + def __init__( + self, + q_net: nn.Module, + c_net: nn.Module, + decoder: nn.Module, + prior: Prior + ): + """ + A basic Conditional Variational Autoencoder Network (C-VAE) + + Args: + q_net (nn.Module): a model that encodes data (x) and condition inputs (x_c) to posterior (q) parameters + c_net (nn.Module): a model that encodes condition inputs (x_c) into condition feature (c) + decoder (nn.Module): a model that decodes latent (z) and condition feature (c) to data (x') + prior (nn.Module): a model containing information about distribution prior (kl-loss, prior params, etc.) + """ + super(CVAE, self).__init__() + self.q_net = q_net + self.c_net = c_net + self.decoder = decoder + self.prior = prior + + def sample(self, condition_inputs, n: int, condition_feature=None, decoder_kwargs=None): + """ + Draw data samples (x') given a batch of condition inputs (x_c) and the VAE prior. + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) + n (int): number of samples to draw + condition_feature (torch.Tensor): Optional - externally supply condition code (c) + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') of size [B, n, ...] + """ + if condition_feature is not None: + c = condition_feature + else: + c = self.c_net(condition_inputs) # [B, ...] + z = self.prior.sample(n=n, inputs=c) # z of shape [B (from c), N, ...] + z_samples = TensorUtils.join_dimensions(z, begin_axis=0, end_axis=2) # [B * N, ...] + c_samples = TensorUtils.repeat_by_expand_at(c, repeats=n, dim=0) # [B * N, ...] + decoder_kwargs = dict() if decoder_kwargs is None else decoder_kwargs + x_out = self.decoder(latents=z_samples, condition_features=c_samples, **decoder_kwargs) + x_out = TensorUtils.reshape_dimensions(x_out, begin_axis=0, end_axis=1, target_dims=(c.shape[0], n)) + return x_out + + def predict(self, condition_inputs, condition_feature=None, decoder_kwargs=None): + """ + Generate a prediction based on latent prior (instead of sample) and condition inputs + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) + condition_feature (torch.Tensor): Optional - externally supply condition code (c) + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched predictions (x') of size [B, ...] + + """ + if condition_feature is not None: + c = condition_feature + else: + c = self.c_net(condition_inputs) # [B, ...] + + prior_params = self.prior(c) # [B, ...] + mu = self.prior.get_mean(prior_params) # [B, ...] + decoder_kwargs = dict() if decoder_kwargs is None else decoder_kwargs + x_out = self.decoder(latents=mu, condition_features=c, **decoder_kwargs) + return x_out + + def forward(self, inputs, condition_inputs, decoder_kwargs=None): + """ + Pass the input through encoder and decoder (using posterior parameters) + Args: + inputs (dict, torch.Tensor): encoder inputs (x) + condition_inputs (dict, torch.Tensor): condition inputs - (x_c) + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') + """ + c = self.c_net(condition_inputs) # [B, ...] + q_params = self.q_net(inputs=inputs, condition_features=c) + z = self.prior.sample_with_parameters(q_params, n=1).squeeze(dim=1) + decoder_kwargs = dict() if decoder_kwargs is None else decoder_kwargs + x_out = self.decoder(latents=z, condition_features=c, **decoder_kwargs) + return {"x_recons": x_out, "q_params": q_params, "z": z, "c": c} + + def compute_kl_loss(self, outputs: dict): + """ + Compute KL Divergence loss + + Args: + outputs (dict): outputs of the self.forward() call + + Returns: + a dictionary of loss values + """ + return self.prior.kl_loss(outputs["q_params"], inputs=outputs["c"]) + + +class DiscreteCVAE(nn.Module): + def __init__( + self, + q_net: nn.Module, + p_net: nn.Module, + c_net: nn.Module, + decoder: nn.Module, + K: int, + recon_loss_fun=None, + logpi_clamp = None, + ): + """ + A basic Conditional Variational Autoencoder Network (C-VAE) + + Args: + q_net (nn.Module): a model that encodes data (x) and condition inputs (x_c) to posterior (q) parameters + p_net (nn.Module): a model that encodes condition feature (c) to latent (p) parameters + c_net (nn.Module): a model that encodes condition inputs (x_c) into condition feature (c) + decoder (nn.Module): a model that decodes latent (z) and condition feature (c) to data (x') + K (int): cardinality of the discrete latent + recon_loss: loss function handle for reconstruction loss + logpi_clamp (float): lower bound of the logpis, for numerical stability + """ + super(DiscreteCVAE, self).__init__() + self.q_net = q_net + self.p_net = p_net + self.c_net = c_net + self.decoder = decoder + self.K = K + self.logpi_clamp= logpi_clamp + if recon_loss_fun is None: + self.recon_loss_fun = nn.MSELoss(reduction="none") + else: + self.recon_loss_fun = recon_loss_fun + + def sample(self, condition_inputs, n: int, condition_feature=None, decoder_kwargs=None): + """ + Draw data samples (x') given a batch of condition inputs (x_c) and the VAE prior. + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) + n (int): number of samples to draw + condition_feature (torch.Tensor): Optional - externally supply condition code (c) + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') of size [B, n, ...] + """ + assert n<=self.K + + if condition_feature is not None: + c = condition_feature + else: + c = self.c_net(condition_inputs) # [B, ...] + logp = self.p_net(c)["logp"] + p = torch.exp(logp) + p = p.nan_to_num(nan=0.0, posinf=1.0, neginf=0.0) + p = p/p.sum(dim=-1,keepdim=True) + # z = (-logp).argsort()[...,:n] + # z = F.one_hot(z,self.K) + + dis_p = Categorical(probs=p) # [n_sample, batch] -> [batch, n_sample] + z = dis_p.sample((n,)).permute(1, 0) + z = F.one_hot(z, self.K) + + z_samples = TensorUtils.join_dimensions(z, begin_axis=0, end_axis=2) # [B * N, ...] + c_samples = TensorUtils.repeat_by_expand_at(c, repeats=n, dim=0) # [B * N, ...] + decoder_kwargs = dict() if decoder_kwargs is None else decoder_kwargs + x_out = self.decoder(latents=z_samples, condition_features=c_samples, **decoder_kwargs) + x_out = TensorUtils.reshape_dimensions(x_out, begin_axis=0, end_axis=1, target_dims=(c.shape[0], n)) + x_out["z"] = z_samples + return x_out + + def predict(self, condition_inputs, condition_feature=None, decoder_kwargs=None): + """ + Generate a prediction based on latent prior (instead of sample) and condition inputs + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) + condition_feature (torch.Tensor): Optional - externally supply condition code (c) + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched predictions (x') of size [B, ...] + + """ + if condition_feature is not None: + c = condition_feature + else: + c = self.c_net(condition_inputs) # [B, ...] + + logp = self.p_net(c)["logp"] + z = logp.argmax(dim=-1) + + decoder_kwargs = dict() if decoder_kwargs is None else decoder_kwargs + x_out = self.decoder(latents=F.one_hot(z,self.K), condition_features=c, **decoder_kwargs) + return x_out + + def forward(self, inputs, condition_inputs, n=None, decoder_kwargs=None): + """ + Pass the input through encoder and decoder (using posterior parameters) + Args: + inputs (dict, torch.Tensor): encoder inputs (x) + condition_inputs (dict, torch.Tensor): condition inputs - (x_c) + n (int): number of samples, if not given, then n=self.K + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') + """ + if n is None: + n = self.K + c = self.c_net(condition_inputs) # [B, ...] + logq = self.q_net(inputs=inputs, condition_features=c)["logq"] + logp = self.p_net(c)["logp"] + if self.logpi_clamp is not None: + logq = logq.clamp(min=self.logpi_clamp,max=2.0) + logp = logp.clamp(min=self.logpi_clamp,max=2.0) + + q = torch.exp(logq) + q = q/q.sum(dim=-1,keepdim=True) + + p = torch.exp(logp) + p = p/p.sum(dim=-1,keepdim=True) + z = (-logq).argsort()[...,:n] + z = F.one_hot(z,self.K) + decoder_kwargs = dict() if decoder_kwargs is None else decoder_kwargs + c_tiled = c.unsqueeze(1).repeat(1,n,1) + x_out = self.decoder(latents=z.reshape(-1,self.K), condition_features=c_tiled.reshape(-1,c.shape[-1]), **decoder_kwargs) + x_out = TensorUtils.reshape_dimensions(x_out,0,1,(z.shape[0],n)) + return {"x_recons": x_out, "q": q, "p": p, "z": z, "c": c} + + def compute_kl_loss(self, outputs: dict): + """ + Compute KL Divergence loss + + Args: + outputs (dict): outputs of the self.forward() call + + Returns: + a dictionary of loss values + """ + p = outputs["p"] + q = outputs["q"] + return (p*(torch.log(p)-torch.log(q))).sum(dim=-1).mean() + + def compute_losses(self,outputs,targets,gamma=1): + recon_loss = 0 + for k,v in outputs['x_recons'].items(): + if k in targets: + if isinstance(self.recon_loss_fun,dict): + loss_v = self.recon_loss_fun[k](v,targets[k].unsqueeze(1)) + else: + loss_v = self.recon_loss_fun(v,targets[k].unsqueeze(1)) + sum_dim=tuple(range(2,loss_v.ndim)) + loss_v = loss_v.sum(dim=sum_dim) + loss_v_detached = loss_v.detach() + min_flag = (loss_v==loss_v.min(dim=1,keepdim=True)[0]) + nonmin_flag = torch.logical_not(min_flag) + recon_loss +=(loss_v*min_flag*outputs["q"]).sum(dim=1)+(loss_v_detached*nonmin_flag*outputs["q"]).sum(dim=1) + + KL_loss = self.compute_kl_loss(outputs) + return recon_loss + gamma*KL_loss + + +class ECDiscreteCVAE(DiscreteCVAE): + def sample(self, condition_inputs, n: int,cond_traj = None, decoder_kwargs=None): + """ + Draw data samples (x') given a batch of condition inputs (x_c) and the VAE prior. + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) + n (int): number of samples to draw + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') of size [B, n, ...] + """ + assert n<=self.K + + if cond_traj is not None: + condition_inputs["cond_traj"] = cond_traj + c = self.c_net(condition_inputs) # [B, ...] + + + bs,Na = c.shape[:2] + c_joined = TensorUtils.join_dimensions(c,0,2) + logp = self.p_net(c_joined)["logp"] + p = torch.exp(logp) + p = p/p.sum(dim=-1,keepdim=True) + # z = (-logp).argsort()[...,:n] + # z = F.one_hot(z,self.K) + + dis_p = Categorical(probs=p) # [n_sample, batch] -> [batch, n_sample] + + z = dis_p.sample((n,)).permute(1, 0) + z = F.one_hot(z, self.K) + + z_samples = TensorUtils.join_dimensions(z, begin_axis=0, end_axis=2) # [B * Na * n, ...] + c_samples = TensorUtils.repeat_by_expand_at(c_joined, repeats=n, dim=0) # [B * Na * n, ...] + + if decoder_kwargs is None: + decoder_kwargs = dict() + else: + + if decoder_kwargs["current_states"].ndim==2: + decoder_kwargs = TensorUtils.unsqueeze_expand_at(decoder_kwargs,Na,1) + else: + assert decoder_kwargs["current_states"].ndim==3 and decoder_kwargs["current_states"].shape[1]==Na + decoder_kwargs = TensorUtils.unsqueeze_expand_at(decoder_kwargs,n,2) + decoder_kwargs = TensorUtils.join_dimensions(decoder_kwargs,0,3) + x_out = self.decoder(latents=z_samples, condition_features=c_samples, **decoder_kwargs) + x_out = TensorUtils.reshape_dimensions(x_out, begin_axis=0, end_axis=1, target_dims=(bs,Na,n)) + + return x_out + + def predict(self, condition_inputs, cond_traj = None, decoder_kwargs=None): + """ + Generate a prediction based on latent prior (instead of sample) and condition inputs + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched predictions (x') of size [B, ...] + + """ + if cond_traj is not None: + condition_inputs["cond_traj"] = cond_traj + c = self.c_net(condition_inputs) # [B, ...] + + bs,Na = c.shape[:2] + c_joined = TensorUtils.join_dimensions(c,0,2) + logp = self.p_net(c_joined)["logp"] + z = logp.argmax(dim=-1) + + if decoder_kwargs is None: + decoder_kwargs = dict() + else: + if decoder_kwargs["current_states"].ndim==2: + decoder_kwargs = TensorUtils.unsqueeze_expand_at(decoder_kwargs,Na,1) + else: + assert decoder_kwargs["current_states"].ndim==3 and decoder_kwargs["current_states"].shape[1]==Na + decoder_kwargs = TensorUtils.join_dimensions(decoder_kwargs,0,2) + x_out = self.decoder(latents=F.one_hot(z,self.K), condition_features=c_joined, **decoder_kwargs) + x_out = TensorUtils.reshape_dimensions(x_out, begin_axis=0, end_axis=1, target_dims=(bs,Na)) + return x_out + + def forward(self, inputs, condition_inputs, cond_traj, decoder_kwargs=None): + """ + Pass the input through encoder and decoder (using posterior parameters) + Args: + inputs (dict, torch.Tensor): encoder inputs (x) + condition_inputs (dict, torch.Tensor): condition inputs - (x_c) + n (int): number of samples, if not given, then n=self.K + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') + """ + condition_inputs["cond_traj"] = cond_traj + c = self.c_net(condition_inputs) # [B, ...] + c_joined = TensorUtils.join_dimensions(c,0,2) + bs,Na = c.shape[:2] + logp = TensorUtils.reshape_dimensions(self.p_net(c_joined)["logp"],0,1,(bs,Na)) + if inputs is not None: + inputs_tiled = TensorUtils.unsqueeze_expand_at(inputs,Na,1) + inputs_joined = TensorUtils.join_dimensions(inputs_tiled,0,2) + logq = TensorUtils.reshape_dimensions(self.q_net(inputs=inputs_joined, condition_features=c_joined)["logq"],0,1,(bs,Na)) + else: + logq = logp + if self.logpi_clamp is not None: + logq = logq.clamp(min=self.logpi_clamp,max=2.0) + logp = logp.clamp(min=self.logpi_clamp,max=2.0) + + q = torch.exp(logq) + p = torch.exp(logp) + p = p.nan_to_num(nan=0.0, posinf=1.0, neginf=0.0) + q = q.nan_to_num(nan=0.0, posinf=1.0, neginf=0.0) + q = q/q.sum(dim=-1,keepdim=True) + p = p/p.sum(dim=-1,keepdim=True) + + z = torch.arange(self.K).to(q.device).tile(*q.shape[:-1],1) + z = F.one_hot(z,self.K) + + if decoder_kwargs is None: + decoder_kwargs = dict() + else: + if decoder_kwargs["current_states"].ndim==2: + decoder_kwargs = TensorUtils.unsqueeze_expand_at(decoder_kwargs,Na,1) + else: + assert decoder_kwargs["current_states"].ndim==3 and decoder_kwargs["current_states"].shape[1]==Na + decoder_kwargs = TensorUtils.unsqueeze_expand_at(decoder_kwargs,self.K,2) + decoder_kwargs = TensorUtils.join_dimensions(decoder_kwargs,0,3) + + c_tiled = c[:,:,None].repeat(1,1,self.K,1) + x_out = self.decoder(latents=z.reshape(-1,self.K), condition_features=TensorUtils.join_dimensions(c_tiled,0,3), **decoder_kwargs) + x_out = TensorUtils.reshape_dimensions(x_out,0,1,(bs,Na,self.K)) + return {"x_recons": x_out, "q": q, "p": p, "z": z} + + def compute_kl_loss(self, outputs: dict): + """ + Compute KL Divergence loss + + Args: + outputs (dict): outputs of the self.forward() call + + Returns: + a dictionary of loss values + """ + p = outputs["p"] + q = outputs["q"] + return (p*(torch.log(p)-torch.log(q))).sum(dim=-1).mean() + + def compute_losses(self,outputs,targets,gamma=1): + recon_loss = 0 + for k,v in outputs['x_recons'].items(): + if k in targets: + if isinstance(self.recon_loss_fun,dict): + loss_v = self.recon_loss_fun[k](v,targets[k].unsqueeze(1)) + else: + loss_v = self.recon_loss_fun(v,targets[k].unsqueeze(1)) + sum_dim=tuple(range(2,loss_v.ndim)) + loss_v = loss_v.sum(dim=sum_dim) + loss_v_detached = loss_v.detach() + min_flag = (loss_v==loss_v.min(dim=1,keepdim=True)[0]) + nonmin_flag = torch.logical_not(min_flag) + recon_loss +=(loss_v*min_flag*outputs["q"]).sum(dim=1)+(loss_v_detached*nonmin_flag*outputs["q"]).sum(dim=1) + + KL_loss = self.compute_kl_loss(outputs) + return recon_loss + gamma*KL_loss + + +class SceneDiscreteCVAE(DiscreteCVAE): + def __init__(self, + q_net: nn.Module, + p_net: nn.Module, + c_net: nn.Module, + decoder: nn.Module, + transformer: nn.Module, + K: int, + aggregate_func="max", + recon_loss_fun=None, + logpi_clamp = None, + num_latent_sample = None, + ): + super(SceneDiscreteCVAE,self).__init__(q_net, p_net, c_net, decoder, K, recon_loss_fun, logpi_clamp) + self.transformer = transformer + self.aggregate_func = aggregate_func + if num_latent_sample is None: + num_latent_sample = self.K + assert num_latent_sample<=self.K + self.num_latent_sample = num_latent_sample + + def sample(self, condition_inputs, mask, pos, n: int,cond_traj = None, decoder_kwargs=None): + """ + Draw data samples (x') given a batch of condition inputs (x_c) and the VAE prior. + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) [B, Na, D] + mask (torch.Tensor): mask of the agents in the scene [B,Na] + pos (torch.Tensor): position of the agents in the scene [B,Na,2] + n (int): number of samples to draw + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') of size [B, n, ...] + """ + assert n<=self.K + + + bs,Na = next(iter(condition_inputs.values())).shape[:2] + condition_inputs = TensorUtils.join_dimensions(condition_inputs,0,2) + if cond_traj is not None: + condition_inputs["cond_traj"] = cond_traj #[B,T,3] + + c = self.c_net(condition_inputs).reshape(bs,Na,-1) # [B*Na, ...] + + c = self.transformer(c,mask,pos)+c + if self.aggregate_func == "max": + c_agg = c.max(1)[0] + elif self.aggregate_func == "mean": + c_agg = c.mean(1) + logp = self.p_net(c_agg)["logp"] + p = torch.exp(logp) + p = p/p.sum(dim=-1,keepdim=True) + + dis_p = Categorical(probs=p) # [n_sample, batch] -> [batch, n_sample] + + z = dis_p.sample((n,)).permute(1, 0) + z = F.one_hot(z, self.K) #[B,n,K] + z = TensorUtils.repeat_by_expand_at(z,repeats=Na,dim=0) + + z_samples = TensorUtils.join_dimensions(z, begin_axis=0, end_axis=2) # [B * Na * n, ...] + c_samples = TensorUtils.repeat_by_expand_at(c, repeats=n, dim=0) # [B * Na * n, ...] + + if decoder_kwargs is None: + decoder_kwargs = dict() + else: + + assert decoder_kwargs["current_states"].ndim==3 and decoder_kwargs["current_states"].shape[1]==Na + decoder_kwargs = TensorUtils.unsqueeze_expand_at(decoder_kwargs,n,2) + decoder_kwargs = TensorUtils.join_dimensions(decoder_kwargs,0,3) + x_out = self.decoder(latents=z_samples, condition_features=c_samples, **decoder_kwargs) + x_out = TensorUtils.reshape_dimensions(x_out, begin_axis=0, end_axis=1, target_dims=(bs,Na,n)) + + return x_out + + def predict(self, condition_inputs, mask, pos, cond_traj = None, decoder_kwargs=None): + """ + Generate a prediction based on latent prior (instead of sample) and condition inputs + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) + mask (torch.Tensor): mask of the agents in the scene [B,Na] + pos (torch.Tensor): position of the agents in the scene [B,Na,2] + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched predictions (x') of size [B, ...] + + """ + + bs,Na = next(iter(condition_inputs.values())).shape[:2] + condition_inputs = TensorUtils.join_dimensions(condition_inputs,0,2) + if cond_traj is not None: + condition_inputs["cond_traj"] = cond_traj #[B,T,3] + + c = self.c_net(condition_inputs).reshape(bs,Na,-1) # [B*Na, ...] + + c = self.transformer(c,mask,pos)+c + if self.aggregate_func == "max": + c_agg = c.max(1)[0] + elif self.aggregate_func == "mean": + c_agg = c.mean(1) + logp = self.p_net(c_agg)["logp"] + p = torch.exp(logp) + p = p/p.sum(dim=-1,keepdim=True) + + + z = p.argmax(-1) #[B] + z = F.one_hot(z, self.K) #[B,K] + z = TensorUtils.repeat_by_expand_at(z,repeats=Na,dim=0) #[B*Na,K] + + z = TensorUtils.join_dimensions(z, begin_axis=0, end_axis=2) # [B * Na , ...] + + + if decoder_kwargs is None: + decoder_kwargs = dict() + else: + assert decoder_kwargs["current_states"].ndim==3 and decoder_kwargs["current_states"].shape[1]==Na + decoder_kwargs = TensorUtils.join_dimensions(decoder_kwargs,0,2) + x_out = self.decoder(latents=z, condition_features=c, **decoder_kwargs) + x_out = TensorUtils.reshape_dimensions(x_out, begin_axis=0, end_axis=1, target_dims=(bs,Na)) + + return x_out + + def forward(self, inputs, condition_inputs, mask, pos, cond_traj=None, decoder_kwargs=None): + """ + Pass the input through encoder and decoder (using posterior parameters) + Args: + inputs (dict, torch.Tensor): encoder inputs (x) + condition_inputs (dict, torch.Tensor): condition inputs - (x_c) + mask (torch.Tensor): mask of the agents in the scene [B,Na] + pos (torch.Tensor): position of the agents in the scene [B,Na,2] + n (int): number of samples, if not given, then n=self.K + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') + """ + + + bs,Na = next(iter(condition_inputs.values())).shape[:2] + condition_inputs = TensorUtils.join_dimensions(condition_inputs,0,2) + if cond_traj is not None: + condition_inputs["cond_traj"] = TensorUtils.join_dimensions(cond_traj,0,2) #[B*Na,T,3] + + c = self.c_net(condition_inputs).reshape(bs,Na,-1) # [B*Na, ...] + + c = self.transformer(c,mask,pos)+c + if self.aggregate_func == "max": + c_agg = c.max(1)[0] + elif self.aggregate_func == "mean": + c_agg = c.mean(1) + logp = self.p_net(c_agg)["logp"] + p = torch.exp(logp) + p = p/p.sum(dim=-1,keepdim=True) + if inputs is not None: + # inputs_joined = TensorUtils.join_dimensions(inputs,0,2) + logq = self.q_net(inputs,c,mask,pos)["logq"] + else: + logq = logp + if self.logpi_clamp is not None: + logq = logq.clamp(min=self.logpi_clamp,max=2.0) + logp = logp.clamp(min=self.logpi_clamp,max=2.0) + + q = torch.exp(logq) + p = torch.exp(logp) + p = p.nan_to_num(nan=0.0, posinf=1.0, neginf=0.0) + q = q.nan_to_num(nan=0.0, posinf=1.0, neginf=0.0) + + _,z = torch.topk(q,self.num_latent_sample,dim=-1) + p = p.take_along_dim(z,-1) + q = q.take_along_dim(z,-1) + q = q/q.sum(dim=-1,keepdim=True) + p = p/p.sum(dim=-1,keepdim=True) + z = F.one_hot(z,self.K) + + + if decoder_kwargs is None: + decoder_kwargs = dict() + else: + if "current_states" in decoder_kwargs: + assert decoder_kwargs["current_states"].ndim==3 and decoder_kwargs["current_states"].shape[1]==Na + + decoder_kwargs = TensorUtils.unsqueeze_expand_at(decoder_kwargs,self.num_latent_sample,2) + decoder_kwargs = TensorUtils.join_dimensions(decoder_kwargs,0,3) + + c_tiled = c[:,:,None].repeat(1,1,self.num_latent_sample,1) + x_out = self.decoder( + latents=TensorUtils.unsqueeze_expand_at(z,Na,1).reshape(-1,self.K), + condition_features=TensorUtils.join_dimensions(c_tiled,0,3), + **decoder_kwargs + ) + + x_out = TensorUtils.reshape_dimensions(x_out,0,1,(bs,Na,self.num_latent_sample)) + return {"x_recons": x_out, "q": q, "p": p, "z": z} + + def compute_kl_loss(self, outputs: dict): + """ + Compute KL Divergence loss + + Args: + outputs (dict): outputs of the self.forward() call + + Returns: + a dictionary of loss values + """ + p = outputs["p"] + q = outputs["q"] + return (p*(torch.log(p)-torch.log(q))).sum(dim=-1).mean() + + def compute_losses(self,outputs,targets,gamma=1): + recon_loss = 0 + for k,v in outputs['x_recons'].items(): + if k in targets: + if isinstance(self.recon_loss_fun,dict): + loss_v = self.recon_loss_fun[k](v,targets[k].unsqueeze(1)) + else: + loss_v = self.recon_loss_fun(v,targets[k].unsqueeze(1)) + sum_dim=tuple(range(2,loss_v.ndim)) + loss_v = loss_v.sum(dim=sum_dim) + loss_v_detached = loss_v.detach() + min_flag = (loss_v==loss_v.min(dim=1,keepdim=True)[0]) + nonmin_flag = torch.logical_not(min_flag) + recon_loss +=(loss_v*min_flag*outputs["q"]).sum(dim=1)+(loss_v_detached*nonmin_flag*outputs["q"]).sum(dim=1) + + KL_loss = self.compute_kl_loss(outputs) + return recon_loss + gamma*KL_loss + + +class SceneDiscreteCVAEDiverse(SceneDiscreteCVAE): + def __init__(self, + q_net: nn.Module, + p_net: nn.Module, + c_net: nn.Module, + decoder: nn.Module, + transformer: nn.Module, + latent_embeding: nn.Module, + K: int, + aggregate_func="max", + recon_loss_fun=None, + logpi_clamp = None, + num_latent_sample=None): + super(SceneDiscreteCVAEDiverse,self).__init__(q_net,p_net,c_net, decoder, transformer, K, aggregate_func, recon_loss_fun, logpi_clamp, num_latent_sample) + self.latent_embeding = latent_embeding + def sample(self, condition_inputs, mask, pos, n: int,cond_traj = None, decoder_kwargs=None): + + """ + Draw data samples (x') given a batch of condition inputs (x_c) and the VAE prior. + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) [B, Na, D] + mask (torch.Tensor): mask of the agents in the scene [B,Na] + pos (torch.Tensor): position of the agents in the scene [B,Na,2] + n (int): number of samples to draw + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') of size [B, n, ...] + """ + assert n<=self.K + res = self.forward(None, condition_inputs, mask, pos, cond_traj, decoder_kwargs) + dis_p = Categorical(probs=res["p"]) # [n_sample, batch] -> [batch, n_sample] + + z = dis_p.sample((n,)).permute(1, 0) + x_out = {k:v[z] for k,v in res["x_recons"].items()} + return x_out + + def predict(self, condition_inputs, mask, pos, cond_traj = None, decoder_kwargs=None): + """ + Generate a prediction based on latent prior (instead of sample) and condition inputs + + Args: + condition_inputs (dict, torch.Tensor): condition inputs (x_c) + mask (torch.Tensor): mask of the agents in the scene [B,Na] + pos (torch.Tensor): position of the agents in the scene [B,Na,2] + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched predictions (x') of size [B, ...] + + """ + + bs,Na = next(iter(condition_inputs.values())).shape[:2] + condition_inputs = TensorUtils.join_dimensions(condition_inputs,0,2) + if cond_traj is not None: + condition_inputs["cond_traj"] = cond_traj #[B,T,3] + + c = self.c_net(condition_inputs).reshape(bs,Na,-1) # [B*Na, ...] + z_enum = torch.eye(self.K).repeat(bs,Na,1,1).to(c.device) + z_emb = self.latent_embeding(z_enum) + latent_emb_dim = z_emb.shape[-1] + c_tiled = c.unsqueeze(-2).repeat_interleave(self.K,-2) + cz_tiled = torch.cat((c_tiled,z_emb),-1) + + cz_tiled = TensorUtils.join_dimensions(cz_tiled.transpose(1,2),0,2) + mask_tiled = mask.repeat_interleave(self.K,0) + pos_tiled = pos.repeat_interleave(self.K,0) + + cz = self.transformer(cz_tiled,mask_tiled,pos_tiled)+cz_tiled + if self.aggregate_func == "max": + c_agg = cz.max(1)[0][...,-latent_emb_dim:] + elif self.aggregate_func == "mean": + c_agg = cz.mean(1)[...,-latent_emb_dim:] + logp = self.p_net(c_agg)["logp"] + logp = logp.reshape(bs,self.K) + + + p = torch.exp(logp) + + if self.logpi_clamp is not None: + logp = logp.clamp(min=self.logpi_clamp,max=2.0) + p = torch.exp(logp) + p = p.nan_to_num(nan=0.0, posinf=1.0, neginf=0.0) + p = p/p.sum(dim=-1,keepdim=True) + z_idx = p.argmax(dim=1) + z_idx = torch.arange(bs).to(c.device)*self.K+z_idx + + + + + if decoder_kwargs is None: + decoder_kwargs = dict() + else: + if "current_states" in decoder_kwargs: + if decoder_kwargs["current_states"].ndim==2: + decoder_kwargs = TensorUtils.unsqueeze_expand_at(decoder_kwargs,self.K,1) + else: + assert decoder_kwargs["current_states"].ndim==3 and decoder_kwargs["current_states"].shape[1]==Na + decoder_kwargs = TensorUtils.join_dimensions(decoder_kwargs,0,3) + x_out = self.decoder(TensorUtils.join_dimensions(cz[z_idx],0,2), **decoder_kwargs) + + + x_out = TensorUtils.reshape_dimensions(x_out,0,1,(bs, Na)) + + return x_out + + def forward(self, inputs, condition_inputs, mask, pos, cond_traj=None, decoder_kwargs=None): + """ + Pass the input through encoder and decoder (using posterior parameters) + Args: + inputs (dict, torch.Tensor): encoder inputs (x) + condition_inputs (dict, torch.Tensor): condition inputs - (x_c) + mask (torch.Tensor): mask of the agents in the scene [B,Na] + pos (torch.Tensor): position of the agents in the scene [B,Na,2] + n (int): number of samples, if not given, then n=self.K + decoder_kwargs (dict): Extra keyword args for decoder (e.g., dynamics model states) + + Returns: + dictionary of batched samples (x') + """ + + + bs,Na = next(iter(condition_inputs.values())).shape[:2] + condition_inputs = TensorUtils.join_dimensions(condition_inputs,0,2) + if cond_traj is not None: + condition_inputs["cond_traj"] = TensorUtils.join_dimensions(cond_traj,0,2) #[B*Na,T,3] + + c = self.c_net(condition_inputs).reshape(bs,Na,-1) # [B*Na, ...] + z_enum = torch.eye(self.K).repeat(bs,Na,1,1).to(c.device) + z_emb = self.latent_embeding(z_enum) + latent_emb_dim = z_emb.shape[-1] + c_tiled = c.unsqueeze(-2).repeat_interleave(self.K,-2) + cz_tiled = torch.cat((c_tiled,z_emb),-1) + + cz_tiled = TensorUtils.join_dimensions(cz_tiled.transpose(1,2),0,2) + mask_tiled = mask.repeat_interleave(self.K,0) + pos_tiled = pos.repeat_interleave(self.K,0) + + cz = self.transformer(cz_tiled,mask_tiled,pos_tiled)+cz_tiled + if self.aggregate_func == "max": + c_agg = cz.max(1)[0][...,-latent_emb_dim:] + elif self.aggregate_func == "mean": + c_agg = cz.mean(1)[...,-latent_emb_dim:] + logp = self.p_net(c_agg)["logp"] + logp = logp.reshape(bs,self.K) + + + p = torch.exp(logp) + if inputs is not None: + # inputs_joined = TensorUtils.join_dimensions(inputs,0,2) + inputs_tiled = {k:v.repeat_interleave(self.K,0) for k,v in inputs.items()} + logq = self.q_net(inputs_tiled,cz[...,-latent_emb_dim:],mask_tiled,pos_tiled)["logq"] + logq = logq.reshape(bs,self.K) + else: + logq = logp + if self.logpi_clamp is not None: + logq = logq.clamp(min=self.logpi_clamp,max=3.0) + logp = logp.clamp(min=self.logpi_clamp,max=3.0) + + q = torch.exp(logq) + p = torch.exp(logp) + p = p.nan_to_num(nan=0.0, posinf=1.0, neginf=0.0) + q = q.nan_to_num(nan=0.0, posinf=1.0, neginf=0.0) + + _,z = torch.topk(q,self.num_latent_sample,dim=-1) + p = p.take_along_dim(z,-1) + q = q.take_along_dim(z,-1) + + cz = cz.reshape(bs,Na,self.K,-1) + cz = cz.take_along_dim(z[:,None,:,None],2) + cz = cz.transpose(1,2) + q = q/q.sum(dim=-1,keepdim=True) + p = p/p.sum(dim=-1,keepdim=True) + + + if decoder_kwargs is None: + decoder_kwargs = dict() + else: + if "current_states" in decoder_kwargs: + assert decoder_kwargs["current_states"].ndim==3 and decoder_kwargs["current_states"].shape[1]==Na + decoder_kwargs = TensorUtils.unsqueeze_expand_at(decoder_kwargs,self.num_latent_sample,1) + decoder_kwargs = TensorUtils.join_dimensions(decoder_kwargs,0,3) + x_out = self.decoder(TensorUtils.join_dimensions(cz,0,3), **decoder_kwargs) + + + x_out = TensorUtils.reshape_dimensions(x_out,0,1,(bs, self.num_latent_sample, Na)) + x_out = {k:v.transpose(1,2) for k,v in x_out.items()} + return {"x_recons": x_out, "q": q, "p": p, "z": z} + + + + +def main(): + import diffstack.models.base_models as l5m + + inputs = OrderedDict(trajectories=torch.randn(10, 50, 3)) + condition_inputs = OrderedDict(image=torch.randn(10, 3, 224, 224)) + + condition_dim = 16 + latent_dim = 4 + + prior = FixedGaussianPrior(latent_dim=4) + + map_encoder = l5m.RasterizedMapEncoder( + model_arch="resnet18", + num_input_channels=3, + feature_dim=128 + ) + + q_encoder = l5m.PosteriorEncoder( + condition_dim=condition_dim, + trajectory_shape=(50, 3), + output_shapes=OrderedDict(mu=(latent_dim,), logvar=(latent_dim,)) + ) + c_encoder = l5m.ConditionEncoder( + map_encoder=map_encoder, + trajectory_shape=(50, 3), + condition_dim=condition_dim + ) + decoder = l5m.ConditionDecoder( + condition_dim=condition_dim, + latent_dim=latent_dim, + output_shapes=OrderedDict(trajectories=(50, 3)) + ) + + model = CVAE( + q_net=q_encoder, + c_net=c_encoder, + decoder=decoder, + prior=prior, + target_criterion=nn.MSELoss(reduction="none") + ) + + + outputs = model(inputs=inputs, condition_inputs=condition_inputs) + losses = model.compute_losses(outputs=outputs, targets=inputs) + samples = model.sample(condition_inputs=condition_inputs, n=10) + print() + + traj_encoder = l5m.RNNTrajectoryEncoder( + trajectory_dim=3, + rnn_hidden_size=100 + ) + + c_net = l5m.ConditionNet( + condition_input_shapes=OrderedDict( + map_feature=(map_encoder.output_shape()[-1],) + ), + condition_dim=condition_dim, + ) + + q_net = l5m.PosteriorNet( + input_shapes=OrderedDict( + traj_feature=(traj_encoder.output_shape()[-1],) + ), + condition_dim=condition_dim, + param_shapes=prior.posterior_param_shapes, + ) + + lean_model = CVAE( + q_net=q_net, + c_net=c_net, + decoder=decoder, + prior=prior, + target_criterion=nn.MSELoss(reduction="none") + ) + + map_feats = map_encoder(condition_inputs["image"]) + traj_feats = traj_encoder(inputs["trajectories"]) + input_feats = dict(traj_feature=traj_feats) + condition_feats = dict(map_feature=map_feats) + + outputs = lean_model(inputs=input_feats, condition_inputs=condition_feats) + losses = lean_model.compute_losses(outputs=outputs, targets=inputs) + samples = lean_model.sample(condition_inputs=condition_feats, n=10) + print() + + +def main_discrete(): + import diffstack.models.base_models as l5m + + inputs = OrderedDict(trajectories=torch.randn(10, 50, 3)) + condition_inputs = OrderedDict(image=torch.randn(10, 3, 224, 224)) + + condition_dim = 16 + latent_dim = 20 + + map_encoder = l5m.RasterizedMapEncoder( + model_arch="resnet18", + feature_dim=128 + ) + + q_encoder = l5m.PosteriorEncoder( + condition_dim=condition_dim, + trajectory_shape=(50, 3), + output_shapes=OrderedDict(logq=(latent_dim,)) + ) + p_encoder = l5m.SplitMLP( + input_dim=condition_dim, + layer_dims=(128,128), + output_shapes=OrderedDict(logp=(latent_dim,)) + ) + c_encoder = l5m.ConditionEncoder( + map_encoder=map_encoder, + trajectory_shape=(50, 3), + condition_dim=condition_dim + ) + decoder_MLP = l5m.SplitMLP( + input_dim=condition_dim+latent_dim, + output_shapes=OrderedDict(trajectories=(50, 3)), + layer_dims=(128,128), + output_activation=nn.ReLU, + ) + decoder = l5m.ConditionDecoder(decoder_model=decoder_MLP) + + model = DiscreteCVAE( + q_net=q_encoder, + p_net=p_encoder, + c_net=c_encoder, + decoder=decoder, + K=latent_dim, + ) + + + outputs = model(inputs=inputs, condition_inputs=condition_inputs) + losses = model.compute_losses(outputs=outputs, targets = inputs) + samples = model.sample(condition_inputs=condition_inputs, n=10) + KL_loss = model.compute_kl_loss(outputs) + + + # traj_encoder = l5m.RNNTrajectoryEncoder( + # trajectory_dim=3, + # rnn_hidden_size=100 + # ) + + # c_net = l5m.ConditionNet( + # condition_input_shapes=OrderedDict( + # map_feature=(map_encoder.output_shape()[-1],) + # ), + # condition_dim=condition_dim, + # ) + + # q_net = l5m.PosteriorNet( + # input_shapes=OrderedDict( + # traj_feature=(traj_encoder.output_shape()[-1],) + # ), + # condition_dim=condition_dim, + # param_shapes=prior.posterior_param_shapes, + # ) + + # lean_model = CVAE( + # q_net=q_net, + # c_net=c_net, + # decoder=decoder, + # prior=prior, + # target_criterion=nn.MSELoss(reduction="none") + # ) + + # map_feats = map_encoder(condition_inputs["image"]) + # traj_feats = traj_encoder(inputs["trajectories"]) + # input_feats = dict(traj_feature=traj_feats) + # condition_feats = dict(map_feature=map_feats) + + # outputs = lean_model(inputs=input_feats, condition_inputs=condition_feats) + # losses = lean_model.compute_losses(outputs=outputs, targets=inputs) + # samples = lean_model.sample(condition_inputs=condition_feats, n=10) + # print() + +if __name__ == "__main__": + main_discrete() \ No newline at end of file diff --git a/diffstack/modules/cost_functions/__init__.py b/diffstack/modules/cost_functions/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/diffstack/modules/cost_functions/cost_functions.py b/diffstack/modules/cost_functions/cost_functions.py deleted file mode 100644 index fd495d0..0000000 --- a/diffstack/modules/cost_functions/cost_functions.py +++ /dev/null @@ -1,589 +0,0 @@ -import torch -import numpy as np - -from typing import Dict, Iterable, Optional, Union, Any, Tuple -from diffstack.modules.cost_functions.linear_base_cost import LinearBaseCost -from diffstack.utils.utils import pt_rbf, angle_wrap - -# Traceable torch functions are meant to be used when tracing computation graph -# for faster autodifferentiation when computing quadratic cost approximations. -# However, in the default LinearCost1 cost every component has analytic gradient -# implmenentations, so these traceable functions are not useful, hence replacing them -# with their original torch counterpart. -tracable_norm = torch.linalg.norm -tracable_rbf = pt_rbf - - -class LinearCost1(LinearBaseCost): - """Linear cost function used in the CoRL 2022 paper. - - The cost is a linear combination of the following terms (in this order): - - lane lateral cost - - lane heading cost - - goal cost - - control cost - - collision cost - """ - theta_dim = 5 - - def forward(self, xu, cost_inputs, keep_components=False): - """Compute cost.""" - gt_neighbors_batch, mus_batch, probs_batch, goal_batch, lanes, _ = cost_inputs - return self._compute_cost(xu, self.theta, gt_neighbors_batch, mus_batch, probs_batch, goal_batch, lanes, rbf_scale=self.rbf_scale_long, keep_components=keep_components) - - def approx_quadratic(self, x, u, cost_inputs, diff=True): - """Compute the quadratic approximation of the cost.""" - gt_neighbors_batch, mus_batch, probs_batch, goal_batch, lanes, _ = cost_inputs - - return self._approx_quadratic(x, u, self.theta, gt_neighbors_batch, mus_batch, probs_batch, goal_batch, lanes, rbf_scale=self.rbf_scale_long, diff=diff) - - @classmethod - def _compute_cost( - cls, - xu: torch.Tensor, # (T+1, x_dims+u_dims) or (T+1, b, x_dims+u_dims) - theta: torch.Tensor, - gt_neighbors: torch.Tensor, # (N-1 or 0, T, K, 2) - mus: torch.Tensor, # (1 or N, T, K, 2) - probs: torch.Tensor, # (1 or N, T, K) - goal: torch.Tensor, # (2, ) - lanes: torch.Tensor, - rbf_scale: Union[float, torch.Tensor] = 2.0, - keep_components: bool = False, - ): - - # Cost terms - cost_terms = cls._compute_cost_terms(xu, gt_neighbors, mus, probs, goal, lanes, rbf_scale=rbf_scale) - - # Theta - if xu.ndim == 3: - theta_vec = theta.unsqueeze(1).unsqueeze(0) # 1, theta_dim, 1 - elif xu.ndim == 2: - theta_vec = theta.unsqueeze(1) # theta_dim, 1 - else: - raise ValueError - - # Weighted sum or weighted components - if keep_components: - return (cost_terms * theta_vec.transpose(-1, -2)) - else: - cost = torch.matmul(cost_terms, theta_vec).squeeze(-1) - return cost # (T, ) or (T, b) - - @staticmethod - def _compute_cost_terms( - xu: torch.Tensor, # (T+1, x_dims+u_dims) or (T+1, b, xu) - gt_neighbors: torch.Tensor, # (N-1 or 0, T, K, 2) or list of same for batch - mus: torch.Tensor, # (1 or N, T, K, 2) or list of same for batch - probs: torch.Tensor, # (1 or N, T, K) or list of same for batch - goal: torch.Tensor, # (2, ) or (b, 2) - lanes: torch.Tensor, - rbf_scale: float, - ): #(T+1, 3, ) or (T+1, b, 3) - - x, u = torch.split(xu, (4, 2), dim=-1) # x, y, orient, vel, d_orient, acc - - # Deal with different time resolutions for prediction and planning - predh = lanes.shape[0]-1 - planh = x.shape[0]-1 - if planh != predh: - assert planh % predh == 0 and planh > predh - num_repeat = planh // predh - gt_neighbors = None if gt_neighbors is None else torch.repeat_interleave(gt_neighbors, num_repeat, dim=1) - mus = None if mus is None else torch.repeat_interleave(mus, num_repeat, dim=1) - lanes = torch.repeat_interleave(lanes, num_repeat, dim=0)[num_repeat-1:] - - assert lanes.shape[:-1] == x.shape[:-1] - assert len(goal.shape) == len(x.shape)-1 - - ego_lane_lat = torch.square(x[..., :2] - lanes[..., :2]).sum(dim=-1) - ego_lane_heading = torch.square(angle_wrap(x[..., 2] - lanes[..., 2])) - # ego_lane_heading = torch.square(x[..., 2] - lanes[..., 2]) - ego_goal = torch.cat((torch.zeros_like(ego_lane_heading)[:-1], - torch.square(x[-1, ..., :2] - goal).sum(dim=-1).unsqueeze(0)), dim=0) * 0.1 - control_cost = torch.square(u).sum(dim=-1) - - collision_reward = LinearCost1._collision_reward( - mus, probs, x, gt_neighbors=gt_neighbors, rbf_scale=rbf_scale) - - cost_terms = [ - ego_lane_lat, - ego_lane_heading, - ego_goal, - control_cost, - -collision_reward] - - cost_terms = torch.stack(cost_terms, dim=-1) # t, (b), theta_dim - return cost_terms - - @staticmethod - def _collision_reward(pred_mus: torch.Tensor, pred_probs: torch.Tensor, ego_x: torch.Tensor, gt_neighbors: torch.Tensor = None, rbf_scale = 2.0, return_grad: bool = False): - # The collision term (with analyitic gradient) is only implemented for a single sample - # so for batched inputs we iterate over the batch. This is highly inefficient. - is_batched = (len(ego_x.shape) == 3) - if is_batched: - reward_outputs = [] - for b_i in range(ego_x.shape[1]): - reward_outputs.append(LinearCost1._collision_reward_single_analytic_grad( # recursive call to itslef - pred_mus[:, :, b_i], pred_probs[:, b_i], ego_x[:, b_i], - gt_neighbors=(None if gt_neighbors is None else gt_neighbors[:, :, b_i]), - rbf_scale=rbf_scale, - return_grad=return_grad)) - - if return_grad: - # TODO refactor with zip - prediction_reward = torch.stack([output[0] for output in reward_outputs], dim=1) - gradients = torch.stack([output[1] for output in reward_outputs], dim=1) - hessians = torch.stack([output[2] for output in reward_outputs], dim=1) - return prediction_reward, gradients, hessians - else: - prediction_reward = torch.stack(reward_outputs, dim=1) - return prediction_reward - else: - return LinearCost1._collision_reward_single_analytic_grad(pred_mus, pred_probs, ego_x, gt_neighbors, rbf_scale=rbf_scale, return_grad=return_grad) - - @staticmethod - def _collision_reward_single_analytic_grad(pred_mus: torch.Tensor, pred_probs: torch.Tensor, x: torch.Tensor, gt_neighbors: torch.Tensor = None, rbf_scale = 2.0, return_grad: bool = False): - """Computes the collision cost and optionally its first and second order gradients. - - Here the gradients are analytically computed. The experimental_costs.py provides alternative implementations - using autodiff that is more general but much less efficient. There have also been attempts to trace the autodiff function, i.e. - create the computation graph for the gradient only once. This does speed up the autodiff variant but the analytic implementation - is still faster. - """ - # For now only support single sample, no batch. - assert len(x.shape) == 2 - assert len(pred_mus.shape) == 4 - assert len(pred_probs.shape) == 2 - - x_mu_delta = x[1:pred_mus.shape[1]+1, :2].unsqueeze(0).unsqueeze(2) - torch.nan_to_num(pred_mus, nan=1e6) - comp_dists = tracable_norm(x_mu_delta, dim=-1) # Distance from predictions (nodes, predhorizon, K) - - # 2) Keep time, produce one value for each time step. - expected_dist = torch.sum(comp_dists * pred_probs.unsqueeze(1), dim=-1) # Factoring in traj. probabilities -> expected closest encounter (nodes, predhorizon) - - # Add in gt - if gt_neighbors is not None: - x_gt_delta = x[1:gt_neighbors.shape[1]+1, :2].unsqueeze(0) - torch.nan_to_num(gt_neighbors, nan=1e6) - gt_comp_dists = tracable_norm(x_gt_delta, dim=-1) # Distance from predictions (nodes, predhorizon) - expected_dist = torch.cat((expected_dist, gt_comp_dists), axis=0) - - # Instead of closest agent, add distance term for all agents - expected_dist = expected_dist.unsqueeze(-1) # (N, T, 1) - full_cost = -tracable_rbf(expected_dist, scale=rbf_scale) # (N, T, ) - ret = full_cost.sum(dim=0) # sum over agents (T, ) - ret = torch.cat((torch.zeros((1, ), dtype=ret.dtype, device=ret.device), ret), dim=0) # extend to (T+1, ) - - if return_grad: - cost_over_scale = full_cost.unsqueeze(-1) / rbf_scale # expected_dist (N, T, 1). Denoted F over a in symbolic_diff. - pred_probs_ext = pred_probs.unsqueeze(1).unsqueeze(-1) # N_pred, 1, K, 1. Denoted p in symbolic_diff. - comp_dists_ext = comp_dists.unsqueeze(-1) # N_pred, T, K, 1. Denoted d in symbolic_diff. - # expected_dist_pred_agents: This is sum of d over K weighted with p. Denoted s_dp on paper. - expected_dist_pred_agents = expected_dist[:x_mu_delta.shape[0]] # N_pred, T, 1. - # precompute common terms - # x_mu_delta: x-x_ak # N_pred, T, K, 2 - x_mu_delta_over_comp_dists = x_mu_delta / comp_dists_ext # (x-x_ak)/d - x_mu_delta_over_comp_dists_squared = x_mu_delta_over_comp_dists.square() # ((x-x_ak)/d)^2 - pred_probs_over_comp_dists = pred_probs_ext / comp_dists_ext - - # --- Gradient - # Gradient for predicted agents - grad_pred = x_mu_delta * pred_probs_over_comp_dists # N_pred, T, K, 2 - grad_pred = expected_dist_pred_agents * grad_pred.sum(-2) # sum over K, remains: N_pred, T, 2 - if gt_neighbors is not None: - # Gradient for gt agents - grad_gt = x_gt_delta # N_gt, T, 2 - # Gradient combine predicted and gt - grad = torch.cat((grad_pred, grad_gt), axis=0) # N, T, 2 - else: - grad = grad_pred - - grad = -cost_over_scale * grad # N, T, 2 - grad = grad.sum(0) # sum over agents (T, 2) - grad = torch.cat((torch.zeros((1, 2), dtype=grad.dtype, device=grad.device), grad), dim=0) # extend to (T+1, ) - - # --- Hessian for diagonals H11 and H22 - # Hessian diagonals for predicted agents - hess_d1 = pred_probs_over_comp_dists * (x_mu_delta_over_comp_dists_squared - 1.) # N_pred, T, K, 2 - hess_d1 = expected_dist_pred_agents * hess_d1.sum(-2) # sum over K, remains: N_pred, T, 2 - - hess_d2 = pred_probs_over_comp_dists * x_mu_delta # N_pred, T, K, 2 - hess_d2 = hess_d2.sum(-2) # sum over K, remains N_pred, T, 2 - shared_term1 = (expected_dist_pred_agents.square() / rbf_scale - 1.) - hess_d2 = hess_d2.square() * shared_term1 # N_pred, T, 2 - - hess_d_pred = (hess_d1 + hess_d2) # N_pred, T, 2 - - # --- Hessian for antidiagonals H12 == H21 - # Hessian antidiagonals for predicted agents - hess_a1 = x_mu_delta.prod(dim=-1, keepdim=True) # (x-x_a)(y-y_a), remains N_pred, T, K, 1 - hess_a1 = hess_a1 * pred_probs_over_comp_dists / comp_dists_ext.square() # N_pred, T, K, 1 - hess_a1 = expected_dist_pred_agents * hess_a1.sum(-2) # sum over K, remains N_pred, T, 1 - - hess_a2 = pred_probs_over_comp_dists * x_mu_delta # N_pred, T, K, 2 - hess_a2 = hess_a2.sum(-2) # sum over K, remains: N_pred, T, 2 - hess_a2 = hess_a2.prod(dim=-1, keepdim=True) # prod over xy, remains: N_pred, T, 1 - hess_a2 = hess_a2 * shared_term1 - - hess_a_pred = (hess_a1 + hess_a2) # N_pred, T, 2 - - # ---- Combine pred and gt - if gt_neighbors is not None: - # Hessian diagonals for gt agents - hess_d_gt = (x_gt_delta.square()/rbf_scale - 1.) # N_gt, T, 2 - - # combine predicted and gt agents - hess_d = cost_over_scale * torch.cat((hess_d_pred, hess_d_gt), axis=0) # N, T, 2 - - # Hessian antidiagonals for gt agents - hess_a_gt = x_gt_delta.prod(dim=-1, keepdim=True)/rbf_scale # N_gt, T, 1 - - # combine predicted and gt agents - hess_a = cost_over_scale * torch.cat((hess_a_pred, hess_a_gt), axis=0) # N, T, 1 - - else: - hess_d = cost_over_scale * hess_d_pred # N, T, 2 - hess_a = cost_over_scale * hess_a_pred # N, T, 1 - - # Build Hessian matrix from H11, H22, H12 - hess_11, hess_22 = torch.split(hess_d, (1, 1), dim=-1) - hess_12 = hess_a - hess_21 = hess_a - - hess = torch.stack(( - torch.cat((hess_11, hess_12), dim=-1), # first column - torch.cat((hess_21, hess_22), dim=-1)), # second column - dim=-1) # N, T, 2, 2 - - hess = hess.sum(0) # sum over agents (T, 2) - hess = torch.cat((torch.zeros((1, 2, 2), dtype=hess.dtype, device=hess.device), hess), dim=0) # extend to (T+1, ) - - return ret, grad, hess - - return ret - - @staticmethod - def _approx_quadratic( - x: torch.Tensor, # (T+1, x_dims+u_dims) or (T+1, b, xu) - u: torch.Tensor, # (T+1, x_dims+u_dims) or (T+1, b, xu) - theta: torch.Tensor, - gt_neighbors_batch: torch.Tensor, # (N-1 or 0, T, K, 2) or list of same for batch - mus_batch: torch.Tensor, # (1 or N, T, K, 2) or list of same for batch - probs_batch: torch.Tensor, # (1 or N, T, K) or list of same for batch - goal: torch.Tensor, # (2, ) or (b, 2) - lanes: torch.Tensor, - rbf_scale=2.0, - diff=True): #(T+1, 3, ) or (T+1, b, 3) - """Directly computes a quadratic approximation of the cost.""" - - # Deal with different time resolutions for prediction and planning - predh = 6 - planh = x.shape[0]-1 - if planh != predh: - num_repeat = planh // predh - gt_neighbors_batch = None if gt_neighbors_batch is None else torch.repeat_interleave(gt_neighbors_batch, num_repeat, dim=1) - mus_batch = None if mus_batch is None else torch.repeat_interleave(mus_batch, num_repeat, dim=1) - lanes = torch.repeat_interleave(lanes, num_repeat, dim=0)[num_repeat-1:] - - # We are looking for a cost in the form: - # cost = 1/2 * x^T*C*x + c^T*x - x = x.detach() - u = u.detach() - - # the pad function is in inverse order dimensions - # (1, 2) will pad the last dimension with 1 and 2 - # (1, 2, 3, 4) will also pad the second to last dimension with 3 and 4 - pad = torch.nn.functional.pad - - # ego_lane_lat = torch.square(x[..., :2] - lanes[..., :2]).sum(dim=-1) - # ego_lane_heading = torch.square(x[..., 2] - lanes[..., 2]) - # We need to wrap the lane heading lanes[..., 2] with 2pi such that it is close to x[..., 2]. - lane_h = x[..., 2] - angle_wrap(x[..., 2] - lanes[..., 2]) - lanes = torch.stack((lanes[..., 0], lanes[..., 1], lane_h), dim=-1) - # Now we can treat heading error term as if it was simply ego_lane_heading = torch.square(x[..., 2] - lanes[..., 2]) - zero_tensor = torch.zeros_like(theta[0]) - ego_lane_C = torch.diag(torch.stack((theta[0], theta[0], theta[1], zero_tensor))).unsqueeze(0).unsqueeze(1).type_as(x) # ..., 4, 4 - ego_lane_c = -2. * lanes[..., :3] * torch.stack((theta[0], theta[0], theta[1])).unsqueeze(0).type_as(x) # ..., 3 - ego_lane_c = torch.cat((ego_lane_c, torch.zeros_like(ego_lane_c)[..., :1]), dim=-1) # ..., 4 - - # goal only for last state - # ego_goal = torch.square(x[-1, ..., :2] - goal).sum(dim=-1) - ego_goal_C = torch.diag(torch.stack((theta[2], theta[2], zero_tensor, zero_tensor))).unsqueeze(0).unsqueeze(1).type_as(x) # 1, 1, 4, 4 - ego_goal_c = -2. * goal.unsqueeze(0) * torch.stack((theta[2], theta[2])).unsqueeze(0).type_as(x) # 1, 1, 2 - # pad along theta_dim and along time, so we only have nonzero only for last step - ego_goal_C = pad(ego_goal_C, (0, 0, 0, 0, 0, 0, x.shape[0]-1, 0)) # T, 1, 4, 4 - ego_goal_c = pad(ego_goal_c, (0, 2, 0, 0, x.shape[0]-1, 0)) # T, 1, 4 - - control_C = torch.diag(torch.stack((theta[3], theta[3]))).unsqueeze(0).unsqueeze(1).type_as(x) # ..., 2, 2 - - _, grads, hessians = LinearCost1._collision_reward( - mus_batch, probs_batch, x, gt_neighbors=gt_neighbors_batch, rbf_scale=rbf_scale, return_grad=True) - grads = -grads * theta[4] - hessians = -hessians * theta[4] - - # hessian matrix * tau vector. Using matmul to do this for last two dims, keeping T, batch - # This comes from the Taylor approximation, expanding the quadratic term with Hessian will give a linear term. - grads = grads - torch.matmul(hessians, x[..., :2].unsqueeze(-1)).squeeze(-1) - if not diff: - hessians = hessians.detach() - grads = grads.detach() - - # Combine elements - state_C = 2. * (ego_lane_C + ego_goal_C) # t, 1, 4, 4 - state_c = ego_lane_c + ego_goal_c # t, b, 4 - - C = state_C - C = torch.nn.functional.pad(C, (0, 2, 0, 2)) # t, 1, 6, 6 - C = C.repeat(1, hessians.shape[1], 1, 1) # t, b, 6, 6 - C[..., :2, :2] += hessians - C[..., -2:, -2:] += 2. * control_C - c = torch.nn.functional.pad(state_c, (0, 2)) # t, b, 6 - c[..., :2] += grads - - return C, c - - -class LinearCostAngleBug(LinearCost1): - - @staticmethod - def _compute_cost_terms( - xu: torch.Tensor, # (T+1, x_dims+u_dims) or (T+1, b, xu) - gt_neighbors: torch.Tensor, # (N-1 or 0, T, K, 2) or list of same for batch - mus: torch.Tensor, # (1 or N, T, K, 2) or list of same for batch - probs: torch.Tensor, # (1 or N, T, K) or list of same for batch - goal: torch.Tensor, # (2, ) or (b, 2) - lanes: torch.Tensor, - rbf_scale: float, - ): #(T+1, 3, ) or (T+1, b, 3) - - x, u = torch.split(xu, (4, 2), dim=-1) # x, y, orient, vel, d_orient, acc - - # Deal with different time resolutions for prediction and planning - predh = lanes.shape[0]-1 - planh = x.shape[0]-1 - if planh != predh: - assert planh % predh == 0 and planh > predh - num_repeat = planh // predh - gt_neighbors = None if gt_neighbors is None else torch.repeat_interleave(gt_neighbors, num_repeat, dim=1) - mus = None if mus is None else torch.repeat_interleave(mus, num_repeat, dim=1) - lanes = torch.repeat_interleave(lanes, num_repeat, dim=0)[num_repeat-1:] - - assert lanes.shape[:-1] == x.shape[:-1] - assert len(goal.shape) == len(x.shape)-1 - - ego_lane_lat = torch.square(x[..., :2] - lanes[..., :2]).sum(dim=-1) - # ego_lane_heading = torch.square(angle_wrap(x[..., 2] - lanes[..., 2])) - ego_lane_heading = torch.square(x[..., 2] - lanes[..., 2]) - ego_goal = torch.cat((torch.zeros_like(ego_lane_heading)[:-1], - torch.square(x[-1, ..., :2] - goal).sum(dim=-1).unsqueeze(0)), dim=0) * 0.1 - control_cost = torch.square(u).sum(dim=-1) - - collision_reward = LinearCost1._collision_reward( - mus, probs, x, gt_neighbors=gt_neighbors, rbf_scale=rbf_scale) - - cost_terms = [ - ego_lane_lat, - ego_lane_heading, - ego_goal, - control_cost, - -collision_reward] - - cost_terms = torch.stack(cost_terms, dim=-1) # t, (b), theta_dim - return cost_terms - - - @staticmethod - def _approx_quadratic( - x: torch.Tensor, # (T+1, x_dims+u_dims) or (T+1, b, xu) - u: torch.Tensor, # (T+1, x_dims+u_dims) or (T+1, b, xu) - theta: torch.Tensor, - gt_neighbors_batch: torch.Tensor, # (N-1 or 0, T, K, 2) or list of same for batch - mus_batch: torch.Tensor, # (1 or N, T, K, 2) or list of same for batch - probs_batch: torch.Tensor, # (1 or N, T, K) or list of same for batch - goal: torch.Tensor, # (2, ) or (b, 2) - lanes: torch.Tensor, - rbf_scale=2.0, - diff=True): #(T+1, 3, ) or (T+1, b, 3) - """Directly computes a quadratic approximation of the cost.""" - - # Deal with different time resolutions for prediction and planning - predh = 6 - planh = x.shape[0]-1 - if planh != predh: - num_repeat = planh // predh - gt_neighbors_batch = None if gt_neighbors_batch is None else torch.repeat_interleave(gt_neighbors_batch, num_repeat, dim=1) - mus_batch = None if mus_batch is None else torch.repeat_interleave(mus_batch, num_repeat, dim=1) - lanes = torch.repeat_interleave(lanes, num_repeat, dim=0)[num_repeat-1:] - - # We are looking for a cost in the form: - # cost = 1/2 * x^T*C*x + c^T*x - x = x.detach() - u = u.detach() - - # the pad function is in inverse order dimensions - # (1, 2) will pad the last dimension with 1 and 2 - # (1, 2, 3, 4) will also pad the second to last dimension with 3 and 4 - pad = torch.nn.functional.pad - - # ego_lane_lat = torch.square(x[..., :2] - lanes[..., :2]).sum(dim=-1) - # ego_lane_heading = torch.square(x[..., 2] - lanes[..., 2]) - # # We need to wrap the lane heading lanes[..., 2] with 2pi such that it is close to x[..., 2]. - # lane_h = x[..., 2] - angle_wrap(x[..., 2] - lanes[..., 2]) - # lanes = torch.stack((lanes[..., 0], lanes[..., 1], lane_h), dim=-1) - # Now we can treat heading error term as if it was simply ego_lane_heading = torch.square(x[..., 2] - lanes[..., 2]) - zero_tensor = torch.zeros_like(theta[0]) - ego_lane_C = torch.diag(torch.stack((theta[0], theta[0], theta[1], zero_tensor))).unsqueeze(0).unsqueeze(1).type_as(x) # ..., 4, 4 - ego_lane_c = -2. * lanes[..., :3] * torch.stack((theta[0], theta[0], theta[1])).unsqueeze(0).type_as(x) # ..., 3 - ego_lane_c = torch.cat((ego_lane_c, torch.zeros_like(ego_lane_c)[..., :1]), dim=-1) # ..., 4 - - # goal only for last state - # ego_goal = torch.square(x[-1, ..., :2] - goal).sum(dim=-1) - ego_goal_C = torch.diag(torch.stack((theta[2], theta[2], zero_tensor, zero_tensor))).unsqueeze(0).unsqueeze(1).type_as(x) # 1, 1, 4, 4 - ego_goal_c = -2. * goal.unsqueeze(0) * torch.stack((theta[2], theta[2])).unsqueeze(0).type_as(x) # 1, 1, 2 - # pad along theta_dim and along time, so we only have nonzero only for last step - ego_goal_C = pad(ego_goal_C, (0, 0, 0, 0, 0, 0, x.shape[0]-1, 0)) # T, 1, 4, 4 - ego_goal_c = pad(ego_goal_c, (0, 2, 0, 0, x.shape[0]-1, 0)) # T, 1, 4 - - control_C = torch.diag(torch.stack((theta[3], theta[3]))).unsqueeze(0).unsqueeze(1).type_as(x) # ..., 2, 2 - - _, grads, hessians = LinearCost1._collision_reward( - mus_batch, probs_batch, x, gt_neighbors=gt_neighbors_batch, rbf_scale=rbf_scale, return_grad=True) - grads = -grads * theta[4] - hessians = -hessians * theta[4] - - # hessian matrix * tau vector. Using matmul to do this for last two dims, keeping T, batch - # This comes from the Taylor approximation, expanding the quadratic term with Hessian will give a linear term. - grads = grads - torch.matmul(hessians, x[..., :2].unsqueeze(-1)).squeeze(-1) - if not diff: - hessians = hessians.detach() - grads = grads.detach() - - # Combine elements - state_C = 2. * (ego_lane_C + ego_goal_C) # t, 1, 4, 4 - state_c = ego_lane_c + ego_goal_c # t, b, 4 - - C = state_C - C = torch.nn.functional.pad(C, (0, 2, 0, 2)) # t, 1, 6, 6 - C = C.repeat(1, hessians.shape[1], 1, 1) # t, b, 6, 6 - C[..., :2, :2] += hessians - C[..., -2:, -2:] += 2. * control_C - c = torch.nn.functional.pad(state_c, (0, 2)) # t, b, 6 - c[..., :2] += grads - - return C, c - -class InterpretableLinearCost1(LinearCost1): - """Interpretable counterpart of LinearCost1. - - This class computes more interpretable cost terms, and effectively assumes theta equals one. - The intended usage is to call forward with keep_components=True, which will return N interpretable cost terms. - - # TODO Add progress term instead of goal lateral distance - # TODO Remove hardcoded dt=0.5 and predh=6 - """ - theta_dim = 8 - - def forward(self, xu, cost_inputs, keep_components=False): - # return torch.zeros([xu.shape[0], xu.shape[1]], device=xu.device) - gt_neighbors_batch, mus_batch, probs_batch, goal_batch, lanes, _ = cost_inputs - dummy_theta = torch.ones_like(self.theta) - return self._compute_cost(xu, dummy_theta, gt_neighbors_batch, mus_batch, probs_batch, goal_batch, lanes, rbf_scale=self.rbf_scale_long, keep_components=keep_components) - - @staticmethod - def _compute_cost_terms( - xu: torch.Tensor, # (T+1, x_dims+u_dims) or (T+1, b, xu) - gt_neighbors: torch.Tensor, # (N-1 or 0, T, K, 2) or list of same for batch - mus: torch.Tensor, # (1 or N, T, K, 2) or list of same for batch - probs: torch.Tensor, # (1 or N, T, K) or list of same for batch - goal: torch.Tensor, # (2, ) or (b, 2) - lanes: torch.Tensor, - rbf_scale: float, - ): #(T+1, 3, ) or (T+1, b, 3) - - x, u = torch.split(xu, (4, 2), dim=-1) # x, y, orient, vel, d_orient, acc - - # Deal with different time resolutions for prediction and planning - predh = lanes.shape[0]-1 - planh = x.shape[0]-1 - if planh != predh: - num_repeat = planh // predh - gt_neighbors = None if gt_neighbors is None else torch.repeat_interleave(gt_neighbors, num_repeat, dim=1) - mus = None if mus is None else torch.repeat_interleave(mus, num_repeat, dim=1) - lanes = torch.repeat_interleave(lanes, num_repeat, dim=0)[num_repeat-1:] - - assert lanes.shape[:-1] == x.shape[:-1] - assert len(goal.shape) == len(x.shape)-1 - - ego_lane_lat = torch.sqrt(torch.square(x[..., :2] - lanes[..., :2]).sum(dim=-1)) - ego_lane_heading = torch.sqrt(torch.square(x[..., 2] - lanes[..., 2])) - ego_goal = torch.repeat_interleave( - torch.sqrt(torch.square(x[-1, ..., :2] - goal).sum(dim=-1)).unsqueeze(0), ego_lane_heading.shape[0], dim=0) - control_cost = torch.sqrt(torch.square(u).sum(dim=-1)) - control_cost1 = torch.sqrt(torch.square(u[..., 0])) - control_cost2 = torch.sqrt(torch.square(u[..., 1])) - - prediction_reward = InterpretableLinearCost1._collision_reward( - mus, probs, x, gt_neighbors=gt_neighbors, rbf_scale=rbf_scale) - distance = InterpretableLinearCost1._distance_to_closest( - mus, probs, x, gt_neighbors=gt_neighbors, rbf_scale=rbf_scale) - - cost_terms = [ - ego_lane_lat, - ego_lane_heading, - ego_goal, - control_cost, - control_cost1, - control_cost2, - prediction_reward, - distance] - - cost_terms = torch.stack(cost_terms, dim=-1) # t, (b), theta_dim - return cost_terms - - @staticmethod - def _distance_to_closest(pred_mus: torch.Tensor, pred_probs: torch.Tensor, ego_x: torch.Tensor, gt_neighbors: torch.Tensor = None, rbf_scale = 2.0, return_grad: bool = False): - """Distance to closest agent. For batched input we simply iterate over batch.""" - is_batched = (len(ego_x.shape) == 3) - if is_batched: - reward_outputs = [] - for b_i in range(ego_x.shape[1]): - reward_outputs.append(InterpretableLinearCost1._distance_to_closest_single( # recursive call to itslef - pred_mus[:, :, b_i], pred_probs[:, b_i], ego_x[:, b_i], - gt_neighbors=(None if gt_neighbors is None else gt_neighbors[:, :, b_i]), - rbf_scale=rbf_scale, - return_grad=return_grad)) - - if return_grad: - # TODO would be nicer with zip - prediction_reward = torch.stack([output[0] for output in reward_outputs], dim=1) - gradients = torch.stack([output[1] for output in reward_outputs], dim=1) - hessians = torch.stack([output[2] for output in reward_outputs], dim=1) - return prediction_reward, gradients, hessians - else: - prediction_reward = torch.stack(reward_outputs, dim=1) - return prediction_reward - else: - return InterpretableLinearCost1._distance_to_closest_single(pred_mus, pred_probs, ego_x, gt_neighbors, rbf_scale=rbf_scale, return_grad=return_grad) - - @staticmethod - def _distance_to_closest_single(pred_mus: torch.Tensor, pred_probs: torch.Tensor, x: torch.Tensor, gt_neighbors: torch.Tensor = None, rbf_scale = 2.0, return_grad: bool = False): - """Distance to closest agent for single input.""" - assert pred_mus.shape[0] == 0 - - # Add in gt - if gt_neighbors is not None: - x_gt_delta = x[1:gt_neighbors.shape[1]+1, :2].unsqueeze(0) - torch.nan_to_num(gt_neighbors, nan=1e6) - gt_comp_dists = tracable_norm(x_gt_delta, dim=-1) # Distance from predictions (nodes, predhorizon) - expected_dist = gt_comp_dists - - # Instead of closest agent, add distance term for all agents - expected_dist = expected_dist.unsqueeze(-1) # (N, T, 1) - closest_dist = expected_dist.min(dim=0).values.min(dim=0).values # (1,) - - ret = closest_dist.repeat_interleave(expected_dist.shape[1], dim=0) # (T, ) - ret = torch.cat((torch.zeros((1, ), dtype=ret.dtype, device=ret.device), ret), dim=0) # extend to (T+1, ) - - if return_grad: - return ret, None, None - - return ret - diff --git a/diffstack/modules/cost_functions/cost_selector.py b/diffstack/modules/cost_functions/cost_selector.py deleted file mode 100644 index 83ba5ab..0000000 --- a/diffstack/modules/cost_functions/cost_selector.py +++ /dev/null @@ -1,36 +0,0 @@ - -import torch -import numpy as np - -from diffstack.modules.cost_functions.cost_functions import LinearCost1, LinearCostAngleBug, InterpretableLinearCost1 - - -def get_cost_object(plan_cost_mode, control_limits, is_trainable, device): - - if plan_cost_mode == 'default': - cost_class = LinearCost1 - # With the distance-based prediction term, num_trajs 128, motion_plan ego preds, L-BFGS lr=0.95 - cost_kwargs = dict( - theta=torch.from_numpy(np.array([1.721, 0.652, 0., 8.755, 1.362, 1.440])).float().to(device)) - elif plan_cost_mode == 'interpretable': - # Mpc1 cost - cost_class = InterpretableLinearCost1 - cost_kwargs = dict( - theta=torch.from_numpy(np.array([1.0] * 8)).float().to(device), - ) - elif plan_cost_mode == 'corl_default': # default for corl paper - cost_class = LinearCostAngleBug - cost_kwargs = dict( - theta=torch.from_numpy(np.array([0.3, 0.3, 0.5, 1.0, 20.0])).float().to(device), - control_limits=control_limits, - ) - elif plan_cost_mode == 'corl_default_angle_fix': # fixed angle wrap - cost_class = LinearCost1 - cost_kwargs = dict( - theta=torch.from_numpy(np.array([0.3, 0.3, 0.5, 1.0, 20.0])).float().to(device), - control_limits=control_limits, - ) - else: - raise NotImplementedError(plan_cost_mode) - - return cost_class(**cost_kwargs, is_trainable=is_trainable) \ No newline at end of file diff --git a/diffstack/modules/cost_functions/linear_base_cost.py b/diffstack/modules/cost_functions/linear_base_cost.py deleted file mode 100644 index 53d933c..0000000 --- a/diffstack/modules/cost_functions/linear_base_cost.py +++ /dev/null @@ -1,157 +0,0 @@ -import torch -from typing import Dict, Iterable, Optional, Union, Any, Tuple - - -class LinearBaseCost(torch.nn.Module): - """Cost function that is a linear combination of N cost terms. - """ - def __init__(self, - theta: torch.TensorType, - normalized: Optional[bool] = False, - scaler: Optional[bool] = False, - is_trainable: Optional[bool] = False, - trainable_dims: Optional[Iterable] = None, - rbf_scale_long: Optional[float] = 2.0, - rbf_scale_lat: Optional[float] = 2.0, - trainable_rbf_scaler: Optional[bool] = False, - control_limits: Optional[tuple] = None): - super().__init__() - - self.control_limits = control_limits - - # Rbf - if trainable_rbf_scaler: - self.rbf_scale_long = torch.nn.Parameter(torch.Tensor([rbf_scale_long]).to(device=theta.device).squeeze().detach(), requires_grad=is_trainable) - self.rbf_scale_lat = torch.nn.Parameter(torch.Tensor([rbf_scale_lat]).to(device=theta.device).squeeze().detach(), requires_grad=is_trainable) - else: - self.rbf_scale_long = rbf_scale_long - self.rbf_scale_lat = rbf_scale_lat - - # Theta - assert len(theta) == self.theta_dim - if trainable_dims is None: - trainable_dims = [True] * self.theta_dim - assert len(trainable_dims) == self.theta_dim - self.normalized = normalized - self.trainable_dims = trainable_dims - - if normalized: - # Normalized parameterization, theta always sums to the initial theta sum. - theta_sum = (theta * torch.Tensor(trainable_dims).float().to(device=theta.device)).sum() - log_theta = torch.log(theta / theta_sum).detach() # normalize to sum to one - self.log_theta = [] - self.fixed_log_theta = [] - for i in range(log_theta.shape[0]): - attr_name = f"log_theta_{i+1}" - if trainable_dims[i]: - setattr(self, attr_name, torch.nn.Parameter(log_theta[i], requires_grad=is_trainable)) - self.log_theta.append(getattr(self, attr_name)) - else: - setattr(self, attr_name, torch.nn.Parameter(log_theta[i], requires_grad=False)) - self.fixed_log_theta.append(getattr(self, attr_name)) - self.theta_params = None - if scaler: - self.theta_scaler = torch.nn.Parameter(theta_sum.unsqueeze(0), requires_grad=is_trainable) - else: - self.theta_scaler = theta_sum.detach() - - assert (torch.isclose(theta, self.theta).all().detach()) - - else: - # Direct parameterization - self.log_theta = None - if trainable_dims is not None and not all(trainable_dims): - raise NotImplementedError - else: - self.theta_params = [torch.nn.Parameter(theta[i], requires_grad=is_trainable) for i in range(theta.shape[0])] - self.theta_scaler = 1. - assert not scaler - - def forward(self, xu, cost_inputs, keep_components=False): - raise NotImplementedError - - def approx_quadratic(self, x, u, cost_inputs, diff=True): - raise NotImplementedError - - @property - def theta(self): - if self.normalized: - # We cannot have the softmax op created in the init. Accessing theta through - # this getter function will create a new op every time. - th_train = iter(torch.unbind(torch.softmax(torch.stack(self.log_theta, dim=0), dim=0), dim=0)) - th_fixed = iter(self.fixed_log_theta) - th = [] - for is_dim_trainable in self.trainable_dims: - if is_dim_trainable: - th.append(next(th_train)) - else: - th.append(torch.exp(next(th_fixed))) - return torch.stack(th, dim=0) * self.theta_scaler - else: - assert self.trainable_dims is None or all(self.trainable_dims) - return torch.stack(self.theta_params, dim=0) - - @property - def theta_standardized(self): - # Decouple theta normalized to sum to one, and an overal scaler - with torch.no_grad(): - theta_sum = self.theta.sum() - return torch.cat((self.theta / theta_sum, theta_sum.unsqueeze(0))) - - def get_params_log(self): - theta = self.theta_standardized.detach() - log_dict = {} - for i in range(theta.shape[0]): - log_dict[f"theta_{i+1}"] = theta[i].item() - if isinstance(self.rbf_scale_long, torch.nn.Parameter): - log_dict[f"theta_rbf_lat"] = self.rbf_scale_lat.item() - log_dict[f"theta_rbf_long"] = self.rbf_scale_long.item() - return log_dict - - def get_params_summary_str(self): - s = f"Plan cost theta: {self.theta.detach().cpu().numpy()}" - s += f"\nPlan cost theta: {self.theta_standardized.detach().cpu().numpy()}" - if isinstance(self.rbf_scale_long, torch.nn.Parameter): - s += f" RBF scaler long={str(self.rbf_scale_long.detach().cpu().numpy())} lat={str(self.rbf_scale_lat.detach().cpu().numpy())}" - return s - - def approximate_quadratic_autodiff_naive(self, x, u, cost_inputs=None, diff=True): - """Adopted from mpc.approximate_cost""" - with torch.enable_grad(): - tau = torch.cat((x, u), dim=2).detach() - tau = torch.autograd.Variable(tau, requires_grad=True) - - grads = torch.autograd.functional.jacobian( - lambda tau: self.forward(tau, cost_inputs).sum(), tau, create_graph=True, vectorize=True) - hessians = list() - for v_i in range(tau.shape[2]): # over state dimensions - hessians.append( - torch.autograd.grad(grads[..., v_i].sum(), tau, retain_graph=True)[0] - ) - hessians = torch.stack(hessians, dim=-1) # 7, 209, 6, 6 - - # hessian matrix * tau vector. Using matmul to do this for last two dims, keeping T, batch - grads = grads - torch.matmul(hessians, tau.unsqueeze(-1)).squeeze(-1) - - if not diff: - return hessians.detach(), grads.detach() - return hessians, grads - - def gt_neighbors_gradient(self, xu, cost_inputs, gt_neighbors, insert_function=None): - # Gradient of cost function wrt. gt_neighbors (gt poses of other agents). - # TODO we could make it more efficient by computing only the collision term of the cost, - # the gradient of other terms (independent of gt_neighbors) are always zero - - if insert_function is None: - # The default way to insert the agent_xy into cost inputs is to replace gt_neighbors_batch - # cost_inputs = gt_neighbors_batch, mus_batch, probs_batch, goal_batch, lanes, _ - insert_function = lambda cost_inputs, agent_xy: ([agent_xy] + list(cost_inputs[1:])) - - with torch.enable_grad(): - tau = torch.autograd.Variable(gt_neighbors, requires_grad=True) - - grads = torch.autograd.functional.jacobian( - lambda tau: self.forward(xu, insert_function(cost_inputs, tau)).sum(), tau, create_graph=True, vectorize=True) - # grads should have the same shape as gt_neighbors - - return grads.detach() diff --git a/diffstack/modules/diffstack.py b/diffstack/modules/diffstack.py deleted file mode 100644 index 7cdc98b..0000000 --- a/diffstack/modules/diffstack.py +++ /dev/null @@ -1,419 +0,0 @@ -import torch - -from collections import OrderedDict -from typing import Dict, Optional, Union, Any, List, Set - -from trajdata.data_structures.batch import AgentBatch -from trajdata.data_structures.batch_element import AgentBatchElement - -from diffstack.modules.module import Module, ModuleSequence, DataFormat, RunMode -from diffstack.modules.predictors.constvel_predictor import ConstVelPredictor -from diffstack.modules.predictors.trajectron_predictor import TrajectronPredictor, TrajectronPredictorWithCacheData -from diffstack.modules.planners.fan_mpc_planner import FanMpcPlanner - -from diffstack.utils.utils import merge_dicts_with_prefix, restore -from diffstack.utils.visualization import visualize_plan_batch - - -class DiffStack(ModuleSequence): - - @property - def input_format(self) -> DataFormat: - return DataFormat(["batch"]) - - @property - def output_format(self) -> DataFormat: - return DataFormat(["pred.pred_dist", "pred.pred_ml", "plan.plan_xu", "loss:train", "loss:validate"]) - - def __init__(self, model_registrar, hyperparams, log_writer, device): - super().__init__([], model_registrar, hyperparams, log_writer, device, input_mappings={}) - - # Build the stack into the components ordered dictionary. - self.components = OrderedDict() - - # Initialize predictor - if self.hyperparams["predictor"] == "constvel": - self.pred_obj = ConstVelPredictor( - model_registrar, hyperparams, log_writer, device, - input_mappings={"agent_batch": "batch"}) - assert not self.hyperparams["train_pred"] - scene_centric = False - elif self.hyperparams["predictor"] in ["tpp", "gt", "nopred"]: - self.pred_obj = TrajectronPredictor( - model_registrar, hyperparams, log_writer, device, - input_mappings={"agent_batch": "batch"}) - scene_centric = False - elif self.hyperparams["predictor"] in ["tpp_cache"]: - self.pred_obj = TrajectronPredictorWithCacheData( - model_registrar, hyperparams, log_writer, device, - input_mappings={"agent_batch": "input.batch"}) - scene_centric = False - else: - raise ValueError(f"Unknown predictor={self.hyperparams['predictor']}") - - self.components["pred"] = self.pred_obj - - # Initialize planner/controller - if self.hyperparams["planner"] in ["", "none"] or not self.hyperparams["plan_train"]: - # TODO Implement a dummy planner with zero outputs, and support different training and inference configs for plan_train parameter - raise NotImplementedError - elif self.hyperparams["planner"] in ["none", "mpc", "fan", "fan_mpc"]: - self.planner_obj = FanMpcPlanner( - model_registrar, hyperparams, log_writer, device, - input_mappings={"agent_batch": "pred.agent_batch" if scene_centric else "input.batch"}) - else: - raise ValueError(f"Unknown planner={self.hyperparams['planner']}") - self.components["plan"] = self.planner_obj - - # Verify inputs/outputs with dry run - self.dry_run(input_keys=['batch', 'loss_weights']) - - # Initialize training counters - self.curr_plan_loss_scaler = None - self.set_curr_epoch(self.curr_epoch) # will initialize curr_plan_loss_scaler - self.set_annealing_params() - - def set_curr_epoch(self, curr_epoch): - super().set_curr_epoch(curr_epoch) - - # update plan_loss_scaler - scale_start = self.hyperparams['plan_loss_scale_start'] - scale_end = self.hyperparams['plan_loss_scale_end'] - nominal_scaler = self.hyperparams['plan_loss_scaler'] - if scale_end < 0 or curr_epoch >= scale_end: - self.curr_plan_loss_scaler = nominal_scaler - elif curr_epoch <= scale_start: - self.curr_plan_loss_scaler = 0. - else: - assert scale_start < scale_end - # get percentage where we are in full range - self.curr_plan_loss_scaler = nominal_scaler * float(curr_epoch-scale_start) / float(scale_end-scale_start) - print ("Setting plan_loss_scaler=%f"%self.curr_plan_loss_scaler) - - def train(self, inputs: Dict): - batch: AgentBatch = inputs["batch"] - batch.to(self.device) - - # Add bias to prediction targets optionally - agent_fut_unbiased = batch.agent_fut.clone() - batch.agent_fut = self.bias_labels_maybe(agent_fut_unbiased) - - # Extend input with prediction loss importance weights - pred_loss_weights = self.prediction_importance_weights( - batch, self.hyperparams["pred_loss_weights"], self.hyperparams["pred_loss_temp"]) - inputs['loss_weights'] = pred_loss_weights - - # Run forward pass of the stack - outputs = self.sequence_components(inputs, run_mode=RunMode.TRAIN) - - # Losses - loss_components = [] - if self.hyperparams["train_pred"]: - loss_components.append(outputs['pred.loss'] * self.hyperparams["pred_loss_scaler"]) - loss_components.append(outputs['plan.loss'] * self.curr_plan_loss_scaler) - loss = torch.stack(loss_components).sum() - outputs.update({"loss": loss}) - - # Log - self.write_training_logs(batch, outputs) - - return outputs - - def validate(self, inputs: Dict): - """Run the stack forward and compute metrics. - """ - batch: AgentBatch = inputs["batch"] - batch.to(self.device) - - # Add bias to prediction targets optionally - agent_fut_unbiased = batch.agent_fut.clone() - batch.agent_fut = self.bias_labels_maybe(agent_fut_unbiased) - - # Run forward pass of the stack - outputs = self.sequence_components(inputs, run_mode=RunMode.VALIDATE) - - outputs["metrics"] = merge_dicts_with_prefix({"pred": outputs['pred.metrics'], "plan": outputs['plan.metrics']}) - - return outputs - - def infer(self, inputs: Dict): - """Run the stack forward without computing metrics or losses. - """ - batch: AgentBatch = inputs["batch"] - batch.to(self.device) - - # Add bias to prediction targets optionally - agent_fut_unbiased = batch.agent_fut.clone() - batch.agent_fut = self.bias_labels_maybe(agent_fut_unbiased) - - # Run forward pass of the stack - outputs = self.sequence_components(inputs, run_mode=RunMode.INFER) - - return outputs - - def write_training_logs(self, batch, outputs): - # Log planning loss. Logs for prediction loss are added in mgcvae.loss() - if self.log_writer is not None: - plan_active = float(outputs['plan.loss_batch'].shape[0])/float(batch.agent_hist.shape[0]) - plan_converged_rate = outputs['plan.converged'].float().mean() - plan_converged_mean_loss = outputs['plan.loss_batch'][outputs['plan.converged']].mean() - - self.log_writer.log({ - 'Plan/loss/plan': outputs['plan.loss'].detach().item(), - 'Plan/loss/plan_conv': plan_converged_mean_loss.detach().item(), - 'Plan/loss/pred': outputs['pred.loss'].detach().item(), - 'Plan/loss/total': outputs['loss'].detach().item(), - 'Plan/loss/scaler': self.curr_plan_loss_scaler, - 'Plan/active': plan_active, - 'Plan/converged': plan_converged_rate.detach().item(), - }, step=self.curr_iter, commit=False) - - self.log_writer.log({ - 'Plan/'+k: v for k, v in self.planner_obj.cost_obj.get_params_log().items() - }, step=self.curr_iter, commit=False) - - def get_params_summary_text(self): - return self.planner_obj.cost_obj.get_params_summary_str() - - def _deprecated_validation(self, batch: AgentBatch, return_plot_data=False): - batch.to(self.device) - plot_data = {} - - # Add bias to prediction targets optionally - agent_fut_unbiased = batch.agent_fut.clone() - batch.agent_fut = self.bias_labels_maybe(agent_fut_unbiased) - - node_types = batch.agent_types() - if len(node_types) > 1: - raise NotImplementedError("Mixing agent types for prediction in a batch is not supported.") - node_type = node_types[0] - - # Prediction, use the most likely latent mode for predictions, all modes for y_dists. - predictions, y_dist, pred_extra = self.pred_obj._run_prediction(batch, prediction_horizon=batch.agent_fut.shape[1]) - - # planning - if self.hyperparams["planner"] not in ["", "none"] and node_type.name in self.hyperparams["plan_node_types"]: - # TODO support planning for pedestrian prediction - # we can only plan for a vehicle but we can use pedestrian prediction. - - if self.hyperparams["predictor"] == "nopred": - future_mode="nopred" - elif self.hyperparams["predictor"] == "gt": - future_mode="gt" - elif self.hyperparams["predictor"] == "blind": - future_mode="none" - else: - future_mode="pred" - - plan_loss_batch, plan_converged_batch, plan_metrics, plan_info = self.planner_obj._plan_loss( - batch, y_dist, init_mode=self.hyperparams["plan_init"], loss_mode=self.hyperparams["plan_loss"], future_mode=future_mode, pred_extra=pred_extra, return_iters=True) #(b, ) - if not plan_metrics: - # emptly plan batch - metrics_dict = {} - plan_and_fan_valid = None - else: - metrics_dict = dict( - plan_loss=plan_loss_batch, - plan_converged=plan_converged_batch, - plan_cost=plan_metrics['cost'], - plan_unbiased_d1=plan_metrics['unbiased_d1'], - plan_unbiased_d2=plan_metrics['unbiased_d2'], - # TODO better names - # plan_mse_d1=plan_metrics['unbiased_d1'], - # plan_mse=plan_metrics['mse'], - plan_hcost=plan_metrics['hcost'], - plan_valid=plan_info['plan_batch_filter'], - ego_pred_gt_dist=plan_info['ego_pred_gt_dist'], - fan_converged=plan_info['fan_converged'], - ) - if 'class_goaldist' in plan_metrics: - metrics_dict.update({"plan_"+k: plan_metrics[k] for k in ["class_goaldist", "class_hcost", "dist_hcost_loss", "maxmargin_goaldist", "class_mse", "fan_hcost"]}) - plan_and_fan_valid=plan_info['plan_and_fan_valid'] - - # HACK - cost_components_batch = plan_info['cost_components'][1:].mean(0) # mean over future - hcost_components_batch = plan_info['hcost_components'][1:].mean(0) # mean over future - icost_components_batch = plan_info['icost_components'][1:].mean(0) # mean over future - for i in range(cost_components_batch.shape[-1]): - metrics_dict[f"costcomp_{i}"] = cost_components_batch[:, i] - metrics_dict[f"hcostcomp_{i}"] = hcost_components_batch[:, i] - for i in range(icost_components_batch.shape[-1]): - metrics_dict[f"icostcomp_{i}"] = icost_components_batch[:, i] * 1000.0 # scale by 10-3, sum over 6 future steps - - # Debug comparisons with different planner inputs - _, _, nopred_plan_metrics, nopred_plan_info = self.planner_obj._plan_loss( - batch, y_dist, init_mode=self.hyperparams["plan_init"], loss_mode=self.hyperparams["plan_loss"], future_mode="nopred", pred_extra=pred_extra, return_iters=True) #(b, ) - - # _, _, nof_plan_metrics, nof_plan_info = self.planner_obj._plan_loss( - # batch, y_dists, init_mode=self.hyperparams["plan_init"], loss_mode=self.hyperparams["plan_loss"], future_mode="none", return_iters=True) #(b, ) - - _, gt_plan_converged, gt_plan_metrics, gt_plan_info = self.planner_obj._plan_loss( - batch, y_dist, init_mode=self.hyperparams["plan_init"], loss_mode=self.hyperparams["plan_loss"], future_mode="gt", pred_extra=pred_extra, return_iters=True) #(b, ) - - metrics_dict.update(dict( - plan_unbiased_nopred_d1=nopred_plan_metrics['unbiased_d1'], plan_unbiased_nopred_d2=nopred_plan_metrics['unbiased_d2'], plan_nopred_hcost=nopred_plan_metrics['hcost'], - # plan_unbiased_nof_d1=nof_plan_metrics['unbiased_d1'], plan_unbiased_nof_d2=nof_plan_metrics['unbiased_d2'], plan_nof_hcost=nof_plan_metrics['hcost'], - plan_unbiased_gt_d1=gt_plan_metrics['unbiased_d1'], plan_unbiased_gt_d2=gt_plan_metrics['unbiased_d2'], plan_gt_hcost=gt_plan_metrics['hcost'], - plan_gt_converged=gt_plan_converged, - )) - - valid_filter = plan_metrics["fan_valid"] - metrics_dict.update({k+"_valid": v[valid_filter] for k, v in plan_metrics.items()}) - - if return_plot_data: - plot_data = { - 'predictions': predictions.detach(), - 'y_dists': y_dist, - # 'y_for_pred': y_for_pred.detach(), - 'plan': (plan_metrics, plan_info), - 'nopred_plan': (nopred_plan_metrics, nopred_plan_info,), - # 'nof_plan': (nof_plan_metrics, nof_plan_info,), - 'gt_plan': (gt_plan_metrics, gt_plan_info), - } - else: - metrics_dict = {} - plan_and_fan_valid = None - - # Compute default metrics - pred_metrics_dict = compute_prediction_metrics(predictions, batch.agent_fut[..., :2], y_dists=y_dist) - metrics_dict.update(pred_metrics_dict) - unbiased_pred_metrics = compute_prediction_metrics(predictions, agent_fut_unbiased[..., :2], y_dists=y_dist) - metrics_dict.update({k + "_unbiased": v for k, v in unbiased_pred_metrics.items()}) - - if plan_and_fan_valid is not None: - for pred_metric_name in ['ml_ade', 'ml_fde', 'nll_mean', 'nll_final']: - metrics_dict[pred_metric_name+"_valid"] = metrics_dict[pred_metric_name][plan_and_fan_valid] - - return metrics_dict, plot_data - - def augment_states_with_ego_indicator(self, x, x_st_t, neighbor_states, plan_data): - """Set an ego-indicator for prediction based on which agent we will treat as ego in the planner.""" - - raise NotImplementedError("Not implemented for trajdata input") - - if self.hyperparams['pred_ego_indicator'] == 'none': - # Do nothing - return x, x_st_t, neighbor_states - - # Choose which neighbor is ego - if self.hyperparams['pred_ego_indicator'] == 'robot': - ego_node_type = "VEHICLE" - # ego_inds = torch.unbind(plan_data[:, 1].int(), 0) - ego_inds = plan_data['robot_idx'].int() - elif self.hyperparams['pred_ego_indicator'] == 'most_relevant': - ego_node_type = "VEHICLE" - # ego_inds = torch.unbind(plan_data[:, 0].int(), 0) - ego_inds = plan_data['most_relevant_idx'].int() - else: - raise ValueError(f"Unknown pred_ego_indicator {self.hyperparams['pred_ego_indicator']}") - - ext_neighbor_states = {} - for edge_type, node_neighbor_states in neighbor_states.items(): - ext_node_neighbor_states = [] - for batch_i, neighbor_state in enumerate(node_neighbor_states): - ego_ind = ego_inds[batch_i] - # TODO only supports vehicle - if edge_type[1] == ego_node_type: - ext_node_neighbor_states.append([torch.nn.functional.pad(neighbor_state[i], (0, 1), value=(ego_ind == i).float()) for i in range(len(neighbor_state))]) - else: - ext_node_neighbor_states.append([torch.nn.functional.pad(neighbor_state[i], (0, 1), value=0.) for i in range(len(neighbor_state))]) - ext_neighbor_states[edge_type] = ext_node_neighbor_states - - x = torch.nn.functional.pad(x, (0, 1)) - x_st_t = torch.nn.functional.pad(x_st_t, (0, 1)) - - return x, x_st_t, ext_neighbor_states - - def bias_labels_maybe(self, y_t): - if self.hyperparams["bias_predictions"]: - y_t = y_t + 2. - return y_t - - def prediction_importance_weights(self, batch: AgentBatchElement, pred_loss_weights: str, temperature: float): - if pred_loss_weights == "none": - return None - - raise NotImplementedError("Not implemented for trajdata input") - - batch_size = y_gt.shape[0] - zero_weights = torch.zeros((batch_size, ), dtype=y_gt.dtype, device=y_gt.device) - - node_types = batch.agent_types() - if len(node_types) > 1: - raise NotImplementedError("Mixing agent types for prediction in a batch is not supported.") - node_type = node_types[0] - - plan_batch_filter, plan_ego_batch, plan_mus_batch, plan_logp_batch, plan_gt_neighbors_batch, plan_all_gt_neighbors_batch, empty_mus_batch, empty_logp_batch = self.prepare_plan_instance( - node_type, neighbors_future_data, plan_data, y_gt, None, update_fan_inputs=False) - - if plan_ego_batch.shape[1] == 0: - print ("Warning: no valid planning example in batch, keeping all prediction weights zero.") - return zero_weights - - if pred_loss_weights == "dist": - ego_gt_future = plan_ego_batch[1:, ..., :2] # (T, b_plan, 2) - agent_gt_future = y_gt[plan_batch_filter, ..., :2].transpose(1, 0) # (T, b,_plan, 2) - - dists = torch.square(agent_gt_future[..., :2] - ego_gt_future[:, :, :2]).sum(dim=-1).sqrt() # (T, b) - dists = torch.min(dists, dim=0).values # min over time, (b, ) - weights = torch.exp(-dists * temperature) - - elif pred_loss_weights == "grad": - x_gt, u_gt, lanes, x_proj, u_fitted, x_init_batch, goal_batch, lane_points, _, _, _ = self.decode_plan_inputs(plan_ego_batch) - plan_xu_gt = plan_ego_batch[..., :6] # decode_plan removes t0 but we need it here, so take it from the original feature vector - - cost_inputs = (None, empty_mus_batch, empty_logp_batch, goal_batch, lanes, lane_points) # gt neighbors will be replaced - # Take only the gt future for the predicted agent, which is the last in plan_all_gt_neighbors_batch - pred_gt_neighbor = plan_all_gt_neighbors_batch[-1:] - - grads = self.planner_obj.cost_obj.gt_neighbors_gradient(plan_xu_gt, cost_inputs, pred_gt_neighbor) # (1, t, b, 2) - - grad_norm = torch.linalg.norm(grads.squeeze(0), dim=-1) # norm over x, y (t, b) - grad_norm = torch.mean(grad_norm, dim=0) # over time (b, ) - dists = grad_norm # for debug - - # temperature: equivalent of weights = exp(log(grad^2) * temp). - # The 2x scaler creates a similar histogram for temp=1 in the case of dist and grad. - weights = torch.pow(grad_norm, 2 * temperature) - - else: - raise ValueError(f"Unknown setting pred_loss_weights={pred_loss_weights}") - - # extend to full batch, examples with no valid ego will get zero weight - weights_full = zero_weights - weights_full[plan_batch_filter] = weights - # normalize - weights_full = weights_full / weights_full.sum() - - return weights_full - - def visualize_plan(self, nusc_maps, scenes, batch, plot_data, titles, plot_styles, num_plots=-1): - - raise NotImplementedError("Not implemented for trajdata input") - - # shortcut - if not plot_styles: - return {}, [] - - (first_history_index, - x_t, y_t, x_st_t, y_st_t, - neighbors_data_st, - neighbors_edge_value, - robot_traj_st_t, - map, neighbors_future_data, plan_data) = batch - - # Restore encodings - neighbors_data_st = restore(neighbors_data_st) - neighbors_edge_value = restore(neighbors_edge_value) - plan_data = restore(plan_data) - - x = x_t.to(self.device) - x_st_t = x_st_t.to(self.device) - y_t = y_t.to(self.device) - - x, x_st_t, neighbors_data_st = self.augment_states_with_ego_indicator(x, x_st_t, neighbors_data_st, plan_data) - - output, plotted_inds = visualize_plan_batch(nusc_maps, scenes, x_t, y_t, plan_data, plot_data, titles, plot_styles, num_plots, self.hyperparams['planner'], self.ph, self.planh) - - return output, plotted_inds - diff --git a/diffstack/modules/dynamics_functions.py b/diffstack/modules/dynamics_functions.py deleted file mode 100644 index 7037d5c..0000000 --- a/diffstack/modules/dynamics_functions.py +++ /dev/null @@ -1,280 +0,0 @@ -import numpy as np -import torch -from typing import Dict, Iterable, Optional, Union, Any, Tuple - -from diffstack.utils.utils import angle_wrap - - -class LinearizedDynamics(torch.nn.Module): - # Wrapper around nn.Module, just so that we can identify it by type in mpc.py. - def linearized(self, x, u, diff): - raise NotImplementedError - - -class ExtendedUnicycleDynamics(LinearizedDynamics): - """ - This class borrows from mpc the linearization with analytically computed gradients. - """ - def __init__(self, dt): - super().__init__() - self.dt = dt - - def dyn_fn(self, x, u, return_grad=False): - x_p, y_p, phi, v = torch.unbind(x, dim=-1) - omega, a = torch.unbind(u, dim=-1) - dt = self.dt - - # if ego_pred_type == 'const_vel': - # return torch.stack([x_p + v * torch.cos(phi) * dt, - # y_p + v * torch.sin(phi) * dt, - # phi * torch.ones_like(a), - # v], dim=-1) - - mask = torch.abs(omega) <= 1e-2 - omega = ~mask * omega + mask * 1e-2 # TODO why 1? shouldnt it be 0? i guess doesnt matter because we will not use it - - phi_p_omega_dt = angle_wrap(phi + omega * dt) - - sin_phi = torch.sin(phi) - cos_phi = torch.cos(phi) - sin_phi_p_omega_dt = torch.sin(phi_p_omega_dt) - cos_phi_p_omega_dt = torch.cos(phi_p_omega_dt) - a_over_omega = a / omega - dsin_domega = (sin_phi_p_omega_dt - sin_phi) / omega - dcos_domega = (cos_phi_p_omega_dt - cos_phi) / omega - - with torch.no_grad(): - # "This function cannot be differentiated " - # "because for small omega we lost the dependency on omega. " - # "Instead we need to manually compute the limit of gradient wrt. omega when omega --> 0." - d1 = torch.stack([(x_p - + (a_over_omega) * dcos_domega - + v * dsin_domega - + (a_over_omega) * sin_phi_p_omega_dt * dt), - (y_p - - v * dcos_domega - + (a_over_omega) * dsin_domega - - (a_over_omega) * cos_phi_p_omega_dt * dt), - phi_p_omega_dt, - v + a * dt], dim=-1) - d2 = torch.stack([x_p + v * cos_phi * dt + (a / 2) * cos_phi * dt ** 2, - y_p + v * sin_phi * dt + (a / 2) * sin_phi * dt ** 2, - phi * torch.ones_like(a), - v + a * dt], dim=-1) - next_states = torch.where(~mask.unsqueeze(-1), d1, d2) - - if return_grad: - one = torch.ones_like(x_p) - zero = torch.zeros_like(x_p) - a_over_omega2 = a_over_omega / omega - v_over_omega = v / omega - - # derivatives wrt. state dimensions - g_state1 = torch.stack([ - torch.stack([one, zero, zero, zero], axis=-1), - torch.stack([zero, one, zero, zero], axis=-1), - torch.stack([ - dcos_domega*v + dt*a_over_omega*cos_phi_p_omega_dt + a_over_omega2*(sin_phi - sin_phi_p_omega_dt), - dcos_domega*a_over_omega + dt*a_over_omega*sin_phi_p_omega_dt - (v_over_omega)*(sin_phi - sin_phi_p_omega_dt), - one, - zero], axis=-1), - torch.stack([ - dsin_domega, - (cos_phi - cos_phi_p_omega_dt) / omega, - zero, - one], axis=-1), - ], axis=-1) # ..., 4, 4 - - g_control1 = torch.stack([ - # derivatives wrt. control dimensions - torch.stack([ - - 2.* dcos_domega * a_over_omega2 - dsin_domega * v_over_omega + a_over_omega * (dt * dt) * cos_phi_p_omega_dt - 2. * dt * a_over_omega2 * sin_phi_p_omega_dt + dt * v_over_omega * cos_phi_p_omega_dt, - - 2.* dsin_domega * a_over_omega2 + dcos_domega * v_over_omega + a_over_omega * (dt * dt) * sin_phi_p_omega_dt + 2. * dt * a_over_omega2 * cos_phi_p_omega_dt + dt * v_over_omega * sin_phi_p_omega_dt, - one * dt, - zero], axis=-1), - torch.stack([ - (dcos_domega + dt * sin_phi_p_omega_dt) / omega, - (dsin_domega - dt * cos_phi_p_omega_dt) / omega, - zero, - one * dt], axis=-1), - - ], axis=-1) # ...., 4 (states), 2 (controls) - - # derivatives wrt. state dimensions - g_state2 = torch.stack([ - torch.stack([one, zero, zero, zero], axis=-1), - torch.stack([zero, one, zero, zero], axis=-1), - torch.stack([ - (-0.5*dt*a - v) * sin_phi*dt, - (0.5*dt*a + v) * cos_phi*dt, - one, - zero], axis=-1), - torch.stack([ - dt * cos_phi, - dt * sin_phi, - zero, - one], axis=-1), - ], axis=-1) # ..., 4, 4 - - g_control2 = torch.stack([ - # derivatives wrt. control dimensions - torch.stack([ - -(1./3.*dt*a + 0.5*v) * dt*dt*sin_phi, - (1./3.*dt*a + 0.5*v) * dt*dt*cos_phi, - one * dt, - zero], axis=-1), - torch.stack([ - 0.5*dt*dt*cos_phi, - 0.5*dt*dt*sin_phi, - zero, - one * dt], axis=-1), - - ], axis=-1) # ...., 4 (states), 2 (controls) - - g_state = torch.where(~mask.unsqueeze(-1).unsqueeze(-1), g_state1, g_state2) - g_control = torch.where(~mask.unsqueeze(-1).unsqueeze(-1), g_control1, g_control2) - - return next_states, g_state, g_control - else: - return next_states - - def forward(self, x, u): - squeeze = x.ndimension() == 1 - if squeeze: - x = x.unsqueeze(0) - u = u.unsqueeze(0) - - assert x.ndimension() == 2 - assert x.shape[0] == u.shape[0] - assert u.ndimension() == 2 - - # TODO clamp control - # u = torch.clamp(u, -self.max_torque, self.max_torque)[:,0] - - state = self.dyn_fn(x, u) - - if squeeze: - state = state.squeeze(0) - - return state - - def linearized(self, x, u, diff): - # unroll trajectory through time - grad_x_list = [] - grad_u_list = [] - T = u.shape[0] - x_unroll = [x[0]] - for t in range(T-1): - new_x, grad_x, grad_u = self.dyn_fn(x_unroll[-1], u[t], return_grad=True) - grad_x_list.append(grad_x) - grad_u_list.append(grad_u) - x_unroll.append(new_x) - grad_x = torch.stack(grad_x_list, dim=0) - grad_u = torch.stack(grad_u_list, dim=0) - x_unroll = torch.stack(x_unroll, dim=0) - - F = torch.cat((grad_x, grad_u), dim=-1) - f = x_unroll[1:] - torch.matmul(grad_x, x_unroll[:-1].unsqueeze(-1)).squeeze(-1) - torch.matmul(grad_u, u[:-1].unsqueeze(-1)).squeeze(-1) - - if not diff: - F = F.detach() - f = f.detach() - - return F, f - - def linearized_autodiff(self, x, u, diff): - # This is a duplacate of mpc.lineare_dynamics - - dynamics = self.dyn_fn - n_batch = x[0].size(0) - T = u.shape[0] - - with torch.enable_grad(): - # TODO: This is inefficient and confusing. - x_init = x[0] - x = [x_init] - F, f = [], [] - for t in range(T): - if t < T-1: - xt = torch.autograd.Variable(x[t], requires_grad=True) - ut = torch.autograd.Variable(u[t], requires_grad=True) - # xut = torch.cat((xt, ut), 1) - new_x = dynamics(xt, ut) - - # Linear dynamics approximation. - Rt, St = [], [] - for j in range(new_x.shape[1]): # n_state - Rj, Sj = torch.autograd.grad( - new_x[:,j].sum(), [xt, ut], - retain_graph=True) - if not diff: - Rj, Sj = Rj.data, Sj.data - Rt.append(Rj) - St.append(Sj) - Rt = torch.stack(Rt, dim=1) - St = torch.stack(St, dim=1) - - Ft = torch.cat((Rt, St), 2) - F.append(Ft) - - if not diff: - xt, ut, new_x = xt.data, ut.data, new_x.data - ft = new_x - Rt.bmm(xt.unsqueeze(2)).squeeze(2) - St.bmm(ut.unsqueeze(2)).squeeze(2) - f.append(ft) - - if t < T-1: - x.append(new_x if not new_x.requires_grad else new_x.detach()) - - F = torch.stack(F, 0) - f = torch.stack(f, 0) - if not diff: - F, f = list(map(torch.autograd.Variable, [F, f])) - return F, f - - - -def extended_unicycle_dyn_fn(x: Union[torch.Tensor, np.ndarray], - u: Union[torch.Tensor, np.ndarray], - dt: float, - ret_np: bool, - ego_pred_type: str = 'motion_plan'): - x_p = torch.as_tensor(x[..., 0]) - y_p = torch.as_tensor(x[..., 1]) - phi = torch.as_tensor(x[..., 2]) - v = torch.as_tensor(x[..., 3]) - dphi = torch.as_tensor(u[..., 0]) - a = torch.as_tensor(u[..., 1]) - - if ego_pred_type == 'const_vel': - return torch.stack([x_p + v * torch.cos(phi) * dt, - y_p + v * torch.sin(phi) * dt, - phi * torch.ones_like(a), - v], dim=-1) - - mask = torch.abs(dphi) <= 1e-2 - dphi = ~mask * dphi + mask * 1 - - phi_p_omega_dt = angle_wrap(phi + dphi * dt) - dsin_domega = (torch.sin(phi_p_omega_dt) - torch.sin(phi)) / dphi - dcos_domega = (torch.cos(phi_p_omega_dt) - torch.cos(phi)) / dphi - - d1 = torch.stack([(x_p - + (a / dphi) * dcos_domega - + v * dsin_domega - + (a / dphi) * torch.sin(phi_p_omega_dt) * dt), - (y_p - - v * dcos_domega - + (a / dphi) * dsin_domega - - (a / dphi) * torch.cos(phi_p_omega_dt) * dt), - phi_p_omega_dt, - v + a * dt], dim=-1) - d2 = torch.stack([x_p + v * torch.cos(phi) * dt + (a / 2) * torch.cos(phi) * dt ** 2, - y_p + v * torch.sin(phi) * dt + (a / 2) * torch.sin(phi) * dt ** 2, - phi * torch.ones_like(a), - v + a * dt], dim=-1) - - next_states = torch.where(~mask.unsqueeze(-1), d1, d2) - if ret_np: - return next_states.numpy() - else: - return next_states diff --git a/diffstack/modules/metric_models/metrics.py b/diffstack/modules/metric_models/metrics.py new file mode 100644 index 0000000..a448251 --- /dev/null +++ b/diffstack/modules/metric_models/metrics.py @@ -0,0 +1,321 @@ +from collections import OrderedDict +import numpy as np + +import torch +import torch.nn as nn +import torch.optim as optim +import pytorch_lightning as pl + +from diffstack.models.learned_metrics import PermuteEBM +import diffstack.utils.tensor_utils as TensorUtils +from diffstack.utils.batch_utils import batch_utils +from diffstack.utils.geometry_utils import transform_points_tensor +from diffstack.models.base_models import RasterizedMapUNet, SplitMLP +from diffstack.models.Transformer import SimpleTransformer +import diffstack.utils.algo_utils as AlgoUtils + + + +class EBMMetric(pl.LightningModule): + def __init__(self, algo_config, modality_shapes, do_log=True): + """ + Creates networks and places them into @self.nets. + """ + super(EBMMetric, self).__init__() + self.algo_config = algo_config + self.nets = nn.ModuleDict() + self._do_log = do_log + assert modality_shapes["image"][0] == 15 + + self.nets["ebm"] = PermuteEBM( + model_arch=algo_config.model_architecture, + input_image_shape=modality_shapes["image"], # [C, H, W] + map_feature_dim=algo_config.map_feature_dim, + traj_feature_dim=algo_config.traj_feature_dim, + embedding_dim=algo_config.embedding_dim, + embed_layer_dims=algo_config.embed_layer_dims, + ) + + @property + def checkpoint_monitor_keys(self): + return {"valLoss": "val/losses_infoNCE_loss"} + + def forward(self, obs_dict): + return self.nets["ebm"](obs_dict) + + def _compute_metrics(self, pred_batch, data_batch): + scores = pred_batch["scores"] + pred_inds = torch.argmax(scores, dim=1) + gt_inds = torch.arange(scores.shape[0]).to(scores.device) + cls_acc = torch.mean((pred_inds == gt_inds).float()).item() + + return dict(cls_acc=cls_acc) + + def training_step(self, batch, batch_idx): + """ + Training on a single batch of data. + + Args: + batch (dict): dictionary with torch.Tensors sampled + from a data loader and filtered by @process_batch_for_training + + batch_idx (int): training step number - required by some Algos that need + to perform staged training and early stopping + + Returns: + info (dict): dictionary of relevant inputs, outputs, and losses + that might be relevant for logging + """ + batch = batch_utils().parse_batch(batch) + pout = self.nets["ebm"](batch) + losses = self.nets["ebm"].compute_losses(pout, batch) + total_loss = 0.0 + for lk, l in losses.items(): + losses[lk] = l * self.algo_config.loss_weights[lk] + total_loss += losses[lk] + + metrics = self._compute_metrics(pout, batch) + + for lk, l in losses.items(): + self.log("train/losses_" + lk, l) + for mk, m in metrics.items(): + self.log("train/metrics_" + mk, m) + + return { + "loss": total_loss, + "all_losses": losses, + "all_metrics": metrics + } + + def validation_step(self, batch, batch_idx): + batch = batch_utils().parse_batch(batch) + pout = self.nets["ebm"](batch) + losses = TensorUtils.detach(self.nets["ebm"].compute_losses(pout, batch)) + metrics = self._compute_metrics(pout, batch) + return {"losses": losses, "metrics": metrics} + + def validation_epoch_end(self, outputs) -> None: + for k in outputs[0]["losses"]: + m = torch.stack([o["losses"][k] for o in outputs]).mean() + self.log("val/losses_" + k, m) + + for k in outputs[0]["metrics"]: + m = np.stack([o["metrics"][k] for o in outputs]).mean() + self.log("val/metrics_" + k, m) + + def configure_optimizers(self): + optim_params = self.algo_config.optim_params["policy"] + return optim.Adam( + params=self.parameters(), + lr=optim_params["learning_rate"]["initial"], + weight_decay=optim_params["regularization"]["L2"], + ) + + def get_metrics(self, obs_dict): + preds = self.forward(obs_dict) + return dict( + scores=preds["scores"].detach() + ) + + +class OccupancyMetric(pl.LightningModule): + def __init__(self, algo_config, modality_shapes): + super(OccupancyMetric, self).__init__() + self.algo_config = algo_config + self.nets = nn.ModuleDict() + self.agent_future_cond = algo_config.agent_future_cond.enabled + if algo_config.agent_future_cond.enabled: + self.agent_future_every_n_frame = algo_config.agent_future_cond.every_n_frame + self.future_num_frames = int(np.floor(algo_config.future_num_frames/self.agent_future_every_n_frame)) + C,H,W = modality_shapes["image"] + modality_shapes["image"] = (C+self.future_num_frames,H,W) + + self.nets["policy"] = RasterizedMapUNet( + model_arch=algo_config.model_architecture, + input_image_shape=modality_shapes["image"], # [C, H, W] + output_channel=algo_config.future_num_frames + ) + + + + @property + def checkpoint_monitor_keys(self): + keys = {"posErr": "val/metrics_pos_selection_err"} + if self.algo_config.loss_weights.pixel_bce_loss > 0: + keys["valBCELoss"] = "val/losses_pixel_bce_loss" + if self.algo_config.loss_weights.pixel_ce_loss > 0: + keys["valCELoss"] = "val/losses_pixel_ce_loss" + return keys + + def rasterize_agent_future(self,obs_dict): + + b, t_h, h, w = obs_dict["image"].shape # [B, C, H, W] + t_f = self.future_num_frames + + # create spatial supervisions + agent_positions = obs_dict["all_other_agents_future_positions"][:,:,::self.agent_future_every_n_frame] + + pos_raster = transform_points_tensor( + agent_positions.reshape(b,-1,2), + obs_dict["raster_from_agent"].float() + ).reshape(b,-1,t_f,2).long() # [B, T, 2] + # make sure all pixels are within the raster image + pos_raster[..., 0] = pos_raster[..., 0].clip(0, w - 1e-5) + pos_raster[..., 1] = pos_raster[..., 1].clip(0, h - 1e-5) + + # compute flattened pixel location + hist_image = torch.zeros([b,t_f,h*w],dtype=obs_dict["image"].dtype,device=obs_dict["image"].device) + raster_hist_pos_flat = pos_raster[..., 1] * w + pos_raster[..., 0] # [B, T, A] + raster_hist_pos_flat = (raster_hist_pos_flat * obs_dict["all_other_agents_future_availability"][:,:,::self.agent_future_every_n_frame]).long() + + hist_image.scatter_(dim=2, index=raster_hist_pos_flat.transpose(1,2), src=torch.ones_like(hist_image)) # mark other agents with -1 + + hist_image[:, :, 0] = 0 # correct the 0th index from invalid positions + hist_image[:, :, -1] = 0 # correct the maximum index caused by out of bound locations + + return hist_image.reshape(b, t_f, h, w) + + + def forward(self, obs_dict, mask_drivable=False, num_samples=None, clearance=None): + if self.agent_future_cond: + hist_image = self.rasterize_agent_future(obs_dict) + image = torch.cat([obs_dict["image"],hist_image],1) + else: + image = obs_dict["image"] + + pred_map = self.nets["policy"](image) + + return { + "occupancy_map": pred_map + } + + def compute_likelihood(self, occupancy_map, traj_pos, raster_from_agent): + b, t, h, w = occupancy_map.shape # [B, C, H, W] + + # create spatial supervisions + pos_raster = transform_points_tensor( + traj_pos, + raster_from_agent.float() + ) # [B, T, 2] + # make sure all pixels are within the raster image + pos_raster[..., 0] = pos_raster[..., 0].clip(0, w - 1e-5) + pos_raster[..., 1] = pos_raster[..., 1].clip(0, h - 1e-5) + + pos_pixel = torch.floor(pos_raster).float() # round down pixels + + # compute flattened pixel location + pos_pixel_flat = pos_pixel[..., 1] * w + pos_pixel[..., 0] # [B, T] + occupancy_map_flat = occupancy_map.reshape(b, t, -1) + + joint_li_map = torch.softmax(occupancy_map_flat, dim=-1) # [B, T, H * W] + joint_li = torch.gather(joint_li_map, dim=2, index=pos_pixel_flat[:, :, None].long()).squeeze(-1) + indep_li_map = torch.sigmoid(occupancy_map_flat) + indep_li = torch.gather(indep_li_map, dim=2, index=pos_pixel_flat[:, :, None].long()).squeeze(-1) + return { + "indep_likelihood": indep_li, + "joint_likelihood": joint_li + } + + def _compute_losses(self, pred_batch, data_batch): + losses = dict() + pred_map = pred_batch["occupancy_map"] + b, t, h, w = pred_map.shape + + spatial_sup = data_batch["spatial_sup"] + mask = data_batch["target_availabilities"] # [B, T] + # compute pixel classification loss + bce_loss = torch.binary_cross_entropy_with_logits( + input=pred_map, # [B, T, H, W] + target=spatial_sup["traj_spatial_map"], # [B, T, H, W] + ) * mask[..., None, None] + losses["pixel_bce_loss"] = bce_loss.mean() + + ce_loss = torch.nn.CrossEntropyLoss(reduction="none")( + input=pred_map.reshape(b * t, h * w), + target=spatial_sup["traj_position_pixel_flat"].long().reshape(b * t), + ) * mask.reshape(b * t) + + losses["pixel_ce_loss"] = ce_loss.mean() + + return losses + + def _compute_metrics(self, pred_batch, data_batch): + metrics = dict() + spatial_sup = data_batch["spatial_sup"] + + pixel_pred = torch.argmax( + torch.flatten(pred_batch["occupancy_map"], start_dim=2), dim=2 + ) # [B, T] + metrics["pos_selection_err"] = torch.mean( + (spatial_sup["traj_position_pixel_flat"].long() != pixel_pred).float() + ) + + likelihood = self.compute_likelihood( + pred_batch["occupancy_map"], + data_batch["target_positions"], + data_batch["raster_from_agent"] + ) + + metrics["joint_likelihood"] = likelihood["joint_likelihood"].mean() + metrics["indep_likelihood"] = likelihood["indep_likelihood"].mean() + + metrics = TensorUtils.to_numpy(metrics) + for k, v in metrics.items(): + metrics[k] = float(v) + return metrics + + def training_step(self, batch, batch_idx): + batch = batch_utils().parse_batch(batch) + pout = self.forward(batch) + batch["spatial_sup"] = AlgoUtils.get_spatial_trajectory_supervision(batch) + losses = self._compute_losses(pout, batch) + total_loss = 0.0 + for lk, l in losses.items(): + loss = l * self.algo_config.loss_weights[lk] + self.log("train/losses_" + lk, loss) + total_loss += loss + + with torch.no_grad(): + metrics = self._compute_metrics(pout, batch) + for mk, m in metrics.items(): + self.log("train/metrics_" + mk, m) + + return total_loss + + def validation_step(self, batch, batch_idx): + batch = batch_utils().parse_batch(batch) + pout = self(batch) + batch["spatial_sup"] = AlgoUtils.get_spatial_trajectory_supervision(batch) + losses = TensorUtils.detach(self._compute_losses(pout, batch)) + metrics = self._compute_metrics(pout, batch) + return {"losses": losses, "metrics": metrics} + + def validation_epoch_end(self, outputs) -> None: + for k in outputs[0]["losses"]: + m = torch.stack([o["losses"][k] for o in outputs]).mean() + self.log("val/losses_" + k, m) + + for k in outputs[0]["metrics"]: + m = np.stack([o["metrics"][k] for o in outputs]).mean() + self.log("val/metrics_" + k, m) + + def configure_optimizers(self): + optim_params = self.algo_config.optim_params["policy"] + return optim.Adam( + params=self.parameters(), + lr=optim_params["learning_rate"]["initial"], + weight_decay=optim_params["regularization"]["L2"], + ) + + def get_metrics(self, obs_dict,horizon=None): + occup_map = self.forward(obs_dict)["occupancy_map"] + b, t, h, w = occup_map.shape # [B, C, H, W] + if horizon is None: + horizon = t + else: + assert horizon<=t + li = self.compute_likelihood(occup_map, obs_dict["target_positions"], obs_dict["raster_from_agent"]) + li["joint_likelihood"] = li["joint_likelihood"][:,:horizon].mean(dim=-1).detach() + li["indep_likelihood"] = li["indep_likelihood"][:,:horizon].mean(dim=-1).detach() + + return li \ No newline at end of file diff --git a/diffstack/modules/module.py b/diffstack/modules/module.py index 0e031a7..efea2cf 100644 --- a/diffstack/modules/module.py +++ b/diffstack/modules/module.py @@ -1,4 +1,4 @@ -from re import I, S +import dill import numpy as np import torch @@ -20,7 +20,6 @@ def __init__(self, required_elements: Set[str]) -> None: self.required_elements = required_elements def satisfied_by(self, data_dict: Dict[str, Any]) -> bool: - # TODO(pkarkus) check types / shapes not only namess return all(x in data_dict for x in self.required_elements) def __iter__(self) -> str: @@ -39,21 +38,25 @@ def for_run_mode(self, run_mode: RunMode): return DataFormat(elements) - class Module(torch.nn.Module): """Abstract module in a differentiable stack. Inheriting classes need to implement: - input_format - output_format + - train_step() + - validate_step() + - infer_step() + - train/validate/infer methods. """ @property - def name(self) -> str: self.__class__.__name__ + def name(self) -> str: + self.__class__.__name__ @property - def input_format(self) -> DataFormat: + def input_format(self) -> DataFormat: """Required input keys specified as a set of strings wrapped as a DataFormat. The naming convention `my_input:run_mode` is used to identify a key @@ -62,10 +65,10 @@ def input_format(self) -> DataFormat: Example: return DataFormat(["rgb_image", "pointcloud", "label:train"]) """ - raise NotImplementedError + return None @property - def output_format(self) -> DataFormat: + def output_format(self) -> DataFormat: """Output keys specified as a set of strings wrapped as a DataFormat. The naming convention `my_output:run_mode` is used to identify a key @@ -74,25 +77,25 @@ def output_format(self) -> DataFormat: Example: return DataFormat(["prediction", "loss:train", "ml_prediction:infer"]) """ - raise NotImplementedError + return None def __init__( - self, - model_registrar: ModelRegistrar, - hyperparams: Dict[str, Any], - log_writer: Optional[Any], - device: str, - input_mappings: Dict[str, str] = {} + self, + model_registrar: Optional[ModelRegistrar] = None, + hyperparams: Optional[Dict[str, Any]] = None, + log_writer: Optional[Any] = None, + device: Optional[str] = None, + input_mappings: Dict[str, str] = {}, ) -> None: """ Args: model_registrar (ModelRegistrar): handles the registration of trainable parameters hyperparams (dict): config parameters - log_writer (wandb.Run): `wandb.Run` object for logging or None. - device (str): torch device - input_mappings (dict): a remapping of input names of the format {target_name: source_name}. - This is used when connecting multiple modules and we need to rename the outputs of the - previous module to the inputs of this module. + log_writer (wandb.Run): `wandb.Run` object for logging or None. + device (str): torch device + input_mappings (dict): a remapping of input names of the format {target_name: source_name}. + This is used when connecting multiple modules and we need to rename the outputs of the + previous module to the inputs of this module. """ super().__init__() self.model_registrar = model_registrar @@ -103,9 +106,11 @@ def __init__( # Initialize epoch counter self.curr_iter = 0 - self.curr_epoch = 0 + self.curr_epoch = 0 - def apply_input_mappings(self, inputs: Dict, run_mode: RunMode, allow_partial: bool = False): + def apply_input_mappings( + self, inputs: Dict, run_mode: RunMode, allow_partial: bool = True + ): mapped_inputs = {} for k in self.input_format.for_run_mode(run_mode): if k in self.input_mappings: @@ -113,17 +118,21 @@ def apply_input_mappings(self, inputs: Dict, run_mode: RunMode, allow_partial: b mapped_inputs[k] = inputs[self.input_mappings[k]] elif not allow_partial: raise ValueError( - f"Key `{k}` is remapped `{k}`<--`{self.input_mappings[k]}` but " + - f"there is no key `{self.input_mappings[k]}` in inputs.\n " + - f"inputs={list(inputs.keys())};\n " + - f"input_mappings={self.input_mappings}") + f"Key `{k}` is remapped `{k}`<--`{self.input_mappings[k]}` but " + + f"there is no key `{self.input_mappings[k]}` in inputs.\n " + + f"inputs={list(inputs.keys())};\n " + + f"input_mappings={self.input_mappings}" + ) elif k in inputs: mapped_inputs[k] = inputs[k] + elif "input." + k in inputs: + mapped_inputs[k] = inputs["input." + k] elif not allow_partial: raise ValueError( - f"Key `{k}` is not found in inputs and input_mappings.\n " + - f"inputs={list(inputs.keys())};\n " + - f"input_mappings={self.input_mappings}") + f"Key `{k}` is not found in inputs and input_mappings.\n " + + f"inputs={list(inputs.keys())};\n " + + f"input_mappings={self.input_mappings}" + ) return mapped_inputs # Optional functions for tracking training iteration/epoch and annealers. @@ -139,36 +148,43 @@ def set_annealing_params(self): def step_annealers(self, node_type=None): pass - - def forward(self, inputs: Dict) -> Dict: - return self.train(inputs) - - # The user is expected to interact with a module through the train/validate/infer methods. - # Childrens of this abstract class should either override the train/validate/infer - # methods or implement a _run_forward function. - def train(self, inputs: Dict) -> Dict: - return self._run_forward(inputs, RunMode.TRAIN) - - def validate(self, inputs: Dict) -> Dict: - return self._run_forward(inputs, RunMode.VALIDATE) - - def infer(self, inputs: Dict) -> Dict: - return self._run_forward(inputs, RunMode.INFER) - - def _run_forward(self, inputs: Dict, run_mode: RunMode) -> Dict: + + def forward(self, inputs: Dict, **kwargs) -> Dict: + return self.train_step(inputs, **kwargs) + + # Abstract methods that children classes should implement. + def train_step(self, inputs: Dict, **kwargs) -> Dict: + return self._run_forward(inputs, RunMode.TRAIN, **kwargs) + + def validate_step(self, inputs: Dict, **kwargs) -> Dict: + return self._run_forward(inputs, RunMode.VALIDATE, **kwargs) + + def infer_step(self, inputs: Dict, **kwargs) -> Dict: + return self._run_forward(inputs, RunMode.INFER, **kwargs) + + def _run_forward(self, inputs: Dict, run_mode: RunMode, **kwargs) -> Dict: raise NotImplementedError + def set_eval(): + pass + + def set_train(): + pass + + def reset(self): + pass + class ModuleSequence(Module): - """ A sequence of stack modules defined by an ordered dict of {name: Module}. + """A sequence of stack modules defined by an ordered dict of {name: Module}. The output of component N will be fed to the input of component N+1. When the output names of module N does not match the input names of module N+1 we can specify the name mapping with the `input_mappings` argument of module N+1. - The outputs of all previous modules can be referenced by `modulename.outputname`. + The outputs of all previous modules can be referenced by `modulename.outputname`. The overall inputs can be referenced by `input.inputname`. - The output will the sequence will contain the outputs of the last component, as + The output will the sequence will contain the outputs of the last component, as well as the outputs of all components in the `modulename.outputname` format. Example: @@ -178,15 +194,15 @@ class ModuleSequence(Module): We can sequence them in the followin way. stack = ModuleSequence(OrderedDict( - pred=MyPredictor(), + pred=MyPredictor(), plan=MyPlanner(input_mappings={ 'prediction': 'pred.most_likely_pred' - 'ego_state': 'input.ego_state', + 'ego_state': 'input.ego_state', 'goal': 'input.ego_goal', }))) input_dict = {'agent_history', 'ego_state', 'ego_goal'} - output_dict = stack.train(input_dict) + output_dict = stack.train_step(input_dict) Now `output_dict.keys()` will contain [ 'input.agent_history', @@ -196,29 +212,38 @@ class ModuleSequence(Module): 'plan.plan_x', 'plan_x' ] - """ + """ + @property - def input_format(self) -> DataFormat: self.components[0].input_format + def input_format(self) -> DataFormat: + list(self.components.values())[0].input_format @property - def output_format(self) -> DataFormat: self.components[-1].input_format + def output_format(self) -> DataFormat: + list(self.components.values())[-1].output_format def __init__( - self, - components: OrderedDict, + self, + components: OrderedDict, model_registrar, - hyperparams, + hyperparams, log_writer, - device, - input_mappings: Dict[str, str] = {} + device, + input_mappings: Dict[str, str] = {}, ) -> None: - - super().__init__(model_registrar, hyperparams, log_writer, device, input_mappings=input_mappings) + super().__init__( + model_registrar, + hyperparams, + log_writer, + device, + input_mappings=input_mappings, + ) self.components = components - if 'input' in self.components: - raise ValueError('Name "input" is reserved for the overall input to the module sequence.') - # TODO check there is no "." in the name + if "input" in self.components: + raise ValueError( + 'Name "input" is reserved for the overall input to the module sequence.' + ) def validate_interfaces(self, desired_input: DataFormat) -> bool: data_dict = {k: None for k in desired_input} @@ -226,49 +251,71 @@ def validate_interfaces(self, desired_input: DataFormat) -> bool: if component.input_format.satisfied_by(data_dict): data_dict.update({k: None for k in component.output_format}) else: - missing_inputs = [x for x in component.input_format.required_elements if x not in data_dict] - raise ValueError(f"Component '{component.name}' missing input(s): {missing_inputs}") + missing_inputs = [ + x + for x in component.input_format.required_elements + if x not in data_dict + ] + raise ValueError( + f"Component '{component.name}' missing input(s): {missing_inputs}" + ) return data_dict - - def sequence_components(self, inputs: Dict, run_mode: RunMode): - """ Sequence the ordered dict of components by connecting outputs to inputs. - """ + + def sequence_components( + self, + inputs: Dict, + run_mode: RunMode, + pre_module_cb: Optional[callable] = None, + post_module_cb: Optional[callable] = None, + **kwargs, + ): + """Sequence the ordered dict of components by connecting outputs to inputs.""" all_outputs = {"input." + k: v for k, v in inputs.items()} - next_inputs = {**all_outputs, **inputs} + next_inputs = {**all_outputs, **inputs, **kwargs} - for name, component in self.components.items(): + for comp_i, (name, component) in enumerate(self.components.items()): component: Module - inputs = component.apply_input_mappings(next_inputs, run_mode) - if run_mode == RunMode.TRAIN: - output = component.train(inputs) + if pre_module_cb is not None: + pre_module_cb(name, component) + + if component.input_format is None: + inputs = next_inputs + else: + inputs = component.apply_input_mappings(next_inputs, run_mode) + + if run_mode == RunMode.TRAIN: + output = component.train_step(inputs, **kwargs) elif run_mode == RunMode.VALIDATE: - output = component.validate(inputs) + output = component.validate_step(inputs, **kwargs) elif run_mode == RunMode.INFER: - output = component.infer(inputs) + output = component.infer_step(inputs, **kwargs) else: raise ValueError(f"Unknown mode {run_mode}") - - # Add to all_outputs + + if post_module_cb is not None: + post_module_cb(name, component) + + # Add to all_outputs all_outputs.update({f"{name}.{k}": v for k, v in output.items()}) # Construct input from all_outputs and the previous outputs. next_inputs = {**all_outputs, **output} - - return next_inputs + + return next_inputs def dry_run( - self, - input_keys: List[str] = None, - run_mode: Optional[RunMode] = None, - check_output: bool = True, - raise_error: bool = True - ) -> List[str]: + self, + input_keys: List[str] = None, + run_mode: Optional[RunMode] = None, + check_output: bool = True, + raise_error: bool = True, + ) -> List[str]: """Checks that all inputs are defined in module sequence. We only verify input and output based on module.input_format and - module.output_format. We do not check if the modules correctly + module.output_format. We do not check if the modules correctly define their respective input_format and output_format. Args: @@ -285,37 +332,49 @@ def dry_run( issues = [] if run_mode is None: for run_mode in RunMode: - issues.extend(self.dry_run(input_keys, run_mode, check_output, raise_error=False)) - + issues.extend( + self.dry_run(input_keys, run_mode, check_output, raise_error=False) + ) + if issues: + break + else: if input_keys is None: input_keys = list(self.input_format) inputs = {k: None for k in input_keys} all_outputs = {"input." + k: v for k, v in inputs.items()} - next_inputs = {**all_outputs, **inputs} + next_inputs = {**all_outputs, **inputs} component: Module for name, component in self.components.items(): - inputs = component.apply_input_mappings(next_inputs, run_mode, allow_partial=True) + inputs = component.apply_input_mappings( + next_inputs, run_mode, allow_partial=True + ) input_format = component.input_format.for_run_mode(run_mode) if not input_format.satisfied_by(inputs): - issues.append(f"Component {name}({run_mode.name}): \n Required inputs: {list(input_format)}\n Provided inputs: {list(inputs.keys())}") - - output = {k: None for k in component.output_format.for_run_mode(run_mode)} - # Add to all_outputs + issues.append( + f"Component {name}({run_mode.name}): \n Required inputs: {list(input_format)}\n Provided inputs: {list(inputs.keys())}" + ) + + output = { + k: None for k in component.output_format.for_run_mode(run_mode) + } + # Add to all_outputs all_outputs.update({f"{name}.{k}": v for k, v in output.items()}) # Construct input from all_outputs and the previous outputs. - next_inputs = {**all_outputs, **output} + next_inputs = {**all_outputs, **output} if check_output: output_format = self.output_format.for_run_mode(run_mode) if not output_format.satisfied_by(next_inputs): - issues.append(f"Sequence {self.name}({run_mode.name}): \n Required outputs: {list(output_format)}\n Provided outputs: {list(next_inputs.keys())}") - + issues.append( + f"Sequence {self.name}({run_mode.name}): \n Required outputs: {list(output_format)}\n Provided outputs: {list(next_inputs.keys())}" + ) + if raise_error and issues: - print ("\n".join(issues)) + print("\n".join(issues)) raise ValueError("\n".join(issues)) - + return issues def set_curr_iter(self, curr_iter): @@ -338,16 +397,42 @@ def step_annealers(self, node_type=None): for comp in self.components.values(): comp.step_annealers(node_type) - def train(self, inputs: Dict) -> Dict: - return self._run_forward(inputs, RunMode.TRAIN) - - def validate(self, inputs: Dict) -> Dict: - return self._run_forward(inputs, RunMode.VALIDATE) - - def infer(self, inputs: Dict) -> Dict: - return self._run_forward(inputs, RunMode.INFER) + def train_step(self, inputs: Dict, **kwargs) -> Dict: + return self._run_forward(inputs, RunMode.TRAIN, **kwargs) + + def validate_step(self, inputs: Dict, **kwargs) -> Dict: + return self._run_forward(inputs, RunMode.VALIDATE, **kwargs) + + def infer_step(self, inputs: Dict, **kwargs) -> Dict: + return self._run_forward(inputs, RunMode.INFER, **kwargs) + + def _run_forward(self, inputs: Dict, run_mode: RunMode, **kwargs) -> Dict: + return self.sequence_components(inputs, run_mode=run_mode, **kwargs) + + def __getstate__(self): + # Custom getstate for pickle that allows for lambda functions. + import dill # reimport in case not available in forked function + + return dill.dumps(self.__dict__) + + def __setstate__(self, state): + # Custom setstate for pickle that allows for lambda functions. + state = dill.loads(state) + self.__dict__.update(state) + + def __str__(self): + return f"Module sequence (device={self.device}) \n " + "\n ".join( + self.components.keys() + ) - def _run_forward(self, inputs: Dict, run_mode: RunMode) -> Dict: - raise self.sequence_components(inputs, run_mode=run_mode) + def set_eval(self): + for k, v in self.components.items(): + v.set_eval() + def set_train(self): + for k, v in self.components.items(): + v.set_train() + def reset(self): + for k, v in self.components.items(): + v.reset() diff --git a/diffstack/modules/planners/__init__.py b/diffstack/modules/planners/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/diffstack/modules/planners/fan_mpc_planner.py b/diffstack/modules/planners/fan_mpc_planner.py deleted file mode 100644 index 701959e..0000000 --- a/diffstack/modules/planners/fan_mpc_planner.py +++ /dev/null @@ -1,828 +0,0 @@ -import torch -import numpy as np -from typing import Dict, Optional, Union, Any - -from trajdata.data_structures.batch import AgentBatch -from trajdata.data_structures.agent import AgentType - -from diffstack.modules.cost_functions.cost_selector import get_cost_object -from diffstack.modules.cost_functions.linear_base_cost import LinearBaseCost -from diffstack.modules.dynamics_functions import ExtendedUnicycleDynamics -from diffstack.modules.planners.fan_planner import FanPlanner -from diffstack.modules.planners.mpc_utils.trajcost_mpc import TrajCostMPC, GradMethods - -from diffstack.utils.utils import CudaTimer, subsample_traj, convert_state_pred2plan, all_gather, restore, batchable_dict, batchable_nonuniform_tensor -from diffstack.modules.module import Module, DataFormat, RunMode - - -MAX_PLAN_NEIGHBORS = 16 - -timer = CudaTimer(enabled=False) -mse_fn = torch.nn.MSELoss(reduction='none') - - -class FanMpcPlanner(Module): - @property - def input_format(self) -> DataFormat: - return DataFormat(["agent_batch", "pred_dist"]) - - @property - def output_format(self) -> DataFormat: - return DataFormat(["plan_xu", "valid", "converged", "loss:train", "loss:validate", "loss_batch:train", "loss:validate", "metrics:train", "metrics:validate"]) - - def __init__(self, model_registrar, hyperparams, log_writer, device, input_mappings = {}): - super().__init__(model_registrar, hyperparams, log_writer, device, input_mappings) - - self.ph = int(hyperparams['prediction_sec'] / hyperparams['dt']) - self.planh = self.ph * int(hyperparams['dt'] / hyperparams['plan_dt']) - - # Dynamics and cost - self.dyn_obj = ExtendedUnicycleDynamics(dt=self.hyperparams['plan_dt']) - # Load cost object from model registrat. - # TODO we should check if the loaded model had the same cost function setting. - self.cost_obj = self.model_registrar.get_model("planner_cost", self.get_plan_cost(hyperparams['plan_cost'], is_trainable=hyperparams["train_plan_cost"])) - self.interpretable_cost_obj = self.model_registrar.get_model("interpretable_cost", self.get_plan_cost('interpretable', is_trainable=False)) - - self.u_lower = torch.from_numpy(np.array([self.cost_obj.control_limits['min_heading_change'], self.cost_obj.control_limits['min_a']])).float().to(self.device) - self.u_upper = torch.from_numpy(np.array([self.cost_obj.control_limits['max_heading_change'], self.cost_obj.control_limits['max_a']])).float().to(self.device) - - # TODO dummy normalization - self.normalize_std = torch.ones((2, ), dtype=torch.float, device=self.device) - - # Fan planner - self.fan_obj = FanPlanner(ph=self.planh, dt=hyperparams['plan_dt'], device=device) - - - # Controller - if self.cost_obj.control_limits is not None: - u_lower = self.u_lower - u_upper = self.u_upper - else: - u_lower = None - u_upper = None - - self.mpc_obj = TrajCostMPC( - 4, 2, self.planh+1, - u_init=None, - # TODO are there actual limits? SLSQP doesn't seem to constrain u, only through cost. - # u_lower=None, u_upper=None, - u_lower=u_lower, - u_upper=u_upper, - lqr_iter=self.hyperparams['plan_lqr_max_iters'], # 50 def for pendulum - verbose=-1, # no output, 0 warnings - exit_unconverged=False, - detach_unconverged=False, # Manually detach instead - linesearch_decay=0.2, - max_linesearch_iter=self.hyperparams['plan_lqr_max_linesearch_iters'], - grad_method=GradMethods.AUTO_DIFF, - eps=self.hyperparams['plan_lqr_eps'], - n_batch=1, # we will update this manually every time before calling MPC - ) - - def get_plan_cost(self, plan_cost_mode: str, is_trainable: bool) -> LinearBaseCost: - control_limits = self.hyperparams["dynamic"]["VEHICLE"]["limits"] - return get_cost_object(plan_cost_mode, control_limits, is_trainable=is_trainable, device=self.device) - - def train(self, inputs: Dict) -> Dict: - batch: AgentBatch = inputs["agent_batch"] - pred_dist = inputs["pred_dist"] - - # TODO currently we use future_mode to decide what prediction gets used. Instead we should - # implement these as different predictors. - if self.hyperparams["predictor"] == "nopred": - future_mode="nopred" - elif self.hyperparams["predictor"] == "gt": - future_mode="gt" - elif self.hyperparams["predictor"] == "blind": - future_mode="none" - elif self.hyperparams["predictor"] == "tpp_nogt": - future_mode="pred_nogt" - else: - future_mode="pred" - init_mode=self.hyperparams["plan_init"] - loss_mode=self.hyperparams["plan_loss"] - - # For now we only support one type of prediction agent in the batch. - node_types = batch.agent_types() - if len(node_types) > 1: - raise NotImplementedError("Mixing agent types for prediction in a batch is not supported.") - node_type = node_types[0] - - if node_type.name in self.hyperparams["plan_node_types"]: - plan_loss_batch, plan_converged, metrics, plan_info = self._plan_loss(batch, pred_dist, None, init_mode, future_mode, loss_mode, return_iters=False) - plan_valid = plan_info['plan_batch_filter'] - plan_xu = plan_info['plan_xu'] # TODO this is a detached tensor - else: - plan_loss_batch = torch.zeros((0,), device=self.device) - plan_valid = torch.zeros((0,), dtype=torch.bool, device=self.device) - plan_converged = torch.zeros((0,), dtype=torch.bool, device=self.device) - metrics = {} - - # loss as metric - metrics["plan_loss"] = plan_loss_batch - - # Use simple mean as loss. It makes sense in that gradients in the batch will be averaged. - plan_loss = plan_loss_batch.mean() - outputs = {"plan_xu": plan_xu, "valid": plan_valid, "converged": plan_converged, "loss": plan_loss, "loss_batch": plan_loss_batch, "metrics": metrics} - - # TODO later the fan planner and mpc should be different subcomponents. For now just hack the output - # as if the fan planner produced its candidates. - if "traj_xu" in plan_info: - outputs["fan.candidates_xu"] = plan_info["traj_xu"] - - return outputs - - def validate(self, inputs: Dict) -> Dict: - return self.train(inputs) - - def infer(self, inputs: Dict) -> Dict: - return self.train(inputs) - - def _plan_loss(self, batch: AgentBatch, y_dist, pred_extra=None, init_mode="fitted", future_mode="pred", loss_mode="mse", return_iters=False): - timer.start("prepare") - plan_batch_filter, plan_ego_batch, lanes, goal_batch, plan_mus_batch, plan_logp_batch, plan_gt_neighbors_batch, plan_all_gt_neighbors_batch, empty_mus_batch, empty_logp_batch, fan_inputs = self.prepare_plan_instance( - batch, y_dist, build_fan_inputs=(self.hyperparams["planner"] in ["fan", "fan_mpc"])) - lane_points = None - x_gt = plan_ego_batch[1:] - timer.end("prepare") - - if plan_ego_batch.shape[1] == 0: - # If we cannot plan for anything in batch, create a dummy zero loss - plan_loss_batch = torch.zeros((1,), device=self.device) - plan_converged = torch.zeros((1, ), dtype=torch.bool, device=self.device) - plan_info = {} - metrics = {} - print ("Could not plan for anything in batch.") - - else: - # Plan - timer.start("plan") - plan_x, plan_u, plan_cost, plan_converged, plan_info = self.plan_batch( - plan_ego_batch[0], lanes, goal_batch, plan_mus_batch, plan_logp_batch, - empty_mus_batch, empty_logp_batch, plan_gt_neighbors_batch, plan_all_gt_neighbors_batch, - future_mode=future_mode, init_mode=init_mode, planner=self.hyperparams["planner"], - pred_extra=pred_extra, - fan_inputs=fan_inputs, - plan_batch_filter=plan_batch_filter, - return_iters=return_iters) # (T, b, ...) - - plan_sub_x = subsample_traj(plan_x, self.ph, self.planh) - # plan_sub_u = subsample_traj(plan_u, self.ph, self.planh) - timer.end("plan") - - timer.start("metrics") - # Plan error metrics - plan_mse_batch = mse_fn(plan_sub_x[1:, :, :2], x_gt[..., :2]).sum(dim=-1) # (T, b) - plan_unbiased_d1_batch = torch.sqrt(plan_mse_batch).mean(dim=0) - plan_unbiased_d2_batch = plan_mse_batch.mean(dim=0) # (b, ) - # For fan planner mse is only meaningful for 'converged' cases. - # We will replace with zero for unconverged samples, so batch dimension is preserved, but this does introduce bias. - if self.hyperparams["planner"] == "fan": - plan_unbiased_d1_batch[torch.logical_not(plan_converged)] = 0. - plan_unbiased_d2_batch[torch.logical_not(plan_converged)] = 0. - - # Cost(plan, gtfutures) - plan_xu = torch.cat((plan_x, plan_u), -1) - plan_hcost_batch = self.cost_obj(plan_xu, cost_inputs=(plan_all_gt_neighbors_batch, empty_mus_batch, empty_logp_batch, goal_batch, lanes, lane_points)) # (T, b) - plan_info['plan_xu'] = plan_xu.detach() - - # Trajectory fan related metrics - if self.hyperparams['planner'] in ['fan', 'fan_mpc']: - # Pad and stack trajectory candidates - traj_xu = torch.nn.utils.rnn.pad_sequence(plan_info['traj_xu'], batch_first=True, padding_value=torch.inf) # b, N, T+1, 6 - traj_xu_sub = subsample_traj(traj_xu.transpose(2, 0), self.ph, self.planh) # T, N, b, 6 - traj_cost = torch.nn.utils.rnn.pad_sequence(plan_info['traj_cost'], batch_first=True, padding_value=torch.inf) # b, N - - ########## - # Label: closest based on last state distance - trajdist2 = torch.square(traj_xu_sub[1:, :, :, :2] - x_gt[:, :, :2].unsqueeze(1)).sum(-1) # T, N, b - trajmse = torch.mean(trajdist2, dim=0) - label_mse = torch.argmin(trajmse, dim=0) # b, - plan_info["label_mse"] = label_mse.detach() - - # goaldist2 = torch.square(traj_xu[:, :, -1, :2] - x_gt[-1, :, :2].unsqueeze(1)).sum(-1) # b, N - # goaldist2 = trajdist2[-1] # b, N - # label_goaldist = torch.argmin(goaldist2, dim=1) # b, - # plan_info["label_goaldist"] = label_goaldist.detach() - - goaldist2 = trajdist2[-1] # N, b - label_goaldist = torch.argmin(goaldist2, dim=0) # b, - plan_info["label_goaldist"] = label_goaldist.detach() - - # label_xy = traj_xu[torch.arange(traj_xu.shape[0]), label_goaldist] # b, T, 2 - # label_plan_mse = mse_fn(label_xy.transpose(1, 0)[1:, :, :2], x_gt[..., :2]).sum(dim=-1) # (T, b) - - # Cross entropy, based on lowest mse - class_mse_loss = torch.nn.functional.cross_entropy(-traj_cost, label_mse, reduction='none') - class_mse_loss = torch.nan_to_num(class_mse_loss, nan=0.) - - # Cross entropy, based on closest at goal label - class_goaldist_loss = torch.nn.functional.cross_entropy(-traj_cost, label_goaldist, reduction='none') - class_goaldist_loss = torch.nan_to_num(class_goaldist_loss, nan=0.) - - ########## - timer.start("hcost") - # TODO this is now slow (50% of compute spent here) - # it would be easy to speed up by skipping padding and batch loop, and instead concatenate all while remembering lengths, and recover with split - # Get hindsight costs for all traj candidates - traj_hcost = self.fan_obj.get_cost_for_trajs( - plan_info['traj_xu'], self.cost_obj, - cost_inputs=(plan_all_gt_neighbors_batch, empty_mus_batch, empty_logp_batch, goal_batch, lanes, lane_points)) # (b)(N) - plan_info["traj_hcost"] = [x.detach() for x in traj_hcost] - traj_hcost = torch.nn.utils.rnn.pad_sequence(traj_hcost, batch_first=True, padding_value=torch.inf) # b, N - timer.end("hcost") - # print (len(plan_info['traj_xu']), plan_info['traj_xu'][0].device, plan_all_gt_neighbors_batch.device, goal_batch.device, lane_points.device) - - # Label: based on lowest hindsight cost - label_hcost = torch.argmin(traj_hcost, dim=1) # b, - plan_info["label_hcost"] = label_hcost.detach() - - # Hcost of fan's choice - fan_choice = torch.argmin(traj_cost, dim=1) # b, - fan_choice_hcost = traj_hcost[torch.arange(traj_hcost.shape[0]).to(self.device), fan_choice] - - # Cross entropy, based on hindsight cost label - class_hcost_loss = torch.nn.functional.cross_entropy(-traj_cost, label_hcost, reduction='none') - class_hcost_loss = torch.nan_to_num(class_hcost_loss, nan=0.) - - target_probs = torch.nn.functional.softmax(-traj_cost, dim=-1) - logits = -traj_cost.nan_to_num(posinf=1e10) - dist_hcost_loss = torch.nn.functional.cross_entropy(logits, target_probs, reduction='none') - dist_hcost_loss = torch.nan_to_num(dist_hcost_loss, nan=0.) - - # Maxmargin - if traj_cost.shape[1] < 2: - # Need at least two candidates - maxmargin_goaldist = torch.full((traj_xu.shape[0], ), torch.inf, dtype=torch.float, device=self.device) - else: - # Lowest cost other than the closest - sorted_idx = torch.argsort(traj_cost, dim=1) - top1 = sorted_idx[:, 0] - top2 = sorted_idx[:, 1] - lowest_other = torch.where(top1 == label_goaldist, top2, top1) - lowest_other_cost = torch.gather(traj_cost, dim=1, index=lowest_other.unsqueeze(-1)).squeeze(-1) - - closest_cost = torch.gather(traj_cost, dim=1, index=label_goaldist.unsqueeze(-1)).squeeze(-1) - - # Maxmargin loss - maxmargin_goaldist = closest_cost - lowest_other_cost - maxmargin_goaldist = torch.nan_to_num(maxmargin_goaldist, nan=0., posinf=0., neginf=0.) - - else: - class_goaldist_loss = torch.zeros_like(plan_unbiased_d1_batch) - maxmargin_goaldist = torch.zeros_like(plan_unbiased_d1_batch) - - # # Debug converged rate - # print (plan_converged.float().mean(), plan_converged) - # print (plan_x[:, 0]) - - # Resolve multi-component loss modes - if loss_mode in ["joint_hcost", "joint-hcost"]: - loss_modes = ["class_hcost", "hcost"] # order is important for scaler2 - elif loss_mode == "joint_hcost2": - loss_modes = ["class_hcost2", "hcost"] # order is important for scaler2 - elif loss_mode in ["joint_mse", "joint-mse"]: - loss_modes = ["class_mse", "mse"] # order is important for scaler2 - elif loss_mode == "joint_mse2": - loss_modes = ["class_mse2", "mse"] # order is important for scaler2 - else: - loss_modes = [loss_mode] - - # Loss function - plan_loss_components = [] - for loss_mode_i in loss_modes: - - #### Trajfan planner losses - if loss_mode_i == "class_goaldist": - plan_loss_batch = class_goaldist_loss - # Mask valid: the first two candidate costs should be finite - if traj_cost.shape[1] < 2: - plan_converged = torch.zeros_like(plan_converged) - else: - plan_converged = plan_converged & traj_cost[:, 1].isfinite() - # Not sure this is needed, could be already taken care of by nan replacement - plan_loss_batch = plan_loss_batch * plan_converged.float() + plan_loss_batch.detach() * (1.-plan_converged.float()) - plan_loss_components.append(plan_loss_batch) - - elif loss_mode_i == "class_mse": - plan_loss_batch = class_mse_loss - # Mask valid: the first two candidate costs should be finite - if traj_cost.shape[1] < 2: - plan_converged = torch.zeros_like(plan_converged) - else: - plan_converged = plan_converged & traj_cost[:, 1].isfinite() - # Not sure this is needed, could be already taken care of by nan replacement - plan_loss_batch = plan_loss_batch * plan_converged.float() + plan_loss_batch.detach() * (1.-plan_converged.float()) - plan_loss_components.append(plan_loss_batch) - - elif loss_mode_i == "class_mse2": - plan_loss_batch = class_mse_loss - # Mask valid: the first two candidate costs should be finite - if traj_cost.shape[1] < 2: - fan_converged = torch.zeros_like(plan_info["fan_converged"]) - else: - fan_converged = plan_info["fan_converged"] & traj_cost[:, 1].isfinite() - # Not sure this is needed, could be already taken care of by nan replacement - plan_loss_batch = plan_loss_batch * fan_converged.float() + plan_loss_batch.detach() * (1.-fan_converged.float()) - plan_loss_components.append(plan_loss_batch) - - elif loss_mode_i == "class_hcost": - plan_loss_batch = class_hcost_loss - # Mask valid: the first two candidate costs should be finite - if traj_cost.shape[1] < 2: - plan_converged = torch.zeros_like(plan_converged) - else: - plan_converged = plan_converged & traj_cost[:, 1].isfinite() - # Not sure this is needed, could be already taken care of by nan replacement - plan_loss_batch = plan_loss_batch * plan_converged.float() + plan_loss_batch.detach() * (1.-plan_converged.float()) - plan_loss_components.append(plan_loss_batch) - - elif loss_mode_i == "class_hcost2": - plan_loss_batch = class_hcost_loss - # Mask valid: the first two candidate costs should be finite - if traj_cost.shape[1] < 2: - fan_converged = torch.zeros_like(plan_info["fan_converged"]) - else: - fan_converged = plan_info["fan_converged"] & traj_cost[:, 1].isfinite() - # Not sure this is needed, could be already taken care of by nan replacement - plan_loss_batch = plan_loss_batch * fan_converged.float() + plan_loss_batch.detach() * (1.-fan_converged.float()) - plan_loss_components.append(plan_loss_batch) - - elif loss_mode_i == "dist_hcost": - plan_loss_batch = dist_hcost_loss - # Mask valid: the first two candidate costs should be finite - if traj_cost.shape[1] < 2: - plan_converged = torch.zeros_like(plan_converged) - else: - plan_converged = plan_converged & traj_cost[:, 1].isfinite() - # Not sure this is needed, could be already taken care of by nan replacement - plan_loss_batch = plan_loss_batch * plan_converged.float() + plan_loss_batch.detach() * (1.-plan_converged.float()) - plan_loss_components.append(plan_loss_batch) - - elif loss_mode_i == "maxmargin_goaldist": - plan_loss_batch = maxmargin_goaldist - # Mask valids - plan_converged = plan_converged & lowest_other_cost.isfinite() - # Not sure this is needed, could be already taken care of by nan replacement - plan_loss_batch = plan_loss_batch * plan_converged.float() + plan_loss_batch.detach() * (1.-plan_converged.float()) - plan_loss_components.append(plan_loss_batch) - - #### MPC planner losses - elif loss_mode_i == "mse": - plan_loss_batch = plan_mse_batch - plan_loss_batch = plan_loss_batch.mean(dim=0) # reduce over time keep batch (b,) - plan_loss_components.append(plan_loss_batch) - - elif loss_mode_i == "mse-bias": - gtplan_mse_batch = mse_fn(gtplan_x[1:, :, :2], x_gt[..., :2]).sum(dim=-1) - plan_loss_batch = (plan_mse_batch - gtplan_mse_batch) - plan_loss_batch = plan_loss_batch.mean(dim=0) # reduce over time keep batch (b,) - plan_loss_components.append(plan_loss_batch) - - elif loss_mode_i == "hcost": - assert not self.hyperparams['train_plan_cost'] # loss function cannot depend on other learned parameter - plan_loss_batch = plan_hcost_batch - plan_loss_batch = plan_loss_batch.mean(dim=0) # reduce over time keep batch (b,) - plan_loss_components.append(plan_loss_batch) - - elif loss_mode_i == "hcost-bias": - assert not self.hyperparams['train_plan_cost'] # loss function cannot depend on other learned parameter - # # Rerun planner with gt future instead of using cached result - # gtplan_x, gtplan_u, _, _, _ = self.plan_batch( - # plan_ego_time_batch, plan_mus_batch, plan_logp_batch, - # empty_mus_batch, empty_logp_batch, plan_gt_neighbors_batch, plan_all_gt_neighbors_batch, - # future_mode="gt", init_mode=init_mode, planner=self.hyperparams["planner"], plan_gt_u=gtplan_u, - # plan_data=plan_data, return_iters=False) # (T, b, ...) - - gtplan_xu = torch.cat((gtplan_x, gtplan_u), -1) - gtplan_hcost_batch = self.cost_obj(gtplan_xu, cost_inputs=(plan_all_gt_neighbors_batch, empty_mus_batch, empty_logp_batch, goal_batch, lanes, lane_points)) - # plan_loss_batch = plan_hcost_batch - gtplan_hcost_batch - plan_loss_batch = subsample_traj(plan_hcost_batch, self.ph, self.planh) - gtplan_hcost_batch - plan_loss_batch = plan_loss_batch.mean(dim=0) # reduce over time keep batch (b,) - plan_loss_components.append(plan_loss_batch) - - else: - raise ValueError("Unknown plan_loss / loss_mode arg: %s"%loss_mode_i) - - # Combine loss components - if len(plan_loss_components) == 1: - plan_loss_batch = plan_loss_components[0] - elif len(plan_loss_components) == 2: - # Second scaler will be used for mpc loss - plan_loss_batch = plan_loss_components[0] + self.hyperparams['plan_loss_scaler2'] * plan_loss_components[1] - else: - raise NotImplementedError("More than 2 plan loss components") - - # Return metrics and planning internals - plan_info['plan_batch_filter'] = plan_batch_filter.detach() - plan_info['converged'] = plan_converged.detach() - plan_info['fan_converged'] = plan_info["fan_converged"].detach() # will be set to all True if not fan planner - - if return_iters: - plan_hcost_components = self.cost_obj( - plan_xu, cost_inputs=(plan_all_gt_neighbors_batch, empty_mus_batch, empty_logp_batch, goal_batch, lanes, lane_points), - keep_components=True) # (T, b, c) - plan_icost_components = self.interpretable_cost_obj( - plan_xu, cost_inputs=(plan_all_gt_neighbors_batch, empty_mus_batch, empty_logp_batch, goal_batch, lanes, lane_points), - keep_components=True) # (T, b, c) - - plan_info['x_gt'] = x_gt.detach() - # FIXN - plan_info['all_gt_neighbors'] = plan_all_gt_neighbors_batch.detach() - plan_info['gt_neighbors'] = plan_gt_neighbors_batch.detach() - # detach_list = lambda l: [t.detach() for t in l] - # plan_iters['all_gt_neighbors'] = detach_list(plan_all_gt_neighbors_batch) - # plan_iters['gt_neighbors'] = detach_list(plan_gt_neighbors_batch) - plan_info['lanes'] = lanes.detach() - plan_info['lane_points'] = lane_points.detach() if lane_points is not None else None - plan_info['plan_loss'] = plan_loss_batch.detach() - plan_info['hcost_components'] = plan_hcost_components.detach() - plan_info['icost_components'] = plan_icost_components.detach() - plan_info['ego_pred_gt_dist'] = torch.square(torch.square(batch.agent_fut[plan_batch_filter, :, :2].transpose(1, 0) - x_gt[:, :, :2]).sum(dim=-1).min(dim=0).values) - - plan_and_fan_valid = plan_batch_filter.detach().clone() - plan_and_fan_valid[plan_batch_filter] = torch.logical_and(plan_batch_filter[plan_batch_filter], plan_info["fan_converged"]) - plan_info['plan_and_fan_valid'] = plan_and_fan_valid.detach() - - metrics = { - "unbiased_d1": plan_unbiased_d1_batch, # (b, ) - "unbiased_d2": plan_unbiased_d2_batch, # (b, ) # TODO this is the same as mse - "cost": plan_cost, # (b,) - "mse": plan_mse_batch.mean(0), # (b,) - "hcost": plan_hcost_batch.mean(0), #(b,) - "fan_valid": plan_info["fan_converged"].detach(), - "converged": plan_info["converged"].detach(), - } - if self.hyperparams['planner'] in ['fan', 'fan_mpc']: - metrics.update({ - "class_goaldist": class_goaldist_loss, - "class_hcost": class_hcost_loss, - "dist_hcost_loss": dist_hcost_loss, - "maxmargin_goaldist": maxmargin_goaldist, - "class_mse": class_mse_loss, - "fan_hcost": fan_choice_hcost.detach(), - }) - - timer.end("metrics") - - timer.print() - - return plan_loss_batch, plan_converged, metrics, plan_info - - def prepare_plan_instance(self, batch: AgentBatch, y_dist, build_fan_inputs: bool = False): - # Prepare planning instance. - batch_size = batch.agent_fut.shape[0] - - if y_dist is not None: - mus = y_dist.mus.squeeze(0) # (b, t, K, 2) - log_pis = y_dist.log_pis.squeeze(0) # (b, t, K) - # # Component logprobability should be the same through time - # assert torch.isclose(log_pis[:, 0, :], log_pis[:, -1, :]).all() - log_pis = log_pis[:, 0] # (b, K) - else: - mus = torch.full((batch_size, batch.agent_fut.shape[1], 1, 2), torch.nan, device=self.device) # b, T, K, 2 - log_pis = torch.full((batch_size, batch.agent_fut.shape[1]), torch.nan, device=self.device) - - # # Assume that robot is always the first neighbor, i.e. robot_ind is always 0 or negative - # assert (batch.extras['robot_ind'] <= 0).all() - - plan_batch_filter = (batch.extras['robot_ind'] >= 0) - - plan_mus_batch = mus.unsqueeze(0).transpose(1, 2)[:, :, plan_batch_filter] # N, T, b, K, 2 - plan_logp_batch = log_pis.unsqueeze(0)[:, plan_batch_filter] # N, b, K - - # TODO(pkarkus) In preprocessing neighbors should be ordered according to dist from ego - # # Choose N most relevant - # if len(others_x) > MAX_PLAN_NEIGHBORS: - # # Choose the most relevant, based on minimum distance to gt - # plan_ego_x = plan_ego_f[..., 1:, :2].unsqueeze(0) # (1, T, 2) - # dists = torch.square(others_x - plan_ego_x).sum(-1) # (N, T) - # dists = torch.min(dists, dim=-1).values # (N, ) - # sorted_idx = torch.argsort(dists) - # others_x = others_x[sorted_idx[:MAX_PLAN_NEIGHBORS]] - - # Combine current step and future - # neigh_pres = batch.neigh_hist[:, :, batch.neigh_hist_len-1].unsqueeze(2) - # neigh_pres_fut = torch.concat([neigh_pres, batch.neigh_fut], dim=2) # b, N, T+1, 8 - # agent_pres = batch.agent_hist[:, batch.agent_hist_len].unsqueeze(1) - # agent_pres_fut = torch.concat([agent_pres, batch.agent_fut], dim=1) # # b, T+1, 8 - - - # we only want neigh_fut for vehicles - neigh_fut = batch.neigh_fut.clone() - nonvehicle_filter = (batch.neigh_types.int() != AgentType.VEHICLE) - neigh_fut[nonvehicle_filter] = torch.nan - - # Convert state representation - # ['x', 'y', 'vx', 'vy', 'ax', 'ay', 'sintheta', 'costheta'] --> [x, y, theta, v] - pred_fut = convert_state_pred2plan(batch.agent_fut[plan_batch_filter]) # b, N, T, 4 - neigh_fut = convert_state_pred2plan(neigh_fut[plan_batch_filter]) # b, N, T, 4 - ego_hist = convert_state_pred2plan(batch.neigh_hist[plan_batch_filter, 0]) # b, T, 4 - # To get the present time step we need to index with history len, Because history is right-padded - pres_history_ind = batch.neigh_hist_len.to(self.device)[plan_batch_filter, 0].unsqueeze(-1).unsqueeze(-1)-1 - ego_pres = torch.gather(ego_hist, 1, torch.repeat_interleave(pres_history_ind, 4, dim=2)).squeeze(1) # b, 4 - - # # Choose N most relevant - neigh_fut = neigh_fut[:, :MAX_PLAN_NEIGHBORS+1] # b, N, T, 8 - # Extend to fix size. pad syntax: (last_l, last_r, second_last_l, second_last_r) - neigh_fut = torch.nn.functional.pad(neigh_fut, (0, 0, 0, 0, 0, MAX_PLAN_NEIGHBORS+1-neigh_fut.shape[1]), 'constant', torch.nan) - plan_ego_batch = neigh_fut[:, 0, :, :].transpose(0, 1) # T, b, 4 - plan_gt_neighbors_batch = neigh_fut[:, 1:, :, :2].transpose(0,1).transpose(1,2) # N, T, b, 2 - # Combine predicted agent gt future with neighbor futures - plan_all_gt_neighbors_batch = torch.concat([pred_fut.unsqueeze(0)[..., :2].transpose(1 ,2), plan_gt_neighbors_batch], dim=0) # N+1, T, b, 2 - # Extend plan agent state with present - plan_ego_batch = torch.concat([ego_pres.unsqueeze(0), plan_ego_batch], dim=0) # T+1, b, 4 - - empty_mus_batch = torch.zeros((0, plan_mus_batch.shape[1], batch_size, plan_mus_batch.shape[3], 2), dtype=plan_mus_batch.dtype, device=self.device) - empty_logp_batch = torch.zeros((0, batch_size, plan_mus_batch.shape[3]), dtype=plan_mus_batch.dtype, device=self.device) - - # Fill relevant lane data if not cached - if build_fan_inputs: - lanes_near_goal_filtered = [batch.extras['lanes_near_goal'][i] for i in range(batch_size) if plan_batch_filter[i]] - else: - lanes_near_goal_filtered = None - fan_inputs = {'lanes_near_goal_filtered': lanes_near_goal_filtered} - - lanes = batch.extras['lane_projection_points'][plan_batch_filter].transpose(1, 0) # b, T, 3 --> T, b, 3 - goal = batch.extras['goal'][plan_batch_filter, ..., :2] # b, 2 - - return plan_batch_filter, plan_ego_batch, lanes, goal, plan_mus_batch, plan_logp_batch, plan_gt_neighbors_batch, plan_all_gt_neighbors_batch, empty_mus_batch, empty_logp_batch, fan_inputs - - def plan_batch(self, x_init_batch, lanes, goal_batch, mus_batch, logp_batch, empty_mus_batch, empty_logp_batch, gt_neighbors_batch, all_gt_neighbors_batch, future_mode, init_mode, planner="mpc", pred_extra=None, fan_inputs=None, plan_batch_filter=None, gt_plan_u=None, nopred_plan_u=None, return_iters=False): - """ - """ - batch_size = x_init_batch.shape[0] - - u_fitted = None - lane_points = None - x_goal_batch = None - plan_data = None - - # Choose initialization - if init_mode == "fitted": - assert not torch.isnan(u_fitted).any() - u_init = u_fitted - assert nopred_plan_u.shape[0] == self.planh+1, "u_init trajlen does not match planh" - elif init_mode == "gtplan": - u_init = gt_plan_u # (ph, b, 2) - assert nopred_plan_u.shape[0] == self.planh+1, "u_init trajlen does not match planh" - elif init_mode == "nopred_plan": - u_init = nopred_plan_u # (ph, b, 2) - assert nopred_plan_u.shape[0] == self.planh+1, "u_init trajlen does not match planh" - elif init_mode == "zero": - u_init = torch.zeros((self.planh+1, batch_size, 2), device=self.device) # (ph, b, 2) - else: - raise NotImplementedError("Unknown plan_init: %s"%init_mode) - - # Chooise planner inputs - if future_mode == "pred": - pass - elif future_mode == "nopred": - # drop the predicted agent, include the other gt agents - mus_batch = empty_mus_batch - logp_batch = empty_logp_batch - elif future_mode == "gt": - # replace prediction with gt - mus_batch = empty_mus_batch - logp_batch = empty_logp_batch - gt_neighbors_batch = all_gt_neighbors_batch - elif future_mode == "none": - # drop all predicted and gt agent futures - mus_batch = empty_mus_batch - logp_batch = empty_logp_batch - # MAXN - # gt_neighbors_batch = [torch.zeros((0, self.ph, 2), dtype=u_init.dtype, device=self.device)] * len(mus_batch) - gt_neighbors_batch = None # torch.zeros((MAX_PLAN_NEIGHBORS, self.ph, mus_batch.shape[2], 2), dtype=u_init.dtype, device=self.device) - elif future_mode == "pred_nogt": - # drop gt agent futures but keep prediction - gt_neighbors_batch = None - else: - raise ValueError("Unknown value use_future=%s"%str(future_mode)) - - # MAXN - # probs_batch = [torch.exp(logp) for logp in logp_batch] - probs_batch = torch.exp(logp_batch) - - cost_inputs = (gt_neighbors_batch, mus_batch, probs_batch, goal_batch, lanes, lane_points) - - # Run planner - if planner == "mpc": - self.mpc_obj.n_batch = batch_size - self.mpc_obj.u_init = u_init.detach() - planned_x, planned_u, planned_cost, is_converged, iters = self.mpc_obj(x_init_batch, self.cost_obj, self.dyn_obj, cost_inputs, return_converged=True, return_iters=return_iters) - planned_x = self.mpc_obj.detach_unconverged_tensor(planned_x, is_converged) - planned_u = self.mpc_obj.detach_unconverged_tensor(planned_u, is_converged) - self.mpc_obj.u_init = None - iters['fan_converged'] = torch.ones_like(is_converged) # only to report for eval stats - - elif planner == "fan": - assert (self.hyperparams['plan_agent_choice'] == 'most_relevant'), "Only supported because relevant_lanes are for the most relevant agent." - if 'fan_candidate_trajs_filtered' in fan_inputs: - planned_x, planned_u, planned_cost, is_converged, iters = self.fan_obj( - x_init_batch, x_goal_batch, self.cost_obj, self.dyn_obj, cost_inputs, candidate_trajs_batch=fan_inputs['fan_candidate_trajs_filtered'], is_valid_batch=fan_inputs['fan_valid_filtered']) - else: - planned_x, planned_u, planned_cost, is_converged, iters = self.fan_obj( - x_init_batch, x_goal_batch, self.cost_obj, self.dyn_obj, cost_inputs, relevant_lanes_batch=fan_inputs['lanes_near_goal_filtered']) - iters['fan_converged'] = is_converged # only for compatibility with fan_mpc planner when used in class_hcost2 loss - - elif planner == "fan_mpc": - # Do trajectory-fan planning, and initialize MPC with the planned trajectory. - # Almost the same code duplicated as above. - assert (self.hyperparams['plan_agent_choice'] == 'most_relevant'), "Only supported because relevant_lanes are for the most relevant agent." - if 'fan_candidate_trajs_filtered' in fan_inputs: - fan_planned_x, fan_planned_u, fan_planned_cost, fan_converged, fan_iters = self.fan_obj( - x_init_batch, x_goal_batch, self.cost_obj, self.dyn_obj, cost_inputs, candidate_trajs_batch=fan_inputs['fan_candidate_trajs_filtered'], is_valid_batch=fan_inputs['fan_valid_filtered']) - else: - fan_planned_x, fan_planned_u, fan_planned_cost, fan_converged, fan_iters = self.fan_obj( - x_init_batch, x_goal_batch, self.cost_obj, self.dyn_obj, cost_inputs, relevant_lanes_batch=fan_inputs['lanes_near_goal_filtered']) - - # Use planner's output as initialization for MPC - # fan_planned_u is already a best control that approximates the spline when unrolled, taking u[t] = (u[t] + u[t+1])/2 - # detach: fan_planned_u is chosen from a set of candidates, its normally not a function of learnable parameter, so we could only - # backprop if we used a weighted sum type of output in the trajfan planner - u_init = fan_planned_u.detach() - - # # Debug - # torch.set_printoptions(precision=10, linewidth=160) - # print (u_init.sum()) - # print (x_init_batch.sum()) - # print (mus_batch.sum()) - # print (probs_batch.sum()) - # print (goal_batch.sum()) - # print (torch.nan_to_num(lanes, 0.19).sum()) - # print (torch.nan_to_num(gt_neighbors_batch, 0.0019).sum()) - - # # Manually fix seed - # import random - # seed = 100 - # random.seed(seed) - # np.random.seed(seed) - # torch.manual_seed(seed) - # if torch.cuda.is_available(): - # torch.cuda.manual_seed_all(seed) - - # MPC - self.mpc_obj.n_batch = batch_size - self.mpc_obj.u_init = u_init.detach() - planned_x, planned_u, planned_cost, mpc_converged, iters = self.mpc_obj(x_init_batch, self.cost_obj, self.dyn_obj, cost_inputs, return_converged=True, return_iters=return_iters) - self.mpc_obj.u_init = None - - # print (planned_x.sum()) - - # Consider converged only those that are converged (valid) for fan planner and converged for MPC - is_converged = torch.logical_and(fan_converged, mpc_converged) - planned_x = self.mpc_obj.detach_unconverged_tensor(planned_x, is_converged) - planned_u = self.mpc_obj.detach_unconverged_tensor(planned_u, is_converged) - - # Merge planning info - iters.update(fan_iters) - # # Intentionally not detaching these so we can backprop through a fan-planner loss optionally - # iters['fan_planned_x'] = fan_planned_x - # iters['fan_planned_u'] = fan_planned_u - # iters['fan_planned_cost'] = fan_planned_cost - iters['fan_converged'] = fan_converged - if return_iters: - iters['mpc_converged'] = mpc_converged - - else: - raise NotImplementedError("Unknown planner: %s"%str(planner)) - - # Add detailed cost info - if return_iters: - cost_components = self.cost_obj(torch.cat((planned_x, planned_u), dim=-1), cost_inputs=cost_inputs, keep_components=True) - iters['cost_components'] = cost_components.detach() - if planner in ["fan_mpc"]: - fan_cost_components = self.cost_obj(torch.cat((fan_planned_x, fan_planned_u), dim=-1), cost_inputs=cost_inputs, keep_components=True) - iters['fan_cost_components'] = fan_cost_components.detach() - if planner in ["mpc", "fan_mpc"]: - iters['lane_targers'] = lanes.detach() - - # # Gradcheck - # from torch.autograd import gradcheck - # from torch.autograd.gradcheck import get_analytical_jacobian, get_numerical_jacobian - # for batch_i in tqdm(range(105, batch_size)): # 121 has large error - # self.mpc_obj.n_batch = 1 - # self.mpc_obj.u_init = u_init.detach().double()[:, batch_i:batch_i+1] - # convert_list = lambda l: l[batch_i].double() - # def wrapped_func(mus_batch, probs_batch): - # planx, planu, cost, converged = self.mpc_obj(x_init_batch[batch_i:batch_i+1].double(), self.cost_obj, self.dyn_obj, (self.cost_theta.double(), [convert_list(gt_neighbors_batch)], [mus_batch], [probs_batch], goal_batch[batch_i:batch_i+1].double(), lanes[:, batch_i:batch_i+1].double()), return_converged=True) - # if not converged[0]: - # planx = planx * 0 - # planu = planu * 0 - # return planx, planu - # inputs = (convert_list(mus_batch), convert_list(probs_batch)) - # assert gradcheck(wrapped_func, inputs, eps=1e-6, atol=0.1, rtol=0.01) - # # assert gradcheck(wrapped_func, inputs, eps=1e-6, atol=1e-4) - # print ("Gradient check passed") - - return planned_x, planned_u, planned_cost, is_converged, iters - - def augment_sample_with_dummy_plan_info(self, sample, ego_traj=None): - ph = self.ph - - (first_history_index, - x_t, y_t, x_st_t, y_st_t, - neighbors_data_st, - neighbors_edge_value, - robot_traj_st_t, - map_input, neighbors_future_data, plan_data) = sample - assert isinstance(plan_data, batchable_dict) - - if plan_data['most_relevant_idx'] < 0: - ego_gt_x = torch.zeros((ph+1, 4)) - else: - ego_gt_xy = neighbors_future_data[('VEHICLE', 'VEHICLE')][plan_data['most_relevant_idx']][..., :2] - ego_gt_x = torch.cat((ego_gt_xy, torch.zeros_like(ego_gt_xy)), dim=-1) # (..., 4) - plan_data['gt_plan_x']=ego_gt_x # torch.zeros((ph+1, 4)) - plan_data['gt_plan_u']=torch.zeros((ph+1, 2)) - plan_data['gt_plan_hcost']=torch.zeros(()) - plan_data['gt_plan_converged']=torch.zeros(()) - plan_data['nopred_plan_x']=ego_gt_x # torch.zeros((ph+1, 4)) - plan_data['nopred_plan_u']=torch.zeros((ph+1, 2)) - plan_data['nopred_plan_hcost']=torch.zeros(()) - plan_data['nopred_plan_converged']=torch.zeros(()) - - if ego_traj is not None: - # Find ego state in neighbors - all_neighbors = np.stack(neighbors_future_data[('VEHICLE', 'VEHICLE')], axis=0) - # last ego_traj step should match current state - dists = np.abs(all_neighbors[:, 0, :2] - ego_traj[None, -1, :2]).sum(1) - assert np.isclose(dists.min(), 0), "could not find ego in neighbours structure" - plan_neighbor = np.argmin(dists) - - plan_data['most_relevant_idx'] = torch.Tensor([int(plan_neighbor)]).int().squeeze(0) - plan_data['robot_idx'] = torch.Tensor([int(plan_neighbor)]).int().squeeze(0) - - return (first_history_index, - x_t, y_t, x_st_t, y_st_t, - neighbors_data_st, - neighbors_edge_value, - robot_traj_st_t, - map_input, neighbors_future_data, plan_data) - - def plan_for_gt(self, batch, node_type, exclude_pred_agent=False, fan_candidates=False): - """ - """ - (first_history_index, - x_t, y_t, x_st_t, y_st_t, - neighbors_data_st, # dict of lists. edge_type -> [batch][neighbor]: Tensor(time, statedim). Represetns - neighbors_edge_value, - robot_traj_st_t, - map, neighbors_future_data, plan_data) = batch - batch_size = x_t.shape[0] - - # x = x_t.to(self.device) - y_gt = y_t.to(self.device) - if robot_traj_st_t is not None: - robot_traj_st_t = robot_traj_st_t.to(self.device) - if type(map) == torch.Tensor: - map = map.to(self.device) - - # Restore encodings - neighbors_future_data = restore(neighbors_future_data) - - assert node_type == "VEHICLE" - # TODO support planning for pedestrian prediction - # we can only plan for a vehicle but we can use pedestrian prediction. - - # Prepare planning instance. - plan_batch_filter, plan_ego_batch, plan_mus_batch, plan_logp_batch, plan_gt_neighbors_batch, plan_all_gt_neighbors_batch, empty_mus_batch, empty_logp_batch = self.prepare_plan_instance( - node_type, neighbors_future_data, plan_data, y_gt, y_dist=None, update_fan_inputs=((self.hyperparams['planner'] in ['fan', 'fan_mpc']) or fan_candidates)) - pred_T = y_gt.shape[1] # b, T, state_dim - - plan_x_full = torch.full((pred_T + 1, batch_size, 4), torch.nan, dtype=torch.float, device=self.device) - plan_u_full = torch.full((pred_T + 1, batch_size, 2), torch.nan, dtype=torch.float, device=self.device) - plan_hcost_full = torch.full((batch_size, ), torch.nan, dtype=torch.float, device=self.device) - plan_converged_full = torch.zeros((batch_size, ), dtype=torch.bool, device=self.device) - fan_candidates_list = [torch.zeros((1, pred_T+1, 6), device=self.device)] * batch_size - fan_valid_full = torch.zeros((batch_size, ), dtype=torch.bool, device=self.device) - - # MAXN - if plan_ego_batch.shape[1] > 0: - # plan_ego_time_batch = torch.stack(plan_ego_batch, dim=1) # (T, b, ...) - plan_ego_time_batch = plan_ego_batch # (T, b, ...) - x_gt, u_gt, lanes, x_proj, u_fitted, x_init_batch, goal_batch, lane_points, _, _, _ = self.decode_plan_inputs(plan_ego_time_batch) - - plan_x, plan_u, plan_cost, plan_converged, _ = self.plan_batch( - plan_ego_time_batch, empty_mus_batch, empty_logp_batch, - empty_mus_batch, empty_logp_batch, plan_gt_neighbors_batch, (plan_gt_neighbors_batch if exclude_pred_agent else plan_all_gt_neighbors_batch), - future_mode="gt", init_mode="zero", planner="mpc", gt_plan_u=None, nopred_plan_u=None, - plan_data=plan_data, return_iters=False) # (T, b, ...) - - plan_xu = torch.cat((plan_x, plan_u), -1) - plan_hcost = self.cost_obj(plan_xu, cost_inputs=(plan_all_gt_neighbors_batch, empty_mus_batch, empty_logp_batch, goal_batch, lanes, lane_points)) # (T, b) - plan_hcost = torch.mean(plan_hcost, dim=0) # (b,) - - plan_x_full[:, plan_batch_filter] = plan_x - plan_u_full[:, plan_batch_filter] = plan_u - plan_hcost_full[plan_batch_filter] = plan_hcost - plan_converged_full[plan_batch_filter] = plan_converged - - if fan_candidates: - fan_ctrl_xu_batch, is_converged_batch = self.fan_obj.get_candidate_trajectories(x_init_batch, relevant_lanes_batch=plan_data['most_relevant_nearby_lanes_filtered']) - - for i, batch_i in enumerate(np.arange(batch_size)[plan_batch_filter.cpu().numpy()]): - fan_candidates_list[batch_i] = fan_ctrl_xu_batch[i] - fan_valid_full[batch_i] = is_converged_batch[i] - - return plan_x_full, plan_u_full, plan_converged_full, plan_hcost_full, plan_batch_filter, fan_candidates_list, fan_valid_full diff --git a/diffstack/modules/planners/fan_planner.py b/diffstack/modules/planners/fan_planner.py deleted file mode 100644 index d4986b1..0000000 --- a/diffstack/modules/planners/fan_planner.py +++ /dev/null @@ -1,160 +0,0 @@ -import torch -import functools - -from diffstack.modules.planners.fan_planner_utils import SplinePlanner -from nuscenes.map_expansion.map_api import NuScenesMap -from nuscenes.map_expansion import arcline_path_utils - -import matplotlib.pyplot as plt - -VISUALIZE = False -USE_CPU = True - - -class FanPlanner(torch.nn.Module): - - def __init__(self, ph, dt, device): - super().__init__() - self.spline_planner = SplinePlanner(device=('cpu' if USE_CPU else device), dt=dt) - self.ph = ph - self.dt = dt - self.scenes = None - - @staticmethod - @functools.lru_cache(maxsize=None) - def get_map(raw_data_path, map_name): - return NuScenesMap(dataroot=raw_data_path, map_name=map_name) - - def forward(self, x_init_batch, x_goal_batch, cost_obj, dyn_obj, cost_inputs, relevant_lanes_batch=None, candidate_trajs_batch=None, is_valid_batch=None): - device = x_init_batch.device - batch_size = x_init_batch.shape[0] - plan_info_dict = {'traj_xu': [], 'traj_cost': []} - - if candidate_trajs_batch is not None: - # Optionall pass in candidates directly - assert len(candidate_trajs_batch) == batch_size - assert len(is_valid_batch) == batch_size - - fan_ctrl_xu_batch = candidate_trajs_batch - is_valid_batch = is_valid_batch - - else: - assert len(relevant_lanes_batch) == batch_size - fan_ctrl_xu_batch, is_valid_batch = self.get_candidate_trajectories(x_init_batch, relevant_lanes_batch) - - # Get costs - traj_cost_batch = self.get_cost_for_trajs(fan_ctrl_xu_batch, cost_obj, cost_inputs) - plan_info_dict['traj_xu'] = fan_ctrl_xu_batch # intentionally dont detach because we will backprop through these - plan_info_dict['traj_cost'] = traj_cost_batch # intentionally dont detach because we will backprop through these - - # Choose the lowest cost as a default planning output - # but we also return all candidates and all costs for differnet loss functions. - plan_xu_batch = [] - plan_cost_batch = [] - for batch_i in range(batch_size): - fan_ctrl_xu = plan_info_dict['traj_xu'][batch_i] - traj_cost = traj_cost_batch[batch_i] - if is_valid_batch[batch_i]: - best_ind = torch.argmin(traj_cost, dim=0) - plan_cost = traj_cost[best_ind] - plan_xu = fan_ctrl_xu[best_ind] - else: - plan_cost = traj_cost.squeeze(0) - plan_xu = fan_ctrl_xu.squeeze(0) - - plan_xu_batch.append(plan_xu) - plan_cost_batch.append(plan_cost) - - plan_xu = torch.stack(plan_xu_batch, dim=1) # T, b, 6 - plan_x, plan_u = torch.split(plan_xu, (4, 2), dim=-1) - plan_cost = torch.stack(plan_cost_batch, dim=0) - if not isinstance(is_valid_batch, torch.Tensor): - is_valid_batch = torch.tensor(is_valid_batch, device=device).bool() - - return plan_x, plan_u, plan_cost, is_valid_batch, plan_info_dict - - def get_candidate_trajectories(self, x_init_batch, relevant_lanes_batch): - device = x_init_batch.device - batch_size = x_init_batch.shape[0] - - x, y, h, v = torch.unbind(x_init_batch, dim=-1) - x_init_xyvh_batch = torch.stack([x, y, v, h], dim=-1) - - if USE_CPU: - relevant_lanes_batch = [[pts.to('cpu') for pts in relevant_lanes] for relevant_lanes in relevant_lanes_batch] - x_init_xyvh_batch = x_init_xyvh_batch.to('cpu') - - # Split over batch - fan_ctrl_xu_batch = [] - is_converged_batch = [] - for batch_i in range(batch_size): - relevant_lanes = relevant_lanes_batch[batch_i] - if len(relevant_lanes)==0: - fan_ctrl_xu_batch.append(torch.zeros((1, self.ph+1, 6), device=device).detach()) - is_converged_batch.append(False) - continue - - x_init_xyvh = x_init_xyvh_batch[batch_i].unsqueeze(0) - fan_trajs, _ = self.spline_planner.gen_trajectory_batch(x_init_xyvh, self.ph * self.dt, relevant_lanes) - assert len(fan_trajs) == 1 - fan_trajs = fan_trajs[0] - - if fan_trajs.shape[0] > 0: - xy = fan_trajs[..., :2] - vel = fan_trajs[..., 2:3] - acce = fan_trajs[..., 3:4] - yaw = fan_trajs[..., 4:5] - yaw_rate = fan_trajs[..., 5:6] - fan_xu = torch.cat((xy, yaw, vel, yaw_rate, acce), -1) # (N, T+1, xu=6) - else: - fan_xu = torch.zeros((0, self.ph+1, 8), dtype=torch.float, device=device) # (N, T+1, xu) - - if USE_CPU: - # Move back to the gpu - # fan_trajs = fan_trajs.to(device) - fan_xu = fan_xu.to(device) - - num_candidates = fan_xu.shape[0] - if num_candidates == 0: - # Skip if there are no valid candidates. - # This happens quite frequently e.g. at low velocities, where - # reaching points from the lane center are dynamically infeasible - fan_ctrl_xu_batch.append(torch.zeros((1, self.ph+1, 6), device=device)) - is_converged_batch.append(False) - continue - - # Convert fan xu splines to state and control. - # The control part of the spline is the acceleration and steering at time t, not the same as control command. - # We can best track the trajectory by commanding at t for the target control as (u_t + u_{t+1})/2 - # In a previous version of the code u[t] = u[t+1] was used - fan_x, fan_u = torch.split(fan_xu, (4, 2), dim=-1) - ctrl_u = torch.cat(((fan_u[:,:-1] + fan_u[:, 1:])*0.5, fan_u[:, -1:]), dim=1) # last control doesnt matter - fan_ctrl_xu = torch.cat((fan_x, ctrl_u), dim=-1) - fan_ctrl_xu_batch.append(fan_ctrl_xu) - is_converged_batch.append(True) - - return fan_ctrl_xu_batch, is_converged_batch # TODO there is no need to return trees - - def get_cost_for_trajs(self, fan_ctrl_xu_batch, cost_obj, cost_inputs): - - batch_size = len(fan_ctrl_xu_batch) - gt_neighbors_batch, mus_batch, probs_batch, goal_batch, lanes, lane_points = cost_inputs - - traj_cost_batch = [] - for batch_i in range(batch_size): - fan_ctrl_xu = fan_ctrl_xu_batch[batch_i] - num_candidates = fan_ctrl_xu.shape[0] - cost_inputs_i = (None if gt_neighbors_batch is None else gt_neighbors_batch[:, :, batch_i].unsqueeze(2).tile((1, 1, num_candidates, 1)), - None if mus_batch is None else mus_batch[:, :, batch_i].unsqueeze(2).tile((1, 1, num_candidates, 1, 1)), - None if probs_batch is None else probs_batch[:, batch_i].unsqueeze(1).tile((1, num_candidates, 1)), - goal_batch[batch_i].unsqueeze(0).tile((num_candidates, 1)), - lanes[:, batch_i].unsqueeze(1).tile((1, num_candidates, 1)), - lane_points[:, batch_i].unsqueeze(1).tile((1, num_candidates, 1, 1)) if lane_points is not None else None, - ) - - traj_cost = cost_obj(fan_ctrl_xu.transpose(1, 0), cost_inputs_i) # T, b - traj_cost = torch.sum(traj_cost, dim=0) # b, - traj_cost_batch.append(traj_cost) - - return traj_cost_batch - diff --git a/diffstack/modules/planners/fan_planner_utils.py b/diffstack/modules/planners/fan_planner_utils.py deleted file mode 100644 index 94e5852..0000000 --- a/diffstack/modules/planners/fan_planner_utils.py +++ /dev/null @@ -1,345 +0,0 @@ -import numpy as np -import torch -from scipy.interpolate import interp1d -from typing import List, Optional, Tuple - -from trajdata.maps.map_api import VectorMap -from trajdata.maps.vec_map_elements import RoadLane, Polyline -from diffstack.utils.utils import angle_wrap - - -def cubic_spline_coefficients(x0, dx0, xf, dxf, tf): - return (x0, dx0, -2 * dx0 / tf - dxf / tf - 3 * x0 / tf ** 2 + 3 * xf / tf ** 2, - dx0 / tf ** 2 + dxf / tf ** 2 + 2 * x0 / tf ** 3 - 2 * xf / tf ** 3) - - -def compute_interpolating_spline(state_0, state_f, tf): - dx0, dy0 = state_0[..., 2] * \ - torch.cos(state_0[..., 3]), state_0[..., 2] * \ - torch.sin(state_0[..., 3]) - dxf, dyf = state_f[..., 2] * \ - torch.cos(state_f[..., 3]), state_f[..., 2] * \ - torch.sin(state_f[..., 3]) - tf = tf * torch.ones_like(state_0[..., 0]) - return ( - torch.stack(cubic_spline_coefficients( - state_0[..., 0], dx0, state_f[..., 0], dxf, tf), -1), - torch.stack(cubic_spline_coefficients( - state_0[..., 1], dy0, state_f[..., 1], dyf, tf), -1), - tf, - ) - - -def compute_spline_xyvaqrt(x_coefficients, y_coefficients, tf, N=10): - t = torch.arange(N).unsqueeze(0).to(tf.device) * tf.unsqueeze(-1) / (N - 1) - tp = t[..., None] ** torch.arange(4).to(tf.device) - dtp = t[..., None] ** torch.tensor([0, 0, 1, 2] - ).to(tf.device) * torch.arange(4).to(tf.device) - ddtp = t[..., None] ** torch.tensor([0, 0, 0, 1]).to( - tf.device) * torch.tensor([0, 0, 2, 6]).to(tf.device) - x_coefficients = x_coefficients.unsqueeze(-1) - y_coefficients = y_coefficients.unsqueeze(-1) - vx = dtp @ x_coefficients - vy = dtp @ y_coefficients - v = torch.hypot(vx, vy) - v_pos = torch.clip(v, min=1e-4) - ax = ddtp @ x_coefficients - ay = ddtp @ y_coefficients - a = (ax * vx + ay * vy) / v_pos - r = (-ax * vy + ay * vx) / (v_pos ** 2) - yaw = torch.atan2(vy, vx) # TODO(pkarkus) this is invalid for v=0 - return torch.cat(( - tp @ x_coefficients, - tp @ y_coefficients, - v, - a, - yaw, - r, - t.unsqueeze(-1), - ), -1) - - -def interp_lanes(lane, extrapolate=True): - """ generate interpolants for lanes - - Args: - lane (np.array()): [Nx3] - - Returns: - - """ - if isinstance(lane, torch.Tensor): - lane = lane.cpu().numpy() - ds = np.cumsum( - np.hstack([0., np.linalg.norm(lane[1:, :2]-lane[:-1, :2], axis=-1)])) - - if extrapolate: - # Allow extrapolation: - return interp1d(ds, lane, fill_value="extrapolate", assume_sorted=True, axis=0), lane[0] - else: - # Nans for extrapolation - return interp1d(ds, lane, bounds_error=False, assume_sorted=True, axis=0), lane[0] - - -def batch_rotate_2D(xy, theta): - if isinstance(xy, torch.Tensor): - x1 = xy[..., 0] * torch.cos(theta) - xy[..., 1] * torch.sin(theta) - y1 = xy[..., 1] * torch.cos(theta) + xy[..., 0] * torch.sin(theta) - return torch.stack([x1, y1], dim=-1) - elif isinstance(xy, np.ndarray): - x1 = xy[..., 0] * np.cos(theta) - xy[..., 1] * np.sin(theta) - y1 = xy[..., 1] * np.cos(theta) + xy[..., 0] * np.sin(theta) - return np.concatenate((x1[..., None], y1[..., None]), axis=-1) - - -class SplinePlanner(object): - def __init__(self, device, dx_grid=None, dy_grid=None, acce_grid=None, dyaw_grid=None, max_steer=0.5, max_rvel=8, - acce_bound=[-6, 4], vbound=[-10, 30], spline_order=3, dt=0.5): - self.spline_order = spline_order - self.device = device - assert spline_order == 3 - if dx_grid is None: - self.dx_grid = torch.tensor([-4., 0, 4.]).to(self.device) - # self.dx_grid = torch.tensor([0.]).to(self.device) - else: - self.dx_grid = dx_grid - if dy_grid is None: - self.dy_grid = torch.tensor([-4., -2., 0, 2., 4.]).to(self.device) - else: - self.dy_grid = dy_grid - if acce_grid is None: - self.acce_grid = torch.tensor([-3., -2., -1., -0.5, 0., 0.5, 1., 2.]).to(self.device) - # self.acce_grid = torch.tensor([-1., 0., 1.]).to(self.device) - else: - self.acce_grid = acce_grid - self.d_lane_lat_grid = torch.tensor([-0.5, 0, 0.5]).to(self.device) - - if dyaw_grid is None: - self.dyaw_grid = torch.tensor( - [-np.pi / 12, 0, np.pi / 12]).to(self.device) - else: - self.dyaw_grid = dyaw_grid - self.max_steer = max_steer - self.max_rvel = max_rvel - self.acce_bound = acce_bound - self.vbound = vbound - self.dt = dt - - def calc_trajectories(self, x0, tf, xf): - if x0.ndim == 1: - x0_tile = x0.tile(xf.shape[0], 1) - xc, yc, tf_vect = compute_interpolating_spline(x0_tile, xf, tf) - elif x0.ndim == xf.ndim: - xc, yc, tf_vect = compute_interpolating_spline(x0, xf, tf) - else: - raise ValueError("wrong dimension for x0") - traj = compute_spline_xyvaqrt(xc, yc, tf_vect, N=round(tf/self.dt) + 1) # +1 for t0 - return traj - - def gen_terminals_lane_original(self, x0, tf, lanes): - if lanes is None: - return self.gen_terminals(x0, tf) - - gs = [self.dx_grid.shape[0], self.acce_grid.shape[0]] - dx = self.dx_grid[:, None, None, None].repeat(1, 1, gs[1], 1).flatten() - dv = self.acce_grid[None, None, :, None].repeat( - gs[0], 1, 1, 1).flatten()*tf - - delta_x = list() - - assert x0.ndim in [1, 2], "x0 must have dimension 1 or 2" - - is_batched = (x0.ndim > 1) - if x0.ndim == 1: - x0 = x0.unsqueeze(0) - - for lane in lanes: - f, p_start = lane - p_start = torch.from_numpy(p_start).to(x0.device) - offset = x0[:, :2]-p_start[None, :2] - s_offset = offset[:, 0] * \ - torch.cos(p_start[2])+offset[:, 1]*torch.sin(p_start[2]) - ds = (dx+dv/2*tf).unsqueeze(0)+x0[:, 2:3]*tf - ss = ds + s_offset.unsqueeze(-1) - xyyaw = torch.from_numpy(f(ss.cpu().numpy())).type( - torch.float).to(x0.device) - delta_x.append(torch.cat((xyyaw[..., :2], dv.tile( - x0.shape[0], 1).unsqueeze(-1)+x0[:, None, 2:3], xyyaw[..., 2:]), -1)) - - delta_x = torch.cat(delta_x, -2) - - if not is_batched: - delta_x = delta_x.squeeze(0) - return delta_x - - def gen_terminals_lane(self, x0, tf, lanes): - if lanes is None: - return self.gen_terminals(x0, tf) - - gs = [self.d_lane_lat_grid.shape[0], self.acce_grid.shape[0]] - dlat = self.d_lane_lat_grid[:, None, None, None].repeat(1, 1, gs[1], 1).flatten() - dv = self.acce_grid[None, None, :, None].repeat( - gs[0], 1, 1, 1).flatten()*tf - - delta_x = list() - - assert x0.ndim in [1, 2], "x0 must have dimension 1 or 2" - is_batched = (x0.ndim > 1) - if x0.ndim == 1: - x0 = x0.unsqueeze(0) - - for lane in lanes: - f, p_start = lane # f: interplation function f(ds)--> lane_x,y,yaw - if isinstance(p_start, np.ndarray): - p_start = torch.from_numpy(p_start) - p_start = p_start.to(x0.device) - offset = x0[:, :2]-p_start[None, :2] - s_offset = offset[:, 0] * \ - torch.cos(p_start[2])+offset[:, 1]*torch.sin(p_start[2]) # distance projected onto lane, from its starting point - - # TODO this can be wildly inaccurate when lane is strongly curved - # instead we should use the projection of current state onto lane, and find the distance along lane - # we can do that by storing cumsum along lane, and assuming straight path from the closest lane point - - v0 = x0[:, 2:3] - vf = v0 + dv.unsqueeze(0) - # Replace negative velocity with stopping - vf = torch.maximum(vf, torch.zeros_like(vf)) - - ds = (v0 + vf) * 0.5 * tf # delta distance along lane. dx is from grid, dv is final velocity from grid, average delta vel is dv/2, x0[:, 2:3] current velo - ss = ds + s_offset.unsqueeze(-1) # target distance along lane - xyyaw = torch.from_numpy(f(ss.cpu().numpy())).type( - torch.float).to(x0.device) # interpolate lane for target ds - - # y offset in the direction of lane normal - dlat_xy = dlat.unsqueeze(0).unsqueeze(-1) * torch.stack([torch.sin(xyyaw[..., 2]), -torch.cos(xyyaw[..., 2])], dim=-1) - - target_xyvh = torch.cat(( - xyyaw[..., :1] + dlat_xy[..., :1], - xyyaw[..., 1:2] + dlat_xy[..., 1:2], - x0[:, None, 2:3] + dv.tile(x0.shape[0], 1).unsqueeze(-1), - xyyaw[..., 2:] - ), -1) # xyh --> xyvh. insert target velo = d0+dv - - # Filter nans (extrapolation) and negative vel target - assert target_xyvh.shape[0] == 1, "No batch support for now, it would need ragged tensor" - target_xyvh = target_xyvh.squeeze(0) # N, 4 - target_xyvh = target_xyvh[torch.logical_not(target_xyvh[:, 2].isnan()) & (target_xyvh[:, 2] >= 0)] # N*, 4 - target_xyvh = target_xyvh.unsqueeze(0) # 1, N*, 4 - - delta_x.append(target_xyvh) - - delta_x = torch.cat(delta_x, -2) - - if not is_batched: - delta_x = delta_x.squeeze(0) - return delta_x - - def gen_terminals(self, x0, tf): - gs = [self.dx_grid.shape[0], self.dy_grid.shape[0], - self.acce_grid.shape[0], self.dyaw_grid.shape[0]] - dx = self.dx_grid[:, None, None, None].repeat( - 1, gs[1], gs[2], gs[3]).flatten() - dy = self.dy_grid[None, :, None, None].repeat( - gs[0], 1, gs[2], gs[3]).flatten() - dv = tf * self.acce_grid[None, None, :, - None].repeat(gs[0], gs[1], 1, gs[3]).flatten() - dyaw = self.dyaw_grid[None, None, None, :].repeat( - gs[0], gs[1], gs[2], 1).flatten() - delta_x = torch.stack([dx, dy, dv, dyaw], -1) - - if x0.ndim == 1: - xy = torch.cat( - (delta_x[:, 0:1] + delta_x[:, 2:3] / 2 * tf + x0[2:3] * tf, delta_x[:, 1:2]), -1) - rotated_xy = batch_rotate_2D(xy, x0[3]) + x0[:2] - return torch.cat((rotated_xy, delta_x[:, 2:] + x0[2:]), -1) + x0[None, :] - elif x0.ndim == 2: - - delta_x = torch.tile(delta_x, [x0.shape[0], 1, 1]) - xy = torch.cat( - (delta_x[:, :, 0:1] + delta_x[:, :, 2:3] / 2 * tf + x0[:, None, 2:3] * tf, delta_x[:, :, 1:2]), -1) - rotated_xy = batch_rotate_2D( - xy, x0[:, 3:4]) + x0[:, None, :2] - - return torch.cat((rotated_xy, delta_x[:, :, 2:] + x0[:, None, 2:]), -1) + x0[:, None, :] - else: - raise ValueError("x0 must have dimension 1 or 2") - - def feasible_flag(self, traj): - feas_flag = ((traj[..., 2] >= self.vbound[0]) & (traj[..., 2] < self.vbound[1]) & - (traj[..., 3] >= self.acce_bound[0]) & (traj[..., 3] <= self.acce_bound[1]) & - (torch.abs(traj[..., 5] * traj[..., 2]) <= self.max_rvel) & ( - torch.abs(traj[..., 2]) * self.max_steer >= torch.abs(traj[..., 5]))).all(1) - return feas_flag - - def gen_trajectories(self, x0, tf, lanes=None, dyn_filter=True): - if lanes is None: - xf = self.gen_terminals(x0, tf) - else: - lane_interp = [interp_lanes(lane, extrapolate=False) for lane in lanes] - - xf = self.gen_terminals_lane( - x0, tf, lane_interp) - - # x, y, v, a, yaw,r, t - traj = self.calc_trajectories(x0, tf, xf) - if dyn_filter: - feas_flag = self.feasible_flag(traj) - return traj[feas_flag, :], xf[feas_flag, :] - else: - return traj, xf - - def gen_trajectory_batch(self, x0_set, tf, lanes=None, dyn_filter=True): - # x0_set states (n, 4) for x, y, vel, yaw - - if lanes is None: - xf_set = self.gen_terminals(x0_set, tf) - else: - # Do not allow extrapolation, will return nan - lane_interp = [interp_lanes(lane, extrapolate=False) for lane in lanes] - # x, y, v, yaw - xf_set = self.gen_terminals_lane(x0_set, tf, lane_interp) - - num_node = x0_set.shape[0] - num = xf_set.shape[1] - x0_tiled = torch.tile(x0_set, [num, 1]) - xf_tiled = xf_set.reshape(-1, xf_set.shape[-1]) - # x, y, v, a, yaw,r, t - traj = self.calc_trajectories(x0_tiled, tf, xf_tiled) - - # yaw values are incorrect when v=0, correct it by taking yaw at t-1 - yaw_tm1 = x0_tiled[:, 3] # this is x, y, v, yaw - for t in range(traj.shape[1]): - # traj is x, y, v, a, yaw,r, t - invalid_yaw_flag = torch.isclose(traj[:, t, 2], torch.zeros((), dtype=traj.dtype, device=traj.device)) - traj[invalid_yaw_flag, t, 4] = yaw_tm1[invalid_yaw_flag] - yaw_tm1 = traj[:, t, 4] - - if dyn_filter: - feas_flag = self.feasible_flag(traj) - else: - feas_flag = torch.ones( - num * num_node, dtype=torch.bool).to(x0_set.device) - feas_flag = feas_flag.reshape(num_node, num) - traj = traj.reshape(num_node, num, *traj.shape[1:]) - return [traj[i, feas_flag[i]] for i in range(num_node)], xf_tiled - - def gen_trajectory_tree(self, x0, tf, n_layers, dyn_filter=True): - trajs = list() - nodes = [x0[None, :]] - for i in range(n_layers): - xf = self.gen_terminals(nodes[i], tf) - x0i = torch.tile(nodes[i], [xf.shape[1], 1]) - xf = xf.reshape(-1, xf.shape[-1]) - - traj = self.calc_trajectories(x0i, tf, xf) - if dyn_filter: - feas_flag = self.feasible_flag(traj) - traj = traj[feas_flag] - xf = xf[feas_flag] - - trajs.append(traj) - - nodes.append(xf.reshape(-1, xf.shape[-1])) - return trajs, nodes[1:] - diff --git a/diffstack/modules/planners/mpc_utils/LICENSE.mit b/diffstack/modules/planners/mpc_utils/LICENSE.mit deleted file mode 100644 index be3e911..0000000 --- a/diffstack/modules/planners/mpc_utils/LICENSE.mit +++ /dev/null @@ -1,22 +0,0 @@ -Copyright (c) 2018, Carnegie Mellon University - -Permission is hereby granted, free of charge, to any person -obtaining a copy of this software and associated documentation -files (the "Software"), to deal in the Software without -restriction, including without limitation the rights to use, -copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following -conditions: - -The above copyright notice and this permission notice shall be -included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT -HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR -OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/diffstack/modules/planners/mpc_utils/README.md b/diffstack/modules/planners/mpc_utils/README.md deleted file mode 100644 index ccaba08..0000000 --- a/diffstack/modules/planners/mpc_utils/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# MPC utilities - -This folder contains code that is partially copied, modified, or extended from [Differentiable MPC](https://github.com/locuslab/mpc.pytorch) \ No newline at end of file diff --git a/diffstack/modules/planners/mpc_utils/__init__.py b/diffstack/modules/planners/mpc_utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/diffstack/modules/planners/mpc_utils/lqr_step_refactored.py b/diffstack/modules/planners/mpc_utils/lqr_step_refactored.py deleted file mode 100644 index 5aa74e1..0000000 --- a/diffstack/modules/planners/mpc_utils/lqr_step_refactored.py +++ /dev/null @@ -1,506 +0,0 @@ -import torch -from torch.autograd import Function, Variable -from torch.nn import Module -from torch.nn.parameter import Parameter - -import numpy as np -import numpy.random as npr - -from collections import namedtuple - -from mpc import mpc - -from mpc.pnqp import pnqp -from mpc import util -from mpc.util import get_traj - -LqrBackOut = namedtuple('lqrBackOut', 'n_total_qp_iter') -LqrForOut = namedtuple( - 'lqrForOut', - 'objs full_du_norm alpha_du_norm mean_alphas costs' -) - - -import warnings -warnings.filterwarnings("default", category=UserWarning) - - -class LQRStepClass(Module): - def __init__(self, - n_state, - n_ctrl, - T, - u_lower=None, - u_upper=None, - u_zero_I=None, - delta_u=None, - linesearch_decay=0.2, - max_linesearch_iter=10, - # true_cost=None, - # true_dynamics=None, - delta_space=True, - # current_x=None, - # current_u=None, - verbose=0, - back_eps=1e-3): - """A single step of the box-constrained iLQR solver. - - Required Args: - n_state, n_ctrl, T - x_init: The initial state [n_batch, n_state] - - Optional Args: - u_lower, u_upper: The lower- and upper-bounds on the controls. - These can either be floats or shaped as [T, n_batch, n_ctrl] - TODO: Better support automatic expansion of these. - TODO - """ - super().__init__() - - self.n_state = n_state - self.n_ctrl = n_ctrl - self.T = T - self.u_lower = u_lower - self.u_upper = u_upper - self.u_zero_I = u_zero_I - self.delta_u = delta_u - self.linesearch_decay = linesearch_decay - self.max_linesearch_iter = max_linesearch_iter - self.delta_space = delta_space - self.verbose = verbose - self.back_eps = back_eps - - def forward(ctx, x, u, true_cost, true_dynamics, no_op_forward, x_init, C, c, F, f=None, cost_inputs=None): - params = (ctx.n_state, ctx.n_ctrl, ctx.T, ctx.u_lower, ctx.u_upper, ctx.u_zero_I, ctx.delta_u, ctx.linesearch_decay, ctx.max_linesearch_iter, ctx.delta_space, ctx.verbose, ctx.back_eps) - return LQRStepFunction.apply(x, u, true_cost, true_dynamics, params, no_op_forward, x_init, C, c, F, f, cost_inputs) - - - -# @profile -def lqr_backward(C, c, F, f, u, params): - n_state, n_ctrl, T, u_lower, u_upper, u_zero_I, delta_u, linesearch_decay, max_linesearch_iter, delta_space, verbose, back_eps = params - n_batch = C.size(1) - - Ks = [] - ks = [] - prev_kt = None - n_total_qp_iter = 0 - Vtp1 = vtp1 = None - for t in range(T-1, -1, -1): - if t == T-1: - Qt = C[t] - qt = c[t] - else: - Ft = F[t] - Ft_T = Ft.transpose(1,2) - Qt = C[t] + Ft_T.bmm(Vtp1).bmm(Ft) - if f is None or f.nelement() == 0: - qt = c[t] + Ft_T.bmm(vtp1.unsqueeze(2)).squeeze(2) - else: - ft = f[t] - qt = c[t] + Ft_T.bmm(Vtp1).bmm(ft.unsqueeze(2)).squeeze(2) + \ - Ft_T.bmm(vtp1.unsqueeze(2)).squeeze(2) - - Qt_xx = Qt[:, :n_state, :n_state] - Qt_xu = Qt[:, :n_state, n_state:] - Qt_ux = Qt[:, n_state:, :n_state] - Qt_uu = Qt[:, n_state:, n_state:] - qt_x = qt[:, :n_state] - qt_u = qt[:, n_state:] - - if u_lower is None: - if n_ctrl == 1 and u_zero_I is None: - Kt = -(1./Qt_uu)*Qt_ux - kt = -(1./Qt_uu.squeeze(2))*qt_u - else: - if u_zero_I is None: - Qt_uu_inv = [ - torch.pinverse(Qt_uu[i]) for i in range(Qt_uu.shape[0]) - ] - Qt_uu_inv = torch.stack(Qt_uu_inv) - Kt = -Qt_uu_inv.bmm(Qt_ux) - kt = util.bmv(-Qt_uu_inv, qt_u) - - # Debug - # assert not Qt_uu.isnan().any() - # if Qt_uu_inv.isnan().any(): - # print (torch.nonzero(Qt_uu_inv.isnan().any(dim=2).any(dim=1))) - # assert False - # assert not Qt_ux.isnan().any() - # assert not Kt.isnan().any() - # assert not kt.isnan().any() - - # Qt_uu_LU = Qt_uu.lu() - # Kt = -Qt_ux.lu_solve(*Qt_uu_LU) - # kt = -qt_u.lu_solve(*Qt_uu_LU) - else: - # Solve with zero constraints on the active controls. - I = u_zero_I[t].float() - notI = 1-I - - qt_u_ = qt_u.clone() - qt_u_[I.bool()] = 0 - - Qt_uu_ = Qt_uu.clone() - - if I.is_cuda: - notI_ = notI.float() - Qt_uu_I = (1-util.bger(notI_, notI_)).type_as(I) - else: - Qt_uu_I = 1-util.bger(notI, notI) - - Qt_uu_[Qt_uu_I.bool()] = 0. - Qt_uu_[bdiag(I).bool()] += 1e-8 - - Qt_ux_ = Qt_ux.clone() - Qt_ux_[I.unsqueeze(2).repeat(1,1,Qt_ux.size(2)).bool()] = 0. - - if n_ctrl == 1: - Kt = -(1./Qt_uu_)*Qt_ux_ - kt = -(1./Qt_uu.squeeze(2))*qt_u_ - else: - Qt_uu_LU_ = Qt_uu_.lu() - Kt = -Qt_ux_.lu_solve(*Qt_uu_LU_) - kt = -qt_u_.unsqueeze(2).lu_solve(*Qt_uu_LU_).squeeze(2) - else: - assert delta_space - lb = get_bound('lower', t, u_lower, u_upper) - u[t] - ub = get_bound('upper', t, u_lower, u_upper) - u[t] - if delta_u is not None: - lb[lb < -delta_u] = -delta_u - ub[ub > delta_u] = delta_u - kt, Qt_uu_free_LU, If, n_qp_iter = pnqp( - Qt_uu, qt_u, lb, ub, - x_init=prev_kt, n_iter=20) - if verbose > 1: - print(' + n_qp_iter: ', n_qp_iter+1) - n_total_qp_iter += 1+n_qp_iter - prev_kt = kt - Qt_ux_ = Qt_ux.clone() - Qt_ux_[(1-If).unsqueeze(2).repeat(1,1,Qt_ux.size(2)).bool()] = 0 - if n_ctrl == 1: - # Bad naming, Qt_uu_free_LU isn't the LU in this case. - Kt = -((1./Qt_uu_free_LU)*Qt_ux_) - else: - Kt = -Qt_ux_.lu_solve(*Qt_uu_free_LU) - - Kt_T = Kt.transpose(1,2) - - Ks.append(Kt) - ks.append(kt) - - Vtp1 = Qt_xx + Qt_xu.bmm(Kt) + Kt_T.bmm(Qt_ux) + Kt_T.bmm(Qt_uu).bmm(Kt) - vtp1 = qt_x + Qt_xu.bmm(kt.unsqueeze(2)).squeeze(2) + \ - Kt_T.bmm(qt_u.unsqueeze(2)).squeeze(2) + \ - Kt_T.bmm(Qt_uu).bmm(kt.unsqueeze(2)).squeeze(2) - - return Ks, ks, LqrBackOut(n_total_qp_iter=n_total_qp_iter) - - -# @profile -def lqr_forward(x_init, C, c, F, f, Ks, ks, cost_inputs, x, u, true_cost, true_dynamics, params): - n_state, n_ctrl, T, u_lower, u_upper, u_zero_I, delta_u, linesearch_decay, max_linesearch_iter, delta_space, verbose, back_eps = params - n_batch = C.size(1) - - old_cost = get_cost(T, u, true_cost, true_dynamics, x=x, cost_inputs=cost_inputs) - - current_cost = None - alphas = torch.ones(n_batch, dtype=C.dtype, device=C.device) - full_du_norm = None - # not_improved_first = None # debug - - # Not implemented, see comment below where to add logic - assert not ((delta_u is not None) and (u_lower is None)) - - i = 0 - while (current_cost is None or \ - (old_cost is not None and \ - torch.any((current_cost > old_cost)).cpu().item() == 1)) and \ - i < max_linesearch_iter: - - # We continue the linesearch logic for the full batch, but we will be only updating alpha where cost is not improving - new_u = [] - new_x = [x_init] - dx = [torch.zeros_like(x_init)] - objs = [] - for t in range(T): - t_rev = T-1-t - Kt = Ks[t_rev] - kt = ks[t_rev] - new_xt = new_x[t] - xt = x[t] - ut = u[t] - dxt = dx[t] - new_ut = util.bmv(Kt, dxt) + ut + torch.diag(alphas).mm(kt) - - # This is where we should deal with delta_u and u_lower - if u_zero_I is not None: - new_ut[u_zero_I[t]] = 0. - - if u_lower is not None: - lb = get_bound('lower', t, u_lower, u_upper) - ub = get_bound('upper', t, u_lower, u_upper) - - if delta_u is not None: - lb_limit, ub_limit = lb, ub - lb = u[t] - delta_u - ub = u[t] + delta_u - I = lb < lb_limit - lb[I] = lb_limit if isinstance(lb_limit, float) else lb_limit[I] - I = ub > ub_limit - ub[I] = ub_limit if isinstance(lb_limit, float) else ub_limit[I] - - if isinstance(lb, float): - # TODO(pkarkus) added this hack - new_ut = util.eclamp(new_ut, lb, ub) - elif lb.ndim == 1: - new_ut = util.eclamp( - new_ut, - lb.type_as(new_ut).unsqueeze(0).expand_as(new_ut), - ub.type_as(new_ut).unsqueeze(0).expand_as(new_ut)) - else: - new_ut = util.eclamp( - new_ut, lb.type_as(new_ut), ub.type_as(new_ut)) - new_u.append(new_ut) - - new_xut = torch.cat((new_xt, new_ut), dim=1) - if t < T-1: - new_xtp1 = true_dynamics( - Variable(new_xt), Variable(new_ut)).data - - new_x.append(new_xtp1) - dx.append(new_xtp1 - x[t+1]) - - objs.append(new_xut) - - objs = torch.stack(objs) - objs = true_cost(objs, cost_inputs=cost_inputs) - - current_cost = torch.sum(objs, dim=0) - - new_u = torch.stack(new_u) - new_x = torch.stack(new_x) - if full_du_norm is None: - full_du_norm = torch.linalg.norm((u-new_u).transpose(0, 1).reshape(n_batch, -1), dim=-1) - # not_improved_first = (current_cost > old_cost) - # full_du_norm_original = (u-new_u).transpose(1,2).contiguous().view( - # n_batch, -1).norm(2, 1) - - alphas[current_cost > old_cost] *= linesearch_decay - i += 1 - - # # Debug - # not_improved_last = (current_cost > old_cost) - # if (torch.logical_not(not_improved_first) & not_improved_last).any(): - # print (not_improved_first) - # print (not_improved_last) - # raise ValueError() - - # If the iteration limit is hit, some alphas - # are one step too small. - alphas[current_cost > old_cost] /= linesearch_decay - alpha_du_norm = (u-new_u).transpose(0,1).reshape(n_batch, -1).norm(2, 1) - # alpha_du_norm_original = (u-new_u).transpose(1,2).contiguous().view( - # n_batch, -1).norm(2, 1) - - return new_x, new_u, LqrForOut( - objs, full_du_norm, - alpha_du_norm, - torch.mean(alphas), - current_cost - ) - - -def get_cost(T, u, cost, dynamics=None, x_init=None, x=None, cost_inputs=None): - assert x_init is not None or x is not None - - if x is None: - x = get_traj(T, u, x_init, dynamics) - - xu = torch.cat((x, u), -1) - objs = cost(xu, cost_inputs) - total_obj = torch.sum(objs, dim=0) - return total_obj - - -def get_bound(side, t, u_lower, u_upper): - if side == 'lower': - v = u_lower - if side == 'upper': - v = u_upper - if isinstance(v, float): - return v - elif v.ndim == 1: - return v - else: - return v[t] - - -def bdiag(d): - assert d.ndimension() == 2 - nBatch, sz = d.size() - dtype = d.type() if not isinstance(d, Variable) else d.data.type() - D = torch.zeros(nBatch, sz, sz).type(dtype) - I = torch.eye(sz).repeat(nBatch, 1, 1).bool() - D[I] = d.view(-1) - return D - - -class LQRStepFunction(Function): - # @profile - def forward(ctx, x, u, true_cost, true_dynamics, params, no_op_forward, x_init, C, c, F, f=None, cost_inputs=None): - n_state, n_ctrl, T, u_lower, u_upper, u_zero_I, delta_u, linesearch_decay, max_linesearch_iter, delta_space, verbose, back_eps = params - - # Save for backward - ctx.n_state = n_state - ctx.n_ctrl = n_ctrl - ctx.T = T - ctx.u_lower = u_lower - ctx.u_upper = u_upper - ctx.delta_space = delta_space - ctx.back_eps = back_eps - - assert all([not isinstance(var, torch.Tensor) and not isinstance(var, Variable) for var in [n_state, n_ctrl, T, delta_space, back_eps]]) - # assert all([not isinstance(var, Variable) for var in [u_lower, u_upper]]) - - if no_op_forward: - ctx.save_for_backward(x_init, C, c, F, f, x, u) - - return x, u, None, None - - if delta_space: - # Taylor-expand the objective to do the backward pass in - # the delta space. - assert x is not None - assert u is not None - c_back = [] - for t in range(T): - xt = x[t] - ut = u[t] - xut = torch.cat((xt, ut), 1) - c_back.append(util.bmv(C[t], xut) + c[t]) - c_back = torch.stack(c_back) - f_back = None - else: - assert False - - Ks, ks, back_out = lqr_backward(C, c_back, F, f_back, u, params) - # if Ks[-1].isnan().any(): - # print ("nan") - # Ks, ks, back_out = lqr_backward(C, c_back, F, f_back, u, params) - - new_x, new_u, for_out = lqr_forward( - x_init, C, c, F, f, Ks, ks, cost_inputs, x, u, true_cost, true_dynamics, params) - ctx.save_for_backward(x_init, C, c, F, f, new_x, new_u) - - return new_x, new_u, back_out, for_out - - def backward(ctx, dl_dx, dl_du, temp=None, temp2=None): - n_state = ctx.n_state - n_ctrl = ctx.n_ctrl - T = ctx.T - u_lower = ctx.u_lower - u_upper = ctx.u_upper - delta_space = ctx.delta_space - back_eps = ctx.back_eps - - # start = time.time() - x_init, C, c, F, f, new_x, new_u = ctx.saved_tensors - - r = [] - for t in range(T): - rt = torch.cat((dl_dx[t], dl_du[t]), 1) - r.append(rt) - r = torch.stack(r) - - if u_lower is None: - I = None - else: - I = (torch.abs(new_u - u_lower) <= 1e-8) | \ - (torch.abs(new_u - u_upper) <= 1e-8) - dx_init = Variable(torch.zeros_like(x_init)) - _mpc = mpc.MPC( - n_state, n_ctrl, T, - u_zero_I=I, - u_init=None, - lqr_iter=1, - verbose=-1, - n_batch=C.size(1), - delta_u=None, - # exit_unconverged=True, # It's really bad if this doesn't converge. - exit_unconverged=False, # It's really bad if this doesn't converge. - eps=back_eps, - ) - - # Suppress warnings about deprecated torch functions - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - dx, du, _ = _mpc(dx_init, mpc.QuadCost(C, -r), mpc.LinDx(F, None)) - - dx, du = dx.detach(), du.detach() - dxu = torch.cat((dx, du), 2) - xu = torch.cat((new_x, new_u), 2) - - dC = torch.zeros_like(C) - for t in range(T): - xut = torch.cat((new_x[t], new_u[t]), 1) - dxut = dxu[t] - dCt = -0.5*(util.bger(dxut, xut) + util.bger(xut, dxut)) - dC[t] = dCt - - dc = -dxu - - lams = [] - prev_lam = None - for t in range(T-1, -1, -1): - Ct_xx = C[t,:,:n_state,:n_state] - Ct_xu = C[t,:,:n_state,n_state:] - ct_x = c[t,:,:n_state] - xt = new_x[t] - ut = new_u[t] - lamt = util.bmv(Ct_xx, xt) + util.bmv(Ct_xu, ut) + ct_x - if prev_lam is not None: - Fxt = F[t,:,:,:n_state].transpose(1, 2) - lamt += util.bmv(Fxt, prev_lam) - lams.append(lamt) - prev_lam = lamt - lams = list(reversed(lams)) - - dlams = [] - prev_dlam = None - for t in range(T-1, -1, -1): - dCt_xx = C[t,:,:n_state,:n_state] - dCt_xu = C[t,:,:n_state,n_state:] - drt_x = -r[t,:,:n_state] - dxt = dx[t] - dut = du[t] - dlamt = util.bmv(dCt_xx, dxt) + util.bmv(dCt_xu, dut) + drt_x - if prev_dlam is not None: - Fxt = F[t,:,:,:n_state].transpose(1, 2) - dlamt += util.bmv(Fxt, prev_dlam) - dlams.append(dlamt) - prev_dlam = dlamt - dlams = torch.stack(list(reversed(dlams))) - - dF = torch.zeros_like(F) - for t in range(T-1): - xut = xu[t] - lamt = lams[t+1] - - dxut = dxu[t] - dlamt = dlams[t+1] - - dF[t] = -(util.bger(dlamt, xut) + util.bger(lamt, dxut)) - - if f.nelement() > 0: - _dlams = dlams[1:] - assert _dlams.shape == f.shape - df = -_dlams - else: - df = torch.Tensor() - - dx_init = -dlams[0] - - # backward_time = time.time()-start - return None, None, None, None, None, None, dx_init, dC, dc, dF, df, None - diff --git a/diffstack/modules/planners/mpc_utils/trajcost_mpc.py b/diffstack/modules/planners/mpc_utils/trajcost_mpc.py deleted file mode 100644 index dc01d81..0000000 --- a/diffstack/modules/planners/mpc_utils/trajcost_mpc.py +++ /dev/null @@ -1,296 +0,0 @@ -import torch -from torch.autograd import Variable - -from .lqr_step_refactored import LQRStepClass, get_cost - -from mpc import util -from mpc.mpc import GradMethods - - -class TrajCostMPC(torch.nn.Module): - """A differentiable box-constrained iLQR solver. - - This code is an extension of the mpc module from - https://github.com/locuslab/mpc.pytorch - for costs defined over trajectories, and custom nonlinear dynamics. - - This provides a differentiable solver for the following box-constrained - control problem with a quadratic cost (defined by C and c) and - non-linear dynamics (defined by f): - - min_{tau={x,u}} sum_t 0.5 tau_t^T C_t tau_t + c_t^T tau_t - s.t. x_{t+1} = f(x_t, u_t) - x_0 = x_init - u_lower <= u <= u_upper - - This implements the Control-Limited Differential Dynamic Programming - paper with a first-order approximation to the non-linear dynamics: - https://homes.cs.washington.edu/~todorov/papers/TassaICRA14.pdf - - Some of the notation here is from Sergey Levine's notes: - http://rll.berkeley.edu/deeprlcourse/f17docs/lecture_8_model_based_planning.pdf - - Required Args: - n_state, n_ctrl, T - - Optional Args: - u_lower, u_upper: The lower- and upper-bounds on the controls. - These can either be floats or shaped as [T, n_batch, n_ctrl] - u_init: The initial control sequence, useful for warm-starting: - [T, n_batch, n_ctrl] - lqr_iter: The number of LQR iterations to perform. - grad_method: The method to compute the Jacobian of the dynamics. - GradMethods.ANALYTIC: Use a manually-defined Jacobian. - + Fast and accurate, use this if possible - GradMethods.AUTO_DIFF: Use PyTorch's autograd. - + Slow - GradMethods.FINITE_DIFF: Use naive finite differences - + Inaccurate - delta_u (float): The amount each component of the controls - is allowed to change in each LQR iteration. - verbose (int): - -1: No output or warnings - 0: Warnings - 1+: Detailed iteration info - eps: Termination threshold, on the norm of the full control - step (without line search) - back_eps: `eps` value to use in the backwards pass. - n_batch: May be necessary for now if it can't be inferred. - TODO: Infer, potentially remove this. - linesearch_decay (float): Multiplicative decay factor for the - line search. - max_linesearch_iter (int): Can be used to disable the line search - if 1 is used for some problems the line search can - be harmful. - exit_unconverged: Assert False if a fixed point is not reached. - detach_unconverged: Detach examples from the graph that do - not hit a fixed point so they are not differentiated through. - backprop: Allow the solver to be differentiated through. - slew_rate_penalty (float): Penalty term applied to - ||u_t - u_{t+1}||_2^2 in the objective. - prev_ctrl: The previous nominal control sequence to initialize - the solver with. - not_improved_lim: The number of iterations to allow that don't - improve the objective before returning early. - best_cost_eps: Absolute threshold for the best cost - to be updated. - """ - - def __init__( - self, n_state, n_ctrl, T, - u_lower=None, u_upper=None, - u_zero_I=None, - u_init=None, - lqr_iter=10, - grad_method=GradMethods.ANALYTIC, - delta_u=None, - verbose=0, - eps=1e-7, - back_eps=1e-7, - n_batch=None, - linesearch_decay=0.2, - max_linesearch_iter=10, - exit_unconverged=True, - detach_unconverged=True, - backprop=True, - slew_rate_penalty=None, - prev_ctrl=None, - not_improved_lim=5, - best_cost_eps=1e-4 - ): - super().__init__() - - assert (u_lower is None) == (u_upper is None) - assert max_linesearch_iter > 0 - - self.n_state = n_state - self.n_ctrl = n_ctrl - self.T = T - self.u_lower = u_lower - self.u_upper = u_upper - - if not isinstance(u_lower, float): - self.u_lower = util.detach_maybe(self.u_lower) - - if not isinstance(u_upper, float): - self.u_upper = util.detach_maybe(self.u_upper) - - self.u_zero_I = util.detach_maybe(u_zero_I) - self.u_init = util.detach_maybe(u_init) - self.lqr_iter = lqr_iter - self.grad_method = grad_method - self.delta_u = delta_u - self.verbose = verbose - self.eps = eps - self.back_eps = back_eps - self.n_batch = n_batch - self.linesearch_decay = linesearch_decay - self.max_linesearch_iter = max_linesearch_iter - self.exit_unconverged = exit_unconverged - self.detach_unconverged = detach_unconverged - self.backprop = backprop - self.not_improved_lim = not_improved_lim - self.best_cost_eps = best_cost_eps - - self.slew_rate_penalty = slew_rate_penalty - self.prev_ctrl = prev_ctrl - - self._lqr = LQRStepClass( - n_state=self.n_state, - n_ctrl=self.n_ctrl, - T=self.T, - u_lower=self.u_lower, - u_upper=self.u_upper, - u_zero_I=self.u_zero_I, - delta_u=self.delta_u, - linesearch_decay=self.linesearch_decay, - max_linesearch_iter=self.max_linesearch_iter, - delta_space=True, - back_eps=self.back_eps, - ) - - # @profile - def forward(self, x_init, cost, dx, cost_inputs, return_converged=False, return_iters=False): - if self.n_batch is not None: - n_batch = self.n_batch - else: - raise ValueError('MPC Error: Could not infer batch size, pass in as n_batch') - - assert x_init.ndimension() == 2 and x_init.size(0) == n_batch - - if self.u_init is None: - u = torch.zeros(self.T, n_batch, self.n_ctrl).type_as(x_init.data) - else: - u = self.u_init - if u.ndimension() == 2: - u = u.unsqueeze(1).expand(self.T, n_batch, -1).clone() - u = u.type_as(x_init.data) - - if self.verbose > 0: - print('Initial mean(cost): {:.4e}'.format( - torch.mean(get_cost( - self.T, u, cost, dx, x_init=x_init - )).item() - )) - - best = None - - # Add init trajectory - if return_iters: - iters = dict(x=[], u=[], cost=[]) - x = util.get_traj(self.T, u, x_init=x_init, dynamics=dx) - iters['x'].append(x.detach()) - iters['u'].append(u.detach()) - iters['cost'].append(get_cost(self.T, u, cost, dx, x_init=x_init, cost_inputs=cost_inputs).detach()) - else: - iters = dict() - - n_not_improved = 0 - for i in range(self.lqr_iter): - u = Variable(util.detach_maybe(u), requires_grad=True) - # Linearize the dynamics around the current trajectory. - x = util.get_traj(self.T, u, x_init=x_init, dynamics=dx) - F, f = dx.linearized(x, util.detach_maybe(u), diff=False) - C, c = cost.approx_quadratic(x, u, cost_inputs, diff=False) - - x, u, back_out, for_out = self.solve_lqr_subproblem( - cost_inputs, x_init, C, c, F, f, cost, dx, x, u) - # back_out, for_out = _lqr.back_out, _lqr.for_out - n_not_improved += 1 - - assert x.ndimension() == 3 - assert u.ndimension() == 3 - - # Add init trajectory - if return_iters: - iters['x'].append(x.detach()) - iters['u'].append(u.detach()) - iters['cost'].append(for_out.costs.detach()) - - # Not improved means nothing in the batch is improved. - if best is None: - best = { - 'x': list(torch.split(x, split_size_or_sections=1, dim=1)), - 'u': list(torch.split(u, split_size_or_sections=1, dim=1)), - 'costs': for_out.costs, - 'full_du_norm': for_out.full_du_norm, - 'alpha_du_norm': for_out.alpha_du_norm, - } - # TODO pkarkus this should set n_not_improved=1 - else: - for j in range(n_batch): - # if for_out.costs[j] <= best['costs'][j] + self.best_cost_eps: - if for_out.costs[j] <= best['costs'][j] - self.best_cost_eps: - n_not_improved = 0 - if for_out.costs[j] <= best['costs'][j]: - best['x'][j] = x[:,j].unsqueeze(1) - best['u'][j] = u[:,j].unsqueeze(1) - best['costs'][j] = for_out.costs[j] - best['full_du_norm'][j] = for_out.full_du_norm[j] - best['alpha_du_norm'][j] = for_out.alpha_du_norm[j] - - if self.verbose > 0: - util.table_log('lqr', ( - ('iter', i), - ('mean(cost)', torch.mean(best['costs']).item(), '{:.4e}'), - ('||full_du||_max', max(for_out.full_du_norm).item(), '{:.2e}'), - # ('||alpha_du||_max', max(for_out.alpha_du_norm), '{:.2e}'), - # TODO: alphas, total_qp_iters here is for the current - # iterate, not the best - ('mean(alphas)', for_out.mean_alphas.item(), '{:.2e}'), - ('total_qp_iters', back_out.n_total_qp_iter), - )) - - # eps defines a max change of control to continue. We use the largest change over the batch. - if max(for_out.full_du_norm) < self.eps or \ - n_not_improved > self.not_improved_lim: - break - - - x = torch.cat(best['x'], dim=1) - u = torch.cat(best['u'], dim=1) - full_du_norm = best['full_du_norm'] - alpha_du_norm = best['alpha_du_norm'] - - F, f = dx.linearized(x, u, diff=True) - - C, c = cost.approx_quadratic(x, u, cost_inputs, diff=True) - - # Run LQR without updating x and u - x, u, _, _ = self.solve_lqr_subproblem( - cost_inputs, x_init, C, c, F, f, cost, dx, x, u, no_op_forward=True) - - converged_mask = full_du_norm < self.eps - # converged_mask = alpha_du_norm < self.eps - if self.detach_unconverged: - if not converged_mask.all(): - if self.exit_unconverged: - assert False - - if self.verbose >= 0: - print("LQR Warning: All examples did not converge to a fixed point.") - print("Detaching and *not* backpropping through the bad examples.") - - x = self.detach_unconverged_tensor(x, converged_mask) - u = self.detach_unconverged_tensor(u, converged_mask) - - costs = best['costs'] - if return_converged: - return (x, u, costs, converged_mask.detach(), iters) - else: - return (x, u, costs) - - @staticmethod - def detach_unconverged_tensor(t, converged_mask): - It = Variable(converged_mask.unsqueeze(0).unsqueeze(2)).type_as(t.data) - t = t*It + t.clone().detach()*(1.-It) - return t - - def solve_lqr_subproblem(self, cost_inputs, x_init, C, c, F, f, cost, dynamics, x, u, - no_op_forward=False): - assert self.slew_rate_penalty is None or isinstance(cost, torch.nn.Module) - e = Variable(torch.Tensor()) - x, u, back_out, for_out = self._lqr(x, u, cost, dynamics, no_op_forward, x_init, C, c, F, f if f is not None else e, cost_inputs) - - return x, u, back_out, for_out - diff --git a/diffstack/modules/predictors/CTT.py b/diffstack/modules/predictors/CTT.py new file mode 100644 index 0000000..f788fe8 --- /dev/null +++ b/diffstack/modules/predictors/CTT.py @@ -0,0 +1,1673 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import wandb +import os +from functools import partial +from typing import Dict +from collections import OrderedDict, defaultdict + +from diffstack.dynamics.unicycle import Unicycle, Unicycle_xyvsc +from diffstack.dynamics.unicycle import Unicycle +from diffstack.modules.module import Module, DataFormat, RunMode + +from diffstack.utils.utils import removeprefix +import diffstack.utils.tensor_utils as TensorUtils +import diffstack.utils.geometry_utils as GeoUtils +import diffstack.utils.metrics as Metrics +from diffstack.utils.geometry_utils import ratan2 +import diffstack.utils.lane_utils as LaneUtils +from diffstack.utils.batch_utils import batch_utils +import diffstack.utils.model_utils as ModelUtils + +from trajdata.utils.map_utils import LaneSegRelation +from diffstack.modules.predictors.trajectron_utils.model.components import GMM2D +from diffstack.utils.vis_utils import plot_scene_open_loop, animate_scene_open_loop + +from diffstack.utils.loss_utils import collision_loss, loss_clip +from pytorch_lightning.loggers import WandbLogger +from pathlib import Path + + +from diffstack.models.CTT import ( + CTT, + TFvars, + FeatureAxes, + GNNedges, +) +from trajdata.data_structures import AgentType +from diffstack.utils.homotopy import ( + HomotopyType, + identify_pairwise_homotopy, + HOMOTOPY_THRESHOLD, +) +from bokeh.plotting import figure, output_file, save +from bokeh.models import Range1d +from bokeh.io import export_png + + +class CTTTrafficModel(Module): + @property + def input_format(self) -> DataFormat: + return DataFormat(["scene_batch"]) + + @property + def output_format(self) -> DataFormat: + return DataFormat(["pred", "pred_dist", "pred_ml", "sample_modes", "step_time"]) + + @property + def checkpoint_monitor_keys(self): + return {"valLoss": "val/losses_predictor_marginal_lm_loss"} + + def __init__(self, model_registrar, cfg, log_writer, device, input_mappings={}): + super(CTTTrafficModel, self).__init__( + model_registrar, cfg, log_writer, device, input_mappings=input_mappings + ) + + self.bu = batch_utils(rasterize_mode="none") + self.modality_shapes = self.bu.get_modality_shapes(cfg) + + self.hist_lane_relation = LaneUtils.LaneRelationFromCfg(cfg.hist_lane_relation) + self.fut_lane_relation = LaneUtils.LaneRelationFromCfg(cfg.fut_lane_relation) + self.cfg = cfg + self.nets = nn.ModuleDict() + + self.create_nets(cfg) + self.fp16 = cfg.fp16 if "fp16" in cfg else False + + if "checkpoint" in cfg and cfg.checkpoint["enabled"]: + checkpoint = torch.load(cfg.checkpoint["path"]) + predictor_dict = { + removeprefix(k, "components.predictor."): v + for k, v in checkpoint["state_dict"].items() + if k.startswith("components.predictor.") + } + self.load_state_dict(predictor_dict) + + self.device = device + self.nets["policy"] = self.nets["policy"].to(self.device) + + # setup loss functions + + self.lane_mode_loss = nn.CrossEntropyLoss( + reduction="none", label_smoothing=0.01 + ) + self.homotopy_loss = nn.CrossEntropyLoss(reduction="none", label_smoothing=0.01) + self.joint_mode_loss = nn.CrossEntropyLoss( + reduction="none", label_smoothing=0.05 + ) + self.traj_loss = nn.MSELoss(reduction="none") + self.unicycle_model = Unicycle(cfg.step_time) + + def create_nets(self, cfg): + enc_var_axes = { + TFvars.Agent_hist: "B,A,T,F", + TFvars.Lane: "B,L,F", + } + agent_ntype = len(AgentType) + auxdim_l = 4 * cfg.num_lane_pts # (x,y,s,c) x num_pts + + agent_raw_dim = agent_ntype + 5 # agent type, l,w,v,a,r + + enc_attn_attributes = OrderedDict() + # each attention operation is done on a pair of variables, e.g. (agent_hist, agent_hist), on a specific axis, e.g. T + # to specify the recipe for factorized attention, we need to specify the following: + # 1. the dimension of the edge feature + # 2. the edge function (if any) + # 3. the number of separate attention mechanisms (in case we need to separate the self-attention from the cross-attention or more generally, if we need to separate the attention for different types of edges) + # 4. whether to use normalization for the embedding + + d_lm_hist = len(self.hist_lane_relation) + d_lm_fut = len(self.fut_lane_relation) + d_ll = len(LaneSegRelation) + d_static = 2 + 1 + len(AgentType) + a2a_edge_func = partial( + ModelUtils.agent2agent_edge, + scale=cfg.encoder.edge_scale, + clip=cfg.encoder.edge_clip, + ) + l2l_edge_func = partial( + ModelUtils.lane2lane_edge, + scale=cfg.encoder.edge_scale, + clip=cfg.encoder.edge_clip, + ) + if cfg.a2l_edge_type == "proj": + a2l_edge_func = partial( + ModelUtils.agent2lane_edge_proj, + scale=cfg.encoder.edge_scale, + clip=cfg.encoder.edge_clip, + ) + a2l_edge_dim = cfg.edge_dim.a2l + elif cfg.a2l_edge_type == "attn": + self.nets["a2l_edge"] = ModelUtils.Agent2Lane_emb_attn( + edge_dim=4, + n_embd=cfg.a2l_n_embd, + agent_feat_dim=d_static, + xy_scale=cfg.encoder.edge_scale, + xy_clip=cfg.encoder.edge_clip, + Lmax=cfg.num_lane_pts, + ) + a2l_edge_dim = cfg.a2l_n_embd + a2l_edge_func = self.nets["a2l_edge"].forward + + self.edge_func = dict(a2a=a2a_edge_func, l2l=l2l_edge_func, a2l=a2l_edge_func) + enc_attn_attributes[(TFvars.Agent_hist, TFvars.Agent_hist, FeatureAxes.T)] = ( + cfg.edge_dim.a2a, + a2a_edge_func, + 1, + False, + ) + enc_attn_attributes[(TFvars.Agent_hist, TFvars.Agent_hist, FeatureAxes.A)] = ( + cfg.edge_dim.a2a, + a2a_edge_func, + cfg.attn_ntype.a2a, + False, + ) + enc_attn_attributes[(TFvars.Lane, TFvars.Lane, FeatureAxes.L)] = ( + cfg.edge_dim.l2l + d_ll, + None, + cfg.attn_ntype.l2l, + False, + ) + enc_attn_attributes[ + (TFvars.Agent_hist, TFvars.Lane, (FeatureAxes.A, FeatureAxes.L)) + ] = (a2l_edge_dim + d_lm_hist, None, cfg.attn_ntype.a2l, False) + + enc_transformer_kwargs = dict( + n_embd=cfg.n_embd, + n_head=cfg.n_head, + PE_mode=cfg.PE_mode, + use_rpe_net=cfg.use_rpe_net, + attn_attributes=enc_attn_attributes, + var_axes=enc_var_axes, + attn_pdrop=cfg.encoder.attn_pdrop, + resid_pdrop=cfg.encoder.resid_pdrop, + MAX_T=cfg.future_num_frames + cfg.history_num_frames + 2, + ) + + lane_margins_dim = 5 + + # decoder design + dec_attn_attributes = OrderedDict() + + dec_attn_attributes[ + (TFvars.Agent_future, TFvars.Agent_hist, (FeatureAxes.T, FeatureAxes.T)) + ] = (cfg.edge_dim.a2a, a2a_edge_func, 1, False) + + dec_attn_attributes[ + (TFvars.Agent_future, TFvars.Agent_future, FeatureAxes.A) + ] = (cfg.edge_dim.a2a + len(HomotopyType) * 2, None, 2, False) + dec_attn_attributes[ + (TFvars.Agent_future, TFvars.Lane, (FeatureAxes.A, FeatureAxes.L)) + ] = (a2l_edge_dim + d_lm_fut + lane_margins_dim, None, 1, False) + dec_var_axes = { + TFvars.Agent_hist: "B,A,T,F", + TFvars.Lane: "B,L,F", + TFvars.Agent_future: "B,A,T,F", + } + dec_transformer_kwargs = dict( + n_embd=cfg.n_embd, + n_head=cfg.n_head, + PE_mode=cfg.PE_mode, + use_rpe_net=cfg.use_rpe_net, + attn_attributes=dec_attn_attributes, + var_axes=dec_var_axes, + attn_pdrop=cfg.decoder.attn_pdrop, + resid_pdrop=cfg.decoder.resid_pdrop, + ) + # embedding functions of raw variables + embed_funcs = { + TFvars.Agent_hist: ModelUtils.Agent_emb(agent_raw_dim, cfg.n_embd), + TFvars.Lane: ModelUtils.Lane_emb( + auxdim_l, + cfg.n_embd, + xy_scale=cfg.encoder.edge_scale, + xy_clip=cfg.encoder.edge_clip, + ), + TFvars.Agent_future: ModelUtils.Agent_emb(agent_raw_dim, cfg.n_embd), + GNNedges.Agenthist2Agenthist: ModelUtils.Agent2Agent_emb( + cfg.edge_dim.a2a, + cfg.n_embd, + xy_scale=cfg.encoder.edge_scale, + xy_clip=cfg.encoder.edge_clip, + ), + GNNedges.Agentfuture2Agentfuture: ModelUtils.Agent2Agent_emb( + cfg.edge_dim.a2a + len(HomotopyType) * 2, + cfg.n_embd, + xy_scale=cfg.encoder.edge_scale, + xy_clip=cfg.encoder.edge_clip, + ), + } + if cfg.a2l_edge_type == "proj": + embed_funcs.update( + { + GNNedges.Agenthist2Lane: ModelUtils.Agent2Lane_emb_proj( + cfg.edge_dim.a2l, + cfg.n_embd, + xy_scale=cfg.encoder.edge_scale, + xy_clip=cfg.encoder.edge_clip, + ), + GNNedges.Agentfuture2Lane: ModelUtils.Agent2Lane_emb_proj( + cfg.edge_dim.a2l + d_lm_fut + lane_margins_dim, + cfg.n_embd, + xy_scale=cfg.encoder.edge_scale, + xy_clip=cfg.encoder.edge_clip, + ), + } + ) + else: + embed_funcs.update( + { + GNNedges.Agenthist2Lane: ModelUtils.Agent2Lane_emb_attn( + 4, + cfg.a2l_n_embd, + d_static, + output_dim=cfg.n_embd, + xy_scale=cfg.encoder.edge_scale, + xy_clip=cfg.encoder.edge_clip, + ), + GNNedges.Agentfuture2Lane: ModelUtils.Agent2Lane_emb_attn( + 4, + cfg.a2l_n_embd, + d_static, + output_dim=cfg.n_embd, + aux_edge_dim=d_lm_fut + lane_margins_dim, + xy_scale=cfg.encoder.edge_scale, + xy_clip=cfg.encoder.edge_clip, + ), + } + ) + + enc_GNN_attributes = OrderedDict() + enc_GNN_attributes[(GNNedges.Agenthist2Agenthist, "edge", None)] = ( + [cfg.n_embd * 2], + None, + None, + None, + ) + enc_GNN_attributes[(GNNedges.Agenthist2Lane, "edge", None)] = ( + [cfg.n_embd * 2], + None, + None, + None, + ) + enc_GNN_attributes[(GNNedges.Agenthist2Lane, "node", TFvars.Agent_hist)] = ( + [cfg.n_embd * 2], + None, + cfg.encoder.pooling, + None, + ) + enc_GNN_attributes[ + (GNNedges.Agenthist2Agenthist, "node", TFvars.Agent_hist) + ] = ([cfg.n_embd * 2], None, cfg.encoder.pooling, None) + + node_n_embd = defaultdict(lambda: cfg.n_embd) + edge_n_embd = defaultdict(lambda: cfg.n_embd) + + enc_edge_var = { + GNNedges.Agenthist2Agenthist: (TFvars.Agent_hist, TFvars.Agent_hist), + GNNedges.Agenthist2Lane: (TFvars.Agent_hist, TFvars.Lane), + } + dec_edge_var = { + GNNedges.Agenthist2Agenthist: (TFvars.Agent_hist, TFvars.Agent_hist), + GNNedges.Agenthist2Lane: (TFvars.Agent_hist, TFvars.Lane), + GNNedges.Agentfuture2Agentfuture: ( + TFvars.Agent_future, + TFvars.Agent_future, + ), + GNNedges.Agentfuture2Lane: (TFvars.Agent_future, TFvars.Lane), + } + enc_GNN_kwargs = dict( + var_axes=enc_var_axes, + GNN_attributes=enc_GNN_attributes, + node_n_embd=node_n_embd, + edge_n_embd=edge_n_embd, + edge_var=enc_edge_var, + ) + + JM_GNN_attributes = OrderedDict() + JM_GNN_attributes[(GNNedges.Agenthist2Agenthist, "edge", None)] = ( + [cfg.n_embd * 2], + None, + None, + None, + ) + JM_GNN_attributes[(GNNedges.Agenthist2Lane, "edge", None)] = ( + [cfg.n_embd * 2], + None, + None, + None, + ) + JM_GNN_attributes[(GNNedges.Agenthist2Lane, "node", TFvars.Agent_hist)] = ( + [cfg.n_embd * 2], + None, + cfg.encoder.pooling, + None, + ) + JM_GNN_attributes[(GNNedges.Agenthist2Agenthist, "node", TFvars.Agent_hist)] = ( + [cfg.n_embd * 2], + None, + cfg.encoder.pooling, + None, + ) + + JM_edge_n_embd = { + GNNedges.Agenthist2Agenthist: cfg.n_embd + cfg.encoder.mode_embed_dim, + GNNedges.Agenthist2Lane: cfg.n_embd + cfg.encoder.mode_embed_dim, + } + + JM_GNN_kwargs = dict( + var_axes=enc_var_axes, + GNN_attributes=JM_GNN_attributes, + node_n_embd=node_n_embd, + edge_n_embd=JM_edge_n_embd, + edge_var=enc_edge_var, + ) + + enc_output_params = dict( + pooling_T=cfg.encoder.pooling, + Th=cfg.history_num_frames + 1, + lane_mode=dict( + n_head=4, + hidden_dim=[cfg.n_embd], + ), + homotopy=dict( + n_head=4, + hidden_dim=[cfg.n_embd], + ), + joint_mode=dict( + n_head=4, + jm_GNN_nblock=cfg.encoder.jm_GNN_nblock, + GNN_kwargs=JM_GNN_kwargs, + num_joint_samples=cfg.encoder.num_joint_samples, + num_joint_factors=cfg.encoder.num_joint_factors, + ), + mode_embed_dim=cfg.encoder.mode_embed_dim, + null_lane_mode=cfg.encoder.null_lane_mode, + PE_mode=cfg.PE_mode, + ) + self.null_lane_mode = cfg.encoder.null_lane_mode + + dec_GNN_attributes = OrderedDict() + if cfg.decoder.GNN_enabled: + dec_GNN_attributes[(GNNedges.Agentfuture2Agentfuture, "edge", None)] = ( + [cfg.n_embd * 2], + None, + None, + None, + ) + dec_GNN_attributes[(GNNedges.Agentfuture2Lane, "edge", None)] = ( + [cfg.n_embd * 2], + None, + None, + None, + ) + dec_GNN_attributes[ + (GNNedges.Agentfuture2Lane, "node", TFvars.Agent_future) + ] = ([cfg.n_embd * 2], None, cfg.decoder.pooling, None) + dec_GNN_attributes[ + (GNNedges.Agentfuture2Agentfuture, "node", TFvars.Agent_future) + ] = ([cfg.n_embd * 2], None, cfg.decoder.pooling, None) + + dec_GNN_kwargs = dict( + var_axes=dec_var_axes, + GNN_attributes=dec_GNN_attributes, + node_n_embd=node_n_embd, + edge_n_embd=edge_n_embd, + edge_var=dec_edge_var, + ) + + dyn = dict() + name_type_table = { + "vehicle": AgentType.VEHICLE, + "pedestrian": AgentType.PEDESTRIAN, + "bicycle": AgentType.BICYCLE, + "motorcycle": AgentType.MOTORCYCLE, + } + + for k, v in cfg.decoder.dyn.items(): + if v == "unicycle": + dyn[name_type_table[k]] = Unicycle(cfg.step_time) + elif v == "unicycle_xyvsc": + dyn[name_type_table[k]] = Unicycle_xyvsc(cfg.step_time) + elif v == "DI_unicycle": + dyn[name_type_table[k]] = Unicycle(cfg.step_time, max_steer=1e3) + else: + dyn[name_type_table[k]] = None + + self.weighted_consistency_loss = cfg.weighted_consistency_loss + dec_output_params = dict( + arch=cfg.decoder.arch, + dyn=dyn, + num_layers=cfg.decoder.num_layers, + lstm_hidden_size=cfg.decoder.lstm_hidden_size, + mlp_hidden_dims=cfg.decoder.mlp_hidden_dims, + traj_dim=cfg.decoder.traj_dim, + dt=cfg.step_time, + Tf=cfg.future_num_frames, + decode_num_modes=cfg.decoder.decode_num_modes, + AR_step_size=cfg.decoder.AR_step_size, + AR_update_mode=cfg.decoder.AR_update_mode, + LR_sample_hack=cfg.LR_sample_hack, + dec_rounds=cfg.decoder.dec_rounds, + ) + + assert ( + cfg.decoder.AR_step_size == 1 + ) # for now, we only support AR_step_size=1 for non-auto-regressive mode + self.Tf_mode = cfg.future_num_frames + self.nets["policy"] = CTT( + n_embd=cfg.n_embd, + embed_funcs=embed_funcs, + enc_nblock=cfg.enc_nblock, + dec_nblock=cfg.dec_nblock, + enc_transformer_kwargs=enc_transformer_kwargs, + enc_GNN_kwargs=enc_GNN_kwargs, + dec_transformer_kwargs=dec_transformer_kwargs, + dec_GNN_kwargs=dec_GNN_kwargs, + enc_output_params=enc_output_params, + dec_output_params=dec_output_params, + hist_lane_relation=self.hist_lane_relation, + fut_lane_relation=self.fut_lane_relation, + max_joint_cardinality=self.cfg.max_joint_cardinality, + classify_a2l_4all_lanes=self.cfg.classify_a2l_4all_lanes, + edge_func=self.edge_func, + ) + + def set_eval(self): + self.nets["policy"].eval() + + def set_train(self): + self.nets["policy"].train() + + def _run_forward(self, inputs: Dict, run_mode: RunMode, **kwargs) -> Dict: + if "parsed_batch" in inputs: + parsed_batch = inputs["parsed_batch"] + else: + if ( + isinstance(inputs["scene_batch"], dict) + and "batch" in inputs["scene_batch"] + ): + parsed_batch = self.bu.parse_batch( + inputs["scene_batch"]["batch"].astype(torch.float) + ) + else: + parsed_batch = self.bu.parse_batch( + inputs["scene_batch"].astype(torch.float) + ) + + if self.fp16: + for k, v in parsed_batch.items(): + if isinstance(v, torch.Tensor) and v.dtype == torch.float32: + parsed_batch[k] = v.half() + + # tic = torch_utils.tic(timer=self.hyperparams.debug.timer) + parsed_batch_torch = { + k: v.to(self.device) + for k, v in parsed_batch.items() + if isinstance(v, torch.Tensor) + } + + parsed_batch.update(parsed_batch_torch) + dt = parsed_batch["dt"][0].item() + B, N, Th = parsed_batch["agent_hist"].shape[:3] + + Tf = parsed_batch["agent_fut"].size(2) + device = parsed_batch["agent_hist"].device + agent_type = ( + F.one_hot( + parsed_batch["agent_type"].masked_fill( + parsed_batch["agent_type"] < 0, 0 + ), + len(AgentType), + ) + .float() + .to(device) + ) + agent_type.masked_fill_((parsed_batch["agent_type"] < 0).unsqueeze(-1), 0) + agent_size = parsed_batch["extent"][..., :2] + static_feature = torch.cat([agent_size, agent_type], dim=-1) + hist_mask = parsed_batch["hist_mask"] + static_feature_tiled_h = static_feature.unsqueeze(2).repeat_interleave( + Th, 2 + ) * hist_mask.unsqueeze(-1) + + hist_xy = parsed_batch["agent_hist"][..., :2] + hist_yaw = torch.arctan2( + parsed_batch["agent_hist"][..., 6:7], parsed_batch["agent_hist"][..., 7:8] + ) + hist_sc = parsed_batch["agent_hist"][..., 6:8] + # normalize hist_sc + hist_sc = GeoUtils.normalize_sc(hist_sc) + + x_hist_uni = self.unicycle_model.get_state(hist_xy, hist_yaw, hist_mask) + u_mask = hist_mask[..., :-1] * hist_mask[..., 1:] + u_hist_uni = self.unicycle_model.inverse_dyn( + x_hist_uni[..., :-1, :], x_hist_uni[..., 1:, :], mask=u_mask + ) + u_hist_uni = torch.cat([u_hist_uni[..., :1, :], u_hist_uni], dim=-2) + + hist_v = self.unicycle_model.state2vel(x_hist_uni) + hist_acce = u_hist_uni[..., :1] + + hist_xyvsc = torch.cat([hist_xy, hist_v, hist_sc], dim=-1) + hist_xysc = torch.cat([hist_xy, hist_sc], dim=-1) + hist_h = ratan2(hist_sc[..., :1], hist_sc[..., 1:2]) + hist_r = (GeoUtils.round_2pi(hist_h[..., 1:, :] - hist_h[..., :-1, :])) / dt + hist_r = torch.cat([hist_r[..., 0:1, :], hist_r], dim=-2) + hist_feature = torch.cat( + [hist_v, hist_acce, hist_r, static_feature_tiled_h], dim=-1 + ) + lane_xyh = parsed_batch["lane_xyh"] + M = lane_xyh.size(1) + lane_xysc = torch.cat( + [ + lane_xyh[..., :2], + torch.sin(lane_xyh[..., 2:3]), + torch.cos(lane_xyh[..., 2:3]), + ], + -1, + ) + lane_adj = parsed_batch["lane_adj"].type(torch.int64) + lane_mask = parsed_batch["lane_mask"] + + raw_vars = { + TFvars.Agent_hist: hist_feature, + TFvars.Lane: lane_xysc.reshape(B, M, -1), + } + hist_aux = torch.cat([hist_xyvsc, static_feature_tiled_h], -1) + lane_aux = lane_xysc.view(B, M, -1) + + aux_xs = { + TFvars.Agent_hist: hist_aux, + TFvars.Lane: lane_aux, + } + a2l_edge = self.edge_func["a2l"]( + hist_aux.transpose(1, 2).reshape(B * Th, N, -1), + lane_aux.repeat_interleave(Th, 0), + ) + + hist_lane_agent_flag, _ = self.hist_lane_relation.categorize_lane_relation_pts( + TensorUtils.join_dimensions(hist_xysc, 0, 2), + lane_xysc.repeat_interleave(N, 0), + TensorUtils.join_dimensions(hist_mask, 0, 2), + lane_mask.repeat_interleave(N, 0), + ) + + hist_lane_agent_flag = hist_lane_agent_flag.view(B, N, M, Th, -1) + a2l_edge = torch.cat( + [ + a2l_edge, + hist_lane_agent_flag.permute(0, 3, 1, 2, 4).reshape(B * Th, N, M, -1), + ], + -1, + ) + l2l_edge = self.edge_func["l2l"](lane_aux, lane_aux) + l2l_edge = torch.cat([l2l_edge, F.one_hot(lane_adj, len(LaneSegRelation))], -1) + agent_hist_mask = parsed_batch["hist_mask"] + + # cross_masks = { + # (TFvars.Agent_hist,TFvars.Lane,(FeatureAxes.A,FeatureAxes.L)): [agent_lane_mask1,agent_lane_mask2], + # } + cross_masks = dict() + var_masks = { + TFvars.Agent_hist: agent_hist_mask.float(), + TFvars.Lane: lane_mask.float(), + } + enc_edges = { + (TFvars.Agent_hist, TFvars.Lane, (FeatureAxes.A, FeatureAxes.L)): a2l_edge, + (TFvars.Lane, TFvars.Lane, FeatureAxes.L): l2l_edge, + } + + frame_indices = { + TFvars.Agent_hist: torch.arange(Th, device=device)[None, None, :].repeat( + B, N, 1 + ), + } + + if ( + "agent_fut" in parsed_batch + and "fut_mask" in parsed_batch + and parsed_batch["fut_mask"].any() + and run_mode in [RunMode.TRAIN, RunMode.VALIDATE] + ): + fut_xy = parsed_batch["agent_fut"][..., :2] + fut_sc = parsed_batch["agent_fut"][..., 6:8] + fut_sc = GeoUtils.normalize_sc(fut_sc) + fut_xysc = torch.cat([fut_xy, fut_sc], -1) + fut_mask = parsed_batch["fut_mask"] + + mode_valid_flag = fut_mask.all(-1) + end_points = fut_xysc[:, :, -1] # Only look at final time for GT! + + GT_lane_mode, _ = self.fut_lane_relation.categorize_lane_relation_pts( + end_points.reshape(B * N, 1, 4), + lane_xysc.repeat_interleave(N, 0), + fut_mask.any(-1).reshape(B * N, 1), + lane_mask.repeat_interleave(N, 0), + force_select=not self.null_lane_mode, + ) + # You could have two lanes that it is both on + + GT_lane_mode = GT_lane_mode.squeeze(-2).argmax(-1).reshape(B, N, M) + if self.null_lane_mode: + GT_lane_mode = torch.cat( + [GT_lane_mode, (GT_lane_mode == 0).all(-1, keepdim=True)], -1 + ) + + angle_diff, GT_homotopy = identify_pairwise_homotopy(fut_xy, mask=fut_mask) + GT_homotopy = GT_homotopy.type(torch.int64).reshape(B, N, N) + else: + GT_lane_mode = None + GT_homotopy = None + center_from_agents = parsed_batch["center_from_agents"] + if run_mode == RunMode.INFER: + num_samples = kwargs.get("num_samples", 10) + else: + num_samples = None + vars, mode_pred = self.nets["policy"]( + raw_vars, + aux_xs, + var_masks, + cross_masks, + frame_indices, + agent_type=agent_type, + enc_edges=enc_edges, + GT_lane_mode=GT_lane_mode, + GT_homotopy=GT_homotopy, + center_from_agents=center_from_agents, + num_samples=num_samples, + ) + output_keys = [ + "trajectories", + "inputs", + "states", + "input_violation", + "jerk", + "type_mask", + ] + output = {k: v for k, v in vars.items() if k in output_keys} + if run_mode == RunMode.INFER: + B, num_samples, N, Tf = output["trajectories"].shape[:-1] + log_pis = ( + mode_pred["joint_pred"] + .view(1, B, 1, num_samples) + .expand(N, B, Tf, num_samples) + ) + if "states" in output: + state_xyhv = torch.cat( + [ + output["states"][..., :2], + output["states"][..., 3:4], + output["states"][..., 2:3], + ], + -1, + ) + else: + # calculate velocity from trajectories + pred_vel = self.unicycle_model.calculate_vel( + output["trajectories"][..., :2], + output["trajectories"][..., 2:], + hist_mask.any(-1)[:, None, :, None] + .repeat_interleave(num_samples, 1) + .repeat_interleave(Tf, -1), + ) + state_xyhv = torch.cat([output["trajectories"], pred_vel], -1) + mus_xyhv = state_xyhv.permute(2, 0, 3, 1, 4) + log_sigmas = torch.zeros_like(mus_xyhv) + corrs = torch.zeros_like(mus_xyhv[..., 0]) + pred_dist = GMM2D(log_pis, mus_xyhv, log_sigmas, corrs) + output["pred_dist"] = pred_dist + output["pred_ml"] = state_xyhv + # output["pred_single"] = torch.full([B,0,Tf,4],np.nan,device=device) + output["sample_modes"] = dict( + lane_mode=mode_pred["lane_mode_sample"], + homotopy=mode_pred["homotopy_sample"], + ) + # print(output["inputs"][...,1].abs().max()) + else: + if "trajectories" in output: + # pad trajectories to Tf in case of anomaly + if output["trajectories"].shape[-2] != Tf: + breakpoint() + dynamic_outputs = { + k: v + for k, v in output.items() + if k + in [ + "inputs", + "states", + "input_violation", + "jerk", + "trajectories", + ] + } + if output["trajectories"].shape[-2] < Tf: + func = ( + lambda x: x + if x.shape[-2] == Tf + else torch.cat( + [ + x, + x[..., -1:, :].repeat_interleave( + Tf - x.shape[-2], -2 + ), + ], + -2, + ) + ) + elif output["trajectories"].shape[-2] > Tf: + func = lambda x: x[..., :Tf, :] + dynamic_outputs = TensorUtils.recursive_dict_list_tuple_apply( + dynamic_outputs, + { + torch.Tensor: func, + type(None): lambda x: x, + }, + ) + output.update(dynamic_outputs) + dec_xysc = torch.cat( + [ + output["trajectories"][..., :2], + torch.sin(output["trajectories"][..., 2:3]), + torch.cos(output["trajectories"][..., 2:3]), + ], + -1, + ) + if dec_xysc.shape[-2] != Tf: + pass + else: + if dec_xysc.ndim == 4: + DS = 1 + dec_xysc = dec_xysc.unsqueeze(1) + elif dec_xysc.ndim == 5: + DS = dec_xysc.size(1) # decode sample + + ( + dec_lm, + dec_lm_margin, + ) = self.fut_lane_relation.categorize_lane_relation_pts( + dec_xysc[:, :, :, -1].view(B * DS * N, 1, 4), + lane_xysc.repeat_interleave(N * DS, 0), + fut_mask.repeat_interleave(DS, 0) + .any(-1) + .reshape(B * DS * N, 1), + lane_mask.repeat_interleave(N * DS, 0), + force_select=False, + force_unique=False, + const_override=dict(Y_near_thresh=1.0, X_rear_thresh=5.0), + ) + + y_dev, psi_dev = LaneUtils.get_ypsi_dev( + dec_xysc.view(B * DS * N, -1, 4), + lane_xysc.repeat_interleave(N * DS, 0), + ) + + dec_angle_diff, dec_homotopy = identify_pairwise_homotopy( + dec_xysc[..., :2].reshape(B * DS, N, Tf, -1), + mask=fut_mask.any(-1).repeat_interleave(DS, 0), + ) + dec_homotopy_margin = torch.stack( + [ + HOMOTOPY_THRESHOLD - dec_angle_diff.abs(), + -dec_angle_diff - HOMOTOPY_THRESHOLD, + dec_angle_diff - HOMOTOPY_THRESHOLD, + ], + -1, + ) + output["dec_lm_margin"] = dec_lm_margin.view(B, DS, N, M, -1) + output["dec_lane_y_dev"] = y_dev.view(B, DS, N, M, Tf, -1) + output["dec_lane_psi_dev"] = psi_dev.view(B, DS, N, M, Tf, -1) + output["dec_homotopy_margin"] = dec_homotopy_margin.view( + B, DS, N, N, -1 + ) + output["dec_lm"] = dec_lm.view(B, DS, N, M, -1) + output["dec_homotopy"] = dec_homotopy.view(B, DS, N, N, -1) + + output.update(mode_pred) + output["GT_lane_mode"] = GT_lane_mode + output["GT_homotopy"] = GT_homotopy + output["mode_valid_flag"] = mode_valid_flag + output["batch"] = parsed_batch + output["step_time"] = self.cfg.step_time + return output + + def compute_metrics(self, pred_batch, data_batch): + EPS = 1e-3 + metrics = dict() + if "GT_homotopy" in pred_batch: + # train/validation mode + batch = pred_batch["batch"] + agent_fut, fut_mask, pred_traj = TensorUtils.to_numpy( + (batch["agent_fut"], batch["fut_mask"], pred_batch["trajectories"]) + ) + if pred_traj.shape[-2] != agent_fut.shape[-2]: + return metrics + mode_valid_flag = pred_batch["mode_valid_flag"] + a2a_valid_flag = mode_valid_flag.unsqueeze(-1) * mode_valid_flag.unsqueeze( + -2 + ) + + metrics["joint_mode_accuracy"] = TensorUtils.to_numpy( + torch.softmax(pred_batch["joint_pred"], dim=1)[:, 0].mean() + ).item() + metrics["joint_mode_correct_rate"] = TensorUtils.to_numpy( + ( + pred_batch["joint_pred"][:, 0] + == pred_batch["joint_pred"].max(dim=1)[0] + ) + .float() + .mean() + ).item() + + lane_mode_correct_flag = pred_batch["GT_lane_mode"].argmax(-1) == ( + pred_batch["lane_mode_pred"].argmax(-1).squeeze(-1) + ) + metrics["pred_lane_mode_correct_rate"] = TensorUtils.to_numpy( + (lane_mode_correct_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + metrics["lane_mode_accuracy"] = TensorUtils.to_numpy( + ( + ( + torch.softmax(pred_batch["lane_mode_pred"], dim=-1).squeeze(-2) + * pred_batch["GT_lane_mode"] + ).sum(-1) + * mode_valid_flag + ).sum() + / mode_valid_flag.sum() + ).item() + + homotopy_correct_flag = pred_batch["GT_homotopy"] == ( + pred_batch["homotopy_pred"].argmax(-1) + ) + + metrics["pred_homotopy_correct_rate"] = TensorUtils.to_numpy( + (homotopy_correct_flag * a2a_valid_flag).sum() / a2a_valid_flag.sum() + ).item() + metrics["homotopy_accuracy"] = TensorUtils.to_numpy( + ( + ( + torch.softmax(pred_batch["homotopy_pred"], dim=-1) + * F.one_hot(pred_batch["GT_homotopy"], len(HomotopyType)) + ).sum(-1) + * a2a_valid_flag + ).sum() + / a2a_valid_flag.sum() + ).item() + + B, N, Tf = agent_fut.shape[:3] + extent = batch["extent"] + traj = pred_batch["trajectories"] + DS = traj.size(1) + GT_homotopy = pred_batch["GT_homotopy"] + GT_lane_mode = pred_batch["GT_lane_mode"] + pred_xysc = torch.cat( + [traj[..., :2], torch.sin(traj[..., 2:3]), torch.cos(traj[..., 2:3])], + -1, + ).view(B, -1, N, Tf, 4) + end_points = pred_xysc[:, :, :, -1] # Only look at final time + lane_xyh = batch["lane_xyh"] + M = lane_xyh.size(1) + lane_xysc = torch.cat( + [ + lane_xyh[..., :2], + torch.sin(lane_xyh[..., 2:3]), + torch.cos(lane_xyh[..., 2:3]), + ], + -1, + ) + pred_lane_mode, _ = self.fut_lane_relation.categorize_lane_relation_pts( + end_points.reshape(B * N * DS, 1, 4), + lane_xysc.repeat_interleave(N * DS, 0), + batch["fut_mask"] + .any(-1) + .repeat_interleave(DS, 0) + .reshape(B * DS * N, 1), + batch["lane_mask"].repeat_interleave(DS * N, 0), + force_select=False, + ) + # You could have two lanes that it is both on + + pred_lane_mode = pred_lane_mode.squeeze(-2).argmax(-1).reshape(B, DS, N, M) + pred_lane_mode = torch.cat( + [pred_lane_mode, (pred_lane_mode == 0).all(-1, keepdim=True)], -1 + ) + + angle_diff, pred_homotopy = identify_pairwise_homotopy( + pred_xysc[..., :2].view(B * DS, N, Tf, 2), + mask=batch["fut_mask"].repeat_interleave(DS, 0), + ) + pred_homotopy = pred_homotopy.type(torch.int64).reshape(B, DS, N, N) + ML_homotopy_flag = (pred_homotopy[:, 1] == GT_homotopy).all(-1) + + ML_homotopy_flag.masked_fill_(torch.logical_not(mode_valid_flag), True) + metrics["ML_homotopy_correct_rate"] = TensorUtils.to_numpy( + (ML_homotopy_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + all_homotopy_flag = ( + (pred_homotopy[:, 1:] == GT_homotopy[:, None]).any(1).all(-1) + ) + all_homotopy_flag.masked_fill_(torch.logical_not(mode_valid_flag), True) + metrics["all_homotopy_correct_rate"] = TensorUtils.to_numpy( + (all_homotopy_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + ML_lane_mode_flag = (pred_lane_mode[:, 1] == GT_lane_mode).all(-1) + all_lane_mode_flag = ( + (pred_lane_mode[:, 1:] == GT_lane_mode[:, None]).any(1).all(-1) + ) + metrics["ML_lane_mode_correct_rate"] = TensorUtils.to_numpy( + (ML_lane_mode_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + metrics["all_lane_mode_correct_rate"] = TensorUtils.to_numpy( + (all_lane_mode_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + ML_scene_mode_flag = (pred_homotopy[:, 1] == GT_homotopy).all(-1) & ( + pred_lane_mode[:, 1] == GT_lane_mode + ).all(-1) + metrics["ML_scene_mode_correct_rate"] = TensorUtils.to_numpy( + (ML_scene_mode_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + all_scene_mode_flag = ( + (pred_homotopy[:, 1:] == GT_homotopy[:, None]).all(-1) + & (pred_lane_mode[:, 1:] == GT_lane_mode[:, None]).all(-1) + ).any(1) + metrics["all_scene_mode_correct_rate"] = TensorUtils.to_numpy( + (all_scene_mode_flag * mode_valid_flag).sum() / mode_valid_flag.sum() + ).item() + pred_edges = self.bu.generate_edges( + batch["agent_type"].repeat_interleave(DS, 0), + extent.repeat_interleave(DS, 0), + TensorUtils.join_dimensions(traj[..., :2], 0, 2), + TensorUtils.join_dimensions(traj[..., 2:3], 0, 2), + batch_first=True, + ) + pred_edges = {k: v for k, v in pred_edges.items() if k != "PP"} + + coll_loss, dis = collision_loss(pred_edges=pred_edges, return_dis=True) + dis_padded = { + k: v.nan_to_num(1.0).view(B, DS, -1, Tf) for k, v in dis.items() + } + edge_mask = { + k: torch.logical_not(v.isnan().any(-1)).view(B, DS, -1) + for k, v in dis.items() + } + for k in dis_padded: + metrics["collision_rate_ML_" + k] = TensorUtils.to_numpy( + (dis_padded[k][:, 1] < 0).sum() + / (edge_mask[k][:, 1].sum() + EPS) + / Tf + ).item() + metrics[f"collision_rate_all_{DS}_mode_{k}"] = TensorUtils.to_numpy( + (dis_padded[k] < 0).sum() / (edge_mask[k].sum() + EPS) / Tf + ).item() + metrics["collision_rate_ML"] = TensorUtils.to_numpy( + sum([(v[:, 1] < 0).sum() for v in dis_padded.values()]) + / Tf + / (sum([edge_mask[k][:, 1].sum() for k in edge_mask]) + EPS) + ).item() + metrics[f"collision_rate_all_{DS}_mode"] = TensorUtils.to_numpy( + sum([(v < 0).sum() for v in dis_padded.values()]) + / Tf + / (sum([edge_mask[k].sum() for k in edge_mask]) + EPS) + ).item() + + confidence = np.ones([B * N, 1]) + Nmode = pred_traj.shape[1] + agent_type = TensorUtils.to_numpy(batch["agent_type"]) + vehicle_mask = ( + (agent_type == AgentType.VEHICLE) + | (agent_type == AgentType.BICYCLE) + | (agent_type == AgentType.MOTORCYCLE) + ) + pedestrian_mask = agent_type == AgentType.PEDESTRIAN + agent_mask = fut_mask.any(-1) + dt = self.cfg.step_time + for Tsecond in [3.0, 4.0, 5.0, 6.0, 8.0]: + if Tf < Tsecond / dt: + continue + Tf_bar = int(Tsecond / dt) + # oracle mode means the trajectory decoded under the GT scene mode + ADE = Metrics.batch_average_displacement_error( + agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), + pred_traj[:, 0, ..., :Tf_bar, :2].reshape(B * N, 1, Tf_bar, 2), + confidence, + fut_mask.reshape(B * N, Tf)[..., :Tf_bar], + mode="mean", + ).reshape(B, N) + allADE = ADE.sum() / (agent_mask.sum() + EPS) + vehADE = (ADE * vehicle_mask * agent_mask).sum() / ( + (vehicle_mask * agent_mask).sum() + EPS + ) + pedADE = (ADE * pedestrian_mask * agent_mask).sum() / ( + (pedestrian_mask * agent_mask).sum() + EPS + ) + FDE = Metrics.batch_final_displacement_error( + agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), + pred_traj[:, 0, ..., :Tf_bar, :2].reshape(B * N, 1, Tf_bar, 2), + confidence, + fut_mask.reshape(B * N, Tf)[..., :Tf_bar], + mode="mean", + ).reshape(B, N) + + allFDE = FDE.sum() / (agent_mask.sum() + EPS) + vehFDE = (FDE * vehicle_mask * agent_mask).sum() / ( + (vehicle_mask * agent_mask).sum() + EPS + ) + pedFDE = (FDE * pedestrian_mask * agent_mask).sum() / ( + (pedestrian_mask * agent_mask).sum() + EPS + ) + metrics[f"oracle_ADE@{Tsecond}s"] = allADE + metrics[f"oracle_FDE@{Tsecond}s"] = allFDE + metrics[f"oracle_vehicle_ADE@{Tsecond}s"] = vehADE + metrics[f"oracle_vehicle_FDE@{Tsecond}s"] = vehFDE + metrics[f"oracle_pedestrian_ADE@{Tsecond}s"] = pedADE + metrics[f"oracle_pedestrian_FDE@{Tsecond}s"] = pedFDE + + # the second mode is the most likely mode (first is GT) + ADE = Metrics.batch_average_displacement_error( + agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), + pred_traj[:, 1, ..., :Tf_bar, :2].reshape(B * N, 1, Tf_bar, 2), + confidence, + fut_mask.reshape(B * N, Tf)[..., :Tf_bar], + mode="mean", + ).reshape(B, N) + FDE = Metrics.batch_final_displacement_error( + agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), + pred_traj[:, 1, ..., :Tf_bar, :2].reshape(B * N, 1, Tf_bar, 2), + confidence, + fut_mask.reshape(B * N, Tf)[..., :Tf_bar], + mode="mean", + ).reshape(B, N) + allADE = ADE.sum() / (agent_mask.sum() + EPS) + vehADE = (ADE * vehicle_mask * agent_mask).sum() / ( + (vehicle_mask * agent_mask).sum() + EPS + ) + pedADE = (ADE * pedestrian_mask * agent_mask).sum() / ( + (pedestrian_mask * agent_mask).sum() + EPS + ) + allFDE = FDE.sum() / (agent_mask.sum() + EPS) + vehFDE = (FDE * vehicle_mask * agent_mask).sum() / ( + (vehicle_mask * agent_mask).sum() + EPS + ) + pedFDE = (FDE * pedestrian_mask * agent_mask).sum() / ( + (pedestrian_mask * agent_mask).sum() + EPS + ) + + metrics[f"ML_ADE@{Tsecond}s"] = allADE + metrics[f"ML_FDE@{Tsecond}s"] = allFDE + metrics[f"ML_vehicle_ADE@{Tsecond}s"] = vehADE + metrics[f"ML_vehicle_FDE@{Tsecond}s"] = vehFDE + metrics[f"ML_pedestrian_ADE@{Tsecond}s"] = pedADE + metrics[f"ML_pedestrian_FDE@{Tsecond}s"] = pedFDE + + # minADE and minFDE are calculated excluding the oracle mode + ADE = Metrics.batch_average_displacement_error( + agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), + pred_traj[:, 1:, ..., :Tf_bar, :2] + .transpose(0, 2, 1, 3, 4) + .reshape(B * N, -1, Tf_bar, 2), + confidence.repeat(Nmode - 1, 1) / (Nmode - 1), + fut_mask.reshape(B * N, Tf)[..., :Tf_bar], + mode="oracle", + ).reshape(B, N) + FDE = Metrics.batch_final_displacement_error( + agent_fut[..., :Tf_bar, :2].reshape(B * N, Tf_bar, 2), + pred_traj[:, 1:, ..., :Tf_bar, :2] + .transpose(0, 2, 1, 3, 4) + .reshape(B * N, -1, Tf_bar, 2), + confidence.repeat(Nmode - 1, 1) / (Nmode - 1), + fut_mask.reshape(B * N, Tf)[..., :Tf_bar], + mode="oracle", + ).reshape(B, N) + + allADE = ADE.sum() / (agent_mask.sum() + EPS) + vehADE = (ADE * vehicle_mask * agent_mask).sum() / ( + (vehicle_mask * agent_mask).sum() + EPS + ) + pedADE = (ADE * pedestrian_mask * agent_mask).sum() / ( + (pedestrian_mask * agent_mask).sum() + EPS + ) + allFDE = FDE.sum() / (agent_mask.sum() + EPS) + vehFDE = (FDE * vehicle_mask * agent_mask).sum() / ( + (vehicle_mask * agent_mask).sum() + EPS + ) + pedFDE = (FDE * pedestrian_mask * agent_mask).sum() / ( + (pedestrian_mask * agent_mask).sum() + EPS + ) + + metrics[f"min_ADE@{Tsecond}s"] = allADE + metrics[f"min_FDE@{Tsecond}s"] = allFDE + metrics[f"min_vehicle_ADE@{Tsecond}s"] = vehADE + metrics[f"min_vehicle_FDE@{Tsecond}s"] = vehFDE + metrics[f"min_pedestrian_ADE@{Tsecond}s"] = pedADE + metrics[f"min_pedestrian_FDE@{Tsecond}s"] = pedFDE + + if "dec_lm" in pred_batch: + DS = pred_batch["dec_lm"].shape[1] + M = batch["lane_xyh"].shape[-2] + if self.cfg.classify_a2l_4all_lanes: + dec_lm, dec_cond_lm, lane_mask = TensorUtils.to_numpy( + ( + pred_batch["dec_lm"], + F.one_hot( + pred_batch["dec_cond_lm"][..., :M], + len(self.fut_lane_relation), + ), + batch["lane_mask"], + ) + ) + mask = ( + (agent_mask[:, :, None] * lane_mask[:, None]) + .reshape(B, 1, N, M) + .repeat(DS, 1) + ) + lm_consistent_rate = ((dec_lm == dec_cond_lm).all(-1) * mask).sum() / ( + (mask).sum() + EPS + ) + else: + dec_lm, dec_cond_lm, lane_mask = TensorUtils.to_numpy( + ( + pred_batch["dec_lm"], + pred_batch["dec_cond_lm"], + batch["lane_mask"], + ) + ) + dec_lm = dec_lm.argmax(-1) + + if self.null_lane_mode: + # augment the null lane + dec_lm = np.concatenate( + [dec_lm, (dec_lm == 0).all(-1, keepdims=True)], -1 + ) + lane_mask = np.concatenate( + [lane_mask, np.ones([B, 1], dtype=bool)], -1 + ) + mask = ( + (agent_mask[:, :, None] * lane_mask[:, None]) + .reshape(B, 1, N, -1) + .repeat(DS, 1) + ) + lm_consistent_rate = ((dec_lm == dec_cond_lm) * mask).sum() / ( + (mask).sum() + EPS + ) + + metrics["lm_consistent_rate"] = lm_consistent_rate + if "dec_homotopy" in pred_batch: + dec_homotopy, dec_cond_homotopy = TensorUtils.to_numpy( + (pred_batch["dec_homotopy"], pred_batch["dec_cond_homotopy"]) + ) + DS = dec_homotopy.shape[1] + mask = ( + (agent_mask[:, :, None] * agent_mask[:, None]) + .reshape(B, 1, N, N) + .repeat(DS, 1) + ) + + homo_consistent_rate = ( + (dec_homotopy.squeeze(-1) == dec_cond_homotopy) * mask + ).sum() / ((mask).sum() + EPS) + metrics["homo_consistent_rate"] = homo_consistent_rate + + return metrics + + def compute_losses(self, pred_batch, inputs): + EPS = 1e-3 + batch = pred_batch["batch"] + lm_K = len(self.fut_lane_relation) + homo_K = len(HomotopyType) + GT_homotopy = F.one_hot(pred_batch["GT_homotopy"], homo_K).float() + GT_lane_mode = F.one_hot(pred_batch["GT_lane_mode"], lm_K).float() + mode_valid_flag = pred_batch["mode_valid_flag"] + dt = self.cfg.step_time + Tf = batch["agent_fut"].shape[-2] + if not self.cfg.classify_a2l_4all_lanes: + # We perform softmax (in cross entropy) over the lane segments + mask_noton_mode = torch.ones( + len(self.fut_lane_relation), dtype=torch.bool, device=self.device + ) + mask_noton_mode[self.fut_lane_relation.NOTON] = False + DS = pred_batch["trajectories"].size(1) + future_mask = batch["fut_mask"] + agent_mask = batch["hist_mask"].any(-1) + # mode probability loss + + B, N, Mext = pred_batch["GT_lane_mode"].shape[:3] + if self.null_lane_mode: + M = Mext - 1 + else: + M = Mext + # marginal mode prediction probability + lane_mode_pred = pred_batch["lane_mode_pred"] + homotopy_pred = pred_batch["homotopy_pred"] + # marginal probability loss, use "subtract max" trick for stable cross entropy + + B, N = agent_mask.shape + if not self.cfg.classify_a2l_4all_lanes: + # We perform softmax (in cross entropy) over the lane segments + GT_lane_mode_marginal = torch.masked_select( + GT_lane_mode, mask_noton_mode[None, None, None, :] + ).view(*GT_lane_mode.shape[:3], -1) + GT_lane_mode_marginal = GT_lane_mode_marginal.swapaxes( + -1, -2 + ) # Not onehot anymore + + lane_mode_pred = lane_mode_pred.reshape(-1, Mext) + marginal_lm_loss = self.lane_mode_loss( + lane_mode_pred - lane_mode_pred.max(dim=-1, keepdim=True)[0], + GT_lane_mode_marginal.reshape(-1, Mext), + ) + lm_mask = mode_valid_flag.flatten().repeat_interleave(lm_K - 1, 0) + else: + lane_mode_pred = lane_mode_pred.reshape(-1, lm_K) + marginal_lm_loss = self.lane_mode_loss( + lane_mode_pred - lane_mode_pred.max(dim=-1, keepdim=True)[0], + GT_lane_mode.reshape(-1, lm_K), + ) + lm_mask = mode_valid_flag.flatten().repeat_interleave(M, 0) + raise NotImplementedError + + marginal_homo_loss = self.homotopy_loss( + homotopy_pred.reshape(-1, homo_K) + - homotopy_pred.reshape(-1, homo_K).max(dim=-1, keepdim=True)[0], + GT_homotopy.reshape(-1, homo_K), + ) + + homo_mask = ( + mode_valid_flag.unsqueeze(2) * mode_valid_flag.unsqueeze(1) + ).flatten() + marginal_lm_loss = (marginal_lm_loss * lm_mask).sum() / (lm_mask.sum() + EPS) + marginal_homo_loss = (marginal_homo_loss * homo_mask).sum() / ( + homo_mask.sum() + EPS + ) + + # joint probability + joint_logpi = pred_batch["joint_pred"] + GT_joint_mode = torch.zeros_like(joint_logpi) + GT_joint_mode[ + ..., 0 + ] = 1 # FIXME: Shouldn't this be only if the first element is the GT? + joint_prob_loss = self.joint_mode_loss( + (joint_logpi - joint_logpi.max(dim=-1, keepdim=True)[0]).nan_to_num(0), + GT_joint_mode, + ).mean() + + # decoded trajectory consistency loss + dec_normalized_prob = pred_batch["dec_cond_prob"] / pred_batch[ + "dec_cond_prob" + ].sum(-1, keepdim=True) + + if "dec_lm_margin" in pred_batch: + mask = (agent_mask.unsqueeze(2) * batch["lane_mask"].unsqueeze(1)).view( + B, 1, N, M, 1 + ) + if self.null_lane_mode: + mask = torch.cat([mask, torch.ones_like(mask[:, :, :, :1])], -2) + noton_mask = torch.ones(len(self.fut_lane_relation), device=mask.device) + noton_mask[self.fut_lane_relation.NOTON] = 0 + mask = mask * noton_mask + if self.null_lane_mode: + # the margin for null lane is the inverse of the maximum of all lane margins + null_lm_margin = ( + -pred_batch["dec_lm_margin"].max(-2)[0] + * (1 - noton_mask)[None, None, None, :] + + -pred_batch["dec_lm_margin"].min(-2)[0] + * noton_mask[None, None, None, :] + ) + dec_lm_margin = torch.cat( + [pred_batch["dec_lm_margin"], null_lm_margin.unsqueeze(-2)], -2 + ) + else: + dec_lm_margin = pred_batch["dec_lm_margin"] - self.cfg.lm_margin_offset + + if self.weighted_consistency_loss: + lm_consistency_loss = ( + loss_clip( + F.relu( + -dec_lm_margin + * pred_batch["dec_cond_lm"].unsqueeze(-1) + * mask + ), + max_loss=4.0, + ).sum(-1) + * dec_normalized_prob[..., None, None] + ).sum() / B # FIXME: should this now be over M? (with self.cfg.classify_a2l_4all_lanes) + else: + lm_consistency_loss = ( + ( + loss_clip( + F.relu( + -dec_lm_margin + * pred_batch["dec_cond_lm"].unsqueeze(-1) + * mask + ), + max_loss=4.0, + ).sum(-1) + ).sum() + / B + / DS + ) + + else: + lm_consistency_loss = torch.tensor(0.0, device=self.device) + if "dec_lane_y_dev" in pred_batch: + mask = (agent_mask.unsqueeze(2) * batch["lane_mask"].unsqueeze(1)).view( + B, 1, N, M + ) + time_weight = torch.arange(1, Tf + 1, device=self.device) / Tf**2 + psi_dev = pred_batch["dec_lane_psi_dev"].squeeze(-1) + y_dev = pred_batch["dec_lane_y_dev"].squeeze(-1) + yaw_dev_loss = ( + psi_dev.abs() + * mask[..., :M, None] + * pred_batch["dec_cond_lm"][..., :M, None] + * time_weight[None, None, None, None, :] + ).sum() / (mask[..., :M].sum() + EPS) + y_dev_loss = ( + y_dev.abs() + * mask[..., :M, None] + * pred_batch["dec_cond_lm"][..., :M, None] + * time_weight[None, None, None, None, :] + ).sum() / (mask[..., :M].sum() + EPS) + else: + yaw_dev_loss = torch.tensor(0.0, device=self.device) + y_dev_loss = torch.tensor(0.0, device=self.device) + if "dec_homotopy_margin" in pred_batch: + mask = (agent_mask.unsqueeze(2) * agent_mask.unsqueeze(1))[ + :, None, :, :, None + ] + if self.weighted_consistency_loss: + homotopy_consistency_loss = ( + loss_clip( + F.relu( + -pred_batch["dec_homotopy_margin"] + * F.one_hot( + pred_batch["dec_cond_homotopy"], len(HomotopyType) + ) + * mask + ), + max_loss=4.0, + ).sum(-1) + * dec_normalized_prob[..., None, None] + ).sum() / B + else: + homotopy_consistency_loss = ( + ( + loss_clip( + F.relu( + -pred_batch["dec_homotopy_margin"] + * F.one_hot( + pred_batch["dec_cond_homotopy"], + len(HomotopyType), + ) + * mask + ), + max_loss=4.0, + ).sum(-1) + ).sum() + / B + / DS + ) + else: + homotopy_consistency_loss = torch.tensor(0.0, device=self.device) + + # trajectory reconstruction loss + traj = pred_batch["trajectories"] + if traj.shape[-2] != batch["agent_fut"].shape[-2]: + xy_loss = torch.tensor(0.0, device=traj.device) + yaw_loss = torch.tensor(0.0, device=traj.device) + coll_loss = torch.tensor(0.0, device=traj.device) + acce_reg_loss = torch.tensor(0.0, device=traj.device) + steering_reg_loss = torch.tensor(0.0, device=traj.device) + input_violation_loss = torch.tensor(0.0, device=traj.device) + jerk_loss = torch.tensor(0.0, device=traj.device) + else: + # only penalize the trajectory under GT mode + traj_GT_mode = pred_batch["trajectories"][:, 0] + GT_xy = batch["agent_fut"][..., :2] + GT_h = ratan2(batch["agent_fut"][..., 6:7], batch["agent_fut"][..., 7:8]) + + xy_loss = self.traj_loss(GT_xy, traj_GT_mode[..., :2]) + xy_loss = xy_loss.norm(dim=-1) + xy_loss = (xy_loss * future_mask).sum() / (future_mask.sum() + EPS) + yaw_loss = torch.abs( + GeoUtils.round_2pi(traj_GT_mode[..., 2:3] - GT_h) + ).squeeze(-1) + yaw_loss = (yaw_loss * future_mask).sum() / (future_mask.sum() + EPS) + + # collision loss + extent = batch["extent"] + DS = traj.size(1) + pred_edges = self.bu.generate_edges( + batch["agent_type"].repeat_interleave(DS, 0), + extent.repeat_interleave(DS, 0), + TensorUtils.join_dimensions(traj[..., :2], 0, 2), + TensorUtils.join_dimensions(traj[..., 2:3], 0, 2), + batch_first=True, + ) + + edge_weight = pred_batch["dec_cond_prob"].flatten().view(-1, 1) + + coll_loss = collision_loss( + pred_edges={k: v for k, v in pred_edges.items() if k != "PP"}, + weight=edge_weight, + ).nan_to_num(0) + # diff = torch.norm((traj[:,1,...,:2]-traj[:,2,...,:2])*agent_mask[...,None,None])+torch.norm((traj[:,2,...,:2]-traj[:,3,...,:2])*agent_mask[...,None,None])+torch.norm((traj[:,3,...,:2]-traj[:,4,...,:2])*agent_mask[...,None,None]) + # print(diff) + acce_reg_loss = torch.tensor(0.0, device=traj.device) + steering_reg_loss = torch.tensor(0.0, device=traj.device) + input_violation_loss = torch.tensor(0.0, device=traj.device) + jerk_loss = torch.tensor(0.0, device=traj.device) + + type_mask = pred_batch["type_mask"] + inputs = pred_batch["inputs"] + input_violation = ( + pred_batch["input_violation"] + if "input_violation" in pred_batch + else dict() + ) + jerk = pred_batch["jerk"] if "jerk" in pred_batch else dict() + + for k, dyn in self.nets["policy"].dyn.items(): + if type_mask[k].sum() == 0: + continue + if isinstance(dyn, Unicycle) or isinstance(dyn, Unicycle_xyvsc): + acce_reg_loss += ( + inputs[k][..., 0:1].norm(dim=-1).mean(-1) * type_mask[k] + ).sum() / (type_mask[k].sum() + EPS) + steering_reg_loss += ( + inputs[k][..., 1:2].norm(dim=-1).mean(-1) * type_mask[k] + ).sum() / (type_mask[k].sum() + EPS) + + if k in input_violation: + input_violation_loss += ( + input_violation[k].sum(-1).mean(-1) * type_mask[k] + ).sum() / (type_mask[k].sum() + EPS) + + if k in jerk: + jerk_loss += ( + (jerk[k] ** 2).sum(-1).mean(-1) * type_mask[k] + ).sum() / (type_mask[k].sum() + EPS) + + losses = dict( + marginal_lm_loss=marginal_lm_loss, + marginal_homo_loss=marginal_homo_loss, + l2_reg=lane_mode_pred.nan_to_num(0).norm(dim=-1).mean() + + homotopy_pred.nan_to_num(0).norm(dim=-1).mean() + + joint_logpi.nan_to_num(0).abs().mean(), + joint_prob_loss=joint_prob_loss, + lm_consistency_loss=lm_consistency_loss, + homotopy_consistency_loss=homotopy_consistency_loss, + yaw_dev_loss=yaw_dev_loss, + y_dev_loss=y_dev_loss, + xy_loss=xy_loss, + yaw_loss=yaw_loss, + coll_loss=coll_loss, + acce_reg_loss=acce_reg_loss, + steering_reg_loss=steering_reg_loss, + input_violation_loss=input_violation_loss, + jerk_loss=jerk_loss, + ) + for k, v in losses.items(): + if v.isinf() or v.isnan(): + raise ValueError(f"{k} becomes NaN") + return losses + + def log_pred_image( + self, + batch, + pred_batch, + batch_idx, + logger, + log_all_image=False, + savegif=False, + **kwargs, + ): + if batch_idx == 0: + return + + try: + parsed_batch = pred_batch["batch"] + indices = ( + list(range(parsed_batch["hist_mask"].shape[0])) + if log_all_image + else [parsed_batch["hist_mask"][..., -1].sum(-1).argmax().item()] + ) + world_from_agent = TensorUtils.to_numpy(parsed_batch["world_from_agent"]) + center_xy = TensorUtils.to_numpy( + parsed_batch["centered_agent_state"][:, :2] + ) + centered_agent_state = TensorUtils.to_numpy( + parsed_batch["centered_agent_state"] + ) + world_yaw = np.arctan2( + centered_agent_state[:, -2], centered_agent_state[:, -1] + ) + curr_posyaw = torch.cat( + [ + parsed_batch["agent_hist"][..., :2], + torch.arctan2( + parsed_batch["agent_hist"][..., -2], + parsed_batch["agent_hist"][..., -1], + )[..., None], + ], + -1, + ) + NS = pred_batch["trajectories"].shape[1] + traj = torch.cat( + [ + curr_posyaw[:, None, :, -1:].repeat_interleave(NS, 1), + pred_batch["trajectories"], + ], + -2, + ) + traj = TensorUtils.to_numpy(traj) + world_traj_xy = GeoUtils.batch_nd_transform_points_np( + traj[..., :2], world_from_agent[:, None, None] + ) + world_traj_yaw = traj[..., 2] + world_yaw[:, None, None, None] + world_traj = np.concatenate([world_traj_xy, world_traj_yaw[..., None]], -1) + + extent = TensorUtils.to_numpy(parsed_batch["agent_hist_extent"][:, :, -1]) + hist_mask = TensorUtils.to_numpy(parsed_batch["hist_mask"]) + # pool = mp.Pool(len(indices)) + + # def plot_scene(bi): + for bi in indices: + if isinstance(logger, str) or isinstance(logger, Path): + # directly save file + html_file_name = ( + Path(logger) / f"CTT_visualization_{batch_idx}_{bi}.html" + ) + else: + html_file_name = f"CTT_visualization_{bi}.html" + if os.path.exists(html_file_name): + os.remove(html_file_name) + + # plot agent + for mode in range(min(NS, 3)): + output_file(html_file_name) + graph = figure(title="Bokeh graph", width=900, height=900) + graph.xgrid.grid_line_color = None + graph.ygrid.grid_line_color = None + + graph.x_range = Range1d( + center_xy[bi, 0] - 40, center_xy[bi, 0] + 40 + ) + graph.y_range = Range1d( + center_xy[bi, 1] - 30, center_xy[bi, 1] + 50 + ) + plot_scene_open_loop( + graph, + world_traj[bi, mode], + extent[bi], + parsed_batch["vector_maps"][bi], + np.eye(3), + bbox=[ + center_xy[bi, 0] - 40, + center_xy[bi, 0] + 40, + center_xy[bi, 1] - 30, + center_xy[bi, 1] + 50, + ], + mask=hist_mask[bi].any(-1), + color_scheme="palette", + ) + if isinstance(logger, WandbLogger): + save(graph) + wandb_html = wandb.Html(open(html_file_name)) + logger.experiment.log( + {f"val_pred_image_mode{mode}": wandb_html} + ) + + elif isinstance(logger, str) or isinstance(logger, Path): + html_file_name = ( + Path(logger) + / f"CTT_visualization_{batch_idx}_{bi}_mode{mode}.html" + ) + output_file(html_file_name) + save(graph) + png_name = ( + str(html_file_name).removesuffix(".html") + + f"_mode{mode}.png" + ) + export_png(graph, filename=png_name) + del graph + if savegif: + graph = figure(title="Bokeh graph", width=900, height=900) + graph.xgrid.grid_line_color = None + graph.ygrid.grid_line_color = None + + graph.x_range = Range1d( + center_xy[bi, 0] - 40, center_xy[bi, 0] + 40 + ) + graph.y_range = Range1d( + center_xy[bi, 1] - 30, center_xy[bi, 1] + 50 + ) + gif_name = ( + str(html_file_name).removesuffix(".html") + + f"_mode{mode}.gif" + ) + animate_scene_open_loop( + graph, + world_traj[bi, mode], + extent[bi], + parsed_batch["vector_maps"][bi], + np.eye(3), + bbox=[ + center_xy[bi, 0] - 40, + center_xy[bi, 0] + 40, + center_xy[bi, 1] - 30, + center_xy[bi, 1] + 50, + ], + mask=hist_mask[bi].any(-1), + color_scheme="palette", + dt=parsed_batch["dt"][0].item(), + gif_name=gif_name, + # tmp_dir=str(html_file_name).removesuffix(".html"), + ) + del graph + + # pool.map(plot_scene, indices) + except Exception as e: + print(e) + pass diff --git a/diffstack/modules/predictors/constvel_predictor.py b/diffstack/modules/predictors/constvel_predictor.py deleted file mode 100644 index c2ce24b..0000000 --- a/diffstack/modules/predictors/constvel_predictor.py +++ /dev/null @@ -1,87 +0,0 @@ -import torch -import numpy as np -from typing import Dict, Optional, Union, Any - -from trajdata.data_structures.batch import AgentBatch - -from mpc import util as mpc_util - -from diffstack.modules.module import Module, DataFormat, RunMode -from diffstack.modules.dynamics_functions import ExtendedUnicycleDynamics -from diffstack.modules.predictors.trajectron_utils.model.components import GMM2D -from diffstack.utils.pred_utils import compute_prediction_metrics - -class ConstVelPredictor(Module): - - @property - def input_format(self) -> DataFormat: - return DataFormat(["agent_batch", "loss_weights"]) - - @property - def output_format(self) -> DataFormat: - return DataFormat(["pred_dist", "pred_ml", "metrics:train", "metrics:validate"]) - - def __init__(self, model_registrar, - hyperparams, log_writer, - device): - super().__init__(model_registrar, hyperparams, log_writer, device) - - self.dyn_obj = ExtendedUnicycleDynamics(dt=self.hyperparams['dt']) - - def train(self, inputs: Dict) -> Dict: - outputs = self.infer(inputs) - outputs["loss"] = torch.zeros((), device=self.device) - return outputs - - def validate(self, inputs: Dict) -> Dict: - batch: AgentBatch = inputs["agent_batch"] - outputs = self.infer(inputs) - outputs["metrics"] = self.validation_metrics(outputs['pred_dist'], outputs['pred_ml'], batch.agent_fut) - return outputs - - def infer(self, inputs: Dict): - batch: AgentBatch = inputs["agent_batch"] - pred_ml, pred_dist = self.constvel_predictions(batch.agent_hist[:, -1], batch.agent_fut.shape[1]) - return {"pred_dist": pred_dist, "pred_ml": pred_ml} - - def constvel_predictions(self, curr_agent_state: torch.Tensor, prediction_horizon: int): - # {'position': ['x', 'y'], 'velocity': ['x', 'y'], 'acceleration': ['x', 'y'], 'heading': ['°', 'd°'], 'augment': ['ego_indicator']} - pos_xy, vel_xy, acc_xy, yaw, dyaw = torch.split_with_sizes(curr_agent_state[..., :8], (2, 2, 2, 1, 1), dim=-1) - vel = torch.linalg.norm(vel_xy, dim=-1).unsqueeze(-1) - - # state: x, y, yaw, vel; control: yaw_rate, acc - x0 = torch.cat((pos_xy, yaw, vel), dim=-1) # (b, 4) - u_zeros = torch.zeros((prediction_horizon+1, x0.shape[0], 2), dtype=x0.dtype, device=x0.device) - - x_constvel = mpc_util.get_traj(prediction_horizon+1, u_zeros, x0, self.dyn_obj) # (T+1, b, 4) - pred_constvel = x_constvel[1:, :, :2] # (T, b, 2) - - # # Simple alternative - # x, y, heading, v = torch.unbind(x0, dim=-1) - # vx = v * 0.5 * torch.cos(heading) - # vy = v * 0.5 * torch.sin(heading) - # x_const = x.unsqueeze(0) + vx.unsqueeze(0) * torch.arange(prediction_horizon1).float().to(self.device).unsqueeze(1) - # y_const = y.unsqueeze(0) + vy.unsqueeze(0) * torch.arange(prediction_horizon+1).float().to(self.device).unsqueeze(1) - # # heading_const = torch.repeat_interleave(heading.unsqueeze(0), prediction_horizon+1, dim=0) - # # v_const = torch.repeat_interleave(v.unsqueeze(0), prediction_horizon+1, dim=0) - # # traj_const = torch.stack([x_const, y_const, heading_const, v_const], dim=-1) - # traj_const = torch.stack([x_const, y_const], dim=-1) - # traj_const = traj_const[1:] - # assert torch.isclose(pred_constvel, traj_const).all() - - pred_constvel = pred_constvel.transpose(0, 1).unsqueeze(0) # (1, b, T, 2) - - mus = pred_constvel.unsqueeze(3) # (1, b, T, 1, 2) - log_pis = torch.zeros((1, x0.shape[0], prediction_horizon, 1), dtype=x0.dtype, device=x0.device) - # log_sigmas = torch.log(torch.tensor([0.0393, 0.4288, 1.6322, 4.1350, 8.4635, 15.1796]), dtype=x0.dtype, device=x0.device) - log_sigmas = torch.log(torch.tensor((self.hyperparams['dt']*np.arange(7))[1:]**2*2, dtype=x0.dtype, device=x0.device)) - log_sigmas = log_sigmas.reshape(1, 1, prediction_horizon, 1, 1).repeat((1, x0.shape[0], 1, 1, 2)) - corrs = 0. * torch.ones((1, x0.shape[0], prediction_horizon, 1), dtype=x0.dtype, device=x0.device) # TODO not sure what is reasonable - - y_dists = GMM2D(log_pis, mus, log_sigmas, corrs) - return pred_constvel, y_dists - - def validation_metrics(self, pred_dist, pred_ml, agent_fut): - # Compute default metrics - metrics = compute_prediction_metrics(pred_ml, agent_fut[..., :2], y_dists=pred_dist) - return metrics diff --git a/diffstack/modules/predictors/factory.py b/diffstack/modules/predictors/factory.py new file mode 100644 index 0000000..3ceb7fb --- /dev/null +++ b/diffstack/modules/predictors/factory.py @@ -0,0 +1,59 @@ +"""Factory methods for creating planner""" +from diffstack.configs.base import AlgoConfig + +from diffstack.utils.utils import removeprefix + +from diffstack.modules.predictors.kinematic_predictor import KinematicTreeModel + + +from diffstack.modules.predictors.CTT import CTTTrafficModel + + +def predictor_factory( + model_registrar, + config: AlgoConfig, + logger, + device, + input_mappings={}, + checkpoint=None, +): + """ + A factory for creating predictor modules + + Args: + config (AlgoConfig): an AlgoConfig object, + + Returns: + predictor: predictor module + """ + algo_name = config.name + + if algo_name == "kinematic": + predictor = KinematicTreeModel( + model_registrar, config, logger, device, input_mappings=input_mappings + ) + + elif algo_name == "CTT": + predictor = CTTTrafficModel( + model_registrar, config, logger, device, input_mappings=input_mappings + ) + else: + raise NotImplementedError(f"{algo_name} is not implemented") + + if checkpoint is not None: + if "state_dict" in checkpoint: + predictor_dict = { + removeprefix(k, "components.predictor."): v + for k, v in checkpoint["state_dict"].items() + if k.startswith("components.predictor.") + } + elif "model_dict" in checkpoint: + predictor_dict = { + "model." + k: v for k, v in checkpoint["model_dict"].items() + } + # Are we loading all model parameters? + assert all([k in predictor_dict for k in predictor.state_dict().keys()]) + else: + raise ValueError("Unknown checkpoint format") + predictor.load_state_dict(predictor_dict) + return predictor diff --git a/diffstack/modules/predictors/kinematic_predictor.py b/diffstack/modules/predictors/kinematic_predictor.py new file mode 100644 index 0000000..df772a1 --- /dev/null +++ b/diffstack/modules/predictors/kinematic_predictor.py @@ -0,0 +1,163 @@ +import torch +import numpy as np +from diffstack.dynamics.unicycle import Unicycle +from diffstack.modules.module import Module, DataFormat +from diffstack.utils.batch_utils import batch_utils +from trajdata.data_structures.batch import AgentBatch,SceneBatch +from diffstack.modules.predictors.trajectron_utils.model.components import GMM2D +import diffstack.utils.tensor_utils as TensorUtils +from typing import Dict, Optional, Union, Any + + +class KinematicTreeModel(Module): + + @property + def input_format(self) -> DataFormat: + if self.scene_centric: + return DataFormat(["scene_batch"]) + else: + return DataFormat(["agent_batch"]) + + @property + def output_format(self) -> DataFormat: + return DataFormat(["pred_dist", "pred_ml", "pred_single", "metrics"]) + + def __init__(self,model_registrar, config, log_writer, device, input_mappings={}): + super().__init__(model_registrar, config, log_writer, device, input_mappings) + self.config = config + self.step_time = config.step_time + if "dynamics" in config: + if config.dynamics.type=="Unicycle": + self.dyn = Unicycle(self.step_time,max_steer=config.dynamics.max_steer,max_yawvel=config.dynamics.max_yawvel,acce_bound=config.dynamics.acce_bound) + else: + raise NotImplementedError + else: + self.dyn=Unicycle(self.step_time) + + self.stage=config["stage"] + self.num_frames_per_stage = config["num_frames_per_stage"] + self.M = config["M"] + self.scene_centric = config["scene_centric"] + self.only_branch_closest = config.only_branch_closest if "only_branch_closest" in config else self.scene_centric + self.bu = batch_utils(rasterize_mode="none",parse=True) + + def train_step(self, inputs: Dict) -> Dict: + outputs = self.infer_step(inputs) + outputs["loss"] = torch.zeros((), device=self.device) + return outputs + + def validate_step(self, inputs: Dict) -> Dict: + outputs = self.infer_step(inputs) + outputs["metrics"] = {} + return outputs + + def infer_step(self, inputs: Dict): + if self.scene_centric: + batch = inputs["scene_batch"] + else: + batch = inputs["agent_batch"] + if isinstance(batch, dict) and "batch" in batch: + batch = batch ["batch"] + batch = self.bu.parse_batch(batch) + pred_ml, pred_dist = self.kinematic_predictions(batch) + + # Dummy single agent prediction. + agent_fut = batch["agent_fut"] + pred_single = torch.full([agent_fut.shape[0], 0, agent_fut.shape[2], 4], dtype=agent_fut.dtype, device=agent_fut.device, fill_value=torch.nan) + + output = dict(pred_dist=pred_dist, pred_ml=pred_ml, pred_single=pred_single) + if self.scene_centric: + output["scene_batch"]=batch + else: + output["agent_batch"]=batch + return output + + def kinematic_predictions(self,batch): + if batch["agent_hist"].shape[-1] == 7: # x, y, vx, vy, ax, ay, heading + yaw = batch["agent_hist"][...,6:7] + else: + assert batch["agent_hist"].shape[-1] == 8 # x, y, vx, vy, ax, ay, sin(heading), cos(heading) + yaw = torch.atan2(batch["agent_hist"][..., [-2]], batch["agent_hist"][..., [-1]]) + pos = batch["agent_hist"][...,:2] + speed = torch.norm(batch["agent_hist"][..., 2:4], dim=-1,keepdim=True) + curr_states = torch.cat([pos, yaw, speed], -1)[...,-1,:] + + bs = curr_states.shape[0] + + device = curr_states.device + assert self.M<=4 + if self.scene_centric: + Na = batch["agent_hist"].shape[1] + inputs = torch.zeros([bs,self.M,Na,2]).to(device) + if self.only_branch_closest: + if batch["agent_hist"].shape[1] == 1: + branch_idx = 0 + else: + dis = torch.norm(batch["agent_hist"][0,1:,-1,:2]-batch["agent_hist"][0,:1,-1,:2],dim=-1) + branch_idx = dis.argmin()+1 + else: + branch_idx = 1 if Na>1 else 0 + if self.M>1: + inputs[:,1,branch_idx,1]=-0.02*curr_states[:,branch_idx,2] + if self.M>2: + inputs[:,2,branch_idx,1]=0.02*curr_states[:,branch_idx,2] + if self.M>3: + inputs[:,3,branch_idx,0]=torch.clip(-curr_states[:,branch_idx,2]*0.5,min=-3) + + else: + inputs = torch.zeros([bs,self.M,2]).to(device) + if self.M>1: + inputs[:,1,1]=-0.02*curr_states[:,2] + if self.M>2: + inputs[:,2,1]=0.02*curr_states[:,2] + if self.M>3: + inputs[:,3,0]=torch.clip(-curr_states[:,2]*0.5,min=-3) + assert self.M<=4 + + + inputs = inputs.unsqueeze(-2).repeat_interleave(self.num_frames_per_stage,-2) + + T = self.num_frames_per_stage*self.stage + pred_by_stage = list() + for stage in range(self.stage): + inputs_i = inputs.repeat_interleave(self.M**stage,0) + if stage==0: + curr_state_i = curr_states + else: + curr_state_i = pred_by_stage[stage-1][...,-1,:] + pred_i = self.dyn.forward_dynamics(curr_state_i.repeat_interleave(self.M,0),TensorUtils.join_dimensions(inputs_i,0,2)) + pred_by_stage.append(pred_i) + + for stage in range(self.stage): + if self.scene_centric: + pred_by_stage[stage]=pred_by_stage[stage].reshape([bs,-1,Na,self.num_frames_per_stage,4]).repeat_interleave(self.M**(self.stage-1-stage),1) + else: + pred_by_stage[stage]=pred_by_stage[stage].reshape([bs,-1,self.num_frames_per_stage,4]).repeat_interleave(self.M**(self.stage-1-stage),1) + pred_traj = torch.cat(pred_by_stage,-2) # b x N x T x D + if self.scene_centric: + pred_traj = pred_traj.reshape(bs,self.M**self.stage,Na,T,-1) + prob = torch.ones([bs,self.M**self.stage]).to(device) # b x N + prob = prob/prob.sum(-1,keepdim=True) + if self.scene_centric: + mus = pred_traj[...,[0,1,2,3]].permute(2,0,3,1,4) # (Na, b, T, M**stage, 4) + Tf = pred_traj.shape[-2] + log_pis = torch.log(prob)[None,:,None,:].repeat_interleave(Tf,-2).repeat_interleave(Na,0) + log_sigmas = torch.log(torch.tensor((self.config['dt']*np.arange(Tf+1))[1:]**2*2, dtype=curr_states.dtype, device=curr_states.device)) + log_sigmas = log_sigmas.reshape(1, 1, Tf, 1, 1).repeat((Na, bs, 1, self.M**self.stage, 2)) + corrs = 0. * torch.ones((1, bs, Tf, self.M**self.stage), dtype=curr_states.dtype, device=curr_states.device) # TODO not sure what is reasonable + y_dists = GMM2D(log_pis, mus, log_sigmas, corrs) + else: + mus = pred_traj[...,[0,1,2,3]].transpose(1,2).unsqueeze(0) # (1, b, T, M**stage, 4) + Tf = pred_traj.shape[-2] + log_pis = torch.log(prob)[None,:,None,:].repeat_interleave(Tf,-2) + log_sigmas = torch.log(torch.tensor((self.config['dt']*np.arange(Tf+1))[1:]**2*2, dtype=curr_states.dtype, device=curr_states.device)) + log_sigmas = log_sigmas.reshape(1, 1, Tf, 1, 1).repeat((1, bs, 1, self.M**self.stage, 2)) + corrs = 0. * torch.ones((1, bs, Tf, self.M**self.stage), dtype=curr_states.dtype, device=curr_states.device) # TODO not sure what is reasonable + + y_dists = GMM2D(log_pis, mus, log_sigmas, corrs) + return pred_traj, y_dists + + def validation_metrics(self, pred_dist, pred_ml, agent_fut): + # Compute default metrics + metrics = compute_prediction_metrics(pred_ml, agent_fut[..., :2], y_dists=pred_dist) + return metrics \ No newline at end of file diff --git a/diffstack/modules/predictors/trajectron_predictor.py b/diffstack/modules/predictors/trajectron_predictor.py deleted file mode 100644 index 11e8d33..0000000 --- a/diffstack/modules/predictors/trajectron_predictor.py +++ /dev/null @@ -1,315 +0,0 @@ -import itertools -import torch -import numpy as np -from typing import Dict, Optional, Union, Any -from contextlib import nullcontext - -from trajdata.data_structures.batch import AgentBatch -from trajdata.data_structures.agent import AgentType -from trajdata.utils.arr_utils import PadDirection -from diffstack.utils.pred_utils import compute_prediction_metrics -from diffstack.utils.utils import set_all_seeds, restore - -from diffstack.modules.module import Module, DataFormat, RunMode -from diffstack.modules.predictors.trajectron_utils.node_type import NodeTypeEnum -from diffstack.modules.predictors.trajectron_utils.model.mgcvae import MultimodalGenerativeCVAE - -from diffstack.data.cached_nusc_as_trajdata import standardized_manual_state, convert_trajdata_hist_to_manual_hist - - -class TrajectronPredictor(Module): - # TODO (pkarkus) add support for multiple predicted agent types. Add a function that separates metrics per agent type. - - @property - def input_format(self) -> DataFormat: - return DataFormat(["agent_batch", "loss_weights:train"]) - - @property - def output_format(self) -> DataFormat: - return DataFormat(["pred_dist", "pred_ml", "loss:train", "metrics:train", "metrics:validate"]) - - def __init__(self, model_registrar, - hyperparams, log_writer, - device, input_mappings = {}): - super().__init__(model_registrar, hyperparams, log_writer, device, input_mappings) - - # Prediction state variables - self.pred_state = self.hyperparams['pred_state'] - self.state = self.hyperparams['state'] - self.state_length = dict() - for state_type in self.state.keys(): - self.state_length[state_type] = int( - np.sum([len(entity_dims) for entity_dims in self.state[state_type].values()]) - ) - assert self.state_length['PEDESTRIAN'] in [6, 7] - - # Validate that hyperparameters are consistent - node_type_state = self.state["VEHICLE"] - if self.hyperparams["pred_ego_indicator"] == "none": - if "augment" in node_type_state and "ego_indicator" in node_type_state["augment"]: - raise ValueError("Inconsistent setting: state variables include ego_indicator but pred_ego_indicator=none") - else: - if "augment" not in node_type_state or "ego_indicator" not in node_type_state["augment"]: - raise ValueError("Inconsistent setting: state variables do not include ego_indicator but pred_ego_indicator != none") - assert np.isclose(hyperparams['dt'] / hyperparams['plan_dt'], int(hyperparams['dt'] / hyperparams['plan_dt'])), "dt must be a multiple of plan_dt" - - node_types = NodeTypeEnum(["VEHICLE", "PEDESTRIAN"]) - edge_types = list(itertools.product(node_types, repeat=2)) - # Build models for each agent type - self.node_models_dict = torch.nn.ModuleDict() - for node_type in node_types: - # Only add a Model for NodeTypes we want to predict - class EnvClass: - dt = self.hyperparams["dt"] - robot_type = "VEHICLE" - self.hyperparams["minimum_history_length"] = 1 - self.hyperparams["maximum_history_length"] = int(self.hyperparams["history_sec"] // self.hyperparams["dt"]) - self.hyperparams["prediction_horizon"] = int(self.hyperparams["prediction_sec"] // self.hyperparams["dt"]) - self.hyperparams["use_map_encoding"] = self.hyperparams["map_encoding"] - self.hyperparams["p_z_x_MLP_dims"] = None if self.hyperparams["p_z_x_MLP_dims"] == 0 else self.hyperparams["p_z_x_MLP_dims"] - self.hyperparams["q_z_xy_MLP_dims"] = None if self.hyperparams["q_z_xy_MLP_dims"] == 0 else self.hyperparams["q_z_xy_MLP_dims"] - - - if node_type.name in self.pred_state.keys(): - self.node_models_dict[node_type.name] = MultimodalGenerativeCVAE( - EnvClass(), - node_type, self.model_registrar, self.hyperparams, self.device, edge_types, - log_writer=(WrappedLogWriter(self.log_writer) if self.log_writer is not None else None) - ) - - def set_curr_iter(self, curr_iter): - super().set_curr_iter(curr_iter) - for node_str, model in self.node_models_dict.items(): - model.set_curr_iter(curr_iter) - - def set_annealing_params(self): - for node_str, model in self.node_models_dict.items(): - model.set_annealing_params() - - def step_annealers(self, node_type=None): - if node_type is None: - for node_type in self.node_models_dict: - self.node_models_dict[node_type].step_annealers() - else: - self.node_models_dict[str(node_type)].step_annealers() - - def train(self, inputs: Dict): - batch: AgentBatch = inputs["agent_batch"] - loss_weights: torch.Tensor = inputs["loss_weights"] - - # Make sure there is only one agent type - agent_types = batch.agent_types() - if len(agent_types) > 1: - raise NotImplementedError("Mixing agent types for prediction in a batch is not supported.") - agent_type: AgentType = agent_types[0] - - # Choose model for agent type - model: MultimodalGenerativeCVAE = self.node_models_dict[agent_type.name] - - inputs, inputs_st, first_history_indices, labels, labels_st, neighbors, neighbors_edge_value = self.parse_batch(batch) - - # Compute training loss - with nullcontext() if self.hyperparams["train_pred"] else torch.no_grad(): - loss, y_dist, (x_tensor, ) = model.train_loss( - inputs, - inputs_st, - first_history_indices, - labels, - labels_st, - neighbors, - neighbors_edge_value, - robot=None, - map=None, - prediction_horizon=self.hyperparams["prediction_horizon"], - ret_dist=True, loss_weights=loss_weights) - - return {"pred_dist": y_dist, "loss": loss} - - def validate(self, inputs: Dict): - batch: AgentBatch = inputs['agent_batch'] - pred_ml, pred_dist, extra_output = self._run_prediction(batch, prediction_horizon=batch.agent_fut.shape[1]) - metrics = self.validation_metrics(pred_dist, pred_ml, batch.agent_fut) - return {"pred_dist": pred_dist, "pred_ml": pred_ml, "metrics": metrics} - - def infer(self, inputs: Dict): - batch: AgentBatch = inputs['agent_batch'] - pred_ml, pred_dist, extra_output = self._run_prediction(batch, prediction_horizon=batch.agent_fut.shape[1]) - return {"pred_dist": pred_dist, "pred_ml": pred_ml} - - def validation_metrics(self, pred_dist, pred_ml, agent_fut): - # Compute default metrics - metrics = compute_prediction_metrics(pred_ml, agent_fut[..., :2], y_dists=pred_dist) - return metrics - - def _run_prediction(self, batch: AgentBatch, prediction_horizon: int): - - agent_types = batch.agent_types() - if len(agent_types) > 1: - raise NotImplementedError("Mixing agent types for prediction in a batch is not supported.") - node_type: AgentType = agent_types[0] - - model: MultimodalGenerativeCVAE = self.node_models_dict[node_type.name] - - inputs, inputs_st, first_history_indices, labels, labels_st, neighbors, neighbors_edge_value = self.parse_batch(batch) - - # Run forward pass, use the most likely latent mode. - predictions = model.predict( - inputs, - inputs_st, - first_history_indices, - neighbors, - neighbors_edge_value, - robot=None, - map=None, - prediction_horizon=prediction_horizon, - num_samples=1, - z_mode=True, - gmm_mode=True, - full_dist=False, - output_dists=False) - - # Run forward pass again, but this time output all modes of z. - y_dists, _, extra_output = model.predict( - inputs, - inputs_st, - first_history_indices, - neighbors, - neighbors_edge_value, - robot=None, - map=None, - prediction_horizon=prediction_horizon, - num_samples=1, - z_mode=False, - gmm_mode=False, - full_dist=True, - output_dists=True, - output_extra=True) - - return predictions, y_dists, extra_output - - def parse_batch(self, batch: AgentBatch): - """ Converts batch input to trajectron input """ - extended_addition_filter: torch.Tensor = torch.tensor(self.hyperparams['edge_addition_filter'], dtype=torch.float) - extended_addition_filter = torch.nn.functional.pad(extended_addition_filter, (0, batch.agent_hist.shape[1] - extended_addition_filter.shape[0]), mode='constant', value=1.0) - - if batch.history_pad_dir == PadDirection.AFTER: - bs = batch.agent_hist.shape[0] - x = convert_trajdata_hist_to_manual_hist(batch.agent_hist.cpu(), AgentType.VEHICLE, self.hyperparams["dt"]) - x = history_padding_last_to_first(x, batch.agent_hist_len.cpu()) - y = batch.agent_fut[..., :2].cpu() - - x_origin_batch = x[:, -1, :].cpu().numpy() - x_st_t = torch.stack([ - standardized_manual_state(x[bi], x_origin_batch[bi], "VEHICLE", self.hyperparams["dt"], only2d=True) - for bi in range(bs)], dim=0) - y_st_t = torch.stack([ - standardized_manual_state(y[bi], x_origin_batch[bi], "VEHICLE", self.hyperparams["dt"], only2d=True) - for bi in range(bs)], dim=0) - - first_history_index = batch.agent_hist.shape[1] - batch.agent_hist_len.cpu() - neighbors_data_st = {("VEHICLE", "VEHICLE"): [[] for _ in range(bs)], ("VEHICLE", "PEDESTRIAN"): [[] for _ in range(bs)]} - - batch_neigh_hist = batch.neigh_hist.cpu() - for bi in range(bs): - for ni in range(batch.num_neigh[bi]): - nhist = batch_neigh_hist[bi, ni] - agent_type = AgentType(int(batch.neigh_types[bi, ni].cpu())) - nhist = convert_trajdata_hist_to_manual_hist(nhist, agent_type, self.hyperparams["dt"]) - nhist = history_padding_last_to_first(nhist, batch.neigh_hist_len[bi, ni].cpu()) - nhist_st = standardized_manual_state(nhist, x_origin_batch[bi], agent_type.name, self.hyperparams["dt"], only2d=False) - if self.hyperparams["pred_ego_indicator"] != "none": - # augment with ego indicator - nhist_st = torch.nn.functional.pad(nhist_st, (0, 1), value=float(ni == 0)) - neighbors_data_st[("VEHICLE", agent_type.name)][bi].append(nhist_st) - - # Convert edge weights. They are the same for vehciles and pedestriancs. - if self.hyperparams["dynamic_edges"] == "yes": - edge_weight = [batch.extras["neigh_edge_weight"][bi][:batch.num_neigh[bi]] for bi in range(bs)] - neighbors_edge_value = {("VEHICLE", "VEHICLE"): edge_weight, ("VEHICLE", "PEDESTRIAN"): edge_weight} - else: - neighbors_edge_value = {("VEHICLE", "VEHICLE"): None, ("VEHICLE", "PEDESTRIAN"): None} - else: - assert False - - # augment with ego indicator - if self.hyperparams["pred_ego_indicator"] != "none": - x = torch.nn.functional.pad(x, (0, 1), value=0.) - x_st_t = torch.nn.functional.pad(x_st_t, (0, 1), value=0.) - - x = x.to(self.device) - y = y.to(self.device) - x_st_t = x_st_t.to(self.device) - y_st_t = y_st_t.to(self.device) - - return x, x_st_t, first_history_index, y, y_st_t, neighbors_data_st, neighbors_edge_value - - -class TrajectronPredictorWithCacheData(TrajectronPredictor): - - def parse_batch(self, batch: AgentBatch): - (first_history_index, - x_t, y_t, x_st_t, y_st_t, - neighbors_data_st, # dict of lists. edge_type -> [batch][neighbor]: Tensor(time, statedim). Represetns - neighbors_edge_value, - robot_traj_st_t, - map, neighbors_future_data, plan_data) = batch.extras["manual_inputs"] - - x = x_t.to(self.device) - y = y_t.to(self.device) - x_st_t = x_st_t.to(self.device) - y_st_t = y_st_t.to(self.device) - if robot_traj_st_t is not None: - robot_traj_st_t = robot_traj_st_t.to(self.device) - if type(map) == torch.Tensor: - map = map.to(self.device) - - # Restore encodings - neighbors_data_st = restore(neighbors_data_st) - neighbors_edge_value = restore(neighbors_edge_value) - neighbors_future_data = restore(neighbors_future_data) - plan_data = restore(plan_data) - - # augment with ego indicator - if self.hyperparams["pred_ego_indicator"] != "none": - ext_neighbor_states = {} - for edge_type, node_neighbor_states in neighbors_data_st.items(): - ext_node_neighbor_states = [] - for batch_i, neighbor_state in enumerate(node_neighbor_states): - ego_ind = plan_data['most_relevant_idx'][batch_i].int() - if edge_type[1] == "VEHICLE": - ext_node_neighbor_states.append([torch.nn.functional.pad(neighbor_state[i], (0, 1), value=(ego_ind == i).float()) for i in range(len(neighbor_state))]) - else: - ext_node_neighbor_states.append([torch.nn.functional.pad(neighbor_state[i], (0, 1), value=0.) for i in range(len(neighbor_state))]) - ext_neighbor_states[edge_type] = ext_node_neighbor_states - neighbors_data_st = ext_neighbor_states - - x = torch.nn.functional.pad(x, (0, 1)) - x_st_t = torch.nn.functional.pad(x_st_t, (0, 1)) - - return x, x_st_t, first_history_index, y, y_st_t, neighbors_data_st, neighbors_edge_value - - -class WrappedLogWriter(): - def __init__(self, wandb_writer) -> None: - self.wandb_writer = wandb_writer - - def add_scalar(self, name, value, iter): - self.wandb_writer.log({name: value}, step=iter, commit=False) - - def add_histogram(self, *args): - pass - - def add_image(self, *args): - pass - - -def history_padding_last_to_first(hist_last, history_len): - if hist_last.ndim > 2: - # recursive to itself - return torch.stack([ - history_padding_last_to_first(hist_last[bi], history_len[bi]) - for bi in range(hist_last.shape[0])], dim=0) - else: - hist_first = torch.full_like(hist_last, 0.) - hist_first[hist_first.shape[0]-history_len:] = hist_last[:history_len] - return hist_first \ No newline at end of file diff --git a/diffstack/modules/predictors/trajectron_utils/LICENSE b/diffstack/modules/predictors/trajectron_utils/LICENSE deleted file mode 100644 index 1d80c56..0000000 --- a/diffstack/modules/predictors/trajectron_utils/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2020 Stanford Autonomous Systems Lab - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/diffstack/modules/predictors/trajectron_utils/README.md b/diffstack/modules/predictors/trajectron_utils/README.md deleted file mode 100644 index c98dfe7..0000000 --- a/diffstack/modules/predictors/trajectron_utils/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Trajectron utilities - -This folder contains code that is partially copied, modified, or extended from [Trajectron++](https://github.com/StanfordASL/Trajectron-plus-plus) diff --git a/diffstack/modules/predictors/trajectron_utils/environment/map.py b/diffstack/modules/predictors/trajectron_utils/environment/map.py index 3e47f3c..4b20ddf 100644 --- a/diffstack/modules/predictors/trajectron_utils/environment/map.py +++ b/diffstack/modules/predictors/trajectron_utils/environment/map.py @@ -1,6 +1,9 @@ import torch import numpy as np -from diffstack.modules.predictors.trajectron_utils.model.dataset.homography_warper import get_rotation_matrix2d, warp_affine_crop +from diffstack.modules.predictors.trajectron_utils.model.dataset.homography_warper import ( + get_rotation_matrix2d, + warp_affine_crop, +) class Map(object): @@ -12,7 +15,7 @@ def __init__(self, data, homography, description=None): def as_image(self): raise NotImplementedError - def get_cropped_maps(self, world_pts, patch_size, rotation=None, device='cpu'): + def get_cropped_maps(self, world_pts, patch_size, rotation=None, device="cpu"): raise NotImplementedError def to_map_points(self, scene_pts): @@ -27,8 +30,9 @@ class GeometricMap(Map): :param data: Numpy array of shape [layers, x, y] :param homography: Numpy array of shape [3, 3] """ + def __init__(self, data, homography, description=None): - #assert isinstance(data.dtype, np.floating), "Geometric Maps must be float values." + # assert isinstance(data.dtype, np.floating), "Geometric Maps must be float values." super(GeometricMap, self).__init__(data, homography, description=description) self._last_padding = None @@ -51,11 +55,18 @@ def get_padded_map(self, padding_x, padding_y, device): return self._last_padded_map else: self._last_padding = (padding_x, padding_y) - self._last_padded_map = torch.full((self.data.shape[0], - self.data.shape[1] + 2 * padding_x, - self.data.shape[2] + 2 * padding_y), - False, dtype=torch.uint8) - self._last_padded_map[..., padding_x:-padding_x, padding_y:-padding_y] = self.torch_map(device) + self._last_padded_map = torch.full( + ( + self.data.shape[0], + self.data.shape[1] + 2 * padding_x, + self.data.shape[2] + 2 * padding_y, + ), + False, + dtype=torch.uint8, + ) + self._last_padded_map[ + ..., padding_x:-padding_x, padding_y:-padding_y + ] = self.torch_map(device) return self._last_padded_map @staticmethod @@ -73,13 +84,16 @@ def batch_rotate(map_batched, centers, angles, out_height, out_width): :return: """ M = get_rotation_matrix2d(centers, angles, torch.ones_like(angles)) - rotated_map_batched = warp_affine_crop(map_batched, centers, M, - dsize=(out_height, out_width), padding_mode='zeros') + rotated_map_batched = warp_affine_crop( + map_batched, centers, M, dsize=(out_height, out_width), padding_mode="zeros" + ) return rotated_map_batched @classmethod - def get_cropped_maps_from_scene_map_batch(cls, maps, scene_pts, patch_size, rotation=None, device='cpu'): + def get_cropped_maps_from_scene_map_batch( + cls, maps, scene_pts, patch_size, rotation=None, device="cpu" + ): """ Returns rotated patches of each map around the transformed scene points. ___________________ @@ -111,39 +125,64 @@ def get_cropped_maps_from_scene_map_batch(cls, maps, scene_pts, patch_size, rota context_padding_x = int(np.ceil(np.sqrt(2) * lat_size)) context_padding_y = int(np.ceil(np.sqrt(2) * long_size)) - centers = torch.tensor([s_map.to_map_points(scene_pts[np.newaxis, i]) for i, s_map in enumerate(maps)], - dtype=torch.long, device=device).squeeze(dim=1) \ - + torch.tensor([context_padding_x, context_padding_y], device=device, dtype=torch.long) - - padded_map = [s_map.get_padded_map(context_padding_x, context_padding_y, device=device) for s_map in maps] - - padded_map_batched = torch.stack([padded_map[i][..., - centers[i, 0] - context_padding_x: centers[i, 0] + context_padding_x, - centers[i, 1] - context_padding_y: centers[i, 1] + context_padding_y] - for i in range(centers.shape[0])], dim=0) - - center_patches = torch.tensor([[context_padding_y, context_padding_x]], - dtype=torch.int, - device=device).repeat(batch_size, 1) + centers = torch.tensor( + [ + s_map.to_map_points(scene_pts[np.newaxis, i]) + for i, s_map in enumerate(maps) + ], + dtype=torch.long, + device=device, + ).squeeze(dim=1) + torch.tensor( + [context_padding_x, context_padding_y], device=device, dtype=torch.long + ) + + padded_map = [ + s_map.get_padded_map(context_padding_x, context_padding_y, device=device) + for s_map in maps + ] + + padded_map_batched = torch.stack( + [ + padded_map[i][ + ..., + centers[i, 0] + - context_padding_x : centers[i, 0] + + context_padding_x, + centers[i, 1] + - context_padding_y : centers[i, 1] + + context_padding_y, + ] + for i in range(centers.shape[0]) + ], + dim=0, + ) + + center_patches = torch.tensor( + [[context_padding_y, context_padding_x]], dtype=torch.int, device=device + ).repeat(batch_size, 1) if rotation is not None: angles = torch.Tensor(rotation) else: angles = torch.zeros(batch_size) - rotated_map_batched = cls.batch_rotate(padded_map_batched/255., - center_patches.float(), - angles, - long_size, - lat_size) + rotated_map_batched = cls.batch_rotate( + padded_map_batched / 255.0, + center_patches.float(), + angles, + long_size, + lat_size, + ) del padded_map_batched - return rotated_map_batched[..., - long_size_half - patch_size[1]:(long_size_half + patch_size[3]), - lat_size_half - patch_size[0]:(lat_size_half + patch_size[2])] + return rotated_map_batched[ + ..., + long_size_half - patch_size[1] : (long_size_half + patch_size[3]), + lat_size_half - patch_size[0] : (lat_size_half + patch_size[2]), + ] - def get_cropped_maps(self, scene_pts, patch_size, rotation=None, device='cpu'): + def get_cropped_maps(self, scene_pts, patch_size, rotation=None, device="cpu"): """ Returns rotated patches of the map around the transformed scene points. ___________________ @@ -163,8 +202,13 @@ def get_cropped_maps(self, scene_pts, patch_size, rotation=None, device='cpu'): :param device: Device on which the rotated tensors should be returned. :return: Rotated and cropped tensor patches. """ - return self.get_cropped_maps_from_scene_map_batch([self]*scene_pts.shape[0], scene_pts, - patch_size, rotation=rotation, device=device) + return self.get_cropped_maps_from_scene_map_batch( + [self] * scene_pts.shape[0], + scene_pts, + patch_size, + rotation=rotation, + device=device, + ) def to_map_points(self, scene_pts): org_shape = None @@ -180,6 +224,6 @@ def to_map_points(self, scene_pts): return map_points -class ImageMap(Map): # TODO Implement for image maps -> watch flipped coordinate system +class ImageMap(Map): def __init__(self): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/diffstack/modules/predictors/trajectron_utils/model/components/gmm2d.py b/diffstack/modules/predictors/trajectron_utils/model/components/gmm2d.py index 48639e3..7d47ba0 100644 --- a/diffstack/modules/predictors/trajectron_utils/model/components/gmm2d.py +++ b/diffstack/modules/predictors/trajectron_utils/model/components/gmm2d.py @@ -51,7 +51,10 @@ def __init__(self, log_pis, mus, log_sigmas, corrs): dim=-1)], dim=-2) - self.pis_cat_dist = td.Categorical(logits=log_pis) + if log_pis.numel() == 0: + self.pis_cat_dist = None + else: + self.pis_cat_dist = td.Categorical(logits=log_pis) def set_device(self, device): self.device = device @@ -62,7 +65,8 @@ def set_device(self, device): self.one_minus_rho2 = self.one_minus_rho2.to(device) self.corrs = self.corrs.to(device) self.L = self.L.to(device) - self.pis_cat_dist = td.Categorical(logits=self.pis_cat_dist.logits.to(device)) + if self.pis_cat_dist is not None: + self.pis_cat_dist = td.Categorical(logits=self.pis_cat_dist.logits.to(device)) @classmethod def from_log_pis_mus_cov_mats(cls, log_pis, mus, cov_mats): @@ -91,7 +95,11 @@ def rsample(self, sample_shape=torch.Size()): dim=-1) ), dim=-1)) - component_cat_samples = self.pis_cat_dist.sample(sample_shape) + if self.pis_cat_dist is None: + raise NotImplementedError # not tested + component_cat_samples = torch.zeros(sample_shape) + else: + component_cat_samples = self.pis_cat_dist.sample(sample_shape) selector = torch.unsqueeze(to_one_hot(component_cat_samples, self.components), dim=-1) return torch.sum(mvn_samples*selector, dim=-2) @@ -110,7 +118,7 @@ def log_prob(self, value): """ # x: [..., 2] value = torch.unsqueeze(value, dim=-2) # [..., 1, 2] - dx = value - self.mus # [..., N, 2] + dx = value - self.mus[..., :value.shape[-1]] # [..., N, 2] exp_nominator = ((torch.sum((dx/self.sigmas)**2, dim=-1) # first and second term of exp nominator - 2*self.corrs*torch.prod(dx, dim=-1)/torch.prod(self.sigmas, dim=-1))) # [..., N] diff --git a/diffstack/modules/predictors/trajectron_utils/model/dataset/dataset.py b/diffstack/modules/predictors/trajectron_utils/model/dataset/dataset.py index 2302e91..c0da26c 100644 --- a/diffstack/modules/predictors/trajectron_utils/model/dataset/dataset.py +++ b/diffstack/modules/predictors/trajectron_utils/model/dataset/dataset.py @@ -9,31 +9,54 @@ except ImportError: from functools import reduce # Required in Python 3 import operator + def prod(iterable): return reduce(operator.mul, iterable, 1) + from .preprocessing import get_node_timestep_data from tqdm import tqdm -from diffstack.modules.predictors.trajectron_utils.environment import EnvironmentMetadata +from diffstack.modules.predictors.trajectron_utils.environment import ( + EnvironmentMetadata, +) from functools import partial from pathos.multiprocessing import ProcessPool as Pool class EnvironmentDataset(object): - def __init__(self, env, state, pred_state, node_freq_mult, scene_freq_mult, hyperparams, **kwargs): + def __init__( + self, + env, + state, + pred_state, + node_freq_mult, + scene_freq_mult, + hyperparams, + **kwargs, + ): self.env = env self.state = state self.pred_state = pred_state self.hyperparams = hyperparams - self.max_ht = self.hyperparams['maximum_history_length'] - self.max_ft = kwargs['min_future_timesteps'] + self.max_ht = self.hyperparams["maximum_history_length"] + self.max_ft = kwargs["min_future_timesteps"] self.node_type_datasets = list() self._augment = False for node_type in env.NodeType: - if node_type not in hyperparams['pred_state']: + if node_type not in hyperparams["pred_state"]: continue - self.node_type_datasets.append(NodeTypeDataset(env, node_type, state, pred_state, node_freq_mult, - scene_freq_mult, hyperparams, **kwargs)) + self.node_type_datasets.append( + NodeTypeDataset( + env, + node_type, + state, + pred_state, + node_freq_mult, + scene_freq_mult, + hyperparams, + **kwargs, + ) + ) @property def augment(self): @@ -49,11 +72,22 @@ def __iter__(self): return iter(self.node_type_datasets) -def parallel_process_scene(scene, env_metadata, node_type, - state, pred_state, edge_types, - max_ht, max_ft, - node_freq_mult, scene_freq_mult, - hyperparams, augment, nusc_maps, kwargs): +def parallel_process_scene( + scene, + env_metadata, + node_type, + state, + pred_state, + edge_types, + max_ht, + max_ft, + node_freq_mult, + scene_freq_mult, + hyperparams, + augment, + nusc_maps, + kwargs, +): results = list() indexing_info = list() @@ -66,90 +100,132 @@ def parallel_process_scene(scene, env_metadata, node_type, scene_aug = scene.augment() node_aug = scene.get_node_by_id(node.id) - scene_data = get_node_timestep_data(env_metadata, scene_aug, t, node_aug, state, pred_state, - edge_types, max_ht, max_ft, hyperparams, nusc_maps) + scene_data = get_node_timestep_data( + env_metadata, + scene_aug, + t, + node_aug, + state, + pred_state, + edge_types, + max_ht, + max_ft, + hyperparams, + nusc_maps, + ) else: - scene_data = get_node_timestep_data(env_metadata, scene, t, node, state, pred_state, - edge_types, max_ht, max_ft, hyperparams, nusc_maps) + scene_data = get_node_timestep_data( + env_metadata, + scene, + t, + node, + state, + pred_state, + edge_types, + max_ht, + max_ft, + hyperparams, + nusc_maps, + ) - results += [( - scene_data, - (scene, t, node) - )] + results += [(scene_data, (scene, t, node))] - indexing_info += [( - scene.frequency_multiplier if scene_freq_mult else 1, - node.frequency_multiplier if node_freq_mult else 1 - )] + indexing_info += [ + ( + scene.frequency_multiplier if scene_freq_mult else 1, + node.frequency_multiplier if node_freq_mult else 1, + ) + ] return (results, indexing_info) class NodeTypeDataset(data.Dataset): - def __init__(self, env, node_type, state, pred_state, node_freq_mult, - scene_freq_mult, hyperparams, augment=False, **kwargs): + def __init__( + self, + env, + node_type, + state, + pred_state, + node_freq_mult, + scene_freq_mult, + hyperparams, + augment=False, + **kwargs, + ): self.env = env self.env_metadata = EnvironmentMetadata(env) self.state = state self.pred_state = pred_state self.hyperparams = hyperparams - self.max_ht = self.hyperparams['maximum_history_length'] - self.max_ft = kwargs['min_future_timesteps'] + self.max_ht = self.hyperparams["maximum_history_length"] + self.max_ft = kwargs["min_future_timesteps"] self.augment = augment self.node_type = node_type - self.edge_types = [edge_type for edge_type in env.get_edge_types() if edge_type[0] is node_type] - self.index, self.data, self.data_origin = self.index_env(node_freq_mult, scene_freq_mult, **kwargs) + self.edge_types = [ + edge_type for edge_type in env.get_edge_types() if edge_type[0] is node_type + ] + self.index, self.data, self.data_origin = self.index_env( + node_freq_mult, scene_freq_mult, **kwargs + ) def index_env(self, node_freq_mult, scene_freq_mult, **kwargs): - num_cpus = kwargs['num_workers'] - del kwargs['num_workers'] + num_cpus = kwargs["num_workers"] + del kwargs["num_workers"] - rank = kwargs['rank'] - del kwargs['rank'] + rank = kwargs["rank"] + del kwargs["rank"] if num_cpus > 0: with Pool(num_cpus) as pool: indexed_scenes = list( tqdm( pool.imap( - partial(parallel_process_scene, - env_metadata=self.env_metadata, - node_type=self.node_type, - state=self.state, - pred_state=self.pred_state, - edge_types=self.edge_types, - max_ht=self.max_ht, - max_ft=self.max_ft, - node_freq_mult=node_freq_mult, - scene_freq_mult=scene_freq_mult, - hyperparams=self.hyperparams, - augment=self.augment, - nusc_maps=self.env.nusc_maps, - kwargs=kwargs), - self.env.scenes + partial( + parallel_process_scene, + env_metadata=self.env_metadata, + node_type=self.node_type, + state=self.state, + pred_state=self.pred_state, + edge_types=self.edge_types, + max_ht=self.max_ht, + max_ft=self.max_ft, + node_freq_mult=node_freq_mult, + scene_freq_mult=scene_freq_mult, + hyperparams=self.hyperparams, + augment=self.augment, + nusc_maps=self.env.nusc_maps, + kwargs=kwargs, + ), + self.env.scenes, ), - desc=f'Indexing {self.node_type}s ({num_cpus} CPUs)', + desc=f"Indexing {self.node_type}s ({num_cpus} CPUs)", total=len(self.env.scenes), - disable=(rank > 0) + disable=(rank > 0), ) ) else: - indexed_scenes = [parallel_process_scene(scene, - env_metadata=self.env_metadata, - node_type=self.node_type, - state=self.state, - pred_state=self.pred_state, - edge_types=self.edge_types, - max_ht=self.max_ht, - max_ft=self.max_ft, - node_freq_mult=node_freq_mult, - scene_freq_mult=scene_freq_mult, - hyperparams=self.hyperparams, - augment=self.augment, - nusc_maps=self.env.nusc_maps, - kwargs=kwargs) for scene in self.env.scenes] + indexed_scenes = [ + parallel_process_scene( + scene, + env_metadata=self.env_metadata, + node_type=self.node_type, + state=self.state, + pred_state=self.pred_state, + edge_types=self.edge_types, + max_ht=self.max_ht, + max_ft=self.max_ft, + node_freq_mult=node_freq_mult, + scene_freq_mult=scene_freq_mult, + hyperparams=self.hyperparams, + augment=self.augment, + nusc_maps=self.env.nusc_maps, + kwargs=kwargs, + ) + for scene in self.env.scenes + ] results = list() indexing_info = list() @@ -161,7 +237,7 @@ def index_env(self, node_freq_mult, scene_freq_mult, **kwargs): for i, counts in enumerate(indexing_info): total = prod(counts) - index += [i]*total + index += [i] * total data, data_origin = zip(*results) @@ -172,25 +248,46 @@ def __len__(self): def preprocess_online(self, ind): batch = self.data[ind] - (first_history_index, x_t, y_t, x_st_t, y_st_t, - neighbors_data_st, neighbors_edge_value, robot_traj_st_t, - map_input, neighbors_future_data, plan_data) = batch - - if 'map_name' not in plan_data: + ( + first_history_index, + x_t, + y_t, + x_st_t, + y_st_t, + neighbors_data_st, + neighbors_edge_value, + robot_traj_st_t, + map_input, + neighbors_future_data, + plan_data, + ) = batch + + if "map_name" not in plan_data: scene, _, _ = self.data_origin[ind] - plan_data['map_name'] = str(scene.map_name) - plan_data['scene_offset'] = torch.Tensor([scene.x_min, scene.y_min], device='cpu').float() - if 'most_relevant_nearby_lane_tokens' not in plan_data: - plan_data['most_relevant_nearby_lane_tokens'] = None - - return (first_history_index, x_t, y_t, x_st_t, y_st_t, - neighbors_data_st, neighbors_edge_value, robot_traj_st_t, - map_input, neighbors_future_data, plan_data) + plan_data["map_name"] = str(scene.map_name) + plan_data["scene_offset"] = torch.Tensor( + [scene.x_min, scene.y_min], device="cpu" + ).float() + if "most_relevant_nearby_lane_tokens" not in plan_data: + plan_data["most_relevant_nearby_lane_tokens"] = None + + return ( + first_history_index, + x_t, + y_t, + x_st_t, + y_st_t, + neighbors_data_st, + neighbors_edge_value, + robot_traj_st_t, + map_input, + neighbors_future_data, + plan_data, + ) def __getitem__(self, i): - # TODO (pkarkus) this seems to lead to memory leak # https://pytorch.org/docs/master/data.html#torch.utils.data.distributed.DistributedSampler - # https://github.com/pytorch/pytorch/issues/13246#issuecomment-905703662 + # https://github.com/pytorch/pytorch/issues/13246#issuecomment-905703662 # return self.data[self.index[i]] return self.preprocess_online(self.index[i]) @@ -200,7 +297,18 @@ def filter(self, filter_fn, verbose=False): len_start = len(self) if filter_fn is not None: - self.index = np.fromiter((i for i in self.index if filter_fn(self.data[i])), dtype=self.index.dtype) + self.index = np.fromiter( + (i for i in self.index if filter_fn(self.data[i])), + dtype=self.index.dtype, + ) if verbose: - print ("Filter: kept %d/%d (%.1f%%) of samples. Filtering took %.1fs."%(len(self), len_start, len(self)/len_start*100., time.time() - tstart)) + print( + "Filter: kept %d/%d (%.1f%%) of samples. Filtering took %.1fs." + % ( + len(self), + len_start, + len(self) / len_start * 100.0, + time.time() - tstart, + ) + ) diff --git a/diffstack/modules/predictors/trajectron_utils/model/dataset/preprocessing.py b/diffstack/modules/predictors/trajectron_utils/model/dataset/preprocessing.py index a1626ff..a110435 100644 --- a/diffstack/modules/predictors/trajectron_utils/model/dataset/preprocessing.py +++ b/diffstack/modules/predictors/trajectron_utils/model/dataset/preprocessing.py @@ -9,19 +9,24 @@ container_abcs = collections.abc + # Wrapper around dict to identify batchable dict data. class batchable_dict(dict): pass + class batchable_list(list): pass + class batchable_nonuniform_tensor(torch.Tensor): pass + def np_unstack(a, axis=0): return np.moveaxis(a, axis, 0) + def restore(data): """ In case we dilled some structures to share between multiple process this function will restore them. @@ -41,20 +46,27 @@ def collate(batch): elem = batch[0] if elem is None: return None - elif isinstance(elem, str) or isinstance(elem, batchable_list) or isinstance(elem, batchable_nonuniform_tensor): - # TODO isinstance(elem, batchable_nonuniform_tensor) is never true, perhaps some import path comparison issue + elif ( + isinstance(elem, str) + or isinstance(elem, batchable_list) + or isinstance(elem, batchable_nonuniform_tensor) + ): return dill.dumps(batch) if torch.utils.data.get_worker_info() else batch elif isinstance(elem, container_abcs.Sequence): - if len(elem) == 4: # We assume those are the maps, map points, headings and patch_size + if ( + len(elem) == 4 + ): # We assume those are the maps, map points, headings and patch_size scene_map, scene_pts, heading_angle, patch_size = zip(*batch) if heading_angle[0] is None: heading_angle = None else: heading_angle = torch.Tensor(heading_angle) - map = scene_map[0].get_cropped_maps_from_scene_map_batch(scene_map, - scene_pts=torch.Tensor(scene_pts), - patch_size=patch_size[0], - rotation=heading_angle) + map = scene_map[0].get_cropped_maps_from_scene_map_batch( + scene_map, + scene_pts=torch.Tensor(scene_pts), + patch_size=patch_size[0], + rotation=heading_angle, + ) return map transposed = zip(*batch) return [collate(samples) for samples in transposed] @@ -62,40 +74,45 @@ def collate(batch): # We dill the dictionary for the same reason as the neighbors structure (see below). # Unlike for neighbors where we keep a list, here we collate elements recursively data_dict = {key: collate([d[key] for d in batch]) for key in elem} - return dill.dumps(data_dict) if torch.utils.data.get_worker_info() else data_dict + return ( + dill.dumps(data_dict) if torch.utils.data.get_worker_info() else data_dict + ) elif isinstance(elem, container_abcs.Mapping): # We have to dill the neighbors structures. Otherwise each tensor is put into # shared memory separately -> slow, file pointer overhead # we only do this in multiprocessing neighbor_dict = {key: [d[key] for d in batch] for key in elem} - return dill.dumps(neighbor_dict) if torch.utils.data.get_worker_info() else neighbor_dict + return ( + dill.dumps(neighbor_dict) + if torch.utils.data.get_worker_info() + else neighbor_dict + ) try: return default_collate(batch) except RuntimeError: # This happens when tensors are not of the same shape. return dill.dumps(batch) if torch.utils.data.get_worker_info() else batch -def get_relevant_lanes_np(nusc_map, scene_offset, position, yaw, vel, dt): - # TODO we should get the lanes from three anchor points: +def get_relevant_lanes_np(nusc_map, scene_offset, position, yaw, vel, dt): # - max deceleration # - [constant speed] # - max acceleration # Normally we should do this along the current lane -- but if that's hard to - # determin reliably, an alternative is to consider both ego-centric and lane centric ways. + # determin reliably, an alternative is to consider both ego-centric and lane centric ways. # Simple alternative: use goal state for lane selection. - # This is biased e.g. when stopped at a traffic light, and we will miss on the + # This is biased e.g. when stopped at a traffic light, and we will miss on the # lanes across the intersection when gt was stopped, and we will miss on the current - # nearby lanes when gt accelerates rapidly. + # nearby lanes when gt accelerates rapidly. interp_vel = max(abs(vel), 0.2) - x = position[0]+scene_offset[0] - y = position[1]+scene_offset[1] + x = position[0] + scene_offset[0] + y = position[1] + scene_offset[1] state = np.hstack((position, yaw)) # t = time.time() - lanes = nusc_map.get_records_in_radius(x, y, 20.0, ['lane', 'lane_connector']) - lanes = lanes['lane'] + lanes['lane_connector'] + lanes = nusc_map.get_records_in_radius(x, y, 20.0, ["lane", "lane_connector"]) + lanes = lanes["lane"] + lanes["lane_connector"] # t_near = time.time() - t # t = time.time() relevant_lanes = list() @@ -109,7 +126,7 @@ def get_relevant_lanes_np(nusc_map, scene_offset, position, yaw, vel, dt): poses = np.array(poses) poses[:, 0:2] -= scene_offset delta_x, delta_y, dpsi = batch_proj(state, poses) - if abs(dpsi[0]) < np.pi/4 and np.min(np.abs(delta_y)) < 4.5: + if abs(dpsi[0]) < np.pi / 4 and np.min(np.abs(delta_y)) < 4.5: relevant_lanes.append(poses) relevant_lane_tokens.append(str(lane)) relevant_lane_arclines.append(lane_arcline) @@ -118,15 +135,12 @@ def get_relevant_lanes_np(nusc_map, scene_offset, position, yaw, vel, dt): def get_relative_robot_traj(env, state, node_traj, robot_traj, node_type, robot_type): - # TODO: We will have to make this more generic if robot_type != node_type # Make Robot State relative to node _, std = env.get_standardize_params(state[robot_type], node_type=robot_type) std[0:2] = env.attention_radius[(node_type, robot_type)] - robot_traj_st = env.standardize(robot_traj, - state[robot_type], - node_type=robot_type, - mean=node_traj, - std=std) + robot_traj_st = env.standardize( + robot_traj, state[robot_type], node_type=robot_type, mean=node_traj, std=std + ) robot_traj_st_t = torch.tensor(robot_traj_st, dtype=torch.float) return robot_traj_st_t @@ -137,11 +151,10 @@ def pred_state_to_plan_state(pred_state): input: x, y, vx, vy, ax, ay, heading, delta_heading output: x, y, heading, v, acc, delta_heading """ - x, y, vx, vy, ax, ay, h, dh = np_unstack(pred_state, -1) + x, y, vx, vy, ax, ay, h, dh = np_unstack(pred_state, -1) v = np.linalg.norm(np.stack((vx, vy), axis=-1), axis=-1, keepdims=False) - # TODO some dataset versions after v6 set USE_KALMAN_FILTER=True which computes acceleration differently. - a = np.linalg.norm(np.stack((ax, ay), axis=-1), axis=-1, keepdims=False) - # a = np.divide(ax * vx + ay * vy, v, out=np.zeros_like(ax), where=(v > 1.)) + a = np.linalg.norm(np.stack((ax, ay), axis=-1), axis=-1, keepdims=False) + # a = np.divide(ax * vx + ay * vy, v, out=np.zeros_like(ax), where=(v > 1.)) plan_state = np.stack([x, y, h, v, a, dh], axis=-1) # assert np.isclose(pred_state, plan_state_to_pred_state(plan_state)).all() # accelerations mismatch, the calculation of ax and ay cannot be recovered from a_norm and heading @@ -169,14 +182,14 @@ def get_node_closest_to_robot(scene, t, node_type=None, nodes=None): :param scene: Scene :param t: Timestep in scene """ - get_pose = lambda n: n.get(np.array([t, t]), {'position': ['x', 'y']}, padding=0.0) + get_pose = lambda n: n.get(np.array([t, t]), {"position": ["x", "y"]}, padding=0.0) node_dist = lambda a, b: np.linalg.norm(get_pose(a) - get_pose(b)) robot_node = scene.robot closest_node = None closest_dist = None - if nodes is None: + if nodes is None: nodes = scene.nodes for node in nodes: @@ -192,10 +205,22 @@ def get_node_closest_to_robot(scene, t, node_type=None, nodes=None): return closest_node -def get_node_timestep_data(env, scene, t, node, state, pred_state, - edge_types, max_ht, max_ft, hyperparams, - nusc_maps, scene_graph=None, - is_closed_loop=False, closed_loop_ego_hist=None): +def get_node_timestep_data( + env, + scene, + t, + node, + state, + pred_state, + edge_types, + max_ht, + max_ft, + hyperparams, + nusc_maps, + scene_graph=None, + is_closed_loop=False, + closed_loop_ego_hist=None, +): """ Pre-processes the data for a single batch element: node state over time for a specific time in a specific scene as well as the neighbour data for it. @@ -220,25 +245,51 @@ def get_node_timestep_data(env, scene, t, node, state, pred_state, timestep_range_plan = np.array([t, t + max_ft]) plan_vehicle_state_dict = { - 'position': ['x', 'y'], 'heading': ['°'], 'velocity': ['norm']} - plan_vehicle_features_dict = { - 'heading': ['d°'], 'acceleration': ['norm2'], # control - 'lane': ['x', 'y', '°'], 'projected': ['x', 'y'], # lane and lane-projected states - 'control_traj_dh': ["t"+str(i) for i in range(max_ft+1)], # future fitted controls - 'control_traj_a': ["t"+str(i) for i in range(max_ft+1)], - # 'control_goal_dh': ["t"+str(i) for i in range(max_ft+1)], # future fitted controls - # 'control_goal_a': ["t"+str(i) for i in range(max_ft+1)], + "position": ["x", "y"], + "heading": ["°"], + "velocity": ["norm"], + } + if hyperparams["dataset_version"] == "v2": + plan_vehicle_features_dict = { + "heading": ["d°"], + "acceleration": ["norm2"], # control + "lane": ["x", "y", "°"], + "projected": ["x", "y"], # lane and lane-projected states + "control_dh": [ + "t" + str(i) for i in range(max_ft + 1) + ], # future fitted controls + "control_a": ["t" + str(i) for i in range(max_ft + 1)], + } + else: + plan_vehicle_features_dict = { + "heading": ["d°"], + "acceleration": ["norm2"], # control + "lane": ["x", "y", "°"], + "projected": ["x", "y"], # lane and lane-projected states + "control_traj_dh": [ + "t" + str(i) for i in range(max_ft + 1) + ], # future fitted controls + "control_traj_a": ["t" + str(i) for i in range(max_ft + 1)], + # 'control_goal_dh': ["t"+str(i) for i in range(max_ft+1)], # future fitted controls + # 'control_goal_a': ["t"+str(i) for i in range(max_ft+1)], } - plan_vehicle_features_dict_old = { - 'heading': ['d°'], 'acceleration': ['norm2'], # control - 'lane': ['x', 'y', '°'], 'projected': ['x', 'y'], # lane and lane-projected states - 'control_dh': [i for i in range(max_ft+1)], # future fitted controls - 'control_a': [i for i in range(max_ft+1)], - } - plan_pedestrian_state_dict = {'position': ['x', 'y'],} + plan_vehicle_features_dict_old = { + "heading": ["d°"], + "acceleration": ["norm2"], # control + "lane": ["x", "y", "°"], + "projected": ["x", "y"], # lane and lane-projected states + "control_dh": [i for i in range(max_ft + 1)], # future fitted controls + "control_a": [i for i in range(max_ft + 1)], + } + plan_pedestrian_state_dict = { + "position": ["x", "y"], + } # Filter fields not in data - state = {nk: {k: v for k, v in ndict.items() if k != "augment"} for nk, ndict in state.items()} + state = { + nk: {k: v for k, v in ndict.items() if k != "augment"} + for nk, ndict in state.items() + } x = node.get(timestep_range_x, state[node.type]) y = node.get(timestep_range_y, pred_state[node.type]) @@ -252,7 +303,9 @@ def get_node_timestep_data(env, scene, t, node, state, pred_state, rel_state = np.zeros_like(x[0]) rel_state[0:2] = x_origin[0:2] x_st = env.standardize(x, state[node.type], node.type, mean=rel_state, std=std) - if list(pred_state[node.type].keys())[0] == 'position': # If we predict position we do it relative to current pos + if ( + list(pred_state[node.type].keys())[0] == "position" + ): # If we predict position we do it relative to current pos y_st = env.standardize(y, pred_state[node.type], node.type, mean=rel_state[0:2]) else: y_st = env.standardize(y, pred_state[node.type], node.type) @@ -269,12 +322,18 @@ def get_node_timestep_data(env, scene, t, node, state, pred_state, neighbors_edge_value = None neighbors_future_data = None plan_data = None - if hyperparams['edge_encoding']: + if hyperparams["edge_encoding"]: # Scene Graph - scene_graph = scene.get_scene_graph(t, - env.attention_radius, - hyperparams['edge_addition_filter'], - hyperparams['edge_removal_filter']) if scene_graph is None else scene_graph + scene_graph = ( + scene.get_scene_graph( + t, + env.attention_radius, + hyperparams["edge_addition_filter"], + hyperparams["edge_removal_filter"], + ) + if scene_graph is None + else scene_graph + ) neighbors_data_not_st = dict() # closed loop logged_robot_data = None # closed loop @@ -292,76 +351,123 @@ def get_node_timestep_data(env, scene, t, node, state, pred_state, # We get all nodes which are connected to the current node for the current timestep connected_nodes = scene_graph.get_neighbors(node, edge_type[1]) - if hyperparams['dynamic_edges'] == 'yes': + if hyperparams["dynamic_edges"] == "yes": # We get the edge masks for the current node at the current timestep - edge_masks = torch.tensor(scene_graph.get_edge_scaling(node), dtype=torch.float) + edge_masks = torch.tensor( + scene_graph.get_edge_scaling(node), dtype=torch.float + ) neighbors_edge_value[edge_type] = edge_masks for n_i, connected_node in enumerate(connected_nodes): - neighbor_state_np = connected_node.get(timestep_range_x, - state[connected_node.type], - padding=0.0) + neighbor_state_np = connected_node.get( + timestep_range_x, state[connected_node.type], padding=0.0 + ) # Closed loop, replace ego trajectory if is_closed_loop and connected_node.is_robot: # assert logged_robot_data is None, "We can only replace robot trajectory once, and we already did that." # We expect to get here twice, once for (VEHICLE, ROBOT) and once for (PEDESTRIAN, ROBOT) - logged_robot_data = torch.tensor(neighbor_state_np, dtype=torch.float) + logged_robot_data = torch.tensor( + neighbor_state_np, dtype=torch.float + ) if closed_loop_ego_hist is not None: - assert neighbor_state_np.shape[0] <= closed_loop_ego_hist.shape[0] - assert neighbor_state_np.shape[1] == closed_loop_ego_hist.shape[1] - neighbor_state_np = closed_loop_ego_hist.copy() - + assert ( + neighbor_state_np.shape[0] <= closed_loop_ego_hist.shape[0] + ) + assert ( + neighbor_state_np.shape[1] == closed_loop_ego_hist.shape[1] + ) + neighbor_state_np = closed_loop_ego_hist.copy() + # Make State relative to node where neighbor and node have same state - _, std = env.get_standardize_params(state[connected_node.type], node_type=connected_node.type) + _, std = env.get_standardize_params( + state[connected_node.type], node_type=connected_node.type + ) std[0:2] = env.attention_radius[edge_type] equal_dims = np.min((neighbor_state_np.shape[-1], x.shape[-1])) rel_state = np.zeros_like(neighbor_state_np) rel_state[:, ..., :equal_dims] = x_origin[..., :equal_dims] - neighbor_state_np_st = env.standardize(neighbor_state_np, - state[connected_node.type], - node_type=connected_node.type, - mean=rel_state, - std=std) + neighbor_state_np_st = env.standardize( + neighbor_state_np, + state[connected_node.type], + node_type=connected_node.type, + mean=rel_state, + std=std, + ) neighbor_state = torch.tensor(neighbor_state_np_st, dtype=torch.float) neighbors_data_st[edge_type].append(neighbor_state) if is_closed_loop: - neighbors_data_not_st[edge_type].append(torch.tensor(neighbor_state_np, dtype=torch.float)) + neighbors_data_not_st[edge_type].append( + torch.tensor(neighbor_state_np, dtype=torch.float) + ) # Add future states for all neighbors. Standardize with same origin and std. - if edge_type[1] == "VEHICLE": - assert connected_node.type == "VEHICLE" - - try: - neighbor_future_features_np = np.concatenate(( - # x, y, orient, vel - connected_node.get(timestep_range_plan, - plan_vehicle_state_dict, - padding=np.nan), - # d_orient, acc_norm | lane x y heading | projected x y | future_controls steer*ph + acc*ph - connected_node.get(timestep_range_plan, - plan_vehicle_features_dict, - padding=np.nan), - ), axis=-1) + if edge_type[1] == "VEHICLE": + assert connected_node.type == "VEHICLE" + + try: + neighbor_future_features_np = np.concatenate( + ( + # x, y, orient, vel + connected_node.get( + timestep_range_plan, + plan_vehicle_state_dict, + padding=np.nan, + ), + # d_orient, acc_norm | lane x y heading | projected x y | future_controls steer*ph + acc*ph + connected_node.get( + timestep_range_plan, + plan_vehicle_features_dict, + padding=np.nan, + ), + ), + axis=-1, + ) except KeyError: - neighbor_future_features_np = np.concatenate(( - # x, y, orient, vel - connected_node.get(timestep_range_plan, - plan_vehicle_state_dict, - padding=np.nan), - # d_orient, acc_norm | lane x y heading | projected x y | future_controls steer*ph + acc*ph - connected_node.get(timestep_range_plan, - plan_vehicle_features_dict_old, - padding=np.nan), - ), axis=-1) - - if is_closed_loop and closed_loop_ego_hist is not None and connected_node.is_robot: - neighbor_future_features_np[0, :6] = pred_state_to_plan_state(closed_loop_ego_hist[-1]) - + neighbor_future_features_np = np.concatenate( + ( + # x, y, orient, vel + connected_node.get( + timestep_range_plan, + plan_vehicle_state_dict, + padding=np.nan, + ), + # d_orient, acc_norm | lane x y heading | projected x y | future_controls steer*ph + acc*ph + connected_node.get( + timestep_range_plan, + plan_vehicle_features_dict_old, + padding=np.nan, + ), + ), + axis=-1, + ) + + if ( + is_closed_loop + and closed_loop_ego_hist is not None + and connected_node.is_robot + ): + neighbor_future_features_np[0, :6] = pred_state_to_plan_state( + closed_loop_ego_hist[-1] + ) + # Add lane points - lane_ref_points = connected_node.get_lane_points(timestep_range_plan, padding=np.nan, num_lane_points=16) - neighbor_future_features_np = np.concatenate((neighbor_future_features_np, lane_ref_points.reshape((lane_ref_points.shape[0], 16*3))), axis=1) + if hyperparams["dataset_version"] == "v2": + pass + else: + lane_ref_points = connected_node.get_lane_points( + timestep_range_plan, padding=np.nan, num_lane_points=16 + ) + neighbor_future_features_np = np.concatenate( + ( + neighbor_future_features_np, + lane_ref_points.reshape( + (lane_ref_points.shape[0], 16 * 3) + ), + ), + axis=1, + ) # # Assert accelearation norm is correct # acc = connected_node.get(timestep_range_plan, {'acceleration': ['x', 'y']}) @@ -382,7 +488,7 @@ def get_node_timestep_data(env, scene, t, node, state, pred_state, # if np.any(np.nan_to_num(lane_dist, 0.) >= 50.): # print (lane_dist) # pass - + # is_robot_vect = np.ones((neighbor_future_features_np.shape[0], 1)) * float(connected_node.is_robot) # neighbor_future_features_np = np.concatenate((neighbor_future_features_np, is_robot_vect), axis=-1) # is_robot = torch.tensor([float(connected_node.is_robot)], dtype=torch.float) @@ -390,52 +496,75 @@ def get_node_timestep_data(env, scene, t, node, state, pred_state, robot_neighbor = torch.Tensor([n_i]).int() # Check if car is parked, i.e., it hasnt moved more more than 1m - xy_first = connected_node.data[0, {'position': ['x', 'y']}].astype(np.float32) - xy_last = connected_node.data[connected_node.last_timestep-connected_node.first_timestep, {'position': ['x', 'y']}].astype(np.float32) - is_neighbor_parked[edge_type].append(np.square(xy_first - xy_last).sum(0) < 1.**2) + xy_first = connected_node.data[0, {"position": ["x", "y"]}].astype( + np.float32 + ) + xy_last = connected_node.data[ + connected_node.last_timestep - connected_node.first_timestep, + {"position": ["x", "y"]}, + ].astype(np.float32) + is_neighbor_parked[edge_type].append( + np.square(xy_first - xy_last).sum(0) < 1.0**2 + ) elif edge_type[1] == "PEDESTRIAN": - neighbor_future_features_np = connected_node.get(timestep_range_plan, - plan_pedestrian_state_dict, - padding=0.0) # x, y | lane x y heading | projected x y + neighbor_future_features_np = connected_node.get( + timestep_range_plan, plan_pedestrian_state_dict, padding=0.0 + ) # x, y | lane x y heading | projected x y else: - raise ValueError("Unknown type {}".format(edge_type[1])) - neighbor_future_features = torch.tensor(neighbor_future_features_np, dtype=torch.float) + raise ValueError("Unknown type {}".format(edge_type[1])) + neighbor_future_features = torch.tensor( + neighbor_future_features_np, dtype=torch.float + ) neighbors_future_data[edge_type].append(neighbor_future_features) - assert hyperparams['dt'] == scene.dt # we will rely on this hyperparam for delta_t, make sure its correct - - # Find closest neighbor for planning + assert ( + hyperparams["dt"] == scene.dt + ) # we will rely on this hyperparam for delta_t, make sure its correct + + # Find closest neighbor for planning # if 'planner' in hyperparams and hyperparams['planner'] and edge_type[1] == 'VEHICLE': - if edge_type[1] == 'VEHICLE': - vehicle_future_f = neighbors_future_data[(node.type, 'VEHICLE')] - is_parked = is_neighbor_parked[(node.type, 'VEHICLE')] - + if edge_type[1] == "VEHICLE": + vehicle_future_f = neighbors_future_data[(node.type, "VEHICLE")] + is_parked = is_neighbor_parked[(node.type, "VEHICLE")] + # Find most relevant agent dists = [] inds = [] for n_i in range(len(vehicle_future_f)): # Filter incomplete future states, controls or missing lanes - if torch.isnan(vehicle_future_f[n_i][:, :(4+2+3)]).any(): + if torch.isnan(vehicle_future_f[n_i][:, : (4 + 2 + 3)]).any(): continue # Filter parked cars for v7 only. - if is_parked[n_i]: - continue - - dist = torch.square(y_t - vehicle_future_f[n_i][1:, :2]) # [1:] exclude current state for vehicle future - dist = dist.sum(dim=-1).amin(dim=-1) # sum over states, min over time + if hyperparams["dataset_version"] not in ["v7"]: + if is_parked[n_i]: + continue + + dist = torch.square( + y_t - vehicle_future_f[n_i][1:, :2] + ) # [1:] exclude current state for vehicle future + dist = dist.sum(dim=-1).amin( + dim=-1 + ) # sum over states, min over time inds.append(n_i) dists.append(dist) - + if dists: - plan_i = inds[torch.argmin(torch.stack(dists))] # neighbor that gets closest to current node + plan_i = inds[ + torch.argmin(torch.stack(dists)) + ] # neighbor that gets closest to current node else: # No neighbors or all futures are incomplete plan_i = -1 # Robot index, filter incomplete futures, controls or missing lanes - if robot_neighbor >= 0 and torch.isnan(vehicle_future_f[robot_neighbor][:, :(4+2+3)]).any(): + if ( + robot_neighbor >= 0 + and torch.isnan( + vehicle_future_f[robot_neighbor][:, : (4 + 2 + 3)] + ).any() + ): robot_i = -1 else: robot_i = robot_neighbor @@ -445,13 +574,26 @@ def get_node_timestep_data(env, scene, t, node, state, pred_state, plan_i = robot_i # Get nearby lanes for most_relevant neighbor (used for trajectroy fan planner) - if plan_i >= 0: + if plan_i >= 0: nusc_map = nusc_maps[scene.map_name] # Relevant lanes that are near the goal state x_goal_np = vehicle_future_f[plan_i][-1, :4].numpy() # x,y,h,v - pos_xy_goal, yaw_goal, vel_goal = np.split(x_goal_np, (2, 3), axis=-1) - relevant_lanes, relevant_lane_tokens, relevant_lane_arclines = get_relevant_lanes_np(nusc_map, scene_offset_np, pos_xy_goal, yaw_goal, vel_goal, hyperparams['dt']) + pos_xy_goal, yaw_goal, vel_goal = np.split( + x_goal_np, (2, 3), axis=-1 + ) + ( + relevant_lanes, + relevant_lane_tokens, + relevant_lane_arclines, + ) = get_relevant_lanes_np( + nusc_map, + scene_offset_np, + pos_xy_goal, + yaw_goal, + vel_goal, + hyperparams["dt"], + ) else: relevant_lanes = [] relevant_lane_tokens = [] @@ -460,19 +602,23 @@ def get_node_timestep_data(env, scene, t, node, state, pred_state, plan_data = batchable_dict( most_relevant_idx=torch.Tensor([int(plan_i)]).int().squeeze(0), robot_idx=torch.Tensor([int(robot_i)]).int().squeeze(0), - most_relevant_nearby_lanes=batchable_list([torch.from_numpy(pts).float() for pts in relevant_lanes]), - most_relevant_nearby_lane_tokens=batchable_list(relevant_lane_tokens), + most_relevant_nearby_lanes=batchable_list( + [torch.from_numpy(pts).float() for pts in relevant_lanes] + ), + most_relevant_nearby_lane_tokens=batchable_list( + relevant_lane_tokens + ), map_name=str(scene.map_name), scene_offset=torch.from_numpy(scene_offset_np), ) else: - assert hyperparams['planner'] in ["", "none"] + assert hyperparams["planner"] in ["", "none"] plan_data = None # Robot robot_traj_st_t = None timestep_range_r = np.array([t, t + max_ft]) - if hyperparams['incl_robot_node']: + if hyperparams["incl_robot_node"]: x_node = node.get(timestep_range_r, state[node.type]) if scene.non_aug_scene is not None: robot = scene.get_node_by_id(scene.non_aug_scene.robot.id) @@ -480,22 +626,31 @@ def get_node_timestep_data(env, scene, t, node, state, pred_state, robot = scene.robot robot_type = robot.type robot_traj = robot.get(timestep_range_r, state[robot_type], padding=np.nan) - robot_traj_st_t = get_relative_robot_traj(env, state, x_node, robot_traj, node.type, robot_type) + robot_traj_st_t = get_relative_robot_traj( + env, state, x_node, robot_traj, node.type, robot_type + ) robot_traj_st_t[torch.isnan(robot_traj_st_t)] = 0.0 # Map map_tuple = None - if hyperparams['use_map_encoding']: - if node.type in hyperparams['map_encoder']: + if hyperparams["use_map_encoding"]: + if node.type in hyperparams["map_encoder"]: if node.non_aug_node is not None: x = node.non_aug_node.get(np.array([t]), state[node.type]) - me_hyp = hyperparams['map_encoder'][node.type] - if 'heading_state_index' in me_hyp: - heading_state_index = me_hyp['heading_state_index'] + me_hyp = hyperparams["map_encoder"][node.type] + if "heading_state_index" in me_hyp: + heading_state_index = me_hyp["heading_state_index"] # We have to rotate the map in the opposit direction of the agent to match them - if type(heading_state_index) is list: # infer from velocity or heading vector - heading_angle = -np.arctan2(x[-1, heading_state_index[1]], - x[-1, heading_state_index[0]]) * 180 / np.pi + if ( + type(heading_state_index) is list + ): # infer from velocity or heading vector + heading_angle = ( + -np.arctan2( + x[-1, heading_state_index[1]], x[-1, heading_state_index[0]] + ) + * 180 + / np.pi + ) else: heading_angle = -x[-1, heading_state_index] * 180 / np.pi else: @@ -504,18 +659,41 @@ def get_node_timestep_data(env, scene, t, node, state, pred_state, scene_map = scene.map[node.type] map_point = x[-1, :2] - patch_size = hyperparams['map_encoder'][node.type]['patch_size'] + patch_size = hyperparams["map_encoder"][node.type]["patch_size"] map_tuple = (scene_map, map_point, heading_angle, patch_size) - data_tuple = (first_history_index, x_t, y_t, x_st_t, y_st_t, neighbors_data_st, - neighbors_edge_value, robot_traj_st_t, map_tuple, neighbors_future_data, plan_data) + data_tuple = ( + first_history_index, + x_t, + y_t, + x_st_t, + y_st_t, + neighbors_data_st, + neighbors_edge_value, + robot_traj_st_t, + map_tuple, + neighbors_future_data, + plan_data, + ) if is_closed_loop: return (data_tuple, (neighbors_data_not_st, logged_robot_data, robot_i)) return data_tuple -def get_timesteps_data(env, scene, t, node_type, state, pred_state, - edge_types, min_ht, max_ht, min_ft, max_ft, hyperparams): +def get_timesteps_data( + env, + scene, + t, + node_type, + state, + pred_state, + edge_types, + min_ht, + max_ht, + min_ft, + max_ft, + hyperparams, +): """ Puts together the inputs for ALL nodes in a given scene and timestep in it. @@ -531,30 +709,49 @@ def get_timesteps_data(env, scene, t, node_type, state, pred_state, :param hyperparams: Model hyperparameters :return: """ - nodes_per_ts = scene.present_nodes(t, - type=node_type, - min_history_timesteps=min_ht, - min_future_timesteps=max_ft, - return_robot=not hyperparams['incl_robot_node']) + nodes_per_ts = scene.present_nodes( + t, + type=node_type, + min_history_timesteps=min_ht, + min_future_timesteps=max_ft, + return_robot=not hyperparams["incl_robot_node"], + ) # Filter fields not in data - state = {nk: {k: v for k, v in ndict.items() if k != "augment"} for nk, ndict in state.items()} + state = { + nk: {k: v for k, v in ndict.items() if k != "augment"} + for nk, ndict in state.items() + } batch = list() nodes = list() out_timesteps = list() for timestep in nodes_per_ts.keys(): - scene_graph = scene.get_scene_graph(timestep, - env.attention_radius, - hyperparams['edge_addition_filter'], - hyperparams['edge_removal_filter']) - present_nodes = nodes_per_ts[timestep] - for node in present_nodes: - nodes.append(node) - out_timesteps.append(timestep) - batch.append(get_node_timestep_data(env, scene, timestep, node, state, pred_state, - edge_types, max_ht, max_ft, hyperparams, - nusc_maps=env.nusc_maps, - scene_graph=scene_graph)) + scene_graph = scene.get_scene_graph( + timestep, + env.attention_radius, + hyperparams["edge_addition_filter"], + hyperparams["edge_removal_filter"], + ) + present_nodes = nodes_per_ts[timestep] + for node in present_nodes: + nodes.append(node) + out_timesteps.append(timestep) + batch.append( + get_node_timestep_data( + env, + scene, + timestep, + node, + state, + pred_state, + edge_types, + max_ht, + max_ft, + hyperparams, + nusc_maps=env.nusc_maps, + scene_graph=scene_graph, + ) + ) if len(out_timesteps) == 0: return None return collate(batch), nodes, out_timesteps diff --git a/diffstack/modules/predictors/trajectron_utils/model/dynamics/unicycle.py b/diffstack/modules/predictors/trajectron_utils/model/dynamics/unicycle.py index 2b53026..c7a366a 100644 --- a/diffstack/modules/predictors/trajectron_utils/model/dynamics/unicycle.py +++ b/diffstack/modules/predictors/trajectron_utils/model/dynamics/unicycle.py @@ -8,16 +8,19 @@ class Unicycle(Dynamic): def init_constants(self): self.F_s = torch.eye(4, device=self.device, dtype=torch.float32) - self.F_s[0:2, 2:] = torch.eye(2, device=self.device, dtype=torch.float32) * self.dt + self.F_s[0:2, 2:] = ( + torch.eye(2, device=self.device, dtype=torch.float32) * self.dt + ) self.F_s_t = self.F_s.transpose(-2, -1) def create_graph(self, xz_size): model_if_absent = nn.Linear(xz_size + 1, 1) - self.p0_model = self.model_registrar.get_model(f"{self.node_type}/unicycle_initializer", model_if_absent) + self.p0_model = self.model_registrar.get_model( + f"{self.node_type}/unicycle_initializer", model_if_absent + ) def dynamic(self, x, u): - r""" - TODO: Boris: Add docstring + """ :param x: :param u: :return: @@ -36,32 +39,44 @@ def dynamic(self, x, u): dsin_domega = (torch.sin(phi_p_omega_dt) - torch.sin(phi)) / dphi dcos_domega = (torch.cos(phi_p_omega_dt) - torch.cos(phi)) / dphi - d1 = torch.stack([(x_p - + (a / dphi) * dcos_domega - + v * dsin_domega - + (a / dphi) * torch.sin(phi_p_omega_dt) * self.dt), - (y_p - - v * dcos_domega - + (a / dphi) * dsin_domega - - (a / dphi) * torch.cos(phi_p_omega_dt) * self.dt), - phi + dphi * self.dt, - v + a * self.dt], dim=0) - d2 = torch.stack([x_p + v * torch.cos(phi) * self.dt + (a / 2) * torch.cos(phi) * self.dt ** 2, - y_p + v * torch.sin(phi) * self.dt + (a / 2) * torch.sin(phi) * self.dt ** 2, - phi * torch.ones_like(a), - v + a * self.dt], dim=0) + d1 = torch.stack( + [ + ( + x_p + + (a / dphi) * dcos_domega + + v * dsin_domega + + (a / dphi) * torch.sin(phi_p_omega_dt) * self.dt + ), + ( + y_p + - v * dcos_domega + + (a / dphi) * dsin_domega + - (a / dphi) * torch.cos(phi_p_omega_dt) * self.dt + ), + phi + dphi * self.dt, + v + a * self.dt, + ], + dim=0, + ) + d2 = torch.stack( + [ + x_p + + v * torch.cos(phi) * self.dt + + (a / 2) * torch.cos(phi) * self.dt**2, + y_p + + v * torch.sin(phi) * self.dt + + (a / 2) * torch.sin(phi) * self.dt**2, + phi * torch.ones_like(a), + v + a * self.dt, + ], + dim=0, + ) return torch.where(~mask, d1, d2) def integrate_samples(self, control_samples, x=None): - r""" - TODO: Boris: Add docstring - :param x: - :param u: - :return: - """ ph = control_samples.shape[-2] - p_0 = self.initial_conditions['pos'].unsqueeze(1) - v_0 = self.initial_conditions['vel'].unsqueeze(1) + p_0 = self.initial_conditions["pos"].unsqueeze(1) + v_0 = self.initial_conditions["vel"].unsqueeze(1) # In case the input is batched because of the robot in online use we repeat this to match the batch size of x. if p_0.size()[0] != x.size()[0]: @@ -73,26 +88,24 @@ def integrate_samples(self, control_samples, x=None): phi_0 = phi_0 + torch.tanh(self.p0_model(torch.cat((x, phi_0), dim=-1))) u = torch.stack([control_samples[..., 0], control_samples[..., 1]], dim=0) - x = torch.stack([p_0[..., 0], p_0[..., 1], phi_0, torch.norm(v_0, dim=-1)], dim = 0).squeeze(dim=-1) + x = torch.stack( + [p_0[..., 0], p_0[..., 1], phi_0, torch.norm(v_0, dim=-1)], dim=0 + ).squeeze(dim=-1) mus_list = [] for t in range(ph): x = self.dynamic(x, u[..., t]) - mus_list.append(torch.stack((x[0], x[1]), dim=-1)) + mus_list.append(torch.stack((x[0], x[1], x[2], x[3]), dim=-1)) pos_mus = torch.stack(mus_list, dim=2) return pos_mus def compute_control_jacobian(self, sample_batch_dim, components, x, u): - r""" - TODO: Boris: Add docstring - :param x: - :param u: - :return: - """ - F = torch.zeros(sample_batch_dim + [components, 4, 2], - device=self.device, - dtype=torch.float32) + F = torch.zeros( + sample_batch_dim + [components, 4, 2], + device=self.device, + dtype=torch.float32, + ) phi = x[2] v = x[3] @@ -106,47 +119,53 @@ def compute_control_jacobian(self, sample_batch_dim, components, x, u): dsin_domega = (torch.sin(phi_p_omega_dt) - torch.sin(phi)) / dphi dcos_domega = (torch.cos(phi_p_omega_dt) - torch.cos(phi)) / dphi - F[..., 0, 0] = ((v / dphi) * torch.cos(phi_p_omega_dt) * self.dt - - (v / dphi) * dsin_domega - - (2 * a / dphi ** 2) * torch.sin(phi_p_omega_dt) * self.dt - - (2 * a / dphi ** 2) * dcos_domega - + (a / dphi) * torch.cos(phi_p_omega_dt) * self.dt ** 2) - F[..., 0, 1] = (1 / dphi) * dcos_domega + (1 / dphi) * torch.sin(phi_p_omega_dt) * self.dt - - F[..., 1, 0] = ((v / dphi) * dcos_domega - - (2 * a / dphi ** 2) * dsin_domega - + (2 * a / dphi ** 2) * torch.cos(phi_p_omega_dt) * self.dt - + (v / dphi) * torch.sin(phi_p_omega_dt) * self.dt - + (a / dphi) * torch.sin(phi_p_omega_dt) * self.dt ** 2) - F[..., 1, 1] = (1 / dphi) * dsin_domega - (1 / dphi) * torch.cos(phi_p_omega_dt) * self.dt + F[..., 0, 0] = ( + (v / dphi) * torch.cos(phi_p_omega_dt) * self.dt + - (v / dphi) * dsin_domega + - (2 * a / dphi**2) * torch.sin(phi_p_omega_dt) * self.dt + - (2 * a / dphi**2) * dcos_domega + + (a / dphi) * torch.cos(phi_p_omega_dt) * self.dt**2 + ) + F[..., 0, 1] = (1 / dphi) * dcos_domega + (1 / dphi) * torch.sin( + phi_p_omega_dt + ) * self.dt + + F[..., 1, 0] = ( + (v / dphi) * dcos_domega + - (2 * a / dphi**2) * dsin_domega + + (2 * a / dphi**2) * torch.cos(phi_p_omega_dt) * self.dt + + (v / dphi) * torch.sin(phi_p_omega_dt) * self.dt + + (a / dphi) * torch.sin(phi_p_omega_dt) * self.dt**2 + ) + F[..., 1, 1] = (1 / dphi) * dsin_domega - (1 / dphi) * torch.cos( + phi_p_omega_dt + ) * self.dt F[..., 2, 0] = self.dt F[..., 3, 1] = self.dt - F_sm = torch.zeros(sample_batch_dim + [components, 4, 2], - device=self.device, - dtype=torch.float32) + F_sm = torch.zeros( + sample_batch_dim + [components, 4, 2], + device=self.device, + dtype=torch.float32, + ) - F_sm[..., 0, 1] = (torch.cos(phi) * self.dt ** 2) / 2 + F_sm[..., 0, 1] = (torch.cos(phi) * self.dt**2) / 2 - F_sm[..., 1, 1] = (torch.sin(phi) * self.dt ** 2) / 2 + F_sm[..., 1, 1] = (torch.sin(phi) * self.dt**2) / 2 F_sm[..., 3, 1] = self.dt return torch.where(~mask.unsqueeze(-1).unsqueeze(-1), F, F_sm) def compute_jacobian(self, sample_batch_dim, components, x, u): - r""" - TODO: Boris: Add docstring - :param x: - :param u: - :return: - """ one = torch.tensor(1) - F = torch.zeros(sample_batch_dim + [components, 4, 4], - device=self.device, - dtype=torch.float32) + F = torch.zeros( + sample_batch_dim + [components, 4, 4], + device=self.device, + dtype=torch.float32, + ) phi = x[2] v = x[3] @@ -165,40 +184,48 @@ def compute_jacobian(self, sample_batch_dim, components, x, u): F[..., 2, 2] = one F[..., 3, 3] = one - F[..., 0, 2] = v * dcos_domega - (a / dphi) * dsin_domega + (a / dphi) * torch.cos(phi_p_omega_dt) * self.dt + F[..., 0, 2] = ( + v * dcos_domega + - (a / dphi) * dsin_domega + + (a / dphi) * torch.cos(phi_p_omega_dt) * self.dt + ) F[..., 0, 3] = dsin_domega - F[..., 1, 2] = v * dsin_domega + (a / dphi) * dcos_domega + (a / dphi) * torch.sin(phi_p_omega_dt) * self.dt + F[..., 1, 2] = ( + v * dsin_domega + + (a / dphi) * dcos_domega + + (a / dphi) * torch.sin(phi_p_omega_dt) * self.dt + ) F[..., 1, 3] = -dcos_domega - F_sm = torch.zeros(sample_batch_dim + [components, 4, 4], - device=self.device, - dtype=torch.float32) + F_sm = torch.zeros( + sample_batch_dim + [components, 4, 4], + device=self.device, + dtype=torch.float32, + ) F_sm[..., 0, 0] = one F_sm[..., 1, 1] = one F_sm[..., 2, 2] = one F_sm[..., 3, 3] = one - F_sm[..., 0, 2] = -v * torch.sin(phi) * self.dt - (a * torch.sin(phi) * self.dt ** 2) / 2 + F_sm[..., 0, 2] = ( + -v * torch.sin(phi) * self.dt - (a * torch.sin(phi) * self.dt**2) / 2 + ) F_sm[..., 0, 3] = torch.cos(phi) * self.dt - F_sm[..., 1, 2] = v * torch.cos(phi) * self.dt + (a * torch.cos(phi) * self.dt ** 2) / 2 + F_sm[..., 1, 2] = ( + v * torch.cos(phi) * self.dt + (a * torch.cos(phi) * self.dt**2) / 2 + ) F_sm[..., 1, 3] = torch.sin(phi) * self.dt return torch.where(~mask.unsqueeze(-1).unsqueeze(-1), F, F_sm) def integrate_distribution(self, control_dist_dphi_a, x): - r""" - TODO: Boris: Add docstring - :param x: - :param u: - :return: - """ sample_batch_dim = list(control_dist_dphi_a.mus.shape[0:2]) ph = control_dist_dphi_a.mus.shape[-3] - p_0 = self.initial_conditions['pos'].unsqueeze(1) - v_0 = self.initial_conditions['vel'].unsqueeze(1) + p_0 = self.initial_conditions["pos"].unsqueeze(1) + v_0 = self.initial_conditions["vel"].unsqueeze(1) # In case the input is batched because of the robot in online use we repeat this to match the batch size of x. if p_0.size()[0] != x.size()[0]: @@ -210,25 +237,38 @@ def integrate_distribution(self, control_dist_dphi_a, x): phi_0 = phi_0 + torch.tanh(self.p0_model(torch.cat((x, phi_0), dim=-1))) dist_sigma_matrix = control_dist_dphi_a.get_covariance_matrix() - pos_dist_sigma_matrix_t = torch.zeros(sample_batch_dim + [control_dist_dphi_a.components, 4, 4], - device=self.device) - - u = torch.stack([control_dist_dphi_a.mus[..., 0], control_dist_dphi_a.mus[..., 1]], dim=0) - x = torch.stack([p_0[..., 0], p_0[..., 1], phi_0, torch.norm(v_0, dim=-1)], dim=0) + pos_dist_sigma_matrix_t = torch.zeros( + sample_batch_dim + [control_dist_dphi_a.components, 4, 4], + device=self.device, + ) + + u = torch.stack( + [control_dist_dphi_a.mus[..., 0], control_dist_dphi_a.mus[..., 1]], dim=0 + ) + x = torch.stack( + [p_0[..., 0], p_0[..., 1], phi_0, torch.norm(v_0, dim=-1)], dim=0 + ) pos_dist_sigma_matrix_list = [] mus_list = [] for t in range(ph): - F_t = self.compute_jacobian(sample_batch_dim, control_dist_dphi_a.components, x, u[:, :, :, t]) - G_t = self.compute_control_jacobian(sample_batch_dim, control_dist_dphi_a.components, x, u[:, :, :, t]) + F_t = self.compute_jacobian( + sample_batch_dim, control_dist_dphi_a.components, x, u[:, :, :, t] + ) + G_t = self.compute_control_jacobian( + sample_batch_dim, control_dist_dphi_a.components, x, u[:, :, :, t] + ) dist_sigma_matrix_t = dist_sigma_matrix[:, :, t] - pos_dist_sigma_matrix_t = (F_t.matmul(pos_dist_sigma_matrix_t.matmul(F_t.transpose(-2, -1))) - + G_t.matmul(dist_sigma_matrix_t.matmul(G_t.transpose(-2, -1)))) + pos_dist_sigma_matrix_t = F_t.matmul( + pos_dist_sigma_matrix_t.matmul(F_t.transpose(-2, -1)) + ) + G_t.matmul(dist_sigma_matrix_t.matmul(G_t.transpose(-2, -1))) pos_dist_sigma_matrix_list.append(pos_dist_sigma_matrix_t[..., :2, :2]) x = self.dynamic(x, u[:, :, :, t]) - mus_list.append(torch.stack((x[0], x[1]), dim=-1)) + mus_list.append(torch.stack((x[0], x[1], x[2]), dim=-1)) pos_dist_sigma_matrix = torch.stack(pos_dist_sigma_matrix_list, dim=2) pos_mus = torch.stack(mus_list, dim=2) - return GMM2D.from_log_pis_mus_cov_mats(control_dist_dphi_a.log_pis, pos_mus, pos_dist_sigma_matrix) + return GMM2D.from_log_pis_mus_cov_mats( + control_dist_dphi_a.log_pis, pos_mus, pos_dist_sigma_matrix + ) diff --git a/diffstack/modules/predictors/trajectron_utils/model/mgcvae.py b/diffstack/modules/predictors/trajectron_utils/model/mgcvae.py index 2c78c72..fcb5cf6 100644 --- a/diffstack/modules/predictors/trajectron_utils/model/mgcvae.py +++ b/diffstack/modules/predictors/trajectron_utils/model/mgcvae.py @@ -902,7 +902,7 @@ def encoder(self, mode, x, y_e, num_samples=None): if mode == ModeKeys.TRAIN: sample_ct = self.hyperparams['k'] elif mode == ModeKeys.EVAL: - sample_ct = self.hyperparams['k'] + sample_ct = self.hyperparams['k_eval'] elif mode == ModeKeys.PREDICT: sample_ct = num_samples if num_samples is None: @@ -1087,7 +1087,7 @@ def eval_loss(self, num_components = self.hyperparams['N'] * self.hyperparams['K'] ### Importance sampled NLL estimate - z, _ = self.encoder(mode, x, y_e) # [k, nbs, N*K] + z, _ = self.encoder(mode, x, y_e) # [k_eval, nbs, N*K] z = self.latent.sample_p(1, mode, full_dist=True) y_dist, _ = self.p_y_xz(ModeKeys.PREDICT, x, x_nr_t, y_r, n_s_t0, z, prediction_horizon, num_samples=1, num_components=num_components) diff --git a/diffstack/scripts/generate_config_templates.py b/diffstack/scripts/generate_config_templates.py new file mode 100644 index 0000000..bf927b6 --- /dev/null +++ b/diffstack/scripts/generate_config_templates.py @@ -0,0 +1,19 @@ +""" +Helpful script to generate example config files for each algorithm. These should be re-generated +when new config options are added, or when default settings in the config classes are modified. +""" +import os + +import diffstack +from diffstack.configs.registry import EXP_CONFIG_REGISTRY + + +def main(): + # store template config jsons in this directory + target_dir = os.path.join(diffstack.__path__[0], "../config/templates/") + for name, cfg in EXP_CONFIG_REGISTRY.items(): + cfg.dump(filename=os.path.join(target_dir, name + ".json")) + + +if __name__ == "__main__": + main() diff --git a/diffstack/scripts/train_pl.py b/diffstack/scripts/train_pl.py new file mode 100644 index 0000000..9a77eaa --- /dev/null +++ b/diffstack/scripts/train_pl.py @@ -0,0 +1,582 @@ +import argparse +import sys +import os +import socket +import torch + +import wandb +import pytorch_lightning as pl +from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger +from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.profilers import PyTorchProfiler + +from diffstack.utils.log_utils import PrintLogger +from diffstack.utils.experiment_utils import get_checkpoint +import diffstack.utils.train_utils as TrainUtils +from diffstack.data.trajdata_datamodules import UnifiedDataModule +from diffstack.configs.registry import get_registered_experiment_config +from diffstack.utils.config_utils import ( + get_experiment_config_from_file, + recursive_update_flat, + translate_trajdata_cfg, +) +from diffstack.stacks.stack_factory import stack_factory +from pathlib import Path +import json +import multiprocessing as mp +from diffstack.configs.config import Dict + +test = False + + +def main(cfg, auto_remove_exp_dir=False, debug=False, evaluate=False, **extra_kwargs): + pl.seed_everything(cfg.seed) + # Dataset + trajdata_config = translate_trajdata_cfg(cfg) + + model = stack_factory( + cfg=cfg, + ) + # if test: + # model.set_eval() + # import pickle + # with open("homo_test_case.pkl","rb") as f: + # batch = pickle.load(f) + # from diffstack.modules.module import Module, DataFormat, RunMode + # with torch.no_grad(): + # res = model.components["predictor"]._run_forward(dict(parsed_batch=batch),RunMode.VALIDATE) + # model.components["predictor"].log_pred_image(batch,res,501,"curated_result") + + datamodule = UnifiedDataModule(data_config=trajdata_config, train_config=cfg.train) + + datamodule.setup() + + if not evaluate: + print("\n============= New Training Run with Config =============") + print(cfg) + print("") + root_dir, log_dir, ckpt_dir, video_dir, version_key = TrainUtils.get_exp_dir( + exp_name=cfg.name, + output_dir=cfg.root_dir, + save_checkpoints=cfg.train.save.enabled, + auto_remove_exp_dir=auto_remove_exp_dir, + ) + + # Save experiment config to the training dir + cfg.dump(os.path.join(root_dir, version_key, "config.json")) + + # if cfg.train.logging.terminal_output_to_txt and not debug: + # # log stdout and stderr to a text file + # logger = PrintLogger(os.path.join(log_dir, "log.txt")) + # sys.stdout = logger + # sys.stderr = logger + + train_callbacks = [] + + # Training Parallelism + assert cfg.train.parallel_strategy in [ + "ddp", + "ddp_spawn", + # "ddp", # TODO for ddp we need to look at NODE_RANK and disable logging in config + None, + ] # TODO: look into other strategies + if not cfg.devices.num_gpus > 1: + # Override strategy when training on a single GPU + with cfg.train.unlocked(): + cfg.train.parallel_strategy = None + if cfg.train.parallel_strategy in ["ddp_spawn", "ddp"]: + with cfg.train.training.unlocked(): + cfg.train.training.batch_size = int( + cfg.train.training.batch_size / cfg.devices.num_gpus + ) + with cfg.train.validation.unlocked(): + cfg.train.validation.batch_size = int( + cfg.train.validation.batch_size / cfg.devices.num_gpus + ) + + # # Environment for close-loop evaluation + # if cfg.train.rollout.enabled: + # # Run rollout at regular intervals + # rollout_callback = RolloutCallback( + # exp_config=cfg, + # every_n_steps=cfg.train.rollout.every_n_steps, + # warm_start_n_steps=cfg.train.rollout.warm_start_n_steps, + # verbose=True, + # save_video=cfg.train.rollout.save_video, + # video_dir=video_dir + # ) + # train_callbacks.append(rollout_callback) + + # Model + + # Checkpointing + if cfg.train.validation.enabled and cfg.train.save.save_best_validation: + assert ( + cfg.train.save.every_n_steps > cfg.train.validation.every_n_steps + ), "checkpointing frequency needs to be greater than validation frequency" + for metric_name, metric_key in model.checkpoint_monitor_keys.items(): + print( + "Monitoring metrics {} under alias {}".format( + metric_key, metric_name + ) + ) + ckpt_valid_callback = pl.callbacks.ModelCheckpoint( + dirpath=ckpt_dir, + filename="iter{step}_ep{epoch}_%s{%s:.2f}" + % (metric_name, metric_key), + # explicitly spell out metric names, otherwise PL parses '/' in metric names to directories + auto_insert_metric_name=False, + save_top_k=cfg.train.save.best_k, # save the best k models + monitor=metric_key, + mode="min", + every_n_train_steps=cfg.train.save.every_n_steps, + verbose=True, + ) + train_callbacks.append(ckpt_valid_callback) + + if cfg.train.rollout.enabled and cfg.train.save.save_best_rollout: + assert ( + cfg.train.save.every_n_steps > cfg.train.rollout.every_n_steps + ), "checkpointing frequency needs to be greater than rollout frequency" + ckpt_rollout_callback = pl.callbacks.ModelCheckpoint( + dirpath=ckpt_dir, + filename="iter{step}_ep{epoch}_simADE{rollout/metrics_ego_ADE:.2f}", + # explicitly spell out metric names, otherwise PL parses '/' in metric names to directories + auto_insert_metric_name=False, + save_top_k=cfg.train.save.best_k, # save the best k models + monitor="rollout/metrics_ego_ADE", + mode="min", + every_n_train_steps=cfg.train.save.every_n_steps, + verbose=True, + ) + train_callbacks.append(ckpt_rollout_callback) + + # a ckpt monitor to save at fixed interval + ckpt_fixed_callback = pl.callbacks.ModelCheckpoint( + dirpath=ckpt_dir, + filename="iter{step}", + auto_insert_metric_name=False, + save_top_k=-1, + monitor=None, + every_n_train_steps=10000, + verbose=True, + ) + train_callbacks.append(ckpt_fixed_callback) + + # def wandb_login(i,return_dict): + # apikey = os.environ["WANDB_APIKEY"] + # wandb.login(key=apikey,host="https://api.wandb.ai") + # logger = WandbLogger( + # name=cfg.name, project=cfg.train.logging.wandb_project_name, + # ) + # return_dict[i] = logger + # manager = mp.Manager() + + logger = None + if debug: + print("Debugging mode, suppress logging.") + elif cfg.train.logging.log_tb: + logger = TensorBoardLogger( + save_dir=root_dir, version=version_key, name=None, sub_dir="logs/" + ) + print("Tensorboard event will be saved at {}".format(logger.log_dir)) + elif cfg.train.logging.log_wandb: + assert ( + "WANDB_APIKEY" in os.environ + ), "Set api key by `export WANDB_APIKEY=`" + try: + apikey = os.environ["WANDB_APIKEY"] + wandb.login(key=apikey, host="https://api.wandb.ai") + logger = WandbLogger( + name=cfg.name, + project=cfg.train.logging.wandb_project_name, + ) + except: + logger = None + # return_dict = manager.dict() + # p1 = mp.Process(target=wandb_login,args=(0,return_dict)) + # p1.start() + # p1.join(timeout=30) + # p1.terminate() + # if 0 in return_dict: + # logger = return_dict[0] + + else: + print("WARNING: not logging training stats") + + # Train + kwargs = dict( + default_root_dir=root_dir, + # checkpointing + enable_checkpointing=cfg.train.save.enabled, + # logging + logger=logger, + # flush_logs_every_n_steps=cfg.train.logging.flush_every_n_steps, + log_every_n_steps=cfg.train.logging.log_every_n_steps, + # training + max_steps=cfg.train.training.num_steps, + # validation + val_check_interval=cfg.train.validation.every_n_steps, + limit_val_batches=cfg.train.validation.num_steps_per_epoch, + # all callbacks + callbacks=train_callbacks, + # device & distributed training setup + # accelerator='cpu', + # accelerator=('gpu' if cfg.devices.num_gpus > 0 else 'cpu'), + devices=(cfg.devices.num_gpus if cfg.devices.num_gpus > 0 else None), + # strategy=cfg.train.parallel_strategy if cfg.train.parallel_strategy is not None else DDPStrategy(find_unused_parameters=True), + strategy="ddp_find_unused_parameters_true", + accelerator="gpu", + gradient_clip_val=cfg.train.gradient_clip_val, + # detect_anomaly=True, + # setting for overfit debugging + # limit_val_batches=0, + # overfit_batches=2 + ) + # if debug: + # profiler = PyTorchProfiler( + # output_filename=None, + # enabled=True, + # use_cuda=cfg.devices.num_gpus > 0, + # record_shapes=False, + # profile_memory=True, + # group_by_input_shapes=False, + # with_stack=False, + # use_kineto=False, + # use_cpu=cfg.devices.num_gpus == 0, + # emit_nvtx=False, + # export_to_chrome=False, + # path_to_export_trace=None, + # row_limit=200, + # sort_by_key=None, + # profiled_functions=None, + # local_rank=None, + # ) + # kwargs["profiler"] = profiler + # kwargs["max_steps"] = extra_kwargs["profile_steps"] + + if cfg.train.get("amp", False): + kwargs["precision"] = 16 + # kwargs["amp_backend"] = "apex" + # kwargs["amp_level"] = "O2" + # if cfg.train.get("auto_batch_size",False): + # kwargs["auto_scale_batch_size"] = "binsearch" + + if cfg.train.get("auto_batch_size", False): + from diffstack.utils.train_utils import trajdata_auto_set_batch_size + + kwargs_tune = kwargs.copy() + kwargs_tune["max_steps"] = 3 + trial_trainer = pl.Trainer(**kwargs_tune) + trial_model = stack_factory( + cfg=cfg, + ) + bs_max = cfg.train.get("max_batch_size", None) + batch_size = trajdata_auto_set_batch_size( + trial_trainer, + trial_model, + datamodule, + bs_min=cfg.train.training.batch_size, + bs_max=bs_max, + ) + # batch_size = trajdata_auto_set_batch_size(trial_trainer,trial_model,datamodule,bs_min = 50,bs_max = 58) + datamodule.train_batch_size = batch_size + datamodule.val_batch_size = batch_size + del trial_trainer, trial_model + if cfg.devices.num_gpus > 0: + torch.cuda.empty_cache() + torch.set_float32_matmul_precision("medium") + + trainer = pl.Trainer(**kwargs) + # Logging + assert not (cfg.train.logging.log_tb and cfg.train.logging.log_wandb) + + if isinstance(logger, WandbLogger): + # record the entire config on wandb + if trainer.global_rank == 0: + logger.experiment.config.update(cfg.to_dict()) + logger.watch(model=model) + # kwargs_tune = kwargs.copy() + # kwargs_tune["max_steps"] = 30 + # trial_trainer = pl.Trainer(**kwargs_tune) + # trial_trainer.fit(model=model, datamodule=datamodule) + trainer.fit(model=model, datamodule=datamodule) + + else: + kwargs = dict( + devices=(1 if cfg.devices.num_gpus > 0 else None), + # strategy=cfg.train.parallel_strategy if cfg.train.parallel_strategy is not None else DDPStrategy(find_unused_parameters=True), + accelerator="auto", + ) + trainer = pl.Trainer(**kwargs) + # Evaluation + + model.set_eval() + + metrics = trainer.test(model, datamodule=datamodule) + if len(metrics) == 1: + flattened_metrics = metrics[0] + else: + flattened_metrics = dict() + for k, v in metrics[0].items(): + flattened_metrics[k] = sum([m[k] for m in metrics]) / len(metrics) + result_path = extra_kwargs.get( + "eval_output_dir", os.path.join(os.getcwd(), "eval_result") + ) + file_name = f"{cfg.registered_name}_{cfg.train.trajdata_test_source_root}_{cfg.train.trajdata_source_test}_eval_result.json" + + with open(os.path.join(result_path, file_name), "w") as f: + json.dump(flattened_metrics, f) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + # External config file that overwrites default config + parser.add_argument( + "--config_file", + type=str, + default=None, + help="(optional) path to a config json that will be used to override the default settings. \ + If omitted, default settings are used. This is the preferred way to run experiments.", + ) + + parser.add_argument( + "--config_name", + type=str, + default=None, + help="(optional) create experiment config from a preregistered name (see configs/registry.py)", + ) + # Experiment Name (for tensorboard, saving models, etc.) + parser.add_argument( + "--name", + type=str, + default=None, + help="(optional) if provided, override the experiment name defined in the config", + ) + + parser.add_argument( + "--wandb_project_name", + type=str, + default=None, + help="(optional) if provided, override the wandb project name defined in the config", + ) + + parser.add_argument( + "--dataset_path", + type=str, + default=None, + help="(optional) if provided, override the dataset root path", + ) + + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="Root directory of training output (checkpoints, visualization, tensorboard log, etc.)", + ) + + parser.add_argument( + "--remove_exp_dir", + action="store_true", + help="Whether to automatically remove existing experiment directory of the same name (remember to set this to " + "True to avoid unexpected stall when launching cloud experiments).", + ) + + parser.add_argument( + "--on_ngc", + action="store_true", + help="whether running the script on ngc (this will change some behaviors like avoid writing into dataset)", + ) + + parser.add_argument( + "--debug", action="store_true", help="Debug mode, suppress wandb logging, etc." + ) + + parser.add_argument( + "--profile_steps", type=int, default=500, help="number of steps to run profiler" + ) + # evaluation mode + parser.add_argument( + "--evaluate", action="store_true", help="Evaluate mode, suppress training." + ) + parser.add_argument( + "--ckpt_path", + type=str, + default=None, + help="path to the checkpoint to be evaluated", + ) + parser.add_argument( + "--eval_output_dir", + type=str, + default=None, + help="path to store evaluation result", + ) + parser.add_argument( + "--ckpt_root_dir", type=str, default=None, help="path of ngc checkpoint folder" + ) + parser.add_argument("--ngc_job_id", type=str, default=None, help="ngc job id") + parser.add_argument("--ckpt_key", type=str, default=None, help="ngc checkpoint key") + parser.add_argument( + "--test_data", type=str, default=None, help="trajdata_source_test in config" + ) + parser.add_argument( + "--test_data_root", + type=str, + default=None, + help="trajdata_test_source_root in config", + ) + parser.add_argument( + "--test_batch_size", type=int, default=None, help="batch size of test run" + ) + parser.add_argument( + "--log_image_frequency", type=int, default=None, help="image logging frequency" + ) + + parser.add_argument( + "--log_all_image", + action="store_true", + help="log images for all scenes instead of only the first one for every batch", + ) + + parser.add_argument( + "--remove_parked", + action="store_true", + help="remove parked agents from the scene", + ) + + args = parser.parse_args() + + if args.config_name is not None: + default_config = get_registered_experiment_config(args.config_name) + elif args.config_file is not None: + # Update default config with external json file + default_config = get_experiment_config_from_file(args.config_file, locked=False) + elif args.evaluate: + ckpt_path = None + if args.ckpt_path is not None: + config_path = str(Path(args.ckpt_path).parents[1] / "config.json") + ckpt_path = args.ckpt_path + elif ( + args.ngc_job_id is not None + and args.ckpt_key is not None + and args.ckpt_root_dir is not None + ): + ckpt_path, config_path = get_checkpoint( + ngc_job_id=args.ngc_job_id, + ckpt_key=args.ckpt_key, + ckpt_root_dir=args.ckpt_root_dir, + ) + default_config = get_experiment_config_from_file(config_path) + + if default_config.stack.stack_type == "pred": + default_config.stack.predictor.load_checkpoint = ckpt_path + elif default_config.stack.stack_type == "plan": + default_config.stack.planner.load_checkpoint = ckpt_path + else: + raise NotImplementedError + + if "test" not in default_config.train: + default_config.train.test = Dict( + enabled=True, + batch_size=32, + num_data_workers=6, + every_n_steps=500, + num_steps_per_epoch=50, + ) + if args.remove_parked: + default_config.env.remove_parked = True + if args.test_data is not None: + default_config.train.trajdata_source_test = args.test_data + if args.test_data_root is not None: + default_config.train.trajdata_test_source_root = args.test_data_root + # modify dataset path to reduce loading time + if default_config.train.trajdata_source_test is None: + default_config.train.trajdata_source_test = ( + default_config.train.trajdata_source_valid + ) + if default_config.train.trajdata_test_source_root is None: + default_config.train.trajdata_test_source_root = ( + default_config.train.trajdata_val_source_root + ) + + default_config.train.trajdata_val_source_root = None + default_config.train.trajdata_source_valid = ( + default_config.train.trajdata_source_test + ) + default_config.train.trajdata_source_root = ( + default_config.train.trajdata_test_source_root + ) + default_config.train.trajdata_source_train = ( + default_config.train.trajdata_source_test + ) + if args.test_batch_size is not None: + default_config.train.test["batch_size"] = args.test_batch_size + if ( + "predictor" in default_config.stack + and default_config.stack.predictor.name == "CTT" + ): + default_config.env.max_num_lanes = 64 + default_config.stack.predictor.decoder.decode_num_modes = 8 + default_config.stack.predictor.LR_sample_hack = False + # default_config.env.remove_single_successor = False + default_config.eval.results_dir = args.eval_output_dir + if args.ngc_job_id is not None: + default_config.registered_name = ( + default_config.registered_name + "_" + args.ngc_job_id + ) + default_config.eval.log_image_frequency = args.log_image_frequency + if args.log_all_image: + default_config.eval.log_all_image = True + else: + default_config.eval.log_all_image = False + + else: + raise Exception( + "Need either a config name or a json file to create experiment config" + ) + + if args.name is not None: + default_config.name = args.name + + if args.dataset_path is not None: + default_config.train.dataset_path = args.dataset_path + + if args.output_dir is not None: + default_config.root_dir = os.path.abspath(args.output_dir) + + if args.wandb_project_name is not None: + default_config.train.logging.wandb_project_name = args.wandb_project_name + + if args.on_ngc: + ngc_job_id = socket.gethostname() + default_config.name = default_config.name + "_" + ngc_job_id + + default_config.train.on_ngc = args.on_ngc + args_dict = vars(args) + args_dict = { + k: v + for k, v in args_dict.items() + if k + not in ["name", "dataset_path", "output_dir", "wandb_project_name", "on_ngc"] + } + default_config, leftover = recursive_update_flat(default_config, args_dict) + if len(leftover) > 0: + Warning(f"Arguments {list(leftover.keys())} are not found in the config") + + if args.debug: + # Test policy rollout + default_config.train.rollout.every_n_steps = 10 + default_config.train.rollout.num_episodes = 1 + + # make rollout evaluation config consistent with the rest of the config + + default_config.lock() # Make config read-only + main( + default_config, + auto_remove_exp_dir=args.remove_exp_dir, + debug=args.debug, + profile_steps=args.profile_steps, + evaluate=args.evaluate, + eval_output_dir=args.eval_output_dir, + ) diff --git a/diffstack/stacks/base.py b/diffstack/stacks/base.py new file mode 100644 index 0000000..8c93cb4 --- /dev/null +++ b/diffstack/stacks/base.py @@ -0,0 +1,347 @@ +from diffstack.modules.module import ModuleSequence, DataFormat +from trajdata.data_structures.batch import SceneBatch, AgentBatch +import pytorch_lightning as pl +import diffstack.utils.tensor_utils as TensorUtils +import diffstack.utils.geometry_utils as GeoUtils +from diffstack.utils.utils import removeprefix +import torch +import torch.nn as nn +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +import wandb + + +class AVStack(pl.LightningModule): + def __init__(self, modules: ModuleSequence, cfg, batch_size=None, **kwargs): + super(AVStack, self).__init__() + self.moduleseq = modules + self.components = nn.ModuleDict() + for k, v in self.moduleseq.components.items(): + self.components[k] = v + self.cfg = cfg + if "monitor_key" in kwargs: + self.monitor_key = kwargs["monitor_key"] + else: + self.monitor_key = {"valLoss": "val/losses_predictor_prediction_loss"} + + self.batch_size = batch_size + self.validation_step_outputs = [] + + @property + def input_format(self) -> DataFormat: + self.moduleseq.input_format + + @property + def output_format(self) -> DataFormat: + self.moduleseq.output_format + + @property + def checkpoint_monitor_keys(self): + return self.monitor_key + + def forward(self, inputs, **kwargs): + return self.moduleseq(inputs, **kwargs) + + def infer_step(self, inputs, **kwargs): + with torch.no_grad(): + return TensorUtils.detach(self.moduleseq.infer_step(inputs, **kwargs)) + + def _compute_losses(self, pout, inputs): + loss_by_component = dict() + for comp_name, component in self.moduleseq.components.items(): + comp_out = { + removeprefix(k, comp_name + "."): v + for k, v in pout.items() + if k.startswith(comp_name + ".") + } + loss_by_component[comp_name] = component.compute_losses(comp_out, inputs) + + return loss_by_component + + def training_step(self, batch, batch_idx): + """ + Training on a single batch of data. + + Args: + batch (dict): dictionary with torch.Tensors sampled + from a data loader and filtered by @process_batch_for_training + + batch_idx (int): training step number - required by some Algos that need + to perform staged training and early stopping + + Returns: + total_loss(torch.Tensor): total weighted loss + """ + + pout = self(batch, batch_idx=batch_idx) + losses = self._compute_losses(pout, batch) + + total_loss = 0.0 + for comp_name in self.moduleseq.components: + comp_loss = losses[comp_name] + for lk, l in comp_loss.items(): + loss = l * self.cfg.stack[comp_name].loss_weights[lk] + self.log("train/losses_" + comp_name + "_" + lk, loss, sync_dist=True) + total_loss += loss + if batch_idx % 100 == 0: + metrics = self._compute_metrics(pout, batch) + for mk, m in metrics.items(): + if isinstance(m, np.ndarray) or isinstance(m, torch.Tensor): + m = m.mean() + self.log("train/metrics_" + mk, m, sync_dist=True) + + return total_loss + + def validation_step(self, batch, batch_idx): + with torch.no_grad(): + pout = TensorUtils.detach( + self.moduleseq.validate_step(batch, batch_idx=batch_idx) + ) + losses = self._compute_losses(pout, batch) + + if self.logger is not None and batch_idx == 1: + if "predictor" in self.components and hasattr( + self.components["predictor"], "log_pred_image" + ): + self.components["predictor"].log_pred_image( + batch, pout, batch_idx, self.logger + ) + else: + # default + self.log_pred_image(batch, pout, batch_idx, self.logger) + print("image logged") + + metrics = self._compute_metrics(pout, batch) + pred = {"losses": losses, "metrics": metrics} + self.validation_step_outputs.append(pred) + return pred + + def test_step(self, batch, batch_idx): + # if batch_idx<380 or batch_idx>381: + # return {} + # if "log_image_frequency" in self.cfg.eval and self.cfg.eval.log_image_frequency is not None: + # if batch_idx%self.cfg.eval.log_image_frequency!=0: + # return {} + with torch.no_grad(): + pout = TensorUtils.detach( + self.moduleseq.validate_step(batch, batch_idx=batch_idx) + ) + losses = self._compute_losses(pout, batch) + metrics = self._compute_metrics(pout, batch) + pred = {"losses": losses, "metrics": metrics} + flattened_pred = TensorUtils.flatten_dict(pred) + for k, v in flattened_pred.items(): + if isinstance(v, torch.Tensor): + flattened_pred[k] = ( + v.cpu().numpy().item() + if v.numel() == 1 + else v.cpu().numpy().mean().item() + ) + elif isinstance(v, np.ndarray): + flattened_pred[k] = v.item() if v.size == 1 else v.mean().item() + + self.log_dict(flattened_pred) + + if ( + "log_image_frequency" in self.cfg.eval + and self.cfg.eval.log_image_frequency is not None + ): + if batch_idx % self.cfg.eval.log_image_frequency == 0: + for comp_name, component in self.moduleseq.components.items(): + if hasattr(component, "log_pred_image"): + component.log_pred_image( + batch, + pout, + batch_idx, + self.cfg.eval.results_dir, + log_all_image=self.cfg.eval.log_all_image, + savegif=self.cfg.eval.get("savegif", True), + ) + + return flattened_pred + + def on_validation_epoch_end(self) -> None: + outputs = self.validation_step_outputs + for comp_name in self.moduleseq.components: + for k in outputs[0]["losses"][comp_name]: + m = torch.stack([o["losses"][comp_name][k] for o in outputs]).mean() + self.log("val/losses_" + comp_name + "_" + k, m, sync_dist=True) + + for k in outputs[0]["metrics"]: + m = np.stack([o["metrics"][k] for o in outputs]).mean() + self.log("val/metrics_" + k, m, sync_dist=True) + self.validation_step_outputs = [] + + def configure_optimizers(self): + pass + + def _compute_metrics(self, pred_batch, data_batch): + pass + + def log_pred_image( + self, + batch, + pred, + batch_idx, + logger, + **kwargs, + ): + if "pred" in pred: + pred = pred["pred"] + if "image" in pred: + N = pred["image"].shape[0] + h = int(np.ceil(np.sqrt(N))) + w = int(np.ceil(N / h)) + fig, ax = plt.subplots(h, w, figsize=(20 * h, 20 * w)) + canvas = FigureCanvas(fig) + image = pred["image"].detach().cpu().numpy().transpose(0, 2, 3, 1) + for i in range(N): + if h > 1 and w > 1: + hi = int(i / w) + wi = i - w * hi + ax[hi, wi].imshow(image[i]) + + elif h == 1 and w == 1: + ax.imshow(image[i]) + else: + ax[i].imshow(image[i]) + canvas.draw() # draw the canvas, cache the renderer + + image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8") + image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + logger.experiment.log( + {"val_pred_image": [wandb.Image(image, caption="val_pred_image")]} + ) + del fig, ax, image + else: + if "data_batch" in pred: + batch = pred["data_batch"] + else: + if "scene_batch" in batch: + batch = batch["scene_batch"] + elif "agent_batch" in batch: + batch = batch["agent_batch"] + if isinstance(batch, SceneBatch) or isinstance(batch, AgentBatch): + batch = batch.__dict__ + + if "trajectories" not in pred or "image" not in batch: + return + + keys = ["trajectories", "agent_avail", "cond_traj"] + pred = {k: v for k, v in pred.items() if k in keys} + pred = TensorUtils.to_numpy(pred) + + if batch["image"].ndim == 5: + map = batch["image"][0, 0, -3:].cpu().numpy().transpose(1, 2, 0) + else: + map = batch["image"][0, -3:].cpu().numpy().transpose(1, 2, 0) + traj = pred["trajectories"][..., :2] + + avail = pred["agent_avail"].astype(float) + avail1 = avail.copy() + avail1[avail == 0] = np.nan + raster_from_agent = batch["raster_from_agent"].cpu().numpy() + if raster_from_agent.ndim == 2: + raster_from_agent = raster_from_agent[np.newaxis, :] + if traj.ndim == 6: + bs, Ne, numMode, Na, T = traj.shape[:5] + traj = traj * avail1[:, None, None, :, None, None] + else: + bs, numMode, Na, T = traj.shape[:4] + traj = traj * avail1[:, None, :, None, None] + Ne = 1 + raster_traj = GeoUtils.batch_nd_transform_points_np( + traj.reshape(bs, -1, 2), raster_from_agent + ).reshape(bs, Ne, -1, 2) + + cond_traj = pred["cond_traj"] if "cond_traj" in pred else None + + if cond_traj is None: + fig, ax = plt.subplots(figsize=(20, 20)) + canvas = FigureCanvas(fig) + ax.imshow(map) + ax.scatter( + raster_traj[0, 0, ..., 0], + raster_traj[0, 0, ..., 1], + color="c", + s=2, + marker="D", + ) + + else: + raster_cond_traj = GeoUtils.batch_nd_transform_points_np( + cond_traj[..., :2].reshape(bs, -1, 2), raster_from_agent + ).reshape(bs, Ne, -1, 2) + h = int(np.ceil(np.sqrt(Ne))) + w = int(np.ceil(Ne / h)) + fig, ax = plt.subplots(h, w, figsize=(20 * h, 20 * w)) + canvas = FigureCanvas(fig) + for i in range(Ne): + if h > 1 and w > 1: + hi = int(i / w) + wi = i - w * hi + ax[hi, wi].imshow(map) + ax[hi, wi].scatter( + raster_traj[0, i, ..., 0], + raster_traj[0, i, ..., 1], + color="c", + s=1, + marker="D", + ) + ax[hi, wi].scatter( + raster_cond_traj[0, i, :, 0], + raster_cond_traj[0, i, :, 1], + color="m", + s=1, + marker="D", + ) + elif h == 1 and w == 1: + ax.imshow(map) + ax.scatter( + raster_traj[0, i, ..., 0], + raster_traj[0, i, ..., 1], + color="c", + s=1, + marker="D", + ) + ax.scatter( + raster_cond_traj[0, i, :, 0], + raster_cond_traj[0, i, :, 1], + color="m", + s=1, + marker="D", + ) + else: + ax[i].imshow(map) + ax[i].scatter( + raster_traj[0, i, ..., 0], + raster_traj[0, i, ..., 1], + color="c", + s=1, + marker="D", + ) + ax[i].scatter( + raster_cond_traj[0, i, :, 0], + raster_cond_traj[0, i, :, 1], + color="m", + s=1, + marker="D", + ) + canvas.draw() # draw the canvas, cache the renderer + + image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8") + image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + logger.experiment.log( + {"val_pred_image": [wandb.Image(image, caption="val_pred_image")]} + ) + del fig, ax, image + + def set_eval(self): + self.moduleseq.set_eval() + + def set_train(self): + self.moduleseq.set_train() + + def reset(self): + self.moduleseq.reset() diff --git a/diffstack/stacks/pred_stack.py b/diffstack/stacks/pred_stack.py new file mode 100644 index 0000000..f023130 --- /dev/null +++ b/diffstack/stacks/pred_stack.py @@ -0,0 +1,44 @@ +from diffstack.modules.module import Module, ModuleSequence, DataFormat +from trajdata import SceneBatch +import pytorch_lightning as pl +import diffstack.utils.tensor_utils as TensorUtils +import torch +import torch.nn as nn +import numpy as np +import torch.optim as optim +from diffstack.stacks.base import AVStack + + +class PredStack(AVStack): + def __init__(self, modules: ModuleSequence, cfg, batch_size=None, **kwargs): + super().__init__( + modules, cfg, batch_size=cfg.train.training.batch_size, **kwargs + ) + + def configure_optimizers(self): + optim_params = self.cfg.stack["predictor"].optim_params["policy"] + return optim.Adam( + params=self.parameters(), + lr=optim_params["learning_rate"]["initial"], + weight_decay=optim_params["regularization"]["L2"], + ) + + def _compute_metrics(self, pred_batch, data_batch): + metrics_dict = dict() + if hasattr(self.components["predictor"], "compute_metrics"): + metrics_dict.update( + self.components["predictor"].compute_metrics(pred_batch, data_batch) + ) + if ( + "trajectories" in pred_batch + and pred_batch["trajectories"].ndim == 5 + and pred_batch["trajectories"].size(1) > 0 + ): + metrics_dict["mode_diversity"] = ( + (pred_batch["trajectories"][:, 0] - pred_batch["trajectories"][:, 1]) + .norm() + .detach() + .cpu() + ) + + return metrics_dict diff --git a/diffstack/stacks/stack_factory.py b/diffstack/stacks/stack_factory.py new file mode 100644 index 0000000..7927497 --- /dev/null +++ b/diffstack/stacks/stack_factory.py @@ -0,0 +1,82 @@ +"""Factory methods for creating models""" +from diffstack.modules.module import ModuleSequence +from diffstack.modules.predictors.factory import ( + predictor_factory, +) +from diffstack.configs.config import Dict +import torch +from collections import OrderedDict, defaultdict +from pathlib import Path +from typing import List + +from diffstack.stacks.pred_stack import PredStack +from omegaconf import DictConfig, OmegaConf, open_dict + + +def get_checkpoint_dict(stack_cfg: DictConfig, module_names: List, device: str): + checkpoint_dict = defaultdict(lambda: None) + for module_name in module_names: + if ( + "load_checkpoint" in stack_cfg[module_name] + and stack_cfg[module_name].load_checkpoint + ): + ckpt_path = Path(stack_cfg[module_name].load_checkpoint).expanduser() + checkpoint_dict[module_name] = torch.load(ckpt_path, map_location=device) + return checkpoint_dict + + +def stack_factory(cfg: DictConfig, model_registrar=None, log_writer=None, device=None): + """ + A factory for creating training stacks + + Args: + cfg (ExperimentConfig): an ExperimentConfig object, + Returns: + stack: pl.LightningModule + """ + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + if cfg.stack.stack_type == "pred": + predictor_config = cfg.stack["predictor"] + + checkpoint_dict = get_checkpoint_dict(cfg.stack, ["predictor"], device) + if "env" in cfg: + if isinstance(predictor_config, Dict): + predictor_config.unlock() + predictor_config.env = cfg.env + predictor_config.lock() + elif isinstance(predictor_config, DictConfig): + OmegaConf.set_struct(predictor_config, True) + with open_dict(predictor_config): + predictor_config.env = cfg.env + predictor = predictor_factory( + model_registrar, + predictor_config, + log_writer, + device, + checkpoint=checkpoint_dict["predictor"], + ) + + modules = ModuleSequence( + OrderedDict(predictor=predictor), + model_registrar, + predictor_config, + log_writer, + device, + ) + if ( + "monitor_key" in predictor_config + and predictor_config["monitor_key"] is not None + ): + avstack = PredStack( + modules, cfg, monitor_key=predictor_config["monitor_key"] + ) + else: + avstack = PredStack( + modules, cfg, monitor_key=predictor.checkpoint_monitor_keys + ) + return avstack + + else: + raise NotImplementedError("The type of stack structure is not recognized!") diff --git a/diffstack/train.py b/diffstack/train.py deleted file mode 100644 index 8f779f5..0000000 --- a/diffstack/train.py +++ /dev/null @@ -1,588 +0,0 @@ -import sys - -import torch -import numpy as np -import os -import time -import json -import pathlib -import wandb - -from collections import defaultdict -from tqdm import tqdm, trange -from torch import nn, optim -from typing import Dict -from random import random - -# Dataset related -from diffstack.data.trajdata_interface import prepare_avdata -from diffstack.data.cached_nusc_as_trajdata import prepare_cache_to_avdata -from diffstack.utils.model_registrar import ModelRegistrar - -# Model related -from diffstack.modules.diffstack import DiffStack -from diffstack.argument_parser import args, get_hyperparams, print_hyperparams_summary -# from diffstack.closed_loop_eval import simulate_scenarios_in_scene - -from diffstack.utils.utils import initialize_torch_distributed, prepeare_torch_env, set_all_seeds, all_gather -from torch.nn.parallel import DistributedDataParallel as DDP - -import matplotlib.pyplot as plt - -# Visualization -from diffstack.utils import visualization as plan_vis - -from trajdata import AgentBatch, AgentType - - -def train(rank, args): - hyperparams = get_hyperparams(args) - # del args - - prepeare_torch_env(rank, hyperparams) - log_writer, model_dir = prepare_logging(rank, hyperparams) - - if rank == 0: - print_hyperparams_summary(hyperparams) - - device = hyperparams["device"] - - ################################# - # PREPARE MODEL AND DATA # - ################################# - - # Create model first. We need this before caching training data with gt plan. - model_registrar = ModelRegistrar(model_dir, device) - if hyperparams["load"]: - # Directly adding to pythonpath for legacy reasons. - from diffstack.modules.predictors import trajectron_utils - sys.path.append(trajectron_utils.__path__[0]) - model_registrar.load_model_from_file(hyperparams["load"], except_contains=["planner_cost"]) - - diffstack = DiffStack(model_registrar, hyperparams, log_writer, device) - print(f'Rank {rank}: Created Training Model.') - - if hyperparams["data_source"] == "trajdata": - train_dataloader, train_sampler, train_dataset, eval_dataloader, eval_sampler, eval_dataset, input_wrapper = prepare_avdata(rank, hyperparams, scene_centric=False) - elif hyperparams["data_source"] == "trajdata-scene": - train_dataloader, train_sampler, train_dataset, eval_dataloader, eval_sampler, eval_dataset, input_wrapper = prepare_avdata(rank, hyperparams, scene_centric=True) - elif hyperparams["data_source"] == "cache": - # Load original cached data and convert it to trajdata - train_dataloader, train_sampler, train_dataset, eval_dataloader, eval_sampler, eval_dataset, input_wrapper = prepare_cache_to_avdata(rank, hyperparams, args, diffstack) - else: - raise ValueError(f"Unknown data_source {hyperparams['data_source']}") - - if torch.cuda.is_available() and device != 'cpu': - diffstack = DDP(diffstack, - device_ids=[rank], - output_device=rank, - find_unused_parameters=True) - diffstack_module = diffstack.module - # pkarkus: DDP moves tensors to GPU (except for dilled dicts when using more than 1 worker) - # using this function we can replicate the same for eval. - # input_wrapper = lambda inputs, **kwargs: trajectron.to_kwargs(inputs, kwargs, trajectron.device_ids[0]) - else: - diffstack_module = diffstack - # input_wrapper = lambda inputs, **kwargs: (inputs, kwargs) - - # Initialize optimizer - lr_scheduler = None - step_scheduler = None - plan_cost_lr = hyperparams['cost_grad_scaler'] * hyperparams['learning_rate'] - optimizer = optim.Adam([{'params': model_registrar.get_all_but_name_match('map_encoder').parameters()}, - {'params': model_registrar.get_name_match('map_encoder').parameters(), - 'lr': hyperparams['map_enc_learning_rate']}, - {'params': model_registrar.get_name_match('planner_cost').parameters(), - 'lr': plan_cost_lr} - ], - lr=hyperparams['learning_rate']) - # Set Learning Rate - if hyperparams['learning_rate_style'] == 'const': - lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=1.0) - elif hyperparams['learning_rate_style'] == 'exp': - lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, - gamma=hyperparams['learning_decay_rate']) - - if hyperparams['lr_step'] != 0: - step_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=hyperparams['lr_step'], gamma=0.1) - - - ################################## - # VALIDATE FUNCTIONS # - ################################## - def log_batch_errors(batch_errors, eval_metrics, - log_writer, namespace, epoch, curr_iter, - bar_plot=[], box_plot=[]): - for node_type in batch_errors.keys(): - if eval_metrics is None: - # Log all metrics if eval_metrics is not provided - eval_metrics = batch_errors[node_type].keys() - for metric in eval_metrics: - metric_batch_error: np.ndarray = np.concatenate(batch_errors[node_type][metric]) - - if len(metric_batch_error) > 0 and log_writer is not None: - log_writer.log({ - # f"{node_type.name}/{namespace}/{metric}_hist": wandb.Histogram(metric_batch_error), - # f"{node_type.name}/{namespace}/{metric}_min": np.min(metric_batch_error), - f"{node_type.name}/{namespace}/{metric}_mean": np.mean(metric_batch_error), - # f"{node_type.name}/{namespace}/{metric}_median": np.median(metric_batch_error), - # f"{node_type.name}/{namespace}/{metric}_max": np.max(metric_batch_error), - f"{node_type.name}/epoch": epoch, # TODO(pkarkus) this is only for compatibility, remove it - }, step=curr_iter, commit=False) - - if log_writer is not None: - log_writer.log({"epoch": epoch, "global_step": curr_iter}, step=curr_iter, commit=True) - - def run_validation(epoch, curr_iter): - with torch.no_grad(): - # Calculate evaluation loss - epoch_metrics: Dict[str, list] = defaultdict(list) - - # Compute metrics over validation set - batch: AgentBatch - for batch in tqdm(eval_dataloader, ncols=80, unit_scale=hyperparams["world_size"], - disable=(rank > 0), desc=f'Epoch {epoch} Eval'): - - with torch.no_grad(): - outputs = diffstack_module.validate(input_wrapper(batch)) - metrics: Dict[str, torch.Tensor] = outputs["metrics"] - - for k, v in metrics.items(): - epoch_metrics[k].append(v.cpu().numpy()) - - # Gather results from GPUs - if hyperparams["world_size"] > 1: - gathered_values = all_gather(epoch_metrics) - - if rank == 0: - epoch_metrics = defaultdict(list) - for partial_epoch_metrics in gathered_values: - for k, v in partial_epoch_metrics.items(): - epoch_metrics[k].extend(v) - - # Log - if rank == 0: - log_batch_errors({AgentType.UNKNOWN: epoch_metrics}, - None, - log_writer, - 'eval', - epoch, - curr_iter) - eval_loss_sorted = {k: np.sort(np.concatenate(vals)) for k, vals in epoch_metrics.items()} - print (f"Eval epoch {epoch}:") - for topn in [100]: # , 20, 10]: - print (f"Top {topn}%:") - for k, vals in eval_loss_sorted.items(): - start_i = (100-topn)*len(vals)//100 - topn_vals = vals[start_i:] - print (f" {k} {topn}: {topn_vals.mean():.4f}") - - pass - - - def run_closed_loop_eval(epoch, curr_iter, all_scenarios_in_scene=False): - if hyperparams["cl_trajlen"] <= 0: - print ("No closed loop evaluation.") - return - - raise NotImplementedError("Need to access scenes in environment dataset and run custom preprocessing") - - # Run closed loop replanning at different frequencies - replan_every_ns = [1] #, 4, 5, 6] - - with torch.no_grad(): - # Calculate evaluation loss - for node_type, data_loader in eval_data_loader.items(): - if rank == 0: - print(f"Starting closed loop evaluation @ epoch {epoch} for node type: {node_type}") - - # TODO support more workers - env = eval_dataset.env - scenes_for_worker = env.scenes - if rank == 0: - for scene in tqdm(scenes_for_worker, disable=(rank > 0), desc=f'Scene '): - eval_loss = simulate_scenarios_in_scene(diffstack_module, nusc_maps, env, scene, node_type, hyperparams, replan_every_ns=replan_every_ns, all_scenarios_in_scene=False) - else: - eval_loss = defaultdict(list) - - - if torch.distributed.get_world_size() > 1: - gathered_values = all_gather(eval_perf) - if rank == 0: - eval_perf = [] - for eval_dicts in gathered_values: - eval_perf.extend(eval_dicts) - - if rank == 0: - log_batch_errors(eval_loss, - None, - log_writer, - 'eval', - epoch, - curr_iter) - eval_loss_sorted = {k: np.sort(np.concatenate(vals)) for k, vals in eval_loss.items()} - print (f"Eval epoch {epoch}:") - for topn in [100]: - print (f"Top {topn}%%:") - for k, vals in eval_loss_sorted.items(): - start_i = (100-topn)*len(vals)//100 - topn_vals = vals[start_i:] - print (f" {k} {topn}: {topn_vals.mean():.6f}") - - pass - - ################################# - # VISUALIZATION # - ################################# - - def run_visualization(epoch, curr_iter, dataset_name, dataset = None, dataloader = None, num_plots=10): - with torch.no_grad(): - if dataset is not None: - batch_idxs = random.sample(range(len(dataset)), num_plots) - batch: AgentBatch = dataset.get_collate_fn(pad_format="right")( - [dataset[i] for i in batch_idxs] - ) - elif dataloader is not None: - for batch in dataloader: - break - else: - raise ValueError("Need to specifiy dataset or data_loader") - - outputs = diffstack_module.validate(input_wrapper(batch)) - plan_xu = outputs['plan.plan_xu'].cpu().numpy() - - # Only keep elements in batch that are valid for planning - batch = batch.filter_batch(outputs['plan.valid']) - - # TODO this is not fixed, wrongly indexing plan output - images = list() - for batch_idx in trange(min(batch.agent_fut.shape[0], num_plots), desc="Visualizing Random Predictions"): - plan_x = plan_xu[:, batch_idx, :4] - - fig, ax = plt.subplots() - if 'plan.fan.candidate_xu' in outputs: - plan_candidates_x = outputs['plan.fan.candidate_xu'][batch_idx][..., :2].cpu().numpy() # N, T+1, 6 - plan_vis.plot_plan_candidates(plan_candidates_x, batch_idx=batch_idx, ax=ax) - - plan_vis.plot_plan_input_batch(batch, batch_idx=batch_idx, ax=ax, legend=False, show=False, close=False) - plan_vis.plot_plan_result(plan_x, ax=ax) - - # Legend - plan_vis.legend_unique_labels(ax, loc="best", frameon=True) - - images.append(wandb.Image( - fig, - caption=f"Batch_idx: {batch_idx}" # " Pred agent: {batch.agent_name[batch_idx]}" - )) - - if log_writer: - log_writer.log({f"{dataset_name}/predictions_viz": images}, step=curr_iter) - - if hyperparams["debug"]: - print ("Breakpoint here for plot interaction") - plt.close("all") - - - - ################################# - # TRAINING # - ################################# - - # Start with eval when loading pretrained model. - if hyperparams["eval_every"] is not None and hyperparams["eval_every"] > 0: - # if rank == 0 and hyperparams["vis_every"] is not None and hyperparams["vis_every"] > 0: - # # run_visualization(0, 0, "eval", eval_dataset) - # run_visualization(0, 0, "eval", dataloader=eval_dataloader) - run_closed_loop_eval(0, 0) - run_validation(0, 0) - - print (diffstack_module.get_params_summary_text()) - - curr_iter: int = 0 - for epoch in range(1, hyperparams['train_epochs'] + 1): - train_sampler.set_epoch(epoch) - pbar = tqdm(train_dataloader, ncols=80, unit_scale=hyperparams["world_size"], disable=(rank > 0)) - - # prof = torch.profiler.profile( - # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), - # on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler/tpp_unified'), - # record_shapes=True, - # profile_memory=True, - # with_stack=True - # ) - # prof.start() - - # initialize the timer for the 1st iteration - step_timer_start = time.time() - - plan_loss_epoch = [] - pred_loss_epoch = [] - plan_metrics_epoch = defaultdict(list) - - batch: AgentBatch - for batch_idx, batch in enumerate(pbar): - diffstack_module.set_curr_iter(curr_iter) - diffstack_module.step_annealers() - - # Manually fix seed - # set_all_seeds(100) - - optimizer.zero_grad(set_to_none=True) - - outputs = diffstack_module(input_wrapper(batch)) - - train_loss = outputs["loss"] - pred_loss = outputs["pred.loss"] - plan_loss = outputs["plan.loss"] - plan_metrics = outputs["plan.metrics"] - - pbar.set_description(f"Epoch {epoch} L: {train_loss.detach().item():.2f}") - - train_loss.backward() - - # Scale down gradients for planning cost - if hyperparams['train_plan_cost']: - for param in model_registrar.get_name_match('planner_cost').parameters(): - if param.grad is None: - continue - param.grad = param.grad * 0.01 # TODO this has no effect !! but keep it for paper push - - # Clipping gradients. - if hyperparams['grad_clip'] is not None: - nn.utils.clip_grad_value_(model_registrar.get_all_but_name_match(['planner_cost']).parameters(), hyperparams['grad_clip']) - - # # Debug gradients - # if batch_idx == 0: - # print ("First batch:") - # torch.set_printoptions(precision=10, linewidth=160) - # (man_first_history_index, - # man_x, man_y, man_x_st_t, man_y_st_t, - # man_neighbors_data_st, # dict of lists. edge_type -> [batch][neighbor]: Tensor(time, statedim). Represetns - # man_neighbors_edge_value, - # man_robot_traj_st_t, - # man_map, neighbors_future_data, plan_data) = batch.extras["manual_inputs"] - # print (man_x_st_t.nan_to_num(0.001).sum().cpu()) - # print (train_loss) - - # grads = [] - # for param in model_registrar.parameters(): - # if param.grad is None: - # continue - # grads.append(param.grad.sum()) - # gradsum = torch.stack(grads).sum() - # print ("Grad: ", gradsum) - # print (grads) - - # # Check gradients for nans - # is_grad_nan = False - # for param in model_registrar.parameters(): - # if param.grad is None: - # continue - # is_grad_nan = is_grad_nan or bool(torch.isnan(param.grad).any()) - # if is_grad_nan: - # print (batch) - # print ("IsNAN:") - # print (bool(torch.isnan(train_loss).any())) - # print (bool(torch.isnan(plan_loss).any())) - # print (bool(torch.isnan(pred_loss).any())) - - # # Validate inputs - # print (batch.agent_fut.isnan().any()) - # print (any([batch.agent_hist[i, :batch.agent_hist_len[i]].isnan().any() for i in range(256)])) - # print (any([batch.agent_hist[i, :batch.agent_hist_len[i]].isnan().any() for i in range(256)])) - # print (any([any([batch.neigh_fut[i][j][:batch.neigh_fut_len[i, j]].isnan().any() for j in range(batch.num_neigh[i]) ]) for i in range(256)])) - # print (any([any([batch.neigh_hist[i][j][:batch.neigh_hist_len[i, j]].isnan().any() for j in range(batch.num_neigh[i]) ]) for i in range(256)])) - - # # Run TPP - # node_type = AgentType.VEHICLE - # model = trajectron_module.pred_obj.node_models_dict[node_type.name] - # from model.model_utils import ModeKeys - # mode = ModeKeys.TRAIN - - # # encoder - # x, x_nr_t, y_e, y_r, y, n_s_t0, dt = model.obtain_encoded_tensors(mode, batch) - # print (any(tensor.isnan().any() for tensor in [x, x_nr_t, y_e, y_r, y, n_s_t0, dt] if tensor is not None)) - - # z, kl = model.encoder(mode, x, y_e) - # print (any(tensor.isnan().any() for tensor in [z, kl] if tensor is not None)) - - # log_p_y_xz, y_dist = model.decoder(mode, x, None, y, None, n_s_t0, z, dt, hyperparams['k'], ret_dist=True) - # print (any(tensor.isnan().any() for tensor in [log_p_y_xz] if tensor is not None)) - - # # Call forward pass again for an opportunity to debug - # for _ in range(20): - # train_loss2, plan_loss2, pred_loss2, _, plan_metrics2 = trajectron_module(batch, return_debug=True) - - optimizer.step() - - # Stepping forward the learning rate scheduler and annealers. - lr_scheduler.step() - if rank == 0 and not hyperparams['debug']: - step_timer_stop = time.time() - elapsed = step_timer_stop - step_timer_start - - log_writer.log({ - "train/learning_rate": lr_scheduler.get_last_lr()[0], - "train/loss": train_loss.detach().item(), - "steps_per_sec": 1 / elapsed, - "epoch": epoch, - "batch": batch_idx, - "global_step": curr_iter, # TODO kept this for compatibility with old tensorboard based logging - }, step=curr_iter, commit=True) - - # Accumulate metrics - # TODO (pkarkus) remove losses and handle everything inside metrics - # TODO (pkarkus) remove losses and handle everything inside metrics - # TODO (pkarkus) remove losses and handle everything inside metrics - plan_loss_epoch.append(plan_loss.detach().cpu()) - pred_loss_epoch.append(pred_loss.detach().cpu()) - for k, v in plan_metrics.items(): - plan_metrics_epoch[k].append(v.detach().cpu()) - - curr_iter += 1 - - # initialize the timer for the following iteration - step_timer_start = time.time() - - # prof.step() - # Log batch - # TODO (pkarkus) simplify this - # TODO filter by node type - - # Accumulate metrics over epoch - pred_loss_epoch = torch.stack(pred_loss_epoch, dim=0) - plan_loss_epoch = torch.stack(plan_loss_epoch, dim=0) - plan_metrics_epoch = {k: torch.cat(v, dim=0) for k, v in plan_metrics_epoch.items()} - if torch.distributed.is_initialized() and torch.distributed.get_world_size() > 1: # torch.cuda.is_available() and - all_plan_loss_batch = all_gather(plan_loss_epoch) - plan_loss_epoch = torch.cat(all_plan_loss_batch, dim=0) - all_pred_loss_batch = all_gather(pred_loss_epoch) - pred_loss_epoch = torch.cat(all_pred_loss_batch, dim=0) - all_plan_metrics_batch = all_gather(plan_metrics_epoch) - plan_metrics_epoch = defaultdict(list) - for plan_metrics_batch in all_plan_metrics_batch: - for k, v in plan_metrics_batch.items(): - plan_metrics_epoch[k].append(v) - plan_metrics_epoch = {k: torch.cat(v, dim=0) for k, v in plan_metrics_epoch.items()} - - # Log epoch stats - if rank == 0: - pred_loss_epoch = pred_loss_epoch.mean() - plan_loss_epoch = plan_loss_epoch.mean() - - if log_writer: - log_writer.log({ - f"train/epoch_pred_loss": pred_loss_epoch.detach().item() - }, step=curr_iter, commit=False) - log_writer.log({ - f"train/epoch_plan_loss": plan_loss_epoch.mean().detach().item() - }, step=curr_iter, commit=False) - - print (f"Epoch {epoch} pred_loss {pred_loss_epoch} plan_loss {plan_loss_epoch}.") - - if "fan_valid" in plan_metrics_epoch: - valid_filter = plan_metrics_epoch["fan_valid"] - for k, v in plan_metrics_epoch.items(): - v_mean = v.float().mean().detach().item() - v_mean_valid = v[valid_filter].float().mean().detach().item() - - if log_writer: - log_writer.log({ - f"train/epoch_{k}": v_mean, - f"train/epoch_{k}_valid": v_mean_valid, - }, step=curr_iter, commit=False) - - print (f"{k}: {v_mean} {v_mean_valid}") - - print (diffstack_module.get_params_summary_text()) - - del plan_loss_epoch - del pred_loss_epoch - del plan_metrics_epoch - - # prof.stop() - # raise - if hyperparams['lr_step'] != 0: - step_scheduler.step() - - - ################################# - # VALIDATION # - ################################# - if hyperparams["eval_every"] is not None and hyperparams["eval_every"] > 0 and epoch % hyperparams["eval_every"] == 0 and epoch > 0: - run_validation(epoch, curr_iter) - - if rank == 0 and (hyperparams["save_every"] is not None and not hyperparams["debug"] and epoch % hyperparams["save_every"] == 0): - model_registrar.save_models(epoch) - - ################################# - # VISUALIZATION # - ################################# - if rank == 0 and (hyperparams["planner"] not in ["", "none"]) and (hyperparams["vis_every"] is not None and hyperparams["vis_every"] > 0 and epoch % hyperparams["vis_every"] == 0 and epoch > 0): - # run_visualization(epoch, curr_iter, "eval", eval_dataset) - run_visualization(epoch, curr_iter, "eval", dataloader=eval_dataloader) - - # Waiting for process 0 to be done its evaluation and visualization. - if torch.distributed.is_initialized() and torch.cuda.is_available(): - torch.distributed.barrier() - - return model_dir - - -def prepare_logging(rank, hyperparams): - # Logging - log_writer = None - model_dir = None - if not hyperparams["debug"]: - # Create the log and model directory if they're not present. - model_dir = os.path.join(hyperparams["log_dir"], - hyperparams["experiment"] + time.strftime('-%d_%b_%Y_%H_%M_%S', time.localtime())) - hyperparams["logdir"] = model_dir - - if rank == 0: - pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True) - - # Save config to model directory - with open(os.path.join(model_dir, 'config.json'), 'w') as conf_json: - json.dump(hyperparams, conf_json) - - # wandb.tensorboard.patch(root_logdir=model_dir, pytorch=True) - - # WandB init. Put it in a loop because it can fail on ngc. - for _ in range(10): - try: - log_writer = wandb.init( - project="debug" if hyperparams["debug"] else hyperparams["experiment"], name=f"{hyperparams['experiment']}", - config=hyperparams, mode="offline" if hyperparams["debug"] else "online") # sync_tensorboard=True, - except: - continue - break - else: - raise ValueError("Could not connect to wandb") - - artifact = wandb.Artifact('logdir', type='path') - artifact.add_dir(model_dir) - wandb.log_artifact(artifact) - - # log_writer = SummaryWriter(log_dir=model_dir) - print (f"Log path: {model_dir}") - return log_writer, model_dir - - -if __name__ == '__main__': - - if torch.distributed.is_torchelastic_launched(): - local_rank = int(os.environ["LOCAL_RANK"]) - initialize_torch_distributed(local_rank) - - print( - f"[{os.getpid()}]: world_size = {torch.distributed.get_world_size()}, " - + f"rank = {torch.distributed.get_rank()}, backend={torch.distributed.get_backend()}, " - + f"port = {os.environ['MASTER_PORT']} \n", end='' - ) - else: - local_rank = 0 - - log_dir = train(local_rank, args) - - wandb.finish() diff --git a/diffstack/utils/algo_utils.py b/diffstack/utils/algo_utils.py new file mode 100644 index 0000000..2d77661 --- /dev/null +++ b/diffstack/utils/algo_utils.py @@ -0,0 +1,283 @@ +import torch +from torch import optim as optim + +import diffstack.utils.tensor_utils as TensorUtils +from diffstack import dynamics as dynamics +from diffstack.utils.batch_utils import batch_utils +from diffstack.utils.geometry_utils import transform_points_tensor, calc_distance_map +from diffstack.utils.l5_utils import get_last_available_index +from diffstack.utils.loss_utils import goal_reaching_loss, trajectory_loss, collision_loss + + +def generate_proxy_mask(orig_loc, radius,mode="L1"): + """ mask out area near the existing samples to boost diversity + + Args: + orig_loc (torch.tensor): original sample location, 1 for sample, 0 for background + radius (int): radius in pixel space + mode (str, optional): Defaults to "L1". + + Returns: + torch.tensor[dtype=torch.bool]: mask for generating new samples + """ + dis_map = calc_distance_map(orig_loc,max_dis = radius+1,mode=mode) + return dis_map<=radius + + +def decode_spatial_prediction(prob_map, residual_yaw_map, num_samples=None, clearance=None): + """ + Decode spatial predictions (e.g., UNet output) to a list of locations + Args: + prob_map (torch.Tensor): probability of each spatial location [B, H, W] + residual_yaw_map (torch.Tensor): location residual (d_x, d_yy) and yaw of each location [B, 3, H, W] + num_samples (int): (optional) if specified, take # of samples according to the discrete prob_map distribution. + default is None, which is to take the max + Returns: + pixel_loc (torch.Tensor): coordinates of each predicted location before applying residual [B, N, 2] + residual_pred (torch.Tensor): residual of each predicted location [B, N, 2] + yaw_pred (torch.Tensor): yaw of each predicted location [B, N, 1] + pixel_prob (torch.Tensor): probability of each sampled prediction [B, N] + """ + # decode map as predictions + b, h, w = prob_map.shape + flat_prob_map = prob_map.flatten(start_dim=1) + if num_samples is None: + # if num_samples is not specified, take the maximum-probability location + pixel_prob, pixel_loc_flat = torch.max(flat_prob_map, dim=1) + pixel_prob = pixel_prob.unsqueeze(1) + pixel_loc_flat = pixel_loc_flat.unsqueeze(1) + else: + # otherwise, use the probability map as a discrete distribution of location predictions + dist = torch.distributions.Categorical(probs=flat_prob_map) + if clearance is None: + pixel_loc_flat = dist.sample((num_samples,)).permute( + 1, 0) # [n_sample, batch] -> [batch, n_sample] + else: + proximity_map = torch.ones_like(prob_map,requires_grad=False) + pixel_loc_flat = list() + for i in range(num_samples): + dist = torch.distributions.Categorical(probs=flat_prob_map*proximity_map.flatten(start_dim=1)) + sample_i = dist.sample() + sample_mask = torch.zeros_like(flat_prob_map,dtype=torch.bool) + sample_mask[torch.arange(b),sample_i]=True + proxy_mask = generate_proxy_mask(sample_mask.reshape(-1,h,w),clearance) + proximity_map = torch.logical_or(proximity_map,torch.logical_not(proxy_mask)) + pixel_loc_flat.append(sample_i) + + pixel_loc_flat = torch.stack(pixel_loc_flat,dim=1) + pixel_prob = torch.gather(flat_prob_map, dim=1, index=pixel_loc_flat) + + local_pred = torch.gather( + input=torch.flatten(residual_yaw_map, 2), # [B, C, H * W] + dim=2, + index=TensorUtils.unsqueeze_expand_at( + pixel_loc_flat, size=3, dim=1) # [B, C, num_samples] + ).permute(0, 2, 1) # [B, C, N] -> [B, N, C] + + residual_pred = local_pred[:, :, 0:2] + yaw_pred = local_pred[:, :, 2:3] + + pixel_loc_x = torch.remainder(pixel_loc_flat, w).float() + pixel_loc_y = torch.floor(pixel_loc_flat.float() / float(w)).float() + pixel_loc = torch.stack((pixel_loc_x, pixel_loc_y), dim=-1) # [B, N, 2] + + return pixel_loc, residual_pred, yaw_pred, pixel_prob + + +def get_spatial_goal_supervision(data_batch): + """Get supervision for training the spatial goal network.""" + b, _, h, w = data_batch["image"].shape # [B, C, H, W] + + # use last available step as goal location + goal_index = get_last_available_index( + data_batch["fut_mask"])[:, None, None] + + # gather by goal index + goal_pos_agent = torch.gather( + data_batch["fut_pos"], # [B, T, 2] + dim=1, + index=goal_index.expand(-1, 1, + data_batch["fut_pos"].shape[-1]) + ) # [B, 1, 2] + + goal_yaw_agent = torch.gather( + data_batch["fut_yaw"], # [B, T, 1] + dim=1, + index=goal_index.expand(-1, 1, data_batch["fut_yaw"].shape[-1]) + ) # [B, 1, 1] + + # create spatial supervisions + goal_pos_raster = transform_points_tensor( + goal_pos_agent, + data_batch["raster_from_agent"].float() + ).squeeze(1) # [B, 2] + # make sure all pixels are within the raster image + goal_pos_raster[:, 0] = goal_pos_raster[:, 0].clip(0, w - 1e-5) + goal_pos_raster[:, 1] = goal_pos_raster[:, 1].clip(0, h - 1e-5) + + goal_pos_pixel = torch.floor(goal_pos_raster).float() # round down pixels + # compute rounding residuals (range 0-1) + goal_pos_residual = goal_pos_raster - goal_pos_pixel + # compute flattened pixel location + goal_pos_pixel_flat = goal_pos_pixel[:, 1] * w + goal_pos_pixel[:, 0] + raster_sup_flat = TensorUtils.to_one_hot( + goal_pos_pixel_flat.long(), num_class=h * w) + raster_sup = raster_sup_flat.reshape(b, h, w) + return { + "goal_position_residual": goal_pos_residual, # [B, 2] + "goal_spatial_map": raster_sup, # [B, H, W] + "goal_position_pixel": goal_pos_pixel, # [B, 2] + "goal_position_pixel_flat": goal_pos_pixel_flat, # [B] + "goal_position": goal_pos_agent.squeeze(1), # [B, 2] + "goal_yaw": goal_yaw_agent.squeeze(1), # [B, 1] + "goal_index": goal_index.reshape(b) # [B] + } + + +def get_spatial_trajectory_supervision(data_batch): + """Get supervision for training the learned occupancy metric.""" + b, _, h, w = data_batch["image"].shape # [B, C, H, W] + t = data_batch["fut_pos"].shape[-2] + # create spatial supervisions + pos_raster = transform_points_tensor( + data_batch["fut_pos"], + data_batch["raster_from_agent"].float() + ) # [B, T, 2] + # make sure all pixels are within the raster image + pos_raster[..., 0] = pos_raster[..., 0].clip(0, w - 1e-5) + pos_raster[..., 1] = pos_raster[..., 1].clip(0, h - 1e-5) + + pos_pixel = torch.floor(pos_raster).float() # round down pixels + + # compute flattened pixel location + pos_pixel_flat = pos_pixel[..., 1] * w + pos_pixel[..., 0] + raster_sup_flat = TensorUtils.to_one_hot( + pos_pixel_flat.long(), num_class=h * w) + raster_sup = raster_sup_flat.reshape(b, t, h, w) + return { + "traj_spatial_map": raster_sup, # [B, T, H, W] + "traj_position_pixel": pos_pixel, # [B, T, 2] + "traj_position_pixel_flat": pos_pixel_flat # [B, T] + } + + +def optimize_trajectories( + init_u, + init_x, + target_trajs, + target_avails, + dynamics_model, + step_time: float, + data_batch=None, + goal_loss_weight=1.0, + traj_loss_weight=0.0, + coll_loss_weight=0.0, + num_optim_iterations: int = 50 +): + """An optimization-based trajectory generator""" + curr_u = init_u.detach().clone() + curr_u.requires_grad = True + action_optim = optim.LBFGS( + [curr_u], max_iter=20, lr=1.0, line_search_fn='strong_wolfe') + + for oidx in range(num_optim_iterations): + def closure(): + action_optim.zero_grad() + + # get trajectory with current params + x = dynamics_model.forward_dynamics( + x0=init_x, + u=curr_u, + ) + pos = dynamics_model.state2pos(x) + yaw = dynamics_model.state2yaw(x) + curr_trajs = torch.cat((pos, yaw), dim=-1) + # compute trajectory optimization losses + losses = dict() + losses["goal_loss"] = goal_reaching_loss( + predictions=curr_trajs, + targets=target_trajs, + availabilities=target_avails + ) * goal_loss_weight + losses["traj_loss"] = trajectory_loss( + predictions=curr_trajs, + targets=target_trajs, + availabilities=target_avails + ) * traj_loss_weight + if coll_loss_weight > 0: + assert data_batch is not None + coll_edges = batch_utils().get_edges_from_batch( + data_batch, + ego_predictions=dict(positions=pos, yaws=yaw) + ) + for c in coll_edges: + coll_edges[c] = coll_edges[c][:, :target_trajs.shape[-2]] + vv_edges = dict(VV=coll_edges["VV"]) + if vv_edges["VV"].shape[0] > 0: + losses["coll_loss"] = collision_loss( + vv_edges) * coll_loss_weight + + total_loss = torch.hstack(list(losses.values())).sum() + + # backprop + total_loss.backward() + return total_loss + action_optim.step(closure) + + final_raw_trajs = dynamics_model.forward_dynamics( + x0=init_x, + u=curr_u, + ) + final_pos = dynamics_model.state2pos(final_raw_trajs) + final_yaw = dynamics_model.state2yaw(final_raw_trajs) + final_trajs = torch.cat((final_pos, final_yaw), dim=-1) + losses = dict() + losses["goal_loss"] = goal_reaching_loss( + predictions=final_trajs, + targets=target_trajs, + availabilities=target_avails + ) + losses["traj_loss"] = trajectory_loss( + predictions=final_trajs, + targets=target_trajs, + availabilities=target_avails + ) + + return dict(positions=final_pos, yaws=final_yaw), final_raw_trajs, curr_u, losses + + +def combine_ego_agent_data(batch, ego_keys, agent_keys, mask=None): + assert len(ego_keys) == len(agent_keys) + combined_batch = dict() + for ego_key, agent_key in zip(ego_keys, agent_keys): + if mask is None: + size_dim0 = batch[agent_key].shape[0]*batch[agent_key].shape[1] + combined_batch[ego_key] = torch.cat((batch[ego_key], batch[agent_key].reshape( + size_dim0, *batch[agent_key].shape[2:])), dim=0) + else: + size_dim0 = mask.sum() + combined_batch[ego_key] = torch.cat((batch[ego_key], batch[agent_key][mask].reshape( + size_dim0, *batch[agent_key].shape[2:])), dim=0) + return combined_batch + + +def yaw_from_pos(pos: torch.Tensor, dt, yaw_correction_speed=0.): + """ + Compute yaws from position sequences. Optionally suppress yaws computed from low-velocity steps + + Args: + pos (torch.Tensor): sequence of positions [..., T, 2] + dt (float): delta timestep to compute speed + yaw_correction_speed (float): zero out yaw change when the speed is below this threshold (noisy heading) + + Returns: + accum_yaw (torch.Tensor): sequence of yaws [..., T-1, 1] + """ + + pos_diff = pos[..., 1:, :] - pos[..., :-1, :] + yaw = torch.atan2(pos_diff[..., 1], pos_diff[..., 0]) + delta_yaw = torch.cat((yaw[..., [0]], yaw[..., 1:] - yaw[..., :-1]), dim=-1) + speed = torch.norm(pos_diff, dim=-1) / dt + delta_yaw[speed < yaw_correction_speed] = 0. + accum_yaw = torch.cumsum(delta_yaw, dim=-1) + return accum_yaw[..., None] diff --git a/diffstack/utils/batch_utils.py b/diffstack/utils/batch_utils.py new file mode 100644 index 0000000..ec5747a --- /dev/null +++ b/diffstack/utils/batch_utils.py @@ -0,0 +1,299 @@ +import torch + +import diffstack.utils.trajdata_utils as av_utils +from diffstack import dynamics as dynamics +from diffstack.configs.base import ExperimentConfig + + +global BATCH_TYPE + +BATCH_TYPE = "trajdata" +# def set_global_batch_type(batch_type): +# global BATCH_TYPE +# assert batch_type in ["trajdata", "l5kit"] +# BATCH_TYPE = batch_type + + +def batch_utils(**kwargs): + if BATCH_TYPE == "trajdata": + return trajdataBatchUtils(**kwargs) + else: + raise NotImplementedError( + "Please set BATCH_TYPE in batch_utils.py to {trajdata, l5kit}" + ) + + +class BatchUtils(object): + """A base class for processing environment-independent batches""" + + def __init__(self, **kwargs): + if "parse" in kwargs: + self.parse = kwargs["parse"] + else: + self.parse = True + if "rasterize_mode" in kwargs: + self.rasterize_mode = kwargs["rasterize_mode"] + else: + self.rasterize_mode = "point" + + @staticmethod + def get_last_available_index(avails): + """ + Args: + avails (torch.Tensor): target availabilities [B, (A), T] + + Returns: + last_indices (torch.Tensor): index of the last available frame + """ + num_frames = avails.shape[-1] + inds = torch.arange(0, num_frames).to(avails.device) # [T] + inds = ( + avails > 0 + ).float() * inds # [B, (A), T] arange indices with unavailable indices set to 0 + last_inds = inds.max(dim=-1)[ + 1 + ] # [B, (A)] calculate the index of the last availale frame + return last_inds + + @staticmethod + def get_current_states(batch: dict, dyn_type: dynamics.DynType) -> torch.Tensor: + """Get the dynamic states of the current timestep""" + bs = batch["curr_speed"].shape[0] + if dyn_type == dynamics.DynType.BICYCLE: + current_states = torch.zeros(bs, 6).to( + batch["curr_speed"].device + ) # [x, y, yaw, vel, dh, veh_len] + current_states[:, 3] = batch["curr_speed"].abs() + current_states[:, [4]] = ( + batch["hist_yaw"][:, 0] - batch["hist_yaw"][:, 1] + ).abs() + current_states[:, 5] = batch["extent"][:, 0] # [veh_len] + else: + current_states = torch.zeros(bs, 4).to( + batch["curr_speed"].device + ) # [x, y, vel, yaw] + current_states[:, 2] = batch["curr_speed"] + return current_states + + @classmethod + def get_current_states_all_agents( + cls, batch: dict, step_time, dyn_type: dynamics.DynType + ) -> torch.Tensor: + raise NotImplementedError + + @staticmethod + def parse_batch(data_batch): + raise NotImplementedError + + @staticmethod + def batch_to_raw_all_agents(data_batch, step_time): + raise NotImplementedError + + @staticmethod + def batch_to_target_all_agents(data_batch): + raise NotImplementedError + + @staticmethod + def get_edges_from_batch(data_batch, ego_predictions=None, all_predictions=None): + raise NotImplementedError + + @staticmethod + def generate_edges(raw_type, extents, pos_pred, yaw_pred): + raise NotImplementedError + + @staticmethod + def gen_ego_edges( + ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types + ): + raise NotImplementedError + + @staticmethod + def gen_EC_edges( + ego_trajectories, + agent_trajectories, + ego_extents, + agent_extents, + raw_types, + mask=None, + ): + raise NotImplementedError + + @staticmethod + def get_drivable_region_map(rasterized_map): + raise NotImplementedError + + @staticmethod + def get_modality_shapes(cfg: ExperimentConfig): + raise NotImplementedError + + +class trajdataBatchUtils(BatchUtils): + """Batch utils for trajdata""" + + def parse_batch(self, data_batch): + if self.parse: + return av_utils.parse_trajdata_batch(data_batch, self.rasterize_mode) + else: + return data_batch + + @staticmethod + def batch_to_raw_all_agents(data_batch, step_time): + raw_type = torch.cat( + (data_batch["agent_type"].unsqueeze(1), data_batch["neigh_types"]), + dim=1, + ).type(torch.int64) + if data_batch["num_neigh"].max() > 0: + src_pos = torch.cat( + ( + data_batch["hist_pos"].unsqueeze(1), + data_batch["neigh_hist_pos"], + ), + dim=1, + ) + src_yaw = torch.cat( + ( + data_batch["hist_yaw"].unsqueeze(1), + data_batch["neigh_hist_yaw"], + ), + dim=1, + ) + src_mask = torch.cat( + ( + data_batch["hist_mask"].unsqueeze(1), + data_batch["neigh_hist_mask"], + ), + dim=1, + ).bool() + + extents = torch.cat( + ( + data_batch["extent"][..., :2].unsqueeze(1), + data_batch["neigh_extents"][..., :2], + ), + dim=1, + ) + + curr_speed = torch.cat( + (data_batch["curr_speed"].unsqueeze(1), data_batch["neigh_curr_speed"]), + dim=1, + ) + else: + src_pos = data_batch["hist_pos"].unsqueeze(1) + src_yaw = data_batch["hist_yaw"].unsqueeze(1) + src_mask = data_batch["hist_mask"].unsqueeze(1) + extents = data_batch["extent"][..., :2].unsqueeze(1) + curr_speed = data_batch["curr_speed"].unsqueeze(1) + + return { + "hist_pos": src_pos, + "hist_yaw": src_yaw, + "curr_speed": curr_speed, + "raw_types": raw_type, + "hist_mask": src_mask, + "extents": extents, + } + + @staticmethod + def batch_to_target_all_agents(data_batch): + pos = torch.cat( + ( + data_batch["fut_pos"].unsqueeze(1), + data_batch["neigh_fut_pos"], + ), + dim=1, + ) + yaw = torch.cat( + ( + data_batch["fut_yaw"].unsqueeze(1), + data_batch["neigh_fut_yaw"], + ), + dim=1, + ) + avails = torch.cat( + ( + data_batch["fut_mask"].unsqueeze(1), + data_batch["neigh_fut_mask"], + ), + dim=1, + ) + + extents = torch.cat( + ( + data_batch["extent"][..., :2].unsqueeze(1), + data_batch["neigh_extents"][..., :2], + ), + dim=1, + ) + + return {"fut_pos": pos, "fut_yaw": yaw, "fut_mask": avails, "extents": extents} + + @staticmethod + def get_current_states_all_agents( + batch: dict, step_time, dyn_type: dynamics.DynType + ) -> torch.Tensor: + if batch["hist_pos"].ndim == 3: + state_all = trajdataBatchUtils.batch_to_raw_all_agents(batch, step_time) + else: + state_all = batch + bs, na = state_all["curr_speed"].shape[:2] + if dyn_type == dynamics.DynType.BICYCLE: + current_states = torch.zeros(bs, na, 6).to( + state_all["curr_speed"].device + ) # [x, y, yaw, vel, dh, veh_len] + current_states[:, :, :2] = state_all["hist_pos"][:, :, -1] + current_states[:, :, 3] = state_all["curr_speed"].abs() + current_states[:, :, [4]] = ( + state_all["hist_yaw"][:, :, -1] - state_all["hist_yaw"][:, :, 1] + ).abs() + current_states[:, :, 5] = state_all["extent"][:, :, -1] # [veh_len] + else: + current_states = torch.zeros(bs, na, 4).to( + state_all["curr_speed"].device + ) # [x, y, vel, yaw] + current_states[:, :, :2] = state_all["hist_pos"][:, :, -1] + current_states[:, :, 2] = state_all["curr_speed"] + current_states[:, :, 3:] = state_all["hist_yaw"][:, :, -1] + return current_states + + @staticmethod + def get_edges_from_batch(data_batch, ego_predictions=None, all_predictions=None): + raise NotImplementedError + + @staticmethod + def generate_edges(raw_type, extents, pos_pred, yaw_pred, batch_first=False): + return av_utils.generate_edges( + raw_type, extents, pos_pred, yaw_pred, batch_first=batch_first + ) + + @staticmethod + def gen_ego_edges( + ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types + ): + return av_utils.gen_ego_edges( + ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types + ) + + @staticmethod + def gen_EC_edges( + ego_trajectories, + agent_trajectories, + ego_extents, + agent_extents, + raw_types, + mask=None, + ): + return av_utils.gen_EC_edges( + ego_trajectories, + agent_trajectories, + ego_extents, + agent_extents, + raw_types, + mask, + ) + + @staticmethod + def get_drivable_region_map(rasterized_map): + return av_utils.get_drivable_region_map(rasterized_map) + + def get_modality_shapes(self, cfg: ExperimentConfig): + return av_utils.get_modality_shapes(cfg, rasterize_mode=self.rasterize_mode) diff --git a/diffstack/utils/bezier_utils.py b/diffstack/utils/bezier_utils.py new file mode 100644 index 0000000..cae9ddd --- /dev/null +++ b/diffstack/utils/bezier_utils.py @@ -0,0 +1,97 @@ +import numpy as np +import math + +from scipy.special import comb, gamma + +def bezier_base(n,i,s): + b=math.factorial(n)/math.factorial(i)/math.factorial(n-i)*s**i*(1-s)**(n-i) + return b + + +def bezier_integral(alpha): + n=len(alpha)-1 + m=0 + for i in range(n+1): + m = m + alpha(1+i) * comb(n,i)*gamma(1+i)*gamma(n+1-i)/gamma(n+2) + return m + +def symbolic_dbezier(n): +# alpha_df=A*alpha_f + A = np.zeros([n,n+1]) + for i in range(0,n): + A[i,i]=-n + A[i,i+1]=n + return A + +class bezier_regressor(object): + def __init__(self, N=6, Nmin=8, Nmax=200,Gamma=1e-2): + self.N = N + self.dB = symbolic_dbezier(N) + self.ddB = symbolic_dbezier(N-1) + self.Nmin = Nmin + self.Nmax = Nmax + self.Bez_matr = dict() + self.bez_reg = dict() + self.Bez_matr_der = dict() + self.Bez_matr_dder = dict() + for i in range(Nmin,Nmax+1): + t=np.linspace(0,1,i) + self.Bez_matr[i]=np.zeros([i,self.N+1]) + self.Bez_matr_der[i]=np.zeros([i,self.N]) + self.Bez_matr_dder[i]=np.zeros([i,self.N-1]) + for j in range(i): + for k in range(self.N+1): + self.Bez_matr[i][j,k]=comb(self.N,k)*(1-t[j])**(self.N-k)*t[j]**k + if k {module} ({name})") - return StockUnpickler.find_class(self, module, name) - - def __init__(self, *args, replace_dict={}, **kwds): - settings = Pickler.settings - _ignore = kwds.pop('ignore', None) - StockUnpickler.__init__(self, *args, **kwds) - self._main = _main_module - self._ignore = settings['ignore'] if _ignore is None else _ignore - self._replace_dict = replace_dict - - def load(self): #NOTE: if settings change, need to update attributes - obj = StockUnpickler.load(self) - if type(obj).__module__ == getattr(_main_module, '__name__', '__main__'): - if not self._ignore: - # point obj class to main - try: obj.__class__ = getattr(self._main, type(obj).__name__) - except (AttributeError,TypeError): pass # defined in a file - #_main_module.__dict__.update(obj.__dict__) #XXX: should update globals ? - return obj - load.__doc__ = StockUnpickler.load.__doc__ - pass - - -if __name__ == '__main__': - - # Needed to find model - sys.path.append('./trajectron/trajectron') - sys.path.append('./') - - args = sys.argv[1:] - - replace_dict = { - "trajectron.trajectron": "diffstack.modules.predictors.trajectron_utils", - "environment": "diffstack.modules.predictors.trajectron_utils.environment", - "model": "diffstack.modules.predictors.trajectron_utils.model", - } - - for filename in args: - new_filename = filename + ".new" - print (f"{filename} -> {new_filename}") - - # Import error can happen here is trying to load checkpoint with old cost_function object. - # The old pred_metrics folder needs to be copied with environment/nuScenes_data/cost_functions.py - with open(filename, 'rb') as f: - train_dataset = custom_load(f, replace_dict=replace_dict) - - print ("Loaded") - - with open(new_filename, 'wb') as f: - dill.dump(train_dataset, f) - - print ("Saved") - - print ("done") diff --git a/diffstack/utils/cleanup_checkpoint.py b/diffstack/utils/cleanup_checkpoint.py deleted file mode 100644 index 967e58e..0000000 --- a/diffstack/utils/cleanup_checkpoint.py +++ /dev/null @@ -1,25 +0,0 @@ -import sys -import torch - -# Needed to find model -sys.path.append('./trajectron/trajectron') - -args = sys.argv[1:] - -for filename in args: - new_filename = filename + ".new" - print (f"{filename} -> {new_filename}") - - # Import error can happen here is trying to load checkpoint with old cost_function object. - # The old pred_metrics folder needs to be copied with environment/nuScenes_data/cost_functions.py - model_dict = torch.load(filename, map_location=torch.device('cpu')) - - print (model_dict) - - if "planner_cost" in model_dict: - print ("Removing planner_cost") - del model_dict["planner_cost"] - - torch.save(model_dict, new_filename) - -print ("done") \ No newline at end of file diff --git a/diffstack/utils/config.py b/diffstack/utils/config.py new file mode 100644 index 0000000..cd57411 --- /dev/null +++ b/diffstack/utils/config.py @@ -0,0 +1,59 @@ +"""Basic config file""" +import yaml +from easydict import EasyDict +from typing import Dict + +class Config: + """Basic config class""" + + def __init__(self, cfg_path): + """Load config file""" + + with open(cfg_path, "r", encoding="utf-8") as file_handle: + self.yml_dict = EasyDict(yaml.safe_load(file_handle)) + + # format the config for print + with open(cfg_path, "r", encoding="utf-8") as file_handle: + self.format_str = file_handle.read().splitlines() + + def __getattribute__(self, name): + """Retrieve a value from the config""" + + yml_dict = super().__getattribute__("yml_dict") + if name in yml_dict: + return yml_dict[name] + + return super().__getattribute__(name) + + def __setattr__(self, name, value): + """Set a value from the config""" + + try: + yml_dict = super().__getattribute__("yml_dict") + except AttributeError: + return super().__setattr__(name, value) + + if name in yml_dict: + yml_dict[name] = value + return None + + return super().__setattr__(name, value) + + def get(self, name, default=None): + """Retrieve a value from the config""" + + if hasattr(self, name): + return getattr(self, name) + + return default + + +class DictConfig(Config): + + def __init__(self, cfg_dict: Dict): + """Load config file""" + + self.yml_dict = EasyDict(cfg_dict) + + # format the config for print + self.format_str = "" diff --git a/diffstack/utils/config_utils.py b/diffstack/utils/config_utils.py new file mode 100644 index 0000000..738c2d9 --- /dev/null +++ b/diffstack/utils/config_utils.py @@ -0,0 +1,132 @@ +import json +import hydra +import os + +import diffstack +from diffstack.configs.registry import get_registered_experiment_config +from diffstack.configs.base import AlgoConfig, ExperimentConfig +from diffstack.configs.config import Dict +from omegaconf import DictConfig +from pathlib import Path + + +def load_hydra_config(file_name: str) -> DictConfig: + fullpath = Path(file_name) + config_path = Path(diffstack.__path__[0]).parent / "config" + config_name = ( + fullpath.absolute().relative_to(config_path.absolute()).with_suffix("") + ) + # Need it to be relative to this file, not the working directory. + rel_config_path = os.path.relpath(config_path, Path(__file__).parent) + + # Initialize configuration management system + hydra.core.global_hydra.GlobalHydra.instance().clear() # reinitialize hydra if already initialized + hydra.initialize(config_path=str(rel_config_path)) + cfg = hydra.compose(config_name=str(config_name), overrides=[]) + + return cfg + + +####### Functions below are related to tbsim style config classes + + +def recursive_update(config: Dict, external_dict: dict): + leftover = list() + for k, v in external_dict.items(): + if k in config and isinstance(config[k], dict): + config[k] = recursive_update(config[k], v) + else: + leftover.append(k) + + leftover_dict = {k: external_dict[k] for k in leftover} + config.update(**leftover_dict) + return config + + +def recursive_update_flat(config: Dict, external_dict: dict): + left_over = dict() + for k, v in external_dict.items(): + assert not isinstance(v, dict) + if k in config: + assert not isinstance(config[k], dict) + config[k] = v + else: + left_over[k] = v + if len(left_over) > 0: + for k, v in config.items(): + if isinstance(v, dict): + config[k], leftover = recursive_update_flat(v, left_over) + return config, left_over + + +def get_experiment_config_from_file(file_path, locked=False): + ext_cfg = json.load(open(file_path, "r")) + cfg = get_registered_experiment_config(ext_cfg["registered_name"]) + cfg = recursive_update(cfg, ext_cfg) + cfg.lock(locked) + return cfg + + +def translate_trajdata_cfg(cfg: ExperimentConfig): + rcfg = Dict() + # assert cfg.stack.step_time == 0.5 # TODO: support interpolation + + if ( + "predictor" in cfg.stack + and "scene_centric" in cfg.stack["predictor"] + and cfg.stack["predictor"].scene_centric + ): + rcfg.centric = "scene" + else: + rcfg.centric = "agent" + if "standardize_data" in cfg.env.data_generation_params: + rcfg.standardize_data = cfg.env.data_generation_params.standardize_data + else: + rcfg.standardize_data = True + if "predictor" in cfg.stack: + rcfg.step_time = cfg.stack["predictor"].step_time + rcfg.history_num_frames = cfg.stack["predictor"].history_num_frames + rcfg.future_num_frames = cfg.stack["predictor"].future_num_frames + elif "planner" in cfg.stack: + rcfg.step_time = cfg.stack["planner"].step_time + rcfg.history_num_frames = cfg.stack["planner"].history_num_frames + rcfg.future_num_frames = cfg.stack["planner"].future_num_frames + if "remove_parked" in cfg.env: + rcfg.remove_parked = cfg.env.remove_parked + rcfg.trajdata_source_root = cfg.train.trajdata_source_root + rcfg.trajdata_val_source_root = cfg.train.trajdata_val_source_root + rcfg.trajdata_source_train = cfg.train.trajdata_source_train + rcfg.trajdata_source_valid = cfg.train.trajdata_source_valid + rcfg.trajdata_source_test = cfg.train.trajdata_source_test + rcfg.trajdata_test_source_root = cfg.train.trajdata_test_source_root + rcfg.dataset_path = cfg.train.dataset_path + + rcfg.max_agents_distance = cfg.env.data_generation_params.max_agents_distance + rcfg.num_other_agents = cfg.env.data_generation_params.other_agents_num + rcfg.max_agents_distance_simulation = cfg.env.simulation.distance_th_close + rcfg.pixel_size = cfg.env.rasterizer.pixel_size + rcfg.raster_size = int(cfg.env.rasterizer.raster_size) + rcfg.raster_center = cfg.env.rasterizer.ego_center + rcfg.yaw_correction_speed = cfg.env.data_generation_params.yaw_correction_speed + rcfg.incl_neighbor_map = cfg.env.incl_neighbor_map + rcfg.incl_vector_map = cfg.env.incl_vector_map + rcfg.incl_raster_map = cfg.env.incl_raster_map + rcfg.calc_lane_graph = cfg.env.calc_lane_graph + rcfg.other_agents_num = cfg.env.data_generation_params.other_agents_num + rcfg.max_num_lanes = cfg.env.get("max_num_lanes", 15) + rcfg.remove_single_successor = cfg.env.get("remove_single_successor", False) + rcfg.num_lane_pts = cfg.env.get("num_lane_pts", 20) + if "vectorize_lane" in cfg.env.data_generation_params: + rcfg.vectorize_lane = cfg.env.data_generation_params.vectorize_lane + + else: + rcfg.vectorize_lane = "None" + + rcfg.lock() + return rcfg + + +def boolean_string(s): + if s not in {"False", "True", "0", "1", "false", "true", "on", "off"}: + raise ValueError("Not a valid boolean string") + return s in ["True", "1", "true", "on"] diff --git a/diffstack/utils/diffusion_utils.py b/diffstack/utils/diffusion_utils.py new file mode 100644 index 0000000..bfa70e1 --- /dev/null +++ b/diffstack/utils/diffusion_utils.py @@ -0,0 +1,256 @@ +""" +Various utilities for neural networks. +""" + +import math + +import torch as torch +import torch.nn as nn +import numpy as np + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor, mask=None): + """ + Take the mean over all non-batch dimensions. + """ + if mask is not None: + tensor = tensor * mask + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(8, channels) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + +def mean_flat(tensor, mask=None): + """ + Take the mean over all non-batch dimensions. + """ + if mask is not None: + tensor = tensor * mask + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + +def mean_sum(tensor, mask): + """ + Take the mean over all non-batch dimensions. + """ + if mask is not None: + tensor = tensor * mask + return (tensor*mask).sum()/mask.sum().clip(min=1) + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x < -0.999, + log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/diffstack/utils/dist_utils.py b/diffstack/utils/dist_utils.py new file mode 100644 index 0000000..b660507 --- /dev/null +++ b/diffstack/utils/dist_utils.py @@ -0,0 +1,680 @@ +import torch +from torch.nn.functional import one_hot, gumbel_softmax +import diffstack.utils.tensor_utils as TensorUtils +from diffstack.dynamics.base import Dynamics +from diffstack.dynamics.unicycle import Unicycle +import abc +import numpy as np +import diffstack.utils.geometry_utils as GeoUtils + + +def categorical_sample(pi, num_sample): + bs, M = pi.shape + pi0 = torch.cat([torch.zeros_like(pi[:, :1]), pi[:, :-1]], -1) + pi_cumsum = pi0.cumsum(-1) + rand_num = torch.rand([bs, num_sample], device=pi.device) + flag = rand_num.unsqueeze(-1) > pi_cumsum.unsqueeze(-2) + idx = torch.argmax((flag * torch.arange(0, M, device=pi.device)[None, None]), -1) + return idx + + +def categorical_psample_wor( + logpi, num_sample, num_factor=None, factor_mask=None, relevance_score=None +): + """pseodo-sample (pick the most probable modes) from a categorical distribution without replacement + the distribution is assumed to be a factorized categorical distribution + + Args: + logpi (probability): [B,M,D], M is the number of factors, D is the number of categories + num_sample (_type_): number of samples + num_factor (int): If not None, pick the top num_factor factors with the highest probability variation + + """ + B, M, D = logpi.shape + if num_factor is None: + num_factor = M + # assert D**num_factor>=num_sample + if num_factor != M: + if relevance_score is None: + # default to use the maximum mode probability as measure + relevance_score = -logpi.max(-1)[0] + if factor_mask is not None: + relevance_score.masked_fill_( + torch.logical_not(factor_mask), relevance_score.min() - 1 + ) + idx = torch.topk(relevance_score, num_factor, dim=1)[1] + else: + idx = torch.arange(M, device=logpi.device)[None, :].repeat_interleave(B, 0) + factor_logpi = torch.gather( + logpi, 1, idx.unsqueeze(-1).repeat_interleave(D, -1) + ) # Factors are chosen + # calculate the product log probability of the factors + num_factor = factor_logpi.shape[1] + + prod_logpi = sum( + [ + factor_logpi[:, i].reshape(B, *[1] * i, D, *[1] * (num_factor - i - 1)) + for i in range(num_factor) + ] + ) # D**num_factor + + prod_logpi_flat = prod_logpi.view(B, -1) + factor_samples = torch.topk(prod_logpi_flat, num_sample, 1)[1] + factor_sample_idx = list() + for i in range(num_factor): + factor_sample_idx.append(factor_samples % D) + factor_samples = torch.div(factor_samples, D, rounding_mode="floor") + factor_sample_idx = torch.stack(factor_sample_idx, -1).flip(-1) + # for unselected factors, pick the maximum likelihood mode + samples = logpi.argmax(-1).unsqueeze(1).repeat_interleave(num_sample, 1) + samples.scatter_( + 2, idx.unsqueeze(1).repeat_interleave(num_sample, 1), factor_sample_idx + ) + + return samples, idx + + # nonfactor_pi = torch.gather(pi,1,nonidx.unsqueeze(-1).repeat_interleave(D,-1)) + + +class BaseDist(abc.ABC): + @abc.abstractmethod + def rsample(self, num_sample): + pass + + @abc.abstractmethod + def get_dist(self): + pass + + @abc.abstractmethod + def detach_(self): + pass + + @abc.abstractmethod + def index_select_(self, idx): + pass + + +class MAGaussian(BaseDist): + def __init__(self, mu, var, K, delta_x_clamp=1.0, min_var=1e-4): + """multiagent gaussian distribution + + Args: + mu (torch.Tensor): [B,N,D] mean + var (torch.Tensor): [B,N,D] variance + K (torch.Tensor): [B,N,D,L] coefficient of shared scene variance + var = var + K @ K.T + """ + self.mu = mu + self.var = var + self.K = K + self.L = K.shape[-1] + self.min_var = min_var + self.delta_x_clamp = delta_x_clamp + + def rsample(self, num_sample): + """sample from the distribution + + Args: + sample_shape (torch.Size, optional): [description]. Defaults to torch.Size(). + + Returns: + torch.Tensor: [B,N,D] sample + """ + B, N, D = self.mu.shape + eps = torch.randn(B, num_sample, N, D, device=self.mu.device) + scene_eps = torch.randn(B, num_sample, 1, self.L, device=self.mu.device) + return ( + self.mu.unsqueeze(-3) + + eps * self.var.sqrt().unsqueeze(-3) + + (self.K.unsqueeze(1) @ scene_eps.unsqueeze(-1)).squeeze(-1) + ) + + def get_dist(self, mask): + B, N, D = self.mu.shape + + var_diag = TensorUtils.block_diag_from_cat( + torch.diag_embed(self.var + self.min_var) + ) + var_inv_diag = torch.linalg.pinv(var_diag) + C = torch.eye(self.K.shape[-1], device=self.K.device)[None].repeat_interleave( + B, 0 + ) + K = self.K.reshape(B, N * D, self.L) + K_T = K.transpose(-1, -2) + var = var_diag + K @ K_T + if self.L > N * D: + var_inv = ( + var_inv_diag + - var_inv_diag + @ K + @ torch.linalg.pinv(C + K_T @ var_inv_diag @ K) + @ K_T + @ var_inv_diag + ) + else: + var_inv = torch.linalg.pinv(var) + return self.mu, var, var_inv + + def get_log_likelihood(self, xp, mask): + # not tested + B = self.mu.shape[0] + mu, var, var_inv = self.get_dist(mask) + if self.delta_x_clamp is not None: + xp = torch.minimum( + torch.maximum(xp, mu - self.delta_x_clamp), mu + self.delta_x_clamp + ) + delta_x = xp - mu + delta_x *= mask + delta_x = delta_x.reshape(B, -1) + mask_var = torch.diag_embed( + (1 - mask).squeeze(-1).repeat_interleave(self.mu.shape[-1], -1) + ) + log_prob = 0.5 * ( + torch.logdet(var_inv + mask_var) + - (delta_x.unsqueeze(-2) @ var_inv @ delta_x.unsqueeze(-1)).flatten() + + np.log(2 * np.pi) + ).nan_to_num(nan=0.0, posinf=0.0, neginf=0.0) + return log_prob + + def detach_(self): + self.mu = self.mu.detach() + self.var = self.var.detach() + self.K = self.K.detach() + + def index_select_(self, idx): + self.mu = self.mu[idx] + self.var = self.varCategorical[idx] + self.K = self.K[idx] + + +class MADynGaussian(BaseDist): + def __init__(self, mu_u, var_u, var_x, K, dyn, delta_x_clamp=1.0, min_var=1e-4): + """multiagent gaussian distribution with dynamics, input follows MAGaussian + + Args: + mu_u (torch.Tensor): [B,N,udim] mean + var_u (torch.Tensor): [B,N,udim] independent variance of u + var_x (torch.Tensor): [B,N,xdim] independent variance of x + K (torch.Tensor): [B,N,D,L] coefficient of shared scene variance + dyn (Dynamics): dynamics object + var = var + K @ K.T + """ + self.mu_u = mu_u + self.var_u = var_u + self.var_x = var_x + self.K = K + self.L = K.shape[-1] + self.dyn = dyn + self.min_var = min_var + self.delta_x_clamp = delta_x_clamp + + def rsample(self, x0, num_sample): + """sample from the distribution + + Args: + x0 (torch.Tensor): [B,N,xdim] initial state + sample_shape (torch.Size, optional): [description]. Defaults to torch.Size(). + + Returns: + torch.Tensor: [B,N,xdim] sample + """ + B, N, D = self.mu_u.shape + eps = torch.randn(B, num_sample, N, D, device=self.mu_u.device) + scene_eps = torch.randn(B, num_sample, 1, self.L, device=self.mu_u.device) + u_sample = ( + self.mu_u.unsqueeze(1) + + eps * self.var_u.sqrt().unsqueeze(1) + + (self.K.unsqueeze(1) @ scene_eps.unsqueeze(-1)).squeeze(-1) + ) + x_sample = self.dyn.step( + x0.unsqueeze(1).repeat_interleave(num_sample, 1), u_sample + ) + x_sample += self.var_x.sqrt().unsqueeze(1) * torch.randn_like(x_sample) + return x_sample + + def get_dist(self, x0, mask): + B, N, D = self.mu_u.shape + + mu_x, _, jacu = self.dyn.step(x0, self.mu_u, bound=False, return_jacobian=True) + if mask is not None: + jacu *= mask.unsqueeze(-1) + # var_x_total = J@var_u@J.T+var_x + var_x = jacu @ torch.diag_embed(self.var_u) @ jacu.transpose( + -1, -2 + ) + torch.diag_embed(self.var_x + self.min_var) + + var_x_inv = torch.linalg.pinv(var_x) + var_x_diag = TensorUtils.block_diag_from_cat(var_x) + var_x_inv_diag = TensorUtils.block_diag_from_cat(var_x_inv) + blk_jacu = TensorUtils.block_diag_from_cat(jacu) + K = self.K.reshape(B, N * D, self.L) + + JK = blk_jacu @ K + JK_T = JK.transpose(-1, -2) + C = torch.eye(self.K.shape[-1], device=self.K.device)[None].repeat_interleave( + B, 0 + ) + + var = var_x_diag + JK @ JK_T + + if self.L > N * D: + # Woodbury matrix identity + var_inv = ( + var_x_inv_diag + - var_x_inv_diag + @ JK + @ torch.linalg.pinv(C + JK_T @ var_x_inv_diag @ JK) + @ JK_T + @ var_x_inv_diag + ) + else: + var_inv = torch.linalg.pinv(var) + return mu_x, var, var_inv + + def get_log_likelihood(self, x0, xp, mask): + B, N = self.mu_u.shape[:2] + mu_x, total_var_x, var_x_inv = self.get_dist(x0, mask) + if self.delta_x_clamp is not None: + xp = torch.minimum( + torch.maximum(xp, mu_x - self.delta_x_clamp), mu_x + self.delta_x_clamp + ) + # # ground truth + delta_x = xp - mu_x + delta_x *= mask + mask_var = torch.diag_embed( + (1 - mask).squeeze(-1).repeat_interleave(self.dyn.xdim, -1) + ) + # hack: write delta function for each dynamics + delta_x[..., 3] = GeoUtils.round_2pi(delta_x[..., 3]) + delta_x = delta_x.reshape(B, -1) + log_prob = 0.5 * ( + torch.logdet(var_x_inv + mask_var) + - (delta_x.unsqueeze(-2) @ var_x_inv @ delta_x.unsqueeze(-1)).flatten() + + np.log(2 * np.pi) + ).nan_to_num(nan=0.0, posinf=0.0, neginf=0.0) + return log_prob + + def detach_(self): + self.mu_u = self.mu_u.detach() + self.var_u = self.var_u.detach() + self.var_x = self.var_x.detach() + self.K = self.K.detach() + + def index_select_(self, idx): + self.mu_u = self.mu_u[idx] + self.var_u = self.var_u[idx] + self.var_x = self.var_x[idx] + self.K = self.K[idx] + + +class MAGMM(BaseDist): + def __init__(self, mu, var, K, pi, delta_x_clamp=1.0, min_var=1e-4): + """multiagent gaussian distribution + + Args: + mu (torch.Tensor): [B,M,N,D] mean + var (torch.Tensor): [B,M,N,D] variance + K (torch.Tensor): [B,M,N,D,L] coefficient of shared scene variance + var = var + K @ K.T + pi: [B,M] mixture weights + """ + self.mu = mu + self.var = var + self.K = K + self.L = K.shape[-1] + self.M = mu.shape[1] + self.pi = pi + self.logits = torch.log(pi) + self.pi_sum = pi.cumsum(-1) + self.delta_x_clamp = delta_x_clamp + self.min_var = min_var + + def rsample(self, num_sample, tau=None, infer=False): + """sample from the distribution + + Args: + sample_shape (torch.Size, optional): [description]. Defaults to torch.Size(). + + Returns: + torch.Tensor: [B,N,D] sample + """ + B, M, N, D = self.mu.shape + if tau is not None: + mode = gumbel_softmax( + self.logits[:, None].repeat_interleave(num_sample, 1), tau, hard=infer + ) + else: + mode = categorical_sample(self.pi, num_sample) + mode = one_hot(mode, self.M) + eps = torch.randn(B, M, num_sample, N, D, device=self.mu.device) + scene_eps = torch.randn(B, M, num_sample, 1, self.L, device=self.mu.device) + sample = ( + self.mu.unsqueeze(-3) + + eps * self.var.sqrt().unsqueeze(-3) + + (self.K.unsqueeze(2) @ scene_eps.unsqueeze(-1)).squeeze(-1) + ) + sample = (sample * mode.transpose(1, 2).view(B, M, num_sample, 1, 1)).sum( + 1 + ) # B x num_sample x N x D + return sample + + def get_dist(self, mask): + B, M, N, D = self.mu.shape + var_tiled, mu_tiled, K_tiled = TensorUtils.join_dimensions( + (self.var, self.mu, self.K), 0, 2 + ) + var_diag = TensorUtils.block_diag_from_cat( + torch.diag_embed(var_tiled + self.min_var) + ) + var_inv_diag = torch.linalg.pinv(var_diag) + C = torch.eye(K_tiled.shape[-1], device=self.K.device)[None].repeat_interleave( + B * M, 0 + ) + K = K_tiled.reshape(B * M, N * D, self.L) + K_T = K.transpose(-1, -2) + var = var_diag + K @ K_T + if self.L > N * D: + var_inv = ( + var_inv_diag + - var_inv_diag + @ K + @ torch.linalg.pinv(C + K_T @ var_inv_diag @ K) + @ K_T + @ var_inv_diag + ) + else: + var_inv = torch.linalg.pinv(var) + return ( + self.mu, + var.reshape(B, M, N * D, N * D), + var_inv.reshape(B, M, N * D, N * D), + self.pi, + ) + + def get_log_likelihood(self, xp, mask): + # not tested + B, M = self.mu.shape[:2] + mu, var, var_inv, pi = self.get_dist(mask) + xp = xp[:, None].repeat_interleave(M, 1) + if self.delta_x_clamp is not None: + xp = torch.minimum( + torch.maximum(xp, mu - self.delta_x_clamp), mu + self.delta_x_clamp + ) + delta_x = xp - mu + + delta_x *= mask.unsqueeze(1) + delta_x = delta_x.reshape(B, M, -1) + mask_var = ( + torch.diag_embed( + (1 - mask).squeeze(-1).repeat_interleave(self.mu.shape[-1], -1) + ) + .unsqueeze(1) + .repeat_interleave(self.M, 1) + ) + log_prob_mode = 0.5 * ( + torch.logdet(var_inv + mask_var) + - (delta_x.unsqueeze(-2) @ var_inv @ delta_x.unsqueeze(-1)).reshape(B, M) + + np.log(2 * np.pi) + ).nan_to_num(nan=0.0, posinf=0.0, neginf=0.0) + log_prob = torch.log((torch.exp(log_prob_mode) * pi).sum(-1)) + return log_prob + + def detach_(self): + self.mu = self.mu.detach() + self.var = self.var.detach() + self.K = self.K.detach() + self.pi = self.pi.detach() + + def index_select_(self, idx): + self.mu = self.mu[idx] + self.var = self.var[idx] + self.K = self.K[idx] + self.pi = self.pi[idx] + + +class MADynGMM(BaseDist): + def __init__(self, mu_u, var_u, var_x, K, pi, dyn, delta_x_clamp=1.0, min_var=1e-6): + """multiagent gaussian distribution with dynamics, input follows MAGaussian + + Args: + mu_u (torch.Tensor): [B,M,N,udim] mean + var_u (torch.Tensor): [B,M,N,udim] independent variance of u + var_x (torch.Tensor): [B,M,N,xdim] independent variance of x + K (torch.Tensor): [B,M,N,D,L] coefficient of shared scene variance + pi (torch.Tensor): [B,M] mixture weights + dyn (Dynamics): dynamics object + var = var + K @ K.T + """ + self.mu_u = mu_u + self.var_u = var_u + self.var_x = var_x + self.K = K + self.L = K.shape[-1] + self.M = mu_u.shape[1] + self.pi = pi + self.logits = torch.log(pi) + self.dyn = dyn + self.delta_x_clamp = delta_x_clamp + self.min_var = min_var + + def rsample(self, x0, num_sample, tau=None, infer=False): + """sample from the distribution + + Args: + sample_shape (torch.Size, optional): [description]. Defaults to torch.Size(). + + Returns: + torch.Tensor: [B,N,D] sample + """ + B, M, N, udim = self.mu_u.shape + if tau is not None: + mode = gumbel_softmax( + self.logits[:, None].repeat_interleave(num_sample, 1), tau, hard=infer + ) + else: + mode = categorical_sample(self.pi, num_sample) + mode = one_hot(mode, self.M) + eps = torch.randn(B, M, num_sample, N, udim, device=self.mu_u.device) + scene_eps = torch.randn(B, M, num_sample, 1, self.L, device=self.mu_u.device) + u_sample = ( + self.mu_u.unsqueeze(2) + + eps * self.var_u.sqrt().unsqueeze(2) + + (self.K.unsqueeze(2) @ scene_eps.unsqueeze(-1)).squeeze(-1) + ) + x_sample = self.dyn.step( + x0[:, None, None] + .repeat_interleave(self.M, 1) + .repeat_interleave(num_sample, 2), + u_sample, + ) + x_sample += self.var_x.sqrt().unsqueeze(2) * torch.randn_like(x_sample) + x_sample = (x_sample * (mode.transpose(1, 2).view(B, M, num_sample, 1, 1))).sum( + 1 + ) + return x_sample + + def get_dist(self, x0, mask): + B, M, N, udim = self.mu_u.shape + var_x_tiled, var_u_tiled, mu_u_tiled, K_tiled = TensorUtils.join_dimensions( + (self.var_x, self.var_u, self.mu_u, self.K), 0, 2 + ) + x0_tiled = x0.repeat_interleave(M, 0) + mu_x, _, jacu = self.dyn.step( + x0_tiled, mu_u_tiled, bound=False, return_jacobian=True + ) + # var_x_total = J@var_u@J.T+var_x + var_x = jacu @ torch.diag_embed(var_u_tiled) @ jacu.transpose( + -1, -2 + ) + torch.diag_embed(var_x_tiled + self.min_var) + + var_x_inv = torch.linalg.pinv(var_x) + var_x_diag = TensorUtils.block_diag_from_cat(var_x) + var_x_inv_diag = TensorUtils.block_diag_from_cat(var_x_inv) + blk_jacu = TensorUtils.block_diag_from_cat(jacu) + K = K_tiled.reshape(B * M, N * udim, self.L) + JK = blk_jacu @ K + JK_T = JK.transpose(-1, -2) + C = torch.eye(K_tiled.shape[-1], device=self.K.device)[None].repeat_interleave( + B * M, 0 + ) + + var = var_x_diag + JK @ JK_T + if self.L > N * udim: + # Woodbury matrix identity + var_inv = ( + var_x_inv_diag + - var_x_inv_diag + @ JK + @ torch.linalg.pinv(C + JK_T @ var_x_inv_diag @ JK) + @ JK_T + @ var_x_inv_diag + ) + else: + var_inv = torch.linalg.pinv(var) + return ( + *TensorUtils.reshape_dimensions((mu_x, var, var_inv), 0, 1, (B, M)), + self.pi, + ) + + def get_log_likelihood(self, x0, xp, mask): + B, M, N = self.mu_u.shape[:3] + mu_x, total_var_x, var_x_inv, pi = self.get_dist(x0, mask) + + xp = xp[:, None].repeat_interleave(M, 1) + if self.delta_x_clamp is not None: + xp = torch.minimum( + torch.maximum(xp, mu_x - self.delta_x_clamp), mu_x + self.delta_x_clamp + ) + # # ground truth + delta_x = xp - mu_x + + delta_x *= mask.unsqueeze(1) + delta_x = delta_x.reshape(B, M, -1) + mask_var = ( + torch.diag_embed( + (1 - mask).squeeze(-1).repeat_interleave(self.dyn.xdim, -1) + ) + .unsqueeze(1) + .repeat_interleave(self.M, 1) + ) + log_prob_mode = 0.5 * ( + torch.logdet(var_x_inv + mask_var) + - (delta_x.unsqueeze(-2) @ var_x_inv @ delta_x.unsqueeze(-1)).reshape(B, M) + + np.log(2 * np.pi) + ).nan_to_num(nan=0.0, posinf=0.0, neginf=0.0) + # log_prob = torch.log((torch.exp(log_prob_mode)*pi).sum(-1)) + # Apply EM: + log_prob = (log_prob_mode * pi).sum(-1) + return log_prob + + def detach_(self): + self.mu_u = self.mu_u.detach() + self.var_u = self.var_u.detach() + self.var_x = self.var_x.detach() + self.K = self.K.detach() + + def index_select_(self, idx): + self.mu_u = self.mu_u[idx] + self.var_u = self.var_u[idx] + self.var_x = self.var_x[idx] + self.K = self.K[idx] + + +# class MADynEBM(BaseDist): +# """Multiagent Energy-based model with dynamics. + +# """ +# def __init__(self,dyn) + + +def test_ma(): + bs = 5 + N = 3 + d = 2 + L = 10 + mu = torch.randn(bs, N, d) + var = torch.randn(bs, N, d) ** 2 + K = torch.randn(bs, N, d, L) + dist = MAGaussian(mu, var, K) + sample = dist.rsample(15) + var_inv = dist.get_dist() + print("done") + + +def test_ma_dyn(): + bs = 5 + N = 3 + udim = 2 + xdim = 4 + L = 10 + dt = 0.1 + mu_u = torch.randn(bs, N, udim) + var_u = torch.randn(bs, N, udim) ** 2 + var_x = torch.randn(bs, N, xdim) ** 2 + K = torch.randn(bs, N, udim, L) + dyn = Unicycle(dt) + dist = MADynGaussian(mu_u, var_u, var_x, K, dyn) + x0 = torch.randn([bs, N, dyn.xdim]) + sample = dist.rsample(x0, 15) + var_inv = dist.get_dist(x0) + + +def test_categorical(): + pi = torch.tensor([0.1, 0.3, 0.6])[None].repeat_interleave(6, 0) + categorical_sample(pi, 10) + + +def test_GMM(): + D = 2 + B = 10 + M = 3 + N = 5 + L = 16 + + mu = torch.randn([B, M, N, D]) + var = torch.randn([B, M, N, D]) ** 2 + K = torch.randn([B, M, N, D, L]) + pi = torch.rand([B, M]) + pi = pi / pi.sum(-1, keepdim=True) + dist = MAGMM(mu, var, K, pi) + sample = dist.rsample(10) + _ = dist.get_dist() + x = torch.randn([B, N, D]) + log_prob = dist.get_log_likelihood(x) + + +def test_dyn_GMM(): + udim = 2 + xdim = 4 + dt = 0.1 + B = 10 + M = 3 + N = 5 + L = 16 + dyn = Unicycle(dt) + + mu_u = torch.randn(B, M, N, udim) + var_u = torch.randn(B, M, N, udim) ** 2 + var_x = torch.randn(B, M, N, xdim) ** 2 + K = torch.randn([B, M, N, udim, L]) + pi = torch.rand([B, M]) + pi = pi / pi.sum(-1, keepdim=True) + dist = MADynGMM(mu_u, var_u, var_x, K, pi, dyn) + + x = torch.randn([B, N, xdim]) + xp = torch.randn([B, N, xdim]) + _ = dist.get_dist(x) + sample = dist.rsample(x, 10, tau=0.5) + log_prob = dist.get_log_likelihood(x, xp) + print("done") + + +def test_sample_wor(): + pi = torch.randn([5, 6, 7]) + pi = pi**2 + pi = pi / pi.sum(-1, keepdim=True) + sample = categorical_sample_wor(pi, 20, 3) + + +if __name__ == "__main__": + test_sample_wor() diff --git a/diffstack/utils/env_utils.py b/diffstack/utils/env_utils.py new file mode 100644 index 0000000..90f2a35 --- /dev/null +++ b/diffstack/utils/env_utils.py @@ -0,0 +1,478 @@ +from typing import OrderedDict +import numpy as np +import pytorch_lightning as pl +import torch +import importlib +import os +from imageio import get_writer + +from diffstack.sim.TBSIM.base import BatchedEnv, BaseEnv +import diffstack.utils.tensor_utils as TensorUtils +from diffstack.utils.timer import Timers +from diffstack.sim.env_builders import EnvNuscBuilder +from diffstack.policies.wrappers import RolloutWrapper, Pos2YawWrapper +import diffstack.utils.geometry_utils as GeoUtils +from diffstack.utils.trajdata_utils import parse_trajdata_batch +from diffstack.utils.geometry_utils import VEH_VEH_collision +from trajdata.simulation import SimulationScene +import random + + +def collision_check(agents_posyaw, new_posyaw, agents_extent, new_extent): + new_posyaw_tiled = new_posyaw[np.newaxis, :].repeat(agents_posyaw.shape[0], 0) + new_extent_tiled = new_extent[np.newaxis, :].repeat(agents_posyaw.shape[0], 0) + dis = VEH_VEH_collision( + new_posyaw_tiled, agents_posyaw, new_extent_tiled, agents_extent + ) + return dis + + +def random_placing_neighbors(simscene, num_neighbors, coll_check=True): + init_modes = [0, 1, 2, 3, 4] + random.shuffle(init_modes) + init_modes = init_modes[:num_neighbors] + offset_x = 18.0 + offset_y = 5.0 + T = 10 + v_sigma = 0.3 + + dt = simscene.scene.dt + obs = simscene.get_obs() + if isinstance(simscene, SimulationScene): + obs = parse_trajdata_batch(obs) + + obs = TensorUtils.to_numpy(obs, ignore_if_unspecified=True) + num_new_agent = 0 + agent_names = [agent.name for agent in simscene.agents] + while "agent" + str(num_new_agent) in agent_names: + num_new_agent += 1 + ego_vel = obs["curr_speed"][0] + agent_plan = list() + for i in range(num_neighbors): + newagent_name = "agent" + str(num_new_agent + i) + newagent_type = 1 + newagent_extent = np.array([4.0, 2.5, 2.0]) + if init_modes[i] == 0: + # in front of the ego vehicle + newagent_state = np.array([[offset_x, 0, 0]]).repeat(T, 0) + vel = np.clip(ego_vel - 2.0 + np.random.randn() * v_sigma, 0.0, 40.0) + elif init_modes[i] == 1: + # behind of the ego vehicle + newagent_state = np.array([[-offset_x, 0, 0]]).repeat(T, 0) + vel = np.clip(ego_vel + 2.0 + np.random.randn() * v_sigma, 0.0, 40.0) + elif init_modes[i] == 2: + # left of the ego vehicle + newagent_state = np.array([[0, -offset_y, 0]]).repeat(T, 0) + vel = np.clip(ego_vel + np.random.randn() * v_sigma, 0.0, 40.0) + elif init_modes[i] == 3: + # right of the ego vehicle + newagent_state = np.array([[0, offset_y, 0]]).repeat(T, 0) + vel = np.clip(ego_vel + np.random.randn() * v_sigma, 0.0, 40.0) + elif init_modes[i] == 4: + # two vehicle length ahead of the ego vehicle + newagent_state = np.array([[2 * offset_x, 0, 0]]).repeat(T, 0) + vel = np.clip(ego_vel - 4.0 + np.random.randn() * v_sigma, 0.0, 40.0) + + newagent_state[:, 0] += np.arange(-T + 1, 1) * dt * vel + + add_flag = True + new_pos_global = GeoUtils.batch_nd_transform_points_np( + newagent_state[:, :2], obs["world_from_agent"][0] + ) + new_yaw_global = newagent_state[:, 2:] + obs["world_yaw"][0] + newagent_state_global = np.hstack((new_pos_global, new_yaw_global)) + if coll_check: + if "centroid" in obs: + agents_pos_global = obs["centroid"] + else: + agents_pos_global = GeoUtils.batch_nd_transform_points_np( + obs["history_positions"][:, -1], obs["world_from_agent"] + ) + agents_yaw_global = obs["world_yaw"] + agents_posyaw = np.hstack( + (agents_pos_global, agents_yaw_global[:, np.newaxis]) + ) + new_posyaw = np.hstack((new_pos_global[0], new_yaw_global[0])) + + dis = collision_check( + agents_posyaw, new_posyaw, obs["extent"], newagent_extent + ) + if dis.min() < 2.0: + add_flag = False + + if add_flag: + agent_plan.append( + dict( + name=newagent_name, + agent_state=newagent_state_global.tolist(), + initial_timestep=simscene.scene_ts - T + 1, + agent_type=newagent_type, + extent=newagent_extent.tolist(), + executed=False, + ) + ) + return agent_plan + + +def random_initial_adjust_plan(env, adjust_recipe): + adjust_plan = dict() + for simscene in env._current_scenes: + adjust_plan[simscene.scene.name] = dict( + remove_existing_neighbors=dict( + flag=adjust_recipe["remove_existing_neighbors"], executed=False + ), + agents=random_placing_neighbors( + simscene, adjust_recipe["initial_num_neighbors"] + ), + ) + + return adjust_plan + + +def rollout_episodes( + env, + policy, + num_episodes, + skip_first_n=1, + n_step_action=1, + render=False, + scene_indices=None, + start_frame_index_each_episode=None, + device=None, + obs_to_torch=True, + adjust_plan_recipe=None, + horizon=None, + seed_each_episode=None, +): + """ + Rollout an environment for a number of episodes + Args: + env (BaseEnv): a base simulation environment (gym-like) + policy (RolloutWrapper): a policy that controls agents in the environment + num_episodes (int): number of episodes to rollout for + skip_first_n (int): number of steps to skip at the begining + n_step_action (int): number of steps to take between querying models + render (bool): if True, return a sequence of rendered frames + scene_indices (tuple, list): (Optional) scenes indices to rollout with + start_frame_index_each_episode (List): (Optional) which frame to start each simulation episode from, + device: device to cast observation to + obs_to_torch: whether to cast observation to torch + adjust_plan_recipe (dict): (Optional) initialization condition, either a fixed plan or a recipe for random generation + horizon (int): (Optional) override horizon of the simulation + seed_each_episode (List): (Optional) a list of seeds, one for each episode + + Returns: + stats (dict): A dictionary of rollout stats for each episode (metrics, rewards, etc.) + info (dict): A dictionary of environment info for each episode + renderings (list): A list of rendered frames in the form of np.ndarray, one for each episode + """ + stats = {} + info = {} + renderings = [] + is_batched_env = isinstance(env, BatchedEnv) + timers = Timers() + adjust_plans = list() + if seed_each_episode is not None: + assert len(seed_each_episode) == num_episodes + if start_frame_index_each_episode is not None: + assert len(start_frame_index_each_episode) == num_episodes + + ego_policy = policy.unwrap()["Rollout.ego_policy"] + trace = list() + for ei in range(num_episodes): + if start_frame_index_each_episode is not None: + start_frame_index = start_frame_index_each_episode[ei] + else: + start_frame_index = None + + env.reset(scene_indices=scene_indices, start_frame_index=start_frame_index) + if adjust_plan_recipe is not None: + if "random_init_plan" in adjust_plan_recipe: + # recipe provided + if adjust_plan_recipe["random_init_plan"]: + adjust_recipe = adjust_plan_recipe + adjust_plan = random_initial_adjust_plan(env, adjust_recipe) + + else: + adjust_plan = None + adjust_recipe = None + + else: + # explicit plan provided + adjust_plan = adjust_plan_recipe + else: + adjust_plan = None + adjust_recipe = None + if adjust_plan is not None: + env.adjust_scene(adjust_plan) + + if seed_each_episode is not None: + env.update_random_seed(seed_each_episode[ei]) + + done = env.is_done() + counter = 0 + step_since_last_update = 0 + frames = list() + while not done: + if adjust_recipe is not None: + if step_since_last_update > adjust_recipe["num_frame_per_new_agent"]: + for simscene in env._current_scenes: + if simscene.scene_ts < simscene.scene.length_timesteps - 10: + extra_agent = random_placing_neighbors(simscene, 1) + adjust_plan[simscene.scene.name]["agents"] += extra_agent + env.adjust_scene(adjust_plan) + step_since_last_update = 0 + timers.tic("step") + with timers.timed("obs"): + obs = env.get_observation() + with timers.timed("to_torch"): + if obs_to_torch: + device = policy.device if device is None else device + obs_torch = TensorUtils.to_torch( + obs, device=device, ignore_if_unspecified=True + ) + else: + obs_torch = obs + + with timers.timed("network"): + action = policy.get_action(obs_torch, step_index=counter) + + if counter < skip_first_n: + # use GT action for the first N steps to warm up environment state (velocity, etc.) + gt_action = env.get_gt_action(obs) + action.ego = gt_action.ego + action.agents = gt_action.agents + env.step(action, num_steps_to_take=1, render=False) + counter += 1 + step_since_last_update += 1 + else: + with timers.timed("env_step"): + ims = env.step( + action, num_steps_to_take=n_step_action, render=render + ) # List of [num_scene, h, w, 3] + if render: + frames.extend(ims) + counter += n_step_action + step_since_last_update += n_step_action + timers.toc("step") + # print(timers) + + done = env.is_done() + + if horizon is not None and counter >= horizon: + break + metrics = env.get_metrics() + if hasattr(ego_policy, "savetrace") and ego_policy.savetrace: + trace.append(ego_policy.trace.copy()) + + for k, v in metrics.items(): + if k not in stats: + stats[k] = [] + if is_batched_env: # concatenate by scene + stats[k] = np.concatenate([stats[k], v], axis=0) + else: + stats[k].append(v) + + env_info = env.get_info() + for k, v in env_info.items(): + if k not in info: + if isinstance(v, dict): + info[k] = dict() + else: + info[k] = list() + + if is_batched_env: + if isinstance(v, dict): + info[k].update(v) + else: + info[k].extend(v) + else: + info[k].append(v) + del env_info + if hasattr(ego_policy, "reset"): + ego_policy.reset() + if render: + frames = np.stack(frames) + if is_batched_env: + # [step, scene] -> [scene, step] + frames = frames.transpose((1, 0, 2, 3, 4)) + renderings.append(frames) + if adjust_plan is not None: + adjust_plans.append(adjust_plan) + + multi_episodes_metrics = env.get_multi_episode_metrics() + stats.update(multi_episodes_metrics) + env.reset_multi_episodes_metrics() + + return stats, info, renderings, adjust_plans, trace + + +class RolloutCallback(pl.Callback): + """A pytorch-lightning callback function that runs rollouts during training""" + + def __init__( + self, + exp_config, + every_n_steps=100, + warm_start_n_steps=1, + verbose=False, + save_video=False, + video_dir=None, + ): + self._every_n_steps = every_n_steps + self._warm_start_n_steps = warm_start_n_steps + self._verbose = verbose + self._exp_cfg = exp_config.clone() + self._save_video = save_video + self._video_dir = video_dir + self.env = None + self.policy = None + self._eval_cfg = self._exp_cfg.eval + + def print_if_verbose(self, msg): + if self._verbose: + print(msg) + + def _get_env(self, device): + if self.env is not None: + return self.env + if self._eval_cfg.env == "nusc": + env_builder = EnvNuscBuilder( + eval_config=self._eval_cfg, exp_config=self._exp_cfg, device=device + ) + + else: + raise NotImplementedError( + "{} is not a valid env".format(self._eval_cfg.env) + ) + + env = env_builder.get_env() + self.env = env + return self.env + + def _get_policy(self, pl_module: pl.LightningModule): + if self.policy is not None: + return self.policy + policy_composers = importlib.import_module("tbsim.evaluation.policy_composers") + + composer_class = getattr(policy_composers, self._eval_cfg.eval_class) + composer = composer_class( + self._eval_cfg, pl_module.device, ckpt_root_dir=self._eval_cfg.ckpt_root_dir + ) + print("Building composer {}".format(self._eval_cfg.eval_class)) + + if self._exp_cfg.algo.name == "ma_rasterized": + policy, _ = composer.get_policy(predictor=pl_module) + else: + policy, _ = composer.get_policy(policy=pl_module) + + if self._eval_cfg.policy.pos_to_yaw: + policy = Pos2YawWrapper( + policy, + dt=self._exp_cfg.algo.step_time, + yaw_correction_speed=self._eval_cfg.policy.yaw_correction_speed, + ) + + if self._eval_cfg.env == "nusc": + rollout_policy = RolloutWrapper(agents_policy=policy) + elif self._eval_cfg.ego_only: + rollout_policy = RolloutWrapper(ego_policy=policy) + else: + rollout_policy = RolloutWrapper(ego_policy=policy, agents_policy=policy) + + self.policy = rollout_policy + return self.policy + + def _run_rollout(self, pl_module: pl.LightningModule, global_step: int): + rollout_policy = self._get_policy(pl_module) + env = self._get_env(pl_module.device) + + scene_i = 0 + eval_scenes = self._eval_cfg.eval_scenes + + result_stats = None + + while scene_i < len(eval_scenes): + scene_indices = eval_scenes[ + scene_i : scene_i + self._eval_cfg.num_scenes_per_batch + ] + scene_i += self._eval_cfg.num_scenes_per_batch + stats, info, renderings, _, _ = rollout_episodes( + env, + rollout_policy, + num_episodes=self._eval_cfg.num_episode_repeats, + n_step_action=self._eval_cfg.n_step_action, + render=self._save_video, + skip_first_n=self._eval_cfg.skip_first_n, + scene_indices=scene_indices, + start_frame_index_each_episode=self._eval_cfg.start_frame_index_each_episode, + seed_each_episode=self._eval_cfg.seed_each_episode, + horizon=self._eval_cfg.num_simulation_steps, + ) + + if result_stats is None: + result_stats = stats + result_stats["scene_index"] = np.array(info["scene_index"]) + else: + for k in stats: + result_stats[k] = np.concatenate( + [result_stats[k], stats[k]], axis=0 + ) + result_stats["scene_index"] = np.concatenate( + [result_stats["scene_index"], np.array(info["scene_index"])] + ) + + if self._save_video: + for ei, episode_rendering in enumerate(renderings): + for i, scene_images in enumerate(episode_rendering): + video_fn = "{}_{}_{}.mp4".format( + global_step, info["scene_index"][i], ei + ) + + writer = get_writer( + os.path.join(self._video_dir, video_fn), fps=10 + ) + print( + "video to {}".format( + os.path.join(self._video_dir, video_fn) + ) + ) + for im in scene_images: + writer.append_data(im) + writer.close() + return result_stats + + def on_train_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs, + batch, + batch_idx, + unused=0, + ) -> None: + should_run = ( + trainer.global_step >= self._warm_start_n_steps + and trainer.global_step % self._every_n_steps == 0 + ) + if should_run: + try: + self.print_if_verbose( + "\nStep %i rollout (%i episodes): " + % (trainer.global_step, len(self._eval_cfg.eval_scenes)) + ) + + stats = self._run_rollout(pl_module, trainer.global_step) + for k, v in stats.items(): + if "ttf" in k: # avoid cluttering the plot + continue + # Set on_step=True and on_epoch=False to force the logger to log stats at the step + # See https://github.com/PyTorchLightning/pytorch-lightning/issues/9772 for explanation + pl_module.log( + "rollout/" + k, np.mean(v), on_step=True, on_epoch=False + ) + self.print_if_verbose(("rollout/" + k, np.mean(v))) + self.print_if_verbose("\n") + except Exception as e: + print("Rollout failed because:") + print(e) diff --git a/diffstack/utils/experiment_utils.py b/diffstack/utils/experiment_utils.py new file mode 100644 index 0000000..5b8b266 --- /dev/null +++ b/diffstack/utils/experiment_utils.py @@ -0,0 +1,393 @@ +import json +import os +import itertools +from collections import namedtuple +from typing import List +from glob import glob +import subprocess +import shutil +from pathlib import Path + +import diffstack +from diffstack.configs.registry import get_registered_experiment_config +from diffstack.configs.config import Dict +from diffstack.configs.eval_config import EvaluationConfig +from diffstack.configs.base import ExperimentConfig + + +class Param(namedtuple("Param", "config_var alias value")): + pass + + +class ParamRange(namedtuple("Param", "config_var alias range")): + def linearize(self): + return [Param(self.config_var, self.alias, v) for v in self.range] + + def __len__(self): + return len(self.range) + + +class ParamConfig(object): + def __init__(self, params: List[Param] = None): + self.params = [] + self.aliases = [] + self.config_vars = [] + print(params) + if params is not None: + for p in params: + self.add(p) + + def add(self, param: Param): + assert param.config_var not in self.config_vars + assert param.alias not in self.aliases + self.config_vars.append(param.config_var) + self.aliases.append(param.alias) + self.params.append(param) + + def __str__(self): + char_to_remove = [" ", "(", ")", ";", "[", "]"] + name = [] + for p in self.params: + v_str = str(p.value) + for c in char_to_remove: + v_str = v_str.replace(c, "") + name.append(p.alias + v_str) + + return "_".join(name) + + def generate_config(self, base_cfg: Dict): + cfg = base_cfg.clone() + for p in self.params: + var_list = p.config_var.split(".") + c = cfg + # traverse the indexing list + for v in var_list[:-1]: + assert v in c, "{} is not a valid config variable".format(p.config_var) + c = c[v] + assert var_list[-1] in c, "{} is not a valid config variable".format( + p.config_var + ) + c[var_list[-1]] = p.value + cfg.name = str(self) + return cfg + + +class ParamSearchPlan(object): + def __init__(self): + self.param_configs = [] + self.const_params = [] + + def add_const_param(self, param: Param): + self.const_params.append(param) + + def add(self, param_config: ParamConfig): + for c in self.const_params: + param_config.add(c) + self.param_configs.append(param_config) + + def extend(self, param_configs: List[ParamConfig]): + for pc in param_configs: + self.add(pc) + + @staticmethod + def compose_concate(param_ranges: List[ParamRange]): + pcs = [] + for pr in param_ranges: + for p in pr.linearize(): + pcs.append(ParamConfig([p])) + return pcs + + @staticmethod + def compose_cartesian(param_ranges: List[ParamRange]): + """Cartesian product among parameters""" + prs = [pr.linearize() for pr in param_ranges] + return [ParamConfig(pr) for pr in itertools.product(*prs)] + + @staticmethod + def compose_zip(param_ranges: List[ParamRange]): + l = len(param_ranges[0]) + assert all( + len(pr) == l for pr in param_ranges + ), "All param_range must be the same length" + prs = [pr.linearize() for pr in param_ranges] + return [ParamConfig(prz) for prz in zip(*prs)] + + def generate_configs(self, base_cfg: Dict): + """ + Generate configs from the parameter search plan, also rename the experiment by generating the correct alias. + """ + if len(self.param_configs) > 0: + return [pc.generate_config(base_cfg) for pc in self.param_configs] + else: + # constant-only + const_cfg = ParamConfig(self.const_params) + return [const_cfg.generate_config(base_cfg)] + + +def create_configs( + configs_to_search_fn, + config_name, + config_file, + config_dir, + prefix, + delete_config_dir=True, +): + if config_name is not None: + cfg = get_registered_experiment_config(config_name) + print("Generating configs for {}".format(config_name)) + elif config_file is not None: + # Update default config with external json file + ext_cfg = json.load(open(config_file, "r")) + cfg = get_registered_experiment_config(ext_cfg["registered_name"]) + cfg.update(**ext_cfg) + print("Generating configs with {} as template".format(config_file)) + else: + raise FileNotFoundError("No base config is provided") + + configs = configs_to_search_fn(base_cfg=cfg) + for c in configs: + pfx = "{}_".format(prefix) if prefix is not None else "" + c.name = pfx + c.name + config_fns = [] + + if delete_config_dir and os.path.exists(config_dir): + shutil.rmtree(config_dir) + os.makedirs(config_dir, exist_ok=True) + for c in configs: + fn = os.path.join(config_dir, "{}.json".format(c.name)) + config_fns.append(fn) + print("Saving config to {}".format(fn)) + c.dump(fn) + + return configs, config_fns + + +def read_configs(config_dir): + configs = [] + config_fns = [] + for cfn in glob(config_dir + "/*.json"): + print(cfn) + config_fns.append(cfn) + ext_cfg = json.load(open(cfn, "r")) + c = get_registered_experiment_config(ext_cfg["registered_name"]) + c.update(**ext_cfg) + configs.append(c) + return configs, config_fns + + +def create_evaluation_configs( + configs_to_search_fn, + config_dir, + cfg, + prefix=None, + delete_config_dir=True, +): + configs = configs_to_search_fn(base_cfg=cfg) + for c in configs: + if prefix is not None: + c.name = prefix + "_" + c.name + + config_fns = [] + + if delete_config_dir and os.path.exists(config_dir): + shutil.rmtree(config_dir) + os.makedirs(config_dir, exist_ok=True) + for c in configs: + fn = os.path.join(config_dir, "{}.json".format(c.name)) + config_fns.append(fn) + print("Saving config to {}".format(fn)) + c.dump(fn) + + return configs, config_fns + + +def read_evaluation_configs(config_dir): + configs = [] + config_fns = [] + for cfn in glob(config_dir + "/*.json"): + print(cfn) + config_fns.append(cfn) + c = EvaluationConfig() + ext_cfg = json.load(open(cfn, "r")) + c.update(**ext_cfg) + configs.append(c) + return configs, config_fns + + +def upload_codebase_to_ngc_workspace(ngc_config): + """ + Upload local codebase to NGC workspace + Args: + ngc_config (dict): NGC config + + """ + ngc_path = os.path.join(ngc_config["workspace_mounting_point_local"], "diffstack/") + local_path = Path(diffstack.__path__[0]).parent + assert os.path.exists(ngc_path), "please mount NGC path first" + dir_list = ["scripts/", "diffstack/"] + for d in dir_list: + print("uploading {}".format(d)) + shutil.copytree( + os.path.join(local_path, d), os.path.join(ngc_path, d), dirs_exist_ok=True + ) + file_list = ["setup.py"] + for f in file_list: + print("uploading {}".format(f)) + shutil.copy(os.path.join(local_path, f), os.path.join(ngc_path, f)) + + +def launch_experiments_local(script_path, cfgs, cfg_paths, extra_args=[]): + for cfg, cpath in zip(cfgs, cfg_paths): + cmd = ["python", script_path, "--config_file", cpath] + extra_args + subprocess.run(cmd) + + +def get_results_info_ngc(ngc_job_id): + cmd = ["ngc", "result", "info", str(ngc_job_id), "--files"] + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + outs, errs = process.communicate() + if len(errs) > 0: + print(str(errs)) + return None + outs = str(outs).split("\\n") + ckpt_paths = [l.strip(" ") for l in outs if l.endswith(".ckpt")] + cfg_path = [l.strip(" ") for l in outs if l.endswith(".json")] + assert len(cfg_path) == 1 + cfg_path = cfg_path[0] + job_name = cfg_path.split("/")[1] + return ckpt_paths, cfg_path, job_name + + +def _download_from_ngc(ngc_job_id, paths_to_download, target_dir, tmp_dir="/tmp"): + cmd = ["ngc", "result", "download", str(ngc_job_id)] + print("Downloading: ") + for fp in paths_to_download: + print(fp) + cmd.extend(["--file", fp]) + + cmd.extend(["--dest", tmp_dir]) + if os.path.exists(os.path.join(tmp_dir, ngc_job_id + "/")): + print("tmp folder with ngc job ID exists, removing ...") + shutil.rmtree( + os.path.join(tmp_dir, ngc_job_id + "/") + ) # otherwise ngc renames the downloaded folder + subprocess.run(cmd) + + os.makedirs(target_dir, exist_ok=True) + + for fp in paths_to_download: + src_path = os.path.join(tmp_dir, ngc_job_id + fp) + shutil.move(src_path, target_dir) + + shutil.rmtree(os.path.join(tmp_dir, ngc_job_id + "/")) + + +def download_checkpoints_from_ngc( + ngc_job_id, ckpt_root_dir, ckpt_path_func=None, tmp_dir="/tmp" +): + assert os.path.exists(ckpt_root_dir) + ckpt_paths, cfg_path, job_name = get_results_info_ngc(ngc_job_id) + + if ckpt_path_func is None: + + def ckpt_path_func(x): + return x + + to_download = ckpt_path_func(ckpt_paths) + to_download.append(cfg_path) + ckpt_target_dir = os.path.join(ckpt_root_dir, "{}_{}".format(job_name, ngc_job_id)) + + _download_from_ngc(ngc_job_id, to_download, ckpt_target_dir, tmp_dir=tmp_dir) + return ckpt_target_dir + + +def get_local_ngc_checkpoint_dir(ngc_job_id, ckpt_root_dir): + for p in glob(ckpt_root_dir + "/*"): + if str(ngc_job_id) == p.split("_")[-1] or str(ngc_job_id) == p.split("/")[-1]: + return p + return None + + +def get_checkpoint( + ckpt_key, + ngc_job_id=None, + ckpt_dir=None, + ckpt_root_dir="checkpoints/", + download_tmp_dir="/tmp", +): + """ + Get checkpoint and config path given either a ngc job ID or a local dir. + + If a @ngc_job_id is specified, the function will first look for a directory that ends with @ngc_job_id inside + @ckpt_root_dir. E.g., if @ngc_job_id == `1234567`, and @ckpt_root_dir == "checkpoints/", the function will look for + a directory "checkpoints/*_1234567", and within the directory a `.ckpt` file that contains @ckpt_key. + If no such directory or checkpoint file exists, it will try to download the checkpoint from + NGC under the result directory of a job and save it locally such that it will not need to download it again the + next time that this function is invoked. + + If a @ckpt_dir is specified, the function will look for the directory locally and return the ckpt that contains + @ckpt_key, as well as its config.json. + + Args: + ckpt_key (str): a string that uniquely identifies a checkpoint file with a directory, e.g., `iter50000.ckpt` + ngc_job_id (str): (Optional) ngc job ID of the checkpoint if the training was done on NGC. + ckpt_dir (str): (Optional) a local directory that contains the specified checkpoint + ckpt_root_dir (str): (Optional) a directory that the function will look for checkpoints downloaded from NGC + download_tmp_dir (str): a temporary storage for the checkpoint. + + Returns: + ckpt_path (str): path to a checkpoint file + cfg_path (str): path to a config.json file + """ + + def ckpt_path_func(paths): + return [p for p in paths if str(ckpt_key) in p] + + if ngc_job_id is None or len(str(ngc_job_id)) == 0: + local_dir = ckpt_dir + assert ckpt_dir is not None + else: + local_dir = get_local_ngc_checkpoint_dir(ngc_job_id, ckpt_root_dir) + + if local_dir is None: + print("checkpoint does not exist, downloading ...") + ckpt_dir = download_checkpoints_from_ngc( + ngc_job_id=str(ngc_job_id), + ckpt_root_dir=ckpt_root_dir, + ckpt_path_func=ckpt_path_func, + tmp_dir=download_tmp_dir, + ) + else: + ckpt_paths = glob(local_dir + "/**/*.ckpt", recursive=True) + if len(ckpt_path_func(ckpt_paths)) == 0: + if ngc_job_id is not None: + print("checkpoint does not exist, downloading ...") + ckpt_dir = download_checkpoints_from_ngc( + ngc_job_id=str(ngc_job_id), + ckpt_root_dir=ckpt_root_dir, + ckpt_path_func=ckpt_path_func, + tmp_dir=download_tmp_dir, + ) + else: + raise FileNotFoundError( + "Cannot find checkpoint in {} with key {}".format( + local_dir, ckpt_key + ) + ) + else: + ckpt_dir = local_dir + + ckpt_paths = ckpt_path_func(glob(ckpt_dir + "/**/*.ckpt", recursive=True)) + assert len(ckpt_paths) > 0, "Could not find a checkpoint that has key {}".format( + ckpt_key + ) + assert len(ckpt_paths) == 1, "More than one checkpoint found {}".format(ckpt_paths) + cfg_path = glob(ckpt_dir + "/**/config.json", recursive=True)[0] + print("Checkpoint path: {}".format(ckpt_paths[0])) + print("Config path: {}".format(cfg_path)) + return ckpt_paths[0], cfg_path + + +if __name__ == "__main__": + # print(get_checkpoint("2546043", ckpt_key="iter87999_")) + pass diff --git a/diffstack/utils/fp16_util.py b/diffstack/utils/fp16_util.py new file mode 100644 index 0000000..23e0418 --- /dev/null +++ b/diffstack/utils/fp16_util.py @@ -0,0 +1,76 @@ +""" +Helpers to train with 16-bit precision. +""" + +import torch.nn as nn +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.half() + l.bias.data = l.bias.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.float() + l.bias.data = l.bias.data.float() + + +def make_master_params(model_params): + """ + Copy model parameters into a (differently-shaped) list of full-precision + parameters. + """ + master_params = _flatten_dense_tensors( + [param.detach().float() for param in model_params] + ) + master_params = nn.Parameter(master_params) + master_params.requires_grad = True + return [master_params] + + +def model_grads_to_master_grads(model_params, master_params): + """ + Copy the gradients from the model parameters into the master parameters + from make_master_params(). + """ + master_params[0].grad = _flatten_dense_tensors( + [param.grad.data.detach().float() for param in model_params] + ) + + +def master_params_to_model_params(model_params, master_params): + """ + Copy the master parameter data back into the model parameters. + """ + # Without copying to a list, if a generator is passed, this will + # silently not copy any parameters. + model_params = list(model_params) + + for param, master_param in zip( + model_params, unflatten_master_params(model_params, master_params) + ): + param.detach().copy_(master_param) + + +def unflatten_master_params(model_params, master_params): + """ + Unflatten the master parameters to look like model_params. + """ + return _unflatten_dense_tensors(master_params[0].detach(), model_params) + + +def zero_grad(model_params): + for param in model_params: + # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group + if param.grad is not None: + param.grad.detach_() + param.grad.zero_() diff --git a/diffstack/utils/geometry_utils.py b/diffstack/utils/geometry_utils.py new file mode 100644 index 0000000..d2b980d --- /dev/null +++ b/diffstack/utils/geometry_utils.py @@ -0,0 +1,567 @@ +import numpy as np + +import torch +from diffstack.utils.tensor_utils import round_2pi +from enum import IntEnum + +# Expose some functions from trajdata +from trajdata.utils.arr_utils import ( + transform_matrices, + batch_nd_transform_points_pt, + batch_nd_transform_points_np, +) + + +def get_box_agent_coords(pos, yaw, extent): + corners = (torch.tensor([[-1, -1], [-1, 1], [1, 1], [1, -1]]) * 0.5).to( + pos.device + ) * (extent.unsqueeze(-2)) + s = torch.sin(yaw).unsqueeze(-1) + c = torch.cos(yaw).unsqueeze(-1) + rotM = torch.cat((torch.cat((c, s), dim=-1), torch.cat((-s, c), dim=-1)), dim=-2) + rotated_corners = (corners + pos.unsqueeze(-2)) @ rotM + return rotated_corners + + +def get_box_world_coords(pos, yaw, extent): + corners = (torch.tensor([[-1, -1], [-1, 1], [1, 1], [1, -1]]) * 0.5).to( + pos.device + ) * (extent.unsqueeze(-2)) + s = torch.sin(yaw).unsqueeze(-1) + c = torch.cos(yaw).unsqueeze(-1) + rotM = torch.cat((torch.cat((c, s), dim=-1), torch.cat((-s, c), dim=-1)), dim=-2) + rotated_corners = corners @ rotM + pos.unsqueeze(-2) + return rotated_corners + + +def get_box_agent_coords_np(pos, yaw, extent): + corners = (np.array([[-1, -1], [-1, 1], [1, 1], [1, -1]]) * 0.5) * ( + extent[..., None, :] + ) + s = np.sin(yaw)[..., None] + c = np.cos(yaw)[..., None] + rotM = np.concatenate( + (np.concatenate((c, s), axis=-1), np.concatenate((-s, c), axis=-1)), axis=-2 + ) + rotated_corners = (corners + pos[..., None, :]) @ rotM + return rotated_corners + + +def extents_to_corners(extents_lw: torch.Tensor, xyh: torch.Tensor) -> torch.Tensor: + """ + Args: + extents_lw: (..., 2) (length, width) + xyh: (..., 3) + Returns: + box_points_xy (..., 4, 2) + + """ + assert extents_lw.shape[-1] == 2 and xyh.shape[-1] == 3 + assert extents_lw.ndim == xyh.ndim + extents_lw = extents_lw.float() + + rel_points = torch.tensor( + [[0.5, 0.5], [-0.5, 0.5], [-0.5, -0.5], [0.5, -0.5]], + dtype=torch.float, + device=extents_lw.device, + ) + rel_points = extents_lw.unsqueeze(-2) * rel_points.view( + ([1] * (extents_lw.ndim - 1)) + [4, 2] + ) + + xy, h = torch.split(xyh, (2, 1), dim=-1) + tf = transform_matrices(h.squeeze(-1).double(), xy.double()) + box_points_xy = batch_nd_transform_points_pt(rel_points.double(), tf) + return box_points_xy.type_as(xyh) + + +def get_box_world_coords_np(pos, yaw, extent): + corners = (np.array([[-1, -1], [-1, 1], [1, 1], [1, -1]]) * 0.5) * ( + extent[..., None, :] + ) + s = np.sin(yaw)[..., None] + c = np.cos(yaw)[..., None] + rotM = np.concatenate( + (np.concatenate((c, s), axis=-1), np.concatenate((-s, c), axis=-1)), axis=-2 + ) + rotated_corners = corners @ rotM + pos[..., None, :] + return rotated_corners + + +def get_upright_box(pos, extent): + yaws = torch.zeros(*pos.shape[:-1], 1).to(pos.device) + boxes = get_box_world_coords(pos, yaws, extent) + upright_boxes = boxes[..., [0, 2], :] + return upright_boxes + + +def batch_nd_transform_points(points, Mat): + ndim = Mat.shape[-1] - 1 + Mat = Mat.transpose(-1, -2) + return (points.unsqueeze(-2) @ Mat[..., :ndim, :ndim]).squeeze(-2) + Mat[ + ..., -1:, :ndim + ].squeeze(-2) + + +def transform_points_tensor( + points: torch.Tensor, transf_matrix: torch.Tensor +) -> torch.Tensor: + """ + Transform a set of 2D/3D points using the given transformation matrix. + Assumes row major ordering of the input points. The transform function has 3 modes: + - points (N, F), transf_matrix (F+1, F+1) + all points are transformed using the matrix and the output points have shape (N, F). + - points (B, N, F), transf_matrix (F+1, F+1) + all sequences of points are transformed using the same matrix and the output points have shape (B, N, F). + transf_matrix is broadcasted. + - points (B, N, F), transf_matrix (B, F+1, F+1) + each sequence of points is transformed using its own matrix and the output points have shape (B, N, F). + Note this function assumes points.shape[-1] == matrix.shape[-1] - 1, which means that last + rows in the matrices do not influence the final results. + For 2D points only the first 2x3 parts of the matrices will be used. + + :param points: Input points of shape (N, F) or (B, N, F) + with F = 2 or 3 depending on input points are 2D or 3D points. + :param transf_matrix: Transformation matrix of shape (F+1, F+1) or (B, F+1, F+1) with F = 2 or 3. + :return: Transformed points of shape (N, F) or (B, N, F) depending on the dimensions of the input points. + """ + points_log = f" received points with shape {points.shape} " + matrix_log = f" received matrices with shape {transf_matrix.shape} " + + assert points.ndim in [2, 3], f"points should have ndim in [2,3],{points_log}" + assert transf_matrix.ndim in [ + 2, + 3, + ], f"matrix should have ndim in [2,3],{matrix_log}" + assert ( + points.ndim >= transf_matrix.ndim + ), f"points ndim should be >= than matrix,{points_log},{matrix_log}" + + points_feat = points.shape[-1] + assert points_feat in [2, 3], f"last points dimension must be 2 or 3,{points_log}" + assert ( + transf_matrix.shape[-1] == transf_matrix.shape[-2] + ), f"matrix should be a square matrix,{matrix_log}" + + matrix_feat = transf_matrix.shape[-1] + assert matrix_feat in [3, 4], f"last matrix dimension must be 3 or 4,{matrix_log}" + assert ( + points_feat == matrix_feat - 1 + ), f"points last dim should be one less than matrix,{points_log},{matrix_log}" + + def _transform(points: torch.Tensor, transf_matrix: torch.Tensor) -> torch.Tensor: + num_dims = transf_matrix.shape[-1] - 1 + transf_matrix = torch.permute(transf_matrix, (0, 2, 1)) + return ( + points @ transf_matrix[:, :num_dims, :num_dims] + + transf_matrix[:, -1:, :num_dims] + ) + + if points.ndim == transf_matrix.ndim == 2: + points = torch.unsqueeze(points, 0) + transf_matrix = torch.unsqueeze(transf_matrix, 0) + return _transform(points, transf_matrix)[0] + + elif points.ndim == transf_matrix.ndim == 3: + return _transform(points, transf_matrix) + + elif points.ndim == 3 and transf_matrix.ndim == 2: + transf_matrix = torch.unsqueeze(transf_matrix, 0) + return _transform(points, transf_matrix) + else: + raise NotImplementedError(f"unsupported case!{points_log},{matrix_log}") + + +def PED_PED_collision(p1, p2, S1, S2): + if isinstance(p1, torch.Tensor): + return ( + torch.linalg.norm(p1[..., 0:2] - p2[..., 0:2], dim=-1) + - (S1[..., 0] + S2[..., 0]) / 2 + ) + + elif isinstance(p1, np.ndarray): + return ( + np.linalg.norm(p1[..., 0:2] - p2[..., 0:2], axis=-1) + - (S1[..., 0] + S2[..., 0]) / 2 + ) + else: + raise NotImplementedError + + +def transform_xy_coordinate_frame(traj_xy, anchor_xyh): + if isinstance(traj_xy, np.ndarray): + delta_x, delta_y, delta_h = np.split(anchor_xyh, 3, axis=-1) + x, y = np.split(traj_xy, 2, axis=-1) + # Move poses to origin + x = x - delta_x + y = y - delta_y + # Rotate around new origin with -delta_h + traj_xy = batch_rotate_2D(np.concatenate((x, y), -1), -delta_h) + elif isinstance(traj_xy, torch.tensor): + delta_x, delta_y, delta_h = torch.split(anchor_xyh, 3, dim=-1) + x, y = torch.split(traj_xy, 2, dim=-1) + # Move poses to origin + x = x - delta_x + y = y - delta_y + # Rotate around new origin with -delta_h + traj_xy = batch_rotate_2D(torch.cat((x, y), -1), -delta_h) + return traj_xy + + +def transform_xyh_coordinate_frame(traj_xyh, anchor_xyh): + if isinstance(traj_xyh, np.ndarray): + delta_x, delta_y, delta_h = np.split(anchor_xyh, 3, axis=-1) + xy, h = np.split(traj_xyh, (2,), axis=-1) + xy = transform_xy_coordinate_frame(xy, anchor_xyh) + h = h - delta_h + traj_xyh = np.concatenate((xy, h), axis=-1) + elif isinstance(traj_xyh, torch.tensor): + delta_x, delta_y, delta_h = torch.split(anchor_xyh, 3, axis=-1) + xy, h = torch.split(traj_xyh, (2, 1), dim=-1) + xy = transform_xy_coordinate_frame(xy, anchor_xyh) + h = h - delta_h + traj_xyh = torch.cat((xy, h), dim=-1) + return traj_xyh + + +def transform_xyhvv_coordinate_frame(traj_xyhvv, anchor_xyh): + if isinstance(traj_xyhvv, np.ndarray): + delta_x, delta_y, delta_h = np.split(anchor_xyh, 3, axis=-1) + xy, h, vx, vy = np.split( + traj_xyhvv, + ( + 2, + 3, + 4, + ), + axis=-1, + ) + xy = transform_xy_coordinate_frame(xy, anchor_xyh) + h = h - delta_h + vxy = batch_rotate_2D(np.concatenate((vx, vy), -1), -delta_h) + traj_xyhvv = np.concatenate((xy, h, vxy), axis=-1) + elif isinstance(traj_xyhvv, torch.tensor): + delta_x, delta_y, delta_h = torch.split(anchor_xyh, 3, axis=-1) + xy, h, vx, vy = torch.split(traj_xyhvv, (2, 1, 1, 1), dim=-1) + xy = transform_xy_coordinate_frame(xy, anchor_xyh) + h = h - delta_h + vxy = batch_rotate_2D(torch.cat((vx, vy), -1), -delta_h) + traj_xyhvv = torch.cat((xy, h, vxy), dim=-1) + return traj_xyhvv + + +def batch_rotate_2D(xy, theta): + if isinstance(xy, torch.Tensor): + x1 = xy[..., 0] * torch.cos(theta) - xy[..., 1] * torch.sin(theta) + y1 = xy[..., 1] * torch.cos(theta) + xy[..., 0] * torch.sin(theta) + return torch.stack([x1, y1], dim=-1) + elif isinstance(xy, np.ndarray): + x1 = xy[..., 0] * np.cos(theta) - xy[..., 1] * np.sin(theta) + y1 = xy[..., 1] * np.cos(theta) + xy[..., 0] * np.sin(theta) + return np.stack((x1, y1), axis=-1) + + +def VEH_VEH_collision(p1, p2, S1, S2, offsetX=0.0, offsetY=0.0): + if isinstance(p1, torch.Tensor): + cornersX = torch.kron( + S1[..., 0] + offsetX, torch.tensor([0.5, 0.5, -0.5, -0.5], device=p1.device) + ) + cornersY = torch.kron( + S1[..., 1] + offsetY, torch.tensor([0.5, -0.5, 0.5, -0.5], device=p1.device) + ) + corners = torch.stack([cornersX, cornersY], dim=-1) + theta1 = p1[..., 2] + theta2 = p2[..., 2] + dx = (p1[..., 0:2] - p2[..., 0:2]).repeat_interleave(4, dim=-2) + delta_x1 = batch_rotate_2D(corners, theta1.repeat_interleave(4, dim=-1)) + dx + delta_x2 = batch_rotate_2D(delta_x1, -theta2.repeat_interleave(4, dim=-1)) + dis = torch.maximum( + torch.abs(delta_x2[..., 0]) - 0.5 * S2[..., 0].repeat_interleave(4, dim=-1), + torch.abs(delta_x2[..., 1]) - 0.5 * S2[..., 1].repeat_interleave(4, dim=-1), + ).view(*S1.shape[:-1], 4) + min_dis, _ = torch.min(dis, dim=-1) + + return min_dis + + elif isinstance(p1, np.ndarray): + cornersX = np.kron(S1[..., 0] + offsetX, np.array([0.5, 0.5, -0.5, -0.5])) + cornersY = np.kron(S1[..., 1] + offsetY, np.array([0.5, -0.5, 0.5, -0.5])) + corners = np.concatenate((cornersX, cornersY), axis=-1) + theta1 = p1[..., 2] + theta2 = p2[..., 2] + dx = (p1[..., 0:2] - p2[..., 0:2]).repeat(4, axis=-2) + delta_x1 = batch_rotate_2D(corners, theta1.repeat(4, axis=-1)) + dx + delta_x2 = batch_rotate_2D(delta_x1, -theta2.repeat(4, axis=-1)) + dis = np.maximum( + np.abs(delta_x2[..., 0]) - 0.5 * S2[..., 0].repeat(4, axis=-1), + np.abs(delta_x2[..., 1]) - 0.5 * S2[..., 1].repeat(4, axis=-1), + ).reshape(*S1.shape[:-1], 4) + min_dis = np.min(dis, axis=-1) + return min_dis + else: + raise NotImplementedError + + +def VEH_PED_collision(p1, p2, S1, S2): + if isinstance(p1, torch.Tensor): + mask = torch.logical_or( + torch.abs(p1[..., 2]) > 0.1, torch.linalg.norm(p2[..., 2:4], dim=-1) > 0.1 + ).detach() + theta = p1[..., 2] + dx = batch_rotate_2D(p2[..., 0:2] - p1[..., 0:2], -theta) + + return torch.maximum( + torch.abs(dx[..., 0]) - S1[..., 0] / 2 - S2[..., 0] / 2, + torch.abs(dx[..., 1]) - S1[..., 1] / 2 - S2[..., 0] / 2, + ) + elif isinstance(p1, np.ndarray): + theta = p1[..., 2] + dx = batch_rotate_2D(p2[..., 0:2] - p1[..., 0:2], -theta) + return np.maximum( + np.abs(dx[..., 0]) - S1[..., 0] / 2 - S2[..., 0] / 2, + np.abs(dx[..., 1]) - S1[..., 1] / 2 - S2[..., 0] / 2, + ) + else: + raise NotImplementedError + + +def PED_VEH_collision(p1, p2, S1, S2): + return VEH_PED_collision(p2, p1, S2, S1) + + +def batch_proj(x, line): + # x:[batch,3], line:[batch,N,3] + line_length = line.shape[-2] + batch_dim = x.ndim - 1 + if isinstance(x, torch.Tensor): + delta = line[..., 0:2] - torch.unsqueeze(x[..., 0:2], dim=-2).repeat( + *([1] * batch_dim), line_length, 1 + ) + dis = torch.linalg.norm(delta, axis=-1) + idx0 = torch.argmin(dis, dim=-1) + idx = idx0.view(*line.shape[:-2], 1, 1).repeat( + *([1] * (batch_dim + 1)), line.shape[-1] + ) + line_min = torch.squeeze(torch.gather(line, -2, idx), dim=-2) + dx = x[..., None, 0] - line[..., 0] + dy = x[..., None, 1] - line[..., 1] + delta_y = -dx * torch.sin(line_min[..., None, 2]) + dy * torch.cos( + line_min[..., None, 2] + ) + delta_x = dx * torch.cos(line_min[..., None, 2]) + dy * torch.sin( + line_min[..., None, 2] + ) + # ref_pts = torch.stack( + # [ + # line_min[..., 0] + delta_x * torch.cos(line_min[..., 2]), + # line_min[..., 1] + delta_x * torch.sin(line_min[..., 2]), + # line_min[..., 2], + # ], + # dim=-1, + # ) + delta_psi = round_2pi(x[..., 2] - line_min[..., 2]) + + return ( + delta_x, + delta_y, + torch.unsqueeze(delta_psi, dim=-1), + ) + elif isinstance(x, np.ndarray): + delta = line[..., 0:2] - np.repeat( + x[..., np.newaxis, 0:2], line_length, axis=-2 + ) + dis = np.linalg.norm(delta, axis=-1) + idx0 = np.argmin(dis, axis=-1) + idx = idx0.reshape(*line.shape[:-2], 1, 1).repeat(line.shape[-1], axis=-1) + line_min = np.squeeze(np.take_along_axis(line, idx, axis=-2), axis=-2) + dx = x[..., None, 0] - line[..., 0] + dy = x[..., None, 1] - line[..., 1] + delta_y = -dx * np.sin(line_min[..., None, 2]) + dy * np.cos( + line_min[..., None, 2] + ) + delta_x = dx * np.cos(line_min[..., None, 2]) + dy * np.sin( + line_min[..., None, 2] + ) + # line_min[..., 0] += delta_x * np.cos(line_min[..., 2]) + # line_min[..., 1] += delta_x * np.sin(line_min[..., 2]) + delta_psi = round_2pi(x[..., 2] - line_min[..., 2]) + return ( + delta_x, + delta_y, + np.expand_dims(delta_psi, axis=-1), + ) + + +def batch_proj_xysc(x, line): + # x:[batch,4](x,y,s,c), line:[batch,N,4](x,y,s,c) + # normalizing s and c + x = normalize_xysc(x) + line = normalize_xysc(line) + + line_length = line.shape[-2] + batch_dim = x.ndim - 1 + if isinstance(x, torch.Tensor): + x = normalize_xysc(x) + line = normalize_xysc(line) + delta = line[..., 0:2] - torch.unsqueeze(x[..., 0:2], dim=-2).repeat( + *([1] * batch_dim), line_length, 1 + ) + dis = torch.linalg.norm(delta, axis=-1) + idx0 = torch.argmin(dis, dim=-1) + idx = idx0.view(*line.shape[:-2], 1, 1).repeat( + *([1] * (batch_dim + 1)), line.shape[-1] + ) + line_min = torch.squeeze(torch.gather(line, -2, idx), dim=-2) + dx = x[..., None, 0] - line[..., 0] + dy = x[..., None, 1] - line[..., 1] + + delta_x = dx * line_min[..., None, 3] + dy * line_min[..., None, 2] + delta_y = -dx * line_min[..., None, 2] + dy * line_min[..., None, 3] + + delta_s = x[..., None, 2] * line[..., 3] - x[..., None, 3] * line[..., 2] + delta_c = x[..., None, 2] * line[..., 2] + x[..., None, 3] * line[..., 3] + + return torch.stack((delta_x, delta_y, delta_s, delta_c), -1) + else: + raise NotImplementedError + + +def normalize_sc(x): + x_zero_flag = (x == 0).all(-1).type(x.dtype) + x = x + x_zero_flag[..., None] * torch.stack( + [torch.zeros_like(x_zero_flag), torch.ones_like(x_zero_flag)], -1 + ) + x = x / torch.norm(x, dim=-1, keepdim=True) + return x + + +def normalize_xysc(x): + xsc = normalize_sc(x[..., 2:4]) + return torch.cat([x[..., :2], xsc], -1) + + +class CollisionType(IntEnum): + """This enum defines the three types of collisions: front, rear and side.""" + + FRONT = 0 + REAR = 1 + SIDE = 2 + + +def detect_collision( + ego_pos: np.ndarray, + ego_yaw: np.ndarray, + ego_extent: np.ndarray, + other_pos: np.ndarray, + other_yaw: np.ndarray, + other_extent: np.ndarray, +): + """ + Computes whether a collision occured between ego and any another agent. + Also computes the type of collision: rear, front, or side. + For this, we compute the intersection of ego's four sides with a target + agent and measure the length of this intersection. A collision + is classified into a class, if the corresponding length is maximal, + i.e. a front collision exhibits the longest intersection with + egos front edge. + + .. note:: please note that this funciton will stop upon finding the first + colision, so it won't return all collisions but only the first + one found. + + :param ego_pos: predicted centroid + :param ego_yaw: predicted yaw + :param ego_extent: predicted extent + :param other_pos: target agents + :return: None if not collision was found, and a tuple with the + collision type and the agent track_id + """ + from l5kit.planning import utils + + ego_bbox = utils._get_bounding_box(centroid=ego_pos, yaw=ego_yaw, extent=ego_extent) + + # within_range_mask = utils.within_range(ego_pos, ego_extent, other_pos, other_extent) + for i in range(other_pos.shape[0]): + agent_bbox = utils._get_bounding_box( + other_pos[i], other_yaw[i], other_extent[i] + ) + if ego_bbox.intersects(agent_bbox): + front_side, rear_side, left_side, right_side = utils._get_sides(ego_bbox) + + intersection_length_per_side = np.asarray( + [ + agent_bbox.intersection(front_side).length, + agent_bbox.intersection(rear_side).length, + agent_bbox.intersection(left_side).length, + agent_bbox.intersection(right_side).length, + ] + ) + argmax_side = np.argmax(intersection_length_per_side) + + # Remap here is needed because there are two sides that are + # mapped to the same collision type CollisionType.SIDE + max_collision_types = max(CollisionType).value + remap_argmax = min(argmax_side, max_collision_types) + collision_type = CollisionType(remap_argmax) + return collision_type, i + return None + + +def calc_distance_map(road_flag, max_dis=10, mode="L1"): + """mark the image with manhattan distance to the drivable area + + Args: + road_flag (torch.Tensor[B,W,H]): an image with 1 channel, 1 for drivable area, 0 for non-drivable area + max_dis (int, optional): maximum distance that the result saturates to. Defaults to 10. + """ + out = torch.zeros_like(road_flag, dtype=torch.float) + out[road_flag == 0] = max_dis + out[road_flag == 1] = 0 + if mode == "L1": + for i in range(max_dis - 1): + out[..., 1:, :] = torch.min(out[..., 1:, :], out[..., :-1, :] + 1) + out[..., :-1, :] = torch.min(out[..., :-1, :], out[..., 1:, :] + 1) + out[..., :, 1:] = torch.min(out[..., :, 1:], out[..., :, :-1] + 1) + out[..., :, :-1] = torch.min(out[..., :, :-1], out[..., :, 1:] + 1) + elif mode == "Linf": + for i in range(max_dis - 1): + out[..., 1:, :] = torch.min(out[..., 1:, :], out[..., :-1, :] + 1) + out[..., :-1, :] = torch.min(out[..., :-1, :], out[..., 1:, :] + 1) + out[..., :, 1:] = torch.min(out[..., :, 1:], out[..., :, :-1] + 1) + out[..., :, :-1] = torch.min(out[..., :, :-1], out[..., :, 1:] + 1) + out[..., 1:, 1:] = torch.min(out[..., 1:, 1:], out[..., :-1, :-1] + 1) + out[..., 1:, :-1] = torch.min(out[..., 1:, :-1], out[..., :-1, 1:] + 1) + out[..., :-1, :-1] = torch.min(out[..., :-1, :-1], out[..., 1:, 1:] + 1) + out[..., :-1, 1:] = torch.min(out[..., :-1, 1:], out[..., 1:, :-1] + 1) + + return out + + +def rel_xysc(p1, p2): + """calculate the relative position of p2 to p1 + + Args: + p1 (torch.Tensor[B,4]): [x,y,s,c] + p2 (torch.Tensor[B,4]): [x,y,s,c] + + Returns: + torch.Tensor[B,4]: [dx,dy,ds,dc] + """ + # normalizing s and c + p1 = normalize_xysc(p1) + p2 = normalize_xysc(p2) + + dx = p2[..., 0] - p1[..., 0] + dy = p2[..., 1] - p1[..., 1] + dxl = dx * p1[..., 3] + dy * p1[..., 2] + dyl = -dx * p1[..., 2] + dy * p1[..., 3] + ds = p2[..., 2] * p1[..., 3] - p2[..., 3] * p1[..., 2] + dc = p2[..., 3] * p1[..., 3] + p2[..., 2] * p1[..., 2] + return torch.stack((dxl, dyl, ds, dc), -1) + + +def ratan2(s, c, eps=1e-4): + # robust arctan2 for pytorch + sign = (c >= 0).float() * 2 - 1 + eps = eps * (c.abs() < eps).type(c.dtype) * sign + return torch.arctan2(s, c + eps) diff --git a/diffstack/utils/homotopy.py b/diffstack/utils/homotopy.py new file mode 100644 index 0000000..6f3b665 --- /dev/null +++ b/diffstack/utils/homotopy.py @@ -0,0 +1,113 @@ +from Pplan.Sampling.spline_planner import SplinePlanner +from Pplan.Sampling.trajectory_tree import TrajTree + +import torch +import numpy as np +import diffstack.utils.geometry_utils as GeoUtils +from diffstack.utils.geometry_utils import ratan2 +from diffstack.utils.tree import Tree +from typing import List +from enum import IntEnum + + +HOMOTOPY_THRESHOLD = np.pi / 6 + + +class HomotopyType(IntEnum): + """ + Homotopy class between two paths + STATIC: relatively small wraping angle + CW: clockwise + CCW: counter-clockwise + """ + + STATIC = 0 + CW = 1 + CCW = 2 + + @staticmethod + def enforce_symmetry(x, mode="U"): + assert x.shape[-2] == x.shape[-1] + xT = x.transpose(-1, -2).clone() + + diag_mask = torch.eye(x.shape[-1], device=x.device).bool().expand(*x.shape) + x.masked_fill_(diag_mask, HomotopyType.STATIC) + + if mode == "U": + # force symmetry based on upper triangular matrix + triangle = torch.tril(torch.ones_like(x), diagonal=-1) + x = x * (1 - triangle) + xT * triangle + elif mode == "L": + # force symmetry based on lower triangular matrix + triangle = torch.triu(torch.ones_like(x), diagonal=1) + x = x * (1 - triangle) + xT * triangle + return x + + +def mag_integral(path0, path1, mask=None): + if isinstance(path0, torch.Tensor): + delta_path = path0 - path1 + close_flag = torch.norm(delta_path, dim=-1) < 1e-3 + angle = ratan2(delta_path[..., 1], delta_path[..., 0]).masked_fill( + close_flag, 0 + ) + delta_angle = GeoUtils.round_2pi(angle[..., 1:] - angle[..., :-1]) + if mask is not None: + if mask.ndim == delta_angle.ndim - 1: + delta_angle = delta_angle * mask[..., None] + elif mask.ndim == delta_angle.ndim: + diff_mask = mask[..., 1:] * mask[..., :-1] + delta_angle = delta_angle * diff_mask + angle_diff = torch.sum(delta_angle, dim=-1) + + elif isinstance(path0, torch.ndarray): + delta_path = path0 - path1 + close_flag = (np.norm(delta_path, dim=-1) < 1e-3).float() + angle = np.arctan2(delta_path[..., 1], delta_path[..., 0]) * (1 - close_flag) + delta_angle = GeoUtils.round_2pi(angle[..., 1:] - angle[..., :-1]) + if mask is not None: + if mask.ndim == delta_angle.ndim - 1: + delta_angle = delta_angle * mask[..., None] + elif mask.ndim == delta_angle.ndim: + diff_mask = mask[..., 1:] * mask[..., :-1] + delta_angle = delta_angle * diff_mask + angle_diff = np.sum(delta_angle, axis=-1) + return angle_diff + + +def identify_homotopy( + ego_path: torch.Tensor, obj_paths: torch.Tensor, threshold=HOMOTOPY_THRESHOLD +): + """Identifying homotopy classes for the ego + + Args: + ego_path (torch.Tensor): B x T x 2 + obj_paths (torch.Tensor): B x M x N x T x 2 + """ + b, M, N = obj_paths.shape[:3] + angle_diff = mag_integral(ego_path[:, None, None], obj_paths) + homotopy = torch.zeros([b, M, N], device=ego_path.device) + homotopy[angle_diff >= threshold] = HomotopyType.CCW + homotopy[angle_diff <= -threshold] = HomotopyType.CW + homotopy[(angle_diff > -threshold) & (angle_diff < threshold)] = HomotopyType.STATIC + + return angle_diff, homotopy + + +def identify_pairwise_homotopy( + path: torch.Tensor, threshold=HOMOTOPY_THRESHOLD, mask=None +): + """ + Args: + path (torch.Tensor): B x N x T x 2 + """ + b, N, T = path.shape[:3] + if mask is not None: + mask = mask[:, None] * mask[:, :, None] + angle_diff = mag_integral(path[:, :, None], path[:, None], mask=mask) + homotopy = torch.zeros([b, N, N], device=path.device) + homotopy[angle_diff >= threshold] = HomotopyType.CCW + homotopy[angle_diff <= -threshold] = HomotopyType.CW + homotopy[(angle_diff > -threshold) & (angle_diff < threshold)] = HomotopyType.STATIC + + return angle_diff, homotopy diff --git a/diffstack/utils/kalman_filter.py b/diffstack/utils/kalman_filter.py new file mode 100644 index 0000000..85e6866 --- /dev/null +++ b/diffstack/utils/kalman_filter.py @@ -0,0 +1,120 @@ +import numpy as np + +class NonlinearKinematicBicycle: + """ + Nonlinear Kalman Filter for a kinematic bicycle model, assuming constant longitudinal speed + and constant heading array + """ + + def __init__(self, dt, sPos=None, sHeading=None, sVel=None, sMeasurement=None): + self.dt = dt + + # measurement matrix + self.C = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) + + # default noise covariance + if (sPos is None) and (sHeading is None) and (sVel is None): + # TODO need to further check + # sPos = 0.5 * 8.8 * dt ** 2 # assume 8.8m/s2 as maximum acceleration + # sHeading = 0.5 * dt # assume 0.5rad/s as maximum turn rate + # sVel = 8.8 * dt # assume 8.8m/s2 as maximum acceleration + # sMeasurement = 1.0 + sPos = 12 * self.dt # assume 6m/s2 as maximum acceleration + sHeading = 0.5 * self.dt # assume 0.5rad/s as maximum turn rate + sVel = 6 * self.dt # assume 6m/s2 as maximum acceleration + if sMeasurement is None: + sMeasurement = 5.0 + # state transition noise + self.Q = np.diag([sPos ** 2, sPos ** 2, sHeading ** 2, sVel ** 2]) + # measurement noise + self.R = np.diag([sMeasurement ** 2, sMeasurement ** 2, sMeasurement ** 2, sMeasurement ** 2]) + + def predict_and_update(self, x_vec_est, u_vec, P_matrix, z_new): + """ + for background please refer to wikipedia: https://en.wikipedia.org/wiki/Extended_Kalman_filter + :param x_vec_est: + :param u_vec: + :param P_matrix: + :param z_new: + :return: + """ + + ## Prediction Step + # predicted state estimate + x_pred = self._kinematic_bicycle_model_rearCG(x_vec_est, u_vec) + # Compute Jacobian to obtain the state transition matrix + A = self._cal_state_Jacobian(x_vec_est, u_vec) + # predicted error covariance + P_pred = A.dot(P_matrix.dot(A.transpose())) + self.Q + + ## Update Step + # innovation or measurement pre-fit residual + y_telda = z_new - self.C.dot(x_pred) + # innovation covariance + S = self.C.dot(P_pred.dot(self.C.transpose())) + self.R + # near-optimal Kalman gain + K = P_pred.dot(self.C.transpose().dot(np.linalg.inv(S))) + # updated (a posteriori) state estimate + x_vec_est_new = x_pred + K.dot(y_telda) + # updated (a posteriori) estimate covariance + P_matrix_new = np.dot((np.identity(4) - K.dot(self.C)), P_pred) + + return x_vec_est_new, P_matrix_new + + def _kinematic_bicycle_model_rearCG(self, x_old, u): + """ + :param x: vehicle state vector = [x position, y position, heading, velocity] + :param u: control vector = [acceleration, steering array] + :param dt: + :return: + """ + + acc = u[0] + delta = u[1] + + x = x_old[0] + y = x_old[1] + psi = x_old[2] + vel = x_old[3] + + x_new = np.array([[0.], [0.], [0.], [0.]]) + + x_new[0] = x + self.dt * vel * np.cos(psi + delta) + x_new[1] = y + self.dt * vel * np.sin(psi + delta) + x_new[2] = psi + self.dt * delta + #x_new[2] = _heading_angle_correction(x_new[2]) + x_new[3] = vel + self.dt * acc + + return x_new + + def _cal_state_Jacobian(self, x_vec, u_vec): + acc = u_vec[0] + delta = u_vec[1] + + x = x_vec[0] + y = x_vec[1] + psi = x_vec[2] + vel = x_vec[3] + + a13 = -self.dt * vel * np.sin(psi + delta) + a14 = self.dt * np.cos(psi + delta) + a23 = self.dt * vel * np.cos(psi + delta) + a24 = self.dt * np.sin(psi + delta) + a34 = self.dt * delta + + JA = np.array([[1.0, 0.0, a13[0], a14[0]], + [0.0, 1.0, a23[0], a24[0]], + [0.0, 0.0, 1.0, a34[0]], + [0.0, 0.0, 0.0, 1.0]]) + + return JA + + +def _heading_angle_correction(theta): + """ + correct heading array so that it always remains in [-pi, pi] + :param theta: + :return: + """ + theta_corrected = (theta + np.pi) % (2.0 * np.pi) - np.pi + return theta_corrected \ No newline at end of file diff --git a/diffstack/utils/l5_utils.py b/diffstack/utils/l5_utils.py new file mode 100644 index 0000000..5a03401 --- /dev/null +++ b/diffstack/utils/l5_utils.py @@ -0,0 +1,726 @@ +import torch +import torch.nn.functional as F + +import diffstack.dynamics as dynamics +import diffstack.utils.tensor_utils as TensorUtils +from diffstack import dynamics as dynamics +from diffstack.configs.base import ExperimentConfig + + +def get_agent_masks(raw_type): + """ + PERCEPTION_LABELS = [ + "PERCEPTION_LABEL_NOT_SET", + "PERCEPTION_LABEL_UNKNOWN", + "PERCEPTION_LABEL_DONTCARE", + "PERCEPTION_LABEL_CAR", + "PERCEPTION_LABEL_VAN", + "PERCEPTION_LABEL_TRAM", + "PERCEPTION_LABEL_BUS", + "PERCEPTION_LABEL_TRUCK", + "PERCEPTION_LABEL_EMERGENCY_VEHICLE", + "PERCEPTION_LABEL_OTHER_VEHICLE", + "PERCEPTION_LABEL_BICYCLE", + "PERCEPTION_LABEL_MOTORCYCLE", + "PERCEPTION_LABEL_CYCLIST", + "PERCEPTION_LABEL_MOTORCYCLIST", + "PERCEPTION_LABEL_PEDESTRIAN", + "PERCEPTION_LABEL_ANIMAL", + "AVRESEARCH_LABEL_DONTCARE", + ] + """ + veh_mask = (raw_type >= 3) & (raw_type <= 13) + ped_mask = (raw_type == 14) | (raw_type == 15) + # veh_mask = veh_mask | ped_mask + # ped_mask = ped_mask * 0 + return veh_mask, ped_mask + + +def get_dynamics_types(veh_mask, ped_mask): + dyn_type = torch.zeros_like(veh_mask) + dyn_type += dynamics.DynType.UNICYCLE * veh_mask + dyn_type += dynamics.DynType.DI * ped_mask + return dyn_type + + +def raw_to_features(batch_raw): + """ map raw src into features of dim 21 """ + raw_type = batch_raw["raw_types"] + pos = batch_raw["hist_pos"] + vel = batch_raw["history_velocities"] + yaw = batch_raw["hist_yaw"] + mask = batch_raw["hist_mask"] + + veh_mask, ped_mask = get_agent_masks(raw_type) + + # all vehicles, cyclists, and motorcyclists + feature_veh = torch.cat((pos, vel, torch.cos(yaw), torch.sin(yaw)), dim=-1) + + # pedestrians and animals + ped_feature = torch.cat( + (pos, vel, vel * torch.sin(yaw), vel * torch.cos(yaw)), dim=-1 + ) + + feature = feature_veh * veh_mask.view( + [*raw_type.shape, 1, 1] + ) + ped_feature * ped_mask.view([*raw_type.shape, 1, 1]) + + type_embedding = F.one_hot(raw_type, 16) + + feature = torch.cat( + (feature, type_embedding.unsqueeze(-2).repeat(1, 1, feature.size(2), 1)), + dim=-1, + ) + feature = feature * mask.unsqueeze(-1) + + return feature + + +def raw_to_states(batch_raw): + raw_type = batch_raw["raw_types"] + pos = batch_raw["hist_pos"] + vel = batch_raw["history_velocities"] + yaw = batch_raw["hist_yaw"] + avail_mask = batch_raw["hist_mask"] + + veh_mask, ped_mask = get_agent_masks(raw_type) # [B, (A)] + + # all vehicles, cyclists, and motorcyclists + state_veh = torch.cat((pos, vel, yaw), dim=-1) # [B, (A), T, S] + # pedestrians and animals + state_ped = torch.cat((pos, vel * torch.cos(yaw), vel * torch.sin(yaw)), dim=-1) # [B, (A), T, S] + + state = state_veh * veh_mask.view( + [*raw_type.shape, 1, 1] + ) + state_ped * ped_mask.view([*raw_type.shape, 1, 1]) # [B, (A), T, S] + + # Get the current state of the agents + num = torch.arange(0, avail_mask.shape[-1]).view(1, 1, -1).to(avail_mask.device) + nummask = num * avail_mask + last_idx, _ = torch.max(nummask, dim=2) + curr_state = torch.gather( + state, 2, last_idx[..., None, None].repeat(1, 1, 1, 4) + ) + return state, curr_state + + +def batch_to_raw_ego(data_batch, step_time): + batch_size = data_batch["hist_pos"].shape[0] + raw_type = torch.ones(batch_size).type(torch.int64).to(data_batch["hist_pos"].device) # [B, T] + raw_type = raw_type * 3 # index for type PERCEPTION_LABEL_CAR + + src_pos = torch.flip(data_batch["hist_pos"], dims=[-2]) + src_yaw = torch.flip(data_batch["hist_yaw"], dims=[-2]) + src_mask = torch.flip(data_batch["hist_mask"], dims=[-1]).bool() + + src_vel = dynamics.Unicycle(step_time).calculate_vel(pos=src_pos, yaw=src_yaw, mask=src_mask) + src_vel[:, -1] = data_batch["curr_speed"].unsqueeze(-1) + + raw = { + "hist_pos": src_pos, + "history_velocities": src_vel, + "hist_yaw": src_yaw, + "raw_types": raw_type, + "hist_mask": src_mask, + "extents": data_batch["extents"] + } + + raw = TensorUtils.unsqueeze(raw, dim=1) # Add the agent dimension + return raw + + +def raw2feature(pos, vel, yaw, raw_type, mask, lanes=None, add_noise=False): + "map raw src into features of dim 21+lane dim" + + """ + PERCEPTION_LABELS = [ + "PERCEPTION_LABEL_NOT_SET", + "PERCEPTION_LABEL_UNKNOWN", + "PERCEPTION_LABEL_DONTCARE", + "PERCEPTION_LABEL_CAR", + "PERCEPTION_LABEL_VAN", + "PERCEPTION_LABEL_TRAM", + "PERCEPTION_LABEL_BUS", + "PERCEPTION_LABEL_TRUCK", + "PERCEPTION_LABEL_EMERGENCY_VEHICLE", + "PERCEPTION_LABEL_OTHER_VEHICLE", + "PERCEPTION_LABEL_BICYCLE", + "PERCEPTION_LABEL_MOTORCYCLE", + "PERCEPTION_LABEL_CYCLIST", + "PERCEPTION_LABEL_MOTORCYCLIST", + "PERCEPTION_LABEL_PEDESTRIAN", + "PERCEPTION_LABEL_ANIMAL", + "AVRESEARCH_LABEL_DONTCARE", + ] + """ + dyn_type = torch.zeros_like(raw_type) + veh_mask = (raw_type >= 3) & (raw_type <= 13) + ped_mask = (raw_type == 14) | (raw_type == 15) + veh_mask = veh_mask | ped_mask + ped_mask = ped_mask * 0 + dyn_type += dynamics.DynType.UNICYCLE * veh_mask + # all vehicles, cyclists, and motorcyclists + if add_noise: + pos_noise = torch.randn(pos.size(0), 1, 1, 2).to(pos.device) * 0.5 + yaw_noise = torch.randn(pos.size(0), 1, 1, 1).to(pos.device) * 0.1 + if pos.ndim == 5: + pos_noise = pos_noise.unsqueeze(1) + yaw_noise = yaw_noise.unsqueeze(1) + feature_veh = torch.cat( + ( + pos + pos_noise, + vel, + torch.cos(yaw + yaw_noise), + torch.sin(yaw + yaw_noise), + ), + dim=-1, + ) + else: + feature_veh = torch.cat((pos, vel, torch.cos(yaw), torch.sin(yaw)), dim=-1) + + state_veh = torch.cat((pos, vel, yaw), dim=-1) + + # pedestrians and animals + if add_noise: + pos_noise = torch.randn(pos.size(0), 1, 1, 2).to(pos.device) * 0.5 + yaw_noise = torch.randn(pos.size(0), 1, 1, 1).to(pos.device) * 0.1 + if pos.ndim == 5: + pos_noise = pos_noise.unsqueeze(1) + yaw_noise = yaw_noise.unsqueeze(1) + ped_feature = torch.cat( + ( + pos + pos_noise, + vel, + vel * torch.sin(yaw + yaw_noise), + vel * torch.cos(yaw + yaw_noise), + ), + dim=-1, + ) + else: + ped_feature = torch.cat( + (pos, vel, vel * torch.sin(yaw), vel * torch.cos(yaw)), dim=-1 + ) + state_ped = torch.cat((pos, vel * torch.cos(yaw), vel * torch.sin(yaw)), dim=-1) + state = state_veh * veh_mask.view( + [*raw_type.shape, 1, 1] + ) + state_ped * ped_mask.view([*raw_type.shape, 1, 1]) + dyn_type += dynamics.DynType.DI * ped_mask + + feature = feature_veh * veh_mask.view( + [*raw_type.shape, 1, 1] + ) + ped_feature * ped_mask.view([*raw_type.shape, 1, 1]) + + type_embedding = F.one_hot(raw_type, 16) + + if pos.ndim == 4: + if lanes is not None: + feature = torch.cat( + ( + feature, + type_embedding.unsqueeze(-2).repeat(1, 1, feature.size(2), 1), + lanes[:, :, None, :].repeat(1, 1, feature.size(2), 1), + ), + dim=-1, + ) + else: + feature = torch.cat( + ( + feature, + type_embedding.unsqueeze(-2).repeat(1, 1, feature.size(2), 1), + ), + dim=-1, + ) + + elif pos.ndim == 5: + if lanes is not None: + feature = torch.cat( + ( + feature, + type_embedding.unsqueeze(-2).repeat(1, 1, 1, feature.size(-2), 1), + lanes[:, :, None, None, :].repeat( + 1, feature.size(1), 1, feature.size(2), 1 + ), + ), + dim=-1, + ) + else: + feature = torch.cat( + ( + feature, + type_embedding.unsqueeze(-2).repeat(1, 1, 1, feature.size(-2), 1), + ), + dim=-1, + ) + feature = feature * mask.unsqueeze(-1) + return feature, dyn_type, state + + +def batch_to_vectorized_feature(data_batch, dyn_list, step_time, algo_config): + device = data_batch["hist_pos"].device + raw_type = torch.cat( + (data_batch["type"].unsqueeze(1), data_batch["all_other_agents_types"]), + dim=1, + ).type(torch.int64) + extents = torch.cat( + ( + data_batch["extent"][..., :2].unsqueeze(1), + torch.max(data_batch["all_other_agents_history_extents"], dim=-2)[0], + ), + dim=1, + ) + + src_pos = torch.cat( + ( + data_batch["hist_pos"].unsqueeze(1), + data_batch["all_other_agents_hist_pos"], + ), + dim=1, + ) + "history position and yaw need to be flipped so that they go from past to recent" + src_pos = torch.flip(src_pos, dims=[-2]) + src_yaw = torch.cat( + ( + data_batch["hist_yaw"].unsqueeze(1), + data_batch["all_other_agents_hist_yaw"], + ), + dim=1, + ) + src_yaw = torch.flip(src_yaw, dims=[-2]) + src_world_yaw = src_yaw + ( + data_batch["yaw"] + .view(-1, 1, 1, 1) + .repeat(1, src_yaw.size(1), src_yaw.size(2), 1) + ).type(torch.float) + src_mask = torch.cat( + ( + data_batch["hist_mask"].unsqueeze(1), + data_batch["all_other_agents_history_availability"], + ), + dim=1, + ).bool() + + src_mask = torch.flip(src_mask, dims=[-1]) + # estimate velocity + src_vel = dyn_list[dynamics.DynType.UNICYCLE].calculate_vel( + src_pos, src_yaw, src_mask + ) + + src_vel[:, 0, -1] = torch.clip( + data_batch["curr_speed"].unsqueeze(-1), + min=algo_config.vmin, + max=algo_config.vmax, + ) + if algo_config.vectorize_lane: + src_lanes = torch.cat( + ( + data_batch["ego_lanes"].unsqueeze(1), + data_batch["all_other_agents_lanes"], + ), + dim=1, + ).type(torch.float) + src_lanes = torch.cat(( + src_lanes[...,0:2], + torch.cos(src_lanes[...,2:3]), + torch.sin(src_lanes[...,2:3]), + src_lanes[...,-1:], + ),dim=-1) + src_lanes = src_lanes.view(*src_lanes.shape[:2], -1) + else: + src_lanes = None + src, dyn_type, src_state = raw2feature( + src_pos, src_vel, src_yaw, raw_type, src_mask, src_lanes + ) + tgt_mask = torch.cat( + ( + data_batch["fut_mask"].unsqueeze(1), + data_batch["all_other_agents_future_availability"], + ), + dim=1, + ).bool() + num = torch.arange(0, src_mask.shape[2]).view(1, 1, -1).to(src_mask.device) + nummask = num * src_mask + last_idx, _ = torch.max(nummask, dim=2) + curr_state = torch.gather( + src_state, 2, last_idx[..., None, None].repeat(1, 1, 1, 4) + ) + + tgt_pos = torch.cat( + ( + data_batch["fut_pos"].unsqueeze(1), + data_batch["all_other_agents_future_positions"], + ), + dim=1, + ) + tgt_yaw = torch.cat( + ( + data_batch["fut_yaw"].unsqueeze(1), + data_batch["all_other_agents_future_yaws"], + ), + dim=1, + ) + tgt_pos_yaw = torch.cat((tgt_pos, tgt_yaw), dim=-1) + + + # curr_pos_yaw = torch.cat((curr_state[..., 0:2], curr_yaw), dim=-1) + + # tgt = tgt - curr_pos_yaw.repeat(1, 1, tgt.size(2), 1) * tgt_mask.unsqueeze(-1) + + + return ( + src, + dyn_type, + src_state, + src_pos, + src_yaw, + src_world_yaw, + src_vel, + raw_type, + src_mask, + src_lanes, + extents, + tgt_pos_yaw, + tgt_mask, + curr_state, + ) + +def obtain_goal_state(tgt_pos_yaw,tgt_mask): + num = torch.arange(0, tgt_mask.shape[2]).view(1, 1, -1).to(tgt_mask.device) + nummask = num * tgt_mask + last_idx, _ = torch.max(nummask, dim=2, keepdim=True) + last_mask = nummask.ge(last_idx) + + goal_mask = tgt_mask*last_mask + goal_pos_yaw = tgt_pos_yaw*goal_mask.unsqueeze(-1) + return goal_pos_yaw[...,:2], goal_pos_yaw[...,2:], goal_mask + + +def batch_to_raw_all_agents(data_batch, step_time): + raw_type = torch.cat( + (data_batch["type"].unsqueeze(1), data_batch["all_other_agents_types"]), + dim=1, + ).type(torch.int64) + + src_pos = torch.cat( + ( + data_batch["hist_pos"].unsqueeze(1), + data_batch["all_other_agents_hist_pos"], + ), + dim=1, + ) + # history position and yaw need to be flipped so that they go from past to recent + src_pos = torch.flip(src_pos, dims=[-2]) + src_yaw = torch.cat( + ( + data_batch["hist_yaw"].unsqueeze(1), + data_batch["all_other_agents_hist_yaw"], + ), + dim=1, + ) + src_yaw = torch.flip(src_yaw, dims=[-2]) + src_mask = torch.cat( + ( + data_batch["hist_mask"].unsqueeze(1), + data_batch["all_other_agents_history_availability"], + ), + dim=1, + ).bool() + + src_mask = torch.flip(src_mask, dims=[-1]) + + extents = torch.cat( + ( + data_batch["extent"][..., :2].unsqueeze(1), + torch.max(data_batch["all_other_agents_history_extents"], dim=-2)[0], + ), + dim=1, + ) + + # estimate velocity + src_vel = dynamics.Unicycle(step_time).calculate_vel(src_pos, src_yaw, src_mask) + src_vel[:, 0, -1] = data_batch["curr_speed"].unsqueeze(-1) + + return { + "hist_pos": src_pos, + "hist_yaw": src_yaw, + "curr_speed": src_vel[:, :, -1, 0], + "raw_types": raw_type, + "hist_mask": src_mask, + "extents": extents, + } + + +def batch_to_target_all_agents(data_batch): + pos = torch.cat( + ( + data_batch["fut_pos"].unsqueeze(1), + data_batch["all_other_agents_future_positions"], + ), + dim=1, + ) + yaw = torch.cat( + ( + data_batch["fut_yaw"].unsqueeze(1), + data_batch["all_other_agents_future_yaws"], + ), + dim=1, + ) + avails = torch.cat( + ( + data_batch["fut_mask"].unsqueeze(1), + data_batch["all_other_agents_future_availability"], + ), + dim=1, + ) + + extents = torch.cat( + ( + data_batch["extent"][..., :2].unsqueeze(1), + torch.max(data_batch["all_other_agents_history_extents"], dim=-2)[0], + ), + dim=1, + ) + + return { + "fut_pos": pos, + "fut_yaw": yaw, + "fut_mask": avails, + "extents": extents + } + + +def generate_edges( + raw_type, + extents, + pos_pred, + yaw_pred, +): + veh_mask = (raw_type >= 3) & (raw_type <= 13) + ped_mask = (raw_type == 14) | (raw_type == 15) + + agent_mask = veh_mask | ped_mask + edge_types = ["VV", "VP", "PV", "PP"] + edges = {et: list() for et in edge_types} + for i in range(agent_mask.shape[0]): + agent_idx = torch.where(agent_mask[i] != 0)[0] + edge_idx = torch.combinations(agent_idx, r=2) + VV_idx = torch.where( + veh_mask[i, edge_idx[:, 0]] & veh_mask[i, edge_idx[:, 1]] + )[0] + VP_idx = torch.where( + veh_mask[i, edge_idx[:, 0]] & ped_mask[i, edge_idx[:, 1]] + )[0] + PV_idx = torch.where( + ped_mask[i, edge_idx[:, 0]] & veh_mask[i, edge_idx[:, 1]] + )[0] + PP_idx = torch.where( + ped_mask[i, edge_idx[:, 0]] & ped_mask[i, edge_idx[:, 1]] + )[0] + if pos_pred.ndim == 4: + edges_of_all_types = torch.cat( + ( + pos_pred[i, edge_idx[:, 0], :], + yaw_pred[i, edge_idx[:, 0], :], + pos_pred[i, edge_idx[:, 1], :], + yaw_pred[i, edge_idx[:, 1], :], + extents[i, edge_idx[:, 0]] + .unsqueeze(-2) + .repeat(1, pos_pred.size(-2), 1), + extents[i, edge_idx[:, 1]] + .unsqueeze(-2) + .repeat(1, pos_pred.size(-2), 1), + ), + dim=-1, + ) + edges["VV"].append(edges_of_all_types[VV_idx]) + edges["VP"].append(edges_of_all_types[VP_idx]) + edges["PV"].append(edges_of_all_types[PV_idx]) + edges["PP"].append(edges_of_all_types[PP_idx]) + elif pos_pred.ndim == 5: + + edges_of_all_types = torch.cat( + ( + pos_pred[i, :, edge_idx[:, 0], :], + yaw_pred[i, :, edge_idx[:, 0], :], + pos_pred[i, :, edge_idx[:, 1], :], + yaw_pred[i, :, edge_idx[:, 1], :], + extents[i, None, edge_idx[:, 0], None, :].repeat( + pos_pred.size(1), 1, pos_pred.size(-2), 1 + ), + extents[i, None, edge_idx[:, 1], None, :].repeat( + pos_pred.size(1), 1, pos_pred.size(-2), 1 + ), + ), + dim=-1, + ) + edges["VV"].append(edges_of_all_types[:, VV_idx]) + edges["VP"].append(edges_of_all_types[:, VP_idx]) + edges["PV"].append(edges_of_all_types[:, PV_idx]) + edges["PP"].append(edges_of_all_types[:, PP_idx]) + if pos_pred.ndim == 4: + for et, v in edges.items(): + edges[et] = torch.cat(v, dim=0) + elif pos_pred.ndim == 5: + for et, v in edges.items(): + edges[et] = torch.cat(v, dim=1) + return edges + + +def gen_ego_edges(ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types): + """generate edges between ego trajectory samples and agent trajectories + + Args: + ego_trajectories (torch.Tensor): [B,N,T,3] + agent_trajectories (torch.Tensor): [B,A,T,3] or [B,N,A,T,3] + ego_extents (torch.Tensor): [B,2] + agent_extents (torch.Tensor): [B,A,2] + raw_types (torch.Tensor): [B,A] + Returns: + edges (torch.Tensor): [B,N,A,T,10] + type_mask (dict) + """ + B,N,T = ego_trajectories.shape[:3] + A = agent_trajectories.shape[-3] + + veh_mask = (raw_types >= 3) & (raw_types <= 13) + ped_mask = (raw_types == 14) | (raw_types == 15) + + edges = torch.zeros([B,N,A,T,10]).to(ego_trajectories.device) + edges[...,:3] = ego_trajectories.unsqueeze(2).repeat(1,1,A,1,1) + if agent_trajectories.ndim==4: + edges[...,3:6] = agent_trajectories.unsqueeze(1).repeat(1,N,1,1,1) + else: + edges[...,3:6] = agent_trajectories + edges[...,6:8] = ego_extents.reshape(B,1,1,1,2).repeat(1,N,A,T,1) + edges[...,8:] = agent_extents.reshape(B,1,A,1,2).repeat(1,N,1,T,1) + type_mask = {"VV":veh_mask,"VP":ped_mask} + return edges,type_mask + + +def gen_EC_edges(ego_trajectories,agent_trajectories,ego_extents, agent_extents, raw_types,mask=None): + """generate edges between ego trajectory samples and agent trajectories + + Args: + ego_trajectories (torch.Tensor): [B,A,T,3] + agent_trajectories (torch.Tensor): [B,A,T,3] + ego_extents (torch.Tensor): [B,2] + agent_extents (torch.Tensor): [B,A,2] + raw_types (torch.Tensor): [B,A] + mask (optional, torch.Tensor): [B,A] + Returns: + edges (torch.Tensor): [B,N,A,T,10] + type_mask (dict) + """ + + B,A = ego_trajectories.shape[:2] + T = ego_trajectories.shape[-2] + + veh_mask = (raw_types >= 3) & (raw_types <= 13) + ped_mask = (raw_types == 14) | (raw_types == 15) + + + if ego_trajectories.ndim==4: + edges = torch.zeros([B,A,T,10]).to(ego_trajectories.device) + edges[...,:3] = ego_trajectories + edges[...,3:6] = agent_trajectories + edges[...,6:8] = ego_extents.reshape(B,1,1,2).repeat(1,A,T,1) + edges[...,8:] = agent_extents.unsqueeze(2).repeat(1,1,T,1) + elif ego_trajectories.ndim==5: + + K = ego_trajectories.shape[2] + edges = torch.zeros([B,A*K,T,10]).to(ego_trajectories.device) + edges[...,:3] = TensorUtils.join_dimensions(ego_trajectories,1,3) + edges[...,3:6] = agent_trajectories.repeat(1,K,1,1) + edges[...,6:8] = ego_extents.reshape(B,1,1,2).repeat(1,A*K,T,1) + edges[...,8:] = agent_extents.unsqueeze(2).repeat(1,K,T,1) + veh_mask = veh_mask.tile(1,K) + ped_mask = ped_mask.tile(1,K) + if mask is not None: + veh_mask = veh_mask*mask + ped_mask = ped_mask*mask + type_mask = {"VV":veh_mask,"VP":ped_mask} + return edges,type_mask + + +def get_edges_from_batch(data_batch, ego_predictions=None, all_predictions=None): + raw_type = torch.cat( + (data_batch["type"].unsqueeze(1), data_batch["all_other_agents_types"]), + dim=1, + ).type(torch.int64) + + # Use predicted ego position to compute future box edges + + targets_all = batch_to_target_all_agents(data_batch) + if ego_predictions is not None: + targets_all["fut_pos"] [:, 0, :, :] = ego_predictions["positions"] + targets_all["fut_yaw"][:, 0, :, :] = ego_predictions["yaws"] + elif all_predictions is not None: + targets_all["fut_pos"] = all_predictions["positions"] + targets_all["fut_yaw"] = all_predictions["yaws"] + else: + raise ValueError("Please specify either ego prediction or all predictions") + + pred_edges = generate_edges( + raw_type, targets_all["extents"], + pos_pred=targets_all["fut_pos"], + yaw_pred=targets_all["fut_yaw"] + ) + return pred_edges + + +def get_last_available_index(avails): + """ + Args: + avails (torch.Tensor): target availabilities [B, (A), T] + + Returns: + last_indices (torch.Tensor): index of the last available frame + """ + num_frames = avails.shape[-1] + inds = torch.arange(0, num_frames).to(avails.device) # [T] + inds = (avails > 0).float() * inds # [B, (A), T] arange indices with unavailable indices set to 0 + last_inds = inds.max(dim=-1)[1] # [B, (A)] calculate the index of the last availale frame + return last_inds + + +def get_current_states(batch: dict, dyn_type: dynamics.DynType) -> torch.Tensor: + bs = batch["curr_speed"].shape[0] + if dyn_type == dynamics.DynType.BICYCLE: + current_states = torch.zeros(bs, 6).to(batch["curr_speed"].device) # [x, y, yaw, vel, dh, veh_len] + current_states[:, 3] = batch["curr_speed"].abs() + current_states[:, [4]] = (batch["hist_yaw"][:, 0] - batch["hist_yaw"][:, 1]).abs() + current_states[:, 5] = batch["extent"][:, 0] # [veh_len] + else: + current_states = torch.zeros(bs, 4).to(batch["curr_speed"].device) # [x, y, vel, yaw] + current_states[:, 2] = batch["curr_speed"] + return current_states + + +def get_current_states_all_agents(batch: dict, step_time, dyn_type: dynamics.DynType) -> torch.Tensor: + if batch["hist_pos"].ndim==3: + state_all = batch_to_raw_all_agents(batch, step_time) + else: + state_all = batch + bs, na = state_all["curr_speed"].shape[:2] + if dyn_type == dynamics.DynType.BICYCLE: + current_states = torch.zeros(bs, na, 6).to(state_all["curr_speed"].device) # [x, y, yaw, vel, dh, veh_len] + current_states[:, :, :2] = state_all["hist_pos"][:, :, 0] + current_states[:, :, 3] = state_all["curr_speed"].abs() + current_states[:, :, [4]] = (state_all["hist_yaw"][:, :, 0] - state_all["hist_yaw"][:, :, 1]).abs() + current_states[:, :, 5] = state_all["extent"][:, :, 0] # [veh_len] + else: + current_states = torch.zeros(bs, na, 4).to(state_all["curr_speed"].device) # [x, y, vel, yaw] + current_states[:, :, :2] = state_all["hist_pos"][:, :, 0] + current_states[:, :, 2] = state_all["curr_speed"] + current_states[:,:,3:] = state_all["hist_yaw"][:,:,0] + return current_states + + +def get_drivable_region_map(rasterized_map): + return rasterized_map[..., -3, :, :] < 1. + + +def get_modality_shapes(cfg: ExperimentConfig): + assert cfg.env.rasterizer.map_type == "py_semantic" + num_channels = (cfg.stack.history_num_frames + 1) * 2 + 3 + h, w = cfg.env.rasterizer.raster_size + return dict(image=(num_channels, h, w)) \ No newline at end of file diff --git a/diffstack/utils/lane_utils.py b/diffstack/utils/lane_utils.py new file mode 100644 index 0000000..a7db498 --- /dev/null +++ b/diffstack/utils/lane_utils.py @@ -0,0 +1,553 @@ +import numpy as np +import torch +import diffstack.utils.geometry_utils as GeoUtils +from diffstack.utils.geometry_utils import ratan2 +import enum +from dataclasses import dataclass +import scipy.interpolate as spint + + +class LaneModeConst: + X_ahead_thresh = 5.0 + X_rear_thresh = 0.0 + Y_near_thresh = 1.8 + Y_far_thresh = 5.0 + psi_thresh = np.pi / 4 + longitudinal_scale = 30 + lateral_scale = 1 + heading_scale = 0.5 + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +def get_edge(lane, dir, W=2.0, num_pts=None): + if dir == "L": + if lane.left_edge is not None: + if num_pts is not None: + lane.left_edge = lane.left_edge.interpolate(num_pts) + xy = lane.left_edge.xy + if lane.left_edge.has_heading: + h = lane.left_edge.h + else: + # check if the points are reversed + edge_angle = np.arctan2(xy[-1, 1] - xy[0, 1], xy[-1, 0] - xy[0, 0]) + center_angle = np.arctan2( + lane.center.xy[-1, 1] - lane.center.xy[0, 1], + lane.center.xy[-1, 0] - lane.center.xy[0, 0], + ) + if np.abs(GeoUtils.round_2pi(edge_angle - center_angle)) > np.pi / 2: + xy = np.flip(xy, 0) + dxy = xy[1:] - xy[:-1] + h = GeoUtils.round_2pi(np.arctan2(dxy[:, 1], dxy[:, 0])) + h = np.hstack((h, h[-1])) + else: + if num_pts is not None: + lane.center = lane.center.interpolate(num_pts) + angle = lane.center.h + np.pi / 2 + offset = np.stack([W * np.cos(angle), W * np.sin(angle)], -1) + xy = lane.center.xy + offset + h = lane.center.h + elif dir == "R": + if lane.right_edge is not None: + if num_pts is not None: + lane.right_edge = lane.right_edge.interpolate(num_pts) + xy = lane.right_edge.xy + if lane.right_edge.has_heading: + h = lane.right_edge.h + else: + # check if the points are reversed + edge_angle = np.arctan2(xy[-1, 1] - xy[0, 1], xy[-1, 0] - xy[0, 0]) + center_angle = np.arctan2( + lane.center.xy[-1, 1] - lane.center.xy[0, 1], + lane.center.xy[-1, 0] - lane.center.xy[0, 0], + ) + if np.abs(GeoUtils.round_2pi(edge_angle - center_angle)) > np.pi / 2: + xy = np.flip(xy, 0) + dxy = xy[1:] - xy[:-1] + h = GeoUtils.round_2pi(np.arctan2(dxy[:, 1], dxy[:, 0])) + h = np.hstack((h, h[-1])) + else: + if num_pts is not None: + lane.center = lane.center.interpolate(num_pts) + angle = lane.center.h - np.pi / 2 + offset = np.stack([W * np.cos(angle), W * np.sin(angle)], -1) + xy = lane.center.xy + offset + h = lane.center.h + elif dir == "C": + if num_pts is not None: + lane.center = lane.center.interpolate(num_pts) + xy = lane.center.xy + if lane.center.has_heading: + h = lane.center.h + else: + dxy = xy[1:] - xy[:-1] + h = GeoUtils.round_2pi(np.arctan2(dxy[:, 1], dxy[:, 0])) + h = np.hstack((h, h[-1])) + return xy, h + + +def get_bdry_xyh(lane1, lane2=None, dir="L", W=3.6, num_pts=25): + if lane2 is None: + xy, h = get_edge(lane1, dir, W, num_pts) + else: + xy1, h1 = get_edge(lane1, dir, W, num_pts) + xy2, h2 = get_edge(lane2, dir, W, num_pts) + xy = np.concatenate((xy1, xy2), 0) + h = np.concatenate((h1, h2), 0) + return xy, h + + +def LaneRelationFromCfg(lane_relation): + if lane_relation == "SimpleLaneRelation": + return SimpleLaneRelation + elif lane_relation == "LRLaneRelation": + return LRLaneRelation + elif lane_relation == "LaneRelation": + return LaneRelation + else: + raise ValueError("Invalid lane relation type") + + +class SimpleLaneRelation(enum.IntEnum): + """ + Categorical token describing the relationship between an agent and a Lane, unitary lane mode that only considers which lane the agent is one + """ + + NOTON = 0 # (0, 2, 3, 4, 5, 6) + ON = 1 + + @staticmethod + def get_all_margins(agent_xysc, lane_xysc, t_range=None, const_override={}): + ( + x_ahead_margin, + x_behind_margin, + y_left_near, + y_left_far, + y_right_near, + y_right_far, + psi_margin, + ) = get_l2a_geometry( + agent_xysc, lane_xysc, t_range, const_override=const_override + ) + return torch.stack( + [x_ahead_margin, x_behind_margin, y_left_near, y_right_near, psi_margin], -1 + ) + + @staticmethod + def categorize_lane_relation_pts( + agent_xysc, + lane_xysc, + agent_mask=None, + lane_mask=None, + t_range=None, + force_select=True, + force_unique=True, + const_override={}, + return_all_margins=False, + ): + ( + x_ahead_margin, + x_behind_margin, + y_left_near, + y_left_far, + y_right_near, + y_right_far, + psi_margin, + ) = get_l2a_geometry( + agent_xysc, lane_xysc, t_range, const_override=const_override + ) + margin = torch.zeros( + *x_ahead_margin.shape, len(SimpleLaneRelation), device=x_ahead_margin.device + ) # margin > 0, then mode is active + margin[..., SimpleLaneRelation.ON] = torch.stack( + [x_ahead_margin, x_behind_margin, y_left_near, y_right_near, psi_margin], -1 + ).min(-1)[0] + margin = fill_margin( + margin, + noton_idx=SimpleLaneRelation.NOTON, + agent_mask=agent_mask, + lane_mask=lane_mask, + ) + if force_select: + # offset the margin to make sure that at least one is positive + margin_max = margin[..., SimpleLaneRelation.ON].max(dim=1)[0] + margin_offset = -margin_max.clip(max=0).detach() + 1e-6 + margin[..., SimpleLaneRelation.ON] = ( + margin[..., SimpleLaneRelation.ON] + margin_offset[:, None, :] + ) + + if force_unique: + second_largest_margin = margin[..., 1].topk(2, 1, sorted=True)[0][:, 1] + margin_offset = -second_largest_margin.clip(min=0).detach() + margin[..., 1] = margin[..., 1] + margin_offset[:, None, :] + margin[..., SimpleLaneRelation.NOTON] = -margin[..., SimpleLaneRelation.ON] + + flag = get_flag(margin) + return flag, margin + + +class LRLaneRelation(enum.IntEnum): + """Unitary lane mode that considers which lane the agent is on, which lane the ego is on the left of, and which lane the ego is on the right of""" + + NOTON = 0 # (0, 2, 3, 4) + ON = 1 + LEFTOF = 2 # (5) + RIGHTOF = 3 # (6) + + @staticmethod + def get_all_margins(agent_xysc, lane_xysc, t_range=None, const_override={}): + ( + x_ahead_margin, + x_behind_margin, + y_left_near, + y_left_far, + y_right_near, + y_right_far, + psi_margin, + ) = get_l2a_geometry( + agent_xysc, lane_xysc, t_range, const_override=const_override + ) + return torch.stack( + [x_ahead_margin, x_behind_margin, y_left_near, y_right_near, psi_margin], -1 + ) + + @staticmethod + def categorize_lane_relation_pts( + agent_xysc, + lane_xysc, + agent_mask=None, + lane_mask=None, + t_range=None, + force_select=True, + force_unique=True, + const_override={}, + ): + ( + x_ahead_margin, + x_behind_margin, + y_left_near, + y_left_far, + y_right_near, + y_right_far, + psi_margin, + ) = get_l2a_geometry( + agent_xysc, lane_xysc, t_range, const_override=const_override + ) + margin = torch.zeros( + *x_ahead_margin.shape, len(LRLaneRelation), device=x_ahead_margin.device + ) # margin > 0, then mode is active + margin[..., LRLaneRelation.ON] = torch.stack( + [x_ahead_margin, x_behind_margin, y_left_near, y_right_near, psi_margin], -1 + ).min(-1)[0] + margin[..., LRLaneRelation.LEFTOF] = torch.stack( + [x_ahead_margin, x_behind_margin, -y_left_near, y_left_far, psi_margin], -1 + ).min(-1)[ + 0 + ] # further than left near, closer than left far + margin[..., LRLaneRelation.RIGHTOF] = torch.stack( + [x_ahead_margin, x_behind_margin, -y_right_near, y_right_far, psi_margin], + -1, + ).min(-1)[0] + + margin = fill_margin( + margin, + noton_idx=LRLaneRelation.NOTON, + agent_mask=agent_mask, + lane_mask=lane_mask, + ) + if force_select: + # offset the margin to make sure that at least one is positive + margin_max = margin[..., LRLaneRelation.ON].max(dim=1)[0] + margin_offset = -margin_max.clip(max=0).detach() + 1e-6 + margin[..., LRLaneRelation.ON] = ( + margin[..., LRLaneRelation.ON] + margin_offset[:, None, :] + ) + + if force_unique: + second_largest_margin = margin[..., LRLaneRelation.ON].topk( + 2, 1, sorted=True + )[0][:, 1] + margin_offset = -second_largest_margin.clip(min=0).detach() + margin[..., LRLaneRelation.ON] = ( + margin[..., LRLaneRelation.ON] + margin_offset[:, None, :] + ) + idx_excl_noton = torch.arange(margin.shape[-1]) + idx_excl_noton = idx_excl_noton[idx_excl_noton != LRLaneRelation.NOTON] + margin[..., LRLaneRelation.NOTON] = -margin[..., idx_excl_noton].max(-1)[0] + flag = get_flag(margin) + return flag, margin + + +class LaneRelation(enum.IntEnum): + """ + pairwise lane mode that gives each agent-lane pair a categorical token describing the relationship between the agent and the lane + """ + + NOTON = 0 + ON = 1 + AHEAD = 2 + BEHIND = 3 + MISALIGN = 4 + LEFTOF = 5 + RIGHTOF = 6 + + @staticmethod + def get_all_margins(agent_xysc, lane_xysc, t_range=None, const_override={}): + ( + x_ahead_margin, + x_behind_margin, + y_left_near, + y_left_far, + y_right_near, + y_right_far, + psi_margin, + ) = get_l2a_geometry( + agent_xysc, lane_xysc, t_range, const_override=const_override + ) + return torch.stack( + [x_ahead_margin, x_behind_margin, y_left_near, y_right_near, psi_margin], -1 + ) + + @staticmethod + def categorize_lane_relation_pts( + agent_xysc, + lane_xysc, + agent_mask=None, + lane_mask=None, + t_range=None, + force_select=True, + force_unique=True, + const_override={}, + ): + ( + x_ahead_margin, + x_behind_margin, + y_left_near, + y_left_far, + y_right_near, + y_right_far, + psi_margin, + ) = get_l2a_geometry( + agent_xysc, lane_xysc, t_range, const_override=const_override + ) + margin = torch.zeros( + *x_ahead_margin.shape, len(LaneRelation), device=x_ahead_margin.device + ) # margin > 0, then mode is active + margin[..., LaneRelation.ON] = torch.stack( + [x_ahead_margin, x_behind_margin, y_left_near, y_right_near, psi_margin], -1 + ).min(-1)[0] + margin[..., LaneRelation.AHEAD] = torch.stack( + [-x_ahead_margin, y_left_near, y_right_near, psi_margin], -1 + ).min(-1)[0] + margin[..., LaneRelation.BEHIND] = torch.stack( + [-x_behind_margin, y_left_near, y_right_near, psi_margin], -1 + ).min(-1)[0] + margin[..., LaneRelation.MISALIGN] = torch.stack( + [y_left_near, y_right_near, -psi_margin], -1 + ).min(-1)[0] + margin[..., LaneRelation.LEFTOF] = torch.stack( + [x_ahead_margin, x_behind_margin, -y_left_near, y_left_far, psi_margin], -1 + ).min(-1)[ + 0 + ] # further than left near, closer than left far + margin[..., LaneRelation.RIGHTOF] = torch.stack( + [x_ahead_margin, x_behind_margin, -y_right_near, y_right_far, psi_margin], + -1, + ).min(-1)[0] + + margin = fill_margin( + margin, + noton_idx=LaneRelation.NOTON, + agent_mask=agent_mask, + lane_mask=lane_mask, + ) + flag = get_flag(margin) + return flag, margin + + +def get_l2a_geometry(agent_xysc, lane_xysc, t_range=None, const_override={}): + const = LaneModeConst(**const_override) + # agent_xysc:[B,T,4], lane_xysc:[B,M,L,4] + # lane_mask: [B,M], agent_mask: [B,T] + B, T = agent_xysc.shape[:2] + M, L = lane_xysc.shape[1:3] + + # idx1 = max(int(T*0.3),1) + # idx2 = min(T-idx1,T-1) + + dx = GeoUtils.batch_proj_xysc( + agent_xysc.repeat_interleave(M, 0).reshape(-1, 4), + lane_xysc.repeat_interleave(T, 1).reshape(-1, L, 4), + ).reshape( + B, M, T, L, -1 + ) # [B,M,T,L,xdim] + close_idx = ( + dx[..., 0].abs().argmin(-1) + ) # Take first element (x-pos), and find closest index (of L) within each lane segment + proj_pts = dx.gather(-2, close_idx[..., None, None].repeat(1, 1, 1, 1, 4)).squeeze( + -2 + ) # Get projection points using the closest point for each lane seg [B,M,T,4] + psi = ratan2(proj_pts[..., 2], proj_pts[..., 3]).detach() + y_dev = proj_pts[..., 1] + # Hausdorff-like distance + x_ahead_margin = ( + const.X_ahead_thresh + dx[..., 0].max(-1)[0] + ) / const.longitudinal_scale # We only have to check the minimal value + x_behind_margin = ( + -dx[..., 0].min(-1)[0] + const.X_rear_thresh + ) / const.longitudinal_scale + y_left_near = (const.Y_near_thresh + y_dev) / const.lateral_scale + y_left_far = (const.Y_far_thresh + y_dev) / const.lateral_scale + y_right_near = (const.Y_near_thresh - y_dev) / const.lateral_scale + y_right_far = (const.Y_far_thresh - y_dev) / const.lateral_scale + psi_margin = (const.psi_thresh - psi.abs()) / const.heading_scale + if t_range is not None: + t0, t1 = t_range + x_ahead_margin = x_ahead_margin[:, :, t0:t1].mean(dim=2) + x_behind_margin = x_behind_margin[:, :, t0:t1].mean(dim=2) + y_left_near = y_left_near[:, :, t0:t1].mean(dim=2) + y_left_far = y_left_far[:, :, t0:t1].mean(dim=2) + y_right_near = y_right_near[:, :, t0:t1].mean(dim=2) + y_right_far = y_right_far[:, :, t0:t1].mean(dim=2) + psi_margin = psi_margin[:, :, t0:t1].mean(dim=2) + return ( + x_ahead_margin, + x_behind_margin, + y_left_near, + y_left_far, + y_right_near, + y_right_far, + psi_margin, + ) + + +def get_ypsi_dev(agent_xysc, lane_xysc): + # agent_xysc:[B,T,4], lane_xysc:[B,M,L,4] + # lane_mask: [B,M], agent_mask: [B,T] + B, T = agent_xysc.shape[:2] + M, L = lane_xysc.shape[1:3] + + # idx1 = max(int(T*0.3),1) + # idx2 = min(T-idx1,T-1) + + dx = GeoUtils.batch_proj_xysc( + agent_xysc.repeat_interleave(M, 0).reshape(-1, 4), + lane_xysc.repeat_interleave(T, 1).reshape(-1, L, 4), + ).reshape( + B, M, T, L, -1 + ) # [B,M,T,L,xdim] + close_idx = ( + dx[..., 0].abs().argmin(-1) + ) # Take first element (x-pos), and find closest index (of L) within each lane segment + proj_pts = dx.gather(-2, close_idx[..., None, None].repeat(1, 1, 1, 1, 4)).squeeze( + -2 + ) # Get projection points using the closest point for each lane seg [B,M,T,4] + psi = ratan2(proj_pts[..., 2], proj_pts[..., 3]).detach() + y_dev = proj_pts[..., 1] + return y_dev, psi + + +def fill_margin(margin, noton_idx, agent_mask=None, lane_mask=None): + idx_excl_noton = torch.arange(margin.shape[-1]) + idx_excl_noton = idx_excl_noton[idx_excl_noton != noton_idx] + # put anything that does not belong to all the classes above to NOTON + margin[..., noton_idx] = -margin[..., idx_excl_noton].max(-1)[ + 0 + ] # Negation of the max of the rest + + # Put anything that is masked out to NOTON + margin[..., noton_idx] = margin[..., noton_idx].masked_fill( + torch.logical_not(agent_mask).unsqueeze(1), 10 + ) # agents we're not considering set to noton + margin[..., idx_excl_noton] = margin[..., idx_excl_noton].masked_fill( + torch.logical_not(agent_mask)[:, None, :, None], -10 + ) + margin[..., noton_idx] = margin[..., noton_idx].masked_fill( + torch.logical_not(lane_mask).unsqueeze(2), 10 + ) + margin[..., idx_excl_noton] = margin[..., idx_excl_noton].masked_fill( + torch.logical_not(lane_mask)[:, :, None, None], -10 + ) + return margin + + +def get_flag(margin): + flag = (margin >= 0).float() + # HACK: suppress multiple on flags + # if flag.shape[-1] > 2: + # flag[..., 2:] = flag[..., 2:].masked_fill( + # (flag[..., 1:2] > 0).repeat_interleave(flag.shape[-1] - 2, -1), 0 + # ) + # assert (flag.sum(-1) == 1).all() + return flag + + +def get_ref_traj(agent_xyh, lane_xyh, des_vel, dt, T): + """calculate reference trajectory along the lane center + + Args: + agent_xyvsc (np.ndarray): B,3 + lane_xysc (np.ndarray): B,L,3 + des_vel (np.ndarray): B + """ + B = agent_xyh.shape[0] + delta_x, _, _ = GeoUtils.batch_proj(agent_xyh, lane_xyh) + indices = np.abs(delta_x).argmin(-1) + xrefs = list() + for agent, lane, idx, delta_x_i, vel in zip( + agent_xyh, lane_xyh, indices, delta_x, des_vel + ): + idx = min(idx, lane.shape[0] - 2) + s = np.linalg.norm( + lane[idx + 1 :, :2] - lane[idx:-1, :2], axis=-1, keepdims=True + ).cumsum() + s = np.insert(s, 0, 0.0) - delta_x_i[idx] + f = spint.interp1d( + s, + lane[idx:], + axis=0, + assume_sorted=True, + bounds_error=False, + fill_value="extrapolate", + ) + xref = f(vel * np.arange(1, T + 1) * dt) + if np.isnan(xref).any(): + xref = np.zeros((T, 3)) + xrefs.append(xref) + return np.stack(xrefs, 0) + + +def get_closest_lane_pts(xyh, lane): + delta_x, _, _ = GeoUtils.batch_proj(xyh, lane.center.xyh) + if isinstance(xyh, np.ndarray): + return np.abs(delta_x).argmin() + elif isinstance(xyh, torch.Tensor): + return delta_x.abs().argmin() + + +def test_edge(): + import pickle + import torch + + with open("sf_test.pkl", "rb") as f: + data = pickle.load(f) + lane_xyh = data["lane_xyh"] + lane_feat = torch.cat( + [ + lane_xyh[..., :2], + torch.sin(lane_xyh[..., 2:3]), + torch.cos(lane_xyh[..., 2:3]), + ], + -1, + ) + ego_xycs = data["agent_hist"][:, 0, :, [0, 1, 6, 7]] + + get_l2a_geometry(ego_xycs, lane_feat, [0, 4]) + + print("123") + + +if __name__ == "__main__": + test_edge() diff --git a/diffstack/utils/log_utils.py b/diffstack/utils/log_utils.py new file mode 100644 index 0000000..dc755ab --- /dev/null +++ b/diffstack/utils/log_utils.py @@ -0,0 +1,84 @@ +""" +This file contains utility classes and functions for logging to stdout, stderr, +and to tensorboard. +""" +import sys +import os +import time +import pathlib +import json +import wandb +from omegaconf import OmegaConf, DictConfig + + +def prepare_logging(cfg: DictConfig, use_logging=True, use_wandb=None, verbose=True, wandb_init_retries=10): + # Logging + log_writer = None + model_dir = None + if not use_logging: + return log_writer, model_dir + if use_wandb is None: + use_wandb = cfg.run.wandb_logging + + # Create the log and model directory if they're not present. + model_dir = os.path.join(cfg.run.log_dir, + cfg.run.log_tag + time.strftime('-%d_%b_%Y_%H_%M_%S', time.localtime())) + cfg.logdir = model_dir + + pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True) + + # Save config to model directory + cfg_dict = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=False) + with open(os.path.join(model_dir, 'config.json'), 'w') as conf_json: + json.dump(cfg_dict, conf_json) + + # wandb.tensorboard.patch(root_logdir=model_dir, pytorch=True) + + # WandB init. Put it in a loop because it can fail on ngc. + if use_wandb: + log_writer = init_wandb(cfg, wandb_init_retries=wandb_init_retries) + + if verbose: + print (f"Log path: {model_dir}") + + return log_writer, model_dir + + +def init_wandb(cfg, wandb_init_retries=10): + cfg_dict = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=False) + for _ in range(wandb_init_retries): + try: + log_writer = wandb.init( + project=cfg.run.wandb_project, + name=f"{cfg.run['log_tag']}", + config=cfg_dict, + mode="offline" if cfg.run["interactive"] else "online", # sync_tensorboard=True, + settings=wandb.Settings(start_method="thread"), + ) + except: + continue + break + else: + raise ValueError("Could not connect to wandb") + return log_writer + + +class PrintLogger(object): + """ + This class redirects print statements to both console and a file. + """ + def __init__(self, log_file): + self.terminal = sys.stdout + print('STDOUT will be forked to %s' % log_file) + self.log_file = open(log_file, "a") + + def write(self, message): + self.terminal.write(message) + self.log_file.write(message) + self.log_file.flush() + + def flush(self): + # this flush method is needed for python 3 compatibility. + # this handles the flush command by doing nothing. + # you might want to specify some extra behavior here. + pass diff --git a/diffstack/utils/loss_utils.py b/diffstack/utils/loss_utils.py new file mode 100644 index 0000000..bf13f88 --- /dev/null +++ b/diffstack/utils/loss_utils.py @@ -0,0 +1,858 @@ +""" +This file contains a collection of useful loss functions for use with torch tensors. +Partially borrowed from https://github.com/ARISE-Initiative/robomimic/blob/master/robomimic/utils/loss_utils.py +""" +from typing import Dict + +import numpy as np +import torch +import torch.nn as nn +import diffstack.utils.tensor_utils as TensorUtils +from diffstack.utils.batch_utils import batch_utils +from diffstack.utils.geometry_utils import ( + VEH_VEH_collision, + VEH_PED_collision, + PED_VEH_collision, + PED_PED_collision, +) +import torch.nn.functional as F + + +def cosine_loss(preds, labels): + """ + Cosine loss between two tensors. + Args: + preds (torch.Tensor): torch tensor + labels (torch.Tensor): torch tensor + Returns: + loss (torch.Tensor): cosine loss + """ + sim = torch.nn.CosineSimilarity(dim=len(preds.shape) - 1)(preds, labels) + return -torch.mean(sim - 1.0) + + +def KLD_0_1_loss(mu, logvar): + """ + KL divergence loss. Computes D_KL( N(mu, sigma) || N(0, 1) ). Note that + this function averages across the batch dimension, but sums across dimension. + Args: + mu (torch.Tensor): mean tensor of shape (B, D) + logvar (torch.Tensor): logvar tensor of shape (B, D) + Returns: + loss (torch.Tensor): KL divergence loss between the input gaussian distribution + and N(0, 1) + """ + return -0.5 * (1.0 + logvar - mu.pow(2) - logvar.exp()).sum(dim=1).mean() + + +def KLD_gaussian_loss(mu_1, logvar_1, mu_2, logvar_2): + """ + KL divergence loss between two Gaussian distributions. This function + computes the average loss across the batch. + Args: + mu_1 (torch.Tensor): first means tensor of shape (B, D) + logvar_1 (torch.Tensor): first logvars tensor of shape (B, D) + mu_2 (torch.Tensor): second means tensor of shape (B, D) + logvar_2 (torch.Tensor): second logvars tensor of shape (B, D) + Returns: + loss (torch.Tensor): KL divergence loss between the two gaussian distributions + """ + return ( + -0.5 + * ( + 1.0 + + logvar_1 + - logvar_2 + - ((mu_2 - mu_1).pow(2) / logvar_2.exp()) + - (logvar_1.exp() / logvar_2.exp()) + ) + .sum(dim=1) + .mean() + ) + + +def KLD_discrete(logp, logq): + """KL divergence loss between two discrete distributions. This function + computes the average loss across the batch. + + Args: + logp (torch.Tensor): log probability of first discrete distribution (B,D) + logq (torch.Tensor): log probability of second discrete distribution (B,D) + """ + return (torch.exp(logp) * (logp - logq)).sum(dim=1) + + +def KLD_discrete_with_zero(p, q, logmin=None, logmax=None): + flag = p != 0 + # breakpoint() + p = p.clip(min=1e-8) + q = q.clip(min=1e-8) + logp = (torch.log(p) * flag).nan_to_num(0) + logq = (torch.log(q) * flag).nan_to_num(0) + logp = logp.clip(min=logmin, max=logmax) + logq = logq.clip(min=logmin, max=logmax) + return (p * (logp - logq) * flag).sum(dim=1) + + +def log_normal(x, m, v, avails=None): + """ + Log probability of tensor x under diagonal multivariate normal with + mean m and variance v. The last dimension of the tensors is treated + as the dimension of the Gaussian distribution - all other dimensions + are treated as independent Gaussians. Adapted from CS 236 at Stanford. + Args: + x (torch.Tensor): tensor with shape (B, ..., D) + m (torch.Tensor): means tensor with shape (B, ..., D) or (1, ..., D) + v (torch.Tensor): variances tensor with shape (B, ..., D) or (1, ..., D) + avails (torch.Tensor): availability of x and m + Returns: + log_prob (torch.Tensor): log probabilities of shape (B, ...) + """ + if avails is None: + element_wise = -0.5 * (torch.log(v) + (x - m).pow(2) / v + np.log(2 * np.pi)) + else: + element_wise = -0.5 * ( + torch.log(v) + ((x - m) * avails).pow(2) / v + np.log(2 * np.pi) + ) + log_prob = element_wise.sum(-1) + return log_prob + + +def log_normal_mixture(x, m, v, w=None, log_w=None): + """ + Log probability of tensor x under a uniform mixture of Gaussians. + Adapted from CS 236 at Stanford. + Args: + x (torch.Tensor): tensor with shape (B, D) + m (torch.Tensor): means tensor with shape (B, M, D) or (1, M, D), where + M is number of mixture components + v (torch.Tensor): variances tensor with shape (B, M, D) or (1, M, D) where + M is number of mixture components + w (torch.Tensor): weights tensor - if provided, should be + shape (B, M) or (1, M) + log_w (torch.Tensor): log-weights tensor - if provided, should be + shape (B, M) or (1, M) + Returns: + log_prob (torch.Tensor): log probabilities of shape (B,) + """ + + # (B , D) -> (B , 1, D) + x = x.unsqueeze(1) + # (B, 1, D) -> (B, M, D) -> (B, M) + log_prob = log_normal(x, m, v) + if w is not None or log_w is not None: + # this weights the log probabilities by the mixture weights so we have log(w_i * N(x | m_i, v_i)) + if w is not None: + assert log_w is None + log_w = torch.log(w) + log_prob += log_w + # then compute log sum_i exp [log(w_i * N(x | m_i, v_i))] + # (B, M) -> (B,) + log_prob = log_sum_exp(log_prob, dim=1) + else: + # (B, M) -> (B,) + log_prob = log_mean_exp(log_prob, dim=1) # mean accounts for uniform weights + return log_prob + + +def NLL_GMM_loss(x, m, v, pi, avails=None, detach=True, mode="sum"): + """ + Log probability of tensor x under a uniform mixture of Gaussians. + Adapted from CS 236 at Stanford. + Args: + x (torch.Tensor): tensor with shape (B, D) + m (torch.Tensor): means tensor with shape (B, M, D) or (1, M, D), where + M is number of mixture components + v (torch.Tensor): variances tensor with shape (B, M, D) or (1, M, D) where + M is number of mixture components + logpi (torch.Tensor): log probability of the modes (B,M) + detach (bool): option whether to detach all modes but the best one + mode (string): mode of loss, sum or max + + Returns: + -log_prob (torch.Tensor): log probabilities of shape (B,) + """ + if v is None: + v = torch.ones_like(m) + + # (B , D) -> (B , 1, D) + x = x.unsqueeze(1) + # (B, 1, D) -> (B, M, D) -> (B, M) + if avails is not None: + avails = avails.unsqueeze(1) + log_prob = log_normal(x, m, v, avails=avails) + if mode == "sum": + if detach: + max_flag = log_prob == log_prob.max(dim=1, keepdim=True)[0] + nonmax_flag = torch.logical_not(max_flag) + log_prob_detach = log_prob.detach() + NLL_loss = (-pi * log_prob * max_flag).sum(1).mean() + ( + -pi * log_prob_detach * nonmax_flag + ).sum(1).mean() + else: + NLL_loss = (-pi * log_prob).sum(1).mean() + elif mode == "max": + max_flag = log_prob == log_prob.max(dim=1, keepdim=True)[0] + NLL_loss = (-pi * log_prob * max_flag).sum(1).mean() + return NLL_loss + + +def log_mean_exp(x, dim): + """ + Compute the log(mean(exp(x), dim)) in a numerically stable manner. + Adapted from CS 236 at Stanford. + Args: + x (torch.Tensor): a tensor + dim (int): dimension along which mean is computed + Returns: + y (torch.Tensor): log(mean(exp(x), dim)) + """ + return log_sum_exp(x, dim) - np.log(x.size(dim)) + + +def log_sum_exp(x, dim=0): + """ + Compute the log(sum(exp(x), dim)) in a numerically stable manner. + Adapted from CS 236 at Stanford. + Args: + x (torch.Tensor): a tensor + dim (int): dimension along which sum is computed + Returns: + y (torch.Tensor): log(sum(exp(x), dim)) + """ + + max_x = torch.max(x, dim)[0] + new_x = x - max_x.unsqueeze(dim).expand_as(x) + + return max_x + (new_x.exp().sum(dim)).log() + + +def project_values_onto_atoms(values, probabilities, atoms): + """ + Project the categorical distribution given by @probabilities on the + grid of values given by @values onto a grid of values given by @atoms. + This is useful when computing a bellman backup where the backed up + values from the original grid will not be in the original support, + requiring L2 projection. + Each value in @values has a corresponding probability in @probabilities - + this probability mass is shifted to the closest neighboring grid points in + @atoms in proportion. For example, if the value in question is 0.2, and the + neighboring atoms are 0 and 1, then 0.8 of the probability weight goes to + atom 0 and 0.2 of the probability weight will go to 1. + Adapted from https://github.com/deepmind/acme/blob/master/acme/tf/losses/distributional.py#L42 + + Args: + values: value grid to project, of shape (batch_size, n_atoms) + probabilities: probabilities for categorical distribution on @values, shape (batch_size, n_atoms) + atoms: value grid to project onto, of shape (n_atoms,) or (1, n_atoms) + Returns: + new probability vectors that correspond to the L2 projection of the categorical distribution + onto @atoms + """ + + # make sure @atoms is shape (n_atoms,) + if len(atoms.shape) > 1: + atoms = atoms.squeeze(0) + + # helper tensors from @atoms + vmin, vmax = atoms[0], atoms[1] + d_pos = torch.cat([atoms, vmin[None]], dim=0)[1:] + d_neg = torch.cat([vmax[None], atoms], dim=0)[:-1] + + # ensure that @values grid is within the support of @atoms + clipped_values = values.clamp(min=vmin, max=vmax)[ + :, None, : + ] # (batch_size, 1, n_atoms) + clipped_atoms = atoms[None, :, None] # (1, n_atoms, 1) + + # distance between atom values in support + d_pos = (d_pos - atoms)[ + None, :, None + ] # atoms[i + 1] - atoms[i], shape (1, n_atoms, 1) + d_neg = (atoms - d_neg)[ + None, :, None + ] # atoms[i] - atoms[i - 1], shape (1, n_atoms, 1) + + # distances between all pairs of grid values + deltas = clipped_values - clipped_atoms # (batch_size, n_atoms, n_atoms) + + # computes eqn (7) in distributional RL paper by doing the following - for each + # output atom in @atoms, consider values that are close enough, and weight their + # probability mass contribution by the normalized distance in [0, 1] given + # by (1. - (z_j - z_i) / (delta_z)). + d_sign = (deltas >= 0.0).float() + delta_hat = (d_sign * deltas / d_pos) - ((1.0 - d_sign) * deltas / d_neg) + delta_hat = (1.0 - delta_hat).clamp(min=0.0, max=1.0) + probabilities = probabilities[:, None, :] + return (delta_hat * probabilities).sum(dim=2) + + +def trajectory_loss( + predictions, + targets, + availabilities, + weights_scaling=None, + crit=nn.MSELoss(reduction="none"), +): + """ + Aggregated per-step loss between gt and predicted trajectories + Args: + predictions (torch.Tensor): predicted trajectory [B, (A), T, D] + targets (torch.Tensor): target trajectory [B, (A), T, D] + availabilities (torch.Tensor): [B, (A), T] + weights_scaling (torch.Tensor): [D] + crit (nn.Module): loss function + + Returns: + loss (torch.Tensor) + """ + assert availabilities.shape == predictions.shape[:-1] + assert predictions.shape == targets.shape + if weights_scaling is None: + weights_scaling = torch.ones(targets.shape[-1], device=targets.device) + assert weights_scaling.shape[-1] == targets.shape[-1] + target_weights = availabilities.unsqueeze(-1) * weights_scaling + loss = torch.mean(crit(predictions, targets) * target_weights) + return loss + + +def MultiModal_trajectory_loss( + predictions, + targets, + availabilities, + prob, + weights_scaling=None, + crit=nn.MSELoss(reduction="none"), + calc_goal_reach=False, + gamma=None, + detach_nonopt=True, +): + """ + Aggregated per-step loss between gt and predicted trajectories + Args: + predictions (torch.Tensor): predicted trajectory [B, M, (A), T, D] + targets (torch.Tensor): target trajectory [B, (A), T, D] + availabilities (torch.Tensor): [B, (M), (A), T] + prob (torch.Tensor): [B, (A), M] + weights_scaling (torch.Tensor): [D] + crit (nn.Module): loss function + gamma (float): risk level for CVAR + detach_nonopt (Bool): if detaching the loss for the non-optimal mode + + Returns: + loss (torch.Tensor) + """ + if weights_scaling is None: + weights_scaling = torch.ones(targets.shape[-1], device=targets.device) + bs, M, Na = predictions.shape[:3] + assert weights_scaling.shape[-1] == targets.shape[-1] + target_weights = availabilities.unsqueeze(-1) * weights_scaling + if availabilities.ndim < targets.ndim: + target_weights = target_weights.unsqueeze(1) + loss_v = ( + crit(predictions, targets.unsqueeze(1).repeat_interleave(M, 1)) * target_weights + ) + loss_agg_agent = loss_v.reshape(bs, M, Na, -1).sum(-1).transpose(1, 2) + loss_agg = loss_v.reshape(bs, M, -1).sum(-1) + + if gamma is not None: + if prob.ndim == 2: + p_cvar = list() + for i in range(bs): + p_cvar.append( + CVaR_weight(loss_agg[i], prob[i], gamma, sign=1, end_idx=None) + ) + p_cvar = torch.stack(p_cvar, 0) + elif prob.ndim == 3: + p_cvar = list() + for i in range(bs): + p_cvar_i = list() + for j in range(Na): + p_cvar_i.append( + CVaR_weight( + loss_agg_agent[i, j], + prob[i, j], + gamma, + sign=1, + end_idx=None, + ) + ) + p_cvar_i = torch.stack(p_cvar_i, 0) + + p_cvar.append(p_cvar_i) + p_cvar = torch.stack(p_cvar, 0) + prob = p_cvar + + if detach_nonopt: + if prob.ndim == 2: + loss_agg_detached = loss_agg.detach() + min_flag = loss_agg == loss_agg.min(dim=1, keepdim=True)[0] + nonmin_flag = torch.logical_not(min_flag) + loss = ( + (loss_agg * min_flag * prob).sum() + + (loss_agg_detached * nonmin_flag * prob).sum() + ) / (availabilities.sum() + 0.01) + elif prob.ndim == 3: + loss_agg_detached = loss_agg_agent.detach() + min_flag = loss_agg_agent == loss_agg_agent.min(dim=2, keepdim=True)[0] + nonmin_flag = torch.logical_not(min_flag) + loss = ( + (loss_agg_agent * min_flag * prob).sum() + + (loss_agg_detached * nonmin_flag * prob).sum() + ) / (availabilities.sum() + 0.01) + else: + if prob.ndim == 2: + loss = (loss_agg * prob).sum() / (availabilities.sum() + 0.01) + elif prob.ndim == 3: + loss = (loss_agg_agent * prob).sum() / (availabilities.sum() + 0.01) + if calc_goal_reach: + last_inds = batch_utils().get_last_available_index(availabilities) # [B, (A)] + num_frames = availabilities.shape[-1] + goal_mask = TensorUtils.to_one_hot(last_inds, num_class=num_frames) + goal_mask = goal_mask.unsqueeze(1).unsqueeze(-1) + if detach_nonopt: + if prob.ndim == 2: + goal_loss = ( + ( + loss_v * (min_flag * prob)[..., None, None, None] * goal_mask + ).sum() + + ( + loss_v.detach() + * (nonmin_flag * prob)[..., None, None, None] + * goal_mask + ).sum() + ) / (goal_mask.sum() + 0.01) + elif prob.ndim == 3: + goal_mask = goal_mask.transpose(1, 2) + goal_loss = ( + ( + loss_v.transpose(1, 2) + * (min_flag * prob)[..., None, None] + * goal_mask + ).sum() + + ( + loss_v.transpose(1, 2).detach() + * (nonmin_flag * prob)[..., None, None, None] + * goal_mask + ).sum() + ) / (goal_mask.sum() + 0.01) + else: + if prob.ndim == 2: + goal_loss = (loss_v * prob[..., None, None, None] * goal_mask).sum() / ( + goal_mask.sum() + 0.01 + ) + elif prob.ndim == 3: + goal_mask = goal_mask.transpose(1, 2) + goal_loss = ( + loss_v.transpose(1, 2) * prob[..., None, None] * goal_mask + ).sum() / (goal_mask.sum() + 0.01) + return loss, goal_loss + else: + return loss, None + + +def goal_reaching_loss( + predictions, + targets, + availabilities, + weights_scaling=None, + crit=nn.MSELoss(reduction="none"), +): + """ + Final step loss between gt and predicted trajectories (normally used in conjunction with a forward dynamics model) + Args: + predictions (torch.Tensor): predicted trajectory [B, (A), T, D] + targets (torch.Tensor): target trajectory [B, (A), T, D] + availabilities (torch.Tensor): [B, (A), T] + weights_scaling (torch.Tensor): [D] + crit (nn.Module): loss function + + Returns: + loss (torch.Tensor) + """ + # compute loss mask by finding the last available target + num_frames = availabilities.shape[-1] + last_inds = batch_utils().get_last_available_index(availabilities) # [B, (A)] + goal_mask = TensorUtils.to_one_hot( + last_inds, num_class=num_frames + ) # [B, (A), T] with the last frame set to 1 + # filter out samples that do not have available frames + available_samples_mask = availabilities.sum(-1) > 0 # [B, (A)] + goal_mask = goal_mask * available_samples_mask.unsqueeze(-1).float() # [B, (A), T] + goal_loss = trajectory_loss( + predictions, + targets, + availabilities=goal_mask, + weights_scaling=weights_scaling, + crit=crit, + ) + return goal_loss + + +def lane_regulation_loss(lane_flag, agent_mask): + return (lane_flag.mean(-1) * agent_mask).sum() / agent_mask.sum() + + +def weighted_trajectory_loss( + predictions, + targets, + target_weights, + total_count, + weights_scaling=None, + crit=nn.MSELoss(reduction="none"), +): + """ + Aggregated per-step loss between gt and predicted trajectories + Args: + predictions (torch.Tensor): predicted trajectory [B, (A), T, D] + targets (torch.Tensor): target trajectory [B, (A), T, D] + weights (torch.Tensor): [B, (A), T] + total_count (float) + weight_scaling (torch.Tensor): [D], Defaults to None. + crit (nn.Module): loss function + + Returns: + loss (torch.Tensor) + """ + assert target_weights.shape == predictions.shape[:-1] + assert predictions.shape == targets.shape + if weights_scaling is None: + weights_scaling = torch.ones(targets.shape[-1], device=targets.device) + assert weights_scaling.shape[-1] == targets.shape[-1] + target_weights = target_weights.unsqueeze(-1) * weights_scaling + loss = torch.sum(crit(predictions, targets) * target_weights) / total_count + return loss + + +def CVaR_weight(val, p, gamma, sign=-1, end_idx=None): + q = torch.clamp(p / gamma, max=1.0) + if end_idx is None: + end_idx = p.shape[0] + assert (p[end_idx:] == 0).all() + remain = 1.0 + if sign == 1: + idx = torch.argsort(val[0:end_idx]) + else: + idx = torch.argsort(val[0:end_idx], descending=True) + i = 0 + for i in range(end_idx): + if q[idx[i]] > remain: + q[idx[i]] = remain + remain = 0.0 + else: + remain -= q[idx[i]] + return q + + +def weighted_multimodal_trajectory_loss( + predictions, + targets, + target_weights, + probability, + total_count, + weights_scaling=None, + crit=nn.MSELoss(reduction="none"), + gamma=None, +): + """ + Aggregated per-step loss between gt and predicted trajectories + Args: + predictions (torch.Tensor): predicted trajectory [B, M, A, T, D] + targets (torch.Tensor): target trajectory [B, A, T, D] + target_weights (torch.Tensor): [B, A, T] + probability (torch.Tensor): [B,M] + total_count (float) + weight_scaling (torch.Tensor): [D], Defaults to None. + crit (nn.Module): loss function + + Returns: + loss (torch.Tensor) + """ + bs, M, A, T = predictions.shape[:4] + assert target_weights.shape == predictions.shape[:-1] + assert predictions.shape == targets.shape + if weights_scaling is None: + weights_scaling = torch.ones(targets.shape[-1], device=targets.device) + assert weights_scaling.shape[-1] == targets.shape[-1] + target_weights = target_weights.unsqueeze(-1) * weights_scaling + err = ( + crit( + targets.unsqueeze(1).repeat(1, predictions.size(1), 1, 1, 1), + predictions, + ) + * target_weights.unsqueeze(1) + * probability[:, :, None, None, None] + ) + if gamma is None: + max_idx = torch.max(probability, dim=-1)[1] + max_mask = torch.zeros( + [*err.shape[:2], 1, 1, 1], dtype=torch.bool, device=err.device + ) + max_mask[torch.arange(0, err.size(0)), max_idx] = True + nonmax_mask = ~max_mask + loss = ( + torch.sum((err * max_mask)) + torch.sum((err * nonmax_mask).detach()) + ) / total_count + else: + loss_agg = err.reshape(bs, M, -1).sum(-1) + + p_cvar = list() + for i in range(bs): + p_cvar.append( + CVaR_weight(loss_agg[i], probability[i], gamma, sign=1, end_idx=None) + ) + p_cvar = torch.stack(p_cvar, 0) + loss = (loss_agg * p_cvar).sum() / total_count + return loss + + +def likelihood_loss(likelihood): + return 1.0 - torch.mean(likelihood) + + +def lane_regularization_loss(lane_flags, weights, total_count, probability=None): + """penalizing the vehicle for exiting drivable area + + Args: + lane_flags (torch.Tensor): 1 for in the lane, 0 for out of the lane, [B, (M), (A), T, 1] + weights (torch.Tensor): [B, (A), T] + total_count (float): + probability (torch.Tensor, optional): [B,M]. Defaults to None. + + Returns: + [type]: [description] + """ + if probability is None: + loss = torch.sum(weights.unsqueeze(-1) * (1.0 - lane_flags)) / total_count + else: + if lane_flags.ndim == 4: + probability = probability[:, :, None, None] + elif lane_flags.ndim == 5: + probability = probability[:, :, None, None, None] + loss = ( + torch.sum( + weights.unsqueeze(-1).unsqueeze(1) * (1.0 - lane_flags) * probability + ) + / total_count + ) + return loss + return loss + + +def goal_reaching_loss( + predictions, + targets, + availabilities, + weights_scaling=None, + crit=nn.MSELoss(reduction="none"), +): + """ + Final step loss between gt and predicted trajectories (normally used in conjunction with a forward dynamics model) + Args: + predictions (torch.Tensor): predicted trajectory [B, (A), T, D] + targets (torch.Tensor): target trajectory [B, (A), T, D] + availabilities (torch.Tensor): [B, (A), T] + weights_scaling (torch.Tensor): [D] + crit (nn.Module): loss function + + Returns: + loss (torch.Tensor) + """ + # compute loss mask by finding the last available target + num_frames = availabilities.shape[-1] + last_inds = batch_utils().get_last_available_index(availabilities) # [B, (A)] + goal_mask = TensorUtils.to_one_hot( + last_inds, num_class=num_frames + ) # [B, (A), T] with the last frame set to 1 + # filter out samples that do not have available frames + available_samples_mask = availabilities.sum(-1) > 0 # [B, (A)] + goal_mask = goal_mask * available_samples_mask.unsqueeze(-1).float() # [B, (A), T] + goal_loss = trajectory_loss( + predictions, + targets, + availabilities=goal_mask, + weights_scaling=weights_scaling, + crit=crit, + ) + return goal_loss + + +def collision_loss( + pred_edges: Dict[str, torch.Tensor], + weight=None, + col_funcs=None, + keepdim=False, + return_dis=False, +): + """ + Calculate collision loss among predicted edges along a batch of trajectories + Args: + pred_edges (dict): A dict that maps collision types to box locations + col_funcs (dict): A dict of collision functions (implemented in diffstack.utils.geometric_utils) + + Returns: + collision loss (torch.Tensor) + """ + if col_funcs is None: + col_funcs = { + "VV": VEH_VEH_collision, + "VP": VEH_PED_collision, + "PV": PED_VEH_collision, + "PP": PED_PED_collision, + } + + coll_loss = 0.0 + dis_by_type = dict() + for et, fun in col_funcs.items(): + if et not in pred_edges: + continue + edges = pred_edges[et] + if edges.shape[0] == 0: + continue + dis = fun(edges[..., 0:3], edges[..., 3:6], edges[..., 6:8], edges[..., 8:]) + dis_by_type[et] = dis + coll_loss_tensor = ( + torch.sigmoid(-dis * 4).nan_to_num(0.0).sum(dim=-1) + ) # smooth collision loss + if weight is not None: + if keepdim: + coll_loss += coll_loss_tensor * weight / (weight.sum() + 1e-5) + else: + coll_loss += torch.sum(coll_loss_tensor * weight) / ( + weight.sum() + 1e-5 + ) + else: + if keepdim: + coll_loss += coll_loss_tensor + else: + coll_loss += coll_loss_tensor.sum(-1).mean() + if return_dis: + return coll_loss, dis_by_type + else: + return coll_loss + + +def collision_loss_masked(edges, type_mask, weight=None, col_funcs=None, keepdim=False): + if col_funcs is None: + col_funcs = { + "VV": VEH_VEH_collision, + "VP": VEH_PED_collision, + "PV": PED_VEH_collision, + "PP": PED_PED_collision, + } + + coll_loss = 0.0 + for k, v in type_mask.items(): + if edges.shape[0] == 0: + continue + dis = col_funcs[k]( + edges[..., 0:3], + edges[..., 3:6], + edges[..., 6:8], + edges[..., 8:], + ).min(dim=-1)[0] + coll_loss_tensor = torch.sigmoid(-dis * 4) * v + if weight is not None: + if keepdim: + coll_loss += coll_loss_tensor * weight / (weight.sum() + 1e-5) + else: + coll_loss += torch.sum(coll_loss_tensor * weight) / ( + weight.sum() + 1e-5 + ) + else: + if keepdim: + coll_loss += coll_loss_tensor + else: + coll_loss += coll_loss_tensor.sum(-1).mean() + return coll_loss + + +def discriminator_loss(likelihood_pred, likelihood_GT): + label = torch.cat( + (torch.zeros_like(likelihood_pred), torch.ones_like(likelihood_GT)), 0 + ) + return F.binary_cross_entropy(torch.cat((likelihood_pred, likelihood_GT)), label) + + +def compute_pred_loss( + recon_loss_type, pred_batch, target_traj, avails, prob, weights_scaling=None +): + if "z" in pred_batch: + z1 = torch.argmax(pred_batch["z"], dim=-1) + else: + z1 = None + if recon_loss_type == "NLL": + assert "logvar" in pred_batch["x_recons"] + bs, M, T, D = pred_batch["trajectories"].shape + var = ( + torch.exp(pred_batch["x_recons"]["logvar"]) + + torch.ones_like(pred_batch["x_recons"]["logvar"]) * 1e-4 + ).reshape(bs, M, -1) + if z1 is not None: + var = torch.gather(var, 1, z1.unsqueeze(-1).repeat(1, 1, var.size(-1))) + avails = ( + avails.unsqueeze(-1).repeat(1, 1, target_traj.shape[-1]).reshape(bs, -1) + ) + pred_loss = NLL_GMM_loss( + x=target_traj.reshape(bs, -1), + m=pred_batch["trajectories"].reshape(bs, M, -1), + v=var, + pi=prob, + avails=avails, + ) + pred_loss = pred_loss.mean() + + elif recon_loss_type == "MSE": + pred_loss, _ = MultiModal_trajectory_loss( + predictions=pred_batch["trajectories"], + targets=target_traj, + availabilities=avails, + prob=prob, + weights_scaling=weights_scaling, + ) + else: + raise NotImplementedError("{} is not implemented".format(recon_loss_type)) + return pred_loss + + +def diversity_score( + pred: torch.tensor, + avail: torch.tensor, + mode: str = "mean", + max_clip=0.4, +) -> torch.tensor: + """ + Compute the distance among trajectory samples at the last timestep + Args: + pred (torch.tensor): array of shape B x M x A x T x 2 + avail (torch.tensor): array of availability B x + mode (str): calculation mode: option are "mean" (average distance) and "max" (distance between + the two most distinctive samples). + Returns: + torch.tensor: average displacement error (ADE) of the batch, an array of float numbers + """ + # compute pairwise distances at the last time step + traj_norm = torch.linalg.norm(pred[..., -1, :] - pred[..., 0, :], dim=-1).clip( + min=0.1 + ) + traj_norm = traj_norm[:, None, :, :, None].detach() + avail = avail.any(-1)[:, None, None, :] + pred = pred[..., -1, :] + error = torch.linalg.norm( + (pred[:, None, :] - pred[:, :, None]) * avail.unsqueeze(-1) / traj_norm, dim=-1 + ) + error = (error.clip(max=max_clip) * avail).sum(-1) / avail.sum(-1).clip(min=1) + # [B, M, M] + error = error.reshape(error.shape[0], -1) # [B, M * M] + if mode == "max": + error = error.max(-1) + elif mode == "mean": + error = error.mean(-1) + else: + raise ValueError(f"mode: {mode} not valid") + return error.mean() + + +def loss_clip(loss, max_loss=2.0): + loss_offset = F.relu(loss - max_loss).detach() + return loss - loss_offset diff --git a/diffstack/utils/math_utils.py b/diffstack/utils/math_utils.py new file mode 100644 index 0000000..ecd7dff --- /dev/null +++ b/diffstack/utils/math_utils.py @@ -0,0 +1,59 @@ +import torch +import numpy as np + +def soft_min(x,y,gamma=5): + if isinstance(x,torch.Tensor): + expfun = torch.exp + elif isinstance(x,np.ndarray): + expfun = np.exp + exp1 = expfun((y-x)/2) + exp2 = expfun((x-y)/2) + return (exp1*x+exp2*y)/(exp1+exp2) + +def soft_max(x,y,gamma=5): + if isinstance(x,torch.Tensor): + expfun = torch.exp + elif isinstance(x,np.ndarray): + expfun = np.exp + exp1 = expfun((x-y)/2) + exp2 = expfun((y-x)/2) + return (exp1*x+exp2*y)/(exp1+exp2) + +def soft_sat(x,x_min=None,x_max=None,gamma=5): + if x_min is None and x_max is None: + return x + elif x_min is None and x_max is not None: + return soft_min(x,x_max,gamma) + elif x_min is not None and x_max is None: + return soft_max(x,x_min,gamma) + else: + if isinstance(x_min,torch.Tensor) or isinstance(x_min,np.ndarray): + assert (x_max>x_min).all() + else: + assert x_max>x_min + xc = x - (x_min+x_max)/2 + if isinstance(x,torch.Tensor): + return xc/(torch.pow(1+torch.pow(torch.abs(xc*2/(x_max-x_min)),gamma),1/gamma))+(x_min+x_max)/2 + elif isinstance(x,np.ndarray): + return xc/(np.power(1+np.power(np.abs(xc*2/(x_max-x_min)),gamma),1/gamma))+(x_min+x_max)/2 + else: + raise Exception("data type not supported") + + +def Gaussian_importance_sampling(mu0,var0,mu1,var1,num_samples=1): + """ perform importance sampling between two Gaussian distributions + + Args: + mu0 (torch.Tensor): [B,D]: mean of the target Gaussian distribution + var0 (torch.Tensor): [B,D]: variance of the target Gaussian distribution + mu1 (torch.Tensor): [B,D]: mean of the proposal Gaussian distribution + var1 (torch.Tensor): [B,D]: variance of the proposal Gaussian distribution + num_samples (int, optional): number of samples. Defaults to 1. + + Returns: + samples: [B,num_samples,D]: samples from the proposal Gaussian distribution + log_weights: [B,num_samples]: log weights of the samples + """ + samples = torch.randn([*mu1.shape[:-1],num_samples,mu1.shape[-1]],device=mu1.device)*torch.sqrt(var1).unsqueeze(-2)+mu1.unsqueeze(-2) + log_weights = -0.5*torch.log(2*np.pi*var0.unsqueeze(-2))-0.5*torch.pow(samples-mu0.unsqueeze(-2),2)/var0.unsqueeze(-2)+0.5*torch.log(2*np.pi*var1.unsqueeze(-2))+0.5*torch.pow(samples-mu1.unsqueeze(-2),2)/var1.unsqueeze(-2) + return samples,log_weights diff --git a/diffstack/utils/metrics.py b/diffstack/utils/metrics.py new file mode 100644 index 0000000..a58f810 --- /dev/null +++ b/diffstack/utils/metrics.py @@ -0,0 +1,848 @@ +""" +Adapted from https://github.com/lyft/l5kit/blob/master/l5kit/l5kit/evaluation/metrics.py +""" + + +from typing import Callable, Dict, List, Union +import torch +import numpy as np + +from diffstack.utils.geometry_utils import ( + VEH_VEH_collision, + VEH_PED_collision, + PED_VEH_collision, + PED_PED_collision, + get_box_world_coords, +) +from diffstack.utils.loss_utils import log_normal +from diffstack.modules.predictors.trajectron_utils.model.components.gmm2d import GMM2D + +from trajdata import SceneBatch + + +metric_signature = Callable[ + [np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray +] + + +def _assert_shapes( + ground_truth: np.ndarray, + pred: np.ndarray, + confidences: np.ndarray, + avails: np.ndarray, +) -> None: + """ + Check the shapes of args required by metrics + Args: + ground_truth (np.ndarray): array of shape (batch)x(timesteps)x(2D coords) + pred (np.ndarray): array of shape (batch)x(modes)x(timesteps)x(2D coords) + confidences (np.ndarray): array of shape (batch)x(modes) with a confidence for each mode in each sample + avails (np.ndarray): array of shape (batch)x(timesteps) with the availability for each gt timesteps + Returns: + """ + assert ( + len(pred.shape) == 4 + ), f"expected 3D (BxMxTxC) array for pred, got {pred.shape}" + batch_size, num_modes, future_len, num_coords = pred.shape + + assert ground_truth.shape == ( + batch_size, + future_len, + num_coords, + ), f"expected 2D (Batch x Time x Coords) array for gt, got {ground_truth.shape}" + assert confidences.shape == ( + batch_size, + num_modes, + ), f"expected 2D (Batch x Modes) array for confidences, got {confidences.shape}" + + assert np.allclose(np.sum(confidences, axis=1), 1), "confidences should sum to 1" + assert avails.shape == ( + batch_size, + future_len, + ), f"expected 1D (Time) array for avails, got {avails.shape}" + # assert all data are valid + assert np.isfinite(pred).all(), "invalid value found in pred" + assert np.isfinite(ground_truth).all(), "invalid value found in gt" + assert np.isfinite(confidences).all(), "invalid value found in confidences" + assert np.isfinite(avails).all(), "invalid value found in avails" + + +def batch_neg_multi_log_likelihood( + ground_truth: np.ndarray, + pred: np.ndarray, + confidences: np.ndarray, + avails: np.ndarray, +) -> np.ndarray: + """ + Compute a negative log-likelihood for the multi-modal scenario. + log-sum-exp trick is used here to avoid underflow and overflow, For more information about it see: + https://en.wikipedia.org/wiki/LogSumExp#log-sum-exp_trick_for_log-domain_calculations + https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/ + https://leimao.github.io/blog/LogSumExp/ + For more details about used loss function and reformulation, please see + https://github.com/lyft/l5kit/blob/master/competition.md. + Args: + ground_truth (np.ndarray): array of shape (batchsize)x(timesteps)x(2D coords) + pred (np.ndarray): array of shape (batchsize)x(modes)x(timesteps)x(2D coords) + confidences (np.ndarray): array of shape ((batchsize)xmodes) with a confidence for each mode in each sample + avails (np.ndarray): array of shape (batchsize)x(timesteps) with the availability for each gt timesteps + Returns: + np.ndarray: negative log-likelihood for this batch, an array of float numbers + """ + _assert_shapes(ground_truth, pred, confidences, avails) + + ground_truth = np.expand_dims(ground_truth, 1) # add modes + avails = avails[:, np.newaxis, :, np.newaxis] # add modes and cords + + error = np.sum( + ((ground_truth - pred) * avails) ** 2, axis=-1 + ) # reduce coords and use availability + + with np.errstate( + divide="ignore" + ): # when confidence is 0 log goes to -inf, but we're fine with it + error = np.log(confidences) - 0.5 * np.sum(error, axis=-1) # reduce timesteps + + # use max aggregator on modes for numerical stability + max_value = np.max( + error, axis=-1, keepdims=True + ) # error are negative at this point, so max() gives the minimum one + error = ( + -np.log(np.sum(np.exp(error - max_value), axis=-1)) - max_value + ) # reduce modes + return error + + +def batch_rmse( + ground_truth: np.ndarray, + pred: np.ndarray, + confidences: np.ndarray, + avails: np.ndarray, +) -> np.ndarray: + """ + Return the root mean squared error, computed using the stable nll + Args: + ground_truth (np.ndarray): array of shape (batch)x(timesteps)x(2D coords) + pred (np.ndarray): array of shape (batch)x(modes)x(timesteps)x(2D coords) + confidences (np.ndarray): array of shape (batch)x(modes) with a confidence for each mode in each sample + avails (np.ndarray): array of shape (batch)x(timesteps) with the availability for each gt timesteps + Returns: + np.ndarray: negative log-likelihood for this batch, an array of float numbers + """ + nll = batch_neg_multi_log_likelihood(ground_truth, pred, confidences, avails) + _, _, future_len, _ = pred.shape + + return np.sqrt(2 * nll / future_len) + + +def batch_prob_true_mode( + ground_truth: np.ndarray, + pred: np.ndarray, + confidences: np.ndarray, + avails: np.ndarray, +) -> np.ndarray: + """ + Return the probability of the true mode + Args: + ground_truth (np.ndarray): array of shape (timesteps)x(2D coords) + pred (np.ndarray): array of shape (modes)x(timesteps)x(2D coords) + confidences (np.ndarray): array of shape (modes) with a confidence for each mode in each sample + avails (np.ndarray): array of shape (timesteps) with the availability for each gt timesteps + Returns: + np.ndarray: a (modes) numpy array + """ + _assert_shapes(ground_truth, pred, confidences, avails) + + ground_truth = np.expand_dims(ground_truth, 1) # add modes + avails = avails[:, np.newaxis, :, np.newaxis] # add modes and cords + + error = np.sum( + ((ground_truth - pred) * avails) ** 2, axis=-1 + ) # reduce coords and use availability + + with np.errstate( + divide="ignore" + ): # when confidence is 0 log goes to -inf, but we're fine with it + error = np.log(confidences) - 0.5 * np.sum(error, axis=-1) # reduce timesteps + + # use max aggregator on modes for numerical stability + max_value = np.max( + error, axis=-1, keepdims=True + ) # error are negative at this point, so max() gives the minimum one + + error = np.exp(error - max_value) / np.sum(np.exp(error - max_value)) + return error + + +def batch_time_displace( + ground_truth: np.ndarray, + pred: np.ndarray, + confidences: np.ndarray, + avails: np.ndarray, +) -> np.ndarray: + """ + Return the displacement at timesteps T + Args: + ground_truth (np.ndarray): array of shape (batch)x(timesteps)x(2D coords) + pred (np.ndarray): array of shape (batch)x(modes)x(timesteps)x(2D coords) + confidences (np.ndarray): array of shape (batch)x(modes) with a confidence for each mode in each sample + avails (np.ndarray): array of shape (batch)x(timesteps) with the availability for each gt timesteps + Returns: + np.ndarray: a (batch)x(timesteps) numpy array + """ + true_mode_error = batch_prob_true_mode(ground_truth, pred, confidences, avails) + true_mode_error = true_mode_error[:, :, np.newaxis] # add timesteps axis + + ground_truth = np.expand_dims(ground_truth, 1) # add modes + avails = avails[:, np.newaxis, :, np.newaxis] # add modes and cords + + error = np.sum( + ((ground_truth - pred) * avails) ** 2, axis=-1 + ) # reduce coords and use availability + return np.sum(true_mode_error * np.sqrt(error), axis=1) # reduce modes + + +def batch_average_displacement_error( + ground_truth: np.ndarray, + pred: np.ndarray, + confidences: np.ndarray, + avails: np.ndarray, + mode: str = "mean", +) -> np.ndarray: + """ + Returns the average displacement error (ADE), which is the average displacement over all timesteps. + During calculation, confidences are ignored, and two modes are available: + - oracle: only consider the best hypothesis + - mean: average over all hypotheses + Args: + ground_truth (np.ndarray): array of shape (batch)x(time)x(2D coords) + pred (np.ndarray): array of shape (batch)x(modes)x(time)x(2D coords) + confidences (np.ndarray): array of shape (batch)x(modes) with a confidence for each mode in each sample + avails (np.ndarray): array of shape (batch)x(time) with the availability for each gt timestep + mode (str): calculation mode - options are 'mean' (average over hypotheses) and 'oracle' (use best hypotheses) + Returns: + np.ndarray: average displacement error (ADE) of the batch, an array of float numbers + """ + _assert_shapes(ground_truth, pred, confidences, avails) + + ground_truth = np.expand_dims(ground_truth, 1) # add modes + avails = avails[:, np.newaxis, :, np.newaxis] # add modes and cords + + error = np.sum( + ((ground_truth - pred) * avails) ** 2, axis=-1 + ) # reduce coords and use availability + error = error ** 0.5 # calculate root of error (= L2 norm) + error = np.mean(error, axis=-1) # average over timesteps + if mode == "oracle": + error = np.min(error, axis=1) # use best hypothesis + elif mode == "mean": + error = np.mean(error*confidences, axis=1) # average over hypotheses + else: + raise ValueError(f"mode: {mode} not valid") + + return error + + +def batch_final_displacement_error( + ground_truth: np.ndarray, + pred: np.ndarray, + confidences: np.ndarray, + avails: np.ndarray, + mode: str = "mean", +) -> np.ndarray: + """ + Returns the final displacement error (FDE), which is the displacement calculated at the last timestep. + During calculation, confidences are ignored, and two modes are available: + - oracle: only consider the best hypothesis + - mean: average over all hypotheses + Args: + ground_truth (np.ndarray): array of shape (batch)x(time)x(2D coords) + pred (np.ndarray): array of shape (batch)x(modes)x(time)x(2D coords) + confidences (np.ndarray): array of shape (batch)x(modes) with a confidence for each mode in each sample + avails (np.ndarray): array of shape (batch)x(time) with the availability for each gt timestep + mode (str): calculation mode - options are 'mean' (average over hypotheses) and 'oracle' (use best hypotheses) + Returns: + np.ndarray: final displacement error (FDE) of the batch, an array of float numbers + """ + _assert_shapes(ground_truth, pred, confidences, avails) + inds = np.arange(0, pred.shape[2]) + inds = (avails > 0) * inds # [B, (A), T] arange indices with unavailable indices set to 0 + last_inds = inds.max(axis=-1) + last_inds = np.tile(last_inds[:, np.newaxis, np.newaxis],(1,pred.shape[1],1)) + ground_truth = np.expand_dims(ground_truth, 1) # add modes + avails = avails[:, np.newaxis, :, np.newaxis] # add modes and cords + + + error = np.sum( + ((ground_truth - pred) * avails) ** 2, axis=-1 + ) # reduce coords and use availability + error = error ** 0.5 # calculate root of error (= L2 norm) + + # error = error[:, :, -1] # use last timestep + error = np.take_along_axis(error,last_inds,axis=2).squeeze(-1) + if mode == "oracle": + error = np.min(error, axis=-1) # use best hypothesis + elif mode == "mean": + error = np.mean(error, axis=-1) # average over hypotheses + else: + raise ValueError(f"mode: {mode} not valid") + + return error + + +def batch_average_diversity( + ground_truth: np.ndarray, + pred: np.ndarray, + confidences: np.ndarray, + avails: np.ndarray, + mode: str = "max", +) -> np.ndarray: + """ + Compute the distance among trajectory samples averaged across time steps + Args: + ground_truth (np.ndarray): array of shape (batch)x(time)x(2D coords) + pred (np.ndarray): array of shape (batch)x(modes)x(time)x(2D coords) + confidences (np.ndarray): array of shape (batch)x(modes) with a confidence for each mode in each sample + avails (np.ndarray): array of shape (batch)x(time) with the availability for each gt timestep + mode (str): calculation mode: option are "mean" (average distance) and "max" (distance between + the two most distinctive samples). + Returns: + np.ndarray: average displacement error (ADE) of the batch, an array of float numbers + """ + _assert_shapes(ground_truth, pred, confidences, avails) + # compute pairwise distances + error = np.linalg.norm( + pred[:, np.newaxis, :] - pred[:, :, np.newaxis], axis=-1 + ) # [B, M, M, T] + error = np.mean(error, axis=-1) # average over timesteps + error = error.reshape([error.shape[0], -1]) # [B, M * M] + if mode == "max": + error = np.max(error, axis=-1) + elif mode == "mean": + error = np.mean(error, axis=-1) + else: + raise ValueError(f"mode: {mode} not valid") + + return error + + +def batch_final_diversity( + ground_truth: np.ndarray, + pred: np.ndarray, + confidences: np.ndarray, + avails: np.ndarray, + mode: str = "max", +) -> np.ndarray: + """ + Compute the distance among trajectory samples at the last timestep + Args: + ground_truth (np.ndarray): array of shape (batch)x(time)x(2D coords) + pred (np.ndarray): array of shape (batch)x(modes)x(time)x(2D coords) + confidences (np.ndarray): array of shape (batch)x(modes) with a confidence for each mode in each sample + avails (np.ndarray): array of shape (batch)x(time) with the availability for each gt timestep + mode (str): calculation mode: option are "mean" (average distance) and "max" (distance between + the two most distinctive samples). + Returns: + np.ndarray: average displacement error (ADE) of the batch, an array of float numbers + """ + _assert_shapes(ground_truth, pred, confidences, avails) + # compute pairwise distances at the last time step + pred = pred[..., -1] + error = np.linalg.norm( + pred[:, np.newaxis, :] - pred[:, :, np.newaxis], axis=-1 + ) # [B, M, M] + error = error.reshape([error.shape[0], -1]) # [B, M * M] + if mode == "max": + error = np.max(error, axis=-1) + elif mode == "mean": + error = np.mean(error, axis=-1) + else: + raise ValueError(f"mode: {mode} not valid") + + return error + + +def single_mode_metrics( + metrics_func, ground_truth: np.ndarray, pred: np.ndarray, avails: np.ndarray +): + """ + Run a metrics with single mode by inserting a mode dimension + + Args: + ground_truth (np.ndarray): array of shape (batch)x(time)x(2D coords) + pred (np.ndarray): array of shape (batch)x(time)x(2D coords) + avails (np.ndarray): array of shape (batch)x(time) with the availability for each gt timestep + mode (str): Optional, set to None when not applicable + calculation mode - options are 'mean' (average over hypotheses) and 'oracle' (use best hypotheses) + Returns: + np.ndarray: metrics values + """ + pred = pred[:, None] + conf = np.ones((pred.shape[0], 1)) + kwargs = dict(ground_truth=ground_truth, pred=pred, confidences=conf, avails=avails) + return metrics_func(**kwargs) + + +def batch_pairwise_collision_rate(agent_edges, collision_funcs=None): + """ + Count number of collisions among edge pairs in a batch + Args: + agent_edges (dict): A dict that maps collision types to box locations + collision_funcs (dict): A dict of collision functions (implemented in diffstack.utils.geometric_utils) + + Returns: + collision loss (torch.Tensor) + """ + if collision_funcs is None: + collision_funcs = { + "VV": VEH_VEH_collision, + "VP": VEH_PED_collision, + "PV": PED_VEH_collision, + "PP": PED_PED_collision, + } + + coll_rates = {} + for et, fun in collision_funcs.items(): + edges = agent_edges[et] + dis = fun( + edges[..., 0:3], + edges[..., 3:6], + edges[..., 6:8], + edges[..., 8:], + ) + dis = dis.min(-1)[0] # reduction over time + if isinstance(dis, np.ndarray): + coll_rates[et] = np.sum(dis <= 0) / float(dis.shape[0]) + else: + coll_rates[et] = torch.sum(dis <= 0) / float(dis.shape[0]) + return coll_rates + + +def batch_pairwise_collision_rate_masked(agent_edges, type_mask,collision_funcs=None): + """ + Count number of collisions among edge pairs in a batch + Args: + agent_edges (dict): A dict that maps collision types to box locations + collision_funcs (dict): A dict of collision functions (implemented in diffstack.utils.geometric_utils) + + Returns: + collision loss (torch.Tensor) + """ + if collision_funcs is None: + collision_funcs = { + "VV": VEH_VEH_collision, + "VP": VEH_PED_collision, + "PV": PED_VEH_collision, + "PP": PED_PED_collision, + } + coll_rates = {} + for et, fun in collision_funcs.items(): + if et in type_mask and type_mask[et].sum()>0: + dis = fun( + agent_edges[..., 0:3], + agent_edges[..., 3:6], + agent_edges[..., 6:8], + agent_edges[..., 8:], + ) + dis = dis.min(-1)[0] # reduction over time + if isinstance(dis, np.ndarray): + coll_rates[et] = np.sum((dis <= 0)*type_mask[et]) / type_mask[et].sum() + else: + coll_rates[et] = torch.sum((dis <= 0)*type_mask[et]) / type_mask[et].sum() + return coll_rates + + +def batch_detect_off_road(positions, drivable_region_map): + """ + Compute whether the given positions are out of drivable region + Args: + positions (torch.Tensor): a position (x, y) in rasterized frame [B, ..., 2] + drivable_region_map (torch.Tensor): binary drivable region maps [B, H, W] + + Returns: + off_road (torch.Tensor): whether each given position is off-road [B, ...] + """ + assert positions.shape[0] == drivable_region_map.shape[0] + assert drivable_region_map.ndim == 3 + b, h, w = drivable_region_map.shape + positions_flat = positions[..., 1].long() * w + positions[..., 0].long() + if positions_flat.ndim == 1: + positions_flat = positions_flat[:, None] + drivable = torch.gather( + drivable_region_map.flatten(start_dim=1), # [B, H * W] + dim=1, + index=positions_flat.long().flatten(start_dim=1), # [B, (all trailing dim flattened)] + ).reshape(*positions.shape[:-1]) + return 1 - drivable.float() + + +def batch_detect_off_road_boxes(positions, yaws, extents, drivable_region_map): + """ + Compute whether boxes specified by (@positions, @yaws, and @extents) are out of drivable region. + A box is considered off-road if at least one of its corners are out of drivable region + Args: + positions (torch.Tensor): agent centroid (x, y) in rasterized frame [B, ..., 2] + yaws (torch.Tensor): agent yaws in rasterized frame [B, ..., 1] + extents (torch.Tensor): agent extents [B, ..., 2] + drivable_region_map (torch.Tensor): binary drivable region maps [B, H, W] + + Returns: + box_off_road (torch.Tensor): whether each given box is off-road [B, ...] + """ + boxes = get_box_world_coords(positions, yaws, extents) # [B, ..., 4, 2] + off_road = batch_detect_off_road(boxes, drivable_region_map) # [B, ..., 4] + box_off_road = off_road.sum(dim=-1) > 0.5 + return box_off_road.float() + + +def GMM_loglikelihood(x, m, v, pi, avails=None, mode="mean"): + """ + Log probability of tensor x under a uniform mixture of Gaussians. + Adapted from CS 236 at Stanford. + Args: + x (torch.Tensor): tensor with shape (B, D) + m (torch.Tensor): means tensor with shape (B, M, D) or (1, M, D), where + M is number of mixture components + v (torch.Tensor): variances tensor with shape (B, M, D) or (1, M, D) where + M is number of mixture components + logpi (torch.Tensor): log probability of the modes (B,M) + detach (bool): option whether to detach all modes but the best one + mode (string): mode of loss, sum or max + + Returns: + -log_prob (torch.Tensor): log probabilities of shape (B,) + """ + + if v is None: + v = torch.ones_like(m) + + # (B , D) -> (B , 1, D) + x = x.unsqueeze(1) + # (B, 1, D) -> (B, M, D) -> (B, M) + if avails is not None: + avails = avails.unsqueeze(1) + log_prob = log_normal(x, m, v, avails=avails) + if mode == "sum": + loglikelihood = (pi*log_prob).sum(1) + elif mode == "mean": + loglikelihood = (pi*log_prob).mean(1) + elif mode == "max": + loglikelihood = (pi*log_prob).max(1) + return loglikelihood + + +class DistanceBuffer(object): + """ class that stores the distance given x,y location + """ + def __init__(self): + self._buffer = dict() + + def __getitem__(self,key): + if key in self._buffer: + return self._buffer[key] + else: + return self.update(key) + + def update(self,key): + dis = np.linalg.norm(key) + self._buffer[key] = dis + self._buffer[-key] = dis + return dis + + +class RandomPerturbation(object): + """ + Add Gaussian noise to the trajectory + """ + def __init__(self, std: np.ndarray): + assert std.shape == (3,) and np.all(std >= 0) + self.std = std + + def perturb(self, obs): + """Given the observation object, add Gaussian noise to positions and yaws + + Args: + obs(Dict[torch.tensor]): observation dict + + Returns: + obs(Dict[torch.tensor]): perturbed observation + """ + obs = dict(obs) + target_traj = np.concatenate((obs["fut_pos"], obs["fut_yaw"]), axis=-1) + std = np.ones_like(target_traj) * self.std[None, :] + noise = np.random.normal(np.zeros_like(target_traj), std) + target_traj += noise + obs["fut_pos"] = target_traj[..., :2] + obs["fut_yaw"] = target_traj[..., :1] + return obs + + +class OrnsteinUhlenbeckPerturbation(object): + """ + Add Ornstein-Uhlenbeck noise to the trajectory + """ + def __init__(self,theta,sigma): + """ + Args: + theta (np.ndarray): converging factor of the OU process + sigma (np.ndarray): magnitude of the Gaussian noise added at each step + """ + assert theta.shape == (3,) and sigma.shape == (3,) + self.theta = theta + self.sigma = sigma + self.buffers = dict() + + def perturb(self,obs): + """Given the observation object, add Gaussian noise to positions and yaws + + Args: + obs(Dict[torch.tensor]): observation dict + + Returns: + obs(Dict[torch.tensor]): perturbed observation + """ + if isinstance(obs["fut_pos"],np.ndarray): + target_traj = np.concatenate((obs["fut_pos"], obs["fut_yaw"]), axis=-1) + bs = target_traj.shape[0] + T = target_traj.shape[-2] + if bs in self.buffers: + buffer = self.buffers[bs] + else: + buffer = [np.zeros([bs,3])] + self.buffers[bs]=buffer + while len(buffer) Dict[str, torch.Tensor]: + """ + Returns a dict of ade and fde metrics out of top-k most likely predictions. + Args: + trajs: tensor, (b, n, s, k, t, state) + log_probs: tensor, (b, n, {s,1}, k, {t,1}) + gt_trajs: tensor, (b, n, t, state) + k: int or list of ints for top k predictions + Returns: + dict with keys ['ade_top_k', 'fde_top_k'] and values of tensor, (b, n, state) + """ + assert trajs.ndim == 6 and log_probs.ndim == 5 and gt_trajs.ndim == 4 + sorted_idxs = torch.argsort(log_probs[..., [0], :], dim=-1, descending=True) + sorted_means = torch.gather(trajs, dim=-3, index=sorted_idxs.unsqueeze(-1).expand( + (-1, trajs.shape[1], -1, -1, trajs.shape[4], trajs.shape[5]))) + + errors = torch.linalg.norm(sorted_means - gt_trajs[:, :, None, None], dim=-1) # b, n, s, k, t + + if isinstance(k, int): + k = [k] + res = {} + for ki in k: + top_k_errors = errors[..., :ki, :] + + ade_k = torch.mean(top_k_errors, dim=-1) # b, n, s, k + min_ade_k = torch.min(ade_k, dim=-1).values # b, n, s + + fde_k = top_k_errors[..., -1] # b, n, s, k + min_fde_k = torch.min(fde_k, dim=-1).values # b, n, s + + res[f"ade_top_{ki}"] = min_ade_k + res[f"fde_top_{ki}"] = min_fde_k + + return res + + +## The next set of functions are replicates from the trajectron repo. + +def compute_ade_pt(predicted_trajs, gt_traj): + error = torch.linalg.norm(predicted_trajs - gt_traj, dim=-1) + ade = torch.mean(error, axis=-1) + return ade.flatten() + + +def compute_fde_pt(predicted_trajs, gt_traj): + final_error = torch.linalg.norm(predicted_trajs[..., -1, :] - gt_traj[..., -1, :], dim=-1) + return final_error.flatten() + + +def compute_nll_pt(predicted_dist, gt_traj): + log_p_yt_xz = torch.clamp(predicted_dist.log_prob(gt_traj), min=-20.) + log_p_y_xz_final = log_p_yt_xz[..., -1] + log_p_y_xz = log_p_yt_xz.mean(dim=-1) + return -log_p_y_xz.sum(dim=0), -log_p_y_xz_final.sum(dim=0) + + +def compute_min_afde_k_pt(predicted_dist, gt_traj, k): + means = predicted_dist.mus[..., :gt_traj.shape[-1]] + probs = predicted_dist.pis_cat_dist.probs + sorted_idxs = torch.argsort(probs[..., [0], :], dim=-1, descending=True) + sorted_means = torch.gather(means, dim=-2, index=sorted_idxs.unsqueeze(-1).expand((-1, -1, means.shape[2], -1, means.shape[4]))) + + errors = torch.linalg.norm(sorted_means - gt_traj.unsqueeze(-2), dim=-1) + top_k_errors = errors[..., :k] + + ade_k = torch.mean(top_k_errors, dim=-2) + min_ade_k = torch.min(ade_k, dim=-1).values + + fde_k = top_k_errors[..., -1, :] + min_fde_k = torch.min(fde_k, dim=-1).values + + return min_ade_k.flatten(), min_fde_k.flatten() + + +def compute_nll(predicted_dist, gt_traj): + log_p_yt_xz = torch.clamp(predicted_dist.log_prob(torch.as_tensor(gt_traj)), min=-20.) + log_p_y_xz_final = log_p_yt_xz[..., -1] + log_p_y_xz = log_p_yt_xz.mean(dim=-1) + return -log_p_y_xz[0].numpy(), -log_p_y_xz_final[0].numpy() + + +def compute_prediction_metrics(futures, + prediction_output_dict=None, + y_dists=None): + eval_ret = dict() + if prediction_output_dict is not None: + ade_errors = compute_ade_pt(prediction_output_dict, futures) + fde_errors = compute_fde_pt(prediction_output_dict, futures) + + eval_ret['ml_ade'] = ade_errors + eval_ret['ml_fde'] = fde_errors + + if y_dists is not None: + nll_means, nll_finals = compute_nll_pt(y_dists, futures) + min_ade_5, min_fde_5 = compute_min_afde_k_pt(y_dists, futures, k=5) + min_ade_10, min_fde_10 = compute_min_afde_k_pt(y_dists, futures, k=10) + + eval_ret['min_ade_5'] = min_ade_5 + eval_ret['min_fde_5'] = min_fde_5 + eval_ret['min_ade_10'] = min_ade_10 + eval_ret['min_fde_10'] = min_fde_10 + eval_ret['nll_mean'] = nll_means + eval_ret['nll_final'] = nll_finals + + return eval_ret + + +def scene_centric_prediction_metrics(scene_batch: SceneBatch, pred_outputs: Dict) -> Dict[str, torch.Tensor]: + + futures = scene_batch.agent_fut[:, :, :, :2].transpose(0, 1) # n, b, t, 2 + if "pred_dist_with_ego" not in pred_outputs or pred_outputs["pred_dist_with_ego"] is None: + return {} + pred_dist = pred_outputs["pred_dist_with_ego"] + + # Metrics for all agents + metrics_dict = compute_prediction_metrics(futures, prediction_output_dict=None, y_dists=pred_dist) + # Average over agents + metrics_dict = {k: v.mean(dim=0) for k, v in metrics_dict.items()} + + # Ego metrics + ego_pred_dist = GMM2D(pred_dist.log_pis[:1], pred_dist.mus[:1], pred_dist.log_sigmas[:1], pred_dist.corrs[:1]) + ego_metrics_dict = compute_prediction_metrics(futures[0], prediction_output_dict=None, y_dists=ego_pred_dist) + metrics_dict.update({"ego_"+k: v for k, v in ego_metrics_dict.items()}) + + # Move to cpu + metrics_dict = {k: v.detach().cpu() for k, v in metrics_dict.items()} + + return metrics_dict diff --git a/diffstack/utils/model_registrar.py b/diffstack/utils/model_registrar.py index 52a110f..1e5855d 100644 --- a/diffstack/utils/model_registrar.py +++ b/diffstack/utils/model_registrar.py @@ -1,6 +1,7 @@ import os import torch import torch.nn as nn +from pathlib import Path def get_model_device(model): @@ -65,6 +66,7 @@ def load_models(self, iter_num): self.load_model_from_file(save_path) def load_model_from_file(self, file_path, except_contains=()): + file_path = str(Path(file_path).expanduser()) print('\nLoading from ' + file_path) # Import error can happen here is trying to load checkpoint with old planner_cost object. @@ -72,7 +74,6 @@ def load_model_from_file(self, file_path, except_contains=()): # Alternatively, the old pred_metrics folder needs to be copied with environment/nuScenes_data/cost_functions.py # Same can happend with `model` which reqires the original trajectron++ code to be in the path. # sys.path.append('./trajectron/trajectron') - file_path = os.path.expanduser(file_path) new_model_dict = torch.load(file_path, map_location=self.device) # Selectively update parameters diff --git a/diffstack/utils/model_utils.py b/diffstack/utils/model_utils.py new file mode 100644 index 0000000..32f11e9 --- /dev/null +++ b/diffstack/utils/model_utils.py @@ -0,0 +1,772 @@ +import torch +import torch.nn.functional as F +import numpy as np +from diffstack.models.base_models import MLP +import torch.nn as nn +from diffstack.utils.geometry_utils import round_2pi, batch_proj_xysc, rel_xysc +import community as community_louvain +import networkx as nx +from scipy.sparse.csgraph import connected_components +from scipy.sparse import csr_matrix +import torch.distributions as td +import diffstack.utils.tensor_utils as TensorUtils +import math +from diffstack.models.TypeTransformer import CrossAttention + +class PED_PED_encode(nn.Module): + def __init__(self, obs_enc_dim, hidden_dim=[64]): + super(PED_PED_encode, self).__init__() + self.FC = MLP(10, obs_enc_dim, hidden_dim) + + def forward(self, x1, x2, size1, size2): + deltax = x2[..., 0:2] - x1[..., 0:2] + input = torch.cat((deltax, x1[..., 2:4], x2[..., 2:4], size1, size2), dim=-1) + return self.FC(input) + + +class PED_VEH_encode(nn.Module): + def __init__(self, obs_enc_dim, hidden_dim=[64]): + super(PED_VEH_encode, self).__init__() + self.FC = MLP(10, obs_enc_dim, hidden_dim) + + def forward(self, x1, x2, size1, size2): + deltax = x2[..., 0:2] - x1[..., 0:2] + veh_vel = torch.cat( + ( + torch.unsqueeze(x2[..., 2] * torch.cos(x2[..., 3]), dim=-1), + torch.unsqueeze(x2[..., 2] * torch.sin(x2[..., 3]), dim=-1), + ), + dim=-1, + ) + input = torch.cat((deltax, x1[..., 2:4], veh_vel, size1, size2), dim=-1) + return self.FC(input) + + +class VEH_PED_encode(nn.Module): + def __init__(self, obs_enc_dim, hidden_dim=[64]): + super(VEH_PED_encode, self).__init__() + self.FC = MLP(9, obs_enc_dim, hidden_dim) + + def forward(self, x1, x2, size1, size2): + dx0 = x2[..., 0:2] - x1[..., 0:2] + theta = x1[..., 3] + dx = torch.cat( + ( + torch.unsqueeze( + dx0[..., 0] * torch.cos(theta) + torch.sin(theta) * dx0[..., 1], + dim=-1, + ), + torch.unsqueeze( + dx0[..., 1] * torch.cos(theta) - torch.sin(theta) * dx0[..., 0], + dim=-1, + ), + ), + dim=-1, + ) + dv = torch.cat( + ( + torch.unsqueeze( + x2[..., 2] * torch.cos(theta) + + torch.sin(theta) * x2[..., 3] + - x1[..., 2], + dim=-1, + ), + torch.unsqueeze( + x2[..., 3] * torch.cos(theta) - torch.sin(theta) * x2[..., 2], + dim=-1, + ), + ), + dim=-1, + ) + input = torch.cat( + (dx, torch.unsqueeze(x1[..., 2], dim=-1), dv, size1, size2), dim=-1 + ) + return self.FC(input) + + +class VEH_VEH_encode(nn.Module): + def __init__(self, obs_enc_dim, hidden_dim=[64]): + super(VEH_VEH_encode, self).__init__() + self.FC = MLP(11, obs_enc_dim, hidden_dim) + + def forward(self, x1, x2, size1, size2): + dx0 = x2[..., 0:2] - x1[..., 0:2] + theta = x1[..., 3] + dx = torch.cat( + ( + torch.unsqueeze( + dx0[..., 0] * torch.cos(theta) + torch.sin(theta) * dx0[..., 1], + dim=-1, + ), + torch.unsqueeze( + dx0[..., 1] * torch.cos(theta) - torch.sin(theta) * dx0[..., 0], + dim=-1, + ), + ), + dim=-1, + ) + dtheta = x2[..., 3] - x1[..., 3] + dv = torch.cat( + ( + torch.unsqueeze(x2[..., 2] * torch.cos(dtheta) - x1[..., 2], dim=-1), + torch.unsqueeze(torch.sin(dtheta) * x2[..., 2], dim=-1), + ), + dim=-1, + ) + input = torch.cat( + ( + dx, + torch.unsqueeze(x1[..., 2], dim=-1), + dv, + torch.unsqueeze(torch.cos(dtheta), dim=-1), + torch.unsqueeze(torch.sin(dtheta), dim=-1), + size1, + size2, + ), + dim=-1, + ) + return self.FC(input) + + +def PED_rel_state(x, x0): + rel_x = torch.clone(x) + rel_x[..., 0:2] -= x0[..., 0:2] + return rel_x + + +def VEH_rel_state(x, x0): + rel_XY = x[..., 0:2] - x0[..., 0:2] + theta = x0[..., 3] + rel_x = torch.stack( + [ + rel_XY[..., 0] * torch.cos(theta) + rel_XY[..., 1] * torch.sin(theta), + rel_XY[..., 1] * torch.cos(theta) - rel_XY[..., 0] * torch.sin(theta), + x[..., 2], + x[..., 3] - x0[..., 3], + ], + dim=-1, + ) + rel_x[..., 3] = round_2pi(rel_x[..., 3]) + return rel_x + + +class PED_pre_encode(nn.Module): + def __init__(self, enc_dim, hidden_dim=[64], use_lane_info=False): + super(PED_pre_encode, self).__init__() + self.FC = MLP(4, enc_dim, hidden_dim) + + def forward(self, x): + return self.FC(x) + + +class VEH_pre_encode(nn.Module): + def __init__(self, enc_dim, hidden_dim=[64], use_lane_info=False): + super(VEH_pre_encode, self).__init__() + self.use_lane_info = use_lane_info + if use_lane_info: + self.FC = MLP(8, enc_dim, hidden_dim) + else: + self.FC = MLP(5, enc_dim, hidden_dim) + + def forward(self, x): + if self.use_lane_info: + input = torch.cat( + ( + x[..., 0:3], + torch.cos(x[..., 3:4]), + torch.sin(x[..., 3:4]), + x[..., 4:5], + torch.cos(x[..., 5:6]), + torch.sin(x[..., 5:6]), + ), + dim=-1, + ) + else: + input = torch.cat( + (x[..., 0:3], torch.cos(x[..., 3:]), torch.sin(x[..., 3:])), dim=-1 + ) + return self.FC(input) + +def break_graph(M, resol=1.0): + if isinstance(M, np.ndarray): + resol = resol * np.max(M) + G = nx.Graph() + for i in range(M.shape[0]): + G.add_node(i) + for i in range(M.shape[0]): + for j in range(i + 1, M.shape[0]): + if M[i, j] > 0: + G.add_edge(i, j, weight=M[i, j]) + partition = community_louvain.best_partition(G, resolution=resol) + elif isinstance(M, nx.classes.graph.Graph): + G = M + partition = community_louvain.best_partition(G, resolution=resol) + + while max(partition.values()) == 0 and resol >= 0.1: + resol = resol * 0.9 + partition = community_louvain.best_partition(G, resolution=resol) + return partition + + +def break_graph_recur(M, max_num): + n_components, labels = connected_components( + csgraph=csr_matrix(M), directed=False, return_labels=True + ) + idx = 0 + + while idx < n_components: + subset = np.where(labels == idx)[0] + if subset.shape[0] <= max_num: + idx += 1 + else: + partition = break_graph(M[np.ix_(subset, subset)]) + added_partition = 0 + for i in range(subset.shape[0]): + if partition[i] > 0: + labels[subset[i]] = n_components + partition[i] - 1 + added_partition = max(added_partition, partition[i]) + + n_components += added_partition + if added_partition == 0: + idx += 1 + + return n_components, labels + +def unpack_RNN_state(state_tuple): + # PyTorch returned LSTM states have 3 dims: + # (num_layers * num_directions, batch, hidden_size) + + state = torch.cat(state_tuple, dim=0).permute(1, 0, 2) + # Now state is (batch, 2 * num_layers * num_directions, hidden_size) + + state_size = state.size() + return torch.reshape(state, (-1, state_size[1] * state_size[2])) + +class Normal: + + def __init__(self, mu=None, logvar=None, params=None): + super().__init__() + if params is not None: + self.mu, self.logvar = torch.chunk(params, chunks=2, dim=-1) + else: + assert mu is not None + assert logvar is not None + self.mu = mu + self.logvar = logvar + self.sigma = torch.exp(0.5 * self.logvar) + + def rsample(self): + eps = torch.randn_like(self.sigma) + return self.mu + eps * self.sigma + + def sample(self): + return self.rsample() + + def kl(self, p=None): + """ compute KL(q||p) """ + if p is None: + kl = -0.5 * (1 + self.logvar - self.mu.pow(2) - self.logvar.exp()) + else: + term1 = (self.mu - p.mu) / (p.sigma + 1e-8) + term2 = self.sigma / (p.sigma + 1e-8) + kl = 0.5 * (term1 * term1 + term2 * term2) - 0.5 - torch.log(term2) + return kl + + def mode(self): + return self.mu + def pseudo_sample(self,n_sample): + sigma_points = torch.stack((self.mu,self.mu-self.sigma,self.mu+self.sigma),1) + if n_sample<=1: + return sigma_points[:,:n_sample] + else: + remain_n = n_sample-3 + sigma_tiled = self.sigma.unsqueeze(1).repeat_interleave(remain_n,1) + mu_tiled = self.mu.unsqueeze(1).repeat_interleave(remain_n,1) + sample = torch.randn_like(sigma_tiled)*sigma_tiled+mu_tiled + return torch.cat([sigma_points,sample],1) + +class Categorical: + + def __init__(self, probs=None, logits=None, temp=0.01): + super().__init__() + self.logits = logits + self.temp = temp + if probs is not None: + self.probs = probs + else: + assert logits is not None + self.probs = torch.softmax(logits, dim=-1) + self.dist = td.OneHotCategorical(self.probs) + + def rsample(self,n_sample=1): + relatex_dist = td.RelaxedOneHotCategorical(self.temp, self.probs) + return relatex_dist.rsample((n_sample,)).transpose(0,-2) + + def sample(self): + return self.dist.sample() + + def pseudo_sample(self,n_sample): + D = self.probs.shape[-1] + idx = self.probs.argsort(-1,descending=True) + assert n_sample<=D + return TensorUtils.to_one_hot(idx[...,:n_sample],num_class=D) + + + + def kl(self, p=None): + """ compute KL(q||p) """ + if p is None: + p = Categorical(logits=torch.zeros_like(self.probs)) + kl = td.kl_divergence(self.dist, p.dist) + return kl + + def mode(self): + argmax = self.probs.argmax(dim=-1) + one_hot = torch.zeros_like(self.probs) + one_hot.scatter_(1, argmax.unsqueeze(1), 1) + return one_hot + + + +def initialize_weights(modules): + for m in modules: + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + if m.bias is not None: nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + if m.bias is not None: nn.init.constant_(m.bias, 0) + +class AFMLP(nn.Module): + + def __init__(self, input_dim, hidden_dims=(128, 128), activation='tanh'): + super().__init__() + if activation == 'tanh': + self.activation = torch.tanh + elif activation == 'relu': + self.activation = torch.relu + elif activation == 'sigmoid': + self.activation = torch.sigmoid + + self.out_dim = hidden_dims[-1] + self.affine_layers = nn.ModuleList() + last_dim = input_dim + for nh in hidden_dims: + self.affine_layers.append(nn.Linear(last_dim, nh)) + last_dim = nh + + initialize_weights(self.affine_layers.modules()) + + def forward(self, x): + for affine in self.affine_layers: + x = self.activation(affine(x)) + return x + +def rotation_2d_torch(x, theta, origin=None): + if origin is None: + origin = torch.zeros(2).to(x.device).to(x.dtype) + norm_x = x - origin + norm_rot_x = torch.zeros_like(x) + norm_rot_x[..., 0] = norm_x[..., 0] * torch.cos(theta) - norm_x[..., 1] * torch.sin(theta) + norm_rot_x[..., 1] = norm_x[..., 0] * torch.sin(theta) + norm_x[..., 1] * torch.cos(theta) + rot_x = norm_rot_x + origin + return rot_x, norm_rot_x + + +class ExpParamAnnealer(nn.Module): + + def __init__(self, start, finish, rate, cur_epoch=0): + super().__init__() + self.register_buffer('start', torch.tensor(start)) + self.register_buffer('finish', torch.tensor(finish)) + self.register_buffer('rate', torch.tensor(rate)) + self.register_buffer('cur_epoch', torch.tensor(cur_epoch)) + + def step(self): + self.cur_epoch += 1 + + def set_epoch(self, epoch): + self.cur_epoch.fill_(epoch) + + def val(self): + return self.finish - (self.finish - self.start) * (self.rate ** self.cur_epoch) + +class IntegerParamAnnealer(nn.Module): + + def __init__(self, start, finish, length, cur_epoch=0): + super().__init__() + self.register_buffer('start', torch.tensor(start)) + self.register_buffer('finish', torch.tensor(finish)) + self.register_buffer('length', torch.tensor(length)) + self.register_buffer('cur_epoch', torch.tensor(cur_epoch)) + + def step(self): + self.cur_epoch += 1 + + def set_epoch(self, epoch): + self.cur_epoch.fill_(epoch) + + def val(self): + return self.finish if self.cur_epoch>=self.length else self.start+int((self.finish-self.start)*self.cur_epoch/self.length) + + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(8, channels) + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + +def Gaussian_RBF_conv(sigma,radius,device=None): + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + X = torch.range(-radius,radius) + Y = torch.range(-radius,radius) + XY = torch.meshgrid(X,Y) + dis_sq = XY[0]**2+XY[1]**2 + RBF = torch.exp(-dis_sq/(2*sigma**2)).to(device) + RBF = torch.nn.Parameter(RBF[None,None]/RBF.sum(),requires_grad=False) + net = nn.Conv2d(1,1,2*radius+1,bias=False,padding=radius,device=device) + net.weight = RBF + return net + +def agent2agent_edge(x1,x2,padto=None,scale=1,clip=[-np.inf,np.inf]): + # assumming x1 and x2 are state of the form [x,y,v,s,c,w,l,type_one_hot] + N1 = x1.shape[1] + N2 = x2.shape[1] + x1_flag = torch.logical_not((x1==0).all(-1)).type(x1.dtype) + x2_flag = torch.logical_not((x2==0).all(-1)).type(x2.dtype) + x1 = x1.unsqueeze(2).repeat_interleave(N2,2) + x2 = x2.unsqueeze(1).repeat_interleave(N1,1) + x1_xysc = x1[...,[0,1,3,4]] + x2_xysc = x2[...,[0,1,3,4]] + dx_xysc = rel_xysc(x1_xysc,x2_xysc)*(x1_flag.unsqueeze(2)*x2_flag.unsqueeze(1)).unsqueeze(-1) + dx_xy = (dx_xysc[...,:2]/scale).clip(min=clip[0],max=clip[1]) + + dx_xysc = torch.cat([dx_xy,dx_xysc[...,2:]],-1) + + v1 = x1[...,2:3] + v2 = x2[...,2:3] + edge = torch.cat([dx_xysc,v2*dx_xysc[...,2:3],v2*dx_xysc[...,3:4],x2[...,5:]],-1) + if padto is not None: + edge = torch.cat((edge,torch.zeros(*edge.shape[:-1],padto-edge.shape[-1],device=edge.device)),-1) + return edge + +def agent2lane_edge_proj(x1,x2,padto=None,scale=1,clip=[-np.inf,np.inf]): + # x1: [B,N,d], x2:[B,M,L*4] + x2 = x2.reshape(*x2.shape[:-1],-1,4) + N = x1.shape[1] + M = x2.shape[1] + x1_flag = torch.logical_not((x1==0).all(-1)).type(x1.dtype) + x2_flag = torch.logical_not((x2==0).all(-1)).type(x2.dtype) + dx = batch_proj_xysc(x1[:,:,None,[0,1,3,4]].repeat_interleave(M,2),x2[:,None].repeat_interleave(N,1)) + dx_xy = (dx[...,:2]/scale).clip(min=clip[0],max=clip[1]) + dx = torch.cat([dx_xy,dx[...,2:]],-1) + dx = (dx * x1_flag[:,:,None,None,None]*x2_flag[:,None,:,:,None]) + + min_idx = dx[...,0].abs().argmin(3) + min_pts = torch.gather(dx,3,min_idx[...,None,None].repeat(1,1,1,1,4)).squeeze(-2) + edge = torch.cat([min_pts,dx[:,:,:,0],dx[:,:,:,-1]],-1) + if padto is not None: + edge = torch.cat((edge,torch.zeros(*edge.shape[:-1],padto-edge.shape[-1],device=edge.device)),-1) + return edge + +def agent2lane_edge_per_pts(x1,x2,padto=None,scale=1,clip=[-np.inf,np.inf]): + # x1: [B,N,4], x2:[B,M,4] + B,N = x1.shape[:2] + M = x2.shape[1] + rel_coord = rel_xysc(x1.repeat_interleave(M,1),x2.repeat_interleave(N,0).reshape(B,N*M,4)).reshape(B,N,M,4) + + x1_flag = torch.logical_not((x1==0).all(-1)).type(x1.dtype) + x2_flag = torch.logical_not((x2==0).all(-1)).type(x2.dtype) + rel_xy = (rel_coord[...,:2]/scale).clip(min=clip[0],max=clip[1]) + rel_coord = torch.cat([rel_xy,rel_coord[...,2:]],-1) + rel_coord = rel_coord * x1_flag[:,:,None,None]*x2_flag[:,None,:,None] + + return rel_coord + +def lane2lane_edge(x1,x2,padto=None,scale=1,clip=[-np.inf,np.inf]): + # x1: [B,M,L*4], x2:[B,M,L*4] + x1 = x1.reshape(*x1.shape[:-1],-1,4).unsqueeze(2).repeat_interleave(x2.shape[1],2) + x2 = x2.reshape(*x2.shape[:-1],-1,4).unsqueeze(1).repeat_interleave(x1.shape[1],1) + + x1s = x1[...,0,:] + x1e = x1[...,-1,:] + x2s = x2[...,0,:] + x2e = x2[...,-1,:] + dx1 = rel_xysc(x1s,x2s) + dx2 = rel_xysc(x1s,x2e) + dx3 = rel_xysc(x1e,x2s) + dx4 = rel_xysc(x1e,x2e) + dx1 = torch.cat([(dx1[...,:2]/scale).clip(min=clip[0],max=clip[1]),dx1[...,2:]],-1) + dx2 = torch.cat([(dx2[...,:2]/scale).clip(min=clip[0],max=clip[1]),dx2[...,2:]],-1) + dx3 = torch.cat([(dx3[...,:2]/scale).clip(min=clip[0],max=clip[1]),dx3[...,2:]],-1) + dx4 = torch.cat([(dx4[...,:2]/scale).clip(min=clip[0],max=clip[1]),dx4[...,2:]],-1) + edge = torch.cat([dx1,dx2,dx3,dx4],-1) + if padto is not None: + edge = torch.cat((edge,torch.zeros(*edge.shape[:-1],padto-edge.shape[-1],device=edge.device)),-1) + return edge + +def edge_as_aux1(x1,x2): + # when the edge encoding is passed in via x1 + # x1: [B,N1,N2*d], x2:[B,N2,_] + B,N1 = x1.shape[:2] + N2 = x2.shape[1] + return x1.reshape(B,N1,N2,-1) + +def edge_as_aux2(x1,x2): + # when the edge encoding is passed in via x1 + # x1: [B,N1,_], x2:[B,N2,N1*d] + B,N1 = x1.shape[:2] + N2 = x2.shape[1] + return x2.reshape(B,N2,N1,-1).transpose(1,2) + +class Agent_emb(nn.Module): + def __init__(self, raw_dim,n_embd): + super(Agent_emb, self).__init__() + self.raw_dim = raw_dim + self.n_embd = n_embd + self.FC = nn.Linear(raw_dim, n_embd) + + def forward(self, input): + if input.size(-1) 0: + col_loss += torch.sum(torch.sigmoid(-dis*4) * scale*type_mask[et][:,None,:,None], dim=[2,3])/T + + return col_loss + + +def get_drivable_area_loss( + ego_trajectories, raster_from_agent, dis_map, ego_extents +): + """Cost for road departure.""" + with torch.no_grad(): + + lane_flags = rasterized_ROI_align( + dis_map, + ego_trajectories[..., :2], + ego_trajectories[..., 2:], + raster_from_agent, + torch.ones(*ego_trajectories.shape[:3] + ).to(ego_trajectories.device), + ego_extents.unsqueeze(1).repeat(1, ego_trajectories.shape[1], 1), + 1, + ).squeeze(-1) + return lane_flags.max(dim=-1)[0] + +def get_lane_loss_simple(ego_trajectories, raster_from_agent, dis_map): + h,w = dis_map.shape[-2:] + + raster_xy = GeoUtils.batch_nd_transform_points(ego_trajectories[...,:2],raster_from_agent) + raster_xy[...,0] = raster_xy[...,0].clip(0,w-1e-5) + raster_xy[...,1] = raster_xy[...,1].clip(0,h-1e-5) + raster_xy = raster_xy.long() + raster_xy_flat = (raster_xy[...,1]*w+raster_xy[...,0]) + raster_xy_flat = raster_xy_flat.flatten() + lane_loss = (dis_map.flatten()[raster_xy_flat]).reshape(*raster_xy.shape[:2]) + return lane_loss.max(dim=-1)[0] + +def get_terminal_likelihood_reward( + ego_trajectories, raster_from_agent, log_likelihood +): + """Cost for road departure.""" + + log_likelihood = (log_likelihood-log_likelihood.mean())/log_likelihood.std() + h,w = log_likelihood.shape[-2:] + + raster_xy = GeoUtils.batch_nd_transform_points(ego_trajectories[...,-1,:2],raster_from_agent) + raster_xy[...,0] = raster_xy[...,0].clip(0,w-1e-5) + raster_xy[...,1] = raster_xy[...,1].clip(0,h-1e-5) + raster_xy = raster_xy.long() + raster_xy_flat = (raster_xy[...,1]*w+raster_xy[...,0]) + + ll_reward = log_likelihood.flatten()[raster_xy_flat] + return ll_reward + +def get_progress_reward(ego_trajectories,d_sat = 10): + dis = torch.linalg.norm(ego_trajectories[...,-1,:2]-ego_trajectories[...,0,:2],dim=-1) + return 2/np.pi*torch.atan(dis/d_sat) + + +def get_total_distance(ego_trajectories): + """Reward that incentivizes progress.""" + # Assume format [..., T, 3] + assert ego_trajectories.shape[-1] == 3 + diff = ego_trajectories[..., 1:, :] - ego_trajectories[..., :-1, :] + dist = torch.norm(diff[..., :2], dim=-1) + total_dist = torch.sum(dist, dim=-1) + return total_dist + + +def ego_sample_planning( + ego_trajectories, + agent_trajectories, + ego_extents, + agent_extents, + raw_types, + raster_from_agent, + dis_map, + weights, + log_likelihood=None, + col_funcs=None, +): + """A basic cost function for prediction-and-planning""" + col_loss = get_collision_loss( + ego_trajectories, + agent_trajectories, + ego_extents, + agent_extents, + raw_types, + col_funcs, + ) + lane_loss = get_drivable_area_loss( + ego_trajectories, raster_from_agent, dis_map, ego_extents + ) + progress = get_total_distance(ego_trajectories) + + log_likelihood = 0 if log_likelihood is None else log_likelihood + if log_likelihood.ndim==3: + log_likelihood = get_terminal_likelihood_reward(ego_trajectories, raster_from_agent, log_likelihood) + + total_score = ( + + weights["likelihood_weight"] * log_likelihood + + weights["progress_weight"] * progress + - weights["collision_weight"] * col_loss + - weights["lane_weight"] * lane_loss + ) + + + return torch.argmax(total_score, dim=1) + + +class TreeMotionPolicy(object): + """ A trajectory tree policy as the result of contingency planning + + """ + def __init__(self,stage,num_frames_per_stage,ego_root,scenario_root,cost_to_go,leaf_idx,curr_node): + self.stage = stage + self.num_frames_per_stage = num_frames_per_stage + self.ego_root = ego_root + self.scenario_root = scenario_root + self.cost_to_go = cost_to_go + self.leaf_idx = leaf_idx + self.curr_node= curr_node + + def identify_branch(self,ego_node,scene_traj): + + assert scene_traj.shape[-2]=scene_traj.shape[-2] + + remain_traj = scene_traj + curr_scenario_node = self.scenario_root + ego_leaf_index = self.leaf_idx[ego_node] + while remain_traj.shape[1]>0: + seg_length = min(remain_traj.shape[-2],self.num_frames_per_stage) + dis = [torch.linalg.norm(child.traj[ego_leaf_index,:,:seg_length,:2]-remain_traj[:,:seg_length,:2],dim=-1).sum().item()\ + for child in curr_scenario_node.children] + idx = torch.argmin(torch.tensor(dis)).item() + + curr_scenario_node = curr_scenario_node.children[idx] + + remain_traj = remain_traj[...,seg_length:,:] + remain_num_frames = curr_scenario_node.traj.shape[-2]-seg_length + if curr_scenario_node.stage>=self.curr_node.stage: + break + return curr_scenario_node + + def get_plan(self,scene_traj,horizon): + if scene_traj is None: + T = 0 + remain_num_frames = self.num_frames_per_stage + else: + T = scene_traj.shape[-2] + remain_num_frames = self.curr_node.total_traj.shape[0]-1-T + assert remain_num_frames>-self.num_frames_per_stage + if remain_num_frames<=0: + assert not self.curr_node.isleaf() + curr_scenario_node = self.identify_branch(self.curr_node,scene_traj) + Q = [self.cost_to_go[(child,curr_scenario_node)] for child in self.curr_node.children] + idx = torch.argmin(torch.tensor(Q)).item() + self.curr_node = self.curr_node.children[idx] + remain_num_frames+=self.curr_node.traj.shape[0] + traj = self.curr_node.traj[-remain_num_frames:,TRAJ_INDEX] + if not self.curr_node.isleaf(): + traj = torch.cat((traj,self.curr_node.children[0].traj[:,TRAJ_INDEX]),-2) + if traj.shape[0]>=horizon: + return traj[:horizon] + else: + traj_patched = torch.cat((traj,traj[-1].tile(horizon-traj.shape[0],1))) + return traj_patched + +class VectorizedTreeMotionPolicy(TreeMotionPolicy): + """ A vectorized trajectory tree policy as the result of contingency planning + + """ + def __init__(self,stage,num_frames_per_stage,ego_tree,children_indices,scenario_tree,cost_to_go,leaf_idx,curr_node): + self.stage = stage + self.num_frames_per_stage = num_frames_per_stage + self.ego_tree = ego_tree + self.ego_root = ego_tree[0][0] + self.children_indices = children_indices + self.scenario_tree = scenario_tree + self.scenario_root = scenario_tree[0][0] + self.cost_to_go = cost_to_go + self.leaf_idx = leaf_idx + self.curr_node= curr_node + + def identify_branch(self,ego_node,scene_traj): + + assert scene_traj.shape[-2]=scene_traj.shape[-2] + + remain_traj = scene_traj + curr_scenario_node = self.scenario_root + + stage = ego_node.depth + ego_stage_index = self.ego_tree[stage].index(ego_node) + ego_leaf_index = self.leaf_idx[stage][ego_stage_index].item() + while remain_traj.shape[1]>0: + seg_length = min(remain_traj.shape[-2],self.num_frames_per_stage) + dis = [torch.linalg.norm(child.traj[ego_leaf_index,:,:seg_length,:2]-remain_traj[:,:seg_length,:2],dim=-1).sum().item()\ + for child in curr_scenario_node.children] + idx = torch.argmin(torch.tensor(dis)).item() + + curr_scenario_node = curr_scenario_node.children[idx] + + remain_traj = remain_traj[...,seg_length:,:] + remain_num_frames = curr_scenario_node.traj.shape[-2]-seg_length + if curr_scenario_node.stage>=self.curr_node.stage: + break + return curr_scenario_node + + def get_plan(self,scene_traj,horizon): + if scene_traj is None: + T = 0 + remain_num_frames = self.num_frames_per_stage + else: + T = scene_traj.shape[-2] + remain_num_frames = self.curr_node.total_traj.shape[0]-1-T + assert remain_num_frames>-self.num_frames_per_stage + if remain_num_frames<=0: + assert not self.curr_node.isleaf() + curr_scenario_node = self.identify_branch(self.curr_node,scene_traj) + assert curr_scenario_node.depth==self.curr_node.depth + stage = self.curr_node.depth + scene_node_idx = self.scenario_tree[stage].index(curr_scenario_node) + curr_node_idx = self.ego_tree[stage].index(self.curr_node) + Q = self.cost_to_go[stage][scene_node_idx,self.children_indices[stage][curr_node_idx]] + idx = torch.argmin(Q).item() + self.curr_node = self.curr_node.children[idx] + remain_num_frames+=self.curr_node.traj.shape[0] + + traj = self.curr_node.traj[-remain_num_frames:,TRAJ_INDEX] + if not self.curr_node.isleaf(): + traj = torch.cat((traj,self.curr_node.children[0].traj[:,TRAJ_INDEX]),-2) + if traj.shape[0]>=horizon: + return traj[:horizon] + else: + traj_patched = torch.cat((traj,traj[-1].tile(horizon-traj.shape[0],1))) + return traj_patched + + +def tiled_to_tree(total_traj,prob,num_stage,num_frames_per_stage,M): + """Turning a trajectory tree in tiled form to a tree data structure + + Args: + total_traj (torch.tensor or np.ndarray): tiled trajectory tree + prob (torch.tensor or np.ndarray): probability of the modes + num_stage (int): number of layers of the tree + num_frames_per_stage (int): number of time frames per layer + M (int): branching factor + + Returns: + nodes (dict[int:List(AgentTrajTree)]): all branches of the trajectory tree nodes indexed by layer + """ + + # total_traj = TensorUtils.reshape_dimensions_single(total_traj,2,3,[M]*num_stage) + x0 = AgentTrajTree(None, None, 0) + nodes = defaultdict(lambda:list()) + nodes[0].append(x0) + for t in range(num_stage): + interval = M**(num_stage-t-1) + tiled_traj = total_traj[...,::interval,:,t*num_frames_per_stage:(t+1)*num_frames_per_stage,:] + for i in range(M**(t+1)): + parent_idx = int(i/M) + p = prob[:,i*interval:(i+1)*interval].sum(-1) + node = AgentTrajTree(tiled_traj[:,i], nodes[t][parent_idx], t + 1, prob=p) + nodes[t+1].append(node) + return nodes + + +def contingency_planning(ego_tree, + ego_extents, + agent_traj, + mode_prob, + agent_extents, + agent_types, + raster_from_agent, + dis_map, + weights, + num_frames_per_stage, + M, + dt, + col_funcs=None, + log_likelihood=None, + pert_std = None): + """A sampling-based contingency planning algorithm + + Args: + ego_tree (Dict[int:List[TrajTree]]): ego trajectory tree + ego_extents (Tensor): [2] + agent_traj (Tensor): EC x M^s x Na x (s*Ts) x 3 scenario tree predicted + mode_prob (Tensor): EC x M^s + agent_extents (Tensor): Na x 2 + agent_types (Tensor): Na + raster_from_agent (Tensor): 3 x 3 + dis_map (Tensor): 224 x 224 (same as rasterized feature dimension) distance map + weights (Dict): weights of various costs + num_frames_per_stage (int): Ts + M (int): branching factor + col_funcs (function handle, optional): in case custom collision function is used. Defaults to None. + + Returns: + TreeMotionPolicy: optimal motion policy + """ + + num_stage = len(ego_tree)-1 + ego_root = ego_tree[0][0] + device = agent_traj.device + leaf_idx = defaultdict(lambda:list()) + for stage in range(num_stage,-1,-1): + for node in ego_tree[stage]: + if node.isleaf(): + leaf_idx[node]=[ego_tree[stage].index(node)] + else: + leaf_idx[node] = [] + for child in node.children: + leaf_idx[node] = leaf_idx[node]+leaf_idx[child] + + + V = dict() + L = dict() + Q = dict() + scenario_tree = tiled_to_tree(agent_traj,mode_prob,num_stage,num_frames_per_stage,M) + scenario_root = scenario_tree[0][0] + v0 = ego_root.traj[0,2] + d_sat = v0.clip(min=2.0)*num_frames_per_stage*dt + for stage in range(num_stage,0,-1): + if stage==0: + total_loss = torch.zeros([1,1],device=device) + else: + ego_nodes = ego_tree[stage] + indices = [leaf_idx[node][0] for node in ego_nodes] + ego_traj = [node.traj[:,TRAJ_INDEX] for node in ego_nodes] + ego_traj = torch.stack(ego_traj,0) + agent_nodes = scenario_tree[stage] + agent_traj = [node.traj[indices] for node in agent_nodes] + agent_traj = torch.stack(agent_traj,0) + ego_traj_tiled = ego_traj.unsqueeze(0).repeat(len(agent_nodes),1,1,1) + col_loss = get_collision_loss(ego_traj_tiled, + agent_traj, + ego_extents.tile(len(agent_nodes),1), + agent_extents.tile(len(agent_nodes),1,1), + agent_types.tile(len(agent_nodes),1), + col_funcs, + ) + + + lane_loss = get_drivable_area_loss(ego_traj.unsqueeze(0), raster_from_agent.unsqueeze(0), dis_map.unsqueeze(0), ego_extents.unsqueeze(0)) + # lane_loss = get_lane_loss_simple(ego_traj,raster_from_agent,dis_map).unsqueeze(0) + + progress_reward = get_progress_reward(ego_traj,d_sat=d_sat) + + total_loss = weights["collision_weight"]*col_loss+weights["lane_weight"]*lane_loss-weights["progress_weight"]*progress_reward.unsqueeze(0) + if pert_std is not None: + total_loss +=torch.randn(total_loss.shape[1],device=device).unsqueeze(0)*pert_std + if log_likelihood is not None and stage==num_stage: + ll_reward = get_terminal_likelihood_reward(ego_traj, raster_from_agent, log_likelihood) + total_loss = total_loss-weights["likelihood_weight"]*ll_reward + + + for i in range(len(ego_nodes)): + for j in range(len(agent_nodes)): + L[(ego_nodes[i],agent_nodes[j])] = total_loss[j,i] + if stage==num_stage: + V[(ego_nodes[i],agent_nodes[j])] = float(total_loss[j,i]) + else: + children_cost_to_go = [Q[(child,agent_nodes[j])] for child in ego_nodes[i].children] + V[(ego_nodes[i],agent_nodes[j])] = float(total_loss[j,i])+min(children_cost_to_go) + + if stage>0: + for agent_node in scenario_tree[stage-1]: + cost_i = [] + prob_i = [] + for child in agent_node.children: + cost_i.append(V[ego_nodes[i],child]) + prob_i.append(child.prob[leaf_idx[ego_nodes[i]]].sum()) + cost_i = torch.tensor(cost_i,device=device) + prob_i = torch.stack(prob_i) + Q[(ego_nodes[i],agent_node)] = float((cost_i*prob_i).sum()/prob_i.sum()) + Q_root = [Q[(child,scenario_root)] for child in ego_root.children] + idx = torch.argmin(torch.tensor(Q_root)).item() + optimal_node = ego_root.children[idx] + motion_policy = TreeMotionPolicy(num_stage, + num_frames_per_stage, + ego_root, + scenario_root, + Q, + leaf_idx, + optimal_node) + motion_policy.get_plan(None,num_stage*num_frames_per_stage) + return motion_policy + +def contingency_planning_parallel(ego_tree, + ego_extents, + agent_traj, + mode_prob, + agent_extents, + agent_types, + raster_from_agent, + dis_map, + weights, + num_frames_per_stage, + M, + dt, + col_funcs=None, + log_likelihood=None, + pert_std = None): + """A sampling-based contingency planning algorithm + + Args: + ego_tree (Dict[int:List[TrajTree]]): ego trajectory tree + ego_extents (Tensor): [2] + agent_traj (Tensor): EC x M^s x Na x (s*Ts) x 3 scenario tree predicted + mode_prob (Tensor): EC x M^s + agent_extents (Tensor): Na x 2 + agent_types (Tensor): Na + raster_from_agent (Tensor): 3 x 3 + dis_map (Tensor): 224 x 224 (same as rasterized feature dimension) distance map + weights (Dict): weights of various costs + num_frames_per_stage (int): Ts + M (int): branching factor + col_funcs (function handle, optional): in case custom collision function is used. Defaults to None. + + Returns: + TreeMotionPolicy: optimal motion policy + """ + device=agent_traj.device + children_indices = TrajTree.get_children_index_torch(ego_tree) + num_stage = len(ego_tree)-1 + ego_root = ego_tree[0][0] + + leaf_idx = {num_stage:torch.arange(len(ego_tree[num_stage]),device=device)} + stage_prob = {num_stage:mode_prob.T} + for stage in range(num_stage-1,-1,-1): + leaf_idx[stage] = leaf_idx[stage+1][children_indices[stage][:,0]] + prob_next = stage_prob[stage+1] + stage_prob[stage] = prob_next.reshape(-1,M,prob_next.shape[-1])[:,:,children_indices[stage][:,0]].sum(1) + + V = dict() + L = dict() + Q = dict() + + + scenario_tree = tiled_to_tree(agent_traj,mode_prob,num_stage,num_frames_per_stage,M) + scenario_root = scenario_tree[0][0] + v0 = ego_root.traj[0,2] + d_sat = v0.clip(min=2.0)*num_frames_per_stage*dt + + + + for stage in range(num_stage,-1,-1): + if stage==0: + total_loss = torch.zeros([1,1],device=device) + else: + #calculate stage cost + ego_nodes = ego_tree[stage] + ego_traj = [node.traj[:,TRAJ_INDEX] for node in ego_nodes] + ego_traj = torch.stack(ego_traj,0) + agent_nodes = scenario_tree[stage] + + agent_traj = [node.traj[leaf_idx[stage]] for node in agent_nodes] + + agent_traj = torch.stack(agent_traj,0) + ego_traj_tiled = ego_traj.unsqueeze(0).repeat(len(agent_nodes),1,1,1) + col_loss = get_collision_loss(ego_traj_tiled, + agent_traj, + ego_extents.tile(len(agent_nodes),1), + agent_extents.tile(len(agent_nodes),1,1), + agent_types.tile(len(agent_nodes),1), + col_funcs, + ) + + + lane_loss = get_drivable_area_loss(ego_traj.unsqueeze(0), raster_from_agent.unsqueeze(0), dis_map.unsqueeze(0), ego_extents.unsqueeze(0)) + # lane_loss = get_lane_loss_simple(ego_traj,raster_from_agent,dis_map).unsqueeze(0) + + progress_reward = get_progress_reward(ego_traj,d_sat=d_sat) + + total_loss = weights["collision_weight"]*col_loss+weights["lane_weight"]*lane_loss-weights["progress_weight"]*progress_reward.unsqueeze(0) + if pert_std is not None: + total_loss +=torch.randn(total_loss.shape[1],device=device).unsqueeze(0)*pert_std + if log_likelihood is not None and stage==num_stage: + ll_reward = get_terminal_likelihood_reward(ego_traj, raster_from_agent, log_likelihood) + total_loss = total_loss-weights["likelihood_weight"]*ll_reward + + L[stage] = total_loss + if stage==num_stage: + V[stage] = total_loss + else: + children_idx = children_indices[stage] + # add the last Q value as inf since empty children index are padded with -1 + Q_prime = torch.cat((Q[stage],torch.full([Q[stage].shape[0],1],np.inf,device=device)),1) + Q_by_node = Q_prime[:,children_idx] + V[stage] = total_loss+Q_by_node.min(dim=-1)[0] + + if stage>0: + children_V = V[stage] + children_V = children_V.reshape(-1,M,children_V.shape[-1]) + prob = stage_prob[stage] + prob = prob.reshape(-1,M,prob.shape[-1]) + prob_normalized = prob/prob.sum(dim=1,keepdim=True) + Q[stage-1] = (children_V*prob_normalized).sum(dim=1) + + idx = Q[0].argmin().item() + + motion_policy = VectorizedTreeMotionPolicy(num_stage, + num_frames_per_stage, + ego_tree, + children_indices, + scenario_tree, + Q, + leaf_idx, + ego_root.children[idx]) + return motion_policy + +def one_shot_planning(ego_tree, + ego_extents, + agent_traj, + mode_prob, + agent_extents, + agent_types, + raster_from_agent, + dis_map, + weights, + num_frames_per_stage, + M, + dt, + col_funcs=None, + log_likelihood=None, + pert_std = None, + strategy="all"): + + """Alternative of contingency planning, try to avoid all predicted trajectories + + Args: + ego_tree (Dict[int:List[TrajTree]]): ego trajectory tree + ego_extents (Tensor): [2] + agent_traj (Tensor): EC x M^s x Na x (s*Ts) x 3 scenario tree predicted + mode_prob (Tensor): EC x M^s + agent_extents (Tensor): Na x 2 + agent_types (Tensor): Na + raster_from_agent (Tensor): 3 x 3 + dis_map (Tensor): 224 x 224 (same as rasterized feature dimension) distance map + weights (Dict): weights of various costs + num_frames_per_stage (int): Ts + M (int): branching factor + col_funcs (function handle, optional): in case custom collision function is used. Defaults to None. + + Returns: + TreeMotionPolicy: optimal motion policy + """ + assert strategy=="all" or strategy=="maximum" + num_stage = len(ego_tree)-1 + ego_root = ego_tree[0][0] + + ego_traj = [node.total_traj[1:,TRAJ_INDEX] for node in ego_tree[num_stage]] + ego_traj = torch.stack(ego_traj,0) + ego_traj_tiled = ego_traj.unsqueeze(1).repeat_interleave(agent_traj.shape[1],1) + Ne = ego_traj.shape[0] + if strategy=="maximum": + idx = mode_prob.argmax(dim=1) + idx = idx.reshape(Ne,*[1]*(agent_traj.ndim-1)) + agent_traj = agent_traj.take_along_dim(idx,1) + col_loss = get_collision_loss(ego_traj_tiled,agent_traj,ego_extents.tile(Ne,1),agent_extents.tile(Ne,1,1),agent_types.tile(Ne,1),col_funcs) + col_loss = col_loss.max(dim=1)[0] + lane_loss = get_drivable_area_loss(ego_traj.unsqueeze(0), raster_from_agent.unsqueeze(0), dis_map.unsqueeze(0), ego_extents.unsqueeze(0)).squeeze(0) + v0 = ego_root.traj[0,2] + d_sat = v0.clip(min=2.0)*num_frames_per_stage*dt + progress_reward = get_progress_reward(ego_traj,d_sat=d_sat) + total_loss = weights["collision_weight"]*col_loss+weights["lane_weight"]*lane_loss-weights["progress_weight"]*progress_reward + if pert_std is not None: + total_loss +=torch.randn(total_loss.shape[0],device=total_loss.device)*pert_std + if log_likelihood is not None: + ll_reward = get_terminal_likelihood_reward(ego_traj, raster_from_agent, log_likelihood) + total_loss = total_loss-weights["likelihood_weight"]*ll_reward + + idx = total_loss.argmin() + return ego_traj[idx] + + +def obtain_ref(line, x, v, N, dt): + """obtain desired trajectory for the MPC controller + + Args: + line (np.ndarray): centerline of the lane [n, 3] + x (np.ndarray): position of the vehicle + v (np.ndarray): desired velocity + N (int): number of time steps + dt (float): time step + + Returns: + refx (np.ndarray): desired trajectory [N,3] + """ + line_length = line.shape[0] + delta_x = line[..., 0:2] - np.repeat(x[..., np.newaxis, 0:2], line_length, axis=-2) + dis = np.linalg.norm(delta_x, axis=-1) + idx = np.argmin(dis, axis=-1) + line_min = line[idx] + dx = x[0] - line_min[0] + dy = x[1] - line_min[1] + delta_y = -dx * np.sin(line_min[2]) + dy * np.cos(line_min[2]) + delta_x = dx * np.cos(line_min[2]) + dy * np.sin(line_min[2]) + refx0 = np.array( + [ + line_min[0] + delta_x * np.cos(line_min[2]), + line_min[1] + delta_x * np.sin(line_min[2]), + line_min[2], + ] + ) + s = [np.linalg.norm(line[idx + 1, 0:2] - refx0[0:2])] + for i in range(idx + 2, line_length): + s.append(s[-1] + np.linalg.norm(line[i, 0:2] - line[i - 1, 0:2])) + f = interp1d( + np.array(s), + line[idx + 1 :], + kind="linear", + axis=0, + copy=True, + bounds_error=False, + fill_value="extrapolate", + assume_sorted=True, + ) + s1 = v * np.arange(1, N + 1) * dt + refx = f(s1) + + return refx + +def Unicycle_braking_traj(x0,dt,T,brake_acc): + """generate braking trajectory for the unicycle model + + Args: + x0 (np.ndarray): initial state + dt (float): time step + T (int): total time step + brake_acc (float): deceleration + + Returns: + np.ndarray: braking trajectory + """ + if isinstance(x0,np.ndarray): + x = np.zeros((T+1,4)) + x[0] = x0 + x[:,3]=x0[3] + x[:,2] = np.clip(x0[2]+dt*np.arange(T+1)*brake_acc,a_min=0,a_max=np.inf) + x[1:,:2] = x[0:1,:2] + np.cumsum(x[:-1,2])[:,None]*dt*np.array([[np.cos(x0[3]),np.sin(x0[3])]]) + + elif isinstance(x0,torch.Tensor): + x = torch.zeros((T+1,4)) + x[0] = x0 + x[:,3]=x0[3] + x[:,2] = torch.clip(x0[2]+dt*torch.arange(T+1)*brake_acc,min=0,max=np.inf) + x[1:,:2] = x[0:1,:2] + torch.cumsum(x[:-1,2],0)[:,None]*dt*torch.tensor([[torch.cos(x0[3]),torch.sin(x0[3])]]) + return x[1:] + + +def test(): + x0=torch.tensor([4,5,7.0,0]) + dt=0.1 + T=30 + brake_acc=-5.0 + x=Unicycle_braking_traj(x0,dt,T,brake_acc) + print(x) + +if __name__ == "__main__": + test() \ No newline at end of file diff --git a/diffstack/utils/pred_utils.py b/diffstack/utils/pred_utils.py index e77deb9..5a38a5d 100644 --- a/diffstack/utils/pred_utils.py +++ b/diffstack/utils/pred_utils.py @@ -7,7 +7,8 @@ from scipy.optimize import minimize from collections import defaultdict from typing import Dict, Union, Tuple, Any, Optional, Iterable - +from trajdata.data_structures.batch import PadDirection, SceneBatch +from diffstack.utils.utils import batch_select def compute_ade_pt(predicted_trajs, gt_traj): @@ -66,3 +67,39 @@ def compute_prediction_metrics(prediction_output_dict, 'nll_mean': nll_means, 'nll_final': nll_finals} + +def split_predicted_agent_extents( + batch: SceneBatch, + num_dist_agents: int = 1, + num_single_agents: Optional[int] = None, + max_num_agents: Optional[int] = None +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Get extents from scene batch. + + Returns three tensors: + - ego_extents + - pred_agent_extents + - ml_agent_extents + """ + # TODO it should be the responsibility of the predictor to associate agent to its role. + # Here we just assume a fixed order + robot_ind = batch.extras["robot_ind"] + pred_ind = batch.extras["pred_agent_ind"] + assert (robot_ind == 0).all() and (pred_ind == 1).all(), "Agent roles are assumed to be hardcoded" + assert batch.history_pad_dir == PadDirection.AFTER + if num_single_agents is None: + if max_num_agents is None: + raise ValueError("Must specify either num_ml_agents or max_num_agents") + num_single_agents = max_num_agents + 1 - num_dist_agents + + agent_extent = batch_select(batch.agent_hist_extent, index=batch.agent_hist_len-1, batch_dims=2) # b, N, t, (length, width) + ego_extent = agent_extent[:, 0] + dist_extents = agent_extent[:, 1:1+num_dist_agents] + dist_extents = torch.nn.functional.pad(dist_extents, (0, 0, 0, num_dist_agents-dist_extents.shape[1]), 'constant', torch.nan) + + single_extents = agent_extent[:, 1+num_dist_agents:(1+num_dist_agents+num_single_agents)] + single_extents = torch.nn.functional.pad(single_extents, (0, 0, 0, num_single_agents-single_extents.shape[1]), 'constant', torch.nan) + # ml_extents = agent_extent[:, 2:(2+MAX_PLAN_NEIGHBORS+1)] + # ml_extents = torch.nn.functional.pad(ml_extents, (0, 0, 0, MAX_PLAN_NEIGHBORS+1-ml_extents.shape[1]), 'constant', torch.nan) + + return ego_extent, dist_extents, single_extents diff --git a/diffstack/utils/rollout_logger.py b/diffstack/utils/rollout_logger.py new file mode 100644 index 0000000..c2665b5 --- /dev/null +++ b/diffstack/utils/rollout_logger.py @@ -0,0 +1,226 @@ +from collections import defaultdict +import numpy as np +from copy import deepcopy + +import diffstack.utils.tensor_utils as TensorUtils +from diffstack.policies.common import RolloutAction + +from torch.nn.utils.rnn import pad_sequence +class RolloutLogger(object): + """Log trajectories and other essential info during rollout for visualization and evaluation""" + def __init__(self, obs_keys=None): + if obs_keys is None: + obs_keys = dict() + self._obs_keys = obs_keys + self._scene_indices = None + self._agent_id_per_scene = dict() + self._agent_data_by_scene = dict() + self._scene_ts = defaultdict(lambda:0) + + def _combine_obs(self, obs): + combined = dict() + excluded_keys = ["extras"] + if "ego" in obs and obs["ego"] is not None: + combined.update(obs["ego"]) + if "agents" in obs and obs["agents"] is not None: + for k in obs["agents"].keys(): + if k in combined and k not in excluded_keys: + if obs["agents"][k] is not None: + if combined[k] is not None: + combined[k] = np.concatenate((combined[k], obs["agents"][k]), axis=0) + else: + combined[k] = obs["agents"][k] + else: + combined[k] = obs["agents"][k] + return combined + + def _combine_action(self, action: RolloutAction): + combined = dict(action=dict()) + if action.has_ego and not action.has_agents: + combined["action"] = action.ego.to_dict() + if action.ego_info is not None and "action_samples" in action.ego_info: + combined["action_samples"] = action.ego_info["action_samples"] + return combined + + elif action.has_agents and not action.has_ego: + combined["action"] = action.agents.to_dict() + if action.agents_info is not None and "action_samples" in action.agents_info: + combined["action_samples"] = action.agents_info["action_samples"] + return combined + elif action.has_agents and action.has_ego: + Nego = action.ego.positions.shape[0] + Nagents = action.agents.positions.shape[0] + combined["action"] = dict() + agents_action = action.agents.to_dict() + ego_action = action.ego.to_dict() + for k in agents_action: + if k in ego_action: + combined["action"][k] = np.concatenate((ego_action[k], agents_action[k]), axis=0) + if action.agents_info is not None and action.ego_info is not None: + if "action_samples" in action.ego_info: + ego_samples = action.ego_info["action_samples"] + else: + ego_samples = None + if "action_samples" in action.agents_info: + agents_samples = action.agents_info["action_samples"] + else: + agents_samples = None + if ego_samples is not None and agents_samples is None: + combined["action_samples"] = dict() + for k in ego_samples: + pad_k = np.zeros([Nagents,*ego_samples[k].shape[1:]]) + combined["action_samples"][k]=np.concatenate((ego_samples[k],pad_k),0) + elif ego_samples is None and agents_samples is not None: + combined["action_samples"] = dict() + for k in agents_samples: + pad_k = np.zeros([Nego,*agents_samples[k].shape[1:]]) + combined["action_samples"][k]=np.concatenate((pad_k,agents_samples[k]),0) + elif ego_samples is not None and agents_samples is not None: + combined["action_samples"] = dict() + for k in ego_samples: + if k in agents_samples: + if ego_samples[k].shape[1]>agents_samples[k].shape[1]: + pad_k = np.zeros([Nagents,ego_samples[k].shape[1]-agents_samples[k].shape[1],*agents_samples[k].shape[2:]]) + agents_samples[k]=np.concatenate((agents_samples[k],pad_k),1) + elif ego_samples[k].shape[1]0: + default_val = list(self._agent_data_by_scene[si][k][ti].values())[0] + ti_k = list() + for ts in range(self._scene_ts[si]): + ti_k.append(self._agent_data_by_scene[si][k][ti][ts] if ts in self._agent_data_by_scene[si][k][ti] else np.ones_like(default_val)*np.nan) + default_val = ti_k[-1] + if not all(elem.shape==ti_k[0].shape for elem in ti_k): + # requires padding + if np.issubdtype(ti_k[0].dtype,np.floating): + padding_value = np.nan + else: + padding_value = 0 + ti_k = [x[0] for x in ti_k] + ti_k_torch = TensorUtils.to_tensor(ti_k,ignore_if_unspecified=True) + + ti_k_padded = pad_sequence(ti_k_torch,padding_value=padding_value,batch_first=True) + serialized[si][k].append(TensorUtils.to_numpy(ti_k_padded)[np.newaxis,:]) + else: + if ti_k[0].ndim==0: + serialized[si][k].append(np.array(ti_k)[np.newaxis,:]) + else: + serialized[si][k].append(np.concatenate(ti_k,axis=0)[np.newaxis,:]) + else: + serialized[si][k].append(np.zeros_like(serialized[si][k][-1])) + if not all(elem.shape==serialized[si][k][0].shape for elem in serialized[si][k]): + # requires padding + if np.issubdtype(serialized[si][k][0][0].dtype,np.floating): + padding_value = np.nan + else: + padding_value = 0 + axes=[1,0]+np.arange(2,serialized[si][k][0].ndim-1).tolist() + mk_transpose = [np.transpose(x[0],axes) for x in serialized[si][k]] + mk_torch = TensorUtils.to_tensor(mk_transpose,ignore_if_unspecified=True) + mk_padded = pad_sequence(mk_torch,padding_value=padding_value) + mk = TensorUtils.to_numpy(mk_padded) + axes=[1,2,0]+np.arange(3,mk.ndim).tolist() + serialized[si][k]=np.transpose(mk,axes) + else: + serialized[si][k] = np.concatenate(serialized[si][k],axis=0) + + + + self._serialized_scene_buffer = serialized + return deepcopy(self._serialized_scene_buffer) + + def get_trajectory(self): + """Get per-scene rollout trajectory in the world coordinate system""" + buffer = self.get_serialized_scene_buffer() + traj = dict() + for si in buffer: + traj[si] = dict( + positions=buffer[si]["centroid"], + yaws=buffer[si]["world_yaw"] + ) + return traj + + def get_track_id(self): + return deepcopy(self._agent_id_per_scene) + + def get_stats(self): + # TODO + raise NotImplementedError() + + def log_step(self, obs, action: RolloutAction): + combined_obs = self._combine_obs(obs) + combined_action = self._combine_action(action) + assert combined_obs["scene_index"].shape[0] == combined_action["action"]["positions"].shape[0] + self._maybe_initialize(combined_obs) + self._append_buffer(combined_obs, combined_action) + for si in np.unique(combined_obs["scene_index"]): + self._scene_ts[si]+=1 + del combined_obs diff --git a/diffstack/utils/sys_utils.py b/diffstack/utils/sys_utils.py new file mode 100644 index 0000000..370fa6a --- /dev/null +++ b/diffstack/utils/sys_utils.py @@ -0,0 +1,12 @@ +import os + + +def delete_files_in_directory(directory_path): + try: + files = os.listdir(directory_path) + for file in files: + file_path = os.path.join(directory_path, file) + if os.path.isfile(file_path): + os.remove(file_path) + except OSError: + print("Error occurred while deleting files.") diff --git a/diffstack/utils/tensor_utils.py b/diffstack/utils/tensor_utils.py new file mode 100644 index 0000000..7ea4bff --- /dev/null +++ b/diffstack/utils/tensor_utils.py @@ -0,0 +1,1189 @@ +""" +A collection of utilities for working with nested tensor structures consisting +of numpy arrays and torch tensors. +""" +import collections +from tracemalloc import start +import numpy as np +import torch +import torch.nn as nn +import sys + + +def recursive_dict_list_tuple_apply(x, type_func_dict, ignore_if_unspecified=False): + """ + Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of + {data_type: function_to_apply}. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + type_func_dict (dict): a mapping from data types to the functions to be + applied for each data type. + ignore_if_unspecified (bool): ignore an item if its type is unspecified by the type_func_dict + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + assert list not in type_func_dict + assert tuple not in type_func_dict + assert dict not in type_func_dict + assert nn.ParameterDict not in type_func_dict + assert nn.ParameterList not in type_func_dict + + if isinstance(x, (dict, collections.OrderedDict, nn.ParameterDict)): + new_x = ( + collections.OrderedDict() + if isinstance(x, collections.OrderedDict) + else dict() + ) + for k, v in x.items(): + new_x[k] = recursive_dict_list_tuple_apply( + v, type_func_dict, ignore_if_unspecified + ) + return new_x + elif isinstance(x, (list, tuple, nn.ParameterList)): + ret = [ + recursive_dict_list_tuple_apply(v, type_func_dict, ignore_if_unspecified) + for v in x + ] + if isinstance(x, tuple): + ret = tuple(ret) + return ret + else: + for t, f in type_func_dict.items(): + if isinstance(x, t): + return f(x) + else: + if ignore_if_unspecified: + return x + else: + raise NotImplementedError("Cannot handle data type %s" % str(type(x))) + + +def map_tensor(x, func): + """ + Apply function @func to torch.Tensor objects in a nested dictionary or + list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + func (function): function to apply to each tensor + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: func, + type(None): lambda x: x, + }, + ) + + +def map_ndarray(x, func): + """ + Apply function @func to np.ndarray objects in a nested dictionary or + list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + func (function): function to apply to each array + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + np.ndarray: func, + str: lambda x: x, + type(None): lambda x: x, + }, + ) + + +def map_tensor_ndarray(x, tensor_func, ndarray_func): + """ + Apply function @tensor_func to torch.Tensor objects and @ndarray_func to + np.ndarray objects in a nested dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + tensor_func (function): function to apply to each tensor + ndarray_Func (function): function to apply to each array + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: tensor_func, + np.ndarray: ndarray_func, + type(None): lambda x: x, + }, + ) + + +def clone(x): + """ + Clones all torch tensors and numpy arrays in nested dictionary or list + or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.clone(), + np.ndarray: lambda x: x.copy(), + type(None): lambda x: x, + }, + ) + + +def detach(x): + """ + Detaches all torch tensors in nested dictionary or list + or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.detach(), + np.ndarray: lambda x: x, + type(None): lambda x: x, + }, + ignore_if_unspecified=True, + ) + + +def to_batch(x): + """ + Introduces a leading batch dimension of 1 for all torch tensors and numpy + arrays in nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x[None, ...], + np.ndarray: lambda x: x[None, ...], + type(None): lambda x: x, + }, + ) + + +def to_sequence(x): + """ + Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy + arrays in nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x[:, None, ...], + np.ndarray: lambda x: x[:, None, ...], + type(None): lambda x: x, + }, + ) + + +def index_at_time(x, ind): + """ + Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in + nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + ind (int): index + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x[:, ind, ...], + np.ndarray: lambda x: x[:, ind, ...], + type(None): lambda x: x, + }, + ) + + +def unsqueeze(x, dim): + """ + Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays + in nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + dim (int): dimension + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.unsqueeze(dim=dim), + np.ndarray: lambda x: np.expand_dims(x, axis=dim), + type(None): lambda x: x, + }, + ) + + +def squeeze(x, dim): + """ + Reduce dimension of size 1 at dimension @dim in all torch tensors and numpy arrays + in nested dictionary or list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + dim (int): dimension + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.squeeze(dim=dim), + np.ndarray: lambda x: np.squeeze(x, axis=dim), + type(None): lambda x: x, + }, + ) + + +def contiguous(x): + """ + Makes all torch tensors and numpy arrays contiguous in nested dictionary or + list or tuple and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.contiguous(), + np.ndarray: lambda x: np.ascontiguousarray(x), + type(None): lambda x: x, + }, + ) + + +def to_device(x, device, ignore_if_unspecified=False): + """ + Sends all torch tensors in nested dictionary or list or tuple to device + @device, and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + device (torch.Device): device to send tensors to + ignore_if_unspecified (bool): ignore an item if its type is unspecified by the type_func_dict + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x, d=device: x.to(d), + str: lambda x: x, + type(None): lambda x: x, + }, + ignore_if_unspecified=ignore_if_unspecified, + ) + + +def to_tensor(x, ignore_if_unspecified=False): + """ + Converts all numpy arrays in nested dictionary or list or tuple to + torch tensors (and leaves existing torch Tensors as-is), and returns + a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + ignore_if_unspecified (bool): ignore an item if its type is unspecified by the type_func_dict + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x, + str: lambda x: x, + np.ndarray: lambda x: torch.from_numpy(x), + type(None): lambda x: x, + }, + ignore_if_unspecified=ignore_if_unspecified, + ) + + +def to_numpy(x, ignore_if_unspecified=False): + """ + Converts all torch tensors in nested dictionary or list or tuple to + numpy (and leaves existing numpy arrays as-is), and returns + a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + ignore_if_unspecified (bool): ignore an item if its type is unspecified by the type_func_dict + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + + def f(tensor): + if tensor.is_cuda: + return tensor.detach().cpu().numpy() + else: + return tensor.detach().numpy() + + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: f, + np.ndarray: lambda x: x, + type(None): lambda x: x, + str: lambda x: x, + }, + ignore_if_unspecified=ignore_if_unspecified, + ) + + +def to_list(x): + """ + Converts all torch tensors and numpy arrays in nested dictionary or list + or tuple to a list, and returns a new nested structure. Useful for + json encoding. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + + def f(tensor): + if tensor.is_cuda: + return tensor.detach().cpu().numpy().tolist() + else: + return tensor.detach().numpy().tolist() + + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: f, + np.ndarray: lambda x: x.tolist(), + type(None): lambda x: x, + }, + ) + + +def to_float(x): + """ + Converts all torch tensors and numpy arrays in nested dictionary or list + or tuple to float type entries, and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.float(), + np.ndarray: lambda x: x.astype(np.float32), + type(None): lambda x: x, + }, + ) + + +def to_uint8(x): + """ + Converts all torch tensors and numpy arrays in nested dictionary or list + or tuple to uint8 type entries, and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.byte(), + np.ndarray: lambda x: x.astype(np.uint8), + type(None): lambda x: x, + }, + ) + + +def to_torch(x, device, ignore_if_unspecified=False): + """ + Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to + torch tensors on device @device and returns a new nested structure. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + device (torch.Device): device to send tensors to + ignore_if_unspecified (bool): ignore an item if its type is unspecified by the type_func_dict + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return to_device( + to_tensor(x, ignore_if_unspecified=ignore_if_unspecified), + device, + ignore_if_unspecified=ignore_if_unspecified, + ) + + +def to_one_hot_single(tensor, num_class): + """ + Convert tensor to one-hot representation, assuming a certain number of total class labels. + + Args: + tensor (torch.Tensor): tensor containing integer labels + num_class (int): number of classes + + Returns: + x (torch.Tensor): tensor containing one-hot representation of labels + """ + x = torch.zeros(tensor.size() + (num_class,)).to(tensor.device) + x.scatter_(-1, tensor.unsqueeze(-1), 1) + return x + + +def to_one_hot(tensor, num_class): + """ + Convert all tensors in nested dictionary or list or tuple to one-hot representation, + assuming a certain number of total class labels. + + Args: + tensor (dict or list or tuple): a possibly nested dictionary or list or tuple + num_class (int): number of classes + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor(tensor, func=lambda x, nc=num_class: to_one_hot_single(x, nc)) + + +def flatten_single(x, begin_axis=1): + """ + Flatten a tensor in all dimensions from @begin_axis onwards. + + Args: + x (torch.Tensor): tensor to flatten + begin_axis (int): which axis to flatten from + + Returns: + y (torch.Tensor): flattened tensor + """ + fixed_size = x.size()[:begin_axis] + _s = list(fixed_size) + [-1] + return x.reshape(*_s) + + +def flatten(x, begin_axis=1): + """ + Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + begin_axis (int): which axis to flatten from + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x, b=begin_axis: flatten_single(x, begin_axis=b), + }, + ) + + +def reshape_dimensions_single(x, begin_axis, end_axis, target_dims): + """ + Reshape selected dimensions in a tensor to a target dimension. + + Args: + x (torch.Tensor): tensor to reshape + begin_axis (int): begin dimension + end_axis (int): end dimension + target_dims (tuple or list): target shape for the range of dimensions + (@begin_axis, @end_axis) + + Returns: + y (torch.Tensor): reshaped tensor + """ + assert begin_axis < end_axis + assert begin_axis >= 0 + assert end_axis <= len(x.shape) + assert isinstance(target_dims, (tuple, list)) + s = x.shape + final_s = [] + for i in range(len(s)): + if i == begin_axis: + final_s.extend(target_dims) + elif i < begin_axis or i >= end_axis: + final_s.append(s[i]) + return x.reshape(*final_s) + + +def reshape_dimensions(x, begin_axis, end_axis, target_dims): + """ + Reshape selected dimensions for all tensors in nested dictionary or list or tuple + to a target dimension. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + begin_axis (int): begin dimension + end_axis (int): end dimension (excluding the dimension) + target_dims (tuple or list): target shape for the range of dimensions + (@begin_axis, @end_axis) + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=t + ), + np.ndarray: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=t + ), + type(None): lambda x: x, + }, + ) + + +def join_dimensions(x, begin_axis, end_axis): + """ + Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for + all tensors in nested dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + begin_axis (int): begin dimension + end_axis (int): end dimension (excluding the dimension) + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=[-1] + ), + np.ndarray: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single( + x, begin_axis=b, end_axis=e, target_dims=[-1] + ), + type(None): lambda x: x, + }, + ) + + +def expand_at_single(x, size, dim): + """ + Expand a tensor at a single dimension @dim by @size + + Args: + x (torch.Tensor): input tensor + size (int): size to expand + dim (int): dimension to expand + + Returns: + y (torch.Tensor): expanded tensor + """ + assert dim < x.ndimension() + assert x.shape[dim] == 1 + expand_dims = [-1] * x.ndimension() + expand_dims[dim] = size + return x.expand(*expand_dims) + + +def expand_at(x, size, dim): + """ + Expand all tensors in nested dictionary or list or tuple at a single + dimension @dim by @size. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + size (int): size to expand + dim (int): dimension to expand + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d)) + + +def unsqueeze_expand_at(x, size, dim): + """ + Unsqueeze and expand a tensor at a dimension @dim by @size. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + size (int): size to expand + dim (int): dimension to unsqueeze and expand + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + x = unsqueeze(x, dim) + return expand_at(x, size, dim) + + +def repeat_by_expand_at(x, repeats, dim): + """ + Repeat a dimension by combining expand and reshape operations. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + repeats (int): number of times to repeat the target dimension + dim (int): dimension to repeat on + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + x = unsqueeze_expand_at(x, repeats, dim + 1) + return join_dimensions(x, dim, dim + 2) + + +def named_reduce_single(x, reduction, dim): + """ + Reduce tensor at a dimension by named reduction functions. + + Args: + x (torch.Tensor): tensor to be reduced + reduction (str): one of ["sum", "max", "mean", "flatten"] + dim (int): dimension to be reduced (or begin axis for flatten) + + Returns: + y (torch.Tensor): reduced tensor + """ + assert x.ndimension() > dim + assert reduction in ["sum", "max", "mean", "flatten"] + if reduction == "flatten": + x = flatten(x, begin_axis=dim) + elif reduction == "max": + x = torch.max(x, dim=dim)[0] # [B, D] + elif reduction == "sum": + x = torch.sum(x, dim=dim) + else: + x = torch.mean(x, dim=dim) + return x + + +def named_reduce(x, reduction, dim): + """ + Reduces all tensors in nested dictionary or list or tuple at a dimension + using a named reduction function. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + reduction (str): one of ["sum", "max", "mean", "flatten"] + dim (int): dimension to be reduced (or begin axis for flatten) + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor( + x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d) + ) + + +def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices): + """ + This function indexes out a target dimension of a tensor in a structured way, + by allowing a different value to be selected for each member of a flat index + tensor (@indices) corresponding to a source dimension. This can be interpreted + as moving along the source dimension, using the corresponding index value + in @indices to select values for all other dimensions outside of the + source and target dimensions. A common use case is to gather values + in target dimension 1 for each batch member (target dimension 0). + + Args: + x (torch.Tensor): tensor to gather values for + target_dim (int): dimension to gather values along + source_dim (int): dimension to hold constant and use for gathering values + from the other dimensions + indices (torch.Tensor): flat index tensor with same shape as tensor @x along + @source_dim + + Returns: + y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out + """ + assert len(indices.shape) == 1 + assert x.shape[source_dim] == indices.shape[0] + + # unsqueeze in all dimensions except the source dimension + new_shape = [1] * x.ndimension() + new_shape[source_dim] = -1 + indices = indices.reshape(*new_shape) + + # repeat in all dimensions - but preserve shape of source dimension, + # and make sure target_dimension has singleton dimension + expand_shape = list(x.shape) + expand_shape[source_dim] = -1 + expand_shape[target_dim] = 1 + indices = indices.expand(*expand_shape) + + out = x.gather(dim=target_dim, index=indices) + return out.squeeze(target_dim) + + +def gather_from_start_single(tensor, indices): + if tensor.ndim < indices.ndim: + return tensor + gather_dim = indices.ndim - 1 + + desired_shape = list(indices.shape) + ([1] * (tensor.ndim - gather_dim - 1)) + repeats = [1] * indices.ndim + list(tensor.shape[indices.ndim :]) + indices = indices.reshape(desired_shape).repeat(repeats) + if isinstance(tensor, torch.Tensor): + return torch.gather(tensor, gather_dim, indices) + elif isinstance(tensor, np.ndarray): + return np.take(tensor, indices, axis=gather_dim) + + +def gather_from_start(tensor, indices): + return recursive_dict_list_tuple_apply( + tensor, + { + torch.Tensor: lambda x: gather_from_start_single(x, indices), + np.ndarray: lambda x: gather_from_start_single(x, indices), + type(None): lambda x: x, + }, + ) + + +def gather_along_dim_with_dim(x, target_dim, source_dim, indices): + """ + Apply @gather_along_dim_with_dim_single to all tensors in a nested + dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + target_dim (int): dimension to gather values along + source_dim (int): dimension to hold constant and use for gathering values + from the other dimensions + indices (torch.Tensor): flat index tensor with same shape as tensor @x along + @source_dim + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + return map_tensor( + x, + lambda y, t=target_dim, s=source_dim, i=indices: gather_along_dim_with_dim_single( + y, t, s, i + ), + ) + + +def gather_sequence_single(seq, indices): + """ + Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in + the batch given an index for each sequence. + + Args: + seq (torch.Tensor): tensor with leading dimensions [B, T, ...] + indices (torch.Tensor): tensor indices of shape [B] + + Return: + y (torch.Tensor): indexed tensor of shape [B, ....] + """ + return gather_along_dim_with_dim_single( + seq, target_dim=1, source_dim=0, indices=indices + ) + + +def gather_sequence(seq, indices): + """ + Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch + for tensors with leading dimensions [B, T, ...]. + + Args: + seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors + of leading dimensions [B, T, ...] + indices (torch.Tensor): tensor indices of shape [B] + + Returns: + y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...] + """ + return gather_along_dim_with_dim(seq, target_dim=1, source_dim=0, indices=indices) + + +def slice_tensor_single(x, dim, start_idx, end_idx): + """select a slice of the tensor + + Args: + x (torch.Tensor or np.array): the tensor + dim (int): dimension to select + start_idx (int): starting index + end_idx (int): ending index + """ + assert start_idx >= 0 and start_idx <= end_idx and end_idx <= x.shape[dim] + if isinstance(x, np.ndarray): + return x.take(np.arange(start_idx, end_idx), dim) + elif isinstance(x, torch.Tensor): + return torch.index_select(x, dim, torch.arange(start_idx, end_idx).to(x.device)) + + +def slice_tensor(tensor, dim, start_idx, end_idx): + """recursively select a slice of the tensor or its field if tensor is a dict + + Args: + tensor (torch.Tensor or dict): the tensor + dim (int): dimension to select + start_idx (int): starting index + end_idx (int): ending index + """ + + return recursive_dict_list_tuple_apply( + tensor, + { + torch.Tensor: lambda x: slice_tensor_single(x, dim, start_idx, end_idx), + np.ndarray: lambda x: slice_tensor_single(x, dim, start_idx, end_idx), + type(None): lambda x: x, + }, + ) + + +def pad_sequence_single(seq, padding, batched=False, pad_same=False, pad_values=0.0): + """ + Pad input tensor or array @seq in the time dimension (dimension 1). + + Args: + seq (np.ndarray or torch.Tensor): sequence to be padded + padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1 + batched (bool): if sequence has the batch dimension + pad_same (bool): if pad by duplicating + pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same + + Returns: + padded sequence (np.ndarray or torch.Tensor) + """ + assert isinstance(seq, (np.ndarray, torch.Tensor)) + assert pad_same or (pad_values is not None) + if pad_values is not None: + assert isinstance(pad_values, float) + repeat_func = np.repeat if isinstance(seq, np.ndarray) else torch.repeat_interleave + concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat + ones_like_func = np.ones_like if isinstance(seq, np.ndarray) else torch.ones_like + seq_dim = 1 if batched else 0 + + begin_pad = [] + end_pad = [] + + if padding[0] > 0: + if batched: + pad = seq[:, [0]] if pad_same else ones_like_func(seq[:, [0]]) * pad_values + else: + pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values + begin_pad.append(repeat_func(pad, padding[0], seq_dim)) + if padding[1] > 0: + if batched: + pad = ( + seq[:, [-1]] if pad_same else ones_like_func(seq[:, [-1]]) * pad_values + ) + else: + pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values + end_pad.append(repeat_func(pad, padding[1], seq_dim)) + + return concat_func(begin_pad + [seq] + end_pad, seq_dim) + + +def pad_sequence(seq, padding, batched=False, pad_same=False, pad_values=0.0): + """ + Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1). + + Args: + seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors + of leading dimensions [B, T, ...] + padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1 + batched (bool): if sequence has the batch dimension + pad_same (bool): if pad by duplicating + pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same + + Returns: + padded sequence (dict or list or tuple) + """ + return recursive_dict_list_tuple_apply( + seq, + { + torch.Tensor: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single( + x, p, b, ps, pv + ), + np.ndarray: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single( + x, p, b, ps, pv + ), + type(None): lambda x: x, + }, + ) + + +def left_right_average(seq): + """Add 1 entry to the seq by averaging the left appended seq and right appended seq + + Args: + seq (np.ndarray or torch.Tensor): + """ + concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat + seq_left = concat_func((seq[:1], seq), 0) + seq_right = concat_func((seq, seq[-1:]), 0) + return 0.5 * (seq_left + seq_right) + + +def assert_size_at_dim_single(x, size, dim, msg): + """ + Ensure that array or tensor @x has size @size in dim @dim. + + Args: + x (np.ndarray or torch.Tensor): input array or tensor + size (int): size that tensors should have at @dim + dim (int): dimension to check + msg (str): text to display if assertion fails + """ + assert x.shape[dim] == size, msg + + +def assert_size_at_dim(x, size, dim, msg): + """ + Ensure that arrays and tensors in nested dictionary or list or tuple have + size @size in dim @dim. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + size (int): size that tensors should have at @dim + dim (int): dimension to check + """ + map_tensor(x, lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m)) + + +def get_shape(x): + """ + Get all shapes of arrays and tensors in nested dictionary or list or tuple. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + + Returns: + y (dict or list or tuple): new nested dict-list-tuple that contains each array or + tensor's shape + """ + return recursive_dict_list_tuple_apply( + x, + { + torch.Tensor: lambda x: x.shape, + np.ndarray: lambda x: x.shape, + type(None): lambda x: x, + }, + ) + + +def list_of_flat_dict_to_dict_of_list(list_of_dict): + """ + Helper function to go from a list of flat dictionaries to a dictionary of lists. + By "flat" we mean that none of the values are dictionaries, but are numpy arrays, + floats, etc. + + Args: + list_of_dict (list): list of flat dictionaries + + Returns: + dict_of_list (dict): dictionary of lists + """ + assert isinstance(list_of_dict, list) + dic = collections.OrderedDict() + for i in range(len(list_of_dict)): + for k in list_of_dict[i]: + if k not in dic: + dic[k] = [] + dic[k].append(list_of_dict[i][k]) + return dic + + +def flatten_nested_dict_list(d, parent_key="", sep="_", item_key=""): + """ + Flatten a nested dict or list to a list. + + For example, given a dict + { + a: 1 + b: { + c: 2 + } + c: 3 + } + + the function would return [(a, 1), (b_c, 2), (c, 3)] + + Args: + d (dict, list): a nested dict or list to be flattened + parent_key (str): recursion helper + sep (str): separator for nesting keys + item_key (str): recursion helper + Returns: + list: a list of (key, value) tuples + """ + items = [] + if isinstance(d, (tuple, list)): + new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key + for i, v in enumerate(d): + items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i))) + return items + elif isinstance(d, dict): + new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key + for k, v in d.items(): + assert isinstance(k, str) + items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=k)) + return items + else: + new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key + return [(new_key, d)] + + +def time_distributed( + inputs, op, activation=None, inputs_as_kwargs=False, inputs_as_args=False, **kwargs +): + """ + Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the + batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...]. + Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping + outputs to [B, T, ...]. + + Args: + inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors + of leading dimensions [B, T, ...] + op: a layer op that accepts inputs + activation: activation to apply at the output + inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op + inputs_as_args (bool) whether to feed input as a args list to the op + kwargs (dict): other kwargs to supply to the op + + Returns: + outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T]. + """ + batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2] + inputs = join_dimensions(inputs, 0, 2) + if inputs_as_kwargs: + outputs = op(**inputs, **kwargs) + elif inputs_as_args: + outputs = op(*inputs, **kwargs) + else: + outputs = op(inputs, **kwargs) + + if activation is not None: + outputs = map_tensor(outputs, activation) + outputs = reshape_dimensions( + outputs, begin_axis=0, end_axis=1, target_dims=(batch_size, seq_len) + ) + return outputs + + +def round_2pi(x): + return (x + np.pi) % (2 * np.pi) - np.pi + + +def cat_list_of_dict(x, dim): + """combining a list of dictionaries to a single dictionary by concatenating the values + + Args: + x (List[Dict[tensor.Torch]]): _description_ + """ + combined_dict = dict() + for k, v in x[0].items(): + if isinstance(v, torch.Tensor): + if v.ndim >= dim: + combined_dict[k] = torch.cat([xi[k] for xi in x], dim=dim) + + elif isinstance(v, dict): + combined_dict[k] = cat_list_of_dict([xi[k] for xi in x], dim) + return combined_dict + + +def block_diag_from_cat(x): + """convert a concatenated array to block diagonal + + Args: + x (Union[torch.Tensor,np.ndarray]): [B,M,n,n] + Returns: + mat: [B,Mxn,Mxn] + """ + if x.ndim == 3: + M = x.shape[1] + Id = torch.eye(M, device=x.device) + slices = [torch.kron(Id[i], x[:, i]) for i in range(M)] + mat = torch.cat(slices, 1) + return mat + elif x.ndim == 4: + bs, M, n, m = x.shape + Id = torch.eye(M, device=x.device) + slices = [ + torch.kron(Id[i], x[:, i].reshape(-1, m)).reshape(bs, n, m * M) + for i in range(M) + ] + mat = torch.cat(slices, 1) + return mat + + +def recursive_mean(x): + # x is a list of dict, potentially multi-level, with torch.Tensor as values, we want to compute the mean. + output = dict() + for k, v in x[0].items(): + if isinstance(v, dict): + output[k] = recursive_mean([xi[k] for xi in x]) + elif isinstance(v, torch.Tensor): + output[k] = torch.stack([xi[k] for xi in x], dim=0).mean(dim=0) + else: + raise ValueError("value must be torch.Tensor or dict") + return output + + +def flatten_dict(x): + res = dict() + for k, v in x.items(): + if isinstance(v, dict): + res.update(flatten_dict(v)) + else: + res[k] = v + return res diff --git a/diffstack/utils/timer.py b/diffstack/utils/timer.py new file mode 100644 index 0000000..c0d299e --- /dev/null +++ b/diffstack/utils/timer.py @@ -0,0 +1,65 @@ + +import time +import numpy as np +from contextlib import contextmanager + + +class Timer(object): + """A simple timer.""" + def __init__(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + self.times = [] + + def recent_average_time(self, latest_n): + return np.mean(np.array(self.times)[-latest_n:]) + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.diff = time.time() - self.start_time + self.times.append(self.diff) + self.total_time += self.diff + self.calls += 1 + self.average_time = self.total_time / self.calls + if average: + return self.average_time + else: + return self.diff + + @contextmanager + def timed(self): + self.tic() + yield + self.toc() + + +class Timers(object): + def __init__(self): + self._timers = {} + + def tic(self, key): + if key not in self._timers: + self._timers[key] = Timer() + self._timers[key].tic() + + def toc(self, key): + self._timers[key].toc() + + @contextmanager + def timed(self, key): + self.tic(key) + yield + self.toc(key) + + def __str__(self): + msg = [] + for k, v in self._timers.items(): + msg.append('%s: %f' % (k, v.average_time)) + return ', '.join(msg) \ No newline at end of file diff --git a/diffstack/utils/torch_utils.py b/diffstack/utils/torch_utils.py new file mode 100644 index 0000000..677d5af --- /dev/null +++ b/diffstack/utils/torch_utils.py @@ -0,0 +1,316 @@ +""" +This file contains some PyTorch utilities. +""" +import numpy as np +import pytorch_lightning as pl +import torch +import torch.optim as optim +import functools +from tqdm.auto import tqdm +import time +from typing import Optional, Union + + +def soft_update(source, target, tau): + """ + Soft update from the parameters of a @source torch module to a @target torch module + with strength @tau. The update follows target = target * (1 - tau) + source * tau. + + Args: + source (torch.nn.Module): source network to push target network parameters towards + target (torch.nn.Module): target network to update + """ + for target_param, param in zip(target.parameters(), source.parameters()): + target_param.copy_(target_param * (1.0 - tau) + param * tau) + + +def hard_update(source, target): + """ + Hard update @target parameters to match @source. + + Args: + source (torch.nn.Module): source network to provide parameters + target (torch.nn.Module): target network to update parameters for + """ + for target_param, param in zip(target.parameters(), source.parameters()): + target_param.copy_(param) + + +def get_torch_device(try_to_use_cuda): + """ + Return torch device. If using cuda (GPU), will also set cudnn.benchmark to True + to optimize CNNs. + + Args: + try_to_use_cuda (bool): if True and cuda is available, will use GPU + + Returns: + device (torch.Device): device to use for models + """ + if try_to_use_cuda and torch.cuda.is_available(): + torch.backends.cudnn.benchmark = True + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + return device + + +def reparameterize(mu, logvar): + """ + Reparameterize for the backpropagation of z instead of q. + This makes it so that we can backpropagate through the sampling of z from + our encoder when feeding the sampled variable to the decoder. + + (See "The reparameterization trick" section of https://arxiv.org/abs/1312.6114) + + Args: + mu (torch.Tensor): batch of means from the encoder distribution + logvar (torch.Tensor): batch of log variances from the encoder distribution + + Returns: + z (torch.Tensor): batch of sampled latents from the encoder distribution that + support backpropagation + """ + # logvar = \log(\sigma^2) = 2 * \log(\sigma) + # \sigma = \exp(0.5 * logvar) + + # clamped for numerical stability + logstd = (0.5 * logvar).clamp(-4, 15) + std = torch.exp(logstd) + + # Sample \epsilon from normal distribution + # use std to create a new tensor, so we don't have to care + # about running on GPU or not + eps = std.new(std.size()).normal_() + + # Then multiply with the standard deviation and add the mean + z = eps.mul(std).add_(mu) + + return z + + +def optimizer_from_optim_params(net_optim_params, net): + """ + Helper function to return a torch Optimizer from the optim_params + section of the config for a particular network. + + Args: + optim_params (Config): optim_params part of algo_config corresponding + to @net. This determines the optimizer that is created. + + net (torch.nn.Module): module whose parameters this optimizer will be + responsible + + Returns: + optimizer (torch.optim.Optimizer): optimizer + """ + return optim.Adam( + params=net.parameters(), + lr=net_optim_params["learning_rate"]["initial"], + weight_decay=net_optim_params["regularization"]["L2"], + ) + + +def lr_scheduler_from_optim_params(net_optim_params, net, optimizer): + """ + Helper function to return a LRScheduler from the optim_params + section of the config for a particular network. Returns None + if a scheduler is not needed. + + Args: + optim_params (Config): optim_params part of algo_config corresponding + to @net. This determines whether a learning rate scheduler is created. + + net (torch.nn.Module): module whose parameters this optimizer will be + responsible + + optimizer (torch.optim.Optimizer): optimizer for this net + + Returns: + lr_scheduler (torch.optim.lr_scheduler or None): learning rate scheduler + """ + lr_scheduler = None + if len(net_optim_params["learning_rate"]["epoch_schedule"]) > 0: + # decay LR according to the epoch schedule + lr_scheduler = optim.lr_scheduler.MultiStepLR( + optimizer=optimizer, + milestones=net_optim_params["learning_rate"]["epoch_schedule"], + gamma=net_optim_params["learning_rate"]["decay_factor"], + ) + return lr_scheduler + + +def backprop_for_loss(net, optim, loss, max_grad_norm=None, retain_graph=False): + """ + Backpropagate loss and update parameters for network with + name @name. + + Args: + net (torch.nn.Module): network to update + + optim (torch.optim.Optimizer): optimizer to use + + loss (torch.Tensor): loss to use for backpropagation + + max_grad_norm (float): if provided, used to clip gradients + + retain_graph (bool): if True, graph is not freed after backward call + + Returns: + grad_norms (float): average gradient norms from backpropagation + """ + + # backprop + optim.zero_grad() + loss.backward(retain_graph=retain_graph) + + # gradient clipping + if max_grad_norm is not None: + torch.nn.utils.clip_grad_norm_(net.parameters(), max_grad_norm) + + # compute grad norms + grad_norms = 0.0 + for p in net.parameters(): + # only clip gradients for parameters for which requires_grad is True + if p.grad is not None: + grad_norms += p.grad.data.norm(2).pow(2).item() + + # step + optim.step() + + return grad_norms + + +class dummy_context_mgr: + """ + A dummy context manager - useful for having conditional scopes (such + as @maybe_no_grad). Nothing happens in this scope. + """ + + def __enter__(self): + return None + + def __exit__(self, exc_type, exc_value, traceback): + return False + + +def maybe_no_grad(no_grad): + """ + Args: + no_grad (bool): if True, the returned context will be torch.no_grad(), otherwise + it will be a dummy context + """ + return torch.no_grad() if no_grad else dummy_context_mgr() + + +def rgetattr(obj, attr, *args): + "recursively get attributes" + + def _getattr(obj, attr): + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split(".")) + + +def rsetattr(obj, attr, val): + "recursively set attributes" + pre, _, post = attr.rpartition(".") + return setattr(rgetattr(obj, pre) if pre else obj, post, val) + + +class ProgressBar(pl.Callback): + def __init__( + self, global_progress: bool = True, leave_global_progress: bool = True + ): + super().__init__() + + self.global_progress = global_progress + self.global_desc = "Epoch: {epoch}/{max_epoch}" + self.leave_global_progress = leave_global_progress + self.global_pb = None + + def on_fit_start(self, trainer, pl_module): + desc = self.global_desc.format( + epoch=trainer.current_epoch + 1, max_epoch=trainer.max_epochs + ) + + self.global_pb = tqdm( + desc=desc, + total=trainer.max_epochs, + initial=trainer.current_epoch, + leave=self.leave_global_progress, + disable=not self.global_progress, + ) + + def on_fit_end(self, trainer, pl_module): + self.global_pb.close() + self.global_pb = None + + def on_epoch_end(self, trainer, pl_module): + + # Set description + desc = self.global_desc.format( + epoch=trainer.current_epoch + 1, max_epoch=trainer.max_epochs + ) + self.global_pb.set_description(desc) + + # Set logs and metrics + # logs = pl_module.logs + # for k, v in logs.items(): + # if isinstance(v, torch.Tensor): + # logs[k] = v.squeeze().item() + # self.global_pb.set_postfix(logs) + + # Update progress + self.global_pb.update(1) + +def tic(timer: bool = True) -> Union[None, float]: + """Use to compute time for time-consuming process, call it before .toc()""" + + start_time = None + if timer: + torch.cuda.synchronize() + start_time: float = time.time() + + return start_time + + +def toc( + start_time: float, name: str = "", timer: bool = True, log=None +) -> Optional[float]: + """Use to compute time for time-consuming process, call it after .tic()""" + + if timer: + torch.cuda.synchronize() + end_time: float = time.time() + elapsed_ms: float = (end_time - start_time) * 1000 + print_str: str = f"{name:30s} EP: {elapsed_ms:.2f} ms" + + if log is not None: + print_log(print_str, log=log, display=False) + else: + print(print_str) + + return elapsed_ms + + return None + + +def print_log(print_str, log, same_line=False, display=True): + ''' + print a string to a log file + + parameters: + print_str: a string to print + log: a opened file to save the log + same_line: True if we want to print the string without a new next line + display: False if we want to disable to print the string onto the terminal + ''' + if display: + if same_line: print('{}'.format(print_str), end='') + else: print('{}'.format(print_str)) + + if same_line: log.write('{}'.format(print_str)) + else: log.write('{}\n'.format(print_str)) + log.flush() + \ No newline at end of file diff --git a/diffstack/utils/tpp_utils.py b/diffstack/utils/tpp_utils.py new file mode 100644 index 0000000..4744c33 --- /dev/null +++ b/diffstack/utils/tpp_utils.py @@ -0,0 +1,994 @@ +from collections import defaultdict +import numpy as np +from scipy.interpolate import interp1d +import torch + +# from diffstack.models.cnn_roi_encoder import rasterized_ROI_align +TRAJ_INDEX = [0, 1, 4] +STATE_INDEX = [0, 1, 4, 2] +INPUT_INDEX = [3, 5] +from Pplan.Sampling.tree import Tree +from Pplan.Sampling.trajectory_tree import TrajTree + +import diffstack.utils.geometry_utils as GeoUtils +import diffstack.utils.tensor_utils as TensorUtils +from diffstack.utils.planning_utils import get_drivable_area_loss +from diffstack.modules.cost_functions.tpp_internal_costs import TPPInternalCost + + +class AgentTrajTree(Tree): + def __init__(self, traj, parent, depth, prob=None): + self.traj = traj + self.children = list() + self.parent = parent + if parent is not None: + parent.expand(self) + self.depth = depth + self.prob = prob + self.attribute = dict() + + +# The state in Pplan contains more higher order derivatives, TRAJ_INDEX selects x,y, and heading +# out of the longer state vector + + +def gen_ego_edges( + ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types +): + """generate edges between ego trajectory samples and agent trajectories + + Args: + ego_trajectories (torch.Tensor): [B,N,T,3] + agent_trajectories (torch.Tensor): [B,A,T,3] or [B,N,A,T,3] + ego_extents (torch.Tensor): [B,2] + agent_extents (torch.Tensor): [B,A,2] + raw_types (torch.Tensor): [B,A] + Returns: + edges (torch.Tensor): [B,N,A,T,10] + type_mask (dict) + """ + B, N, T = ego_trajectories.shape[:3] + A = agent_trajectories.shape[-3] + + # veh_mask = (raw_types >= 3) & (raw_types <= 13) + # ped_mask = (raw_types == 14) | (raw_types == 15) + veh_mask = raw_types == 1 + ped_mask = raw_types == 2 + + edges = torch.zeros([B, N, A, T, 10]).to(ego_trajectories.device) + edges[..., :3] = ego_trajectories.unsqueeze(2).repeat(1, 1, A, 1, 1) + if agent_trajectories.ndim == 4: + edges[..., 3:6] = agent_trajectories.unsqueeze(1).repeat(1, N, 1, 1, 1) + else: + edges[..., 3:6] = agent_trajectories + edges[..., 6:8] = ego_extents.reshape(B, 1, 1, 1, 2).repeat(1, N, A, T, 1) + edges[..., 8:] = agent_extents.reshape(B, 1, A, 1, 2).repeat(1, N, 1, T, 1) + type_mask = {"VV": veh_mask, "VP": ped_mask} + return edges, type_mask + + +def get_collision_loss( + ego_trajectories, + agent_trajectories, + ego_extents, + agent_extents, + raw_types, + prob=None, + col_funcs=None, +): + """Get veh-veh and veh-ped collision loss.""" + # with torch.no_grad(): + ego_edges, type_mask = gen_ego_edges( + ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types + ) + if col_funcs is None: + col_funcs = { + "VV": GeoUtils.VEH_VEH_collision, + "VP": GeoUtils.VEH_PED_collision, + } + B, N, T = ego_trajectories.shape[:3] + col_loss = torch.zeros([B, N]).to(ego_trajectories.device) + for et, func in col_funcs.items(): + dis = func( + ego_edges[..., 0:3], + ego_edges[..., 3:6], + ego_edges[..., 6:8], + ego_edges[..., 8:], + ) + if dis.nelement() > 0: + col_loss += ( + torch.sum( + torch.sigmoid(-dis * 4) * type_mask[et][:, None, :, None], + dim=[2, 3], + ) + / T + ) + + return col_loss + + +# def get_drivable_area_loss( +# ego_trajectories, raster_from_agent, dis_map, ego_extents +# ): +# """Cost for road departure.""" +# with torch.no_grad(): + +# lane_flags = rasterized_ROI_align( +# dis_map, +# ego_trajectories[..., :2], +# ego_trajectories[..., 2:], +# raster_from_agent, +# torch.ones(*ego_trajectories.shape[:3] +# ).to(ego_trajectories.device), +# ego_extents.unsqueeze(1).repeat(1, ego_trajectories.shape[1], 1), +# 1, +# ).squeeze(-1) +# return lane_flags.max(dim=-1)[0] + + +def get_lane_loss_simple(ego_trajectories, raster_from_agent, dis_map): + h, w = dis_map.shape[-2:] + + raster_xy = GeoUtils.batch_nd_transform_points( + ego_trajectories[..., :2], raster_from_agent + ) + raster_xy[..., 0] = raster_xy[..., 0].clip(0, w - 1e-5) + raster_xy[..., 1] = raster_xy[..., 1].clip(0, h - 1e-5) + raster_xy = raster_xy.long() + raster_xy_flat = raster_xy[..., 1] * w + raster_xy[..., 0] + raster_xy_flat = raster_xy_flat.flatten() + lane_loss = (dis_map.flatten()[raster_xy_flat]).reshape(*raster_xy.shape[:2]) + return lane_loss.max(dim=-1)[0] + + +def get_lane_loss_vectorized(ego_trajectories, lane_info, ego_extents): + Ne, T = ego_trajectories.shape[:2] + cost = torch.zeros(Ne, device=ego_trajectories.device) + + if "leftbdry" in lane_info: + delta_x, delta_y, _ = GeoUtils.batch_proj( + ego_trajectories.reshape(-1, 3), + TensorUtils.to_torch(lane_info["leftbdry"], device=ego_extents.device)[ + None + ].repeat_interleave(Ne * T, 0), + ) + idx = delta_x.abs().argmin(1) + leftmargin = ( + -delta_y.gather(1, idx.reshape(-1, 1)).reshape(Ne, T) - ego_extents[1] / 2 + ) + cost += -(leftmargin.min(1)[0]).clamp(max=0) + if "rightbdry" in lane_info: + delta_x, delta_y, _ = GeoUtils.batch_proj( + ego_trajectories.reshape(-1, 3), + TensorUtils.to_torch(lane_info["rightbdry"], device=ego_extents.device)[ + None + ].repeat_interleave(Ne * T, 0), + ) + idx = delta_x.abs().argmin(1) + rightmargin = ( + delta_y.gather(1, idx.reshape(-1, 1)).reshape(Ne, T) - ego_extents[1] / 2 + ) + cost += -(rightmargin.min(1)[0]).clamp(max=0) + return cost + + +def get_terminal_likelihood_reward(ego_trajectories, raster_from_agent, log_likelihood): + """Cost for road departure.""" + + log_likelihood = (log_likelihood - log_likelihood.mean()) / log_likelihood.std() + h, w = log_likelihood.shape[-2:] + + raster_xy = GeoUtils.batch_nd_transform_points( + ego_trajectories[..., -1, :2], raster_from_agent + ) + raster_xy[..., 0] = raster_xy[..., 0].clip(0, w - 1e-5) + raster_xy[..., 1] = raster_xy[..., 1].clip(0, h - 1e-5) + raster_xy = raster_xy.long() + raster_xy_flat = raster_xy[..., 1] * w + raster_xy[..., 0] + + ll_reward = log_likelihood.flatten()[raster_xy_flat] + return ll_reward + + +def get_progress_reward(ego_trajectories, d_sat=10): + dis = torch.linalg.norm( + ego_trajectories[..., -1, :2] - ego_trajectories[..., 0, :2], dim=-1 + ) + return 2 / np.pi * torch.atan(dis / d_sat) + + +def get_total_distance(ego_trajectories): + """Reward that incentivizes progress.""" + # Assume format [..., T, 3] + assert ego_trajectories.shape[-1] == 3 + diff = ego_trajectories[..., 1:, :] - ego_trajectories[..., :-1, :] + dist = torch.norm(diff[..., :2], dim=-1) + total_dist = torch.sum(dist, dim=-1) + return total_dist + + +def ego_sample_planning( + ego_trajectories, + agent_trajectories, + ego_extents, + agent_extents, + raw_types, + raster_from_agent, + dis_map, + weights, + log_likelihood=None, + col_funcs=None, +): + """A basic cost function for prediction-and-planning""" + col_loss = get_collision_loss( + ego_trajectories, + agent_trajectories, + ego_extents, + agent_extents, + raw_types, + col_funcs, + ) + # lane_loss = get_drivable_area_loss( + # ego_trajectories, raster_from_agent, dis_map, ego_extents + # ) + progress = get_total_distance(ego_trajectories) + + log_likelihood = 0 if log_likelihood is None else log_likelihood + if log_likelihood.ndim == 3: + log_likelihood = get_terminal_likelihood_reward( + ego_trajectories, raster_from_agent, log_likelihood + ) + + total_score = ( + +weights["likelihood_weight"] * log_likelihood + + weights["progress_weight"] * progress + - weights["collision_weight"] * col_loss + # - weights["lane_weight"] * lane_loss + ) + + return torch.argmax(total_score, dim=1) + + +class TreeMotionPolicy(object): + """A trajectory tree policy as the result of contingency planning""" + + def __init__( + self, + stage, + num_frames_per_stage, + ego_root, + scenario_root, + cost_to_go, + leaf_idx, + curr_node, + ): + self.stage = stage + self.num_frames_per_stage = num_frames_per_stage + self.ego_root = ego_root + self.scenario_root = scenario_root + self.cost_to_go = cost_to_go + self.leaf_idx = leaf_idx + self.curr_node = curr_node + + def identify_branch(self, ego_node, scene_traj): + assert scene_traj.shape[-2] < self.stage * self.num_frames_per_stage + assert ego_node.total_traj.shape[0] - 1 >= scene_traj.shape[-2] + + remain_traj = scene_traj + curr_scenario_node = self.scenario_root + ego_leaf_index = self.leaf_idx[ego_node] + while remain_traj.shape[1] > 0: + seg_length = min(remain_traj.shape[-2], self.num_frames_per_stage) + dis = [ + torch.linalg.norm( + child.traj[ego_leaf_index, :, :seg_length, :2] + - remain_traj[:, :seg_length, :2], + dim=-1, + ) + .sum() + .item() + for child in curr_scenario_node.children + ] + idx = torch.argmin(torch.tensor(dis)).item() + + curr_scenario_node = curr_scenario_node.children[idx] + + remain_traj = remain_traj[..., seg_length:, :] + remain_num_frames = curr_scenario_node.traj.shape[-2] - seg_length + if curr_scenario_node.stage >= self.curr_node.stage: + break + return curr_scenario_node + + def get_plan(self, scene_traj, horizon): + if scene_traj is None: + T = 0 + remain_num_frames = self.num_frames_per_stage + else: + T = scene_traj.shape[-2] + remain_num_frames = self.curr_node.total_traj.shape[0] - 1 - T + assert remain_num_frames > -self.num_frames_per_stage + if remain_num_frames <= 0: + assert not self.curr_node.isleaf() + curr_scenario_node = self.identify_branch(self.curr_node, scene_traj) + Q = [ + self.cost_to_go[(child, curr_scenario_node)] + for child in self.curr_node.children + ] + idx = torch.argmin(torch.tensor(Q)).item() + self.curr_node = self.curr_node.children[idx] + remain_num_frames += self.curr_node.traj.shape[0] + state = self.curr_node.traj[-remain_num_frames:, STATE_INDEX] + action = self.curr_node.traj[-remain_num_frames:, INPUT_INDEX] + if not self.curr_node.isleaf(): + state = torch.cat( + (state, self.curr_node.children[0].traj[:, STATE_INDEX]), -2 + ) + action = torch.cat( + (action, self.curr_node.children[0].traj[:, INPUT_INDEX]), -2 + ) + + if state.shape[0] >= horizon: + return state[:horizon], action[:horizon] + else: + state_patched = torch.cat( + (state, state[-1].tile(horizon - state.shape[0], 1)) + ) + action_patched = torch.cat( + ( + action, + torch.zeros_like(action[-1]).tile(horizon - action.shape[0], 1), + ) + ) + return state_patched, action_patched + + +class VectorizedTreeMotionPolicy(TreeMotionPolicy): + """A vectorized trajectory tree policy as the result of contingency planning""" + + def __init__( + self, + stage, + num_frames_per_stage, + ego_tree, + children_indices, + scenario_tree, + cost_to_go, + leaf_idx, + curr_node, + ): + self.stage = stage + self.num_frames_per_stage = num_frames_per_stage + self.ego_tree = ego_tree + self.ego_root = ego_tree[0][0] + self.children_indices = children_indices + self.scenario_tree = scenario_tree + self.scenario_root = scenario_tree[0][0] + self.cost_to_go = cost_to_go + self.leaf_idx = leaf_idx + self.curr_node = curr_node + + def identify_branch(self, ego_node, scene_traj): + assert scene_traj.shape[-2] < self.stage * self.num_frames_per_stage + assert ego_node.total_traj.shape[0] - 1 >= scene_traj.shape[-2] + + remain_traj = scene_traj + curr_scenario_node = self.scenario_root + + stage = ego_node.depth + ego_stage_index = self.ego_tree[stage].index(ego_node) + ego_leaf_index = self.leaf_idx[stage][ego_stage_index].item() + while remain_traj.shape[1] > 0: + seg_length = min(remain_traj.shape[-2], self.num_frames_per_stage) + dis = [ + torch.linalg.norm( + child.traj[ego_leaf_index, :, :seg_length, :2] + - remain_traj[:, :seg_length, :2], + dim=-1, + ) + .sum() + .item() + for child in curr_scenario_node.children + ] + idx = torch.argmin(torch.tensor(dis)).item() + + curr_scenario_node = curr_scenario_node.children[idx] + + remain_traj = remain_traj[..., seg_length:, :] + remain_num_frames = curr_scenario_node.traj.shape[-2] - seg_length + if curr_scenario_node.stage >= self.curr_node.stage: + break + return curr_scenario_node + + def get_plan(self, scene_traj, horizon): + if scene_traj is None: + T = 0 + remain_num_frames = self.num_frames_per_stage + else: + T = scene_traj.shape[-2] + remain_num_frames = self.curr_node.total_traj.shape[0] - 1 - T + assert remain_num_frames > -self.num_frames_per_stage + if remain_num_frames <= 0: + assert not self.curr_node.isleaf() + curr_scenario_node = self.identify_branch(self.curr_node, scene_traj) + assert curr_scenario_node.depth == self.curr_node.depth + stage = self.curr_node.depth + scene_node_idx = self.scenario_tree[stage].index(curr_scenario_node) + curr_node_idx = self.ego_tree[stage].index(self.curr_node) + Q = self.cost_to_go[stage][ + scene_node_idx, self.children_indices[stage][curr_node_idx] + ] + idx = torch.argmin(Q).item() + self.curr_node = self.curr_node.children[idx] + remain_num_frames += self.curr_node.traj.shape[0] + + state = self.curr_node.traj[-remain_num_frames:, STATE_INDEX] + action = self.curr_node.traj[-remain_num_frames:, INPUT_INDEX] + if not self.curr_node.isleaf(): + state = torch.cat( + (state, self.curr_node.children[0].traj[:, STATE_INDEX]), -2 + ) + action = torch.cat( + (action, self.curr_node.children[0].traj[:, INPUT_INDEX]), -2 + ) + if state.shape[0] >= horizon: + return state[:horizon], action[:horizon] + else: + state_patched = torch.cat( + (state, state[-1].tile(horizon - state.shape[0], 1)) + ) + action_patched = torch.cat( + ( + action, + torch.zeros_like(action[-1]).tile(horizon - action.shape[0], 1), + ) + ) + return state_patched, action_patched + + def get_traj_array(self): + xu_batch = list() + root_xu = torch.cat( + (self.ego_root.traj[:, STATE_INDEX], self.ego_root.traj[:, INPUT_INDEX]), -1 + ) + for branch in self.ego_tree[1]: + xu = [root_xu] + while True: + xu.append( + torch.cat( + (branch.traj[:, STATE_INDEX], branch.traj[:, INPUT_INDEX]), -1 + ) + ) + if branch.isleaf(): + break + else: + branch = branch.children[0] + xu = torch.cat(xu, 0) + xu_batch.append(xu) + return torch.stack(xu_batch, 0) + + +def tiled_to_tree(total_traj, prob, num_stage, num_frames_per_stage, M): + """Turning a trajectory tree in tiled form to a tree data structure + + Args: + total_traj (torch.tensor or np.ndarray): tiled trajectory tree + prob (torch.tensor or np.ndarray): probability of the modes + num_stage (int): number of layers of the tree + num_frames_per_stage (int): number of time frames per layer + M (int): branching factor + + Returns: + nodes (dict[int:List(AgentTrajTree)]): all branches of the trajectory tree nodes indexed by layer + """ + + # total_traj = TensorUtils.reshape_dimensions_single(total_traj,2,3,[M]*num_stage) + x0 = AgentTrajTree(None, None, 0) + nodes = defaultdict(lambda: list()) + nodes[0].append(x0) + for t in range(num_stage): + interval = M ** (num_stage - t - 1) + tiled_traj = total_traj[ + ..., + ::interval, + :, + t * num_frames_per_stage : (t + 1) * num_frames_per_stage, + :, + ] + for i in range(M ** (t + 1)): + parent_idx = int(i / M) + p = prob[:, i * interval : (i + 1) * interval].sum(-1) + node = AgentTrajTree(tiled_traj[:, i], nodes[t][parent_idx], t + 1, prob=p) + nodes[t + 1].append(node) + return nodes + + +def contingency_planning( + ego_tree, + ego_extents, + agent_traj, + mode_prob, + agent_extents, + agent_types, + raster_from_agent, + dis_map, + weights, + num_frames_per_stage, + M, + dt, + col_funcs=None, + log_likelihood=None, + pert_std=None, +): + """A sampling-based contingency planning algorithm + + Args: + ego_tree (_type_): _description_ + ego_extents (_type_): _description_ + agent_traj (_type_): _description_ + mode_prob (_type_): _description_ + agent_extents (_type_): _description_ + agent_types (_type_): _description_ + raster_from_agent (_type_): _description_ + dis_map (_type_): _description_ + weights (_type_): _description_ + num_frames_per_stage (_type_): _description_ + M (_type_): _description_ + col_funcs (_type_, optional): _description_. Defaults to None. + + Returns: + _type_: _description_ + """ + + num_stage = len(ego_tree) - 1 + ego_root = ego_tree[0][0] + device = agent_traj.device + leaf_idx = defaultdict(lambda: list()) + for stage in range(num_stage, -1, -1): + for node in ego_tree[stage]: + if node.isleaf(): + leaf_idx[node] = [ego_tree[stage].index(node)] + else: + leaf_idx[node] = [] + for child in node.children: + leaf_idx[node] = leaf_idx[node] + leaf_idx[child] + + V = dict() + L = dict() + Q = dict() + scenario_tree = tiled_to_tree( + agent_traj, mode_prob, num_stage, num_frames_per_stage, M + ) + scenario_root = scenario_tree[0][0] + v0 = ego_root.traj[0, 2] + d_sat = v0.clip(min=2.0) * num_frames_per_stage * dt + for stage in range(num_stage, 0, -1): + if stage == 0: + total_loss = torch.zeros([1, 1], device=device) + else: + ego_nodes = ego_tree[stage] + indices = [leaf_idx[node][0] for node in ego_nodes] + ego_traj = [node.traj[:, TRAJ_INDEX] for node in ego_nodes] + ego_traj = torch.stack(ego_traj, 0) + agent_nodes = scenario_tree[stage] + agent_traj = [node.traj[indices] for node in agent_nodes] + agent_traj = torch.stack(agent_traj, 0) + ego_traj_tiled = ego_traj.unsqueeze(0).repeat(len(agent_nodes), 1, 1, 1) + col_loss = get_collision_loss( + ego_traj_tiled, + agent_traj, + ego_extents.tile(len(agent_nodes), 1), + agent_extents.tile(len(agent_nodes), 1, 1), + agent_types.tile(len(agent_nodes), 1), + col_funcs, + ) + + # lane_loss = get_drivable_area_loss(ego_traj.unsqueeze(0), raster_from_agent.unsqueeze(0), dis_map.unsqueeze(0), ego_extents.unsqueeze(0)) + # lane_loss = get_lane_loss_simple(ego_traj,raster_from_agent,dis_map).unsqueeze(0) + + progress_reward = get_progress_reward(ego_traj, d_sat=d_sat) + + total_loss = weights["collision_weight"] * col_loss - weights[ + "progress_weight" + ] * progress_reward.unsqueeze(0) + if pert_std is not None: + total_loss += ( + torch.randn(total_loss.shape[1], device=device).unsqueeze(0) + * pert_std + ) + if log_likelihood is not None and stage == num_stage: + ll_reward = get_terminal_likelihood_reward( + ego_traj, raster_from_agent, log_likelihood + ) + total_loss = total_loss - weights["likelihood_weight"] * ll_reward + + for i in range(len(ego_nodes)): + for j in range(len(agent_nodes)): + L[(ego_nodes[i], agent_nodes[j])] = total_loss[j, i] + if stage == num_stage: + V[(ego_nodes[i], agent_nodes[j])] = float(total_loss[j, i]) + else: + children_cost_to_go = [ + Q[(child, agent_nodes[j])] for child in ego_nodes[i].children + ] + V[(ego_nodes[i], agent_nodes[j])] = float(total_loss[j, i]) + min( + children_cost_to_go + ) + + if stage > 0: + for agent_node in scenario_tree[stage - 1]: + cost_i = [] + prob_i = [] + for child in agent_node.children: + cost_i.append(V[ego_nodes[i], child]) + prob_i.append(child.prob[leaf_idx[ego_nodes[i]]].sum()) + cost_i = torch.tensor(cost_i, device=device) + prob_i = torch.stack(prob_i) + Q[(ego_nodes[i], agent_node)] = float( + (cost_i * prob_i).sum() / prob_i.sum() + ) + Q_root = [Q[(child, scenario_root)] for child in ego_root.children] + idx = torch.argmin(torch.tensor(Q_root)).item() + optimal_node = ego_root.children[idx] + motion_policy = TreeMotionPolicy( + num_stage, + num_frames_per_stage, + ego_root, + scenario_root, + Q, + leaf_idx, + optimal_node, + ) + motion_policy.get_plan(None, num_stage * num_frames_per_stage) + return motion_policy + + +def get_cost_for_trajs(xu_batch, agent_traj, cost_obj, goal, lanes): + bs = len(xu_batch) + numMode = agent_traj.shape[0] + if agent_traj.nelement() == 0: + dummy_shape = list(agent_traj.shape) + dummy_shape[2] = 1 + agent_traj = torch.ones(dummy_shape, device=agent_traj.device) * 1e3 + + # Tile batch for each prediction mode, and input multi-modal predictions as pred_singles + ego_xu_tiled = xu_batch.repeat_interleave(numMode, 0) + pred_singles = TensorUtils.join_dimensions(agent_traj[..., :2], 0, 2) + pred_mus = torch.zeros( + (bs * numMode, 0, pred_singles.shape[-2], 1, 2), device=xu_batch.device + ) # b, N, T, K, 2 + pred_probs = torch.zeros( + [bs * numMode, 1, pred_singles.shape[-2]], device=xu_batch.device + ) + goal = ( + goal.unsqueeze(0).repeat_interleave(numMode * bs, 0) + if goal is not None + else None + ) + lanes = ( + lanes.unsqueeze(0).repeat_interleave(numMode * bs, 0) + if lanes is not None + else None + ) + + cost_inputs = (pred_mus, pred_probs, pred_singles, goal, lanes) + cost_inputs = TensorUtils.to_device(cost_inputs, ego_xu_tiled.device) + traj_cost = cost_obj(ego_xu_tiled, cost_inputs) # b, T + + # sum over time + traj_cost = traj_cost.sum(1) + # recover batch and prediction modes + traj_cost = traj_cost.reshape(bs, numMode) + + return traj_cost + + +def contingency_planning_parallel( + ego_tree, + ego_extents, + agent_traj, + mode_prob, + agent_extents, + agent_types, + raster_from_agent, + lane_info, + weights, + num_frames_per_stage, + M, + dt, + cost_obj=None, + lanes=None, + goal=None, + col_funcs=None, + log_likelihood=None, + lane_type="rasterized", + pert_std=None, +): + """A sampling-based contingency planning algorithm + + Args: + ego_tree (_type_): _description_ + ego_extents (_type_): _description_ + agent_traj (_type_): _description_ + mode_prob (_type_): _description_ + agent_extents (_type_): _description_ + agent_types (_type_): _description_ + raster_from_agent (_type_): _description_ + dis_map (_type_): _description_ + weights (_type_): _description_ + num_frames_per_stage (_type_): _description_ + M (_type_): _description_ + col_funcs (_type_, optional): _description_. Defaults to None. + + Returns: + _type_: _description_ + """ + device = agent_traj.device + children_indices = TrajTree.get_children_index_torch(ego_tree) + num_stage = len(ego_tree) - 1 + ego_root = ego_tree[0][0] + + leaf_idx = {num_stage: torch.arange(len(ego_tree[num_stage]), device=device)} + stage_prob = {num_stage: mode_prob.T} + for stage in range(num_stage - 1, -1, -1): + leaf_idx[stage] = leaf_idx[stage + 1][children_indices[stage][:, 0]] + prob_next = stage_prob[stage + 1] + stage_prob[stage] = prob_next.reshape(-1, M, prob_next.shape[-1])[ + :, :, children_indices[stage][:, 0] + ].sum(1) + + V = dict() + L = dict() + Q = dict() + + scenario_tree = tiled_to_tree( + agent_traj, mode_prob, num_stage, num_frames_per_stage, M + ) + scenario_root = scenario_tree[0][0] + v0 = ego_root.traj[0, 2] + d_sat = v0.clip(min=2.0) * num_frames_per_stage * dt + + for stage in range(num_stage, -1, -1): + if stage == 0: + total_loss = torch.zeros([1, 1], device=device) + else: + # calculate stage cost + ego_nodes = ego_tree[stage] + ego_traj = [node.traj[:, TRAJ_INDEX] for node in ego_nodes] + + ego_traj = torch.stack(ego_traj, 0) + agent_nodes = scenario_tree[stage] + + agent_traj = [node.traj[leaf_idx[stage]] for node in agent_nodes] + + agent_traj = torch.stack(agent_traj, 0) + if cost_obj is None or isinstance(cost_obj, TPPInternalCost): + # use internal TPP cost + if agent_traj.nelement() == 0: + col_loss = torch.zeros( + [*agent_traj.shape[:2]], device=ego_traj.device + ) + else: + ego_traj_tiled = ego_traj.unsqueeze(0).repeat( + len(agent_nodes), 1, 1, 1 + ) + col_loss = get_collision_loss( + ego_traj_tiled, + agent_traj, + ego_extents.tile(len(agent_nodes), 1), + agent_extents.tile(len(agent_nodes), 1, 1), + agent_types.tile(len(agent_nodes), 1), + col_funcs, + ) + + # lane_loss = get_drivable_area_loss(ego_traj.unsqueeze(0), raster_from_agent.unsqueeze(0), dis_map.unsqueeze(0), ego_extents.unsqueeze(0)) + if lane_type == "rasterized": + lane_loss = get_lane_loss_simple( + ego_traj, raster_from_agent, lane_info + ).unsqueeze(0) + elif lane_type == "vectorized": + lane_loss = get_lane_loss_vectorized( + ego_traj, lane_info, ego_extents + ).unsqueeze(0) + progress_reward = get_progress_reward(ego_traj, d_sat=d_sat) + + total_loss = ( + weights["collision_weight"] * col_loss + - weights["progress_weight"] * progress_reward.unsqueeze(0) + + weights["lane_weight"] * lane_loss + ) + if pert_std is not None: + total_loss += ( + torch.randn(total_loss.shape[1], device=device).unsqueeze(0) + * pert_std + ) + if log_likelihood is not None and stage == num_stage: + ll_reward = get_terminal_likelihood_reward( + ego_traj, raster_from_agent, log_likelihood + ) + total_loss = total_loss - weights["likelihood_weight"] * ll_reward + else: + # use diffstack cost + ego_x = torch.stack( + [node.traj[:, STATE_INDEX] for node in ego_nodes], 0 + ) + ego_u = torch.stack( + [node.traj[:, INPUT_INDEX] for node in ego_nodes], 0 + ) + ego_x_pre = torch.stack( + [node.parent.traj[-1:, STATE_INDEX] for node in ego_nodes], 0 + ) + ego_u_pre = torch.stack( + [node.parent.traj[-1:, INPUT_INDEX] for node in ego_nodes], 0 + ) + ego_xu = torch.cat((ego_x, ego_u), -1) + ego_xu_pre = torch.cat((ego_x_pre, ego_u_pre), -1) + ego_xu = torch.cat((ego_xu_pre, ego_xu), 1) + lane_seg = lanes[ + (stage - 1) * num_frames_per_stage : stage * num_frames_per_stage + + 1 + ] + + total_loss = get_cost_for_trajs( + ego_xu, agent_traj, cost_obj, goal, lane_seg + ).transpose(0, 1) + + L[stage] = total_loss + if stage == num_stage: + V[stage] = total_loss + else: + children_idx = children_indices[stage] + # add the last Q value as inf since empty children index are padded with -1 + Q_prime = torch.cat( + (Q[stage], torch.full([Q[stage].shape[0], 1], np.inf, device=device)), 1 + ) + Q_by_node = Q_prime[:, children_idx] + V[stage] = total_loss + Q_by_node.min(dim=-1)[0] + + if stage > 0: + children_V = V[stage] + children_V = children_V.reshape(-1, M, children_V.shape[-1]) + prob = stage_prob[stage] + prob = prob.reshape(-1, M, prob.shape[-1]) + prob_normalized = prob / prob.sum(dim=1, keepdim=True) + Q[stage - 1] = (children_V * prob_normalized).sum(dim=1) + + idx = Q[0].argmin().item() + + motion_policy = VectorizedTreeMotionPolicy( + num_stage, + num_frames_per_stage, + ego_tree, + children_indices, + scenario_tree, + Q, + leaf_idx, + ego_root.children[idx], + ) + return motion_policy + + +def one_shot_planning( + ego_tree, + ego_extents, + agent_traj, + mode_prob, + agent_extents, + agent_types, + raster_from_agent, + dis_map, + weights, + num_frames_per_stage, + M, + dt, + col_funcs=None, + log_likelihood=None, + pert_std=None, + strategy="all", +): + """Alternative of contingency planning, try to avoid all predicted trajectories + + Args: + ego_tree (_type_): _description_ + ego_extents (_type_): _description_ + agent_traj (_type_): _description_ + mode_prob (_type_): _description_ + agent_extents (_type_): _description_ + agent_types (_type_): _description_ + raster_from_agent (_type_): _description_ + dis_map (_type_): _description_ + weights (_type_): _description_ + num_frames_per_stage (_type_): _description_ + M (_type_): _description_ + col_funcs (_type_, optional): _description_. Defaults to None. + + Returns: + _type_: _description_ + """ + assert strategy == "all" or strategy == "maximum" + num_stage = len(ego_tree) - 1 + ego_root = ego_tree[0][0] + + ego_traj = [node.total_traj[1:, TRAJ_INDEX] for node in ego_tree[num_stage]] + ego_traj = torch.stack(ego_traj, 0) + ego_traj_tiled = ego_traj.unsqueeze(1).repeat_interleave(agent_traj.shape[1], 1) + Ne = ego_traj.shape[0] + if strategy == "maximum": + idx = mode_prob.argmax(dim=1) + idx = idx.reshape(Ne, *[1] * (agent_traj.ndim - 1)) + agent_traj = agent_traj.take_along_dim(idx, 1) + col_loss = get_collision_loss( + ego_traj_tiled, + agent_traj, + ego_extents.tile(Ne, 1), + agent_extents.tile(Ne, 1, 1), + agent_types.tile(Ne, 1), + col_funcs, + ) + col_loss = col_loss.max(dim=1)[0] + # lane_loss = get_drivable_area_loss(ego_traj.unsqueeze(0), raster_from_agent.unsqueeze(0), dis_map.unsqueeze(0), ego_extents.unsqueeze(0)).squeeze(0) + v0 = ego_root.traj[0, 2] + d_sat = v0.clip(min=2.0) * num_frames_per_stage * dt + progress_reward = get_progress_reward(ego_traj, d_sat=d_sat) + total_loss = ( + weights["collision_weight"] * col_loss + - weights["progress_weight"] * progress_reward + ) + if pert_std is not None: + total_loss += ( + torch.randn(total_loss.shape[0], device=total_loss.device) * pert_std + ) + if log_likelihood is not None: + ll_reward = get_terminal_likelihood_reward( + ego_traj, raster_from_agent, log_likelihood + ) + total_loss = total_loss - weights["likelihood_weight"] * ll_reward + + idx = total_loss.argmin() + return ego_traj[idx] + + +def obtain_ref(line, x, v, N, dt): + """obtain desired trajectory for the MPC controller + + Args: + line (np.ndarray): centerline of the lane [n, 3] + x (np.ndarray): position of the vehicle + v (np.ndarray): desired velocity + N (int): number of time steps + dt (float): time step + + Returns: + refx (np.ndarray): desired trajectory [N,3] + """ + line_length = line.shape[0] + delta_x = line[..., 0:2] - np.repeat(x[..., np.newaxis, 0:2], line_length, axis=-2) + dis = np.linalg.norm(delta_x, axis=-1) + idx = np.argmin(dis, axis=-1) + line_min = line[idx] + dx = x[0] - line_min[0] + dy = x[1] - line_min[1] + delta_y = -dx * np.sin(line_min[2]) + dy * np.cos(line_min[2]) + delta_x = dx * np.cos(line_min[2]) + dy * np.sin(line_min[2]) + refx0 = np.array( + [ + line_min[0] + delta_x * np.cos(line_min[2]), + line_min[1] + delta_x * np.sin(line_min[2]), + line_min[2], + ] + ) + s = [np.linalg.norm(line[idx + 1, 0:2] - refx0[0:2])] + for i in range(idx + 2, line_length): + s.append(s[-1] + np.linalg.norm(line[i, 0:2] - line[i - 1, 0:2])) + f = interp1d( + np.array(s), + line[idx + 1 :], + kind="linear", + axis=0, + copy=True, + bounds_error=False, + fill_value="extrapolate", + assume_sorted=True, + ) + s1 = v * np.arange(1, N + 1) * dt + refx = f(s1) + + return refx diff --git a/diffstack/utils/train_utils.py b/diffstack/utils/train_utils.py new file mode 100644 index 0000000..9349d78 --- /dev/null +++ b/diffstack/utils/train_utils.py @@ -0,0 +1,151 @@ +""" +This file contains several utility functions used to define the main training loop. It +mainly consists of functions to assist with logging, rollouts, and the @run_epoch function, +which is the core training logic for models in this repository. +""" +import os +import socket +import shutil +import pytorch_lightning as pl +from pytorch_lightning.loops.utilities import _reset_progress + +def infinite_iter(data_loader): + """ + Get an infinite generator + Args: + data_loader (DataLoader): data loader to iterate through + + """ + c_iter = iter(data_loader) + while True: + try: + data = next(c_iter) + except StopIteration: + c_iter = iter(data_loader) + data = next(c_iter) + yield data + + +def get_exp_dir(exp_name, output_dir, save_checkpoints=True, auto_remove_exp_dir=False): + """ + Create experiment directory from config. If an identical experiment directory + exists and @auto_remove_exp_dir is False (default), the function will prompt + the user on whether to remove and replace it, or keep the existing one and + add a new subdirectory with the new timestamp for the current run. + + Args: + exp_name (str): name of the experiment + output_dir (str): output directory of the experiment + save_checkpoints (bool): if save checkpoints + auto_remove_exp_dir (bool): if True, automatically remove the existing experiment + folder if it exists at the same path. + + Returns: + log_dir (str): path to created log directory (sub-folder in experiment directory) + output_dir (str): path to created models directory (sub-folder in experiment directory) + to store model checkpoints + video_dir (str): path to video directory (sub-folder in experiment directory) + to store rollout videos + """ + + # create directory for where to dump model parameters, tensorboard logs, and videos + base_output_dir = output_dir + if not os.path.isabs(base_output_dir): + base_output_dir = os.path.abspath(base_output_dir) + base_output_dir = os.path.join(base_output_dir, exp_name) + if os.path.exists(base_output_dir): + # if not auto_remove_exp_dir: + # ans = input( + # "WARNING: model directory ({}) already exists! \noverwrite? (y/n)\n".format( + # base_output_dir + # ) + # ) + # else: + # ans = "y" + # if ans == "y": + # print("REMOVING") + # shutil.rmtree(base_output_dir) + if auto_remove_exp_dir: + print(f"REMOVING {base_output_dir}") + shutil.rmtree(base_output_dir) + os.makedirs(base_output_dir, exist_ok=True) + + # version the run + existing_runs = [ + a + for a in os.listdir(base_output_dir) + if os.path.isdir(os.path.join(base_output_dir, a)) + ] + run_counts = [-1] + for ep in existing_runs: + m = ep.split("run") + if len(m) == 2 and m[0] == "": + run_counts.append(int(m[1])) + version_str = "run{}".format(max(run_counts) + 1) + + # only make model directory if model saving is enabled + ckpt_dir = None + if save_checkpoints: + ckpt_dir = os.path.join(base_output_dir, version_str, "checkpoints") + os.makedirs(ckpt_dir) + + # tensorboard directory + log_dir = os.path.join(base_output_dir, version_str, "logs") + os.makedirs(log_dir) + + # video directory + video_dir = os.path.join(base_output_dir, version_str, "videos") + os.makedirs(video_dir) + return base_output_dir, log_dir, ckpt_dir, video_dir, version_str + + +def trajdata_auto_set_batch_size(trainer: "pl.Trainer", model: "pl.LightningModule",datamodule:"pl.LightningDataModule",bs_min = 2,bs_max = None,conservative_reduction=True) -> int: + if bs_max == None: + # power search to find the bs_max + bs_trial = bs_min + + while True: + + datamodule.train_batch_size = bs_trial + datamodule.val_batch_size = bs_trial + try: + trainer.fit(model=model,datamodule=datamodule) + _reset_progress(trainer.fit_loop) + print(f"batch size {bs_trial} succeeded, trying {bs_trial*2}") + bs_min = bs_trial + bs_trial *= 2 + + + except: + print(f"batch size {bs_trial} failed, setting max batch size to {bs_trial}") + bs_max = bs_trial + break + + # the maximum batch size is dataset size divided by validation interval (there needs to be at least 1 validation per epoch) + if bs_trial >=len(datamodule.train_dataset)/getattr(trainer,"val_check_interval",100): + bs_max = int(len(datamodule.train_dataset)/getattr(trainer,"val_check_interval",100))-1 + break + else: + bs_max = min(bs_max,int(len(datamodule.train_dataset)/getattr(trainer,"val_check_interval",100))-1) + # binary search to find the optimal batch size + print(f" starting binary search with minimum batch size {bs_min}, maximum batch size {bs_max}") + while bs_max - bs_min > 1: + bs_trial = (bs_min + bs_max) // 2 + print(f"trying batch size {bs_trial}") + datamodule.train_batch_size = bs_trial + datamodule.val_batch_size = bs_trial + try: + trainer.fit(model=model,datamodule=datamodule) + _reset_progress(trainer.fit_loop) + print(f"batch size {bs_trial} succeeded") + bs_min = bs_trial + except: + bs_max = bs_trial + print(f"batch size {bs_trial} failed") + if bs_max-bs_min torch.Tensor: + """Paint agent histories onto an agent-centric map image""" + + b, a, t, _ = agent_hist_pos.shape + _, _, _, h, w = maps.shape + maps = maps.clone() + agent_hist_pos = TensorUtils.unsqueeze_expand_at(agent_hist_pos,a,1) + agent_mask_tiled = TensorUtils.unsqueeze_expand_at(agent_mask,a,1)*TensorUtils.unsqueeze_expand_at(agent_mask,a,2) + raster_hist_pos = transform_points_tensor(agent_hist_pos.reshape(b*a,-1,2), raster_from_agent.reshape(b*a,3,3)).reshape(b,a,a,t,2) + raster_hist_pos = raster_hist_pos * agent_mask_tiled.unsqueeze(-1) # Set invalid positions to 0.0 Will correct below + + raster_hist_pos[..., 0].clip_(0, (w - 1)) + raster_hist_pos[..., 1].clip_(0, (h - 1)) + raster_hist_pos = torch.round(raster_hist_pos).long() # round pixels [B, A, A, T, 2] + raster_hist_pos = raster_hist_pos.transpose(2,3) + raster_hist_pos_flat = raster_hist_pos[..., 1] * w + raster_hist_pos[..., 0] # [B, A, T, A] + hist_image = torch.zeros(b, a, t, h * w, dtype=maps.dtype, device=maps.device) # [B, A, T, H * W] + + ego_mask = torch.zeros_like(raster_hist_pos_flat,dtype=torch.bool) + ego_mask[:,range(a),:,range(a)]=1 + agent_mask = torch.logical_not(ego_mask) + + + hist_image.scatter_(dim=3, index=raster_hist_pos_flat*agent_mask, src=torch.ones_like(hist_image) * -1) # mark other agents with -1 + hist_image.scatter_(dim=3, index=raster_hist_pos_flat*ego_mask, src=torch.ones_like(hist_image)) # mark ego with 1. + hist_image[..., 0] = 0 # correct the 0th index from invalid positions + hist_image[..., -1] = 0 # correct the maximum index caused by out of bound locations + + hist_image = hist_image.reshape(b, a, t, h, w) + + maps = torch.cat((hist_image, maps), dim=2) # treat time as extra channels + return maps + + +def rasterize_agents_sc( + maps: torch.Tensor, + agent_pos: torch.Tensor, + agent_yaw: torch.Tensor, + agent_speed: torch.Tensor, + agent_mask: torch.Tensor, + raster_from_agent: torch.Tensor, +) -> torch.Tensor: + """Paint agent histories onto an agent-centric map image""" + + b, a, t, _ = agent_pos.shape + _, _, _, h, w = maps.shape + + # take the first agent as the center agent + raster_pos = transform_points_tensor(agent_pos.reshape(b,-1,2), raster_from_agent[:,0]).reshape(b,a,t,2) + raster_pos = raster_pos * agent_mask.unsqueeze(-1) # Set invalid positions to 0.0 Will correct below + + raster_pos[..., 0].clip_(0, (w - 1)) + raster_pos[..., 1].clip_(0, (h - 1)) + raster_pos_round = torch.round(raster_pos).long() # round pixels [B, A, T, 2] + raster_dxy = raster_pos - raster_pos_round.float() + + raster_pos_flat = raster_pos_round[..., 1] * w + raster_pos_round[..., 0] # [B, A, T] + prob = torch.zeros(b, a, t, h * w, dtype=maps.dtype, device=maps.device) # [B, A, T, H * W] + # dx = torch.zeros(b, a, t, h * w, dtype=maps.dtype, device=maps.device) # [B, A, T, H * W] + # dy = torch.zeros(b, a, t, h * w, dtype=maps.dtype, device=maps.device) # [B, A, T, H * W] + # heading = torch.zeros(b, a, t, h * w, dtype=maps.dtype, device=maps.device) # [B, A, T, H * W] + # vel = torch.zeros(b, a, t, h * w, dtype=maps.dtype, device=maps.device) # [B, A, T, H * W] + + raster_pos_flat = raster_pos_flat.unsqueeze(-1) + prob.scatter_(dim=3, index=raster_pos_flat, src=torch.ones_like(prob)) + # dx.scatter_(dim=3, index=raster_pos_flat, src=raster_dxy[...,0:1].clone().repeat_interleave(h*w,-1)) + # dy.scatter_(dim=3, index=raster_pos_flat, src=raster_dxy[...,1:2].clone().repeat_interleave(h*w,-1)) + # heading.scatter_(dim=3, index=raster_pos_flat, src=agent_yaw.repeat_interleave(h*w,-1)) + # vel.scatter_(dim=3, index=raster_pos_flat, src=agent_speed.unsqueeze(-1).repeat_interleave(h*w,-1)) + + # feature = torch.stack((prob, dx, dy, heading), dim=3) # [B, A, T, 5, H * W] + feature = prob + feature[..., 0] = 0 # correct the 0th index from invalid positions + feature[..., -1] = 0 # correct the maximum index caused by out of bound locations + + feature = feature.reshape(b, a, t, -1, h, w) + + return feature + + + +def rasterize_agents( + maps: torch.Tensor, + agent_hist_pos: torch.Tensor, + agent_hist_yaw: torch.Tensor, + agent_extent: torch.Tensor, + agent_mask: torch.Tensor, + raster_from_agent: torch.Tensor, + map_res: torch.Tensor, + cat=True, + filter=None, +) -> torch.Tensor: + """Paint agent histories onto an agent-centric map image""" + b, a, t, _ = agent_hist_pos.shape + _, _, h, w = maps.shape + + + agent_hist_pos = agent_hist_pos.reshape(b, a * t, 2) + raster_hist_pos = transform_points_tensor(agent_hist_pos, raster_from_agent) + raster_hist_pos[~agent_mask.reshape(b, a * t)] = 0.0 # Set invalid positions to 0.0 Will correct below + raster_hist_pos = raster_hist_pos.reshape(b, a, t, 2).permute(0, 2, 1, 3) # [B, T, A, 2] + raster_hist_pos[..., 0].clip_(0, (w - 1)) + raster_hist_pos[..., 1].clip_(0, (h - 1)) + raster_hist_pos = torch.round(raster_hist_pos).long() # round pixels + + raster_hist_pos_flat = raster_hist_pos[..., 1] * w + raster_hist_pos[..., 0] # [B, T, A] + + hist_image = torch.zeros(b, t, h * w, dtype=maps.dtype, device=maps.device) # [B, T, H * W] + + hist_image.scatter_(dim=2, index=raster_hist_pos_flat[:, :, 1:], src=torch.ones_like(hist_image) * -1) # mark other agents with -1 + hist_image.scatter_(dim=2, index=raster_hist_pos_flat[:, :, [0]], src=torch.ones_like(hist_image)) # mark ego with 1. + hist_image[:, :, 0] = 0 # correct the 0th index from invalid positions + hist_image[:, :, -1] = 0 # correct the maximum index caused by out of bound locations + + hist_image = hist_image.reshape(b, t, h, w) + if filter=="0.5-1-0.5": + kernel = torch.tensor([[0.5, 0.5, 0.5], + [0.5, 1., 0.5], + [0.5, 0.5, 0.5]]).to(hist_image.device) + + kernel = kernel.view(1, 1, 3, 3).repeat(t, t, 1, 1) + hist_image = F.conv2d(hist_image, kernel,padding=1) + if cat: + maps = maps.clone() + maps = torch.cat((hist_image, maps), dim=1) # treat time as extra channels + return maps + else: + return hist_image + +def rasterize_agents_rec( + maps: torch.Tensor, + agent_hist_pos: torch.Tensor, + agent_hist_yaw: torch.Tensor, + agent_extent: torch.Tensor, + agent_mask: torch.Tensor, + raster_from_agent: torch.Tensor, + map_res: torch.Tensor, + cat=True, + ego_neg = False, + parallel_raster=True, +) -> torch.Tensor: + """Paint agent histories onto an agent-centric map image""" + with torch.no_grad(): + b, a, t, _ = agent_hist_pos.shape + _, _, h, w = maps.shape + + coord_tensor = torch.cat((torch.arange(w).view(w,1,1).repeat_interleave(h,1), + torch.arange(h).view(1,h,1).repeat_interleave(w,0),),-1).to(maps.device) + + agent_hist_pos = agent_hist_pos.reshape(b, a * t, 2) + raster_hist_pos = transform_points_tensor(agent_hist_pos, raster_from_agent) + + + raster_hist_pos[~agent_mask.reshape(b, a * t)] = 0.0 # Set invalid positions to 0.0 Will correct below + + raster_hist_pos = raster_hist_pos.reshape(b, a, t, 2).permute(0, 2, 1, 3) # [B, T, A, 2] + + raster_hist_pos_yx = torch.cat((raster_hist_pos[...,1:],raster_hist_pos[...,0:1]),-1) + + if parallel_raster: + # vectorized version, uses much more memory + coord_tensor_tiled = coord_tensor.view(1,1,1,h,w,-1).repeat(b,t,a,1,1,1) + dyx = raster_hist_pos_yx[...,None,None,:]-coord_tensor_tiled + cos_yaw = torch.cos(-agent_hist_yaw) + sin_yaw = torch.sin(-agent_hist_yaw) + + rotM = torch.stack( + [ + torch.stack([cos_yaw, sin_yaw], dim=-1), + torch.stack([-sin_yaw, cos_yaw], dim=-1), + ],dim=-2, + ) + rotM = rotM.transpose(1,2) + rel_yx = torch.matmul(rotM.unsqueeze(-3).repeat(1,1,1,h,w,1,1),dyx.unsqueeze(-1)).squeeze(-1) + agent_extent_yx = torch.cat((agent_extent[...,1:2],agent_extent[...,0:1]),-1) + extent_tiled = agent_extent_yx[:,None,:,None,None] + + flag = (torch.abs(rel_yx)1: + # aggregate along the agent dimension + hist_img = hist_img[:,:,0] + hist_img[:,:,1:].max(2)[0]*(hist_img[:,:,0]==0) + else: + hist_img = hist_img.squeeze(2) + else: + + # loop through all agents, slow but memory efficient + coord_tensor_tiled = coord_tensor.view(1,1,h,w,-1).repeat(b,t,1,1,1) + agent_extent_yx = torch.cat((agent_extent[...,1:2],agent_extent[...,0:1]),-1) + hist_img_ego = torch.zeros([b,t,h,w],device=maps.device) + hist_img_nb = torch.zeros([b,t,h,w],device=maps.device) + for i in range(raster_hist_pos_yx.shape[-2]): + dyx = raster_hist_pos_yx[...,i,None,None,:]-coord_tensor_tiled + yaw_i = agent_hist_yaw[:,i] + cos_yaw = torch.cos(-yaw_i) + sin_yaw = torch.sin(-yaw_i) + + rotM = torch.stack( + [ + torch.stack([cos_yaw, sin_yaw], dim=-1), + torch.stack([-sin_yaw, cos_yaw], dim=-1), + ],dim=-2, + ) + + rel_yx = torch.matmul(rotM.unsqueeze(-3).repeat(1,1,h,w,1,1),dyx.unsqueeze(-1)).squeeze(-1) + extent_tiled = agent_extent_yx[:,None,i,None,None] + + flag = (torch.abs(rel_yx)1: + hist_img = hist_img_ego + hist_img_nb*(hist_img_ego==0) + else: + hist_img = hist_img_ego + + if cat: + maps = maps.clone() + maps = torch.cat((hist_img, maps), dim=1) # treat time as extra channels + return maps + else: + return hist_img + + + +def get_drivable_region_map(maps): + if isinstance(maps, torch.Tensor): + if maps.shape[-3]>=7: + drivable = torch.amax(maps[..., -7:-4, :, :], dim=-3).bool() + else: + drivable = torch.amax(maps[..., -3:, :, :], dim=-3).bool() + else: + if maps.shape[-3]>=7: + drivable = np.amax(maps[..., -7:-4, :, :], axis=-3).astype(bool) + else: + drivable = np.amax(maps[..., -3:, :, :], dim=-3).astype(bool) + return drivable + + +def maybe_pad_neighbor(batch): + """Pad neighboring agent's history to the same length as that of the ego using NaNs""" + hist_len = batch["agent_hist"].shape[1] + fut_len = batch["agent_fut"].shape[1] + b, a, neigh_len, _ = batch["neigh_hist"].shape + empty_neighbor = a == 0 + if empty_neighbor: + batch["neigh_hist"] = torch.ones(b, 1, hist_len, batch["neigh_hist"].shape[-1],device=batch["agent_hist"].device) * torch.nan + batch["neigh_fut"] = torch.ones(b, 1, fut_len, batch["neigh_fut"].shape[-1],device=batch["agent_hist"].device) * torch.nan + batch["neigh_types"] = torch.zeros(b, 1,device=batch["agent_hist"].device) + batch["neigh_hist_extents"] = torch.zeros(b, 1, hist_len, batch["neigh_hist_extents"].shape[-1],device=batch["agent_hist"].device) + batch["neigh_fut_extents"] = torch.zeros(b, 1, fut_len, batch["neigh_hist_extents"].shape[-1],device=batch["agent_hist"].device) + elif neigh_len < hist_len: + hist_pad = torch.ones(b, a, hist_len - neigh_len, batch["neigh_hist"].shape[-1],device=batch["agent_hist"].device) * torch.nan + batch["neigh_hist"] = torch.cat((hist_pad, batch["neigh_hist"]), dim=2) + hist_pad = torch.zeros(b, a, hist_len - neigh_len, batch["neigh_hist_extents"].shape[-1],device=batch["agent_hist"].device) + batch["neigh_hist_extents"] = torch.cat((hist_pad, batch["neigh_hist_extents"]), dim=2) + +def parse_scene_centric(batch: dict, rasterize_mode:str): + fut_pos, fut_yaw, fut_speed, fut_mask = trajdata2posyawspeed(batch["agent_fut"]) + hist_pos, hist_yaw, hist_speed, hist_mask = trajdata2posyawspeed(batch["agent_hist"]) + + curr_pos = hist_pos[:,:,-1] + curr_yaw = hist_yaw[:,:,-1] + if batch["centered_agent_state"].shape[-1]==7: + world_yaw = batch["centered_agent_state"][...,6] + else: + assert batch["centered_agent_state"].shape[-1]==8 + world_yaw = torch.atan2(batch["centered_agent_state"][...,6],batch["centered_agent_state"][...,7]) + curr_speed = hist_speed[..., -1] + centered_state = batch["centered_agent_state"] + centered_yaw = centered_state[:, -1] + centered_pos = centered_state[:, :2] + old_type = batch["agent_type"] + agent_type = torch.zeros_like(old_type) + agent_type[old_type < 0] = 0 + agent_type[old_type ==[3,4]] = 2 + agent_type[old_type ==1] = 3 + agent_type[old_type ==2] = 2 + agent_type[old_type ==5] = 4 + + # mask out invalid extents + agent_hist_extent = batch["agent_hist_extent"] + agent_hist_extent[torch.isnan(agent_hist_extent)] = 0. + + if not batch["centered_agent_from_world_tf"].isnan().any(): + centered_world_from_agent = torch.inverse(batch["centered_agent_from_world_tf"]) + else: + centered_world_from_agent = None + b,a = curr_yaw.shape[:2] + agents_from_center = (GeoUtils.transform_matrices(-curr_yaw.flatten(),torch.zeros(b*a,2,device=curr_yaw.device)) + @GeoUtils.transform_matrices(torch.zeros(b*a,device=curr_yaw.device),-curr_pos.reshape(-1,2))).reshape(*curr_yaw.shape[:2],3,3) + center_from_agents = GeoUtils.transform_matrices(curr_yaw.flatten(),curr_pos.reshape(-1,2)).reshape(*curr_yaw.shape[:2],3,3) + # map-related + if batch["maps"] is not None: + map_res = batch["maps_resolution"][0,0] + h, w = batch["maps"].shape[-2:] + # TODO: pass env configs to here + + centered_raster_from_agent = torch.Tensor([ + [map_res, 0, 0.125 * w], + [0, map_res, 0.5 * h], + [0, 0, 1] + ]).type_as(agents_from_center) + + centered_agent_from_raster,_ = torch.linalg.inv_ex(centered_raster_from_agent) + + raster_from_center = centered_raster_from_agent @ agents_from_center + center_from_raster = center_from_agents @ centered_agent_from_raster + + raster_from_world = batch["rasters_from_world_tf"] + world_from_raster,_ = torch.linalg.inv_ex(raster_from_world) + raster_from_world[torch.isnan(raster_from_world)] = 0. + world_from_raster[torch.isnan(world_from_raster)] = 0. + + if rasterize_mode=="none": + maps = batch["maps"] + elif rasterize_mode=="point": + maps = rasterize_agents_scene( + batch["maps"], + hist_pos, + hist_yaw, + None, + hist_mask, + raster_from_center, + map_res + ) + elif rasterize_mode=="square": + #TODO: add the square rasterization function for scene-centric data + raise NotImplementedError + elif rasterize_mode=="point_sc": + hist_hm = rasterize_agents_sc( + batch["maps"], + hist_pos, + hist_yaw, + hist_speed, + hist_mask, + raster_from_center, + ) + fut_hm = rasterize_agents_sc( + batch["maps"], + fut_pos, + fut_yaw, + fut_speed, + fut_mask, + raster_from_center, + ) + maps = batch["maps"] + drivable_map = get_drivable_region_map(batch["maps"]) + else: + maps = None + drivable_map = None + raster_from_center = None + center_from_raster = None + raster_from_world = None + centered_agent_from_raster = None + centered_raster_from_agent = None + + extent_scale = 1.0 + + + d = dict( + image=maps, + drivable_map=drivable_map, + fut_pos=fut_pos, + fut_yaw=fut_yaw, + fut_mask=fut_mask, + hist_pos=hist_pos, + hist_yaw=hist_yaw, + hist_mask=hist_mask, + curr_speed=curr_speed, + centroid=curr_pos, + world_yaw=world_yaw, + type=agent_type, + extent=agent_hist_extent.max(dim=-2)[0] * extent_scale, + raster_from_agent=centered_raster_from_agent, + agent_from_raster=centered_agent_from_raster, + raster_from_center=raster_from_center, + center_from_raster=center_from_raster, + agents_from_center = agents_from_center, + center_from_agents = center_from_agents, + raster_from_world=raster_from_world, + agent_from_world=batch["centered_agent_from_world_tf"], + world_from_agent=centered_world_from_agent, + ) + if rasterize_mode=="point_sc": + d["hist_hm"] = hist_hm + d["fut_hm"] = fut_hm + + return d + +def parse_node_centric(batch: dict,rasterize_mode:str,): + maybe_pad_neighbor(batch) + fut_pos, fut_yaw, _, fut_mask = trajdata2posyawspeed(batch["agent_fut"]) + hist_pos, hist_yaw, hist_speed, hist_mask = trajdata2posyawspeed(batch["agent_hist"]) + curr_speed = hist_speed[..., -1] + curr_state = batch["curr_agent_state"] + curr_yaw = curr_state[:, -1] + curr_pos = curr_state[:, :2] + + # convert nuscenes types to l5kit types + # old_type = batch["agent_type"] + # agent_type = torch.zeros_like(old_type) + # agent_type[old_type < 0] = 0 + # agent_type[old_type ==[3,4]] = 2 + # agent_type[old_type ==1] = 3 + # agent_type[old_type ==2] = 2 + # agent_type[old_type ==5] = 4 + # mask out invalid extents + agent_hist_extent = batch["agent_hist_extent"] + agent_hist_extent[torch.isnan(agent_hist_extent)] = 0. + + neigh_hist_pos, neigh_hist_yaw, neigh_hist_speed, neigh_hist_mask = trajdata2posyawspeed(batch["neigh_hist"]) + neigh_fut_pos, neigh_fut_yaw, _, neigh_fut_mask = trajdata2posyawspeed(batch["neigh_fut"]) + if neigh_hist_speed.nelement() > 0: + neigh_curr_speed = neigh_hist_speed[..., -1] + else: + neigh_curr_speed = neigh_hist_speed.unsqueeze(-1) + # old_neigh_types = batch["neigh_types"] + # # convert nuscenes types to l5kit types + # neigh_types = torch.zeros_like(old_neigh_types) + # # neigh_types = torch.zeros_like(old_type) + # neigh_types[old_neigh_types < 0] = 0 + # neigh_types[old_neigh_types ==[3,4]] = 2 + # neigh_types[old_neigh_types ==1] = 3 + # neigh_types[old_neigh_types ==2] = 2 + # neigh_types[old_neigh_types ==5] = 4 + + # mask out invalid extents + neigh_hist_extents = batch["neigh_hist_extents"] + neigh_hist_extents[torch.isnan(neigh_hist_extents)] = 0. + + world_from_agents = torch.inverse(batch["agents_from_world_tf"]) + if batch["curr_agent_state"].shape[-1]==7: + world_yaw = batch["curr_agent_state"][...,6] + else: + assert batch["curr_agent_state"].shape[-1]==8 + world_yaw = torch.atan2(batch["curr_agent_state"][...,6],batch["curr_agent_state"][...,7]) + + # map-related + if batch["maps"] is not None: + map_res = batch["maps_resolution"][0] + h, w = batch["maps"].shape[-2:] + # TODO: pass env configs to here + raster_from_agent = torch.tensor([ + [map_res, 0, 0.125 * w], + [0, map_res, 0.5 * h], + [0, 0, 1] + ],device=curr_pos.device) + agent_from_raster = torch.inverse(raster_from_agent) + raster_from_agent = TensorUtils.unsqueeze_expand_at(raster_from_agent, size=batch["maps"].shape[0], dim=0) + agent_from_raster = TensorUtils.unsqueeze_expand_at(agent_from_raster, size=batch["maps"].shape[0], dim=0) + raster_from_world = torch.bmm(raster_from_agent, batch["agents_from_world_tf"]) + if neigh_hist_pos.nelement()>0: + all_hist_pos = torch.cat((hist_pos[:, None], neigh_hist_pos), dim=1) + all_hist_yaw = torch.cat((hist_yaw[:, None], neigh_hist_yaw), dim=1) + + all_extents = torch.cat((batch["agent_hist_extent"].unsqueeze(1),batch["neigh_hist_extents"]),1).max(dim=2)[0][...,:2] + all_hist_mask = torch.cat((hist_mask[:, None], neigh_hist_mask), dim=1) + else: + all_hist_pos = hist_pos[:, None] + all_hist_yaw = hist_yaw[:, None] + all_extents = batch["agent_hist_extent"].unsqueeze(1)[...,:2] + all_hist_mask = hist_mask[:, None] + if rasterize_mode=="none": + maps = batch["maps"] + elif rasterize_mode=="point": + maps = rasterize_agents( + batch["maps"], + all_hist_pos, + all_hist_yaw, + all_extents, + all_hist_mask, + raster_from_agent, + map_res + ) + elif rasterize_mode=="square": + maps = rasterize_agents_rec( + batch["maps"], + all_hist_pos, + all_hist_yaw, + all_extents, + all_hist_mask, + raster_from_agent, + map_res + ) + else: + raise Exception("unknown rasterization mode") + drivable_map = get_drivable_region_map(batch["maps"]) + else: + maps = None + drivable_map = None + raster_from_agent = None + agent_from_raster = None + raster_from_world = None + + extent_scale = 1.0 + d = dict( + image=maps, + drivable_map=drivable_map, + fut_pos=fut_pos, + fut_yaw=fut_yaw, + fut_mask=fut_mask, + hist_pos=hist_pos, + hist_yaw=hist_yaw, + hist_mask=hist_mask, + curr_speed=curr_speed, + centroid=curr_pos, + world_yaw=world_yaw, + type=batch["agent_type"], + extent=agent_hist_extent.max(dim=-2)[0] * extent_scale, + raster_from_agent=raster_from_agent, + agent_from_raster=agent_from_raster, + raster_from_world=raster_from_world, + agent_from_world=batch["agents_from_world_tf"], + world_from_agent=world_from_agents, + neigh_hist_pos=neigh_hist_pos, + neigh_hist_yaw=neigh_hist_yaw, + neigh_hist_mask=neigh_hist_mask, # dump hack to agree with l5kit's typo ... + neigh_curr_speed=neigh_curr_speed, + neigh_fut_pos=neigh_fut_pos, + neigh_fut_yaw=neigh_fut_yaw, + neigh_fut_mask=neigh_fut_mask, + neigh_types=batch["neigh_types"], + neigh_extents=neigh_hist_extents.max(dim=-2)[0] * extent_scale if neigh_hist_extents.nelement()>0 else None, + + ) + # if "agent_lanes" in batch: + # d["ego_lanes"] = batch["agent_lanes"] + + return d + +@torch.no_grad() +def parse_trajdata_batch(batch, rasterize_mode="point"): + if isinstance(batch,AgentBatch): + # Be careful here, without making a copy of vars(batch) we would modify the fields of AgentBatch. + batch = dict(vars(batch)) + d = parse_node_centric(batch,rasterize_mode) + elif isinstance(batch,SceneBatch): + batch = dict(vars(batch)) + d = parse_scene_centric(batch,rasterize_mode) + elif isinstance(batch,dict): + batch = dict(batch) + if "num_agents" in batch: + # scene centric + d = parse_scene_centric(batch,rasterize_mode) + + else: + # agent centric + d = parse_node_centric(batch,rasterize_mode) + + batch.update(d) + for k,v in batch.items(): + if isinstance(v,torch.Tensor): + batch[k]=v.nan_to_num(0) + batch.pop("agent_name", None) + batch.pop("robot_fut", None) + return batch + + +def get_modality_shapes(cfg: ExperimentConfig, rasterize_mode: str = "point"): + h = cfg.env.rasterizer.raster_size + if rasterize_mode=="none": + return dict(static=(3,h,h),dynamic=(0,h,h),image=(3,h,h)) + else: + num_channels = (cfg.history_num_frames + 1) + 3 + return dict(static=(3,h,h),dynamic=(cfg.history_num_frames + 1,h,h),image=(num_channels, h, h)) + + +def gen_ego_edges(ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types): + """generate edges between ego trajectory samples and agent trajectories + + Args: + ego_trajectories (torch.Tensor): [B,N,T,3] + agent_trajectories (torch.Tensor): [B,A,T,3] or [B,N,A,T,3] + ego_extents (torch.Tensor): [B,2] + agent_extents (torch.Tensor): [B,A,2] + raw_types (torch.Tensor): [B,A] + Returns: + edges (torch.Tensor): [B,N,A,T,10] + type_mask (dict) + """ + B,N,T = ego_trajectories.shape[:3] + A = agent_trajectories.shape[-3] + + veh_mask = raw_types == int(AgentType["VEHICLE"]) + ped_mask = raw_types == int(AgentType["PEDESTRIAN"]) + + edges = torch.zeros([B,N,A,T,10],device=ego_trajectories.device) + edges[...,:3] = ego_trajectories.unsqueeze(2).repeat(1,1,A,1,1) + if agent_trajectories.ndim==4: + edges[...,3:6] = agent_trajectories.unsqueeze(1).repeat(1,N,1,1,1) + else: + edges[...,3:6] = agent_trajectories + edges[...,6:8] = ego_extents.reshape(B,1,1,1,2).repeat(1,N,A,T,1) + edges[...,8:] = agent_extents.reshape(B,1,A,1,2).repeat(1,N,1,T,1) + type_mask = {"VV":veh_mask,"VP":ped_mask} + return edges,type_mask + + + + print("abc") + +def gen_EC_edges(ego_trajectories,agent_trajectories,ego_extents, agent_extents, raw_types,mask=None): + """generate edges between ego trajectory samples and agent trajectories + + Args: + ego_trajectories (torch.Tensor): [B,A,T,3] + agent_trajectories (torch.Tensor): [B,A,T,3] + ego_extents (torch.Tensor): [B,2] + agent_extents (torch.Tensor): [B,A,2] + raw_types (torch.Tensor): [B,A] + mask (optional, torch.Tensor): [B,A] + Returns: + edges (torch.Tensor): [B,N,A,T,10] + type_mask (dict) + """ + + B,A = ego_trajectories.shape[:2] + T = ego_trajectories.shape[-2] + + veh_mask = raw_types == int(AgentType["VEHICLE"]) + ped_mask = raw_types == int(AgentType["PEDESTRIAN"]) + + + if ego_trajectories.ndim==4: + edges = torch.zeros([B,A,T,10],device=ego_trajectories.device) + edges[...,:3] = ego_trajectories + edges[...,3:6] = agent_trajectories + edges[...,6:8] = ego_extents.reshape(B,1,1,2).repeat(1,A,T,1) + edges[...,8:] = agent_extents.unsqueeze(2).repeat(1,1,T,1) + elif ego_trajectories.ndim==5: + + K = ego_trajectories.shape[2] + edges = torch.zeros([B,A*K,T,10],device=ego_trajectories.device) + edges[...,:3] = TensorUtils.join_dimensions(ego_trajectories,1,3) + edges[...,3:6] = agent_trajectories.repeat(1,K,1,1) + edges[...,6:8] = ego_extents.reshape(B,1,1,2).repeat(1,A*K,T,1) + edges[...,8:] = agent_extents.unsqueeze(2).repeat(1,K,T,1) + veh_mask = veh_mask.tile(1,K) + ped_mask = ped_mask.tile(1,K) + if mask is not None: + veh_mask = veh_mask*mask + ped_mask = ped_mask*mask + type_mask = {"VV":veh_mask,"VP":ped_mask} + return edges,type_mask + + +def generate_edges( + raw_type, + extents, + pos_pred, + yaw_pred, + batch_first = False, +): + veh_mask = raw_type == int(AgentType["VEHICLE"]) + ped_mask = raw_type == int(AgentType["PEDESTRIAN"]) + + agent_mask = veh_mask | ped_mask + edge_types = ["VV", "VP", "PV", "PP"] + edges = {et: list() for et in edge_types} + for i in range(agent_mask.shape[0]): + agent_idx = torch.where(agent_mask[i] != 0)[0] + edge_idx = torch.combinations(agent_idx, r=2) + VV_idx = torch.where( + veh_mask[i, edge_idx[:, 0]] & veh_mask[i, edge_idx[:, 1]] + )[0] + VP_idx = torch.where( + veh_mask[i, edge_idx[:, 0]] & ped_mask[i, edge_idx[:, 1]] + )[0] + PV_idx = torch.where( + ped_mask[i, edge_idx[:, 0]] & veh_mask[i, edge_idx[:, 1]] + )[0] + PP_idx = torch.where( + ped_mask[i, edge_idx[:, 0]] & ped_mask[i, edge_idx[:, 1]] + )[0] + if pos_pred.ndim == 4: + edges_of_all_types = torch.cat( + ( + pos_pred[i, edge_idx[:, 0], :], + yaw_pred[i, edge_idx[:, 0], :], + pos_pred[i, edge_idx[:, 1], :], + yaw_pred[i, edge_idx[:, 1], :], + extents[i, edge_idx[:, 0]] + .unsqueeze(-2) + .repeat(1, pos_pred.size(-2), 1), + extents[i, edge_idx[:, 1]] + .unsqueeze(-2) + .repeat(1, pos_pred.size(-2), 1), + ), + dim=-1, + ) + edges["VV"].append(edges_of_all_types[VV_idx]) + edges["VP"].append(edges_of_all_types[VP_idx]) + edges["PV"].append(edges_of_all_types[PV_idx]) + edges["PP"].append(edges_of_all_types[PP_idx]) + elif pos_pred.ndim == 5: + + edges_of_all_types = torch.cat( + ( + pos_pred[i, :, edge_idx[:, 0], :], + yaw_pred[i, :, edge_idx[:, 0], :], + pos_pred[i, :, edge_idx[:, 1], :], + yaw_pred[i, :, edge_idx[:, 1], :], + extents[i, None, edge_idx[:, 0], None, :].repeat( + pos_pred.size(1), 1, pos_pred.size(-2), 1 + ), + extents[i, None, edge_idx[:, 1], None, :].repeat( + pos_pred.size(1), 1, pos_pred.size(-2), 1 + ), + ), + dim=-1, + ) + edges["VV"].append(edges_of_all_types[:, VV_idx]) + edges["VP"].append(edges_of_all_types[:, VP_idx]) + edges["PV"].append(edges_of_all_types[:, PV_idx]) + edges["PP"].append(edges_of_all_types[:, PP_idx]) + if batch_first: + for et, v in edges.items(): + edges[et] = pad_sequence(v, batch_first=True,padding_value=torch.nan) + else: + if pos_pred.ndim == 4: + for et, v in edges.items(): + edges[et] = torch.cat(v, dim=0) + elif pos_pred.ndim == 5: + for et, v in edges.items(): + edges[et] = torch.cat(v, dim=1) + return edges + + +def merge_scene_batches(scene_batches: List[SceneBatch], dt: float) -> SceneBatch: + assert scene_batches[0].history_pad_dir == PadDirection.BEFORE + assert all([b.agent_hist.shape[0] == 1 for b in scene_batches]), "only batch_size=1 is supported" + + # Convert everything to world coordinates + scene_batches = [b.apply_transform(b.centered_world_from_agent_tf) for b in scene_batches] + state_format = scene_batches[0].agent_hist._format + + # Not all agent might be present at all time steps, so we match them by name. + # Get unique names, use np.unique return_index and sort to preserve ordering. + agent_names = [np.array(b.agent_names[0]) for b in scene_batches] + all_agent_names = np.concatenate(agent_names) + _, idx = np.unique(all_agent_names, return_index=True) + unique_agent_names = all_agent_names[np.sort(idx)] + + num_agents = len(unique_agent_names) + hist_len = len(scene_batches) + fut_len = scene_batches[-1].agent_fut.shape[-2] + + # Create full history with nans, then replace them for each time step + agent_hist = torch.full((1, num_agents, hist_len, scene_batches[0].agent_hist.shape[-1]), dtype=scene_batches[0].agent_hist.dtype, fill_value=torch.nan) + agent_hist_extent = torch.full((1, num_agents, hist_len, 2), dtype=scene_batches[0].agent_hist_extent.dtype, fill_value=torch.nan) + agent_type = torch.full((1, num_agents), dtype=scene_batches[0].agent_type.dtype, fill_value=-1) + + for t, scene_batch in enumerate(scene_batches): + match_inds = np.argwhere(np.array(scene_batch.agent_names[0])[:, None] == unique_agent_names[None, :]) # n_current, n_all -> n_current, 2 + assert match_inds.shape[0] == len(scene_batch.agent_names[0]), "there should be only 1 unique match" + agent_hist[0, match_inds[:, 1], t, :] = scene_batch.agent_hist[0, match_inds[:, 0], -1, :] + agent_hist_extent[0, match_inds[:, 1], t, :] = scene_batch.agent_hist_extent[0, match_inds[:, 0], -1, :] + agent_type[0, match_inds[:, 1]] = scene_batch.agent_type[0, match_inds[:, 0]] + + # Dummy future, repeat last state + agent_fut = agent_hist[:, :, -1:, :].repeat_interleave(fut_len, dim=-2) + agent_fut_extent = agent_hist_extent[:, :, -1:, :].repeat_interleave(fut_len, dim=-2) + + # Create trajdata batch + merged_batch = SceneBatch( + data_idx=torch.tensor([torch.nan]), + scene_ts=scene_batches[0].scene_ts, + scene_ids=scene_batches[0].scene_ids, + dt=torch.tensor([dt]), + num_agents=torch.tensor([num_agents]), + agent_type=agent_type, + centered_agent_state=scene_batches[0].centered_agent_state, + agent_names=[list(unique_agent_names)], + agent_hist=StateTensor.from_array(agent_hist, state_format), + agent_hist_extent=agent_hist_extent, + agent_hist_len=torch.tensor([[hist_len] * num_agents]), # len includes current state + agent_fut=StateTensor.from_array(agent_fut, state_format), + agent_fut_extent=agent_fut_extent, + agent_fut_len=torch.from_numpy(np.array([[fut_len] * num_agents])), + robot_fut=None, + robot_fut_len=None, + map_names=scene_batches[0].map_names, + maps=scene_batches[0].map_names, + maps_resolution=scene_batches[0].maps_resolution, + vector_maps=scene_batches[0].vector_maps, + rasters_from_world_tf=scene_batches[0].rasters_from_world_tf, + centered_agent_from_world_tf=scene_batches[0].centered_agent_from_world_tf, + centered_world_from_agent_tf=scene_batches[0].centered_world_from_agent_tf, + history_pad_dir=PadDirection.BEFORE, + extras=dict(scene_batches[0].extras), + ) + + return merged_batch diff --git a/diffstack/utils/tree.py b/diffstack/utils/tree.py new file mode 100644 index 0000000..845fdde --- /dev/null +++ b/diffstack/utils/tree.py @@ -0,0 +1,102 @@ +import itertools +from collections import defaultdict +import networkx as nx + +class Tree(object): + + def __init__(self, content, parent, depth): + self.content = content + self.children = list() + self.parent = parent + if parent is not None: + parent.expand(self) + self.depth = depth + self.attribute = dict() + + def expand(self, child): + self.children.append(child) + + def expand_set(self, children): + self.children += children + + def isroot(self): + return self.parent is None + + def isleaf(self): + return len(self.children) == 0 + + def get_subseq_trajs(self): + return [child.traj for child in self.children] + + + def get_all_leaves(self,leaf_set=[]): + if self.isleaf(): + leaf_set.append(self) + else: + for child in self.children: + leaf_set = child.get_all_leaves(leaf_set) + return leaf_set + def get_label(self): + raise NotImplementedError + + @staticmethod + def get_nodes_by_level(obj,depth,nodes=None,trim_short_branch=True): + assert obj.depth<=depth + if nodes is None: + nodes = defaultdict(lambda: list()) + if obj.depth==depth: + nodes[depth].append(obj) + return nodes, True + else: + if obj.isleaf(): + return nodes, False + + else: + flag = False + children_flags = dict() + for child in obj.children: + nodes, child_flag = Tree.get_nodes_by_level(child,depth,nodes) + children_flags[child] = child_flag + flag = flag or child_flag + if trim_short_branch: + obj.children = [child for child in obj.children if children_flags[child]] + if flag: + nodes[obj.depth].append(obj) + return nodes, flag + + @staticmethod + def get_children(obj): + if isinstance(obj, Tree): + return obj.children + elif isinstance(obj, list): + children = [node.children for node in obj] + children = list(itertools.chain.from_iterable(children)) + return children + else: + raise TypeError("obj must be a TrajTree or a list") + + def as_network(self): + G = nx.Graph() + G.add_node(self.get_label()) + for child in self.children: + G = nx.union(G,child.as_network()) + G.add_edge(self.get_label(),child.get_label()) + return G + + def plot(self): + G = self.as_network() + + pos = nx.nx_agraph.pygraphviz_layout(G, prog="dot") + nx.draw(G, pos,with_labels = True) + # nx.draw(G, with_labels = True) + + + + +def depth_first_traverse(tree:Tree,func, visited:dict, result): + result = func(tree,result) + visited[tree] = True + for child in tree.children: + if not (child in visited and visited[child]): + result, visited = depth_first_traverse(child, func, visited, result) + return result, visited \ No newline at end of file diff --git a/diffstack/utils/utils.py b/diffstack/utils/utils.py index 4006610..7e4a524 100644 --- a/diffstack/utils/utils.py +++ b/diffstack/utils/utils.py @@ -1,23 +1,25 @@ import os import random import torch -import torch.distributed as dist import time import numpy as np import pickle import dill import collections.abc -import datetime -from collections import defaultdict -from typing import Dict, Union, Tuple, Any, Optional, Iterable +from datetime import timedelta +from collections import defaultdict, OrderedDict +from scipy.interpolate import interp1d +from typing import Dict, Union, Tuple, Any, Optional, Iterable, List from torch.utils.data._utils.collate import default_collate from nuscenes.map_expansion import arcline_path_utils from nuscenes.map_expansion.map_api import NuScenesMap +from diffstack.utils.geometry_utils import batch_rotate_2D # Expose a couple of util functions defined in different submodules. -from trajdata.utils.arr_utils import angle_wrap +from trajdata.utils.arr_utils import angle_wrap, batch_select +from diffstack.modules.predictors.trajectron_utils.model.components import GMM2D container_abcs = collections.abc @@ -44,7 +46,7 @@ def initialize_torch_distributed(local_rank: int): init_method='env://', # default timeout torch.distributed.default_pg_timeout=1800 (sec, =30mins) # increase timeout for datacaching where workload for different gpus can be very different - timeout=datetime.timedelta(hours=10)) # 10h + timeout=timedelta(hours=10)) # 10h def set_all_seeds(seed): @@ -56,14 +58,14 @@ def set_all_seeds(seed): def prepeare_torch_env(rank, hyperparams): - if torch.cuda.is_available() and hyperparams["device"] != 'cpu': - hyperparams["device"] = f'cuda:{rank}' + if torch.cuda.is_available() and hyperparams.run["device"] != 'cpu': + hyperparams.run["device"] = f'cuda:{rank}' torch.cuda.set_device(rank) else: - hyperparams["device"] = f'cpu' + hyperparams.run["device"] = f'cpu' - if hyperparams["seed"] is not None: - set_all_seeds(hyperparams["seed"]) + if hyperparams.run["seed"] is not None: + set_all_seeds(hyperparams.run["seed"]) class CudaTimer(object): def __init__(self, enabled=True): @@ -119,21 +121,114 @@ def batch_derivative_of(states, dt = 1.): return diff / dt -def subsample_traj(x, predh, planh): - assert x.shape[0] == planh + 1 - if planh != predh: - assert planh % predh == 0, f"planning horizon ({predh}) needs to be a multiple of prediction horizon ({predh})" - subsample_gap = planh // predh - subsample_inds = list(range(0, planh+1, subsample_gap)) # for gap=2 [0,1,2,3,4] --> [0,2,4] - assert len(subsample_inds) == predh+1 - return x[subsample_inds] +def subsample_future(x: torch.Tensor, new_horizon: int, current_horizon: int): + return subsample_traj(x, new_horizon, current_horizon, is_future=True) + + +def subsample_history(x: torch.Tensor, new_horizon: int, current_horizon: int): + return subsample_traj(x, new_horizon, current_horizon, is_future=False) + + +def subsample_traj(x: torch.Tensor, new_horizon: int, current_horizon: int, is_future: bool = True): + """x: [..., planh+1, :]""" + assert x.shape[-2] == current_horizon + 1 + if current_horizon != new_horizon: + if new_horizon == 0: + subsample_inds = [0] + else: + assert current_horizon % new_horizon == 0, f"planning horizon ({new_horizon}) needs to be a multiple of prediction horizon ({current_horizon})" + subsample_gap = current_horizon // new_horizon + subsample_inds = list(range(0, current_horizon+1, subsample_gap)) # for gap=2 [0,1,2,3,4,5] --> [0,2,4] + if not is_future: + # for gap=2 [0,1,2,3,4,5] --> [0,2,4] --> [-1, -3, -5] --> [-5, -3, -1] (equivalent --> [1, 3, 5]) + subsample_inds = [-ind-1 for ind in subsample_inds] + subsample_inds.reverse() + + assert len(subsample_inds) == new_horizon+1 + if x.ndim == 2: + return x[subsample_inds] + else: + return x.transpose(-2, 0)[subsample_inds].transpose(-2, 0) else: return x + def normalize_angle(h): return (h + np.pi) % (2.0 * np.pi) - np.pi +def traj_xyhvv_to_pred(traj, dt): + # Input: [x, y, h, vx, vy] + # Output prediction state: ['x', 'y', 'vx', 'vy', 'ax', 'ay', 'sintheta', 'costheta'] + x, y, h, vx, vy = np.split(traj, 5, axis=-1) + ax = batch_derivative_of(vx, dt=dt) + ay = batch_derivative_of(vy, dt=dt) + pred_state = np.concatenate(( + x, y, vx, vy, ax, ay, np.sin(h), np.cos(h) + ), axis=-1) + return pred_state + + +def traj_pred_to_xyhvv(traj): + # Input prediction state: ['x', 'y', 'vx', 'vy', 'ax', 'ay', 'sintheta', 'costheta'] + # Output: [x, y, h, vx, vy] + x, y, vx, vy, ax, ay, sinh, cosh = np.split(traj, 8, axis=-1) + h = np.arctan2(sinh, cosh) + return np.concatenate((x, y, h, vx, vy), axis=-1) + + +def traj_xy_to_xyh(traj: Union[np.ndarray, torch.Tensor]): + xy = traj + dxy = xy[..., 1:, :2] - xy[..., :-1, :2] + if isinstance(traj, torch.Tensor): + h = torch.atan2(dxy[..., 1], dxy[..., 0])[..., None] # TODO invalid for near-zero velocity + h = torch.concat((h, h[..., -1:, :]), dim=-2) # extend time + return torch.concat((xy, h), dim=-1) + else: + h = np.arctan2(dxy[..., 1], dxy[..., 0])[..., None] # TODO invalid for near-zero velocity + h = np.concatenate((h, h[..., -1:, :]), axis=-2) # extend time + return np.concatenate((xy, h), axis=-1) + + +def traj_xyh_to_xyhv(traj: Union[np.ndarray, torch.Tensor], dt: float): + xyhvv = traj_xyh_to_xyhvv(traj, dt) + return traj_xyhvv_to_xyhv(xyhvv) + + +def traj_xyh_to_xyhvv(traj: Union[np.ndarray, torch.Tensor], dt: float): + if isinstance(traj, torch.Tensor): + x, y, h = torch.split(traj, 1, dim=-1) + vx = batch_derivative_of(x, dt) + vy = batch_derivative_of(y, dt) + return torch.concat((x, y, h, vx, vy), dim=-1) + else: + x, y, h = np.split(traj, 3, axis=-1) + vx = batch_derivative_of(x, dt) + vy = batch_derivative_of(y, dt) + return np.concatenate((x, y, h, vx, vy), axis=-1) + + +def traj_xyhv_to_xyhvv(traj): + x, y, h, v = np.split(traj, 4, axis=-1) + vx = v * np.cos(h) + vy = v * np.sin(h) + return np.concatenate((x, y, h, vx, vy), axis=-1) + + +def traj_xyhvv_to_xyhv(traj: Union[np.ndarray, torch.Tensor]): + # Use only the forward velocity component and ignore sideway velocity. + if isinstance(traj, torch.Tensor): + x, y, h, vx, vy = torch.split(traj, 1, dim=-1) + v_xy = batch_rotate_2D(torch.stack((vx, vy),-1), -h) # forward and sideway velocity + v = v_xy[...,0] + return torch.concat((x, y, h, v), dim=-1) + else: + x, y, h, vx, vy = np.split(traj, 5, axis=-1) + v_xy = batch_rotate_2D(np.stack((vx, vy),-1), -h) # forward and sideway velocity + v = v_xy[...,0] + return np.concatenate((x, y, h, v), axis=-1) + + def closest_lane_state(global_state: np.ndarray, nusc_map: NuScenesMap): nearest_lane = nusc_map.get_closest_lane(x=global_state[0], y=global_state[1]) lane_rec = nusc_map.get_arcline_path(nearest_lane) @@ -190,7 +285,88 @@ def closest_lane_np(ego_xy: np.ndarray, lane_points_list: Iterable[np.ndarray]): state_to_lane_dist2 = [ np.square(lane_points[..., :2] - ego_xy[np.newaxis,..., :2]).sum(-1).min(-1) for lane_points in lane_points_list] - return np.argmin(state_to_lane_dist2) + return np.argmin(state_to_lane_dist2) if len(state_to_lane_dist2)>0 else None + + +def get_pointgoal_from_onroute_lanes(ego_state_xyhv: np.ndarray, lanes_xyh: List[np.ndarray], dt: float, future_len: int, max_vel: float = 29.0, target_acc: float = 1.5) -> np.ndarray: + # TODO we should move this logic to the planner + + assert ego_state_xyhv.ndim == 1 and ego_state_xyhv.shape[-1] == 4 + ego_x, ego_y, ego_h, ego_v = ego_state_xyhv + + if len(lanes_xyh) == 0: + # No lanes. Set goal as current location. + goal_point_xyh = np.stack((ego_x, ego_y, ego_h), axis=-1) + ref_traj_xyh = np.repeat(goal_point_xyh[None], future_len+1, axis=0) + print ("WARNING: no lane for inferring goal") + return ref_traj_xyh, goal_point_xyh + + # Target final velocity, based on accelearting but remaining under max_vel limit + target_v = np.minimum(ego_v + target_acc * future_len * dt, max_vel) + avg_v = 0.5 * (target_v + ego_v) + target_pathlen = avg_v * future_len * dt + + closest_lane_ind = closest_lane_np(ego_state_xyhv[:3], lanes_xyh) + closest_lane = lanes_xyh[closest_lane_ind] + + state_to_lane_dist2 = np.square(closest_lane[:, :2] - ego_state_xyhv[None, :2]).sum(-1) + closest_point_ind = np.argmin(state_to_lane_dist2) + + # Assume polyline is ordered by the lane direction. + if closest_point_ind < closest_lane.shape[0]-1: + future_lane = closest_lane[closest_point_ind:] + # Get distance along lane. + # TODO compute distance from current state, not first lane point + step_len = np.concatenate([[0.], np.linalg.norm(future_lane[1:] - future_lane[:-1], axis=-1)]) + else: + # TODO again we should compute distance from current state, negative for first state, positive for second + future_lane = closest_lane[closest_point_ind-1:] + step_len = np.concatenate([[0.], np.linalg.norm(future_lane[1:] - future_lane[:-1], axis=-1)]) + + dist_along_lane = np.cumsum(step_len) + + # Find connecting lane + del closest_lane + while dist_along_lane[-1] < target_pathlen: + # TODO manually find continuation of lane, i.e. the lane with first point closest to last point of our lane. + first_lane_points = np.array([lane[0] for lane in lanes_xyh]) + d = np.linalg.norm(first_lane_points[:, :2] - future_lane[-1, None, :2], axis=-1) + closest_lane_ind = np.argmin(d) + + if d[closest_lane_ind] > 3: # 3m maximum to treat them as connected lanes + # No more lanes, we need to stop at the end of current lane. + break + + future_lane = np.concatenate((future_lane, lanes_xyh[closest_lane_ind]), axis=0) + # TODO compute distance from current state, not first lane point + step_len = np.concatenate([[0.], np.linalg.norm(future_lane[1:] - future_lane[:-1], axis=-1)]) + dist_along_lane = np.cumsum(step_len) + + del closest_lane_ind + + # Find lane point closest to our target distance. + # TODO this could be done more efficiently, use lane util functions + target_lane_point_ind = np.argmin(np.abs(dist_along_lane - target_pathlen)) + goal_point_xyh = future_lane[target_lane_point_ind] + + # Recomput target pathlen based on goal point. + target_pathlen = dist_along_lane[target_lane_point_ind] + avg_v = target_pathlen / (future_len * dt) + target_v = 2 * avg_v - ego_v + a = (target_v - ego_v) / (future_len * dt) + + # For each future time t, what is the intended length of traversed path. + delta_t = np.arange(future_len+1) * dt + delta_pathlen_t = (delta_t * a * 0.5 + ego_v) * delta_t + + # Interpolate lane at these delta lenth points. + # TODO we need to unwrap angles, do interpolation, and then wrap them again. + interp_fn = interp1d(dist_along_lane, future_lane, bounds_error=False, assume_sorted=True, axis=0) # nan for extrapolation + # interp_fn = interp1d(dist_along_lane, future_lane, fill_value="extrapolate", assume_sorted=True, axis=0) + + ref_traj_xyh = interp_fn(delta_pathlen_t) + + return ref_traj_xyh, goal_point_xyh def lat_long_distances(x: torch.Tensor, y: torch.Tensor, vect_x: torch.Tensor, vect_y: torch.Tensor, vect_h: torch.Tensor): @@ -256,6 +432,13 @@ def pt_rbf(input: torch.Tensor, center: Union[torch.Tensor, float] = 0.0, scale: return torch.exp(-0.5*torch.square(input - center).sum(-1)/scale) +def wrap(angles: Union[torch.Tensor, np.ndarray]): + return (angles + np.pi) % (2 * np.pi) - np.pi + +def angle_wrap(angles: Union[torch.Tensor, np.ndarray]): + return wrap(angles) + + # Custom torch functions that support jit.trace class _tracable_exp_fn(torch.autograd.Function): def forward(ctx, x: torch.Tensor): @@ -292,6 +475,26 @@ def ensure_length_nd(x, u, extra_info: Optional[Dict[str, torch.Tensor]] = None) return x, u +def prediction_diversity(y_dists): + # Variance of predictions at the final time step + mus = y_dists.mus # 1, b, T, N, 2 + log_pis = y_dists.log_pis # 1, b, T, N + + xy = mus.squeeze(0)[:, -1] # b, N, 2 + probs = torch.exp(log_pis.squeeze(0)[:, -1]) # b, N + # weighted mean final xy + xy_mean = (xy * probs.unsqueeze(-1)).sum(dim=1) # b, 2 + dist_from_mean = torch.linalg.norm(xy_mean.unsqueeze(1) - xy, dim=-1) # b, N + # variance of eucledian distances from the (weighted) mean prediction + # TODO this metric doesnt have a proper statistical interpretation + # To make it interpretable follow something like + # similarity measure in DPP https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=9387598 + # diversity energy https://openaccess.thecvf.com/content/ICCV2021/papers/Cui_LookOut_Diverse_Multi-Future_Prediction_and_Planning_for_Self-Driving_ICCV_2021_paper.pdf + var = (probs * torch.square(dist_from_mean)).sum(dim=-1) + + return var + + def convert_state_pred2plan(x_pred: Union[torch.Tensor, np.ndarray]): """ Transform @@ -315,9 +518,185 @@ def convert_state_pred2plan(x_pred: Union[torch.Tensor, np.ndarray]): return x_plan +def gmms_from_single_futures(x_xyhv: torch.Tensor, dt: float): + assert x_xyhv.ndim == 4 # N, b, T, 4 + assert x_xyhv.shape[-1] == 4 # xyhv + ph = x_xyhv.shape[-2] + + mus = x_xyhv.unsqueeze(3) # (N, b, T, 1, 4) + log_pis = torch.zeros(mus.shape[:-1], dtype=mus.dtype, device=mus.device) + log_sigmas = torch.log((torch.arange(1, ph+1, dtype=mus.dtype, device=mus.device) * dt)**2*2) + log_sigmas = log_sigmas.reshape(1, 1, ph, 1, 1).repeat((x_xyhv.shape[0], x_xyhv.shape[1], 1, 1, 2)) + corrs = 0. * torch.ones(mus.shape[:-1], dtype=mus.dtype, device=mus.device) # TODO not sure what is reasonable + + y_dists = GMM2D(log_pis, mus, log_sigmas, corrs) + return y_dists + + +def gmm_concat_as_modes(gmms: Iterable[GMM2D], probs: Iterable[float]): + gmm_joint = GMM2D( + log_pis=torch.concat([gmm.log_pis + np.log(p) for gmm, p in zip(gmms, probs)], dim=-1), + mus=torch.concat([gmm.mus for gmm in gmms], dim=-2), + log_sigmas=torch.concat([gmm.sigmas for gmm in gmms], dim=-2), + corrs=torch.concat([gmm.corrs for gmm in gmms], dim=-1), + ) + return gmm_joint + + +def gmm_concat_as_agents(gmms: Iterable[GMM2D]): + gmm_joint = GMM2D( + log_pis=torch.concat([gmm.log_pis for gmm in gmms], dim=0), + mus=torch.concat([gmm.mus for gmm in gmms], dim=0), + log_sigmas=torch.concat([gmm.log_sigmas for gmm in gmms], dim=0), + corrs=torch.concat([gmm.corrs for gmm in gmms], dim=0), + ) + return gmm_joint + + +def gmm_extend(gmm: GMM2D, num_modes: int): + gmm = GMM2D( + log_pis=torch.nn.functional.pad(gmm.log_pis, (0, num_modes), mode="constant", value=-np.inf), + mus=torch.nn.functional.pad(gmm.mus, (0, 0, 0, num_modes), mode="constant", value=np.nan), + log_sigmas=torch.nn.functional.pad(gmm.log_sigmas, (0, 0, 0, num_modes), mode="constant", value=np.nan), + corrs=torch.nn.functional.pad(gmm.corrs, (0, num_modes), mode="constant", value=0.), + ) + return gmm + + def move_list_element_to_front(a: list, i: int) -> list: a = [a[i]] + [a[j] for j in range(len(a)) if j != i] return a + +def soft_min(x,y,gamma=5): + if isinstance(x,torch.Tensor): + expfun = torch.exp + elif isinstance(x,np.ndarray): + expfun = np.exp + exp1 = expfun((y-x)/2) + exp2 = expfun((x-y)/2) + return (exp1*x+exp2*y)/(exp1+exp2) + +def soft_max(x,y,gamma=5): + if isinstance(x,torch.Tensor): + expfun = torch.exp + elif isinstance(x,np.ndarray): + expfun = np.exp + exp1 = expfun((x-y)/2) + exp2 = expfun((y-x)/2) + return (exp1*x+exp2*y)/(exp1+exp2) +def soft_sat(x,x_min=None,x_max=None,gamma=5): + if x_min is None and x_max is None: + return x + elif x_min is None and x_max is not None: + return soft_min(x,x_max,gamma) + elif x_min is not None and x_max is None: + return soft_max(x,x_min,gamma) + else: + if isinstance(x_min,torch.Tensor) or isinstance(x_min,np.ndarray): + assert (x_max>x_min).all() + else: + assert x_max>x_min + xc = x - (x_min+x_max)/2 + if isinstance(x,torch.Tensor): + return xc/(torch.pow(1+torch.pow(torch.abs(xc*2/(x_max-x_min)),gamma),1/gamma))+(x_min+x_max)/2 + elif isinstance(x,np.ndarray): + return xc/(np.power(1+np.power(np.abs(xc*2/(x_max-x_min)),gamma),1/gamma))+(x_min+x_max)/2 + else: + raise Exception("data type not supported") + + +def recursive_dict_list_tuple_apply(x, type_func_dict, ignore_if_unspecified=False): + """ + Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of + {data_type: function_to_apply}. + + Args: + x (dict or list or tuple): a possibly nested dictionary or list or tuple + type_func_dict (dict): a mapping from data types to the functions to be + applied for each data type. + ignore_if_unspecified (bool): ignore an item if its type is unspecified by the type_func_dict + + Returns: + y (dict or list or tuple): new nested dict-list-tuple + """ + assert list not in type_func_dict + assert tuple not in type_func_dict + assert dict not in type_func_dict + assert torch.nn.ParameterDict not in type_func_dict + assert torch.nn.ParameterList not in type_func_dict + + if isinstance(x, (dict, OrderedDict, torch.nn.ParameterDict)): + new_x = ( + OrderedDict() + if isinstance(x, OrderedDict) + else dict() + ) + for k, v in x.items(): + new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict, ignore_if_unspecified) + return new_x + elif isinstance(x, (list, tuple, torch.nn.ParameterList)): + ret = [recursive_dict_list_tuple_apply(v, type_func_dict, ignore_if_unspecified) for v in x] + if isinstance(x, tuple): + ret = tuple(ret) + return ret + else: + for t, f in type_func_dict.items(): + if isinstance(x, t): + return f(x) + else: + if ignore_if_unspecified: + return x + else: + raise NotImplementedError("Cannot handle data type %s" % str(type(x))) + +def reshape_dimensions_single(x, begin_axis, end_axis, target_dims): + """ + Reshape selected dimensions in a tensor to a target dimension. + + Args: + x (torch.Tensor): tensor to reshape + begin_axis (int): begin dimension + end_axis (int): end dimension + target_dims (tuple or list): target shape for the range of dimensions + (@begin_axis, @end_axis) + + Returns: + y (torch.Tensor): reshaped tensor + """ + assert begin_axis < end_axis + assert begin_axis >= 0 + assert end_axis <= len(x.shape) + assert isinstance(target_dims, (tuple, list)) + s = x.shape + final_s = [] + for i in range(len(s)): + if i == begin_axis: + final_s.extend(target_dims) + elif i < begin_axis or i >= end_axis: + final_s.append(s[i]) + return x.reshape(*final_s) + + +def yaw_from_pos(pos: torch.Tensor, dt, yaw_correction_speed=0.): + """ + Compute yaws from position sequences. Optionally suppress yaws computed from low-velocity steps + + Args: + pos (torch.Tensor): sequence of positions [..., T, 2] + dt (float): delta timestep to compute speed + yaw_correction_speed (float): zero out yaw change when the speed is below this threshold (noisy heading) + + Returns: + accum_yaw (torch.Tensor): sequence of yaws [..., T-1, 1] + """ + + pos_diff = pos[..., 1:, :] - pos[..., :-1, :] + yaw = torch.atan2(pos_diff[..., 1], pos_diff[..., 0]) + delta_yaw = torch.cat((yaw[..., [0]], yaw[..., 1:] - yaw[..., :-1]), dim=-1) + speed = torch.norm(pos_diff, dim=-1) / dt + delta_yaw[speed < yaw_correction_speed] = 0. + accum_yaw = torch.cumsum(delta_yaw, dim=-1) + return accum_yaw[..., None] def all_gather(data): @@ -332,7 +711,7 @@ def all_gather(data): list[data] List of data gathered from each rank """ - world_size = dist.get_world_size() + world_size = torch.distributed.get_world_size() if world_size == 1: return [data] @@ -345,7 +724,7 @@ def all_gather(data): # Obtain Tensor size of each rank local_size = torch.IntTensor([tensor.numel()]).to("cuda") size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)] - dist.all_gather(size_list, local_size) + torch.distributed.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) @@ -357,7 +736,7 @@ def all_gather(data): if local_size != max_size: padding = torch.ByteTensor(size=(max_size - local_size, )).to("cuda") tensor = torch.cat((tensor, padding), dim=0) - dist.all_gather(tensor_list, tensor) + torch.distributed.all_gather(tensor_list, tensor) data_list = [] for size, tensor in zip(size_list, tensor_list): @@ -461,3 +840,7 @@ def block_diag(m): eye = attach_dim(torch.eye(n, device=m.device).unsqueeze(-2), d - 3, 1) return (m2 * eye).reshape(siz0 + torch.Size(torch.tensor(siz1) * n)) +def removeprefix(line, prefix): + if line.startswith(prefix): + line_new = line[len(prefix):] + return line_new \ No newline at end of file diff --git a/diffstack/utils/vis_utils.py b/diffstack/utils/vis_utils.py new file mode 100644 index 0000000..b448b25 --- /dev/null +++ b/diffstack/utils/vis_utils.py @@ -0,0 +1,394 @@ +import numpy as np +from PIL import Image, ImageDraw +from collections import defaultdict +from typing import List, Optional, Tuple, Dict +from trajdata.maps.vec_map import VectorMap + +from l5kit.geometry import transform_points +from l5kit.rasterization.render_context import RenderContext +from l5kit.configs.config import load_metadata +from trajdata.maps import RasterizedMap + +from diffstack.utils.tensor_utils import map_ndarray +from diffstack.utils.geometry_utils import get_box_world_coords_np +import diffstack.utils.tensor_utils as TensorUtils +import os +import glob +from bokeh.models import ColumnDataSource, GlyphRenderer +from bokeh.plotting import figure, curdoc +import bokeh +from bokeh.io import export_png +from trajdata.utils.arr_utils import ( + transform_coords_2d_np, + batch_nd_transform_points_pt, + batch_nd_transform_points_np, +) +from trajdata.utils.vis_utils import draw_map_elems + +COLORS = { + "agent_contour": "#247BA0", + "agent_fill": "#56B1D8", + "ego_contour": "#911A12", + "ego_fill": "#FE5F55", +} + + +def agent_to_raster_np(pt_tensor, trans_mat): + pos_raster = transform_points(pt_tensor[None], trans_mat)[0] + return pos_raster + + +def draw_actions( + state_image, + trans_mat, + pred_action=None, + pred_plan=None, + pred_plan_info=None, + ego_action_samples=None, + plan_samples=None, + action_marker_size=3, + plan_marker_size=8, +): + im = Image.fromarray((state_image * 255).astype(np.uint8)) + draw = ImageDraw.Draw(im) + + if pred_action is not None: + raster_traj = agent_to_raster_np( + pred_action["positions"].reshape(-1, 2), trans_mat + ) + for point in raster_traj: + circle = np.hstack([point - action_marker_size, point + action_marker_size]) + draw.ellipse(circle.tolist(), fill="#FE5F55", outline="#911A12") + if ego_action_samples is not None: + raster_traj = agent_to_raster_np( + ego_action_samples["positions"].reshape(-1, 2), trans_mat + ) + for point in raster_traj: + circle = np.hstack([point - action_marker_size, point + action_marker_size]) + draw.ellipse(circle.tolist(), fill="#808080", outline="#911A12") + + if pred_plan is not None: + pos_raster = agent_to_raster_np(pred_plan["positions"][:, -1], trans_mat) + for pos in pos_raster: + circle = np.hstack([pos - plan_marker_size, pos + plan_marker_size]) + draw.ellipse(circle.tolist(), fill="#FF6B35") + + if plan_samples is not None: + pos_raster = agent_to_raster_np(plan_samples["positions"][0, :, -1], trans_mat) + for pos in pos_raster: + circle = np.hstack([pos - plan_marker_size, pos + plan_marker_size]) + draw.ellipse(circle.tolist(), fill="#FF6B35") + + im = np.asarray(im) + # visualize plan heat map + if pred_plan_info is not None and "location_map" in pred_plan_info: + import matplotlib.pyplot as plt + + cm = plt.get_cmap("jet") + heatmap = pred_plan_info["location_map"][0] + heatmap = heatmap - heatmap.min() + heatmap = heatmap / heatmap.max() + heatmap = cm(heatmap) + + heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)) + heatmap = heatmap.resize(size=(im.shape[1], im.shape[0])) + heatmap = np.asarray(heatmap)[..., :3] + padding = np.ones((im.shape[0], 200, 3), dtype=np.uint8) * 255 + + composite = heatmap.astype(np.float32) * 0.3 + im.astype(np.float32) * 0.7 + composite = composite.astype(np.uint8) + im = np.concatenate((im, padding, heatmap, padding, composite), axis=1) + + return im + + +def draw_agent_boxes( + image, pos, yaw, extent, raster_from_agent, outline_color, fill_color +): + boxes = get_box_world_coords_np(pos, yaw, extent) + boxes_raster = transform_points(boxes, raster_from_agent) + boxes_raster = boxes_raster.reshape((-1, 4, 2)).astype(np.int) + + im = Image.fromarray((image * 255).astype(np.uint8)) + im_draw = ImageDraw.Draw(im) + for b in boxes_raster: + im_draw.polygon( + xy=b.reshape(-1).tolist(), outline=outline_color, fill=fill_color + ) + + im = np.asarray(im).astype(np.float32) / 255.0 + return im + + +def render_state_trajdata( + batch: dict, + batch_idx: int, + action, +) -> np.ndarray: + pos = batch["hist_pos"][batch_idx, -1] + yaw = batch["hist_yaw"][batch_idx, -1] + extent = batch["extent"][batch_idx, :2] + + image = RasterizedMap.to_img( + TensorUtils.to_tensor(batch["maps"][batch_idx]), + [[0, 1, 2], [3, 4], [5, 6]], + ) + + image = draw_agent_boxes( + image, + pos=pos[None, :], + yaw=yaw[None, :], + extent=extent[None, :], + raster_from_agent=batch["raster_from_agent"][batch_idx], + outline_color=COLORS["ego_contour"], + fill_color=COLORS["ego_fill"], + ) + + scene_index = batch["scene_index"][batch_idx] + agent_scene_index = scene_index == batch["scene_index"] + agent_scene_index[batch_idx] = 0 # don't plot ego + + neigh_pos = batch["centroid"][agent_scene_index] + neigh_yaw = batch["world_yaw"][agent_scene_index] + neigh_extent = batch["extent"][agent_scene_index, :2] + + if neigh_pos.shape[0] > 0: + image = draw_agent_boxes( + image, + pos=neigh_pos, + yaw=neigh_yaw[:, None], + extent=neigh_extent, + raster_from_agent=batch["raster_from_world"][batch_idx], + outline_color=COLORS["agent_contour"], + fill_color=COLORS["agent_fill"], + ) + + plan_info = None + plan_samples = None + action_samples = None + if "plan_info" in action.agents_info: + plan_info = TensorUtils.map_ndarray( + action.agents_info["plan_info"], lambda x: x[[batch_idx]] + ) + if "plan_samples" in action.agents_info: + plan_samples = TensorUtils.map_ndarray( + action.agents_info["plan_samples"], lambda x: x[[batch_idx]] + ) + if "action_samples" in action.agents_info: + action_samples = TensorUtils.map_ndarray( + action.agents_info["action_samples"], lambda x: x[[batch_idx]] + ) + + vis_action = TensorUtils.map_ndarray( + action.agents.to_dict(), lambda x: x[batch_idx] + ) + image = draw_actions( + image, + trans_mat=batch["raster_from_agent"][batch_idx], + pred_action=vis_action, + pred_plan_info=plan_info, + ego_action_samples=action_samples, + plan_samples=plan_samples, + action_marker_size=2, + plan_marker_size=3, + ) + return image + + +def get_state_image_with_boxes_l5kit(ego_obs, agents_obs, rasterizer): + yaw = ego_obs["world_yaw"] # set to 0 to fix the video + state_im = rasterizer.rasterize(ego_obs["centroid"], yaw) + + raster_from_world = rasterizer.render_context.raster_from_world( + ego_obs["centroid"], yaw + ) + raster_from_agent = raster_from_world @ ego_obs["world_from_agent"] + + state_im = draw_agent_boxes( + state_im, + agents_obs["centroid"], + agents_obs["world_yaw"][:, None], + agents_obs["extent"][:, :2], + raster_from_world, + outline_color=COLORS["agent_contour"], + fill_color=COLORS["agent_fill"], + ) + + state_im = draw_agent_boxes( + state_im, + ego_obs["centroid"][None], + ego_obs["world_yaw"][None, None], + ego_obs["extent"][None, :2], + raster_from_world, + outline_color=COLORS["ego_contour"], + fill_color=COLORS["ego_fill"], + ) + + return state_im, raster_from_agent, raster_from_world + + +def get_agent_edge(xy, h, extent): + edges = ( + np.array([[0.5, 0.5], [0.5, -0.5], [-0.5, -0.5], [-0.5, 0.5]]) + * extent[np.newaxis, :2] + ) + rotM = np.array([[np.cos(h), -np.sin(h)], [np.sin(h), np.cos(h)]]) + edges = (rotM @ edges[..., np.newaxis]).squeeze(-1) + xy[np.newaxis, :] + return edges + + +def plot_scene_open_loop( + fig: figure, + traj: np.ndarray, + extent: np.ndarray, + vec_map: VectorMap, + map_from_world_tf: np.ndarray, + bbox: Optional[Tuple[float, float, float, float]] = None, + color_scheme="blue_red", + mask=None, +): + Na = traj.shape[0] + static_glyphs = draw_map_elems(fig, vec_map, map_from_world_tf, bbox) + agent_edge = np.stack( + [ + get_agent_edge(traj[i, 0, :2], traj[i, 0, 2], extent[i, :2]) + for i in range(Na) + ], + 0, + ) + agent_edge = batch_nd_transform_points_np(agent_edge, map_from_world_tf[None]) + agent_patches = defaultdict(lambda: None) + traj_lines = defaultdict(lambda: None) + traj_xy = batch_nd_transform_points_np(traj[..., :2], map_from_world_tf[None]) + + if color_scheme == "blue_red": + agent_color = ["red"] + ["blue"] * (Na - 1) + elif color_scheme == "palette": + palette = bokeh.palettes.Category20[20] + agent_color = ["blueviolet"] + [palette[i % 20] for i in range(Na - 1)] + for i in range(Na): + if mask is None or mask[i]: + agent_patches[i] = fig.patch( + x=agent_edge[i, :, 0], + y=agent_edge[i, :, 1], + color=agent_color[i], + ) + traj_lines[i] = fig.line( + x=traj_xy[i, :, 0], + y=traj_xy[i, :, 1], + color=agent_color[i], + line_width=2, + ) + return static_glyphs, agent_patches, traj_lines + + +def delete_files_in_directory(directory_path): + try: + files = os.listdir(directory_path) + for file in files: + file_path = os.path.join(directory_path, file) + if os.path.isfile(file_path): + os.remove(file_path) + except OSError: + print("Error occurred while deleting files.") + + +def make_gif(frame_folder, gif_name, duration=100, loop=0): + frames = [ + Image.open(image) + for image in sorted(glob.glob(f"{frame_folder}/*.png"), key=os.path.getmtime) + ] + frame_one = frames[0] + frame_one.save( + gif_name, + format="GIF", + append_images=frames, + save_all=True, + duration=duration, + loop=loop, + ) + + +def animate_scene_open_loop( + fig: figure, + traj: np.ndarray, + extent: np.ndarray, + vec_map: VectorMap, + map_from_world_tf: np.ndarray, + # rel_bbox: Optional[Tuple[float, float, float, float]] = None, + bbox: Optional[Tuple[float, float, float, float]] = None, + color_scheme="blue_red", + mask=None, + dt=0.1, + tmp_dir="tmp", + gif_name="diffstack_anim.gif", +): + Na, T = traj.shape[:2] + # traj_xy = batch_nd_transform_points_np(traj[..., :2], map_from_world_tf[None]) + # bbox = ( + # [ + # rel_bbox[0] + traj_xy[0, 0, 0], + # rel_bbox[1] + traj_xy[0, 0, 0], + # rel_bbox[2] + traj_xy[0, 0, 1], + # rel_bbox[3] + traj_xy[0, 0, 1], + # ] + # if rel_bbox is not None + # else None + # ) + static_glyphs = draw_map_elems(fig, vec_map, map_from_world_tf, bbox) + agent_edge = np.stack( + [ + get_agent_edge(traj[i, 0, :2], traj[i, 0, 2], extent[i, :2]) + for i in range(Na) + ], + 0, + ) + agent_edge = batch_nd_transform_points_np(agent_edge, map_from_world_tf[None]) + agent_patches = defaultdict(lambda: None) + traj_lines = defaultdict(lambda: None) + + agent_xy_source = defaultdict(lambda: None) + if color_scheme == "blue_red": + agent_color = ["red"] + ["blue"] * (Na - 1) + elif color_scheme == "palette": + palette = bokeh.palettes.Category20[20] + agent_color = ["blueviolet"] + [palette[i % 20] for i in range(Na - 1)] + for i in range(Na): + if mask is None or mask[i]: + agent_xy_source[i] = ColumnDataSource( + data=dict(x=agent_edge[i, :, 0], y=agent_edge[i, :, 1]) + ) + + agent_patches[i] = fig.patch( + x="x", + y="y", + color=agent_color[i], + source=agent_xy_source[i], + name=f"patch_{i}", + ) + if os.path.exists(tmp_dir): + if os.path.isfile(tmp_dir): + os.remove(tmp_dir) + else: + os.mkdir(tmp_dir) + delete_files_in_directory(tmp_dir) + + for t in range(T): + agent_edge = np.stack( + [ + get_agent_edge(traj[i, t, :2], traj[i, t, 2], extent[i, :2]) + for i in range(Na) + ], + 0, + ) + agent_edge = batch_nd_transform_points_np(agent_edge, map_from_world_tf[None]) + for i in range(Na): + if mask is None or mask[i]: + new_source_data = dict(x=agent_edge[i, :, 0], y=agent_edge[i, :, 1]) + patch = fig.select_one({"name": f"patch_{i}"}) + patch.data_source.data = new_source_data + export_png(fig, filename=tmp_dir + "/plot_" + str(t) + ".png") + + make_gif(tmp_dir, gif_name, duration=dt * 1000) + delete_files_in_directory(tmp_dir) + return static_glyphs, agent_patches, traj_lines diff --git a/diffstack/utils/visualization.py b/diffstack/utils/visualization.py index bdbb8f2..2946890 100644 --- a/diffstack/utils/visualization.py +++ b/diffstack/utils/visualization.py @@ -1,4 +1,3 @@ -from enum import unique import numpy as np import torch @@ -12,8 +11,8 @@ from trajdata.data_structures.batch import AgentBatch, SceneBatch from trajdata.maps import RasterizedMap -from diffstack.utils.utils import subsample_traj -from diffstack.utils.pred_utils import compute_ade_pt +from diffstack.utils.utils import subsample_future +from diffstack.utils.metrics import compute_ade_pt import matplotlib.pyplot as plt from matplotlib.figure import Figure @@ -190,18 +189,19 @@ def plot_plan_input_batch( def plot_plan_result( plan_x: np.ndarray, + control_x: np.ndarray = None, ax: Optional[Axes] = None, ) -> None: if ax is None: _, ax = plt.subplots() ax.plot(plan_x[..., 0], plan_x[..., 1], c="blue", ls="-", label="Plan") - + if control_x is not None: + ax.plot(control_x[..., 0], control_x[..., 1], c="purple", ls="-", label="Control") return ax def plot_plan_candidates( plan_candidates_x: np.ndarray, - batch_idx: int, ax: Optional[Axes] = None, ) -> None: if ax is None: @@ -563,7 +563,7 @@ def animate(i): label_hcost_xy = traj_xy[plan_iters_i['label_hcost']] plan_gt_xy = plan_iters_i['x_gt'][:, :2] plan_gt_xy = np.concatenate([fan_plan_xy[:1, :2], plan_gt_xy], axis=0) # append t0 - fan_mse = np.square(subsample_traj(fan_plan_xy, ph, planh)[1:] - plan_gt_xy[1:, :2]).sum(axis=-1).mean() + fan_mse = np.square(subsample_future(fan_plan_xy, ph, planh)[1:] - plan_gt_xy[1:, :2]).sum(axis=-1).mean() mpc_plan_cost = plan_iters['cost'][-1][batch_i].cpu().numpy() / plan_gt_xy.shape[0] # from sum over time to mean over time # Filter @@ -683,7 +683,7 @@ def animate(i): label_hcost_xy = traj_xy[plan_iters_i['label_hcost']] plan_gt_xy = plan_iters_i['x_gt'][:, :2] plan_gt_xy = np.concatenate([fan_plan_xy[:1, :2], plan_gt_xy], axis=0) # append t0 - fan_mse = np.square(subsample_traj(fan_plan_xy, ph, planh)[1:] - plan_gt_xy[1:, :2]).sum(axis=-1).mean() + fan_mse = np.square(subsample_future(fan_plan_xy, ph, planh)[1:] - plan_gt_xy[1:, :2]).sum(axis=-1).mean() mpc_plan_cost = plan_iters['cost'][-1][batch_i].cpu().numpy() / plan_gt_xy.shape[0] # from sum over time to mean over time # Filter diff --git a/diffstack_modules.png b/diffstack_modules.png deleted file mode 100644 index 3c99579..0000000 Binary files a/diffstack_modules.png and /dev/null differ diff --git a/patches/trajdata_vectorize.patch b/patches/trajdata_vectorize.patch new file mode 100644 index 0000000..529585b --- /dev/null +++ b/patches/trajdata_vectorize.patch @@ -0,0 +1,7097 @@ +diff --git a/.gitignore b/.gitignore +index 1e55dd9..1ad9436 100644 +--- a/.gitignore ++++ b/.gitignore +@@ -157,3 +157,4 @@ cython_debug/ + # option (not recommended) you can uncomment the following to ignore the entire idea folder. + #.idea/ + ++tests/Drivesim_scene_generation.html +diff --git a/CITATION.cff b/CITATION.cff +index e793eb6..dbf28a0 100644 +--- a/CITATION.cff ++++ b/CITATION.cff +@@ -5,23 +5,7 @@ authors: + given-names: "Boris" + orcid: "https://orcid.org/0000-0002-8698-202X" + title: "trajdata: A unified interface to many trajectory forecasting datasets" +-version: 1.3.3 ++version: 1.0.3 + doi: 10.5281/zenodo.6671548 +-date-released: 2023-08-22 +-url: "https://github.com/nvr-avg/trajdata" +-preferred-citation: +- type: conference-paper +- authors: +- - family-names: "Ivanovic" +- given-names: "Boris" +- orcid: "https://orcid.org/0000-0002-8698-202X" +- - family-names: "Song" +- given-names: "Guanyu" +- - family-names: "Gilitschenski" +- given-names: "Igor" +- - family-names: "Pavone" +- given-names: "Marco" +- journal: "Proceedings of the Neural Information Processing Systems Track on Datasets and Benchmarks" +- month: 12 +- title: "trajdata: A Unified Interface to Multiple Human Trajectory Datasets" +- year: 2023 +\ No newline at end of file ++date-released: 2022-06-20 ++url: "https://github.com/nvr-avg/trajdata" +\ No newline at end of file +diff --git a/README.md b/README.md +index 774d6d6..a715ebc 100644 +--- a/README.md ++++ b/README.md +@@ -1,4 +1,4 @@ +-# trajdata: A Unified Interface to Multiple Human Trajectory Datasets ++# Unified Trajectory Data Loader + + [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) + [![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) +@@ -6,10 +6,6 @@ + [![DOI](https://zenodo.org/badge/488789438.svg)](https://zenodo.org/badge/latestdoi/488789438) + [![PyPI version](https://badge.fury.io/py/trajdata.svg)](https://badge.fury.io/py/trajdata) + +-### Announcements +- +-**Sept 2023**: [Our paper about trajdata](https://arxiv.org/abs/2307.13924) has been accepted to the NeurIPS 2023 Datasets and Benchmarks Track! +- + ## Installation + + The easiest way to install trajdata is through PyPI with +@@ -211,20 +207,6 @@ for t in range(1, sim_scene.scene.length_timesteps): + + `examples/sim_example.py` contains a more comprehensive example which initializes a simulation from a scene in the nuScenes mini dataset, steps through it by replaying agents' GT motions, and computes metrics based on scene statistics (e.g., displacement error from the original GT data, velocity/acceleration/jerk histograms). + +-## Citation +- +-If you use this software, please cite it as follows: +-``` +-@Inproceedings{ivanovic2023trajdata, +- author = {Ivanovic, Boris and Song, Guanyu and Gilitschenski, Igor and Pavone, Marco}, +- title = {{trajdata}: A Unified Interface to Multiple Human Trajectory Datasets}, +- booktitle = {{Proceedings of the Neural Information Processing Systems (NeurIPS) Track on Datasets and Benchmarks}}, +- month = dec, +- year = {2023}, +- address = {New Orleans, USA}, +- url = {https://arxiv.org/abs/2307.13924} +-} +-``` +- + ## TODO + - Create a method like finalize() which writes all the batch information to a TFRecord/WebDataset/some other format which is (very) fast to read from for higher epoch training. ++- Add more examples to the README. +diff --git a/copy_to_public.sh b/copy_to_public.sh +new file mode 100755 +index 0000000..dba3133 +--- /dev/null ++++ b/copy_to_public.sh +@@ -0,0 +1 @@ ++rsync -ravh --progress --exclude ".gitlab*" --exclude "public/" --exclude "opendrive" --exclude "copy_to_public.sh" --exclude "*.pyc" --exclude "ngc/" --exclude "*.egg-info/" --exclude "__pycache__/" ./* public/ +\ No newline at end of file +diff --git a/examples/map_api_example.py b/examples/map_api_example.py +index 759b2e0..75077ce 100644 +--- a/examples/map_api_example.py ++++ b/examples/map_api_example.py +@@ -4,6 +4,7 @@ from typing import Dict, List, Optional + + import matplotlib.pyplot as plt + import numpy as np ++import os + + from trajdata import MapAPI, VectorMap + from trajdata.caching.df_cache import DataFrameCache +@@ -23,7 +24,7 @@ def load_random_scene(cache_path: Path, env_name: str, scene_dt: float) -> Scene + + + def main(): +- cache_path = Path("~/.unified_data_cache").expanduser() ++ cache_path = Path(os.environ["TRAJDATA_CACHE_DIR"]).expanduser() + map_api = MapAPI(cache_path) + + ### Loading random scene and initializing VectorMap. +diff --git a/ngc/Dockerfile b/ngc/Dockerfile +new file mode 100644 +index 0000000..c5841bc +--- /dev/null ++++ b/ngc/Dockerfile +@@ -0,0 +1,27 @@ ++FROM nvcr.io/nvidia/pytorch:21.03-py3 ++ARG PYTHON_VERSION=3.8 ++ ++RUN apt-get update ++RUN apt-get install htop -y ++RUN apt-get install screen -y ++RUN apt-get install psmisc -y ++ ++RUN pip install --upgrade pip ++ ++RUN pip install \ ++numpy==1.19 \ ++tqdm==4.62 \ ++matplotlib==3.5 \ ++dill==0.3.4 \ ++pandas==1.4.1 \ ++pyarrow==7.0.0 \ ++nuscenes-devkit==1.1.9 \ ++l5kit==1.5.0 \ ++black==22.1.0 \ ++isort==5.10.1 \ ++pytest==7.1.1 \ ++pytest-xdist==2.5.0 \ ++zarr==2.11.0 \ ++kornia==0.6.4 ++ ++WORKDIR /workspace/trajdata +diff --git a/ngc/copy_to_ngc.sh b/ngc/copy_to_ngc.sh +new file mode 100755 +index 0000000..c42315f +--- /dev/null ++++ b/ngc/copy_to_ngc.sh +@@ -0,0 +1,4 @@ ++#!/bin/bash ++ ++cd ../.. ++rsync -ravh --progress --exclude="*.pyc" --exclude=".git" --exclude="__pycache__" --exclude="*.egg-info" --exclude=".pytest_cache" trajdata/ ngc-trajdata/ +diff --git a/ngc/jupyter_lab.json b/ngc/jupyter_lab.json +new file mode 100644 +index 0000000..7d43c0c +--- /dev/null ++++ b/ngc/jupyter_lab.json +@@ -0,0 +1,49 @@ ++{ ++ "userLabels":[ ++ ++ ], ++ "aceId":257, ++ "aceInstance":"dgx1v.16g.8.norm", ++ "dockerImageName":"nvidian/nvr-av/trajdata:1.0", ++ "aceName":"nv-us-west-2", ++ "systemLabels":[ ++ ++ ], ++ "datasetMounts":[ ++ { ++ "containerMountPoint":"/workspace/lyft", ++ "id":90893 ++ }, ++ { ++ "containerMountPoint":"/workspace/nuScenes", ++ "id":78251 ++ } ++ ], ++ "workspaceMounts":[ ++ { ++ "containerMountPoint":"/workspace/trajdata", ++ "id":"aWUsWZ_5SA6caFr30uWmmw", ++ "mountMode":"RW" ++ }, ++ { ++ "containerMountPoint":"/workspace/trajdata_cache", ++ "id":"BgHSNdSaR0ywsc-UY-xm3w", ++ "mountMode":"RW" ++ } ++ ], ++ "replicaCount":1, ++ "publishedContainerPorts":[ ++ 8888 ++ ], ++ "reservedLabels":[ ++ ++ ], ++ "name":"ml-model.notamodel-trajdata-processing", ++ "command":"pip install -e .; jupyter lab --ip=0.0.0.0 --allow-root --no-browser --NotebookApp.token='' --notebook-dir=/ --NotebookApp.allow_origin='*' & date; nvidia-smi; sleep 1d", ++ "runPolicy":{ ++ "minTimesliceSeconds":1, ++ "totalRuntimeSeconds":86400, ++ "preemptClass":"RUNONCE" ++ }, ++ "resultContainerMountPoint":"/results" ++} +diff --git a/ngc/ngc_test.py b/ngc/ngc_test.py +new file mode 100644 +index 0000000..c63312e +--- /dev/null ++++ b/ngc/ngc_test.py +@@ -0,0 +1,57 @@ ++import os ++from collections import defaultdict ++ ++from torch.utils.data import DataLoader ++from tqdm import tqdm ++ ++from trajdata import AgentBatch, AgentType, UnifiedDataset ++from trajdata.augmentation import NoiseHistories ++from trajdata.visualization.vis import plot_agent_batch ++ ++ ++# @profile ++def main(): ++ noise_hists = NoiseHistories() ++ ++ dataset = UnifiedDataset( ++ desired_data=["nusc"], ++ centric="agent", ++ # desired_dt=0.1, ++ history_sec=(0.1, 1.5), ++ future_sec=(0.1, 5.0), ++ only_types=[AgentType.VEHICLE, AgentType.PEDESTRIAN], ++ agent_interaction_distances=defaultdict(lambda: 40.0), ++ incl_robot_future=True, ++ incl_raster_map=True, ++ raster_map_params={ ++ "px_per_m": 2, ++ "map_size_px": 224, ++ "offset_frac_xy": (-0.5, 0.0), ++ }, ++ # augmentations=[noise_hists], ++ data_dirs={ ++ "nusc": "/workspace/datasets/nuScenes", ++ }, ++ cache_location="/workspace/unified_data_cache", ++ num_workers=os.cpu_count(), ++ # verbose=True, ++ ) ++ ++ print(f"# Data Samples: {len(dataset):,}") ++ ++ dataloader = DataLoader( ++ dataset, ++ batch_size=64, ++ shuffle=True, ++ collate_fn=dataset.get_collate_fn(), ++ num_workers=os.cpu_count(), ++ ) ++ ++ batch: AgentBatch ++ for batch in tqdm(dataloader): ++ pass ++ # plot_agent_batch(batch, batch_idx=0) ++ ++ ++if __name__ == "__main__": ++ main() +diff --git a/ngc/preprocess_data_ngc.py b/ngc/preprocess_data_ngc.py +new file mode 100644 +index 0000000..920f439 +--- /dev/null ++++ b/ngc/preprocess_data_ngc.py +@@ -0,0 +1,33 @@ ++from trajdata import UnifiedDataset ++ ++ ++def main(): ++ dataset = UnifiedDataset( ++ desired_data=[ ++ "nusc_trainval", ++ "nusc_mini", ++ "lyft_sample", ++ "lyft_train", ++ # "lyft_train_full", ++ "lyft_val", ++ ], ++ data_dirs={ ++ "nusc_trainval": "/workspace/nuScenes", ++ "nusc_mini": "/workspace/nuScenes", ++ "lyft_sample": "/workspace/lyft/lyft_prediction/scenes/sample.zarr", ++ "lyft_train": "/workspace/lyft/lyft_prediction/scenes/train.zarr", ++ # "lyft_train_full": "/workspace/lyft/lyft_prediction/scenes/train_full.zarr", ++ "lyft_val": "/workspace/lyft/lyft_prediction/scenes/validate.zarr", ++ }, ++ cache_location="/workspace/trajdata_cache", ++ rebuild_cache=True, ++ rebuild_maps=True, ++ num_workers=64, ++ verbose=True, ++ ) ++ ++ print(f"Total Data Samples: {len(dataset):,}") ++ ++ ++if __name__ == "__main__": ++ main() +diff --git a/requirements.txt b/requirements.txt +new file mode 100644 +index 0000000..97d46f9 +--- /dev/null ++++ b/requirements.txt +@@ -0,0 +1,26 @@ ++numpy ++tqdm ++matplotlib==3.3.4 ++dill ++pandas ++seaborn ++pyarrow ++torch ++zarr ++kornia ++intervaltree ++ ++# nuScenes devkit ++nuscenes-devkit==1.1.9 ++ ++# Lyft Level 5 devkit ++protobuf==3.19.4 ++l5kit==1.5.0 ++ ++# Development ++black ++isort ++pytest ++pytest-xdist ++twine ++build +diff --git a/setup.cfg b/setup.cfg +new file mode 100644 +index 0000000..01016fc +--- /dev/null ++++ b/setup.cfg +@@ -0,0 +1,50 @@ ++[metadata] ++name = trajdata ++version = 1.1.0 ++author = Boris Ivanovic ++author_email = bivanovic@nvidia.com ++description = A unified interface to many trajectory forecasting datasets. ++long_description = file: README.md ++long_description_content_type = text/markdown ++# license = Apache License 2.0 ++url = https://github.com/nvr-avg/trajdata ++classifiers = ++ Development Status :: 3 - Alpha ++ Intended Audience :: Developers ++ Programming Language :: Python :: 3.8 ++ License :: OSI Approved :: Apache Software License ++ ++[options] ++package_dir = ++ = src ++packages = find: ++python_requires = >=3.8 ++install_requires = ++ numpy>=1.19 ++ tqdm>=4.62 ++ matplotlib>=3.5 ++ dill>=0.3.4 ++ pandas>=1.4.1 ++ pyarrow>=7.0.0 ++ torch>=1.10.2 ++ zarr>=2.11.0 ++ kornia>=0.6.4 ++ seaborn>=0.12 ++ intervaltree ++ ++[options.packages.find] ++where = src ++ ++[options.extras_require] ++dev = ++ black ++ isort ++ pytest ++ pytest-xdist ++ twine ++ build ++nusc = ++ nuscenes-devkit==1.1.9 ++lyft = ++ protobuf==3.20.3 ++ l5kit==1.5.0 +diff --git a/src/trajdata/caching/df_cache.py b/src/trajdata/caching/df_cache.py +index eb0efef..654977c 100644 +--- a/src/trajdata/caching/df_cache.py ++++ b/src/trajdata/caching/df_cache.py +@@ -30,6 +30,7 @@ from trajdata.data_structures.scene_metadata import Scene + from trajdata.data_structures.state import NP_STATE_TYPES, StateArray + from trajdata.maps.traffic_light_status import TrafficLightStatus + from trajdata.utils import arr_utils, df_utils, raster_utils, state_utils ++from trajdata.utils.scene_utils import is_integer_robust + + STATE_COLS: Final[List[str]] = ["x", "y", "z", "vx", "vy", "ax", "ay"] + EXTENT_COLS: Final[List[str]] = ["length", "width", "height"] +@@ -320,7 +321,7 @@ class DataFrameCache(SceneCache): + def _upsample_data( + self, new_index: pd.MultiIndex, upsample_dt_ratio: float, method: str + ) -> pd.DataFrame: +- upsample_dt_factor: int = int(upsample_dt_ratio) ++ upsample_dt_factor: int = int(round(upsample_dt_ratio)) + + interpolated_df: pd.DataFrame = pd.DataFrame( + index=new_index, columns=self.scene_data_df.columns +@@ -353,7 +354,7 @@ class DataFrameCache(SceneCache): + def _downsample_data( + self, new_index: pd.MultiIndex, downsample_dt_ratio: float + ) -> pd.DataFrame: +- downsample_dt_factor: int = int(downsample_dt_ratio) ++ downsample_dt_factor: int = int(round(downsample_dt_ratio)) + + subsample_index: pd.MultiIndex = new_index.set_levels( + new_index.levels[1] * downsample_dt_factor, level=1 +@@ -368,7 +369,8 @@ class DataFrameCache(SceneCache): + def interpolate_data(self, desired_dt: float, method: str = "linear") -> None: + upsample_dt_ratio: float = self.scene.env_metadata.dt / desired_dt + downsample_dt_ratio: float = desired_dt / self.scene.env_metadata.dt +- if not upsample_dt_ratio.is_integer() and not downsample_dt_ratio.is_integer(): ++ ++ if not is_integer_robust(upsample_dt_ratio) and not is_integer_robust(downsample_dt_ratio): + raise ValueError( + f"{str(self.scene)}'s dt of {self.scene.dt}s " + f"is not integer divisible by the desired dt {desired_dt}s." +diff --git a/src/trajdata/data_structures/agent.py b/src/trajdata/data_structures/agent.py +index 7d21217..0687d8e 100644 +--- a/src/trajdata/data_structures/agent.py ++++ b/src/trajdata/data_structures/agent.py +@@ -11,6 +11,7 @@ class AgentType(IntEnum): + PEDESTRIAN = 2 + BICYCLE = 3 + MOTORCYCLE = 4 ++ STATIC = 5 + + + class Extent: +diff --git a/src/trajdata/data_structures/batch.py b/src/trajdata/data_structures/batch.py +index c93fb8b..1903557 100644 +--- a/src/trajdata/data_structures/batch.py ++++ b/src/trajdata/data_structures/batch.py +@@ -1,6 +1,6 @@ + from __future__ import annotations + +-from dataclasses import dataclass ++from dataclasses import dataclass, replace + from typing import Dict, List, Optional, Union + + import torch +@@ -9,7 +9,7 @@ from torch import Tensor + from trajdata.data_structures.agent import AgentType + from trajdata.data_structures.state import StateTensor + from trajdata.maps import VectorMap +-from trajdata.utils.arr_utils import PadDirection ++from trajdata.utils.arr_utils import PadDirection, batch_nd_transform_xyvvaahh_pt, roll_with_tensor, transform_xyh_torch + + + @dataclass +@@ -39,6 +39,10 @@ class AgentBatch: + map_names: Optional[List[str]] + maps: Optional[Tensor] + maps_resolution: Optional[Tensor] ++ lane_xyh: Optional[Tensor] ++ lane_adj: Optional[Tensor] ++ lane_ids: Optional[List[List[str]]] ++ lane_mask: Optional[Tensor] + vector_maps: Optional[List[VectorMap]] + rasters_from_world_tf: Optional[Tensor] + agents_from_world_tf: Tensor +@@ -141,12 +145,16 @@ class AgentBatch: + maps_resolution=_filter(self.maps_resolution) + if self.maps_resolution is not None + else None, +- vector_maps=_filter(self.vector_maps) ++ vector_maps=_filter_tensor_or_list(self.vector_maps) + if self.vector_maps is not None + else None, + rasters_from_world_tf=_filter(self.rasters_from_world_tf) + if self.rasters_from_world_tf is not None + else None, ++ lane_xyh=_filter(self.lane_xyh) if self.lane_xyh is not None else None, ++ lane_adj=_filter(self.lane_adj) if self.lane_adj is not None else None, ++ lane_ids=self.lane_ids, ++ lane_mask=_filter(self.lane_mask) if self.lane_mask is not None else None, + agents_from_world_tf=_filter(self.agents_from_world_tf), + scene_ids=_filter_tensor_or_list(self.scene_ids), + history_pad_dir=self.history_pad_dir, +@@ -155,6 +163,48 @@ class AgentBatch: + }, + ) + ++ def to_scene_batch(self, agent_ind: int) -> SceneBatch: ++ """ ++ Converts AgentBatch to SeceneBatch by combining neighbors and agent. ++ ++ The agent of AgentBatch will be treated as if it was the last neighbor. ++ self.extras will be simply copied over, any custom conversion must be ++ implemented externally. ++ """ ++ ++ batch_size = self.neigh_hist.shape[0] ++ num_neigh = self.neigh_hist.shape[1] ++ ++ combine = lambda neigh, agent: torch.cat((neigh, agent.unsqueeze(0)), dim=0) ++ combine_list = lambda neigh, agent: neigh + [agent] ++ ++ return SceneBatch( ++ data_idx=self.data_idx, ++ scene_ts=self.scene_ts, ++ dt=self.dt, ++ num_agents=self.num_neigh + 1, ++ agent_type=combine(self.neigh_types, self.agent_type), ++ centered_agent_state=self.curr_agent_state, # TODO this is not actually the agent but the `global` coordinate frame ++ agent_names=combine_list(["UNKNOWN" for _ in range(num_neigh)], self.agent_name), ++ agent_hist=combine(self.neigh_hist, self.agent_hist), ++ agent_hist_extent=combine(self.neigh_hist_extents, self.agent_hist_extent), ++ agent_hist_len=combine(self.neigh_hist_len, self.agent_hist_len), ++ agent_fut=combine(self.neigh_fut, self.agent_fut), ++ agent_fut_extent=combine(self.neigh_fut_extents, self.agent_fut_extent), ++ agent_fut_len=combine(self.neigh_fut_len, self.agent_fut_len), ++ robot_fut=self.robot_fut, ++ robot_fut_len=self.robot_fut_len, ++ map_names=self.map_names, # TODO ++ maps=self.maps, ++ maps_resolution=self.maps_resolution, ++ vector_maps=self.vector_maps, ++ rasters_from_world_tf=self.rasters_from_world_tf, ++ centered_agent_from_world_tf=self.agents_from_world_tf, ++ centered_world_from_agent_tf=torch.linalg.inv(self.agents_from_world_tf), ++ scene_ids=self.scene_ids, ++ history_pad_dir=self.history_pad_dir, ++ extras=self.extras, ++ ) + + @dataclass + class SceneBatch: +@@ -164,7 +214,8 @@ class SceneBatch: + num_agents: Tensor + agent_type: Tensor + centered_agent_state: StateTensor +- agent_names: List[str] ++ agent_names: List[List[str]] ++ track_ids:Optional[List[List[str]]] + agent_hist: StateTensor + agent_hist_extent: Tensor + agent_hist_len: Tensor +@@ -177,6 +228,10 @@ class SceneBatch: + maps: Optional[Tensor] + maps_resolution: Optional[Tensor] + vector_maps: Optional[List[VectorMap]] ++ lane_xyh: Optional[Tensor] ++ lane_adj: Optional[Tensor] ++ lane_ids: Optional(List[List[str]]) ++ lane_mask: Optional[Tensor] + rasters_from_world_tf: Optional[Tensor] + centered_agent_from_world_tf: Tensor + centered_world_from_agent_tf: Tensor +@@ -186,12 +241,19 @@ class SceneBatch: + + def to(self, device) -> None: + excl_vals = { ++ "num_agents", + "agent_names", ++ "track_ids", ++ "agent_type", ++ "agent_hist_len", ++ "agent_fut_len", ++ "robot_fut_len", + "map_names", + "vector_maps", + "history_pad_dir", + "scene_ids", + "extras", ++ "lane_ids", + } + + for val in vars(self).keys(): +@@ -205,6 +267,31 @@ class SceneBatch: + self.extras[key] = val.__to__(device, non_blocking=True) + else: + self.extras[key] = val.to(device, non_blocking=True) ++ return self ++ ++ def astype(self, dtype) -> None: ++ new_obj = replace(self) ++ excl_vals = { ++ "num_agents", ++ "agent_names", ++ "track_ids", ++ "agent_type", ++ "agent_hist_len", ++ "agent_fut_len", ++ "robot_fut_len", ++ "map_names", ++ "vector_maps", ++ "history_pad_dir", ++ "scene_ids", ++ "extras", ++ "lane_ids", ++ } ++ ++ for val in vars(self).keys(): ++ tensor_val = getattr(self, val) ++ if val not in excl_vals and tensor_val is not None: ++ setattr(new_obj, val, tensor_val.type(dtype)) ++ return new_obj + + def agent_types(self) -> List[AgentType]: + unique_types: Tensor = torch.unique(self.agent_type) +@@ -214,13 +301,31 @@ class SceneBatch: + if unique_type >= 0 + ] + +- def for_agent_type(self, agent_type: AgentType) -> SceneBatch: +- match_type = self.agent_type == agent_type +- return self.filter_batch(match_type) ++ def copy(self): ++ # Shallow copy ++ return replace(self) ++ ++ def convert_pad_direction(self, pad_dir: PadDirection) -> SceneBatch: ++ if self.history_pad_dir == pad_dir: ++ return self ++ batch: SceneBatch = self.copy() ++ if self.history_pad_dir == PadDirection.BEFORE: ++ # n, n, -2 , -1, 0 --> -2, -1, 0, n, n ++ shifts = batch.agent_hist_len ++ else: ++ # -2, -1, 0, n, n --> n, n, -2 , -1, 0 ++ shifts = -batch.agent_hist_len ++ batch.agent_hist = roll_with_tensor(batch.agent_hist, shifts, dim=-2) ++ batch.agent_hist_extent = roll_with_tensor(batch.agent_hist_extent, shifts, dim=-2) ++ batch.history_pad_dir = pad_dir ++ return batch + +- def filter_batch(self, filter_mask: torch.tensor) -> SceneBatch: ++ def filter_batch(self, filter_mask: torch.Tensor) -> SceneBatch: + """Build a new batch with elements for which filter_mask[i] == True.""" + ++ if filter_mask.ndim != 1: ++ raise ValueError("Expected 1d filter mask.") ++ + # Some of the tensors might be on different devices, so we define some convenience functions + # to make sure the filter_mask is always on the same device as the tensor we are indexing. + filter_mask_dict = {} +@@ -229,7 +334,8 @@ class SceneBatch: + self.agent_hist.device + ) + +- _filter = lambda tensor: tensor[filter_mask_dict[str(tensor.device)]] ++ # Use tensor.__class__ to keep TensorState. This might ++ _filter = lambda tensor: tensor.__class__(tensor[filter_mask_dict[str(tensor.device)]]) + _filter_tensor_or_list = lambda tensor_or_list: ( + _filter(tensor_or_list) + if isinstance(tensor_or_list, torch.Tensor) +@@ -248,6 +354,8 @@ class SceneBatch: + dt=_filter(self.dt), + num_agents=_filter(self.num_agents), + agent_type=_filter(self.agent_type), ++ agent_names=_filter_tensor_or_list(self.agent_names), ++ track_ids=_filter_tensor_or_list(self.track_ids), + centered_agent_state=_filter(self.centered_agent_state), + agent_hist=_filter(self.agent_hist), + agent_hist_extent=_filter(self.agent_hist_extent), +@@ -266,9 +374,13 @@ class SceneBatch: + maps_resolution=_filter(self.maps_resolution) + if self.maps_resolution is not None + else None, +- vector_maps=_filter(self.vector_maps) ++ vector_maps=_filter_tensor_or_list(self.vector_maps) + if self.vector_maps is not None + else None, ++ lane_xyh=_filter(self.lane_xyh) if self.lane_xyh is not None else None, ++ lane_adj=_filter(self.lane_adj) if self.lane_adj is not None else None, ++ lane_ids = self.lane_ids, ++ lane_mask = _filter(self.lane_mask) if self.lane_mask is not None else None, + rasters_from_world_tf=_filter(self.rasters_from_world_tf) + if self.rasters_from_world_tf is not None + else None, +@@ -277,7 +389,7 @@ class SceneBatch: + scene_ids=_filter_tensor_or_list(self.scene_ids), + history_pad_dir=self.history_pad_dir, + extras={ +- key: _filter_tensor_or_list(val, filter_mask) ++ key: _filter_tensor_or_list(val) + for key, val in self.extras.items() + }, + ) +@@ -321,31 +433,68 @@ class SceneBatch: + scene_ts=self.scene_ts, + dt=self.dt, + agent_name=index_agent_list(self.agent_names), ++ track_ids=index_agent_list(self.track_ids), + agent_type=index_agent(self.agent_type), + curr_agent_state=self.centered_agent_state, # TODO this is not actually the agent but the `global` coordinate frame +- agent_hist=index_agent(self.agent_hist), ++ agent_hist=StateTensor.from_array(index_agent(self.agent_hist), self.agent_hist._format), + agent_hist_extent=index_agent(self.agent_hist_extent), + agent_hist_len=index_agent(self.agent_hist_len), +- agent_fut=index_agent(self.agent_fut), ++ agent_fut=StateTensor.from_array(index_agent(self.agent_fut), self.agent_fut._format), + agent_fut_extent=index_agent(self.agent_fut_extent), + agent_fut_len=index_agent(self.agent_fut_len), + num_neigh=self.num_agents - 1, + neigh_types=index_neighbors(self.agent_type), +- neigh_hist=index_neighbors(self.agent_hist), ++ neigh_hist=StateTensor.from_array(index_neighbors(self.agent_hist), self.agent_hist._format), + neigh_hist_extents=index_neighbors(self.agent_hist_extent), + neigh_hist_len=index_neighbors(self.agent_hist_len), +- neigh_fut=index_neighbors(self.agent_fut), ++ neigh_fut=StateTensor.from_array(index_neighbors(self.agent_fut), self.agent_fut._format), + neigh_fut_extents=index_neighbors(self.agent_fut_extent), + neigh_fut_len=index_neighbors(self.agent_fut_len), + robot_fut=self.robot_fut, + robot_fut_len=self.robot_fut_len, +- map_names=index_agent_list(self.map_names), +- maps=index_agent(self.maps), +- vector_maps=index_agent(self.vector_maps), +- maps_resolution=index_agent(self.maps_resolution), +- rasters_from_world_tf=index_agent(self.rasters_from_world_tf), ++ map_names=self.map_names, ++ maps=self.maps, ++ vector_maps=self.vector_maps, ++ maps_resolution=self.maps_resolution, ++ rasters_from_world_tf=self.rasters_from_world_tf, + agents_from_world_tf=self.centered_agent_from_world_tf, + scene_ids=self.scene_ids, + history_pad_dir=self.history_pad_dir, + extras=self.extras, + ) ++ ++ def apply_transform(self, tf: torch.Tensor, dtype: Optional[torch.dtype] = None) -> SceneBatch: ++ """ ++ Applies a transformation matrix to all coordinates stored in the SceneBatch. ++ ++ Returns a shallow copy, only coordinate fields are replaced. ++ self.extras will be simply copied over (shallow copy), any custom conversion must be ++ implemented externally. ++ """ ++ assert tf.ndim == 3 # b, 3, 3 ++ assert tf.shape[-1] == 3 and tf.shape[-1] == 3 ++ assert tf.dtype == torch.double # tf should be double precision, otherwise we have large numerical errors ++ if dtype is None: ++ dtype = self.agent_hist.dtype ++ ++ # Shallow copy ++ batch: SceneBatch = replace(self) ++ ++ # TODO (pkarkus) support generic format ++ assert batch.agent_hist._format == "x,y,xd,yd,xdd,ydd,s,c" ++ assert batch.agent_fut._format == "x,y,xd,yd,xdd,ydd,s,c" ++ state_class = batch.agent_hist.__class__ ++ ++ # Transforms ++ batch.agent_hist = state_class(batch_nd_transform_xyvvaahh_pt(batch.agent_hist.double(), tf).type(dtype)) ++ batch.agent_fut = state_class(batch_nd_transform_xyvvaahh_pt(batch.agent_fut.double(), tf).type(dtype)) ++ batch.rasters_from_world_tf = tf.unsqueeze(1) @ batch.rasters_from_world_tf if batch.rasters_from_world_tf is not None else None ++ batch.centered_agent_from_world_tf = tf @ batch.centered_agent_from_world_tf ++ centered_world_from_agent_tf = torch.linalg.inv(batch.centered_agent_from_world_tf) ++ if batch.lane_xyh is not None: ++ batch.lane_xyh = transform_xyh_torch(batch.lane_xyh.double(), tf).type(dtype) ++ # sanity check ++ assert torch.isclose(batch.centered_world_from_agent_tf @ torch.linalg.inv(tf), centered_world_from_agent_tf, atol=1e-5).all() ++ batch.centered_world_from_agent_tf = centered_world_from_agent_tf ++ ++ return batch +diff --git a/src/trajdata/data_structures/batch_element.py b/src/trajdata/data_structures/batch_element.py +index f18e34d..012f3b8 100644 +--- a/src/trajdata/data_structures/batch_element.py ++++ b/src/trajdata/data_structures/batch_element.py +@@ -10,6 +10,11 @@ from trajdata.data_structures.scene import SceneTime, SceneTimeAgent + from trajdata.data_structures.state import StateArray + from trajdata.maps import MapAPI, RasterizedMapPatch, VectorMap + from trajdata.utils.state_utils import convert_to_frame_state, transform_from_frame ++from trajdata.utils.arr_utils import transform_xyh_np, get_close_lanes ++ ++from trajdata.utils.map_utils import LaneSegRelation ++ ++ + + + class AgentBatchElement: +@@ -158,6 +163,20 @@ class AgentBatchElement: + else None, + **vector_map_params if vector_map_params is not None else None, + ) ++ if vector_map_params.get("calc_lane_graph", False): ++ # not tested ++ ego_xyh = np.concatenate([self.curr_agent_state_np.position, self.curr_agent_state_np.heading]) ++ num_pts = vector_map_params.get("num_lane_pts", 30) ++ max_num_lanes = vector_map_params.get("max_num_lanes",20) ++ remove_single_successor = vector_map_params.get("remove_single_successor",False) ++ radius = vector_map_params.get("radius", 100) ++ self.num_lanes,self.lane_xyh,self.lane_adj,self.lane_ids = gen_lane_graph(self.vec_map,ego_xyh,self.agent_from_world_tf,num_pts,max_num_lanes,radius,remove_single_successor) ++ ++ else: ++ self.lane_xyh = None ++ self.lane_adj = None ++ self.lane_ids=list() ++ self.num_lanes = 0 + + self.scene_id = scene_time_agent.scene.name + +@@ -404,7 +423,7 @@ class SceneBatchElement: + nearby_agents, self.agent_types_np = self.get_nearby_agents( + scene_time, self.centered_agent, distance_limit + ) +- ++ self.agents = nearby_agents + self.num_agents = len(nearby_agents) + self.agent_names = [agent.name for agent in nearby_agents] + ( +@@ -440,7 +459,22 @@ class SceneBatchElement: + self.cache if self.cache.is_traffic_light_data_cached() else None, + **vector_map_params if vector_map_params is not None else None, + ) +- ++ if vector_map_params.get("calc_lane_graph", False): ++ # not tested ++ ego_xyh = np.concatenate([self.centered_agent_state_np.position, self.centered_agent_state_np.heading]) ++ num_pts = vector_map_params.get("num_lane_pts", 30) ++ max_num_lanes = vector_map_params.get("max_num_lanes",20) ++ self.num_lanes,self.lane_xyh,self.lane_adj,self.lane_ids = gen_lane_graph(self.vec_map,ego_xyh,self.centered_agent_from_world_tf,num_pts,max_num_lanes) ++ ++ else: ++ self.lane_xyh = None ++ self.lane_adj = None ++ self.lane_ids = list() ++ self.num_lanes = 0 ++ ++ ++ ++ + self.scene_id = scene_time.scene.name + + ### ROBOT DATA ### +@@ -585,6 +619,73 @@ class SceneBatchElement: + ).view(self.cache.obs_type) + return robot_curr_and_fut_np + ++ ++def gen_lane_graph(vec_map,ego_xyh,agent_from_world,num_pts=20,max_num_lanes=15,radius=150,remove_single_successor=True): ++ close_lanes,dis = get_close_lanes(radius,ego_xyh,vec_map,num_pts) ++ lanes_by_id = {lane.id:lane for lane in close_lanes} ++ dis_by_id = {lane.id:dis[i] for i,lane in enumerate(close_lanes)} ++ if remove_single_successor: ++ for lane in close_lanes: ++ ++ while len(lane.next_lanes) == 1: ++ # if there are more than one succeeding lanes, then we abort the merging ++ next_id = list(lane.next_lanes)[0] ++ ++ if next_id in lanes_by_id: ++ next_lane = lanes_by_id[next_id] ++ shared_next = False ++ for id in next_lane.prev_lanes: ++ if id != lane.id and id in lanes_by_id: ++ shared_next = True ++ break ++ if shared_next: ++ # if the next lane shares two prev lanes in the close_lanes, then we abort the merging ++ break ++ lane.combine_next(lanes_by_id[next_id]) ++ lanes_by_id.pop(next_id) ++ else: ++ break ++ close_lanes = list(lanes_by_id.values()) ++ dis = np.array([dis_by_id[lane.id] for lane in close_lanes]) ++ num_lanes = len(close_lanes) ++ if num_lanes > max_num_lanes: ++ idx = dis.argsort()[:max_num_lanes] ++ close_lanes = [lane for i,lane in enumerate(close_lanes) if i in idx] ++ num_lanes = max_num_lanes ++ ++ if num_lanes >0: ++ lane_xyh = list() ++ lane_adj = np.zeros([len(close_lanes), len(close_lanes)],dtype=np.int32) ++ lane_ids = [lane.id for lane in close_lanes] ++ ++ for i,lane in enumerate(close_lanes): ++ center = lane.center.interpolate(num_pts).points[:,[0,1,3]] ++ center_local = transform_xyh_np(center, agent_from_world[None]) ++ lane_xyh.append(center_local) ++ # construct lane adjacency matrix ++ for adj_lane_id in lane.next_lanes: ++ if adj_lane_id in lane_ids: ++ lane_adj[i,lane_ids.index(adj_lane_id)] = LaneSegRelation.NEXT.value ++ ++ for adj_lane_id in lane.prev_lanes: ++ if adj_lane_id in lane_ids: ++ lane_adj[i,lane_ids.index(adj_lane_id)] = LaneSegRelation.PREV.value ++ ++ for adj_lane_id in lane.adj_lanes_left: ++ if adj_lane_id in lane_ids: ++ lane_adj[i,lane_ids.index(adj_lane_id)] = LaneSegRelation.LEFT.value ++ ++ for adj_lane_id in lane.adj_lanes_right: ++ if adj_lane_id in lane_ids: ++ lane_adj[i,lane_ids.index(adj_lane_id)] = LaneSegRelation.RIGHT.value ++ lane_xyh = np.stack(lane_xyh, axis=0) ++ lane_xyh = lane_xyh ++ lane_adj = lane_adj ++ else: ++ lane_xyh = np.zeros([0,num_pts,3]) ++ lane_adj = np.zeros([0,0]) ++ lane_ids = list() ++ return num_lanes,lane_xyh,lane_adj,lane_ids + + def is_agent_stationary(cache: SceneCache, agent_info: AgentMetadata) -> bool: + # Agent is considered stationary if it moves less than 1m between the first and last valid timestep. +diff --git a/src/trajdata/data_structures/collation.py b/src/trajdata/data_structures/collation.py +index 1b4d2be..2b911d6 100644 +--- a/src/trajdata/data_structures/collation.py ++++ b/src/trajdata/data_structures/collation.py +@@ -11,8 +11,9 @@ from torch.nn.utils.rnn import pad_sequence + from trajdata.augmentation import BatchAugmentation + from trajdata.data_structures.batch import AgentBatch, SceneBatch + from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement ++from trajdata.maps import VectorMap, RasterizedMapPatch ++from trajdata.utils.map_utils import batch_transform_raster_maps + from trajdata.data_structures.state import TORCH_STATE_TYPES +-from trajdata.maps import VectorMap + from trajdata.utils import arr_utils + + +@@ -34,10 +35,26 @@ def _collate_data(elems): + else: + return torch.as_tensor(np.stack(elems)) + ++def _collate_lane_graph(elems): ++ num_lanes = [elem.num_lanes for elem in elems] ++ bs = len(elems) ++ M = max(num_lanes) ++ lane_xyh = np.zeros([bs,M,*elems[0].lane_xyh.shape[-2:]]) ++ lane_adj = np.zeros([bs,M,M],dtype=int) ++ lane_ids = list() ++ lane_mask = np.zeros([bs,M],dtype=int) ++ for i,elem in enumerate(elems): ++ lane_xyh[i,:num_lanes[i]] = elem.lane_xyh ++ lane_adj[i,:num_lanes[i],:num_lanes[i]] = elem.lane_adj ++ lane_ids.append(elem.lane_ids) ++ lane_mask[i,:num_lanes[i]] = 1 ++ return torch.as_tensor(lane_xyh),torch.as_tensor(lane_adj), torch.as_tensor(lane_mask), lane_ids + + def raster_map_collate_fn_agent( + batch_elems: List[AgentBatchElement], + ): ++ # TODO(pkarkus) refactor with batch_rotate_raster_maps ++ + if batch_elems[0].map_patch is None: + return None, None, None, None + +@@ -181,89 +198,48 @@ def raster_map_collate_fn_scene( + batch_elems: List[SceneBatchElement], + max_agent_num: Optional[int] = None, + pad_value: Any = np.nan, +-) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: ++) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + + if batch_elems[0].map_patches is None: + return None, None, None, None + +- patch_size: int = batch_elems[0].map_patches[0].crop_size +- assert all( +- batch_elem.map_patches[0].crop_size == patch_size for batch_elem in batch_elems +- ) +- ++ # Collect map patches for all elements and agents into a flat list + map_names: List[str] = list() + num_agents: List[int] = list() +- agents_rasters_from_world_tfs: List[np.ndarray] = list() +- agents_patches: List[np.ndarray] = list() +- agents_rot_angles_list: List[float] = list() +- agents_res_list: List[float] = list() ++ map_patches: List[RasterizedMapPatch] = list() + + for elem in batch_elems: + map_names.append(elem.map_name) + num_agents.append(min(elem.num_agents, max_agent_num)) +- agents_rasters_from_world_tfs += [ +- x.raster_from_world_tf for x in elem.map_patches[:max_agent_num] +- ] +- agents_patches += [x.data for x in elem.map_patches[:max_agent_num]] +- agents_rot_angles_list += [ +- x.rot_angle for x in elem.map_patches[:max_agent_num] +- ] +- agents_res_list += [x.resolution for x in elem.map_patches[:max_agent_num]] ++ map_patches += elem.map_patches[:max_agent_num] + +- patch_data: Tensor = torch.as_tensor(np.stack(agents_patches), dtype=torch.float) +- agents_rot_angles: Tensor = torch.as_tensor( +- np.stack(agents_rot_angles_list), dtype=torch.float +- ) +- agents_rasters_from_world_tf: Tensor = torch.as_tensor( +- np.stack(agents_rasters_from_world_tfs), dtype=torch.float +- ) +- agents_resolution: Tensor = torch.as_tensor( +- np.stack(agents_res_list), dtype=torch.float ++ # Batch transform map patches and pad ++ ( ++ rot_crop_patches, ++ agents_resolution, ++ agents_rasters_from_world_tf ++ ) = batch_rotate_raster_maps_for_agents_in_scene( ++ map_patches, num_agents, max_agent_num, pad_value + ) + +- patch_size_y, patch_size_x = patch_data.shape[-2:] +- center_y: int = patch_size_y // 2 +- center_x: int = patch_size_x // 2 +- half_extent: int = patch_size // 2 +- +- if torch.count_nonzero(agents_rot_angles) == 0: +- agents_rasters_from_world_tf = torch.bmm( +- torch.tensor( +- [ +- [ +- [1.0, 0.0, half_extent], +- [0.0, 1.0, half_extent], +- [0.0, 0.0, 1.0], +- ] +- ], +- dtype=agents_rasters_from_world_tf.dtype, +- device=agents_rasters_from_world_tf.device, +- ).expand((agents_rasters_from_world_tf.shape[0], -1, -1)), +- agents_rasters_from_world_tf, +- ) ++ return map_names, rot_crop_patches, agents_resolution, agents_rasters_from_world_tf + +- rot_crop_patches = patch_data +- else: +- agents_rasters_from_world_tf = torch.bmm( +- arr_utils.transform_matrices( +- -agents_rot_angles, +- torch.tensor([[half_extent, half_extent]]).expand( +- (agents_rot_angles.shape[0], -1) +- ), +- ), +- agents_rasters_from_world_tf, +- ) + +- # Batch rotating patches by rot_angles. +- rot_patches: Tensor = rotate(patch_data, torch.rad2deg(agents_rot_angles)) ++def batch_rotate_raster_maps_for_agents_in_scene( ++ map_patches: List[RasterizedMapPatch], ++ num_agents: List[int], ++ max_agent_num: Optional[int] = None, ++ pad_value: Any = np.nan, ++) -> Tuple[Tensor, Tensor, Tensor]: + +- # Center cropping via slicing. +- rot_crop_patches = rot_patches[ +- ..., +- center_y - half_extent : center_y + half_extent, +- center_x - half_extent : center_x + half_extent, +- ] ++ # Batch transform map patches ++ ( ++ rot_crop_patches, ++ agents_resolution, ++ agents_rasters_from_world_tf ++ ) = batch_transform_raster_maps(map_patches) + ++ # Separate batch and agents + rot_crop_patches = split_pad_crop( + rot_crop_patches, num_agents, pad_value=pad_value, desired_size=max_agent_num + ) +@@ -278,7 +254,7 @@ def raster_map_collate_fn_scene( + agents_resolution, num_agents, pad_value=0, desired_size=max_agent_num + ) + +- return map_names, rot_crop_patches, agents_resolution, agents_rasters_from_world_tf ++ return rot_crop_patches, agents_resolution, agents_rasters_from_world_tf + + + def agent_collate_fn( +@@ -673,6 +649,10 @@ def agent_collate_fn( + vector_maps: Optional[List[VectorMap]] = None + if batch_elems[0].vec_map is not None: + vector_maps = [batch_elem.vec_map for batch_elem in batch_elems] ++ ++ lane_xyh,lane_adj,lane_mask,lane_ids = None,None,None,None ++ if hasattr(batch_elems[0],"lane_xyh") and batch_elems[0].lane_xyh is not None: ++ lane_xyh,lane_adj,lane_mask, lane_ids = _collate_lane_graph(batch_elems) + + agents_from_world_tf = torch.as_tensor( + np.stack([batch_elem.agent_from_world_tf for batch_elem in batch_elems]), +@@ -712,6 +692,10 @@ def agent_collate_fn( + robot_fut_len=robot_future_len, + map_names=map_names, + maps=map_patches, ++ lane_xyh=lane_xyh, ++ lane_adj=lane_adj, ++ lane_mask=lane_mask, ++ lane_ids = lane_ids, + maps_resolution=maps_resolution, + vector_maps=vector_maps, + rasters_from_world_tf=rasters_from_world_tf, +@@ -781,6 +765,9 @@ def scene_collate_fn( + return_dict: bool, + pad_format: str, + batch_augments: Optional[List[BatchAugmentation]] = None, ++ desired_num_agents = None, ++ desired_hist_len=None, ++ desired_fut_len=None, + ) -> SceneBatch: + batch_size: int = len(batch_elems) + history_pad_dir: arr_utils.PadDirection = ( +@@ -800,6 +787,9 @@ def scene_collate_fn( + AgentObsTensor = TORCH_STATE_TYPES[obs_format] + + max_agent_num: int = max(elem.num_agents for elem in batch_elems) ++ if desired_num_agents is not None: ++ max_agent_num = max(max_agent_num,desired_num_agents) ++ + + centered_agent_state: List[AgentStateTensor] = list() + agents_types: List[Tensor] = list() +@@ -820,6 +810,10 @@ def scene_collate_fn( + + max_history_len: int = max(elem.agent_history_lens_np.max() for elem in batch_elems) + max_future_len: int = max(elem.agent_future_lens_np.max() for elem in batch_elems) ++ if desired_hist_len is not None: ++ max_history_len = max(max_history_len,desired_hist_len) ++ if desired_fut_len is not None: ++ max_future_len = max(max_future_len,desired_fut_len) + + robot_future: List[AgentObsTensor] = list() + robot_future_len: Tensor = torch.zeros((batch_size,), dtype=torch.long) +@@ -951,6 +945,10 @@ def scene_collate_fn( + if batch_elems[0].vec_map is not None: + vector_maps = [batch_elem.vec_map for batch_elem in batch_elems] + ++ lane_xyh,lane_adj,lane_mask, lane_ids = None,None,None,None ++ if hasattr(batch_elems[0],"lane_xyh") and batch_elems[0].lane_xyh is not None: ++ lane_xyh,lane_adj,lane_mask, lane_ids = _collate_lane_graph(batch_elems) ++ + centered_agent_from_world_tf = torch.as_tensor( + np.stack( + [batch_elem.centered_agent_from_world_tf for batch_elem in batch_elems] +@@ -990,6 +988,7 @@ def scene_collate_fn( + agent_type=agents_types_t, + centered_agent_state=centered_agent_state_t, + agent_names=agent_names, ++ track_ids = None, + agent_hist=agents_histories_t, + agent_hist_extent=agents_history_extents_t, + agent_hist_len=agents_history_len, +@@ -1000,6 +999,10 @@ def scene_collate_fn( + robot_fut_len=robot_future_len, + map_names=map_names, + maps=map_patches, ++ lane_xyh = lane_xyh, ++ lane_adj = lane_adj, ++ lane_mask = lane_mask, ++ lane_ids = lane_ids, + maps_resolution=maps_resolution, + vector_maps=vector_maps, + rasters_from_world_tf=rasters_from_world_tf, +diff --git a/src/trajdata/dataset.py b/src/trajdata/dataset.py +index 64f80eb..609e25e 100644 +--- a/src/trajdata/dataset.py ++++ b/src/trajdata/dataset.py +@@ -1,4 +1,5 @@ + import gc ++import json + import random + import time + from collections import defaultdict +@@ -49,7 +50,7 @@ from trajdata.data_structures import ( + from trajdata.dataset_specific import RawDataset + from trajdata.maps.map_api import MapAPI + from trajdata.parallel import ParallelDatasetPreprocessor, scene_paths_collate_fn +-from trajdata.utils import agent_utils, env_utils, scene_utils, string_utils ++from trajdata.utils import agent_utils, env_utils, py_utils, scene_utils, string_utils + from trajdata.utils.parallel_utils import parallel_iapply + + # TODO(bivanovic): Move this to a better place in the codebase. +@@ -111,6 +112,7 @@ class UnifiedDataset(Dataset): + cache_location: str = "~/.unified_data_cache", + rebuild_cache: bool = False, + rebuild_maps: bool = False, ++ save_index: bool = False, + num_workers: int = 0, + verbose: bool = False, + extras: Dict[str, Callable[..., np.ndarray]] = dict(), +@@ -152,12 +154,17 @@ class UnifiedDataset(Dataset): + cache_location (str, optional): Where to store and load preprocessed, cached data. Defaults to "~/.unified_data_cache". + rebuild_cache (bool, optional): If True, process and cache trajectory data even if it is already cached. Defaults to False. + rebuild_maps (bool, optional): If True, process and cache maps even if they are already cached. Defaults to False. ++ save_index (bool, optional): If True, save the resulting agent (or scene) data index after it is computed (speeding up subsequent initializations with the same argument values). + num_workers (int, optional): Number of parallel workers to use for dataset preprocessing and loading. Defaults to 0. + verbose (bool, optional): If True, print internal data loading information. Defaults to False. + extras (Dict[str, Callable[..., np.ndarray]], optional): Adds extra data to each batch element. Each Callable must take as input a filled {Agent,Scene}BatchElement and return an ndarray which will subsequently be added to the batch element's `extra` dict. + transforms (Iterable[Callable], optional): Allows for custom modifications of batch elements. Each Callable must take in a filled {Agent,Scene}BatchElement and return a {Agent,Scene}BatchElement. + rank (int, optional): Proccess rank when using torch DistributedDataParallel for multi-GPU training. Only the rank 0 process will be used for caching. + """ ++ self.desired_data: List[str] = desired_data ++ self.scene_description_contains: Optional[ ++ List[str] ++ ] = scene_description_contains + self.centric: str = centric + self.desired_dt: float = desired_dt + +@@ -204,6 +211,11 @@ class UnifiedDataset(Dataset): + # Collation can be quite slow if vector maps are included, + # so we do not unless the user requests it. + "collate": False, ++ # Whether loaded maps should be stored in memory (memoized) for later re-use. ++ # For datasets which provide full maps ahead-of-time (i.e., all except Waymo), ++ # this should be True. However, for Waymo it should be False because maps ++ # are already partitioned geographically and keeping them around significantly grows memory. ++ "keep_in_memory": True, + } + ) + if self.desired_dt is not None: +@@ -233,13 +245,15 @@ class UnifiedDataset(Dataset): + self.torch_obs_type = TORCH_STATE_TYPES[obs_format] + + # Ensuring scene description queries are all lowercase +- if scene_description_contains is not None: +- scene_description_contains = [s.lower() for s in scene_description_contains] ++ if self.scene_description_contains is not None: ++ self.scene_description_contains = [ ++ s.lower() for s in self.scene_description_contains ++ ] + + self.envs: List[RawDataset] = env_utils.get_raw_datasets(data_dirs) + self.envs_dict: Dict[str, RawDataset] = {env.name: env for env in self.envs} + +- matching_datasets: List[SceneTag] = self.get_matching_scene_tags(desired_data) ++ matching_datasets: List[SceneTag] = self._get_matching_scene_tags(desired_data) + if self.verbose: + print( + "Loading data for matched scene tags:", +@@ -249,7 +263,7 @@ class UnifiedDataset(Dataset): + + self._map_api: Optional[MapAPI] = None + if self.incl_vector_map: +- self._map_api = MapAPI(self.cache_path) ++ self._map_api = MapAPI(self.cache_path, keep_in_memory=vector_map_params.get("keep_in_memory", True)) + + all_scenes_list: Union[List[SceneMetadata], List[Scene]] = list() + for env in self.envs: +@@ -257,8 +271,8 @@ class UnifiedDataset(Dataset): + all_data_cached: bool = False + all_maps_cached: bool = not env.has_maps or not require_map_cache + if self.env_cache.env_is_cached(env.name) and not self.rebuild_cache: +- scenes_list: List[Scene] = self.get_desired_scenes_from_env( +- matching_datasets, scene_description_contains, env ++ scenes_list: List[Scene] = self._get_desired_scenes_from_env( ++ matching_datasets, env + ) + + all_data_cached: bool = all( +@@ -314,9 +328,9 @@ class UnifiedDataset(Dataset): + ): + distributed.barrier() + +- scenes_list: List[SceneMetadata] = self.get_desired_scenes_from_env( +- matching_datasets, scene_description_contains, env +- ) ++ scenes_list: List[ ++ SceneMetadata ++ ] = self._get_desired_scenes_from_env(matching_datasets, env) + + if self.incl_vector_map and env.metadata.map_locations is not None: + # env.metadata.map_locations can be none for map-containing +@@ -330,7 +344,7 @@ class UnifiedDataset(Dataset): + all_scenes_list += scenes_list + + # List of cached scene paths. +- scene_paths: List[Path] = self.preprocess_scene_data( ++ scene_paths: List[Path] = self._preprocess_scene_data( + all_scenes_list, num_workers + ) + if self.verbose: +@@ -343,7 +357,11 @@ class UnifiedDataset(Dataset): + data_index: Union[ + List[Tuple[str, int, np.ndarray]], + List[Tuple[str, int, List[Tuple[str, np.ndarray]]]], +- ] = self.get_data_index(num_workers, scene_paths) ++ ] ++ if self._index_cache_path().exists(): ++ data_index = self._load_data_index() ++ else: ++ data_index = self._get_data_index(num_workers, scene_paths) + + # Done with this list. Cutting memory usage because + # of multiprocessing later on. +@@ -360,8 +378,99 @@ class UnifiedDataset(Dataset): + ) + self._data_len: int = len(self._data_index) + ++ # Use only rank 0 process for caching when using multi-GPU torch training. ++ if save_index and rank == 0: ++ if self._index_cache_path().exists(): ++ print( ++ "WARNING: Overwriting already-cached data index (since save_index is True).", ++ flush=True, ++ ) ++ ++ self._cache_data_index(data_index) ++ if ( ++ distributed.is_initialized() ++ and distributed.get_world_size() > 1 ++ ): ++ distributed.barrier() + self._cached_batch_elements = None + ++ def _index_cache_path( ++ self, ret_args: bool = False ++ ) -> Union[Path, Tuple[Path, Dict[str, Any]]]: ++ # Whichever UnifiedDataset arguments affect data indexing are captured ++ # and hashed together here. ++ impactful_args: Dict[str, Any] = { ++ "desired_data": tuple(self.desired_data), ++ "scene_description_contains": tuple(self.scene_description_contains) ++ if self.scene_description_contains is not None ++ else None, ++ "centric": self.centric, ++ "desired_dt": self.desired_dt, ++ "history_sec": self.history_sec, ++ "future_sec": self.future_sec, ++ "incl_robot_future": self.incl_robot_future, ++ "only_types": tuple(t.name for t in self.only_types) ++ if self.only_types is not None ++ else None, ++ "only_predict": tuple(t.name for t in self.only_predict) ++ if self.only_predict is not None ++ else None, ++ "no_types": tuple(t.name for t in self.no_types) ++ if self.no_types is not None ++ else None, ++ "ego_only": self.ego_only, ++ } ++ index_hash: str = py_utils.hash_dict(impactful_args) ++ index_cache_path: Path = self.cache_path / "data_indexes" / index_hash ++ ++ if ret_args: ++ return index_cache_path, impactful_args ++ else: ++ return index_cache_path ++ ++ def _cache_data_index( ++ self, ++ data_index: Union[ ++ List[Tuple[str, int, np.ndarray]], ++ List[Tuple[str, int, List[Tuple[str, np.ndarray]]]], ++ ], ++ ) -> None: ++ index_cache_dir, index_args = self._index_cache_path(ret_args=True) ++ ++ # Create it if it doesn't exist yet. ++ index_cache_dir.mkdir(parents=True, exist_ok=True) ++ ++ index_cache_file: Path = index_cache_dir / "data_index.dill" ++ with open(index_cache_file, "wb") as f: ++ dill.dump(data_index, f) ++ ++ args_file: Path = index_cache_dir / "index_args.json" ++ with open(args_file, "w") as f: ++ json.dump(index_args, f, indent=4) ++ ++ print( ++ f"Cached data index to {str(index_cache_file)}", ++ flush=True, ++ ) ++ ++ def _load_data_index( ++ self, ++ ) -> Union[ ++ List[Tuple[str, int, np.ndarray]], ++ List[Tuple[str, int, List[Tuple[str, np.ndarray]]]], ++ ]: ++ index_cache_file: Path = self._index_cache_path() / "data_index.dill" ++ with open(index_cache_file, "rb") as f: ++ data_index = dill.load(f) ++ ++ if self.verbose: ++ print( ++ f"Loaded data index from {str(index_cache_file)}", ++ flush=True, ++ ) ++ ++ return data_index ++ + def load_or_create_cache( + self, cache_path: str, num_workers=0, filter_fn=None + ) -> None: +@@ -494,7 +603,7 @@ class UnifiedDataset(Dataset): + f"Kept {self._data_len}/{old_len} elements, {self._data_len/old_len*100.0:.2f}%." + ) + +- def get_data_index( ++ def _get_data_index( + self, num_workers: int, scene_paths: List[Path] + ) -> Union[ + List[Tuple[str, int, np.ndarray]], +@@ -681,13 +790,16 @@ class UnifiedDataset(Dataset): + return_dict=return_dict, + pad_format=pad_format, + batch_augments=batch_augments, ++ desired_num_agents = self.max_agent_num, ++ desired_hist_len = int(self.history_sec[1]/self.desired_dt), ++ desired_fut_len = int(self.future_sec[1]/self.desired_dt), + ) + else: + raise ValueError(f"{self.centric}-centric data batches are not supported.") + + return collate_fn + +- def get_matching_scene_tags(self, queries: List[str]) -> List[SceneTag]: ++ def _get_matching_scene_tags(self, queries: List[str]) -> List[SceneTag]: + # if queries is None: + # return list(chain.from_iterable(env.components for env in self.envs)) + +@@ -708,10 +820,9 @@ class UnifiedDataset(Dataset): + + return matching_scene_tags + +- def get_desired_scenes_from_env( ++ def _get_desired_scenes_from_env( + self, + scene_tags: List[SceneTag], +- scene_description_contains: Optional[List[str]], + env: RawDataset, + ) -> Union[List[Scene], List[SceneMetadata]]: + scenes_list: Union[List[Scene], List[SceneMetadata]] = list() +@@ -721,14 +832,14 @@ class UnifiedDataset(Dataset): + if env.name in scene_tag: + scenes_list += env.get_matching_scenes( + scene_tag, +- scene_description_contains, ++ self.scene_description_contains, + self.env_cache, + self.rebuild_cache, + ) + + return scenes_list + +- def preprocess_scene_data( ++ def _preprocess_scene_data( + self, + scenes_list: Union[List[SceneMetadata], List[Scene]], + num_workers: int, +diff --git a/src/trajdata/dataset_specific/carla/__init__.py b/src/trajdata/dataset_specific/carla/__init__.py +new file mode 100644 +index 0000000..8854e8a +--- /dev/null ++++ b/src/trajdata/dataset_specific/carla/__init__.py +@@ -0,0 +1 @@ ++from .carla_dataset import CarlaDataset +\ No newline at end of file +diff --git a/src/trajdata/dataset_specific/carla/carla_dataset.py b/src/trajdata/dataset_specific/carla/carla_dataset.py +new file mode 100644 +index 0000000..2692e26 +--- /dev/null ++++ b/src/trajdata/dataset_specific/carla/carla_dataset.py +@@ -0,0 +1,415 @@ ++import warnings ++from copy import deepcopy ++from pathlib import Path ++from typing import Any, Dict, List, Optional, Tuple, Type, Union ++ ++import pandas as pd ++from nuscenes.eval.prediction.splits import NUM_IN_TRAIN_VAL ++from nuscenes.map_expansion.map_api import NuScenesMap, locations ++from nuscenes.nuscenes import NuScenes ++from nuscenes.utils.splits import create_splits_scenes ++from tqdm import tqdm ++ ++from trajdata.caching import EnvCache, SceneCache ++from trajdata.data_structures.agent import ( ++ Agent, ++ AgentMetadata, ++ AgentType, ++ FixedExtent, ++ VariableExtent, ++) ++from trajdata.data_structures.environment import EnvMetadata ++from trajdata.data_structures.scene_metadata import Scene, SceneMetadata ++from trajdata.data_structures.scene_tag import SceneTag ++from trajdata.dataset_specific.nusc import nusc_utils ++from trajdata.dataset_specific.raw_dataset import RawDataset ++from trajdata.dataset_specific.scene_records import CarlaSceneRecord ++from trajdata.maps import VectorMap ++from trajdata.maps.map_api import MapAPI ++ ++from pdb import set_trace as st ++import glob, os ++import pickle ++from collections import defaultdict ++import re ++ ++import torch ++import numpy as np ++ ++carla_to_trajdata_object_type = { ++ 0: AgentType.VEHICLE, ++ 1: AgentType.BICYCLE, # Motorcycle ++ 2: AgentType.BICYCLE, ++ 3: AgentType.PEDESTRIAN, ++ 4: AgentType.UNKNOWN, ++ # ?: AgentType.STATIC, ++} ++ ++def create_splits_scenes(data_dir:str) -> Dict[str, List[str]]: ++ all_scenes = {} ++ all_scenes['train'] = [scene_path.split('/')[-1] for scene_path in glob.glob(data_dir+'/train/route*')] ++ all_scenes['val'] = [scene_path.split('/')[-1] for scene_path in glob.glob(data_dir+'/val/route*')] ++ return all_scenes ++ ++# TODO: (Yulong) format in object class ++def tracklet_to_pred(tracklet_mem,ego=False): ++ if ego: ++ x, y, z = np.split(tracklet_mem['location'],3,axis=-1) ++ hx, hy, hz = np.split(np.deg2rad(tracklet_mem['rotation']),3,axis=-1) ++ vx, vy, _ = np.split(tracklet_mem['velocity'],3,axis=-1) ++ ax, ay, _ = np.split(tracklet_mem['acceleration'],3,axis=-1) ++ else: ++ x, y, z = np.split(tracklet_mem['location'][0,:,-1,:],3,axis=-1) ++ hx, hy, hz = np.split(np.deg2rad(tracklet_mem['rotation'])[0,:,-1,:],3,axis=-1) ++ vx, vy, _ = np.split(tracklet_mem['velocity'][0,:,-1,:],3,axis=-1) ++ ax, ay, _ = np.split(tracklet_mem['acceleration'][0,:,-1,:],3,axis=-1) ++ pred_state = np.concatenate([ ++ x, -y, z, vx, -vy, ax, -ay, -hy ++ ],axis=-1) ++ return pred_state ++ ++def CarlaTracking(dataroot): ++ dataset_obj = defaultdict(lambda: defaultdict(dict)) ++ frames = list(dataroot.glob('*/route*/metadata/tracking/*.pkl')) ++ for frame in frames: ++ frame = str(frame) ++ with open(frame, 'rb') as f: ++ track_mem = pickle.load(f) ++ frame_idx = frame.split('/')[-1].split('.')[0] ++ scene = frame.split('/')[-4] ++ dataset_obj[scene]["all"][frame_idx] = track_mem ++ ++ frames = list(dataroot.glob('*/route*/metadata/ego/*.pkl')) ++ for frame in frames: ++ frame = str(frame) ++ with open(frame, 'rb') as f: ++ track_mem = pickle.load(f) ++ frame_idx = frame.split('/')[-1].split('.')[0] ++ scene = frame.split('/')[-4] ++ dataset_obj[scene]["ego"][frame_idx] = track_mem ++ ++ return dataset_obj ++ ++def agg_ego_data(all_frames): ++ agent_data = [] ++ agent_frame = [] ++ for frame_idx, frame_info in enumerate(all_frames): ++ pred_state = tracklet_to_pred(frame_info, ego=True) ++ agent_data.append(pred_state) ++ agent_frame.append(frame_idx) ++ ++ agent_data_df = pd.DataFrame( ++ agent_data, ++ columns=["x", "y", "z", "vx", "vy", "ax", "ay", "heading"], ++ index=pd.MultiIndex.from_tuples( ++ [ ++ ("ego", idx) ++ for idx in agent_frame ++ ], ++ names=["agent_id", "scene_ts"], ++ ), ++ ) ++ ++ agent_metadata = AgentMetadata( ++ name="ego", ++ agent_type=AgentType.VEHICLE, ++ first_timestep=agent_frame[0], ++ last_timestep=agent_frame[-1], ++ extent=FixedExtent( ++ length=all_frames[-1]["size"][1], width=all_frames[-1]["size"][0], height=all_frames[-1]["size"][2] ++ ), ++ ) ++ return Agent( ++ metadata=agent_metadata, ++ data=agent_data_df, ++ ) ++ ++def agg_agent_data(all_frames, agent_info, frame_idx): ++ agent_data = [] ++ agent_frame = [] ++ Agent_list = [] ++ for frame_idx, frame_info in enumerate(all_frames): ++ for idx in range(frame_info['id'].shape[1]): ++ if frame_info["id"][0,idx,0] == agent_info["id"]: ++ # two segments of tracking ++ if len(agent_frame) > 0 and frame_idx != agent_frame[-1] + 1: ++ agent_data_df = pd.DataFrame( ++ agent_data, ++ columns=["x", "y", "z", "vx", "vy", "ax", "ay", "heading"], ++ index=pd.MultiIndex.from_tuples( ++ [ ++ (f'{int(agent_info["id"])}_{len(Agent_list)}', idx) ++ for idx in agent_frame ++ ], ++ names=["agent_id", "scene_ts"], ++ ), ++ ) ++ ++ agent_metadata = AgentMetadata( ++ name=f'{int(agent_info["id"])}_{len(Agent_list)}', ++ agent_type=carla_to_trajdata_object_type[agent_info["cls"]], ++ first_timestep=agent_frame[0], ++ last_timestep=agent_frame[-1], ++ extent=FixedExtent( ++ length=agent_info["size"][1], width=agent_info["size"][0], height=agent_info["size"][2] ++ ), ++ ) ++ Agent_list.append(Agent( ++ metadata=agent_metadata, ++ data=agent_data_df, ++ )) ++ ++ agent_data = [] ++ agent_frame = [] ++ ++ ++ pred_state = tracklet_to_pred(frame_info)[idx] ++ agent_data.append(pred_state) ++ agent_frame.append(frame_idx) ++ ++ agent_data_df = pd.DataFrame( ++ agent_data, ++ columns=["x", "y", "z", "vx", "vy", "ax", "ay", "heading"], ++ index=pd.MultiIndex.from_tuples( ++ [ ++ (str(int(agent_info["id"])), idx) ++ for idx in agent_frame ++ ], ++ names=["agent_id", "scene_ts"], ++ ), ++ ) ++ ++ agent_metadata = AgentMetadata( ++ name=str(int(agent_info["id"])), ++ agent_type=carla_to_trajdata_object_type[agent_info["cls"]], ++ first_timestep=agent_frame[0], ++ last_timestep=agent_frame[-1], ++ extent=FixedExtent( ++ length=agent_info["size"][1], width=agent_info["size"][0], height=agent_info["size"][2] ++ ), ++ ) ++ ++ Agent_list.append(Agent( ++ metadata=agent_metadata, ++ data=agent_data_df, ++ )) ++ return Agent_list ++ ++ ++ ++class CarlaDataset(RawDataset): ++ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: ++ # We're using the nuScenes prediction challenge split here. ++ # See https://github.com/nutonomy/nuscenes-devkit/blob/master/python-sdk/nuscenes/eval/prediction/splits.py ++ # for full details on how the splits are obtained below. ++ all_scene_splits: Dict[str, List[str]] = create_splits_scenes(data_dir) ++ ++ train_scenes: List[str] = deepcopy(all_scene_splits["train"]) ++ NUM_IN_TRAIN_VAL = round(len(train_scenes)*0.25) ++ all_scene_splits["train"] = train_scenes[NUM_IN_TRAIN_VAL:] ++ all_scene_splits["train_val"] = train_scenes[:NUM_IN_TRAIN_VAL] ++ ++ if env_name == 'carla': ++ carla_scene_splits: Dict[str, List[str]] = { ++ k: all_scene_splits[k] for k in ["train", "train_val", "val"] ++ } ++ ++ # nuScenes possibilities are the Cartesian product of these ++ dataset_parts: List[Tuple[str, ...]] = [ ++ ("train", "train_val", "val"), ++ ] ++ else: ++ raise ValueError(f"Unknown nuScenes environment name: {env_name}") ++ ++ self.scene_splits = carla_scene_splits ++ ++ # Inverting the dict from above, associating every scene with its data split. ++ carla_scene_split_map: Dict[str, str] = { ++ v_elem: k for k, v in carla_scene_splits.items() for v_elem in v ++ } ++ return EnvMetadata( ++ name=env_name, ++ data_dir=data_dir, ++ dt=nusc_utils.NUSC_DT, ++ parts=dataset_parts, ++ scene_split_map=carla_scene_split_map, ++ # The location names should match the map names used in ++ # the unified data cache. ++ map_locations=tuple([]), ++ ) ++ ++ def load_dataset_obj(self, verbose: bool = False) -> None: ++ if verbose: ++ print(f"Loading {self.name} dataset...", flush=True) ++ ++ self.dataset_obj = CarlaTracking( ++ dataroot=self.metadata.data_dir ++ ) ++ ++ def _get_matching_scenes_from_obj( ++ self, ++ scene_tag: SceneTag, ++ scene_desc_contains: Optional[List[str]], ++ env_cache: EnvCache, ++ ) -> List[SceneMetadata]: ++ all_scenes_list: List[CarlaSceneRecord] = list() ++ ++ scenes_list: List[SceneMetadata] = list() ++ for idx, scene_record in enumerate(self.dataset_obj): ++ scene_name: str = scene_record ++ scene_location: str = re.match('.*(Town\d+)', scene_record)[1] ++ scene_split: str = self.metadata.scene_split_map[scene_name] ++ scene_length: int = len(self.dataset_obj[scene_record]) ++ ++ # Saving all scene records for later caching. ++ all_scenes_list.append( ++ CarlaSceneRecord( ++ scene_name, scene_location, scene_length, idx ++ ) ++ ) ++ ++ if scene_split in scene_tag: ++ ++ scene_metadata = SceneMetadata( ++ env_name=self.metadata.name, ++ name=scene_name, ++ dt=self.metadata.dt, ++ raw_data_idx=idx, ++ ) ++ scenes_list.append(scene_metadata) ++ ++ self.cache_all_scenes_list(env_cache, all_scenes_list) ++ return scenes_list ++ ++ def _get_matching_scenes_from_cache( ++ self, ++ scene_tag: SceneTag, ++ scene_desc_contains: Optional[List[str]], ++ env_cache: EnvCache, ++ ) -> List[Scene]: ++ all_scenes_list: List[CarlaSceneRecord] = env_cache.load_env_scenes_list( ++ self.name ++ ) ++ ++ scenes_list: List[SceneMetadata] = list() ++ for scene_record in all_scenes_list: ++ ( ++ scene_name, ++ scene_location, ++ scene_length, ++ data_idx, ++ ) = scene_record ++ scene_split: str = self.metadata.scene_split_map[scene_name] ++ ++ if scene_split in scene_tag: ++ scene_metadata = Scene( ++ self.metadata, ++ scene_name, ++ scene_location, ++ scene_split, ++ scene_length, ++ data_idx, ++ None, # This isn't used if everything is already cached. ++ ) ++ scenes_list.append(scene_metadata) ++ ++ return scenes_list ++ ++ def get_scene(self, scene_info: SceneMetadata) -> Scene: ++ _, route_name, _, data_idx = scene_info ++ ++ scene_record = sorted(self.dataset_obj[route_name]) ++ scene_name: str = route_name ++ scene_location: str = re.match('.*(Town\d+)', route_name)[1] ++ scene_split: str = self.metadata.scene_split_map[scene_name] ++ scene_length: int = len(self.dataset_obj[route_name]['ego']) ++ ++ return Scene( ++ self.metadata, ++ scene_name, ++ scene_location, ++ scene_split, ++ scene_length, ++ data_idx, ++ scene_record, ++ ) ++ ++ def get_agent_info( ++ self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] ++ ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: ++ ego_agent_info: AgentMetadata = AgentMetadata( ++ name="ego", ++ agent_type=AgentType.VEHICLE, ++ first_timestep=0, ++ last_timestep=scene.length_timesteps - 1, ++ extent=FixedExtent(length=4.084, width=1.730, height=1.562), #TODO: (yulong) replace with carla ego ++ ) ++ ++ agent_presence: List[List[AgentMetadata]] = [ ++ [ego_agent_info] for _ in range(scene.length_timesteps) ++ ] ++ ++ agent_data_list: List[pd.DataFrame] = list() ++ existing_agents: Dict[str, AgentMetadata] = dict() ++ ++ all_frames = [self.dataset_obj[scene.name]["all"][key] for key in sorted(self.dataset_obj[scene.name]["all"])] ++ ++ # frame_idx_dict = { ++ # frame_dict: idx for idx, frame_dict in enumerate(all_frames) ++ # } ++ for frame_idx, frame_info in enumerate(all_frames): ++ for idx in range(frame_info['id'].shape[1]): ++ if str(int(frame_info["id"][0,idx,0])) in existing_agents: ++ continue ++ ++ agent_info = {"id": frame_info["id"][0,idx,0], ++ "cls": frame_info["cls"][0,idx,:].argmax(), ++ "size": frame_info["size"][0,idx,:] } ++ # if not agent_info["next"]: ++ # # There are some agents with only a single detection to them, we don't care about these. ++ # continue ++ ++ agent_list: List[Agent] = agg_agent_data( ++ all_frames, agent_info, frame_idx ++ ) ++ for agent in agent_list: ++ for scene_ts in range( ++ agent.metadata.first_timestep, agent.metadata.last_timestep + 1 ++ ): ++ agent_presence[scene_ts].append(agent.metadata) ++ ++ existing_agents[agent.name] = agent.metadata ++ ++ agent_data_list.append(agent.data) ++ ++ ego_all_frames = [self.dataset_obj[scene.name]["ego"][key] for key in sorted(self.dataset_obj[scene.name]["ego"])] ++ ++ ego_agent: Agent = agg_ego_data(ego_all_frames) ++ agent_data_list.append(ego_agent.data) ++ ++ agent_list: List[AgentMetadata] = [ego_agent_info] + list( ++ existing_agents.values() ++ ) ++ ++ cache_class.save_agent_data(pd.concat(agent_data_list), cache_path, scene) ++ ++ return agent_list, agent_presence ++ ++ ++ def cache_maps( ++ self, ++ cache_path: Path, ++ map_cache_class: Type[SceneCache], ++ map_params: Dict[str, Any], ++ ) -> None: ++ ++ map_api = MapAPI(cache_path) ++ for carla_town in [f"Town0{x}" for x in range(1, 8)] + ["Town10", "Town10HD"]: # ["main"]: ++ vec_map = map_api.get_map( ++ f"drivesim:main" if carla_town == "main" else f"carla:{carla_town}", ++ incl_road_lanes=True, ++ incl_road_areas=True, ++ incl_ped_crosswalks=True, ++ incl_ped_walkways=True, ++ ) ++ map_cache_class.finalize_and_cache_map(cache_path, vec_map, map_params) +diff --git a/src/trajdata/dataset_specific/drivesim/__init__.py b/src/trajdata/dataset_specific/drivesim/__init__.py +new file mode 100644 +index 0000000..0167f39 +--- /dev/null ++++ b/src/trajdata/dataset_specific/drivesim/__init__.py +@@ -0,0 +1 @@ ++from .drivesim_dataset import DrivesimDataset +diff --git a/src/trajdata/dataset_specific/drivesim/drivesim_dataset.py b/src/trajdata/dataset_specific/drivesim/drivesim_dataset.py +new file mode 100644 +index 0000000..865e451 +--- /dev/null ++++ b/src/trajdata/dataset_specific/drivesim/drivesim_dataset.py +@@ -0,0 +1,116 @@ ++import warnings ++from copy import deepcopy ++from pathlib import Path ++from typing import Any, Dict, List, Optional, Tuple, Type, Union ++from collections import defaultdict ++ ++import pandas as pd ++ ++from tqdm import tqdm ++ ++from trajdata.caching import EnvCache, SceneCache ++from trajdata.data_structures.agent import ( ++ Agent, ++ AgentMetadata, ++ AgentType, ++ FixedExtent, ++ VariableExtent, ++) ++from trajdata.data_structures.environment import EnvMetadata ++from trajdata.data_structures.scene_metadata import Scene, SceneMetadata ++from trajdata.data_structures.scene_tag import SceneTag ++from trajdata.dataset_specific.raw_dataset import RawDataset ++from trajdata.dataset_specific.scene_records import DrivesimSceneRecord ++from trajdata.maps import VectorMap ++ ++DRIVESIM_DT = 0.1 ++ ++class DrivesimDataset(RawDataset): ++ def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: ++ ++ return EnvMetadata( ++ name="drivesim", ++ data_dir=data_dir, ++ dt=DRIVESIM_DT, ++ parts=[("train",),("main",)], ++ scene_split_map=defaultdict(lambda: "train"), ++ # The location names should match the map names used in ++ # the unified data cache. ++ map_locations=("main",), ++ ) ++ ++ def load_dataset_obj(self, verbose: bool = False) -> None: ++ pass ++ ++ def _get_matching_scenes_from_obj( ++ self, ++ scene_tag: SceneTag, ++ scene_desc_contains: Optional[List[str]], ++ env_cache: EnvCache, ++ ) -> List[SceneMetadata]: ++ raise NotImplementedError() ++ ++ def _get_matching_scenes_from_cache( ++ self, ++ scene_tag: SceneTag, ++ scene_desc_contains: Optional[List[str]], ++ env_cache: EnvCache, ++ ) -> List[Scene]: ++ all_scenes_list: List[DrivesimSceneRecord] = env_cache.load_env_scenes_list( ++ self.name ++ ) ++ scenes_list: List[SceneMetadata] = list() ++ for scene_record in all_scenes_list: ++ ( ++ scene_name, ++ scene_location, ++ scene_length, ++ scene_desc, ++ data_idx, ++ ) = scene_record ++ scene_split: str = self.metadata.scene_split_map[scene_name] ++ ++ if scene_location.split("-")[0] in scene_tag and scene_split in scene_tag: ++ if scene_desc_contains is not None and not any( ++ desc_query in scene_desc for desc_query in scene_desc_contains ++ ): ++ continue ++ ++ scene_metadata = Scene( ++ self.metadata, ++ scene_name, ++ scene_location, ++ scene_split, ++ scene_length, ++ data_idx, ++ None, # This isn't used if everything is already cached. ++ scene_desc, ++ ) ++ scenes_list.append(scene_metadata) ++ ++ return scenes_list ++ ++ def get_scene(self, scene_info: SceneMetadata) -> Scene: ++ raise NotImplementedError() ++ ++ def get_agent_info( ++ self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] ++ ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: ++ raise NotImplementedError() ++ ++ def cache_map( ++ self, ++ map_name: str, ++ cache_path: Path, ++ map_cache_class: Type[SceneCache], ++ map_params: Dict[str, Any], ++ ) -> None: ++ raise NotImplementedError() ++ ++ def cache_maps( ++ self, ++ cache_path: Path, ++ map_cache_class: Type[SceneCache], ++ map_params: Dict[str, Any], ++ ) -> None: ++ raise NotImplementedError() +diff --git a/src/trajdata/dataset_specific/drivesim/nvmap_utils.py b/src/trajdata/dataset_specific/drivesim/nvmap_utils.py +new file mode 100644 +index 0000000..0f41cc2 +--- /dev/null ++++ b/src/trajdata/dataset_specific/drivesim/nvmap_utils.py +@@ -0,0 +1,1792 @@ ++ ++import os, shutil, copy ++import glob ++import json ++from collections import OrderedDict, deque ++ ++import numpy as np ++import cv2 ++ ++import matplotlib as mpl ++# mpl.use('Agg') ++import matplotlib.pyplot as plt ++from matplotlib.patches import Circle, Polygon, PathPatch, Ellipse ++from matplotlib.path import Path ++ ++import sys ++ ++# NOTE: NVMap documentation https://nvmapspec.drive.nvda.io/map_layer_format/index.html ++ ++# transformation matrix of the base pose of the global cooridinate system (defined in ECEF) ++# will be used to transform each road segment to global frame. ++# rivermark is used for autolabels ++GLOBAL_BASE_POSE_RIVERMARK = np.array([[ 0.7661314567952979, -0.47780431092164843, -0.42982046411659275, -2684537.8381108316], ++ [ 0.06759529959471867, 0.7249870917843614, -0.6854375188292178, -4305029.705512633], ++ [ 0.6391192896333319, 0.49608140179888094, 0.5877327423309361, 3852303.4349504006], ++ [ 0.0, 0.0, 0.0, 1.0]]) ++# endeavor can be used for map if desired ++GLOBAL_BASE_POSE_ENDEAVOR = np.array([[0.892451945618358, 0.11765375230214345, -0.43553084773783024, -2686874.9549495797], ++ [-0.40233607907375113, 0.644307257120398, -0.6503797643665964, -4305568.426410184], ++ [0.2040960661981692, 0.7556624596942837, 0.6223496145826857, 3850100.0530229895], ++ [ 0.0, 0.0, 0.0, 1.0]]) ++# this is the one that will be used for the map and trajectories ++GLOBAL_BASE_POSE = GLOBAL_BASE_POSE_ENDEAVOR ++# GLOBAL_BASE_POSE = GLOBAL_BASE_POSE_RIVERMARK ++ ++MAX_ENDEAVOR_FRAME = 3613 ++ ++RIVERMARK_VIZ = False ++# 0 : drivable only ++# 1 : + ego ++# 2 : + lines ++# 3 : + other non-extra cars ++# 4 : + all cars ++RIVERMARK_VIZ_LAYER = 4 ++NVCOLOR = (118.0/255, 185.0/255, 0.0/255) ++NVCOLOR2 = (55.0/255, 225.0/255, 0.0/255) ++# EXTRA_DYN_LWH = [5.087, 2.307, 1.856] ++EXTRA_DYN_LWH = [4.387, 1.907, 1.656] ++ ++NUSC_EXPAND_LEN = 0.0 #2.0 # meters to expand drivable area outward ++ ++AUTOLABEL_DT = 0.1 # sec ++ ++NVMAP_LAYERS = {'lane_dividers_v1', 'lanes_v1', 'lane_channels_v1'} ++SEGMENT_GRAPH_LAYER = 'segment_graph_v1' ++ ++# maps track ids (which occur first in time) to other tracks that are ++# actually the same object ++# NOTE: each association list must be sorted in temporal order ++TRACK_ASSOC = { ++ # frame 3000 - 3545 sequence ++ 2486 : [2584, 2669], ++ 2546 : [2618], ++ 2489 : [2558], ++ 2496 : [2553], ++ 2515 : [2578], ++ 2525 : [2615], ++ 2555 : [2609], ++ 2546 : [2618, 2673], ++ # 2491 : [2566, 2631], ++ 2491 : [2631], ++ # 2566 : [2631], ++ 2484 : [2692, 2729, 2745, 2766], ++ 2686 : [2811], ++ 2618 : [2673], ++ 2592 : [2719], ++ 2251 : [2847], ++ 2269 : [2734, 2853], ++ 1280 : [2768, 2790], ++ 2631 : [2685], ++ 2561 : [2699], ++ # frame 570 - 990 sequences ++ 1298 : [1360], ++ 1282 : [1362], ++ 1241 : [1366], ++ 1305 : [1398], ++ 1286 : [1377], ++ 1327 : [1409], ++ 1444 : [1576], ++ 1326 : [1455], ++ 1376 : [1425, 1435, 1444], ++ 1345 : [1445, 1456, 1476], ++ 1346 : [1477, 1491], ++ # frame 270 - 600 ++ 1067 : [1193], ++ 1109 : [1206], ++ 1126 : [1229], ++ 1127 : [1234], ++ 1137 : [1255], ++ 1150 : [1251], ++ 1168 : [1246], ++ 1184 : [1269], ++ 1203 : [1225], ++ 1196 : [1307], ++ 1211 : [1315], ++ 1191 : [1297], ++ 1183 : [1276], ++ 1188 : [1274], ++ 1178 : [1250], ++ # frame 1440 - 1860 ++ 1717 : [1761], ++ 1712 : [1778], ++ 1704 : [1786], ++ 1733 : [1803], ++ 1708 : [1722,1811], ++ 1674 : [1732,1740], ++ 1760 : [1837], ++ 1824 : [1836], ++ 1793 : [1872], ++ 1812 : [1869], ++ 1748 : [1812], ++ 1826 : [1864], ++ 1723 : [1756], ++ # rivermark ++ 9409 : [9517] ++} ++# false positive (or extrapolation is bad) ++# 2566 ++TRACK_REMOVE = {2676, 2687, 2590, 2833, 1199, 9413, 9449, 9437, 9589, 2588, 2669, 2562, 2516, 2566} # 9449 is rivermark dynamic, and 9437 is rivermark trash can ++ ++def check_single_veh_coll(traj_tgt, lw_tgt, traj_others, lw_others): ++ ''' ++ Checks if the target trajectory collides with each of the given other trajectories. ++ ++ Assumes all trajectories and attributes are UNNORMALIZED. Handles nan frames in traj_others by simply skipping. ++ ++ :param traj_tgt: (T x 4) ++ :param lw_tgt: (2, ) ++ :param traj_others: (N x T x 4) ++ :param lw_others: (N x 2) ++ ++ :returns veh_coll: (N) ++ :returns coll_time: (N) ++ ''' ++ from shapely.geometry import Polygon ++ ++ NA, FT, _ = traj_others.shape ++ ++ veh_coll = np.zeros((NA, FT), dtype=np.bool) ++ poly_cache = dict() # for the tgt polygons since used many times ++ for aj in range(NA): ++ for t in range(FT): ++ # compute iou ++ if t not in poly_cache: ++ ai_state = traj_tgt[t, :] ++ if np.sum(np.isnan(ai_state)) > 0: ++ continue ++ ai_corners = get_corners(ai_state, lw_tgt) ++ ai_poly = Polygon(ai_corners) ++ poly_cache[t] = ai_poly ++ else: ++ ai_poly = poly_cache[t] ++ ++ aj_state = traj_others[aj, t, :] ++ if np.sum(np.isnan(aj_state)) > 0: ++ continue ++ aj_corners = get_corners(aj_state, lw_others[aj]) ++ aj_poly = Polygon(aj_corners) ++ cur_iou = ai_poly.intersection(aj_poly).area / ai_poly.union(aj_poly).area ++ if cur_iou > 0.02: ++ veh_coll[aj, t] = True ++ ++ return veh_coll ++ ++def plt_color(i): ++ clist = plt.rcParams['axes.prop_cycle'].by_key()['color'] ++ return clist[i] ++ ++def get_rot(h): ++ return np.array([ ++ [np.cos(h), np.sin(h)], ++ [-np.sin(h), np.cos(h)], ++ ]) ++ ++def get_corners(box, lw): ++ l, w = lw ++ simple_box = np.array([ ++ [-l/2., -w/2.], ++ [l/2., -w/2.], ++ [l/2., w/2.], ++ [-l/2., w/2.], ++ ]) ++ h = np.arctan2(box[3], box[2]) ++ rot = get_rot(h) ++ simple_box = np.dot(simple_box, rot) ++ simple_box += box[:2] ++ return simple_box ++ ++def plot_box(box, lw, color='g', alpha=0.7, no_heading=False): ++ l, w = lw ++ h = np.arctan2(box[3], box[2]) ++ simple_box = get_corners(box, lw) ++ ++ arrow = np.array([ ++ box[:2], ++ box[:2] + l/2.*np.array([np.cos(h), np.sin(h)]), ++ ]) ++ ++ plt.fill(simple_box[:, 0], simple_box[:, 1], color=color, edgecolor='k', alpha=alpha, linewidth=1.0, zorder=3) ++ if not no_heading: ++ # plt.plot(arrow[:, 0], arrow[:, 1], color, alpha=1.0) ++ plt.plot(arrow[:, 0], arrow[:, 1], 'k', alpha=alpha, zorder=3) ++ ++def create_video(img_path_form, out_path, fps): ++ ''' ++ Creates a video from a format for frame e.g. 'data_out/frame%04d.png'. ++ Saves in out_path. ++ ''' ++ import subprocess ++ # if RIVERMARK_VIZ: ++ # ffmpeg_cmd = ['ffmpeg', '-y', '-i', img_path_form, ++ # '-vf', 'transpose=2', img_path_form] ++ # subprocess.run(ffmpeg_cmd) ++ ++ ffmpeg_cmd = ['ffmpeg', '-y', '-r', str(fps), '-i', img_path_form, ++ '-vcodec', 'libx264', '-crf', '18', '-pix_fmt', 'yuv420p', out_path] ++ subprocess.run(ffmpeg_cmd) ++ ++def debug_viz_vid(seg_dict, prefix, poses, poses_valid, poses_lwh, ++ comp_out_path='./out/dev_nvmap', ++ fps=10, ++ subsamp=3, ++ pose_ids=None, ++ **kwargs): ++ poses = poses[:,::subsamp] ++ poses_valid = poses_valid[:,::subsamp] ++ T = poses.shape[1] ++ out_dir = os.path.join(comp_out_path, prefix) ++ if not os.path.exists(out_dir): ++ os.makedirs(out_dir) ++ for t in range(T): ++ print('rendering frame %d...' % (t)) ++ debug_viz_segs(seg_dict, 'frame_%06d' % (t), poses[:,t:t+1], poses_valid[:,t:t+1], poses_lwh, ++ comp_out_path=out_dir, ++ pose_ids=pose_ids, ++ ego_traj=poses[0], ++ **kwargs) ++ create_video(os.path.join(out_dir, 'frame_%06d.jpg'), out_dir + '.mp4', fps) ++ ++def debug_viz_segs(seg_dict, prefix, poses=None, poses_valid=None, poses_lwh=None, pose_ids=None, ++ comp_out_path='./out/dev_nvmap', ++ extent=80, ++ grid=True, ++ show_ticks=True, ++ dpi=100, ++ ego_traj=None): ++ ''' ++ Visualize segments in the given dictionary. ++ ++ :param seg_dict: segments dictionary ++ :param prefix: prefix to save figure ++ :param poses: NxTx4x4 trajectory for N vehicles that will be plotted if given. ++ ''' ++ if not os.path.exists(comp_out_path): ++ os.makedirs(comp_out_path) ++ ++ # fig = plt.figure() ++ fig = plt.figure(dpi=dpi) ++ ++ origins = [] ++ arr_len = 10.0 ++ for _, seg in seg_dict.items(): ++ if poses is not None: ++ dist2ego = np.linalg.norm(seg.local2world[:2, -1] - poses[0,0,:2,-1]) ++ if dist2ego > 200: ++ continue ++ # plot layers ++ if 'drivable_area' in seg.layers: ++ # draw drivable area first so under everything else ++ for drivable_poly in seg.layers['drivable_area']: ++ polypatch = Polygon(drivable_poly[:,:2], ++ color='darkgray', ++ alpha=1.0, ++ linestyle='-') ++ # linewidth=2) ++ plt.gca().add_patch(polypatch) ++ ++ if RIVERMARK_VIZ and RIVERMARK_VIZ_LAYER < 2: ++ # only drivable area ++ continue ++ ++ for layer_k, layer_v in seg.layers.items(): ++ if layer_k in {'lane_dividers', 'lane_divider'}: ++ for lane_div in layer_v: ++ polyline = lane_div.polyline if isinstance(lane_div, LaneDivider) else lane_div ++ linepatch = PathPatch(Path(polyline[:,:2]), ++ fill=False, ++ color='gold', ++ linestyle='-') ++ # linewidth=2) ++ plt.gca().add_patch(linepatch) ++ elif layer_k in {'road_dividers', 'road_divider'}: ++ for road_div in layer_v: ++ linepatch = PathPatch(Path(road_div[:,:2]), ++ fill=False, ++ color='orange', ++ linestyle='-') ++ # linewidth=2) ++ plt.gca().add_patch(linepatch) ++ elif layer_k in {'road_boundaries'}: ++ for road_bound in layer_v: ++ linepatch = PathPatch(Path(road_bound.polyline[:,:2]), ++ fill=False, ++ color='darkgray', ++ linestyle='-') ++ # linewidth=2) ++ plt.gca().add_patch(linepatch) ++ # elif layer_k in {'lane_channels'}: ++ # for lane_channel in layer_v: ++ # linepatch = PathPatch(Path(lane_channel.left), ++ # fill=False, ++ # color='blue', ++ # linestyle='-') ++ # # linewidth=2) ++ # plt.gca().add_patch(linepatch) ++ # linepatch = PathPatch(Path(lane_channel.right), ++ # fill=False, ++ # color='red', ++ # linestyle='-') ++ # # linewidth=2) ++ # plt.gca().add_patch(linepatch) ++ ++ # plot local coordinate system origin ++ if poses is None: ++ local_coords = np.array([[0.0, 0.0, 0.0, 1.0], [arr_len, 0.0, 0.0, 1.0], [0.0, arr_len, 0.0, 1.0]]) ++ world_coords = np.dot(seg.local2world, local_coords.T).T ++ world_coords = world_coords[:,:2] # only plot 2D coords ++ origins.append(world_coords[0]) ++ xdelta = world_coords[1] - world_coords[0] ++ ydelta = world_coords[2] - world_coords[0] ++ plt.arrow(world_coords[0, 0], world_coords[0, 1], xdelta[0], xdelta[1], color='red') ++ plt.arrow(world_coords[0, 0], world_coords[0, 1], ydelta[0], ydelta[1], color='green') ++ ++ if poses is not None and poses_valid is not None and poses_lwh is not None: ++ # center on ego ++ origin = poses[0,0,:2,-1] ++ if RIVERMARK_VIZ: ++ extent = 45 ++ origin = origin + np.array([extent - 15.0, 0.0]) ++ ++ # if RIVERMARK_VIZ and RIVERMARK_VIZ_LAYER > 3: ++ # # plot ego traj ++ # ego_pos = ego_traj[::10,:2,3] ++ # plt.plot(ego_pos[:,0], ego_pos[:,1], 'o-', c=NVCOLOR2, markersize=2.5) #, markersize=8), linewidth ++ ++ plt.xlim(origin[0]-extent, origin[0]+extent) ++ plt.ylim(origin[1]-extent, origin[1]+extent) ++ for n in range(poses.shape[0]): ++ if RIVERMARK_VIZ and RIVERMARK_VIZ_LAYER < 1: ++ continue ++ if RIVERMARK_VIZ and RIVERMARK_VIZ_LAYER < 3 and n != 0: ++ continue ++ if RIVERMARK_VIZ and RIVERMARK_VIZ_LAYER < 4 and pose_ids[n] == 'extra': ++ continue ++ # if RIVERMARK_VIZ and RIVERMARK_VIZ_LAYER == 3 and n == 0: ++ # continue ++ cur_color = plt_color((n+2) % 9) ++ if RIVERMARK_VIZ and RIVERMARK_VIZ_LAYER >= 4 and pose_ids[n] == 'extra': ++ cur_color = '#ff00ff' ++ if RIVERMARK_VIZ: ++ print(n) ++ print(pose_ids[n]) ++ print(cur_color) ++ cur_poses = poses[n] #, ::20] ++ xy = cur_poses[:,:2,3] ++ hvec = cur_poses[:,:2,0] # x axis ++ hvec = hvec / np.linalg.norm(hvec, axis=1, keepdims=True) ++ for t in range(cur_poses.shape[0]): ++ if poses_valid[n, t]: ++ plot_box(np.array([xy[t,0], xy[t,1], hvec[t,0], hvec[t,1]]), poses_lwh[n,:2], ++ color=NVCOLOR if n ==0 else cur_color, alpha=1.0, no_heading=False) ++ if pose_ids is not None and not RIVERMARK_VIZ: ++ plt.text(xy[t,0] + 1.0, xy[t,1] + 1.0, pose_ids[n], c='red', fontsize='x-small') ++ ++ plt.gca().set_aspect('equal') ++ plt.grid(grid) ++ if not show_ticks: ++ plt.xticks([]) ++ plt.yticks([]) ++ plt.gca().axis('off') ++ # plt.tight_layout() ++ cur_save_path = os.path.join(comp_out_path, prefix + '.jpg') ++ fig.savefig(cur_save_path) ++ # plt.show() ++ plt.close(fig) ++ ++ if RIVERMARK_VIZ: ++ og_img = cv2.imread(cur_save_path) ++ rot_img = cv2.rotate(og_img, cv2.cv2.ROTATE_90_COUNTERCLOCKWISE) ++ cv2.imwrite(cur_save_path, rot_img) ++ ++ ++def get_tile_mask(tile, layer_name, local_box, canvas_size): ++ ''' ++ Rasterizes a layer of the given tile object into a binary mask. ++ Assumes tile object has been converted to hold nuscenes-like layers. ++ ++ :param tile: NVMapTile object holding the nuscenes layers ++ :param layer_name str: which layer to rasterize, currently supports ['drivable_area', 'carpark_area', 'road_divider', 'lane_divider'] ++ :param local_box tuple: (center_x, center_y, H, W) in meters which patch of the map to rasterize ++ :param canvas_size tuple: (H, W) pixels tuple which determines the resolution at which the layer is rasterized ++ ''' ++ # must transform each map element to pixel space ++ # https://github.com/nutonomy/nuscenes-devkit/blob/9b209638ef3dee6d0cdc5ac700c493747f5b35fe/python-sdk/nuscenes/map_expansion/map_api.py#L1894 ++ patch_x, patch_y, patch_h, patch_w = local_box ++ canvas_h = canvas_size[0] ++ canvas_w = canvas_size[1] ++ scale_height = canvas_h/patch_h ++ scale_width = canvas_w/patch_w ++ trans_x = -patch_x + patch_w / 2.0 ++ trans_y = -patch_y + patch_h / 2.0 ++ trans = np.array([[trans_x, trans_y]]) ++ scale = np.array([[scale_width, scale_height]]) ++ ++ map_mask = np.zeros(canvas_size, np.uint8) ++ for seg_id, seg in tile.segments.items(): ++ for poly_pts in seg.layers[layer_name]: ++ # convert to pixel coords ++ poly_pts = (poly_pts + trans)*scale ++ # rasterize ++ if layer_name in {'drivable_area'}: ++ # polygon ++ coords = poly_pts.astype(np.int32) ++ cv2.fillPoly(map_mask, [coords], 1) ++ elif layer_name in {'lane_divider', 'road_divider'}: ++ # polyline ++ coords = poly_pts.astype(np.int32) ++ cv2.polylines(map_mask, [coords], False, 1, 2) ++ elif layer_name in {'carpark_area'}: ++ # empty ++ pass ++ else: ++ print('Unrecognized layer %d - cannot render mask' % (layer_name)) ++ ++ return map_mask ++ ++# https://www.dmv.ca.gov/portal/handbook/california-driver-handbook/lane-control/ ++# - solid yellow lines = center of road for two-way (road divider) ++# - two, one solid one broken yellow = may pass but going opposite dir (road divider) ++# - two solid yellow = road divider ++# - solid white = edge of road going same way (lane divider) ++# - broken white = two or more lanes same direction (lane divider) ++# - double white = HOV (lane divider) ++# - invisible should be ommitted (e.g. will just be in intersections) ++ ++def convert_tile_to_nuscenes(tile): ++ ''' ++ Given a tile, converts its layers into similar format as nuscenes. ++ This includes converting lane dividers and road boundaries to lane/road dividers and ++ drivable area. ++ returns an updated copy of the tile. ++ ''' ++ print('Converting to nuscenes...') ++ tile = copy.deepcopy(tile) ++ for seg_id, seg in tile.segments.items(): ++ if 'lane_dividers' in seg.layers: ++ nusc_lane_dividers = [] ++ nusc_road_dividers = [] ++ for div in seg.layers['lane_dividers']: ++ style = div.style ++ if len(style) > 0: ++ div_color = style[0][2] ++ div_pattern = style[0][0] ++ if div_color in {'White'}: #, 'Green'}: ++ # divides traffic in same direction ++ nusc_lane_dividers.append(div.polyline[:,:2]) ++ elif div_color in {'Yellow'} or (len(style) == 2 and div_pattern == 'Botts Dots'): #, 'Blue', 'Red', 'Orange'}: ++ # divides traffic in opposite direction ++ nusc_road_dividers.append(div.polyline[:,:2]) ++ # elif RIVERMARK_VIZ: ++ # nusc_lane_dividers.append(div.polyline[:,:2]) ++ # update segment ++ seg.layers['lane_divider'] = nusc_lane_dividers ++ seg.layers['road_divider'] = nusc_road_dividers ++ ++ del seg.layers['lane_dividers'] ++ ++ if 'road_boundaries' in seg.layers: ++ del seg.layers['road_boundaries'] ++ pass # actually don't need this for now, it's just for reference ++ ++ if 'lane_channels' in seg.layers: ++ # convert to drivable area polygon ++ expand_len = NUSC_EXPAND_LEN # meters ++ drivable_area_polys = [] ++ for channel in seg.layers['lane_channels']: ++ left = channel.left[:,:2] ++ right = channel.right[:,:2] ++ if RIVERMARK_VIZ: ++ seg.layers['lane_divider'].append(left) ++ seg.layers['road_divider'].append(right) ++ if left.shape[0] > 1: ++ # compute normals at each vertex to expand ++ left_diff = left[1:] - left[:-1] ++ left_diff = np.concatenate([left_diff, left_diff[-1:,:]], axis=0) ++ left_diff = left_diff / np.linalg.norm(left_diff, axis=1, keepdims=True) ++ left_norm = np.concatenate([-left_diff[:,1:2], left_diff[:,0:1]], axis=1) ++ right_diff = right[1:] - right[:-1] ++ right_diff = np.concatenate([right_diff, right_diff[-1:,:]], axis=0) ++ right_diff = right_diff / np.linalg.norm(right_diff, axis=1, keepdims=True) ++ right_norm = np.concatenate([right_diff[:,1:2], -right_diff[:,0:1]], axis=1) ++ # expand channel ++ left = left + (left_norm * expand_len) ++ right = right + (right_norm * expand_len) ++ ++ channel_poly = np.concatenate([right, np.flip(left, axis=0)], axis=0) ++ drivable_area_polys.append(channel_poly) ++ seg.layers['drivable_area'] = drivable_area_polys ++ del seg.layers['lane_channels'] ++ ++ # add empty carpark area for completeness ++ seg.layers['carpark_area'] = [] ++ ++ # compute extents of map ++ map_maxes = np.array([-float('inf'), -float('inf')]) # xmax, ymax ++ map_mins = np.array([float('inf'), float('inf')]) # xmin, ymin ++ for _, seg in tile.segments.items(): ++ for k, v in seg.layers.items(): ++ if len(v) > 0: ++ all_pts = np.concatenate(v, axis=0) ++ cur_maxes = np.amax(all_pts, axis=0) ++ cur_mins = np.amin(all_pts, axis=0) ++ map_maxes = np.where(cur_maxes > map_maxes, cur_maxes, map_maxes) ++ map_mins = np.where(cur_mins < map_mins, cur_mins, map_mins) ++ map_xlim = (map_mins[0] - 10, map_maxes[0] + 10) # buffer of 10m ++ map_ylim = (map_mins[1] - 10, map_maxes[1] + 10) # buffer of 10m ++ W = map_xlim[1] - map_xlim[0] ++ H = map_ylim[1] - map_ylim[0] ++ # translate so bottom left corner is at origin ++ trans_offset = np.array([[-map_xlim[0], -map_ylim[0]]]) ++ tile.trans_offset = trans_offset ++ tile.H = H ++ tile.W = W ++ for _, seg in tile.segments.items(): ++ seg.local2world[:2, -1] += trans_offset[0] ++ for k, v in seg.layers.items(): ++ if len(v) > 0: ++ for pts in v: ++ pts += trans_offset ++ ++ return tile ++ ++def load_tile(tile_path, ++ layers=['lane_dividers_v1', 'lane_channels_v1']): ++ # load in all road segment dicts ++ print('Parsing segment graph...') ++ tile_name = tile_path.split('/')[-1] ++ segs_path = os.path.join(tile_path, SEGMENT_GRAPH_LAYER) ++ assert os.path.exists(tile_path), 'cannot find segment graph layer, which is required to load any layers' ++ seg_json_list, _ = load_json_dir(segs_path) ++ road_segments = parse_road_segments(seg_json_list) ++ print('Found %d road segments:' % (len(road_segments))) ++ ++ # fill road segments with other desired layers ++ print('Loading requested layers...') ++ for layer in layers: ++ assert layer in NVMAP_LAYERS, 'loading layer type %s is currently not supported!' % (layer) ++ layer_dir = os.path.join(tile_path, layer) ++ assert os.path.exists(layer_dir), 'could not find requested layer %s in tile directory!' % (layer) ++ layer_json_list, seg_names = load_json_dir(layer_dir) ++ if layer == 'lane_dividers_v1': ++ parse_lane_dividers(layer_json_list, seg_names, road_segments) ++ elif layer == 'lane_channels_v1': ++ parse_lane_channels(layer_json_list, seg_names, road_segments) ++ elif layer == 'lanes_v1': ++ raise NotImplementedError() ++ ++ return NVMapTile(road_segments, name=tile_name) ++ ++ ++def load_json_dir(json_dir_path): ++ ''' ++ Loads in all json files in the given directory and returns the resulting list of dicts along ++ with the names of the json files read from. ++ ''' ++ json_files = sorted(glob.glob(os.path.join(json_dir_path, '*.json'))) ++ file_names = ['.'.join(jf.split('/')[-1].split('.')[:-1]) for jf in json_files] ++ json_list = [] ++ for jf in json_files: ++ with open(jf, 'r') as f: ++ json_list.append(json.load(f)) ++ return json_list, file_names ++ ++def parse_pt(pt_dict): ++ pt_entries = ['x', 'y', 'z', 'w'] ++ if len(pt_dict) == 3: ++ pt = [pt_dict[pt_entries[0]], pt_dict[pt_entries[1]], pt_dict[pt_entries[2]]] ++ elif len(pt_dict) == 4: ++ pt = [pt_dict[pt_entries[0]], pt_dict[pt_entries[1]], pt_dict[pt_entries[2]], pt_dict[pt_entries[3]]] ++ else: ++ assert False, 'input point must be length 3 or 4' ++ return pt ++ ++def parse_pt_list(pt_list): ++ return np.array([parse_pt(pt) for pt in pt_list]) ++ ++def parse_lane_channels(lane_channels_json_list, seg_name_list, road_segments): ++ ''' ++ Parses lane channels in each segment to an object, and store in its respective segment. ++ Transforms channels into the global coordinate system. ++ ++ :param lane_channels_json_list: list of lane channels json dicts ++ :param seg_name_list: the name of the segment corresponding to each json file in lane_div_json_list ++ :param road_segments: dict of all road segments in a tile ++ :return: updated road_segments (also updated in place) ++ ''' ++ for channel_dict, seg_name in zip(lane_channels_json_list, seg_name_list): ++ cur_seg = road_segments[seg_name] ++ # load lane dividers ++ lane_channels = [] ++ lane_channel_dicts = channel_dict['channels'] ++ for channel in lane_channel_dicts: ++ # left ++ left_geom = channel['left_side']['geometry'][0]['chunk']['channel_edge_line'] ++ left_polyline = parse_pt_list(left_geom['points']) ++ left_polyline = cur_seg.to_global(left_polyline)[:,:3] ++ # right ++ right_geom = channel['right_side']['geometry'][0]['chunk']['channel_edge_line'] ++ right_polyline = parse_pt_list(right_geom['points']) ++ right_polyline = cur_seg.to_global(right_polyline)[:,:3] ++ assert left_polyline.shape[0] == right_polyline.shape[0], 'channel edges should be same length!' ++ # build lane channel object ++ lane_channels.append(LaneChannel(left_polyline, right_polyline)) ++ cur_seg.layers['lane_channels'] = lane_channels ++ ++ return road_segments ++ ++def parse_lane_dividers(lane_div_json_list, seg_name_list, road_segments): ++ ''' ++ Parses lane dividers in each segment to an object, and store in its respective segment. ++ Transforms dividers into the global coordinate system. ++ ++ :param lane_div_json_list: list of lane divider json dicts ++ :param seg_name_list: the name of the segment corresponding to each json file in lane_div_json_list ++ :param road_segments: dict of all road segments in a tile ++ :return: updated road_segments (also updated in place) ++ ''' ++ for div_dict, seg_name in zip(lane_div_json_list, seg_name_list): ++ cur_seg = road_segments[seg_name] ++ # load lane dividers ++ lane_divs = [] ++ lane_div_dicts = div_dict['dividers'] ++ for div in lane_div_dicts: ++ # NOTE: left/right/height lines also sometimes available - ignoring for now ++ # NOTE: lane divider styles (type/color) also available - ignoring for now ++ # center_line is only guaranteed ++ lane_geom = div['geometry']['center_line'] ++ lane_polyline = parse_pt_list(lane_geom['points']) ++ lane_polyline = cur_seg.to_global(lane_polyline)[:,:3] ++ # parse divider style ++ if 'style' in div: ++ lane_styles = div['style'] ++ lane_styles = [(style['pattern'], style['material'], style['color']) for style in lane_styles] ++ else: ++ lane_styles = [] ++ # build lane div object ++ lane_divs.append(LaneDivider(lane_polyline, lane_styles)) ++ cur_seg.layers['lane_dividers'] = lane_divs ++ ++ # load road boundaries ++ road_bounds = [] ++ road_bound_dicts = div_dict['road_boundaries'] ++ for div in road_bound_dicts: ++ # NOTE: left/right/height lines also sometimes available - ignoring for now ++ # NOTE: road boundary type also available - ignoring for now ++ # center_line is only guaranteed ++ bound_geom = div['geometry']['center_line'] ++ bound_polyline = parse_pt_list(bound_geom['points']) ++ bound_polyline = cur_seg.to_global(bound_polyline)[:,:3] ++ # parse boundary type ++ if 'type' in div: ++ bound_type = div['type'] ++ else: ++ bound_type = None ++ # build object ++ road_bounds.append(RoadBoundary(bound_polyline, bound_type)) ++ cur_seg.layers['road_boundaries'] = road_bounds ++ ++ return road_segments ++ ++def parse_road_segments(seg_json_list): ++ ''' ++ Parses road segments into objects, and transforms into a shared coordinate system. ++ ++ :param seg_json_list: list of road segment json dicts ++ ++ :return: dict of all road segments mapping id -> RoadSegment ++ ''' ++ # build objects for all road segments ++ road_segments = OrderedDict() ++ for seg_dict in seg_json_list: ++ cur_seg = build_road_seg(seg_dict) ++ road_segments[cur_seg.id] = cur_seg ++ ++ # go through and annotate local2world by converting GPS to the "global" coordinate system ++ for seg_id, seg in road_segments.items(): ++ lat_lng_alt = np.array(seg.gps_origin).reshape((-1,3)) ++ rot_axis = np.array([1.0, 0.0, 0.0]).reshape((-1,3)) ++ rot_angle = np.array([0.0]).reshape((-1,1)) ++ ecef_pose = lat_lng_alt_2_ecef(lat_lng_alt, rot_axis, rot_angle, 'WGS84')[0] ++ seg.local2world = np.linalg.inv(GLOBAL_BASE_POSE) @ ecef_pose ++ ++ return road_segments ++ ++def build_road_seg(seg_dict): ++ ''' ++ Parses road segment json dictionary into an object. ++ ''' ++ segment = seg_dict['segment'] ++ seg_id = segment['id'] ++ seg_origin = [segment['origin']['lat'], segment['origin']['lon'], segment['origin']['height']] ++ connections = [] ++ if 'connections' in segment: ++ conn_list = segment['connections'] ++ for conn in conn_list: ++ source_id = conn['source_id'] ++ source2tgt = [] ++ for ci in range(4): ++ source2tgt.append(np.array(parse_pt(conn['source_to_target'][f'column_{ci}']))) ++ source2tgt = np.stack(source2tgt, axis=1) ++ connections.append((source_id, source2tgt)) ++ return RoadSegment(seg_id, seg_origin, connections) ++ ++def collect_seg_origins(road_segments): ++ ''' ++ Returns np array of all road_segment origins (in order of dict). ++ :param road_segments: OrderedDict of RoadSegment objects ++ ''' ++ return np.array([seg.origin for _, seg in road_segments.items()]) ++ ++class RoadSegment(object): ++ def __init__(self, seg_id, origin, connections, ++ is_root=False, ++ layers=None): ++ ''' ++ :param seg_id str: ++ :param origin: list of [lat, lon, height] ++ :param connections list: list of tuples, each containing (neighbor_id, neighbor2local_transform) ++ where the transform is a np.array(4,4)) ++ :param layers dict: layer objects within this road segment ++ ''' ++ self.id = seg_id ++ self.gps_origin = origin # GPS ++ self.connections = connections ++ self.layers = layers if layers is not None else dict() ++ # transformation matrix from local to world frame ++ self.local2world = None ++ ++ def to_global(self, pts): ++ ''' ++ Transform an array of points from this segment's frame to global. ++ ++ :param pts: np array (N x 3) ++ ''' ++ pts = np.concatenate([pts, np.ones((pts.shape[0], 1))], axis=1) ++ pts = np.dot(self.local2world, pts.T).T ++ return pts ++ ++ def transform(self, mat): ++ ''' ++ Transforms this segment and all contained layers by the given 4x4 transformation matrix. ++ ''' ++ self.local2world = mat @ self.local2world ++ for _, layer in self.layers.items(): ++ for el in layer: ++ el.transform(mat) ++ ++ def __repr__(self): ++ return '' % str('\n '.join('%s : %s' % (k, repr(v)) for (k, v) in self.__dict__.items())) ++ ++class LaneDivider(object): ++ def __init__(self, polyline, style): ++ ''' ++ :param polyline: np.array Nx3 defining the divider geometry ++ :param style: list of tuples of (pattern, matterial, color) ++ ''' ++ self.polyline = polyline ++ self.style = style ++ ++ def transform(self, mat): ++ ''' ++ Transforms this map element by the given 4x4 transformation matrix. ++ ''' ++ pts = np.concatenate([self.polyline, np.ones((self.polyline.shape[0], 1))], axis=1) ++ pts = np.dot(mat, pts.T).T ++ self.polyline = pts[:,:3] ++ ++ def __repr__(self): ++ return '' % str('\n '.join('%s : %s' % (k, repr(v)) for (k, v) in self.__dict__.items())) ++ ++class RoadBoundary(object): ++ def __init__(self, polyline, bound_type): ++ ''' ++ :param polyline: np.array Nx3 defining the divider geometry ++ :param type str: type of the boundary ++ ''' ++ self.polyline = polyline ++ self.bound_type = bound_type ++ ++ def transform(self, mat): ++ ''' ++ Transforms this map element by the given 4x4 transformation matrix. ++ ''' ++ pts = np.concatenate([self.polyline, np.ones((self.polyline.shape[0], 1))], axis=1) ++ pts = np.dot(mat, pts.T).T ++ self.polyline = pts[:,:3] ++ ++ def __repr__(self): ++ return '' % str('\n '.join('%s : %s' % (k, repr(v)) for (k, v) in self.__dict__.items())) ++ ++class LaneChannel(object): ++ def __init__(self, left_polyline, right_polyline): ++ ''' ++ :param left_polyline: np.array Nx3 defining the channel left edge geometry ++ :param right_polyline: np.array Nx3 defining the channel right edge geometry ++ ''' ++ self.left = left_polyline ++ self.right = right_polyline ++ ++ def transform(self, mat): ++ ''' ++ Transforms this map element by the given 4x4 transformation matrix. ++ ''' ++ pts = np.concatenate([self.left, np.ones((self.left.shape[0], 1))], axis=1) ++ pts = np.dot(mat, pts.T).T ++ self.left = pts[:,:3] ++ pts = np.concatenate([self.right, np.ones((self.right.shape[0], 1))], axis=1) ++ pts = np.dot(mat, pts.T).T ++ self.right = pts[:,:3] ++ ++ def __repr__(self): ++ return '' % str('\n '.join('%s : %s' % (k, repr(v)) for (k, v) in self.__dict__.items())) ++ ++class NVMapTile(object): ++ def __init__(self, segments, ++ name=None): ++ ''' ++ :param segments: dict mapping seg_id -> RoadSegment objects for all road segments in the tile. ++ ''' ++ self.segments = segments ++ self.name = name ++ ++ def transform(self, mat): ++ ''' ++ Multiply all elements of the map tile by the given 4x4 matrix ++ ''' ++ for seg_id, seg in self.segments.items(): ++ seg.transform(mat) ++ ++ def __repr__(self): ++ return '' % str('\n '.join('%s : %s' % (k, repr(v)) for (k, v) in self.__dict__.items())) ++ ++ ++####################################################################################################################### ++ ++# ++# Utils from Zan for converting between GPS and world coordinates ++# ++ ++from scipy.spatial.transform import Rotation as R ++ ++ ++def lat_lng_alt_2_ecef(lat_lng_alt, orientation_axis, orientation_angle, earth_model='WGS84'): ++ ''' Computes the transformation from the world pose coordiante system to the earth centered earth fixed (ECEF) one ++ Args: ++ lat_lng_alt (np.array): latitude, longitude and altitude coordinate (in degrees and meters) [n,3] ++ orientation_axis (np.array): orientation in the local ENU coordinate system [n,3] ++ orientation_angle (np.array): orientation angle of the local ENU coordinate system in degrees [n,1] ++ earth_model (string): earth model used for conversion (spheric will be unaccurate when maps are large) ++ Out: ++ trans (np.array): transformation parameters from world pose to ECEF coordinate system in se3 form (n, 4, 4) ++ ''' ++ n = lat_lng_alt.shape[0] ++ trans = np.tile(np.eye(4).reshape(1,4,4),[n,1,1]) ++ ++ theta = (90. - lat_lng_alt[:, 0]) * np.pi/180 ++ phi = lat_lng_alt[:, 1] * np.pi/180 ++ ++ R_enu_ecef = local_ENU_2_ECEF_orientation(theta, phi) ++ ++ if earth_model == 'WGS84': ++ a = 6378137.0 ++ flattening = 1.0 / 298.257223563 ++ b = a * (1.0 - flattening) ++ translation = lat_lng_alt_2_translation_ellipsoidal(lat_lng_alt, a, b) ++ ++ elif earth_model == 'sphere': ++ earth_radius = 6378137.0 # Earth radius in meters ++ z_dir = np.concatenate([(np.sin(theta)*np.cos(phi))[:,None], ++ (np.sin(theta)*np.sin(phi))[:,None], ++ (np.cos(theta))[:,None] ],axis=1) ++ ++ translation = (earth_radius + lat_lng_alt[:, -1])[:,None] * z_dir ++ ++ else: ++ raise ValueError ("Selected ellipsoid not implemented!") ++ ++ world_pose_orientation = axis_angle_2_so3(orientation_axis, orientation_angle) ++ ++ trans[:,:3,:3] = R_enu_ecef @ world_pose_orientation ++ trans[:,:3,3] = translation ++ ++ return trans ++ ++def local_ENU_2_ECEF_orientation(theta, phi): ++ ''' Computes the rotation matrix between the world_pose and ECEF coordinate system ++ Args: ++ theta (np.array): theta coordinates in radians [n,1] ++ phi (np.array): phi coordinates in radians [n,1] ++ Out: ++ (np.array): rotation from world pose to ECEF in so3 representation [n,3,3] ++ ''' ++ z_dir = np.concatenate([(np.sin(theta)*np.cos(phi))[:,None], ++ (np.sin(theta)*np.sin(phi))[:,None], ++ (np.cos(theta))[:,None] ],axis=1) ++ z_dir = z_dir/np.linalg.norm(z_dir, axis=-1, keepdims=True) ++ ++ y_dir = np.concatenate([-(np.cos(theta)*np.cos(phi))[:,None], ++ -(np.cos(theta)*np.sin(phi))[:,None], ++ (np.sin(theta))[:,None] ],axis=1) ++ y_dir = y_dir/np.linalg.norm(y_dir, axis=-1, keepdims=True) ++ ++ x_dir = np.cross(y_dir, z_dir) ++ ++ return np.concatenate([x_dir[:,:,None], y_dir[:,:,None], z_dir[:,:,None]], axis = -1) ++ ++ ++def lat_lng_alt_2_translation_ellipsoidal(lat_lng_alt, a, b): ++ ''' Computes the translation based on the ellipsoidal earth model ++ Args: ++ lat_lng_alt (np.array): latitude, longitude and altitude coordinate (in degrees and meters) [n,3] ++ a (float/double): Semi-major axis of the ellipsoid ++ b (float/double): Semi-minor axis of the ellipsoid ++ Out: ++ (np.array): translation from world pose to ECEF [n,3] ++ ''' ++ ++ phi = lat_lng_alt[:, 0] * np.pi/180 ++ gamma = lat_lng_alt[:, 1] * np.pi/180 ++ ++ cos_phi = np.cos(phi) ++ sin_phi = np.sin(phi) ++ cos_gamma = np.cos(gamma) ++ sin_gamma = np.sin(gamma) ++ e_square = (a * a - b * b) / (a * a) ++ ++ N = a / np.sqrt(1 - e_square * sin_phi * sin_phi) ++ ++ ++ x = (N + lat_lng_alt[:, 2]) * cos_phi * cos_gamma ++ y = (N + lat_lng_alt[:, 2]) * cos_phi * sin_gamma ++ z = (N * (b*b)/(a*a) + lat_lng_alt[:, 2]) * sin_phi ++ ++ return np.concatenate([x[:,None] ,y[:,None], z[:,None]], axis=1 ) ++ ++def axis_angle_2_so3(axis, angle, degrees=True): ++ ''' Converts the axis angle representation of the so3 rotation matrix ++ Args: ++ axis (np.array): the rotation axes [n,3] ++ angle float/double: rotation angles either in degrees or radians [n] ++ degrees bool: True if angle is given in degrees else False ++ ++ Out: ++ (np array): rotations given so3 matrix representation [n,3,3] ++ ''' ++ # Treat angle (radians) below this as 0. ++ cutoff_angle = 1e-9 if not degrees else 1e-9*180/np.pi ++ angle[angle < cutoff_angle] = 0.0 ++ ++ # Scale the axis to have the norm representing the angle ++ if degrees: ++ angle = np.radians(angle) ++ axis_angle = (angle/np.linalg.norm(axis, axis=1, keepdims=True)) * axis ++ ++ return R.from_rotvec(axis_angle).as_matrix() ++ ++def ecef_2_lat_lng_alt(trans, earth_model='WGS84'): ++ ''' Converts the transformation from the earth centered earth fixed (ECEF) coordinate frame to the world pose ++ Args: ++ trans (np.array): transformation parameters in ECEF [n,4,4] ++ earth_model (string): earth model used for conversion (spheric will be unaccurate when maps are large) ++ Out: ++ lat_lng_alt (np.array): latitude, longitude and altitude coordinate (in degrees and meters) [n,3] ++ orientation_axis (np.array): orientation in the local ENU coordinate system [n,3] ++ orientation_angle (np.array): orientation angle of the local ENU coordinate system in degrees [n,1] ++ ''' ++ ++ translation = trans[:,:3,3] ++ rotation = trans[:,:3,:3] ++ ++ if earth_model == 'WGS84': ++ a = 6378137.0 ++ flattening = 1.0 / 298.257223563 ++ lat_lng_alt = translation_2_lat_lng_alt_ellipsoidal(translation, a, flattening) ++ ++ elif earth_model == 'sphere': ++ earth_radius = 6378137.0 # Earth radius in meters ++ lat_lng_alt = translation_2_lat_lng_alt_spherical(translation, earth_radius) ++ ++ else: ++ raise ValueError ("Selected ellipsoid not implemented!") ++ ++ ++ # Compute the orientation axis and angle ++ theta = (90. - lat_lng_alt[:, 0]) * np.pi/180 ++ phi = lat_lng_alt[:, 1] * np.pi/180 ++ ++ R_ecef_enu = local_ENU_2_ECEF_orientation(theta, phi).transpose(0,2,1) ++ ++ orientation = R_ecef_enu @ rotation ++ orientation_axis, orientation_angle = so3_2_axis_angle(orientation) ++ ++ ++ return lat_lng_alt, orientation_axis, orientation_angle ++ ++def translation_2_lat_lng_alt_spherical(translation, earth_radius): ++ ''' Computes the translation in the ECEF to latitude, longitude, altitude based on the spherical earth model ++ Args: ++ translation (np.array): translation in the ECEF coordinate frame (in meters) [n,3] ++ earth_radius (float/double): earth radius ++ Out: ++ (np.array): latitude, longitude and altitude [n,3] ++ ''' ++ altitude = np.linalg.norm(translation, axis=-1) - earth_radius ++ latitude = 90 - np.arccos(translation[:,2] / np.linalg.norm(translation, axis=-1, keepdims=True)) * 180/np.pi ++ longitude = np.arctan2(translation[:,1],translation[:,0]) * 180/np.pi ++ ++ return np.concatenate([latitude[:,None], longitude[:,None], altitude[:,None]], axis=1) ++ ++def translation_2_lat_lng_alt_ellipsoidal(translation, a, f): ++ ''' Computes the translation in the ECEF to latitude, longitude, altitude based on the ellipsoidal earth model ++ Args: ++ translation (np.array): translation in the ECEF coordinate frame (in meters) [n,3] ++ a (float/double): Semi-major axis of the ellipsoid ++ f (float/double): flattening factor of the earth ++ radius ++ Out: ++ (np.array): latitude, longitude and altitude [n,3] ++ ''' ++ ++ # Compute support parameters ++ f0 = (1 - f) * (1 - f) ++ f1 = 1 - f0 ++ f2 = 1 / f0 - 1 ++ ++ z_div_1_f = translation[:,2] / (1 - f) ++ x2y2 = np.square(translation[:,0]) + np.square(translation[:,1]) ++ ++ x2y2z2 = x2y2 + z_div_1_f*z_div_1_f ++ x2y2z2_pow_3_2 = x2y2z2 * np.sqrt(x2y2z2) ++ ++ gamma = (x2y2z2_pow_3_2 + a * f2 * z_div_1_f * z_div_1_f) / (x2y2z2_pow_3_2 - a * f1 * x2y2) * translation[:,2] / np.sqrt(x2y2) ++ ++ longitude = np.arctan2(translation[:,1], translation[:,0]) * 180/np.pi ++ latitude = np.arctan(gamma) * 180/np.pi ++ altitude = np.sqrt(1 + np.square(gamma)) * (np.sqrt(x2y2) - a / np.sqrt(1 + f0 * np.square(gamma))) ++ ++ return np.concatenate([latitude[:,None], longitude[:,None], altitude[:,None]], axis=1) ++ ++def so3_2_axis_angle(so3, degrees=True): ++ ''' Converts the so3 representation to axis_angle ++ Args: ++ so3 (np.array): the rotation matrices [n,3,3] ++ degrees bool: True if angle should be given in degrees ++ ++ Out: ++ axis (np array): the rotation axis [n,3] ++ angle (np array): the rotation angles, either in degrees (if degrees=True) or radians [n,] ++ ''' ++ rot_vec = R.from_matrix(so3).as_rotvec() ++ ++ angle = np.linalg.norm(rot_vec, axis=-1, keepdims=True) ++ axis = rot_vec / angle ++ if degrees: ++ angle = np.degrees(angle) ++ ++ return axis, angle ++ ++####################################################################################################################### ++ ++# ++# Utils for loading in ego and autolabel pose data for session ++# ++ ++import datetime ++import pickle ++ ++from scipy import spatial, interpolate ++ ++# NV_EGO_LWH = [4.084, 1.73, 1.562] # this is the nuscenes measurements ++NV_EGO_LWH = [5.30119, 2.1133, 1.49455] # actual hyperion 8 ++ ++class PoseInterpolator: ++ ''' Interpolates the poses to the desired time stamps. The translation component is interpolated linearly, ++ while spherical linear interpolation (SLERP) is used for the rotations. ++ https://en.wikipedia.org/wiki/Slerp ++ ++ Args: ++ poses (np.array): poses at given timestamps in a se3 representation [n,4,4] ++ timestamps (np.array): timestamps of the known poses [n] ++ ts_target (np.array): timestamps for which the poses will be interpolated [m,1] ++ Out: ++ (np.array): interpolated poses in se3 representation [m,4,4] ++ ''' ++ def __init__(self, poses, timestamps): ++ ++ self.slerp = spatial.transform.Slerp(timestamps, R.from_matrix(poses[:,:3,:3])) ++ self.f_x = interpolate.interp1d(timestamps, poses[:,0,3]) ++ self.f_y = interpolate.interp1d(timestamps, poses[:,1,3]) ++ self.f_z = interpolate.interp1d(timestamps, poses[:,2,3]) ++ ++ self.last_row = np.array([0,0,0,1]).reshape(1,1,-1) ++ ++ def interpolate_to_timestamps(self, ts_target): ++ x_interp = self.f_x(ts_target).reshape(-1,1,1) ++ y_interp = self.f_y(ts_target).reshape(-1,1,1) ++ z_interp = self.f_z(ts_target).reshape(-1,1,1) ++ R_interp = self.slerp(ts_target).as_matrix().reshape(-1,3,3) ++ ++ t_interp = np.concatenate([x_interp,y_interp,z_interp],axis=-2) ++ ++ return np.concatenate((np.concatenate([R_interp,t_interp],axis=-1), np.tile(self.last_row,(R_interp.shape[0],1,1))), axis=1) ++ ++def angle_diff(x, y, period=2*np.pi): ++ ''' ++ Get the smallest angle difference between 2 angles: the angle from y to x. ++ :param x: angle 1 (B) ++ :param y: angle 2 (B) ++ :param period: periodicity in radians for assessing difference. ++ :return diff: smallest angle difference between to angles (B) ++ ''' ++ # calculate angle difference, modulo to [0, 2*pi] ++ diff = (x - y + period / 2) % period - period / 2 ++ diff[diff > np.pi] = diff[diff > np.pi] - (2 * np.pi) # shift (pi, 2*pi] to (-pi, 0] ++ return diff ++ ++def load_ego_pose_from_image_meta(images_path, map_tile=None): ++ ''' ++ Loads in the SDC pose from the metadata attached to a session image stream. ++ ++ :param images_path str: directory of the images/metadata. Should contain *.pkl for each frame and timestamps.npz ++ :param map_tile: Tile object, if given translates the ego trajectory so in the same frame as this map tile. ++ ''' ++ frame_meta = sorted(glob.glob(os.path.join(images_path, '*.pkl'))) ++ timestamps_pth = os.path.join(images_path, 'timestamps.npz') ++ ++ # load in timesteps ++ # ego_t = np.load(timestamps_pth)['frame_t'] ++ ++ ego_poses = [] ++ ego_t = [] ++ for meta_file in frame_meta: ++ with open(meta_file, 'rb') as f: ++ cur_meta = pickle.load(f) ++ # ego_poses.append(cur_meta['ego_pose_s']) ++ # ego_poses.append(cur_meta['ego_pose_timestamps'][0]) ++ ego_poses.append(cur_meta['ego_pose_e']) ++ ego_t.append(cur_meta['ego_pose_timestamps'][1]) ++ ego_poses = np.stack(ego_poses, axis=0) ++ ego_t = np.array(ego_t) ++ ++ # pose_sidx = int(frame_meta[0].split('/')[-1].split('.')[0]) ++ # pose_eidx = int(frame_meta[-1].split('/')[-1].split('.')[0]) + 1 ++ # ego_t = ego_t[pose_sidx:pose_eidx] ++ ++ if map_tile is not None: ++ ego_poses[:, :2, -1] += map_tile.trans_offset ++ ++ return ego_poses, ego_t ++ ++def check_time_overlap(s0, e0, s1, e1): ++ overlap = (s0 < s1 and e0 > s1) or \ ++ (s0 > s1 and s0 < e1) or \ ++ (s1 < s0 and e1 > s0) or \ ++ (s1 > s0 and s1 < e0) ++ return overlap ++ ++ ++def load_trajectories(autolabels_path, ego_images_path, lidar_rig_path, ++ map_tile=None, ++ frame_range=None, ++ postprocess=True, ++ fill_first_n=None, ++ mine_dups=False, ++ extra_obj_path=None, ++ load_ids=None, ++ crop2valid=True): ++ ''' ++ This only loads labeled trajectories that are available at the same section ++ as the ego labels. ++ ++ :param autolabels_path str: pkl file to load autolabels from ++ :param ego_images_path str: directory containing image metadata to load ego poses from ++ :param lidar_rig_path str: npz containing lidar2rig transform for ego ++ :param map_tile: Tile object, if given translates the trajectories so in the same frame as this map tile. ++ :param frame_range tuple: If given (start, end) only loads in this frame range (wrt the ego sequence) ++ :param postprocess bool: If true, runs some post-processing to associate track and heuristically remove rotation flips. ++ :param extra_obj_path str: if given, load in an additional trajectory from here and includes in the data ++ :return traj_T: ++ ''' ++ # Load the autolabels ++ with open(autolabels_path, 'rb') as f: ++ labels = pickle.load(f) ++ ++ # Load the poses and the timestamps ++ ego_poses, ego_pose_timestamps = load_ego_pose_from_image_meta(ego_images_path) ++ if frame_range is not None: ++ assert frame_range[1] > frame_range[0] ++ assert (frame_range[0] >= 0 and frame_range[0] <= ego_poses.shape[0]) ++ ego_poses = ego_poses[frame_range[0]:frame_range[1]] ++ ego_pose_timestamps = ego_pose_timestamps[frame_range[0]:frame_range[1]] ++ ++ # if ego_poses are not in rivermark frame, need to take them so can load autolabels ++ ego_poses_rivermark = ego_poses ++ if GLOBAL_BASE_POSE is not GLOBAL_BASE_POSE_RIVERMARK: ++ ego_poses_ecef = np.matmul(GLOBAL_BASE_POSE[np.newaxis], ego_poses) ++ ego_poses_rivermark = np.matmul(np.linalg.inv(GLOBAL_BASE_POSE_RIVERMARK)[np.newaxis], ego_poses_ecef) ++ ++ # Load the lidar to rig transformation parameters and timestamps ++ T_lidar_rig = np.load(lidar_rig_path)['T_lidar_rig'] ++ ++ # first pass to break tracks into contiguous subsequences ++ # and merge manually given missed associations ++ track_seqs = dict() ++ processed_ids = set() ++ updated_labels = dict() ++ for track_id, track in labels.items(): ++ if load_ids is not None and track_id not in load_ids: ++ continue ++ if track_id in processed_ids or track_id in TRACK_REMOVE: ++ # already processed this through association ++ # or should be removed ++ continue ++ ++ obj_ts = track['3D_bbox'][:,0] ++ if track_id in TRACK_ASSOC: ++ # stack all the data from all associated tracks ++ # NOTE: this assumes TRACK_ASSOC is sorted in temporal order already ++ assoc_data = [labels[assoc_track_id]['3D_bbox'] for assoc_track_id in TRACK_ASSOC[track_id]] ++ # if association is wrong, the tracks may overlap ++ valid_assoc = [not check_time_overlap(obj_ts[0], obj_ts[-1], assoc_label[0,0], assoc_label[-1,0]) for assoc_label in assoc_data] ++ if np.sum(valid_assoc) != len(valid_assoc): ++ print('Invalid associations for track_id %s!!' % (track_id)) ++ print('Ignoring: ') ++ print(np.array(TRACK_ASSOC[track_id])[~np.array(valid_assoc)]) ++ assoc_data = [assoc_label for aid, assoc_label in enumerate(assoc_data) if valid_assoc[aid]] ++ if len(assoc_data) > 0: ++ assoc_bbox = np.concatenate([track['3D_bbox']] + assoc_data, axis=0) ++ updated_labels[track_id] = {'3D_bbox' : assoc_bbox, 'type' : track['type']} ++ obj_ts = assoc_bbox[:,0] ++ processed_ids.update(TRACK_ASSOC[track_id]) ++ else: ++ updated_labels[track_id] = track ++ else: ++ updated_labels[track_id] = track ++ ++ if len(obj_ts) < 2: ++ # make sure track is longer than single frame ++ continue ++ track_seqs[track_id] = [] ++ # larger than 3 timesteps considered a break, otherwise should be reasonable to interpolate ++ # should we do even larger? ++ track_break = np.diff(1e-6*obj_ts) > (AUTOLABEL_DT*3 + AUTOLABEL_DT*0.5) ++ seq_sidx = 0 ++ for tidx in range(1, obj_ts.shape[0]): ++ if track_break[tidx-1]: ++ track_seqs[track_id].append((seq_sidx, tidx)) ++ seq_sidx = tidx ++ track_seqs[track_id].append((seq_sidx, obj_ts.shape[0])) ++ processed_ids.add(track_id) ++ ++ # load extra object ++ if extra_obj_path is not None: ++ extra_obj_data = np.load(extra_obj_path) ++ extra_obj_poses = extra_obj_data['poses'] ++ extra_obj_timestamps = extra_obj_data['pose_timestamps'] ++ extra_obj_lwh = EXTRA_DYN_LWH ++ extra_track = { ++ 'poses' : extra_obj_poses, ++ 'timestamps' : extra_obj_timestamps, ++ 'lwh' : extra_obj_lwh, ++ 'type' : 'car' ++ } ++ updated_labels['extra'] = extra_track ++ track_seqs['extra'] = [(0,extra_obj_poses.shape[0])] ++ ++ # collect all tracks that overlap with ego data ++ traj_poses = [] ++ traj_valid = [] ++ traj_lwh = [] ++ traj_ids = [] ++ for track_id, cur_track_seqs in track_seqs.items(): ++ track = updated_labels[track_id] ++ if track_id == 'extra': ++ all_obj_ts = track['timestamps'] ++ all_obj_dat = track['poses'] ++ obj_lwh = track['lwh'] ++ else: ++ all_obj_ts = track['3D_bbox'][:,0] ++ all_obj_dat = track['3D_bbox'] ++ obj_lwh = np.median(all_obj_dat[:,4:7], axis=0) # use all timesteps for bbox size ++ # will fill these in as we go through each subseq ++ full_obj_traj = np.ones_like(ego_poses_rivermark)*np.nan ++ obj_valid = np.zeros((full_obj_traj.shape[0]), dtype=bool) ++ for seq_sidx, seq_eidx in cur_track_seqs: ++ obj_ts = all_obj_ts[seq_sidx:seq_eidx] ++ if (obj_ts[0] >= ego_pose_timestamps[0] and obj_ts[0] <= ego_pose_timestamps[-1]) or \ ++ (obj_ts[-1] >= ego_pose_timestamps[0] and obj_ts[-1] <= ego_pose_timestamps[-1]) or \ ++ (obj_ts[0] <= ego_pose_timestamps[0] and obj_ts[-1] >= ego_pose_timestamps[-1]): ++ obj_type = track['type'] ++ # if obj_type != 'car': ++ # continue ++ obj_dat = all_obj_dat[seq_sidx:seq_eidx] ++ # find steps overlapping the ego sequence ++ valid_ts = np.logical_and(obj_ts >= ego_pose_timestamps[0], obj_ts <= ego_pose_timestamps[-1]) ++ ++ overlap_inds = np.nonzero(valid_ts)[0] ++ if len(overlap_inds) < 2: ++ continue # need more than 1 frame overlap ++ sidx = np.amin(overlap_inds) ++ eidx = np.amax(overlap_inds)+1 ++ ++ obj_ts = obj_ts[sidx:eidx] ++ # some poses have the same timestep -- drop these so we can interpolate ++ valid_t = np.diff(obj_ts) > 0 ++ valid_t = np.append(valid_t, [True]) ++ if not valid_t[0]: ++ # want to keep the edge times in tact since these surround ego times ++ valid_t[0] = True ++ valid_t[1] = False ++ obj_ts = obj_ts[valid_t] ++ ++ if track_id == 'extra': ++ print(obj_ts) ++ glob_obj_poses = obj_dat[sidx:eidx] ++ print(glob_obj_poses.shape) ++ # exit() ++ else: ++ obj_pos = obj_dat[sidx:eidx,1:4][valid_t] ++ # print(obj_pos) ++ # print(obj_dat[sidx:eidx,4:7][valid_t]) ++ obj_rot_eulxyz = obj_dat[sidx:eidx,7:][valid_t] ++ obj_rot_eulxyz[obj_rot_eulxyz[:,2] < -np.pi, 2] += (2 * np.pi) ++ obj_rot_eulxyz[obj_rot_eulxyz[:,2] > np.pi, 2] -= (2 * np.pi) ++ obj_R = R.from_euler('xyz', obj_rot_eulxyz, degrees=False).as_matrix() ++ # build transformation matrix (pose sequence) ++ obj_poses = np.repeat(np.eye(4)[np.newaxis], len(obj_ts), axis=0) ++ obj_poses[:,:3,:3] = obj_R ++ obj_poses[:,:3,-1] = obj_pos ++ ++ # need to interpolate the ego pose to transform from lidar frame to global ++ overlap_ego_mask = np.logical_and(ego_pose_timestamps >= obj_ts[0] - 1e6, ego_pose_timestamps <= obj_ts[-1] + 1e6) # add 1 sec around so can interp first/last frames ++ overlap_ego_t = ego_pose_timestamps[overlap_ego_mask] ++ overlap_ego_poses = ego_poses_rivermark[overlap_ego_mask] ++ ego_interp = PoseInterpolator(overlap_ego_poses, overlap_ego_t) ++ T_rig_global = ego_interp.interpolate_to_timestamps(obj_ts) ++ ++ # transform to rig frame from lidar ++ rig_obj_poses = np.matmul(T_lidar_rig[np.newaxis], obj_poses) ++ ++ # print('elev') ++ # print(rig_obj_poses[:,2, 3]) ++ # print('height') ++ # print(obj_dat[sidx:eidx,6][valid_t]) ++ ++ # transform to global frame (w.r.t rivermark) from rig ++ glob_obj_poses = np.matmul(T_rig_global, rig_obj_poses) ++ # now to the desired global frame ++ if GLOBAL_BASE_POSE is not GLOBAL_BASE_POSE_RIVERMARK: ++ # to ECEF ++ glob_obj_poses = np.matmul(GLOBAL_BASE_POSE_RIVERMARK[np.newaxis], glob_obj_poses) ++ # to desired global pose ++ glob_obj_poses = np.matmul(np.linalg.inv(GLOBAL_BASE_POSE)[np.newaxis], glob_obj_poses) ++ ++ if postprocess and track_id != 'extra': ++ # we're going to collect frames with "correct" orientations ++ # by first looking at dynamic frames where can use motion to infer correct orientation, ++ # then using dynamic to determine correctness of static frames. ++ # then we can interpolate between all these correct frames. ++ glob_hvec = glob_obj_poses[:,:2,0] # x-axis ++ glob_hvec = glob_hvec / np.linalg.norm(glob_hvec, axis=-1, keepdims=True) ++ glob_yaw = np.arctan2(glob_hvec[:,1], glob_hvec[:,0]) ++ glob_pos = glob_obj_poses[:,:2,3] # 2d ++ ++ # TODO add smoothing to the position to avoid noisy velocities ++ obj_vel = np.diff(glob_pos[:,:2], axis=0) / np.diff(obj_ts*1e-6)[:,np.newaxis] ++ obj_vel = np.concatenate([obj_vel, obj_vel[-1:,:]], axis=0) ++ # is_dynamic = np.median(np.linalg.norm(obj_vel, axis=1)) > 2.0 # m/s ++ is_dynamic = np.linalg.norm(obj_vel, axis=1) > 2.0 # m/s ++ ++ is_correct_mask = np.zeros((glob_pos.shape[0]), dtype=bool) ++ ++ # dynamic first ++ if np.sum(is_dynamic) > 0: ++ dynamic_vel = obj_vel[is_dynamic] ++ dynamic_yaw = glob_yaw[is_dynamic] ++ dynamic_vel_norm = np.linalg.norm(dynamic_vel, axis=1, keepdims=True) ++ vel_dir = dynamic_vel / (dynamic_vel_norm + 1e-9) ++ head_dir = np.concatenate([np.cos(dynamic_yaw[:,np.newaxis]), np.sin(dynamic_yaw[:,np.newaxis])], axis=1) ++ vel_head_dot = np.sum(vel_dir * head_dir, axis=1) ++ dynamic_correct = vel_head_dot > 0 ++ is_correct_mask[is_dynamic] = dynamic_correct ++ ++ # now static, by referencing closest correct dynamic ++ dynamic_inds = np.nonzero(np.logical_and(is_dynamic, is_correct_mask))[0] ++ static_inds = np.nonzero(~is_dynamic)[0] ++ if len(static_inds) > 0: ++ # if no dynamic frames ++ # assume correct orientation has the most frequent sign ++ num_pos = np.sum(glob_yaw[~is_dynamic] >= 0) ++ num_neg = np.sum(glob_yaw[~is_dynamic] < 0) ++ for static_ind in static_inds: ++ if len(dynamic_inds) > 0: ++ closest_dyn_ind = np.argmin(np.abs(static_ind - dynamic_inds)) ++ dyn_stat_dot = np.sum(glob_hvec[closest_dyn_ind]*glob_hvec[static_ind]) ++ if dyn_stat_dot > 0: # going in same direction ++ is_correct_mask[static_ind] = True ++ else: ++ is_wrong = glob_yaw[static_ind] >= 0 if num_neg > num_pos else glob_yaw[static_ind] < 0 ++ is_correct_mask[static_ind] = not is_wrong ++ ++ if np.sum(is_correct_mask) > 0: ++ fix_interp_poses = glob_obj_poses[is_correct_mask] ++ fix_interp_t = obj_ts[is_correct_mask] ++ # what if edges are not correct? ++ if not is_correct_mask[0]: ++ # just pad first correct to beginning ++ fix_interp_poses = np.concatenate([fix_interp_poses[0:1], fix_interp_poses], axis=0) ++ fix_interp_t = np.concatenate([[obj_ts[0]], fix_interp_t], axis=0) ++ if not is_correct_mask[-1]: ++ # just pad first correct to beginning ++ fix_interp_poses = np.concatenate([fix_interp_poses, fix_interp_poses[-1:]], axis=0) ++ fix_interp_t = np.concatenate([fix_interp_t, [obj_ts[-1]]], axis=0) ++ ++ # now interpolate between correct frames ++ fix_flip_interp = PoseInterpolator(fix_interp_poses, fix_interp_t) ++ fixed_rot_poses = fix_flip_interp.interpolate_to_timestamps(obj_ts) ++ glob_obj_poses[:,:3,:3] = fixed_rot_poses[:,:3,:3] # don't want to update translation ++ ++ # after processing interpolate to the relevant ego timestamps (upsample from 10Hz to 30Hz) ++ obj_interp = PoseInterpolator(glob_obj_poses, obj_ts) ++ overlap_ego_mask = np.logical_and(ego_pose_timestamps >= obj_ts[0], ego_pose_timestamps <= obj_ts[-1]) ++ overlap_ego_t = ego_pose_timestamps[overlap_ego_mask] ++ glob_obj_poses = obj_interp.interpolate_to_timestamps(overlap_ego_t) ++ ++ # update full seq information ++ full_obj_traj[overlap_ego_mask] = glob_obj_poses ++ obj_valid[overlap_ego_mask] = True ++ ++ if np.sum(obj_valid) > 0: ++ traj_poses.append(full_obj_traj) ++ traj_valid.append(obj_valid) ++ traj_lwh.append(obj_lwh) ++ traj_ids.append(track_id) ++ ++ traj_poses = np.stack(traj_poses, axis=0) ++ traj_valid = np.stack(traj_valid, axis=0) ++ traj_lwh = np.stack(traj_lwh, axis=0) ++ traj_ids = np.array(traj_ids) ++ ++ if crop2valid: ++ # we interpolated inside ego timestamp maximum, so have to crop a bit ++ val_inds = np.nonzero(np.sum(traj_valid, axis=0) > 0)[0] ++ start_valid = np.amin(val_inds) ++ end_valid = np.amax(val_inds)+1 ++ traj_poses = traj_poses[:,start_valid:end_valid] ++ traj_valid = traj_valid[:,start_valid:end_valid] ++ ego_poses = ego_poses[start_valid:end_valid] ++ ego_pose_timestamps = ego_pose_timestamps[start_valid:end_valid] ++ ++ print(start_valid) ++ print(end_valid) ++ ++ if fill_first_n is not None: ++ # for each trajectory, make sure the first n steps are ++ # all valid either by interpolation or extrapolation. ++ all_ts = ego_pose_timestamps*1e-6 ++ for ai in range(traj_poses.shape[0]): ++ cur_poses = traj_poses[ai] ++ cur_trans = cur_poses[:,:3,3] ++ cur_R = cur_poses[:,:3,:3] ++ cur_valid = traj_valid[ai] ++ ++ if np.sum(cur_valid) < 30: ++ # if does't show up for at least a second throughout, don't be extrapolating ++ continue ++ ++ first_n_valid = cur_valid[:fill_first_n] ++ if np.sum(~first_n_valid) == 0: ++ continue ++ ++ first_n_timestamps = all_ts[:fill_first_n] ++ ++ all_val_inds = sorted(np.nonzero(cur_valid)[0]) ++ first_val_idx = all_val_inds[0] ++ last_val_idx = all_val_inds[-1] ++ ++ # interp steps are those between the first and last valid steps ++ first_n_steps = np.arange(min(fill_first_n, all_ts.shape[0])) ++ interp_steps = np.logical_and(first_n_steps >= first_val_idx, first_n_steps <= last_val_idx) ++ first_n_interp = None ++ if np.sum(interp_steps) > 0: ++ first_n_interp = PoseInterpolator(cur_poses[cur_valid], all_ts[cur_valid]) ++ first_n_interp = first_n_interp.interpolate_to_timestamps(first_n_timestamps[interp_steps]) ++ ++ # extrap fw are those past last valid step (disappear) ++ extrap_fw_steps = first_n_steps > last_val_idx ++ first_n_extrap_fw = None ++ if np.sum(extrap_fw_steps) > 0: ++ if last_val_idx > 0 and cur_valid[last_val_idx-1]: ++ # need to compute velocity to extrapolate ++ dt = first_n_timestamps[extrap_fw_steps] - all_ts[last_val_idx] ++ dt = dt[:,np.newaxis] ++ last_pos = np.repeat(cur_trans[last_val_idx][np.newaxis], dt.shape[0], axis=0) ++ # translation ++ last_lin_vel = (cur_trans[last_val_idx] - cur_trans[last_val_idx-1]) / (all_ts[last_val_idx] - all_ts[last_val_idx-1]) ++ last_lin_vel = np.repeat(last_lin_vel[np.newaxis], dt.shape[0], axis=0) ++ extrap_trans = last_pos + last_lin_vel*dt ++ # copy rotation ++ extrap_R = np.repeat(cur_R[last_val_idx:last_val_idx+1], dt.shape[0], axis=0) ++ # # extrapolate rotation ++ # last_delta_R = np.dot(cur_R[last_val_idx], cur_R[last_val_idx-1].T) ++ # last_delta_rotvec = R.from_matrix(last_delta_R).as_rotvec() ++ # last_delta_angle = np.linalg.norm(last_delta_rotvec) ++ # last_delta_axis = last_delta_rotvec / (last_delta_angle + 1e-9) ++ # last_ang_vel = last_delta_angle / (all_ts[last_val_idx] - all_ts[last_val_idx-1]) ++ # last_ang_vel = np.repeat(last_ang_vel[np.newaxis,np.newaxis], dt.shape[0], axis=0) ++ # extrap_angle = last_ang_vel*dt ++ # last_delta_axis = np.repeat(last_delta_axis[np.newaxis], dt.shape[0], axis=0) ++ # extrap_rotvec = extrap_angle*last_delta_axis ++ # extrap_delta_R = R.from_rotvec(extrap_rotvec).as_matrix() ++ # extrap_R = np.matmul(extrap_delta_R, cur_R[last_val_idx:last_val_idx+1]) ++ ++ # put together ++ extrap_poses = np.repeat(np.eye(4)[np.newaxis], dt.shape[0], axis=0) ++ extrap_poses[:,:3,:3] = extrap_R ++ extrap_poses[:,:3,3] = extrap_trans ++ first_n_extrap_fw = extrap_poses ++ ++ # extrap bw are those before first valid step (appear) ++ extrap_bw_steps = first_n_steps < first_val_idx ++ first_n_extrap_bw = None ++ if np.sum(extrap_bw_steps) > 0: ++ if first_val_idx < (cur_valid.shape[0]-1) and cur_valid[first_val_idx+1]: ++ # need to compute velocity to extrapolate ++ dt = first_n_timestamps[extrap_bw_steps] - all_ts[first_val_idx] # note will be < 0 ++ dt = dt[:,np.newaxis] ++ first_pos = np.repeat(cur_trans[first_val_idx][np.newaxis], dt.shape[0], axis=0) ++ # translation ++ first_lin_vel = (cur_trans[first_val_idx+1] - cur_trans[first_val_idx]) / (all_ts[first_val_idx+1] - all_ts[first_val_idx]) ++ first_lin_vel = np.repeat(first_lin_vel[np.newaxis], dt.shape[0], axis=0) ++ extrap_trans = first_pos + first_lin_vel*dt ++ # copy rotation ++ extrap_R = np.repeat(cur_R[first_val_idx:first_val_idx+1], dt.shape[0], axis=0) ++ # # extrapolate rotation ++ # first_delta_R = np.dot(cur_R[first_val_idx+1], cur_R[first_val_idx].T) ++ # first_delta_rotvec = R.from_matrix(first_delta_R).as_rotvec() ++ # first_delta_angle = np.linalg.norm(first_delta_rotvec) ++ # first_delta_axis = first_delta_rotvec / (first_delta_angle+1e-9) ++ # first_ang_vel = first_delta_angle / (all_ts[first_val_idx+1] - all_ts[first_val_idx]) ++ # first_ang_vel = np.repeat(first_ang_vel[np.newaxis,np.newaxis], dt.shape[0], axis=0) ++ # extrap_angle = first_ang_vel*dt ++ # first_delta_axis = np.repeat(first_delta_axis[np.newaxis], dt.shape[0], axis=0) ++ # extrap_rotvec = extrap_angle*first_delta_axis ++ # extrap_delta_R = R.from_rotvec(extrap_rotvec).as_matrix() ++ # extrap_R = np.matmul(extrap_delta_R, cur_R[first_val_idx:first_val_idx+1]) ++ # put together ++ extrap_poses = np.repeat(np.eye(4)[np.newaxis], dt.shape[0], axis=0) ++ extrap_poses[:,:3,:3] = extrap_R ++ extrap_poses[:,:3,3] = extrap_trans ++ first_n_extrap_bw = extrap_poses ++ ++ first_n_poses = cur_poses[:fill_first_n] ++ if first_n_interp is not None: ++ first_n_poses[interp_steps] = first_n_interp ++ if first_n_extrap_fw is not None: ++ first_n_poses[extrap_fw_steps] = first_n_extrap_fw ++ if first_n_extrap_bw is not None: ++ first_n_poses[extrap_bw_steps] = first_n_extrap_bw ++ ++ first_n_valid = np.sum(np.isnan(first_n_poses.reshape((first_n_poses.shape[0], 16))), axis=1) == 0 ++ ++ traj_poses[ai, :fill_first_n] = first_n_poses ++ traj_valid[ai, :fill_first_n] = first_n_valid ++ ++ # if traj_ids[ai] == 2588: ++ # print(extrap_fw_steps) ++ # print(first_n_extrap_fw) ++ # exit() ++ ++ # based on fill-in can tell if there are duplicated (mis-associated tracks) if they collide ++ if mine_dups: ++ print('Mining possible duplicates...') ++ first_n_xy = traj_poses[:, :fill_first_n, :2, 3] ++ first_n_hvec = traj_poses[:, :fill_first_n, :2, 0] ++ first_n_hvec = first_n_hvec / np.linalg.norm(first_n_hvec, axis=-1, keepdims=True) ++ for ai in range(traj_poses.shape[0]): ++ ai_mask = np.zeros((traj_poses.shape[0]), dtype=bool) ++ ai_mask[ai] = True ++ cur_id = traj_ids[ai] ++ other_ids = traj_ids[~ai_mask] ++ ++ traj_tgt = np.concatenate([first_n_xy[ai], first_n_hvec[ai]], axis=1) ++ lw_tgt = traj_lwh[ai, :2] ++ traj_others = np.concatenate([first_n_xy[~ai_mask], first_n_hvec[~ai_mask]], axis=2) ++ lw_others = traj_lwh[~ai_mask, :2] ++ ++ # if they collide more than 75% of the time ++ veh_coll = check_single_veh_coll(traj_tgt, lw_tgt, traj_others, lw_others) ++ dup_mask = np.sum(veh_coll, axis=1) > int(0.75*veh_coll.shape[1]) ++ ++ if np.sum(dup_mask) > 0: ++ dup_ids = sorted([otid for otid in other_ids[dup_mask].tolist() if otid > cur_id]) ++ dup_ids = [str(otid) for otid in dup_ids] ++ if len(dup_ids) > 0: ++ dup_str = ','.join(dup_ids) ++ dup_str = str(cur_id) + ' : [' + dup_str + '],' ++ print(dup_str) ++ ++ # add ego at index 0 ++ ego_poses = ego_poses[np.newaxis] ++ traj_poses = np.concatenate([ego_poses, traj_poses], axis=0) ++ traj_valid = np.concatenate([np.ones((1, ego_poses.shape[1]), dtype=bool), traj_valid], axis=0) ++ traj_lwh = np.concatenate([np.array([NV_EGO_LWH]), traj_lwh], axis=0) ++ traj_ids = np.array(['ego'] + traj_ids.tolist()) ++ ++ if map_tile is not None: ++ traj_poses[:, :, :2, -1] += map_tile.trans_offset ++ ++ return traj_poses, traj_valid, traj_lwh, ego_pose_timestamps*1e-6, traj_ids ++ ++ ++if __name__ == '__main__': ++ tile_path = './data/nvidia/nvmaps/92d651e5-21d2-4816-b16d-0feace622aa1/jsv3/92d651e5-21d2-4816-b16d-0feace622aa1/tile/4bd02829-cab6-435d-8ebd-679c96787f8b_json' ++ tile = load_tile(tile_path, layers=['lane_dividers_v1', 'lane_channels_v1']) ++ # convert to nuscenes-like map format if desired ++ nusc_tile = convert_tile_to_nuscenes(tile) ++ ++ # dynamic_obj_path = './data/nvidia/dynamic_object_poses.npz' ++ # dyn_obj_data = np.load(dynamic_obj_path) ++ # dyn_obj_poses = dyn_obj_data['poses'] ++ # dyn_obj_timestamps = dyn_obj_data['pose_timestamps'] ++ ++ # print(dyn_obj_poses.shape) ++ ++ # # debug_viz_vid(tile.segments, 'dyn_obj', ++ # # comp_out_path='./out/dev_gtc_demo/dev_preprocess', ++ # # poses=dyn_obj_poses[np.newaxis], ++ # # poses_valid=np.ones((1, dyn_obj_poses.shape[0])), ++ # # poses_lwh=np.array([[5.087, 2.307, 1.856]]), ++ # # # pose_ids=traj_ids, ++ # # subsamp=1, ++ # # fps=30) ++ ++ # # TODO option to just load specific ids? ++ ++ # # rivermark ++ # autolabels_path = './data/nvidia/endeavor/labels/autolabels.pkl' ++ # ego_images_path = './data/nvidia/ego_session/processed/44093/images/image_00' ++ # lidar_rig_path = './data/nvidia/endeavor/poses/T_lidar_rig.npz' ++ # frame_range = (510, 880) ++ # traj_poses, traj_valid, traj_lwh, traj_t, traj_ids = load_trajectories(autolabels_path, ego_images_path, lidar_rig_path, ++ # frame_range=frame_range, ++ # postprocess=True, ++ # # extra_obj_path=dynamic_obj_path, ++ # # load_ids=['ego', 'extra'], ++ # crop2valid=False, ++ # fill_first_n=160) ++ ++ # print(traj_poses.shape) ++ # debug_viz_vid(tile.segments, 'rivermark_fill', ++ # comp_out_path='./out/dev_gtc_demo/dev_preprocess', ++ # poses=traj_poses, ++ # poses_valid=traj_valid, ++ # poses_lwh=traj_lwh, ++ # pose_ids=traj_ids, ++ # subsamp=1, ++ # fps=30) ++ ++ # exit() ++ ++ ++ autolabels_path = './data/nvidia/endeavor/labels/autolabels.pkl' ++ ego_images_path = './data/nvidia/endeavor/images/image_00' ++ lidar_rig_path = './data/nvidia/endeavor/poses/T_lidar_rig.npz' ++ ++ # this just filters which frames to process, can be None ++ # frame_range = (3170, 3590) # (3200, 3545) # (0, 900), (1000, 2400), None ++ # frame_range = (60, 3590) ++ frame_range = (3000, 3590) ++ # frame_range = (2000, 2400) ++ # frame_range = (570, 870) # merge ++ # frame_range = (240, 660) # merge ++ # frame_range = (1440, 1860) # merge ++ ++ # process on the original nvmap ++ traj_poses, traj_valid, traj_lwh, traj_t, traj_ids = load_trajectories(autolabels_path, ego_images_path, lidar_rig_path, ++ frame_range=frame_range, ++ postprocess=True, ++ mine_dups=False, ++ fill_first_n=300) ++ print(traj_poses.shape) ++ debug_viz_vid(tile.segments, 'extrap_proc_ids_sframe_00003000_eframe_00003300', ++ comp_out_path='./out/dev_gtc_demo/dev_preprocess', ++ poses=traj_poses, ++ poses_valid=traj_valid, ++ poses_lwh=traj_lwh, ++ pose_ids=traj_ids, ++ subsamp=1, ++ fps=30) ++ ++ # # process with the nuscenes map version ++ # # (the only difference is the coordinate system is offset such that the origin is at bottom left) ++ # traj_poses, traj_valid, traj_lwh, traj_t, traj_ids = load_trajectories(autolabels_path, ego_images_path, lidar_rig_path, ++ # frame_range=frame_range, ++ # map_tile=nusc_tile, ++ # postprocess=True) ++ # debug_viz_vid(nusc_tile.segments, 'processed_sframe_00003200_eframe_00003545_nusc', ++ # comp_out_path='./out/gtc_demo/dev_preprocess', ++ # poses=traj_poses, ++ # poses_valid=traj_valid, ++ # poses_lwh=traj_lwh, ++ # pose_ids=traj_ids, ++ # subsamp=3, ++ # fps=10) ++ ++ # ++ # output single pose for Amlan ++ # ++ ++ # single_frame_idx = 3500 ++ # last_pose = traj_poses[:,single_frame_idx:single_frame_idx+1] ++ # last_valid = traj_valid[:,single_frame_idx] ++ # last_step_valid_poses = last_pose[last_valid] ++ # print(last_step_valid_poses.shape) ++ # last_step_lwh = traj_lwh[last_valid] ++ ++ # debug_viz_segs(nusc_tile.segments, 'endeavor_nusc_step%d' % (single_frame_idx), poses=last_step_valid_poses, poses_valid=traj_valid[:,single_frame_idx:single_frame_idx+1][last_valid], poses_lwh=last_step_lwh) ++ ++ # np.savez('./out/dev_nvmap/endeavor_poses_frame%06d.npz' % (single_frame_idx), ++ # poses=last_step_valid_poses, ++ # lwh=last_step_lwh) ++ # exit() ++ ++ # ++ # Output GPS trajectories for Jeremy ++ # ++ ++ # # convert ego poses back to global ECEF coordinate system ++ # N, T, _, _ = traj_poses.shape ++ # ecef_traj_poses = np.matmul(GLOBAL_BASE_POSE[np.newaxis,np.newaxis], traj_poses) ++ # print(ecef_traj_poses.shape) ++ # # convert to GPS ++ # gps_traj_poses = ecef_2_lat_lng_alt(ecef_traj_poses.reshape((N*T, 4, 4)), earth_model='WGS84') ++ # lat_lng_alt, orientation_axis, orientation_angle = gps_traj_poses ++ ++ # save_path = os.path.join('./out/dev_nvmap/endeavor_trajectory_track_ids.npz') ++ # out_dict = { ++ # 'timestamps' : traj_t, ++ # 'track_ids' : traj_ids, ++ # 'ecef_poses' : ecef_traj_poses, ++ # 'pose_valid' : traj_valid, ++ # 'gps_lat_lng_alt' : lat_lng_alt.reshape((N, T, 3)), ++ # 'gps_orientation_axis' : orientation_axis.reshape((N, T, 3)), ++ # 'gps_orientation_angle_degrees' : orientation_angle.reshape((N, T, 1)) ++ # } ++ # for k, v in out_dict.items(): ++ # print(k) ++ # print(v.shape) ++ # # if k != 'timestamps': ++ # # print(np.sum(np.isnan(v), axis=1)) ++ # np.savez(save_path, **out_dict) ++ ++ # exit() ++ ++ # debug_viz_vid(tile.segments, 'gt_sframe_00003200_eframe_00003545', ++ # comp_out_path='./out/gtc_demo/dev_preprocess', ++ # poses=traj_poses, ++ # poses_valid=traj_valid, ++ # poses_lwh=traj_lwh, ++ # subsamp=3, ++ # fps=10) +diff --git a/src/trajdata/dataset_specific/drivesim/test.py b/src/trajdata/dataset_specific/drivesim/test.py +new file mode 100644 +index 0000000..ee6b58e +--- /dev/null ++++ b/src/trajdata/dataset_specific/drivesim/test.py +@@ -0,0 +1,79 @@ ++import numpy as np ++from typing import Dict ++import trajdata.dataset_specific.drivesim.nvmap_utils as nvutils ++ ++# This is from the tag from the .xodr file for Endeavor. ++LATLONALT_ORIGIN_ENDEAVOR = np.array([[37.37852062996696, -121.9596180846297, 0.0]]) ++ ++ ++def convert_to_DS(poses: np.ndarray, track_ids:list,fps:float): ++ """_summary_ ++ ++ Args: ++ poses (np.ndarray): x, y, h states of each agent at each time, relative to the ego vehicle's state at time t=0. (N, T, 3) ++ """ ++ # wgs84 = coutils.LocalToWGS84((0, 0, 0), (37.37852062996696, -121.9596180846297, 0.0)) ++ # wgs84_2 = coutils.ECEFtoWGS84(nvutils.GLOBAL_BASE_POSE_ENDEAVOR[:3, -1]) ++ ++ world_from_map_ft = nvutils.lat_lng_alt_2_ecef( ++ LATLONALT_ORIGIN_ENDEAVOR, ++ np.array([[1, 0, 0]]), ++ np.array([[0]]) ++ ) ++ ++ N, T = poses.shape[:2] ++ x = poses[..., 0] ++ y = poses[..., 1] ++ heading = poses[..., 2] ++ ++ c = np.cos(heading) ++ s = np.sin(heading) ++ T_mat = np.tile(np.eye(4), (N, T, 1, 1)) ++ T_mat[..., 0, 0] = c ++ T_mat[..., 0, 1] = -s ++ T_mat[..., 1, 0] = s ++ T_mat[..., 1, 1] = c ++ T_mat[..., 0, 3] = x ++ T_mat[..., 1, 3] = y ++ # TODO: Some height for ray-casting down to road? ++ # T_mat[..., 2, 3] = 0 ++ ++ ecef_traj_poses = np.matmul(world_from_map_ft[:, np.newaxis], T_mat) ++ gps_traj_poses = nvutils.ecef_2_lat_lng_alt(ecef_traj_poses.reshape((N*T, 4, 4)), earth_model='WGS84') ++ lat_lng_alt, orientation_axis, orientation_angle = gps_traj_poses ++ breakpoint() ++ out_dict = { ++ 'timestamps' : np.linspace(0, T/fps, T), ++ 'track_ids' : np.array(track_ids), ++ # 'bbox_lwh' : np.array([[4.387, 1.907, 1.656], [4.387, 1.907, 1.656]]), ++ # 'ecef_poses' : ecef_traj_poses, ++ 'pose_valid' : np.ones((ecef_traj_poses.shape[0], ecef_traj_poses.shape[1]), dtype=bool), ++ 'gps_lat_lng_alt' : lat_lng_alt.reshape((N, T, 3)), ++ 'gps_orientation_axis' : orientation_axis.reshape((N, T, 3)), ++ 'gps_orientation_angle_degrees' : orientation_angle.reshape((N, T, 1)) ++ } ++ ++ return out_dict ++ ++ ++def main(): ++ poses = np.zeros((2, 315, 3)) ++ poses[0, :, 0] = np.linspace(-505, -519, 315) ++ poses[0, :, 1] = np.linspace(-1019, -866, 315) ++ poses[0, :, 2] = np.linspace(np.pi/2, 5*np.pi/8, 315) ++ ++ fps = 30 ++ out_dict = convert_to_DS(poses,["ego","769"],fps) ++ ++ # with np.load("/home/bivanovic/projects/drivesim-ov/source/extensions/omni.drivesim.dl_traffic_model/data/example_trajectories.npz") as data: ++ # out_dict["gps_lat_lng_alt"][1] = data["gps_lat_lng_alt"][1] ++ # out_dict["gps_orientation_axis"][1] = data["gps_orientation_axis"][1] ++ # out_dict["gps_orientation_angle_degrees"][1] = data["gps_orientation_angle_degrees"][1] ++ ++ np.savez( ++ "/home/bivanovic/projects/drivesim-ov/source/extensions/omni.drivesim.dl_traffic_model/data/test.npz", ++ **out_dict ++ ) ++ ++if __name__ == "__main__": ++ main() +\ No newline at end of file +diff --git a/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py b/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py +index 1c4df3f..08f9977 100644 +--- a/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py ++++ b/src/trajdata/dataset_specific/nuplan/nuplan_dataset.py +@@ -360,38 +360,6 @@ class NuplanDataset(RawDataset): + + return agent_list, agent_presence + +- def cache_map( +- self, +- map_name: str, +- cache_path: Path, +- map_cache_class: Type[SceneCache], +- map_params: Dict[str, Any], +- ) -> None: +- nuplan_map: NuPlanMap = map_factory.get_maps_api( +- map_root=str(self.metadata.data_dir.parent / "maps"), +- map_version=nuplan_utils.NUPLAN_MAP_VERSION, +- map_name=nuplan_utils.NUPLAN_FULL_MAP_NAME_DICT[map_name], +- ) +- +- # Loading all layer geometries. +- nuplan_map.initialize_all_layers() +- +- # This df has the normal lane_connectors with additional boundary information, +- # which we want to use, however the default index is not the lane_connector_fid, +- # although it is a 1:1 mapping so we instead create another index with the +- # lane_connector_fids as the key and the resulting integer indices as the value. +- lane_connector_fids: pd.Series = nuplan_map._vector_map[ +- "gen_lane_connectors_scaled_width_polygons" +- ]["lane_connector_fid"] +- lane_connector_idxs: pd.Series = pd.Series( +- index=lane_connector_fids, data=range(len(lane_connector_fids)) +- ) +- +- vector_map = VectorMap(map_id=f"{self.name}:{map_name}") +- nuplan_utils.populate_vector_map(vector_map, nuplan_map, lane_connector_idxs) +- +- map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) +- + def cache_maps( + self, + cache_path: Path, +@@ -406,4 +374,45 @@ class NuplanDataset(RawDataset): + desc=f"Caching {self.name} Maps at {map_params['px_per_m']:.2f} px/m", + position=0, + ): +- self.cache_map(map_name, cache_path, map_cache_class, map_params) ++ cache_map( ++ map_root=str(self.metadata.data_dir.parent / "maps"), ++ env_name=self.name, ++ map_name=map_name, ++ cache_path=cache_path, ++ map_cache_class=map_cache_class, ++ map_params=map_params, ++ ) ++ ++ ++def cache_map( ++ map_root: str, ++ env_name: str, ++ map_name: str, ++ cache_path: Path, ++ map_cache_class: Type[SceneCache], ++ map_params: Dict[str, Any], ++) -> None: ++ nuplan_map: NuPlanMap = map_factory.get_maps_api( ++ map_root=map_root, ++ map_version=nuplan_utils.NUPLAN_MAP_VERSION, ++ map_name=nuplan_utils.NUPLAN_FULL_MAP_NAME_DICT[map_name], ++ ) ++ ++ # Loading all layer geometries. ++ nuplan_map.initialize_all_layers() ++ ++ # This df has the normal lane_connectors with additional boundary information, ++ # which we want to use, however the default index is not the lane_connector_fid, ++ # although it is a 1:1 mapping so we instead create another index with the ++ # lane_connector_fids as the key and the resulting integer indices as the value. ++ lane_connector_fids: pd.Series = nuplan_map._vector_map[ ++ "gen_lane_connectors_scaled_width_polygons" ++ ]["lane_connector_fid"] ++ lane_connector_idxs: pd.Series = pd.Series( ++ index=lane_connector_fids, data=range(len(lane_connector_fids)) ++ ) ++ ++ vector_map = VectorMap(map_id=f"{env_name}:{map_name}") ++ nuplan_utils.populate_vector_map(vector_map, nuplan_map, lane_connector_idxs) ++ ++ map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) +\ No newline at end of file +diff --git a/src/trajdata/dataset_specific/nuplan/nuplan_utils.py b/src/trajdata/dataset_specific/nuplan/nuplan_utils.py +index 5007ace..dea194c 100644 +--- a/src/trajdata/dataset_specific/nuplan/nuplan_utils.py ++++ b/src/trajdata/dataset_specific/nuplan/nuplan_utils.py +@@ -189,6 +189,7 @@ class NuPlanObject: + + + def nuplan_type_to_unified_type(nuplan_type: str) -> AgentType: ++ # TODO (pkarkus) map traffic cones, barriers to static; generic_object to pedestrian + if nuplan_type == "pedestrian": + return AgentType.PEDESTRIAN + elif nuplan_type == "bicycle": +@@ -327,12 +328,20 @@ def populate_vector_map( + # The right boundary of Lane A has Lane A to its left. + boundary_connectivity_dict[right_boundary_id]["left"].append(lane_id) + ++ # Find road areas that this lane intersects for faster lane-based lookup later. ++ intersect_filt = nuplan_map._vector_map["drivable_area"].intersects(lane_info["geometry"]) ++ isnear_filt = (nuplan_map._vector_map["drivable_area"].distance(lane_info["geometry"]) < 3.) ++ road_area_ids = set(nuplan_map._vector_map["drivable_area"][intersect_filt | isnear_filt]["fid"].values) ++ if not road_area_ids: ++ print (f"Warning: no road lane associated with lane {lane_id}") ++ + # "partial" because we aren't adding lane connectivity until later. + partial_new_lane = RoadLane( + id=lane_id, + center=Polyline(center_pts), + left_edge=Polyline(left_pts), + right_edge=Polyline(right_pts), ++ road_area_ids=road_area_ids, + ) + vector_map.add_map_element(partial_new_lane) + overall_pbar.update() +diff --git a/src/trajdata/dataset_specific/scene_records.py b/src/trajdata/dataset_specific/scene_records.py +index 68bd5b9..0c6f29b 100644 +--- a/src/trajdata/dataset_specific/scene_records.py ++++ b/src/trajdata/dataset_specific/scene_records.py +@@ -28,6 +28,12 @@ class NuscSceneRecord(NamedTuple): + desc: str + data_idx: int + ++class CarlaSceneRecord(NamedTuple): ++ name: str ++ location: str ++ length: str ++ data_idx: int ++ + + class LyftSceneRecord(NamedTuple): + name: str +@@ -48,3 +54,11 @@ class NuPlanSceneRecord(NamedTuple): + split: str + # desc: str + data_idx: int ++ ++class DrivesimSceneRecord(NamedTuple): ++ name: str ++ location: str ++ length: str ++ split: str ++ # desc: str ++ data_idx: int +diff --git a/src/trajdata/maps/map_api.py b/src/trajdata/maps/map_api.py +index 36a1245..e6e4a6d 100644 +--- a/src/trajdata/maps/map_api.py ++++ b/src/trajdata/maps/map_api.py +@@ -1,6 +1,6 @@ + from __future__ import annotations + +-from typing import TYPE_CHECKING, Optional ++from typing import TYPE_CHECKING, Optional, Union + + if TYPE_CHECKING: + from trajdata.maps.map_kdtree import MapElementKDTree +@@ -15,9 +15,20 @@ from trajdata.utils import map_utils + + + class MapAPI: +- def __init__(self, unified_cache_path: Path) -> None: ++ def __init__(self, unified_cache_path: Path, keep_in_memory: bool = False) -> None: ++ """A simple interface for loading trajdata's vector maps which does not require ++ instantiation of a `UnifiedDataset` object. ++ ++ Args: ++ unified_cache_path (Path): Path to trajdata's local cache on disk. ++ keep_in_memory (bool): Whether loaded maps should be stored ++ in memory (memoized) for later re-use. For most cases (e.g., batched dataloading), ++ this is a good idea. However, this can cause rapid memory usage growth for some ++ datasets (e.g., Waymo) and it can be better to disable this. Defaults to False. ++ """ + self.unified_cache_path: Path = unified_cache_path + self.maps: Dict[str, VectorMap] = dict() ++ self._keep_in_memory = keep_in_memory + + def get_map( + self, map_id: str, scene_cache: Optional[SceneCache] = None, **kwargs +@@ -25,22 +36,43 @@ class MapAPI: + if map_id not in self.maps: + env_name, map_name = map_id.split(":") + env_maps_path: Path = self.unified_cache_path / env_name / "maps" +- stored_vec_map: VectorizedMap = map_utils.load_vector_map( +- env_maps_path / f"{map_name}.pb" +- ) ++ vec_map_path: Path = env_maps_path / f"{map_name}.pb" ++ ++ if not Path.exists(vec_map_path): ++ if self.data_dirs is None: ++ raise ValueError( ++ f"There is no cached map at {vec_map_path} and there was no " + ++ "`data_dirs` provided to rebuild cache.") ++ ++ # Rebuild maps by creating a dummy dataset object. ++ # TODO(pkarkus) We need support for rebuilding map files only, without creating dataset and building agent data. ++ from trajdata.dataset import UnifiedDataset ++ dataset = UnifiedDataset( ++ desired_data=[env_name], ++ rebuild_maps=True, ++ data_dirs=self.data_dirs, ++ cache_location=self.unified_cache_path, ++ verbose=True, ++ ) ++ # Hopefully we successfully created map cache. ++ ++ stored_vec_map: VectorizedMap = map_utils.load_vector_map(vec_map_path) + + vec_map: VectorMap = VectorMap.from_proto(stored_vec_map, **kwargs) + vec_map.search_kdtrees: Dict[ + str, MapElementKDTree + ] = map_utils.load_kdtrees(env_maps_path / f"{map_name}_kdtrees.dill") + +- self.maps[map_id] = vec_map ++ if self._keep_in_memory: ++ self.maps[map_id] = vec_map ++ else: ++ vec_map = self.maps[map_id] + + if scene_cache is not None: +- self.maps[map_id].associate_scene_data( ++ vec_map.associate_scene_data( + scene_cache.get_traffic_light_status_dict( + kwargs.get("desired_dt", None) + ) + ) + +- return self.maps[map_id] ++ return vec_map +diff --git a/src/trajdata/maps/map_kdtree.py b/src/trajdata/maps/map_kdtree.py +index 59abdc2..a978a90 100644 +--- a/src/trajdata/maps/map_kdtree.py ++++ b/src/trajdata/maps/map_kdtree.py +@@ -1,7 +1,7 @@ + from __future__ import annotations + + from collections import defaultdict +-from typing import TYPE_CHECKING ++from typing import TYPE_CHECKING, Iterator + + if TYPE_CHECKING: + from trajdata.maps.vec_map import VectorMap +@@ -43,8 +43,7 @@ class MapElementKDTree: + total=len(vector_map), + disable=not verbose, + ): +- result = self._extract_points(map_elem) +- if result is not None: ++ for result in self._extract_points_and_metadata(map_elem): + points, extras = result + polyline_inds.extend([len(polylines)] * points.shape[0]) + +@@ -54,6 +53,7 @@ class MapElementKDTree: + + for k, v in extras.items(): + metadata[k].append(v) ++ metadata["map_elem_id"].append(np.array([map_elem.id])) + + points = np.concatenate(polylines, axis=0) + polyline_inds = np.array(polyline_inds) +@@ -64,7 +64,7 @@ class MapElementKDTree: + + def _extract_points_and_metadata( + self, map_element: MapElement +- ) -> Optional[Tuple[np.ndarray, dict[str, np.ndarray]]]: ++ ) -> Iterator[Tuple[np.ndarray, dict[str, np.ndarray]]]: + """Defines the coordinates we want to store in the KDTree for a MapElement. + Args: + map_element (MapElement): the MapElement to store in the KDTree. +@@ -116,16 +116,16 @@ class LaneCenterKDTree(MapElementKDTree): + self.max_segment_len = max_segment_len + super().__init__(vector_map) + +- def _extract_points(self, map_element: MapElement) -> Optional[np.ndarray]: ++ def _extract_points_and_metadata( ++ self, map_element: MapElement ++ ) -> Iterator[Tuple[np.ndarray, dict[str, np.ndarray]]]: + if map_element.elem_type == MapElementType.ROAD_LANE: + pts: Polyline = map_element.center + if self.max_segment_len is not None: + pts = pts.interpolate(max_dist=self.max_segment_len) + + # We only want to store xyz in the kdtree, not heading. +- return pts.xyz, {"heading": pts.h} +- else: +- return None ++ yield pts.xyz, {"heading": pts.h} + + def current_lane_inds( + self, +@@ -181,3 +181,42 @@ class LaneCenterKDTree(MapElementKDTree): + min_costs = [np.min(costs[lane_inds == ind]) for ind in unique_lane_inds] + + return unique_lane_inds[np.argsort(min_costs)] ++ ++ ++class RoadAreaKDTree(MapElementKDTree): ++ """KDTree for road area polygons. ++ The polygons may have holes. We will simply store points along both the ++ exterior_polygon and all interior_holes. Finding a nearest point in this KDTree will ++ correspond to finding any ++ """ ++ ++ def __init__( ++ self, vector_map: VectorMap, max_segment_len: Optional[float] = None ++ ) -> None: ++ """ ++ Args: ++ vec_map: the VectorizedMap object to build the KDTree for ++ max_segment_len (float, optional): if specified, we will insert extra points into the KDTree ++ such that all polyline segments are shorter then max_segment_len. ++ """ ++ self.max_segment_len = max_segment_len ++ super().__init__(vector_map) ++ ++ def _extract_points_and_metadata( ++ self, map_element: MapElement ++ ) -> Iterator[Tuple[np.ndarray, dict[str, np.ndarray]]]: ++ if map_element.elem_type == MapElementType.ROAD_AREA: ++ # Exterior polygon ++ pts: Polyline = map_element.exterior_polygon ++ if self.max_segment_len is not None: ++ pts = pts.interpolate(max_dist=self.max_segment_len) ++ # We only want to store xyz in the kdtree, not heading. ++ yield pts.xyz, {"exterior": np.array([True])} ++ ++ # Interior holes ++ for pts in map_element.interior_holes: ++ if self.max_segment_len is not None: ++ pts = pts.interpolate(max_dist=self.max_segment_len) ++ # We only want to store xyz in the kdtree, not heading. ++ yield pts.xyz, {"exterior": np.array([False])} ++ +diff --git a/src/trajdata/maps/vec_map.py b/src/trajdata/maps/vec_map.py +index 95b299d..52a225f 100644 +--- a/src/trajdata/maps/vec_map.py ++++ b/src/trajdata/maps/vec_map.py +@@ -25,9 +25,10 @@ import matplotlib.pyplot as plt + import numpy as np + from matplotlib.axes import Axes + from tqdm import tqdm ++from shapely.geometry import Polygon + + import trajdata.proto.vectorized_map_pb2 as map_proto +-from trajdata.maps.map_kdtree import LaneCenterKDTree ++from trajdata.maps.map_kdtree import LaneCenterKDTree, RoadAreaKDTree + from trajdata.maps.traffic_light_status import TrafficLightStatus + from trajdata.maps.vec_map_elements import ( + MapElement, +@@ -52,6 +53,7 @@ class VectorMap: + ) + search_kdtrees: Optional[Dict[MapElementType, MapElementKDTree]] = None + traffic_light_status: Optional[Dict[Tuple[int, int], TrafficLightStatus]] = None ++ online_metadict: Optional[Dict[Tuple[str, int], Dict]] = None + + def __post_init__(self) -> None: + self.env_name, self.map_name = self.map_id.split(":") +@@ -60,11 +62,16 @@ class VectorMap: + if MapElementType.ROAD_LANE in self.elements: + self.lanes = list(self.elements[MapElementType.ROAD_LANE].values()) + ++ # self._road_area_polygons: Dict[str, Polygon] = {} ++ + def add_map_element(self, map_elem: MapElement) -> None: + self.elements[map_elem.elem_type][map_elem.id] = map_elem + + def compute_search_indices(self) -> None: +- self.search_kdtrees = {MapElementType.ROAD_LANE: LaneCenterKDTree(self)} ++ self.search_kdtrees = { ++ MapElementType.ROAD_LANE: LaneCenterKDTree(self), ++ # MapElementType.ROAD_AREA: RoadAreaKDTree(self), ++ } + + def iter_elems(self) -> Iterator[MapElement]: + for elems_dict in self.elements.values(): +@@ -101,6 +108,9 @@ class VectorMap: + new_lane.adjacent_lanes_right.extend( + [lane_id.encode() for lane_id in road_lane.adj_lanes_right] + ) ++ # new_lane.road_area_ids.extend( ++ # [road_area_id.encode() for road_area_id in road_lane.road_area_ids] ++ # ) + + def _write_road_areas( + self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray +@@ -250,6 +260,9 @@ class VectorMap: + prev_lanes: Set[str] = set( + [iden.decode() for iden in road_lane_obj.entry_lanes] + ) ++ # road_area_ids: Set[str] = set( ++ # [iden.decode() for iden in road_lane_obj.road_area_ids] ++ # ) + + # Double-using the connectivity attributes for lane IDs now (will + # replace them with Lane objects after all Lane objects have been created). +@@ -262,6 +275,7 @@ class VectorMap: + adj_lanes_right, + next_lanes, + prev_lanes, ++ # road_area_ids=road_area_ids, + ) + map_elem_dict[MapElementType.ROAD_LANE][elem_id] = curr_lane + +@@ -369,6 +383,33 @@ class VectorMap: + self.lanes[idx] for idx in lane_kdtree.polyline_inds_in_range(xyz, dist) + ] + ++ def get_road_areas_within(self, xyz: np.ndarray, dist: float) -> List[RoadArea]: ++ road_area_kdtree: RoadAreaKDTree = self.search_kdtrees[MapElementType.ROAD_AREA] ++ polyline_inds = road_area_kdtree.polyline_inds_in_range(xyz, dist) ++ element_ids = set([ ++ road_area_kdtree.metadata["map_elem_id"][ind] for ind in polyline_inds ++ ]) ++ if MapElementType.ROAD_AREA not in self.elements: ++ raise ValueError( ++ "Road areas are not loaded. Use map_api.get_map(..., incl_road_areas=True)." ++ ) ++ return [ ++ self.elements[MapElementType.ROAD_AREA][id] for id in element_ids ++ ] ++ ++ def get_road_area_polygon_2d(self, id: str) -> Polygon: ++ if id not in self._road_area_polygons: ++ road_area: RoadArea = self.elements[MapElementType.ROAD_AREA][id] ++ road_area_polygon = Polygon( ++ shell=[(pt[0], pt[1]) for pt in road_area.exterior_polygon.points], ++ holes=[ ++ [(pt[0], pt[1]) for pt in polyline.points] ++ for polyline in road_area.interior_holes ++ ] ++ ) ++ self._road_area_polygons[id] = road_area_polygon ++ return self._road_area_polygons[id] ++ + def get_traffic_light_status( + self, lane_id: str, scene_ts: int + ) -> TrafficLightStatus: +@@ -380,6 +421,15 @@ class VectorMap: + else TrafficLightStatus.NO_DATA + ) + ++ def get_online_metadict( ++ self, lane_id: str, scene_ts: int = 0 ++ ) -> Dict: ++ return ( ++ self.online_metadict[(str(lane_id), scene_ts)] ++ if self.online_metadict is not None ++ else {} ++ ) ++ + def rasterize( + self, resolution: float = 2, **kwargs + ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: +diff --git a/src/trajdata/maps/vec_map_elements.py b/src/trajdata/maps/vec_map_elements.py +index fd7a9f1..d1e34b6 100644 +--- a/src/trajdata/maps/vec_map_elements.py ++++ b/src/trajdata/maps/vec_map_elements.py +@@ -1,10 +1,11 @@ + from dataclasses import dataclass, field + from enum import IntEnum +-from typing import List, Optional, Set ++from typing import List, Optional, Set, Union + + import numpy as np + + from trajdata.utils import map_utils ++from trajdata.utils.arr_utils import angle_wrap + + + class MapElementType(IntEnum): +@@ -47,6 +48,15 @@ class Polyline: + def xyz(self) -> np.ndarray: + return self.points[..., :3] + ++ @property ++ def xyh(self) -> np.ndarray: ++ if self.has_heading: ++ return self.points[..., (0, 1, 3)] ++ else: ++ raise ValueError( ++ f"This Polyline only has {self.points.shape[-1]} coordinates, expected 4." ++ ) ++ + @property + def xyzh(self) -> np.ndarray: + if self.has_heading: +@@ -67,14 +77,17 @@ class Polyline: + map_utils.interpolate(self.points, num_pts=num_pts, max_dist=max_dist) + ) + +- def project_onto(self, xyz_or_xyzh: np.ndarray) -> np.ndarray: ++ def project_onto(self, xyz_or_xyzh: np.ndarray, return_index: bool = False) -> Union[np.ndarray, List]: + """Project the given points onto this Polyline. + + Args: + xyzh (np.ndarray): Points to project, of shape (M, D) ++ return_indices (bool): Return the index of starting point of the line segment ++ on which the projected points lies on. + + Returns: + np.ndarray: The projected points, of shape (M, D) ++ np.ndarray: The index of previous polyline points if return_indices == True. + + Note: + D = 4 if this Polyline has headings, otherwise D = 3 +@@ -94,7 +107,8 @@ class Polyline: + dot_products: np.ndarray = (point_seg_diffs * line_seg_diffs).sum( + axis=-1, keepdims=True + ) +- norms: np.ndarray = np.linalg.norm(line_seg_diffs, axis=-1, keepdims=True) ** 2 ++ # norms: np.ndarray = np.linalg.norm(line_seg_diffs, axis=-1, keepdims=True) ** 2 ++ norms: np.ndarray = np.square(line_seg_diffs).sum(axis=-1, keepdims=True) + + # Clip ensures that the projected point stays within the line segment boundaries. + projs: np.ndarray = ( +@@ -102,20 +116,114 @@ class Polyline: + ) + + # 2. Find the nearest projections to the original points. +- closest_proj_idxs: int = np.linalg.norm(xyz - projs, axis=-1).argmin(axis=-1) ++ # We have nan values when two consecutive points are equal. This will never be ++ # the closest projection point, so we replace nans with a large number. ++ point_to_proj_dist = np.nan_to_num(np.linalg.norm(xyz - projs, axis=-1), nan=1e6) ++ closest_proj_idxs: int = point_to_proj_dist.argmin(axis=-1) ++ ++ proj_points = projs[range(xyz.shape[0]), closest_proj_idxs] + + if self.has_heading: + # Adding in the heading of the corresponding p0 point (which makes + # sense as p0 to p1 is a line => same heading along it). +- return np.concatenate( ++ proj_points = np.concatenate( + [ +- projs[range(xyz.shape[0]), closest_proj_idxs], ++ proj_points, + np.expand_dims(self.points[closest_proj_idxs, -1], axis=-1), + ], + axis=-1, + ) ++ ++ if return_index: ++ return proj_points, closest_proj_idxs + else: +- return projs[range(xyz.shape[0]), closest_proj_idxs] ++ return proj_points ++ ++ def distance_to_point(self, xyz: np.ndarray): ++ assert xyz.ndim == 2 ++ xyz_proj = self.project_onto(xyz) ++ return np.linalg.norm(xyz[..., :3] - xyz_proj[..., :3], axis=-1) ++ ++ def get_length(self): ++ # TODO(pkarkus) we could store cummulative distances to speed this up ++ dists = np.linalg.norm(self.xyz[1:, :3] - self.xyz[:-1, :3], axis=-1) ++ length = dists.sum() ++ return length ++ ++ ++ def get_length_from(self, start_ind: np.ndarray): ++ # TODO(pkarkus) we could store cummulative distances to speed this up ++ assert start_ind.ndim == 1 ++ dists = np.linalg.norm(self.xyz[1:, :3] - self.xyz[:-1, :3], axis=-1) ++ length_upto = np.cumsum(np.pad(dists, (1, 0))) ++ length_from = length_upto[-1][None] - length_upto[start_ind] ++ return length_from ++ ++ ++ def traverse_along(self, dist: np.ndarray, start_ind: Optional[np.ndarray] = None) -> np.ndarray: ++ """ ++ Interpolated endpoint of traversing `dist` distance along polyline from a starting point. ++ ++ Returns nan if the end point is not inside the polyline. ++ TODO(pkarkus) we could store cummulative distances to speed this up ++ ++ Args: ++ dist (np.ndarray): distances, any shape [...] ++ start_ind (np.ndarray): index of point along polyline to calcualte distance from. ++ Optional. Shape must match dist. [...] ++ ++ Returns: ++ endpoint_xyzh (np.ndarray): points along polyline `dist` distance from the ++ starting point. Nan if endpoint would require extrapolation. [..., 4] ++ ++ """ ++ assert self.has_heading ++ ++ # Add up distances from beginning of polyline ++ segment_lens = np.linalg.norm(self.xyz[1:] - self.xyz[:-1], axis=-1) # n-1 ++ cum_len = np.pad(np.cumsum(segment_lens, axis=0), (1, 0)) # n ++ ++ # Increase dist with the length of lane up to start_ind ++ if start_ind is not None: ++ assert start_ind.ndim == dist.ndim ++ dist = dist + cum_len[start_ind] ++ ++ # Find the first index where cummulative length is larger or equal than `dist` ++ inds = np.searchsorted(cum_len, dist, side='right') ++ # Invalidate inds == 0 and inds == len(cum_len), which means endpoint is outside the polyline. ++ invalid = np.logical_or(inds == 0, inds == len(cum_len)) ++ # Replace invalid indices so we can easily carry out computation below, and invalidate output eventually. ++ inds[invalid] = 1 ++ ++ # Remaining distance from last point ++ remaining_dist = dist - cum_len[inds-1] ++ ++ # Invalidate negative remaining dist (this should only happen when dist < 0) ++ invalid = np.logical_or(invalid, remaining_dist < 0.) ++ ++ # Interpolate between the previous and next points. ++ segment_vect_xyz = self.xyz[inds] - self.xyz[inds-1] ++ segment_len = np.linalg.norm(segment_vect_xyz, axis=-1) ++ assert (segment_len > 0.).all(), "Polyline segment has zero length" ++ ++ proportion = (remaining_dist / segment_len) ++ endpoint_xyz = segment_vect_xyz * proportion[..., np.newaxis] + self.xyz[inds] ++ endpoint_h = angle_wrap(angle_wrap(self.h[inds] - self.h[inds-1]) * proportion + self.h[inds-1]) ++ endpoint_xyzh = np.concatenate((endpoint_xyz, endpoint_h[..., np.newaxis]), axis=-1) ++ ++ # Invalidate dummy output ++ endpoint_xyzh[invalid] = np.nan ++ ++ return endpoint_xyzh ++ ++ def concatenate_with(self, other: "Polyline") -> "Polyline": ++ return self.concatenate([self, other]) ++ ++ @staticmethod ++ def concatenate(polylines: List["Polyline"]) -> "Polyline": ++ # Assumes no overlap between consecutive polylines, i.e. next lane starts after current lane ends. ++ points = np.concatenate([polyline.points for polyline in polylines], axis=0) ++ return Polyline(points) + + + @dataclass +@@ -132,6 +240,7 @@ class RoadLane(MapElement): + adj_lanes_right: Set[str] = field(default_factory=lambda: set()) + next_lanes: Set[str] = field(default_factory=lambda: set()) + prev_lanes: Set[str] = field(default_factory=lambda: set()) ++ road_area_ids: Set[str] = field(default_factory=lambda: set()) + elem_type: MapElementType = MapElementType.ROAD_LANE + + def __post_init__(self) -> None: +@@ -150,7 +259,35 @@ class RoadLane(MapElement): + @property + def reachable_lanes(self) -> Set[str]: + return self.adj_lanes_left | self.adj_lanes_right | self.next_lanes +- ++ ++ ++ def combine_next(self, next_lane): ++ assert next_lane.id in self.next_lanes ++ self.next_lanes.remove(next_lane.id) ++ self.next_lanes = self.next_lanes.union(next_lane.next_lanes) ++ self.center = self.center.concatenate_with(next_lane.center) ++ if self.left_edge is not None and next_lane.left_edge is not None: ++ self.left_edge = self.left_edge.concatenate_with(next_lane.left_edge) ++ if self.right_edge is not None and next_lane.right_edge is not None: ++ self.right_edge = self.right_edge.concatenate_with(next_lane.right_edge) ++ self.adj_lanes_right = self.adj_lanes_right.union(next_lane.adj_lanes_right) ++ self.adj_lanes_left = self.adj_lanes_left.union(next_lane.adj_lanes_left) ++ self.road_area_ids = self.road_area_ids.union(next_lane.road_area_ids) ++ ++ def combine_prev(self,prev_lane): ++ assert prev_lane.id in self.prev_lanes ++ self.prev_lanes.remove(prev_lane.id) ++ self.prev_lanes = self.prev_lanes.union(prev_lane.prev_lanes) ++ self.center = prev_lane.center.concatenate_with(self.center) ++ if self.left_edge is not None and prev_lane.left_edge is not None: ++ self.left_edge = prev_lane.left_edge.concatenate_with(self.left_edge) ++ if self.right_edge is not None and prev_lane.right_edge is not None: ++ self.right_edge = prev_lane.right_edge.concatenate_with(self.right_edge) ++ self.adj_lanes_right = self.adj_lanes_right.union(prev_lane.adj_lanes_right) ++ self.adj_lanes_left = self.adj_lanes_left.union(prev_lane.adj_lanes_left) ++ self.road_area_ids = self.road_area_ids.union(prev_lane.road_area_ids) ++ ++ + + @dataclass + class RoadArea(MapElement): +diff --git a/src/trajdata/maps/vec_map_remote.py b/src/trajdata/maps/vec_map_remote.py +new file mode 100644 +index 0000000..77f2375 +--- /dev/null ++++ b/src/trajdata/maps/vec_map_remote.py +@@ -0,0 +1,582 @@ ++from __future__ import annotations ++ ++from typing import TYPE_CHECKING ++ ++if TYPE_CHECKING: ++ from trajdata.maps.map_kdtree import MapElementKDTree, LaneCenterKDTree ++ ++from collections import defaultdict ++from dataclasses import dataclass, field ++from math import ceil ++from typing import ( ++ DefaultDict, ++ Dict, ++ Iterator, ++ List, ++ Optional, ++ Set, ++ Tuple, ++ Union, ++ overload, ++) ++ ++import matplotlib as mpl ++import matplotlib.pyplot as plt ++import numpy as np ++from matplotlib.axes import Axes ++from tqdm import tqdm ++ ++import trajdata.proto.vectorized_map_pb2 as map_proto ++from trajdata.maps.map_kdtree import LaneCenterKDTree ++from trajdata.maps.traffic_light_status import TrafficLightStatus ++from trajdata.maps.vec_map_elements import ( ++ MapElement, ++ MapElementType, ++ PedCrosswalk, ++ PedWalkway, ++ Polyline, ++ RoadArea, ++ RoadLane, ++) ++from trajdata.utils import map_utils, raster_utils ++ ++ ++@dataclass(repr=False) ++class VectorMap: ++ map_id: str ++ extent: Optional[ ++ np.ndarray ++ ] = None # extent is [min_x, min_y, min_z, max_x, max_y, max_z] ++ elements: DefaultDict[MapElementType, Dict[str, MapElement]] = field( ++ default_factory=lambda: defaultdict(dict) ++ ) ++ search_kdtrees: Optional[Dict[MapElementType, MapElementKDTree]] = None ++ traffic_light_status: Optional[Dict[Tuple[int, int], TrafficLightStatus]] = None ++ ++ def __post_init__(self) -> None: ++ self.env_name, self.map_name = self.map_id.split(":") ++ ++ self.lanes: Optional[List[RoadLane]] = None ++ if MapElementType.ROAD_LANE in self.elements: ++ self.lanes = list(self.elements[MapElementType.ROAD_LANE].values()) ++ ++ def add_map_element(self, map_elem: MapElement) -> None: ++ self.elements[map_elem.elem_type][map_elem.id] = map_elem ++ ++ def compute_search_indices(self) -> None: ++ self.search_kdtrees = {MapElementType.ROAD_LANE: LaneCenterKDTree(self)} ++ ++ def iter_elems(self) -> Iterator[MapElement]: ++ for elems_dict in self.elements.values(): ++ for elem in elems_dict.values(): ++ yield elem ++ ++ def get_road_lane(self, lane_id: str) -> RoadLane: ++ return self.elements[MapElementType.ROAD_LANE][lane_id] ++ ++ def __len__(self) -> int: ++ return sum(len(elems_dict) for elems_dict in self.elements.values()) ++ ++ def _write_road_lanes( ++ self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray ++ ) -> None: ++ road_lane: RoadLane ++ for elem_id, road_lane in self.elements[MapElementType.ROAD_LANE].items(): ++ new_element: map_proto.MapElement = vectorized_map.elements.add() ++ new_element.id = elem_id.encode() ++ ++ new_lane: map_proto.RoadLane = new_element.road_lane ++ map_utils.populate_lane_polylines(new_lane, road_lane, shifted_origin) ++ ++ new_lane.entry_lanes.extend( ++ [lane_id.encode() for lane_id in road_lane.prev_lanes] ++ ) ++ new_lane.exit_lanes.extend( ++ [lane_id.encode() for lane_id in road_lane.next_lanes] ++ ) ++ ++ new_lane.adjacent_lanes_left.extend( ++ [lane_id.encode() for lane_id in road_lane.adj_lanes_left] ++ ) ++ new_lane.adjacent_lanes_right.extend( ++ [lane_id.encode() for lane_id in road_lane.adj_lanes_right] ++ ) ++ ++ def _write_road_areas( ++ self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray ++ ) -> None: ++ road_area: RoadArea ++ for elem_id, road_area in self.elements[MapElementType.ROAD_AREA].items(): ++ new_element: map_proto.MapElement = vectorized_map.elements.add() ++ new_element.id = elem_id.encode() ++ ++ new_area: map_proto.RoadArea = new_element.road_area ++ map_utils.populate_polygon( ++ new_area.exterior_polygon, ++ road_area.exterior_polygon.xyz, ++ shifted_origin, ++ ) ++ ++ hole: Polyline ++ for hole in road_area.interior_holes: ++ new_hole: map_proto.Polyline = new_area.interior_holes.add() ++ map_utils.populate_polygon( ++ new_hole, ++ hole.xyz, ++ shifted_origin, ++ ) ++ ++ def _write_ped_crosswalks( ++ self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray ++ ) -> None: ++ ped_crosswalk: PedCrosswalk ++ for elem_id, ped_crosswalk in self.elements[ ++ MapElementType.PED_CROSSWALK ++ ].items(): ++ new_element: map_proto.MapElement = vectorized_map.elements.add() ++ new_element.id = elem_id.encode() ++ ++ new_crosswalk: map_proto.PedCrosswalk = new_element.ped_crosswalk ++ map_utils.populate_polygon( ++ new_crosswalk.polygon, ++ ped_crosswalk.polygon.xyz, ++ shifted_origin, ++ ) ++ ++ def _write_ped_walkways( ++ self, vectorized_map: map_proto.VectorizedMap, shifted_origin: np.ndarray ++ ) -> None: ++ ped_walkway: PedWalkway ++ for elem_id, ped_walkway in self.elements[MapElementType.PED_WALKWAY].items(): ++ new_element: map_proto.MapElement = vectorized_map.elements.add() ++ new_element.id = elem_id.encode() ++ ++ new_walkway: map_proto.PedWalkway = new_element.ped_walkway ++ map_utils.populate_polygon( ++ new_walkway.polygon, ++ ped_walkway.polygon.xyz, ++ shifted_origin, ++ ) ++ ++ def to_proto(self) -> map_proto.VectorizedMap: ++ output_map = map_proto.VectorizedMap() ++ output_map.name = self.map_id ++ ++ ( ++ output_map.min_pt.x, ++ output_map.min_pt.y, ++ output_map.min_pt.z, ++ output_map.max_pt.x, ++ output_map.max_pt.y, ++ output_map.max_pt.z, ++ ) = self.extent ++ ++ shifted_origin: np.ndarray = self.extent[:3] ++ ( ++ output_map.shifted_origin.x, ++ output_map.shifted_origin.y, ++ output_map.shifted_origin.z, ++ ) = shifted_origin ++ ++ # Populating the elements in the vectorized map protobuf. ++ self._write_road_lanes(output_map, shifted_origin) ++ self._write_road_areas(output_map, shifted_origin) ++ self._write_ped_crosswalks(output_map, shifted_origin) ++ self._write_ped_walkways(output_map, shifted_origin) ++ ++ return output_map ++ ++ @classmethod ++ def from_proto(cls, vec_map: map_proto.VectorizedMap, **kwargs): ++ # Options for which map elements to include. ++ incl_road_lanes: bool = kwargs.get("incl_road_lanes", True) ++ incl_road_areas: bool = kwargs.get("incl_road_areas", False) ++ incl_ped_crosswalks: bool = kwargs.get("incl_ped_crosswalks", False) ++ incl_ped_walkways: bool = kwargs.get("incl_ped_walkways", False) ++ ++ # Add any map offset in case the map origin was shifted for storage efficiency. ++ shifted_origin: np.ndarray = np.array( ++ [ ++ vec_map.shifted_origin.x, ++ vec_map.shifted_origin.y, ++ vec_map.shifted_origin.z, ++ 0.0, # Some polylines also have heading so we're adding ++ # this (zero) coordinate to account for that. ++ ] ++ ) ++ ++ map_elem_dict: Dict[str, Dict[str, MapElement]] = defaultdict(dict) ++ ++ map_elem: MapElement ++ for map_elem in vec_map.elements: ++ elem_id: str = map_elem.id.decode() ++ if incl_road_lanes and map_elem.HasField("road_lane"): ++ road_lane_obj: map_proto.RoadLane = map_elem.road_lane ++ ++ center_pl: Polyline = Polyline( ++ map_utils.proto_to_np(road_lane_obj.center) + shifted_origin ++ ) ++ ++ # We do not care for the heading of the left and right edges ++ # (only the center matters). ++ left_pl: Optional[Polyline] = None ++ if road_lane_obj.HasField("left_boundary"): ++ left_pl = Polyline( ++ map_utils.proto_to_np( ++ road_lane_obj.left_boundary, incl_heading=False ++ ) ++ + shifted_origin[:3] ++ ) ++ ++ right_pl: Optional[Polyline] = None ++ if road_lane_obj.HasField("right_boundary"): ++ right_pl = Polyline( ++ map_utils.proto_to_np( ++ road_lane_obj.right_boundary, incl_heading=False ++ ) ++ + shifted_origin[:3] ++ ) ++ ++ adj_lanes_left: Set[str] = set( ++ [iden.decode() for iden in road_lane_obj.adjacent_lanes_left] ++ ) ++ adj_lanes_right: Set[str] = set( ++ [iden.decode() for iden in road_lane_obj.adjacent_lanes_right] ++ ) ++ ++ next_lanes: Set[str] = set( ++ [iden.decode() for iden in road_lane_obj.exit_lanes] ++ ) ++ prev_lanes: Set[str] = set( ++ [iden.decode() for iden in road_lane_obj.entry_lanes] ++ ) ++ ++ # Double-using the connectivity attributes for lane IDs now (will ++ # replace them with Lane objects after all Lane objects have been created). ++ curr_lane = RoadLane( ++ elem_id, ++ center_pl, ++ left_pl, ++ right_pl, ++ adj_lanes_left, ++ adj_lanes_right, ++ next_lanes, ++ prev_lanes, ++ ) ++ map_elem_dict[MapElementType.ROAD_LANE][elem_id] = curr_lane ++ ++ elif incl_road_areas and map_elem.HasField("road_area"): ++ road_area_obj: map_proto.RoadArea = map_elem.road_area ++ ++ exterior: Polyline = Polyline( ++ map_utils.proto_to_np( ++ road_area_obj.exterior_polygon, incl_heading=False ++ ) ++ + shifted_origin[:3] ++ ) ++ ++ interior_holes: List[Polyline] = list() ++ interior_hole: map_proto.Polyline ++ for interior_hole in road_area_obj.interior_holes: ++ interior_holes.append( ++ Polyline( ++ map_utils.proto_to_np(interior_hole, incl_heading=False) ++ + shifted_origin[:3] ++ ) ++ ) ++ ++ curr_area = RoadArea(elem_id, exterior, interior_holes) ++ map_elem_dict[MapElementType.ROAD_AREA][elem_id] = curr_area ++ ++ elif incl_ped_crosswalks and map_elem.HasField("ped_crosswalk"): ++ ped_crosswalk_obj: map_proto.PedCrosswalk = map_elem.ped_crosswalk ++ ++ polygon_vertices: Polyline = Polyline( ++ map_utils.proto_to_np(ped_crosswalk_obj.polygon, incl_heading=False) ++ + shifted_origin[:3] ++ ) ++ ++ curr_area = PedCrosswalk(elem_id, polygon_vertices) ++ map_elem_dict[MapElementType.PED_CROSSWALK][elem_id] = curr_area ++ ++ elif incl_ped_walkways and map_elem.HasField("ped_walkway"): ++ ped_walkway_obj: map_proto.PedCrosswalk = map_elem.ped_walkway ++ ++ polygon_vertices: Polyline = Polyline( ++ map_utils.proto_to_np(ped_walkway_obj.polygon, incl_heading=False) ++ + shifted_origin[:3] ++ ) ++ ++ curr_area = PedWalkway(elem_id, polygon_vertices) ++ map_elem_dict[MapElementType.PED_WALKWAY][elem_id] = curr_area ++ ++ return cls( ++ map_id=vec_map.name, ++ extent=np.array( ++ [ ++ vec_map.min_pt.x, ++ vec_map.min_pt.y, ++ vec_map.min_pt.z, ++ vec_map.max_pt.x, ++ vec_map.max_pt.y, ++ vec_map.max_pt.z, ++ ] ++ ), ++ elements=map_elem_dict, ++ search_kdtrees=None, ++ traffic_light_status=None, ++ ) ++ ++ def associate_scene_data( ++ self, traffic_light_status_dict: Dict[Tuple[int, int], TrafficLightStatus] ++ ) -> None: ++ """Associates vector map with scene-specific data like traffic light information""" ++ self.traffic_light_status = traffic_light_status_dict ++ ++ def get_current_lane( ++ self, ++ xyzh: np.ndarray, ++ max_dist: float = 2.0, ++ max_heading_error: float = np.pi / 8, ++ ) -> List[RoadLane]: ++ """ ++ Args: ++ xyzh (np.ndarray): 3d position and heading of agent in world coordinates ++ ++ Returns: ++ List[RoadLane]: List of possible road lanes that agent could be on ++ """ ++ lane_kdtree: LaneCenterKDTree = self.search_kdtrees[MapElementType.ROAD_LANE] ++ return [ ++ self.lanes[idx] ++ for idx in lane_kdtree.current_lane_inds(xyzh, max_dist, max_heading_error) ++ ] ++ ++ def get_closest_lane(self, xyz: np.ndarray) -> RoadLane: ++ lane_kdtree: LaneCenterKDTree = self.search_kdtrees[MapElementType.ROAD_LANE] ++ return self.lanes[lane_kdtree.closest_polyline_ind(xyz)] ++ ++ def get_closest_unique_lanes(self, xyz_vec: np.ndarray) -> List[RoadLane]: ++ assert xyz_vec.ndim == 2 # xyz_vec is assumed to be (*, 3) ++ lane_kdtree: LaneCenterKDTree = self.search_kdtrees[MapElementType.ROAD_LANE] ++ closest_inds = lane_kdtree.closest_polyline_ind(xyz_vec) ++ unique_inds = np.unique(closest_inds) ++ return [self.lanes[ind] for ind in unique_inds] ++ ++ def get_lanes_within(self, xyz: np.ndarray, dist: float) -> List[RoadLane]: ++ lane_kdtree: LaneCenterKDTree = self.search_kdtrees[MapElementType.ROAD_LANE] ++ return [ ++ self.lanes[idx] for idx in lane_kdtree.polyline_inds_in_range(xyz, dist) ++ ] ++ ++ def get_traffic_light_status( ++ self, lane_id: str, scene_ts: int ++ ) -> TrafficLightStatus: ++ return ( ++ self.traffic_light_status.get( ++ (int(lane_id), scene_ts), TrafficLightStatus.NO_DATA ++ ) ++ if self.traffic_light_status is not None ++ else TrafficLightStatus.NO_DATA ++ ) ++ ++ def rasterize( ++ self, resolution: float = 2, **kwargs ++ ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: ++ """Renders this vector map at the specified resolution. ++ ++ Args: ++ resolution (float): The rasterized image's resolution in pixels per meter. ++ ++ Returns: ++ np.ndarray: The rasterized RGB image. ++ """ ++ return_tf_mat: bool = kwargs.get("return_tf_mat", False) ++ incl_centerlines: bool = kwargs.get("incl_centerlines", True) ++ incl_lane_edges: bool = kwargs.get("incl_lane_edges", True) ++ incl_lane_area: bool = kwargs.get("incl_lane_area", True) ++ ++ scene_ts: Optional[int] = kwargs.get("scene_ts", None) ++ ++ # (255, 102, 99) also looks nice. ++ center_color: Tuple[int, int, int] = kwargs.get("center_color", (129, 51, 255)) ++ # (86, 203, 249) also looks nice. ++ edge_color: Tuple[int, int, int] = kwargs.get("edge_color", (118, 185, 0)) ++ # (191, 215, 234) also looks nice. ++ area_color: Tuple[int, int, int] = kwargs.get("area_color", (214, 232, 181)) ++ ++ min_x, min_y, _, max_x, max_y, _ = self.extent ++ ++ world_center_m: Tuple[float, float] = ( ++ (max_x + min_x) / 2, ++ (max_y + min_y) / 2, ++ ) ++ ++ raster_size_x: int = ceil((max_x - min_x) * resolution) ++ raster_size_y: int = ceil((max_y - min_y) * resolution) ++ ++ raster_from_local: np.ndarray = np.array( ++ [ ++ [resolution, 0, raster_size_x / 2], ++ [0, resolution, raster_size_y / 2], ++ [0, 0, 1], ++ ] ++ ) ++ ++ # Compute pose from its position and rotation. ++ pose_from_world: np.ndarray = np.array( ++ [ ++ [1, 0, -world_center_m[0]], ++ [0, 1, -world_center_m[1]], ++ [0, 0, 1], ++ ] ++ ) ++ ++ raster_from_world: np.ndarray = raster_from_local @ pose_from_world ++ ++ map_img: np.ndarray = np.zeros( ++ shape=(raster_size_y, raster_size_x, 3), dtype=np.uint8 ++ ) ++ ++ lane_edges: List[np.ndarray] = list() ++ centerlines: List[np.ndarray] = list() ++ lane: RoadLane ++ for lane in tqdm( ++ self.elements[MapElementType.ROAD_LANE].values(), ++ desc=f"Rasterizing Map at {resolution:.2f} px/m", ++ leave=False, ++ ): ++ centerlines.append( ++ raster_utils.world_to_subpixel( ++ lane.center.points[:, :2], raster_from_world ++ ) ++ ) ++ if lane.left_edge is not None and lane.right_edge is not None: ++ left_pts: np.ndarray = lane.left_edge.points[:, :2] ++ right_pts: np.ndarray = lane.right_edge.points[:, :2] ++ ++ lane_edges += [ ++ raster_utils.world_to_subpixel(left_pts, raster_from_world), ++ raster_utils.world_to_subpixel(right_pts, raster_from_world), ++ ] ++ ++ lane_color = area_color ++ status = self.get_traffic_light_status(lane.id, scene_ts) ++ if status == TrafficLightStatus.GREEN: ++ lane_color = [0, 200, 0] ++ elif status == TrafficLightStatus.RED: ++ lane_color = [200, 0, 0] ++ elif status == TrafficLightStatus.UNKNOWN: ++ lane_color = [150, 150, 0] ++ ++ # Drawing lane areas. Need to do per loop because doing it all at once can ++ # create lots of wonky holes in the image. ++ # See https://stackoverflow.com/questions/69768620/cv2-fillpoly-failing-for-intersecting-polygons ++ if incl_lane_area: ++ lane_area: np.ndarray = np.concatenate( ++ [left_pts, right_pts[::-1]], axis=0 ++ ) ++ raster_utils.rasterize_world_polygon( ++ lane_area, ++ map_img, ++ raster_from_world, ++ color=lane_color, ++ ) ++ ++ # Drawing all lane edge lines at the same time. ++ if incl_lane_edges: ++ raster_utils.cv2_draw_polylines(lane_edges, map_img, color=edge_color) ++ ++ # Drawing centerlines last (on top of everything else). ++ if incl_centerlines: ++ raster_utils.cv2_draw_polylines(centerlines, map_img, color=center_color) ++ ++ if return_tf_mat: ++ return map_img.astype(float) / 255, raster_from_world ++ else: ++ return map_img.astype(float) / 255 ++ ++ @overload ++ def visualize_lane_graph( ++ self, ++ origin_lane: RoadLane, ++ num_hops: int, ++ **kwargs, ++ ) -> Axes: ++ ... ++ ++ @overload ++ def visualize_lane_graph(self, origin_lane: str, num_hops: int, **kwargs) -> Axes: ++ ... ++ ++ @overload ++ def visualize_lane_graph(self, origin_lane: int, num_hops: int, **kwargs) -> Axes: ++ ... ++ ++ def visualize_lane_graph( ++ self, origin_lane: Union[RoadLane, str, int], num_hops: int, **kwargs ++ ) -> Axes: ++ ax = kwargs.get("ax", None) ++ if ax is None: ++ fig, ax = plt.subplots() ++ ++ origin: str ++ if isinstance(origin_lane, RoadLane): ++ origin = origin_lane.id ++ elif isinstance(origin_lane, str): ++ origin = origin_lane ++ elif isinstance(origin_lane, int): ++ origin = self.lanes[origin_lane].id ++ ++ viridis = mpl.colormaps[kwargs.get("cmap", "rainbow")].resampled(num_hops + 1) ++ ++ already_seen: Set[str] = set() ++ lanes_to_plot: List[Tuple[str, int]] = [(origin, 0)] ++ ++ if kwargs.get("legend", True): ++ ax.scatter([], [], label=f"Lane Endpoints", color="k") ++ ax.plot([], [], label=f"Origin Lane ({origin})", color=viridis(0)) ++ for h in range(1, num_hops + 1): ++ ax.plot( ++ [], ++ [], ++ label=f"{h} Lane{'s' if h > 1 else ''} Away", ++ color=viridis(h), ++ ) ++ ++ raster_from_world = kwargs.get("raster_from_world", None) ++ while len(lanes_to_plot) > 0: ++ lane_id, curr_hops = lanes_to_plot.pop(0) ++ already_seen.add(lane_id) ++ lane: RoadLane = self.get_road_lane(lane_id) ++ ++ center: np.ndarray = lane.center.points[..., :2] ++ first_pt_heading: float = lane.center.points[0, -1] ++ mdpt: np.ndarray = lane.center.midpoint[..., :2] ++ ++ if raster_from_world is not None: ++ center = map_utils.transform_points(center, raster_from_world) ++ mdpt = map_utils.transform_points(mdpt[None, :], raster_from_world)[0] ++ ++ ax.plot(center[:, 0], center[:, 1], color=viridis(curr_hops)) ++ ax.scatter(center[[0, -1], 0], center[[0, -1], 1], color=viridis(curr_hops)) ++ ax.quiver( ++ [center[0, 0]], ++ [center[0, 1]], ++ [np.cos(first_pt_heading)], ++ [np.sin(first_pt_heading)], ++ color=viridis(curr_hops), ++ ) ++ ax.text(mdpt[0], mdpt[1], s=lane_id) ++ ++ if curr_hops < num_hops: ++ lanes_to_plot += [ ++ (l, curr_hops + 1) ++ for l in lane.reachable_lanes ++ if l not in already_seen ++ ] ++ ++ if kwargs.get("legend", True): ++ ax.legend(loc="best", frameon=True) ++ ++ return ax +\ No newline at end of file +diff --git a/src/trajdata/proto/vectorized_map.proto b/src/trajdata/proto/vectorized_map.proto +index e1cd502..8f68281 100644 +--- a/src/trajdata/proto/vectorized_map.proto ++++ b/src/trajdata/proto/vectorized_map.proto +@@ -39,14 +39,14 @@ message Point { + } + + message Polyline { +- // Position deltas in millimeters. The origin is an arbitrary location. +- // From https://github.com/woven-planet/l5kit/blob/master/l5kit/l5kit/data/proto/road_network.proto#L446 ++ // Position deltas in 10^-5 meters. The origin is an arbitrary location. ++ // Inspired by https://github.com/woven-planet/l5kit/blob/master/l5kit/l5kit/data/proto/road_network.proto#L446 + // The delta for the first point is just its coordinates tuple, i.e. it is a "delta" from + // the origin. For subsequent points, this field stores the difference between the point's + // coordinates and the previous point's coordinates. This is for representation efficiency. +- repeated sint32 dx_mm = 1; +- repeated sint32 dy_mm = 2; +- repeated sint32 dz_mm = 3; ++ repeated sint64 dx_mm = 1; ++ repeated sint64 dy_mm = 2; ++ repeated sint64 dz_mm = 3; + repeated double h_rad = 4; + } + +@@ -74,6 +74,9 @@ message RoadLane { + // A list of neighbors to the right of this lane. Neighbor lanes + // include only adjacent lanes going the same direction. + repeated bytes adjacent_lanes_right = 7; ++ ++ // A list of associated road area ids. ++ repeated bytes road_area_ids = 8; + } + + message RoadArea { +diff --git a/src/trajdata/proto/vectorized_map_pb2.py b/src/trajdata/proto/vectorized_map_pb2.py +index 7352109..c5910d9 100644 +--- a/src/trajdata/proto/vectorized_map_pb2.py ++++ b/src/trajdata/proto/vectorized_map_pb2.py +@@ -1,136 +1,545 @@ + # -*- coding: utf-8 -*- + # Generated by the protocol buffer compiler. DO NOT EDIT! + # source: vectorized_map.proto +-"""Generated protocol buffer code.""" ++ + from google.protobuf import descriptor as _descriptor +-from google.protobuf import descriptor_pool as _descriptor_pool + from google.protobuf import message as _message + from google.protobuf import reflection as _reflection + from google.protobuf import symbol_database as _symbol_database +- + # @@protoc_insertion_point(imports) + + _sym_db = _symbol_database.Default() + + +-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( +- b'\n\x14vectorized_map.proto\x12\x08trajdata"\xb0\x01\n\rVectorizedMap\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x08\x65lements\x18\x02 \x03(\x0b\x32\x14.trajdata.MapElement\x12\x1f\n\x06max_pt\x18\x03 \x01(\x0b\x32\x0f.trajdata.Point\x12\x1f\n\x06min_pt\x18\x04 \x01(\x0b\x32\x0f.trajdata.Point\x12\'\n\x0eshifted_origin\x18\x05 \x01(\x0b\x32\x0f.trajdata.Point"\xd8\x01\n\nMapElement\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\'\n\troad_lane\x18\x02 \x01(\x0b\x32\x12.trajdata.RoadLaneH\x00\x12\'\n\troad_area\x18\x03 \x01(\x0b\x32\x12.trajdata.RoadAreaH\x00\x12/\n\rped_crosswalk\x18\x04 \x01(\x0b\x32\x16.trajdata.PedCrosswalkH\x00\x12+\n\x0bped_walkway\x18\x05 \x01(\x0b\x32\x14.trajdata.PedWalkwayH\x00\x42\x0e\n\x0c\x65lement_data"(\n\x05Point\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01\x12\t\n\x01z\x18\x03 \x01(\x01"F\n\x08Polyline\x12\r\n\x05\x64x_mm\x18\x01 \x03(\x11\x12\r\n\x05\x64y_mm\x18\x02 \x03(\x11\x12\r\n\x05\x64z_mm\x18\x03 \x03(\x11\x12\r\n\x05h_rad\x18\x04 \x03(\x01"\x98\x02\n\x08RoadLane\x12"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12.\n\rleft_boundary\x18\x02 \x01(\x0b\x32\x12.trajdata.PolylineH\x00\x88\x01\x01\x12/\n\x0eright_boundary\x18\x03 \x01(\x0b\x32\x12.trajdata.PolylineH\x01\x88\x01\x01\x12\x13\n\x0b\x65ntry_lanes\x18\x04 \x03(\x0c\x12\x12\n\nexit_lanes\x18\x05 \x03(\x0c\x12\x1b\n\x13\x61\x64jacent_lanes_left\x18\x06 \x03(\x0c\x12\x1c\n\x14\x61\x64jacent_lanes_right\x18\x07 \x03(\x0c\x42\x10\n\x0e_left_boundaryB\x11\n\x0f_right_boundary"d\n\x08RoadArea\x12,\n\x10\x65xterior_polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12*\n\x0einterior_holes\x18\x02 \x03(\x0b\x32\x12.trajdata.Polyline"3\n\x0cPedCrosswalk\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline"1\n\nPedWalkway\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polylineb\x06proto3' ++ ++ ++DESCRIPTOR = _descriptor.FileDescriptor( ++ name='vectorized_map.proto', ++ package='trajdata', ++ syntax='proto3', ++ serialized_options=None, ++ create_key=_descriptor._internal_create_key, ++ serialized_pb=b'\n\x14vectorized_map.proto\x12\x08trajdata\"\xb0\x01\n\rVectorizedMap\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x08\x65lements\x18\x02 \x03(\x0b\x32\x14.trajdata.MapElement\x12\x1f\n\x06max_pt\x18\x03 \x01(\x0b\x32\x0f.trajdata.Point\x12\x1f\n\x06min_pt\x18\x04 \x01(\x0b\x32\x0f.trajdata.Point\x12\'\n\x0eshifted_origin\x18\x05 \x01(\x0b\x32\x0f.trajdata.Point\"\xd8\x01\n\nMapElement\x12\n\n\x02id\x18\x01 \x01(\x0c\x12\'\n\troad_lane\x18\x02 \x01(\x0b\x32\x12.trajdata.RoadLaneH\x00\x12\'\n\troad_area\x18\x03 \x01(\x0b\x32\x12.trajdata.RoadAreaH\x00\x12/\n\rped_crosswalk\x18\x04 \x01(\x0b\x32\x16.trajdata.PedCrosswalkH\x00\x12+\n\x0bped_walkway\x18\x05 \x01(\x0b\x32\x14.trajdata.PedWalkwayH\x00\x42\x0e\n\x0c\x65lement_data\"(\n\x05Point\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01\x12\t\n\x01z\x18\x03 \x01(\x01\"F\n\x08Polyline\x12\r\n\x05\x64x_mm\x18\x01 \x03(\x12\x12\r\n\x05\x64y_mm\x18\x02 \x03(\x12\x12\r\n\x05\x64z_mm\x18\x03 \x03(\x12\x12\r\n\x05h_rad\x18\x04 \x03(\x01\"\xaf\x02\n\x08RoadLane\x12\"\n\x06\x63\x65nter\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12.\n\rleft_boundary\x18\x02 \x01(\x0b\x32\x12.trajdata.PolylineH\x00\x88\x01\x01\x12/\n\x0eright_boundary\x18\x03 \x01(\x0b\x32\x12.trajdata.PolylineH\x01\x88\x01\x01\x12\x13\n\x0b\x65ntry_lanes\x18\x04 \x03(\x0c\x12\x12\n\nexit_lanes\x18\x05 \x03(\x0c\x12\x1b\n\x13\x61\x64jacent_lanes_left\x18\x06 \x03(\x0c\x12\x1c\n\x14\x61\x64jacent_lanes_right\x18\x07 \x03(\x0c\x12\x15\n\rroad_area_ids\x18\x08 \x03(\x0c\x42\x10\n\x0e_left_boundaryB\x11\n\x0f_right_boundary\"d\n\x08RoadArea\x12,\n\x10\x65xterior_polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\x12*\n\x0einterior_holes\x18\x02 \x03(\x0b\x32\x12.trajdata.Polyline\"3\n\x0cPedCrosswalk\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polyline\"1\n\nPedWalkway\x12#\n\x07polygon\x18\x01 \x01(\x0b\x32\x12.trajdata.Polylineb\x06proto3' + ) + + +-_VECTORIZEDMAP = DESCRIPTOR.message_types_by_name["VectorizedMap"] +-_MAPELEMENT = DESCRIPTOR.message_types_by_name["MapElement"] +-_POINT = DESCRIPTOR.message_types_by_name["Point"] +-_POLYLINE = DESCRIPTOR.message_types_by_name["Polyline"] +-_ROADLANE = DESCRIPTOR.message_types_by_name["RoadLane"] +-_ROADAREA = DESCRIPTOR.message_types_by_name["RoadArea"] +-_PEDCROSSWALK = DESCRIPTOR.message_types_by_name["PedCrosswalk"] +-_PEDWALKWAY = DESCRIPTOR.message_types_by_name["PedWalkway"] +-VectorizedMap = _reflection.GeneratedProtocolMessageType( +- "VectorizedMap", +- (_message.Message,), +- { +- "DESCRIPTOR": _VECTORIZEDMAP, +- "__module__": "vectorized_map_pb2" +- # @@protoc_insertion_point(class_scope:trajdata.VectorizedMap) +- }, ++ ++ ++_VECTORIZEDMAP = _descriptor.Descriptor( ++ name='VectorizedMap', ++ full_name='trajdata.VectorizedMap', ++ filename=None, ++ file=DESCRIPTOR, ++ containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[ ++ _descriptor.FieldDescriptor( ++ name='name', full_name='trajdata.VectorizedMap.name', index=0, ++ number=1, type=9, cpp_type=9, label=1, ++ has_default_value=False, default_value=b"".decode('utf-8'), ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='elements', full_name='trajdata.VectorizedMap.elements', index=1, ++ number=2, type=11, cpp_type=10, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='max_pt', full_name='trajdata.VectorizedMap.max_pt', index=2, ++ number=3, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='min_pt', full_name='trajdata.VectorizedMap.min_pt', index=3, ++ number=4, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='shifted_origin', full_name='trajdata.VectorizedMap.shifted_origin', index=4, ++ number=5, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ ], ++ extensions=[ ++ ], ++ nested_types=[], ++ enum_types=[ ++ ], ++ serialized_options=None, ++ is_extendable=False, ++ syntax='proto3', ++ extension_ranges=[], ++ oneofs=[ ++ ], ++ serialized_start=35, ++ serialized_end=211, + ) +-_sym_db.RegisterMessage(VectorizedMap) + +-MapElement = _reflection.GeneratedProtocolMessageType( +- "MapElement", +- (_message.Message,), +- { +- "DESCRIPTOR": _MAPELEMENT, +- "__module__": "vectorized_map_pb2" +- # @@protoc_insertion_point(class_scope:trajdata.MapElement) +- }, ++ ++_MAPELEMENT = _descriptor.Descriptor( ++ name='MapElement', ++ full_name='trajdata.MapElement', ++ filename=None, ++ file=DESCRIPTOR, ++ containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[ ++ _descriptor.FieldDescriptor( ++ name='id', full_name='trajdata.MapElement.id', index=0, ++ number=1, type=12, cpp_type=9, label=1, ++ has_default_value=False, default_value=b"", ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='road_lane', full_name='trajdata.MapElement.road_lane', index=1, ++ number=2, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='road_area', full_name='trajdata.MapElement.road_area', index=2, ++ number=3, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='ped_crosswalk', full_name='trajdata.MapElement.ped_crosswalk', index=3, ++ number=4, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='ped_walkway', full_name='trajdata.MapElement.ped_walkway', index=4, ++ number=5, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ ], ++ extensions=[ ++ ], ++ nested_types=[], ++ enum_types=[ ++ ], ++ serialized_options=None, ++ is_extendable=False, ++ syntax='proto3', ++ extension_ranges=[], ++ oneofs=[ ++ _descriptor.OneofDescriptor( ++ name='element_data', full_name='trajdata.MapElement.element_data', ++ index=0, containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[]), ++ ], ++ serialized_start=214, ++ serialized_end=430, + ) +-_sym_db.RegisterMessage(MapElement) + +-Point = _reflection.GeneratedProtocolMessageType( +- "Point", +- (_message.Message,), +- { +- "DESCRIPTOR": _POINT, +- "__module__": "vectorized_map_pb2" +- # @@protoc_insertion_point(class_scope:trajdata.Point) +- }, ++ ++_POINT = _descriptor.Descriptor( ++ name='Point', ++ full_name='trajdata.Point', ++ filename=None, ++ file=DESCRIPTOR, ++ containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[ ++ _descriptor.FieldDescriptor( ++ name='x', full_name='trajdata.Point.x', index=0, ++ number=1, type=1, cpp_type=5, label=1, ++ has_default_value=False, default_value=float(0), ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='y', full_name='trajdata.Point.y', index=1, ++ number=2, type=1, cpp_type=5, label=1, ++ has_default_value=False, default_value=float(0), ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='z', full_name='trajdata.Point.z', index=2, ++ number=3, type=1, cpp_type=5, label=1, ++ has_default_value=False, default_value=float(0), ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ ], ++ extensions=[ ++ ], ++ nested_types=[], ++ enum_types=[ ++ ], ++ serialized_options=None, ++ is_extendable=False, ++ syntax='proto3', ++ extension_ranges=[], ++ oneofs=[ ++ ], ++ serialized_start=432, ++ serialized_end=472, + ) +-_sym_db.RegisterMessage(Point) + +-Polyline = _reflection.GeneratedProtocolMessageType( +- "Polyline", +- (_message.Message,), +- { +- "DESCRIPTOR": _POLYLINE, +- "__module__": "vectorized_map_pb2" +- # @@protoc_insertion_point(class_scope:trajdata.Polyline) +- }, ++ ++_POLYLINE = _descriptor.Descriptor( ++ name='Polyline', ++ full_name='trajdata.Polyline', ++ filename=None, ++ file=DESCRIPTOR, ++ containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[ ++ _descriptor.FieldDescriptor( ++ name='dx_mm', full_name='trajdata.Polyline.dx_mm', index=0, ++ number=1, type=18, cpp_type=2, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='dy_mm', full_name='trajdata.Polyline.dy_mm', index=1, ++ number=2, type=18, cpp_type=2, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='dz_mm', full_name='trajdata.Polyline.dz_mm', index=2, ++ number=3, type=18, cpp_type=2, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='h_rad', full_name='trajdata.Polyline.h_rad', index=3, ++ number=4, type=1, cpp_type=5, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ ], ++ extensions=[ ++ ], ++ nested_types=[], ++ enum_types=[ ++ ], ++ serialized_options=None, ++ is_extendable=False, ++ syntax='proto3', ++ extension_ranges=[], ++ oneofs=[ ++ ], ++ serialized_start=474, ++ serialized_end=544, + ) +-_sym_db.RegisterMessage(Polyline) + +-RoadLane = _reflection.GeneratedProtocolMessageType( +- "RoadLane", +- (_message.Message,), +- { +- "DESCRIPTOR": _ROADLANE, +- "__module__": "vectorized_map_pb2" +- # @@protoc_insertion_point(class_scope:trajdata.RoadLane) +- }, ++ ++_ROADLANE = _descriptor.Descriptor( ++ name='RoadLane', ++ full_name='trajdata.RoadLane', ++ filename=None, ++ file=DESCRIPTOR, ++ containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[ ++ _descriptor.FieldDescriptor( ++ name='center', full_name='trajdata.RoadLane.center', index=0, ++ number=1, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='left_boundary', full_name='trajdata.RoadLane.left_boundary', index=1, ++ number=2, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='right_boundary', full_name='trajdata.RoadLane.right_boundary', index=2, ++ number=3, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='entry_lanes', full_name='trajdata.RoadLane.entry_lanes', index=3, ++ number=4, type=12, cpp_type=9, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='exit_lanes', full_name='trajdata.RoadLane.exit_lanes', index=4, ++ number=5, type=12, cpp_type=9, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='adjacent_lanes_left', full_name='trajdata.RoadLane.adjacent_lanes_left', index=5, ++ number=6, type=12, cpp_type=9, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='adjacent_lanes_right', full_name='trajdata.RoadLane.adjacent_lanes_right', index=6, ++ number=7, type=12, cpp_type=9, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='road_area_ids', full_name='trajdata.RoadLane.road_area_ids', index=7, ++ number=8, type=12, cpp_type=9, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ ], ++ extensions=[ ++ ], ++ nested_types=[], ++ enum_types=[ ++ ], ++ serialized_options=None, ++ is_extendable=False, ++ syntax='proto3', ++ extension_ranges=[], ++ oneofs=[ ++ _descriptor.OneofDescriptor( ++ name='_left_boundary', full_name='trajdata.RoadLane._left_boundary', ++ index=0, containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[]), ++ _descriptor.OneofDescriptor( ++ name='_right_boundary', full_name='trajdata.RoadLane._right_boundary', ++ index=1, containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[]), ++ ], ++ serialized_start=547, ++ serialized_end=850, + ) +-_sym_db.RegisterMessage(RoadLane) + +-RoadArea = _reflection.GeneratedProtocolMessageType( +- "RoadArea", +- (_message.Message,), +- { +- "DESCRIPTOR": _ROADAREA, +- "__module__": "vectorized_map_pb2" +- # @@protoc_insertion_point(class_scope:trajdata.RoadArea) +- }, ++ ++_ROADAREA = _descriptor.Descriptor( ++ name='RoadArea', ++ full_name='trajdata.RoadArea', ++ filename=None, ++ file=DESCRIPTOR, ++ containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[ ++ _descriptor.FieldDescriptor( ++ name='exterior_polygon', full_name='trajdata.RoadArea.exterior_polygon', index=0, ++ number=1, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ _descriptor.FieldDescriptor( ++ name='interior_holes', full_name='trajdata.RoadArea.interior_holes', index=1, ++ number=2, type=11, cpp_type=10, label=3, ++ has_default_value=False, default_value=[], ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ ], ++ extensions=[ ++ ], ++ nested_types=[], ++ enum_types=[ ++ ], ++ serialized_options=None, ++ is_extendable=False, ++ syntax='proto3', ++ extension_ranges=[], ++ oneofs=[ ++ ], ++ serialized_start=852, ++ serialized_end=952, + ) +-_sym_db.RegisterMessage(RoadArea) + +-PedCrosswalk = _reflection.GeneratedProtocolMessageType( +- "PedCrosswalk", +- (_message.Message,), +- { +- "DESCRIPTOR": _PEDCROSSWALK, +- "__module__": "vectorized_map_pb2" +- # @@protoc_insertion_point(class_scope:trajdata.PedCrosswalk) +- }, ++ ++_PEDCROSSWALK = _descriptor.Descriptor( ++ name='PedCrosswalk', ++ full_name='trajdata.PedCrosswalk', ++ filename=None, ++ file=DESCRIPTOR, ++ containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[ ++ _descriptor.FieldDescriptor( ++ name='polygon', full_name='trajdata.PedCrosswalk.polygon', index=0, ++ number=1, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ ], ++ extensions=[ ++ ], ++ nested_types=[], ++ enum_types=[ ++ ], ++ serialized_options=None, ++ is_extendable=False, ++ syntax='proto3', ++ extension_ranges=[], ++ oneofs=[ ++ ], ++ serialized_start=954, ++ serialized_end=1005, + ) +-_sym_db.RegisterMessage(PedCrosswalk) + +-PedWalkway = _reflection.GeneratedProtocolMessageType( +- "PedWalkway", +- (_message.Message,), +- { +- "DESCRIPTOR": _PEDWALKWAY, +- "__module__": "vectorized_map_pb2" +- # @@protoc_insertion_point(class_scope:trajdata.PedWalkway) +- }, ++ ++_PEDWALKWAY = _descriptor.Descriptor( ++ name='PedWalkway', ++ full_name='trajdata.PedWalkway', ++ filename=None, ++ file=DESCRIPTOR, ++ containing_type=None, ++ create_key=_descriptor._internal_create_key, ++ fields=[ ++ _descriptor.FieldDescriptor( ++ name='polygon', full_name='trajdata.PedWalkway.polygon', index=0, ++ number=1, type=11, cpp_type=10, label=1, ++ has_default_value=False, default_value=None, ++ message_type=None, enum_type=None, containing_type=None, ++ is_extension=False, extension_scope=None, ++ serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), ++ ], ++ extensions=[ ++ ], ++ nested_types=[], ++ enum_types=[ ++ ], ++ serialized_options=None, ++ is_extendable=False, ++ syntax='proto3', ++ extension_ranges=[], ++ oneofs=[ ++ ], ++ serialized_start=1007, ++ serialized_end=1056, + ) ++ ++_VECTORIZEDMAP.fields_by_name['elements'].message_type = _MAPELEMENT ++_VECTORIZEDMAP.fields_by_name['max_pt'].message_type = _POINT ++_VECTORIZEDMAP.fields_by_name['min_pt'].message_type = _POINT ++_VECTORIZEDMAP.fields_by_name['shifted_origin'].message_type = _POINT ++_MAPELEMENT.fields_by_name['road_lane'].message_type = _ROADLANE ++_MAPELEMENT.fields_by_name['road_area'].message_type = _ROADAREA ++_MAPELEMENT.fields_by_name['ped_crosswalk'].message_type = _PEDCROSSWALK ++_MAPELEMENT.fields_by_name['ped_walkway'].message_type = _PEDWALKWAY ++_MAPELEMENT.oneofs_by_name['element_data'].fields.append( ++ _MAPELEMENT.fields_by_name['road_lane']) ++_MAPELEMENT.fields_by_name['road_lane'].containing_oneof = _MAPELEMENT.oneofs_by_name['element_data'] ++_MAPELEMENT.oneofs_by_name['element_data'].fields.append( ++ _MAPELEMENT.fields_by_name['road_area']) ++_MAPELEMENT.fields_by_name['road_area'].containing_oneof = _MAPELEMENT.oneofs_by_name['element_data'] ++_MAPELEMENT.oneofs_by_name['element_data'].fields.append( ++ _MAPELEMENT.fields_by_name['ped_crosswalk']) ++_MAPELEMENT.fields_by_name['ped_crosswalk'].containing_oneof = _MAPELEMENT.oneofs_by_name['element_data'] ++_MAPELEMENT.oneofs_by_name['element_data'].fields.append( ++ _MAPELEMENT.fields_by_name['ped_walkway']) ++_MAPELEMENT.fields_by_name['ped_walkway'].containing_oneof = _MAPELEMENT.oneofs_by_name['element_data'] ++_ROADLANE.fields_by_name['center'].message_type = _POLYLINE ++_ROADLANE.fields_by_name['left_boundary'].message_type = _POLYLINE ++_ROADLANE.fields_by_name['right_boundary'].message_type = _POLYLINE ++_ROADLANE.oneofs_by_name['_left_boundary'].fields.append( ++ _ROADLANE.fields_by_name['left_boundary']) ++_ROADLANE.fields_by_name['left_boundary'].containing_oneof = _ROADLANE.oneofs_by_name['_left_boundary'] ++_ROADLANE.oneofs_by_name['_right_boundary'].fields.append( ++ _ROADLANE.fields_by_name['right_boundary']) ++_ROADLANE.fields_by_name['right_boundary'].containing_oneof = _ROADLANE.oneofs_by_name['_right_boundary'] ++_ROADAREA.fields_by_name['exterior_polygon'].message_type = _POLYLINE ++_ROADAREA.fields_by_name['interior_holes'].message_type = _POLYLINE ++_PEDCROSSWALK.fields_by_name['polygon'].message_type = _POLYLINE ++_PEDWALKWAY.fields_by_name['polygon'].message_type = _POLYLINE ++DESCRIPTOR.message_types_by_name['VectorizedMap'] = _VECTORIZEDMAP ++DESCRIPTOR.message_types_by_name['MapElement'] = _MAPELEMENT ++DESCRIPTOR.message_types_by_name['Point'] = _POINT ++DESCRIPTOR.message_types_by_name['Polyline'] = _POLYLINE ++DESCRIPTOR.message_types_by_name['RoadLane'] = _ROADLANE ++DESCRIPTOR.message_types_by_name['RoadArea'] = _ROADAREA ++DESCRIPTOR.message_types_by_name['PedCrosswalk'] = _PEDCROSSWALK ++DESCRIPTOR.message_types_by_name['PedWalkway'] = _PEDWALKWAY ++_sym_db.RegisterFileDescriptor(DESCRIPTOR) ++ ++VectorizedMap = _reflection.GeneratedProtocolMessageType('VectorizedMap', (_message.Message,), { ++ 'DESCRIPTOR' : _VECTORIZEDMAP, ++ '__module__' : 'vectorized_map_pb2' ++ # @@protoc_insertion_point(class_scope:trajdata.VectorizedMap) ++ }) ++_sym_db.RegisterMessage(VectorizedMap) ++ ++MapElement = _reflection.GeneratedProtocolMessageType('MapElement', (_message.Message,), { ++ 'DESCRIPTOR' : _MAPELEMENT, ++ '__module__' : 'vectorized_map_pb2' ++ # @@protoc_insertion_point(class_scope:trajdata.MapElement) ++ }) ++_sym_db.RegisterMessage(MapElement) ++ ++Point = _reflection.GeneratedProtocolMessageType('Point', (_message.Message,), { ++ 'DESCRIPTOR' : _POINT, ++ '__module__' : 'vectorized_map_pb2' ++ # @@protoc_insertion_point(class_scope:trajdata.Point) ++ }) ++_sym_db.RegisterMessage(Point) ++ ++Polyline = _reflection.GeneratedProtocolMessageType('Polyline', (_message.Message,), { ++ 'DESCRIPTOR' : _POLYLINE, ++ '__module__' : 'vectorized_map_pb2' ++ # @@protoc_insertion_point(class_scope:trajdata.Polyline) ++ }) ++_sym_db.RegisterMessage(Polyline) ++ ++RoadLane = _reflection.GeneratedProtocolMessageType('RoadLane', (_message.Message,), { ++ 'DESCRIPTOR' : _ROADLANE, ++ '__module__' : 'vectorized_map_pb2' ++ # @@protoc_insertion_point(class_scope:trajdata.RoadLane) ++ }) ++_sym_db.RegisterMessage(RoadLane) ++ ++RoadArea = _reflection.GeneratedProtocolMessageType('RoadArea', (_message.Message,), { ++ 'DESCRIPTOR' : _ROADAREA, ++ '__module__' : 'vectorized_map_pb2' ++ # @@protoc_insertion_point(class_scope:trajdata.RoadArea) ++ }) ++_sym_db.RegisterMessage(RoadArea) ++ ++PedCrosswalk = _reflection.GeneratedProtocolMessageType('PedCrosswalk', (_message.Message,), { ++ 'DESCRIPTOR' : _PEDCROSSWALK, ++ '__module__' : 'vectorized_map_pb2' ++ # @@protoc_insertion_point(class_scope:trajdata.PedCrosswalk) ++ }) ++_sym_db.RegisterMessage(PedCrosswalk) ++ ++PedWalkway = _reflection.GeneratedProtocolMessageType('PedWalkway', (_message.Message,), { ++ 'DESCRIPTOR' : _PEDWALKWAY, ++ '__module__' : 'vectorized_map_pb2' ++ # @@protoc_insertion_point(class_scope:trajdata.PedWalkway) ++ }) + _sym_db.RegisterMessage(PedWalkway) + +-if _descriptor._USE_C_DESCRIPTORS == False: +- +- DESCRIPTOR._options = None +- _VECTORIZEDMAP._serialized_start = 35 +- _VECTORIZEDMAP._serialized_end = 211 +- _MAPELEMENT._serialized_start = 214 +- _MAPELEMENT._serialized_end = 430 +- _POINT._serialized_start = 432 +- _POINT._serialized_end = 472 +- _POLYLINE._serialized_start = 474 +- _POLYLINE._serialized_end = 544 +- _ROADLANE._serialized_start = 547 +- _ROADLANE._serialized_end = 827 +- _ROADAREA._serialized_start = 829 +- _ROADAREA._serialized_end = 929 +- _PEDCROSSWALK._serialized_start = 931 +- _PEDCROSSWALK._serialized_end = 982 +- _PEDWALKWAY._serialized_start = 984 +- _PEDWALKWAY._serialized_end = 1033 ++ + # @@protoc_insertion_point(module_scope) +diff --git a/src/trajdata/utils/arr_utils.py b/src/trajdata/utils/arr_utils.py +index e76a678..6f10bb4 100644 +--- a/src/trajdata/utils/arr_utils.py ++++ b/src/trajdata/utils/arr_utils.py +@@ -93,7 +93,9 @@ def vrange(starts: np.ndarray, stops: np.ndarray) -> np.ndarray: + return np.repeat(stops - lens.cumsum(), lens) + np.arange(lens.sum()) + + +-def angle_wrap(radians: np.ndarray) -> np.ndarray: ++def angle_wrap( ++ radians: Union[np.ndarray, torch.Tensor] ++) -> Union[np.ndarray, torch.Tensor]: + """This function wraps angles to lie within [-pi, pi). + + Args: +@@ -130,12 +132,12 @@ def rotation_matrix(angle: Union[float, np.ndarray]) -> np.ndarray: + return rotmat.transpose(*np.arange(2, batch_dims + 2), 0, 1) + + +-def transform_matrices(angles: Tensor, translations: Tensor) -> Tensor: ++def transform_matrices(angles: Tensor, translations: Optional[Tensor]) -> Tensor: + """Creates a 3x3 transformation matrix for each angle and translation in the input. + + Args: +- angles (Tensor): The (N,)-shaped angles tensor to rotate points by (in radians). +- translations (Tensor): The (N,2)-shaped translations to shift points by. ++ angles (Tensor): The (...)-shaped angles tensor to rotate points by (in radians). ++ translations (Tensor): The (...,2)-shaped translations to shift points by. + + Returns: + Tensor: The Nx3x3 transformation matrices. +@@ -143,12 +145,19 @@ def transform_matrices(angles: Tensor, translations: Tensor) -> Tensor: + cos_vals = torch.cos(angles) + sin_vals = torch.sin(angles) + last_rows = torch.tensor( +- [[0.0, 0.0, 1.0]], dtype=angles.dtype, device=angles.device +- ).expand((angles.shape[0], -1)) ++ [0.0, 0.0, 1.0], dtype=angles.dtype, device=angles.device ++ ).view([1] * angles.ndim + [3]).expand(list(angles.shape) + [-1]) ++ ++ if translations is None: ++ trans_x = torch.zeros_like(angles) ++ trans_y = trans_x ++ else: ++ trans_x, trans_y = torch.unbind(translations, dim=-1) ++ + return torch.stack( + [ +- torch.stack([cos_vals, -sin_vals, translations[:, 0]], dim=-1), +- torch.stack([sin_vals, cos_vals, translations[:, 1]], dim=-1), ++ torch.stack([cos_vals, -sin_vals, trans_x], dim=-1), ++ torch.stack([sin_vals, cos_vals, trans_y], dim=-1), + last_rows, + ], + dim=-2, +@@ -243,6 +252,125 @@ def transform_xyh_np(xyh: np.ndarray, tf_mat: np.ndarray) -> np.ndarray: + transformed_angles = transform_angles_np(xyh[..., 2], tf_mat) + return np.concatenate([transformed_xy, transformed_angles[..., None]], axis=-1) + ++def transform_xyh_torch(xyh: torch.Tensor, tf_mat: torch.Tensor) -> torch.Tensor: ++ """ ++ Returns transformed set of xyh points ++ ++ Args: ++ xyh (torch.Tensor): shape [...,3] ++ tf_mat (torch.Tensor): shape [...,3,3] ++ """ ++ transformed_xy = batch_nd_transform_points_pt(xyh[..., :2], tf_mat) ++ transformed_angles = batch_nd_transform_angles_pt(xyh[..., 2], tf_mat) ++ return torch.cat([transformed_xy, transformed_angles[..., None]], dim=-1) ++ ++# -------- TODO(pkarkus) redundant transforms, remove them ++ ++ ++def batch_nd_transform_points_np(points: np.ndarray, Mat: np.ndarray) -> np.ndarray: ++ ndim = Mat.shape[-1] - 1 ++ batch = list(range(Mat.ndim - 2)) + [Mat.ndim - 1] + [Mat.ndim - 2] ++ Mat = np.transpose(Mat, batch) ++ if points.ndim == Mat.ndim - 1: ++ return (points[..., np.newaxis, :] @ Mat[..., :ndim, :ndim]).squeeze(-2) + Mat[ ++ ..., -1:, :ndim ++ ].squeeze(-2) ++ elif points.ndim == Mat.ndim: ++ return ( ++ (points[..., np.newaxis, :] @ Mat[..., np.newaxis, :ndim, :ndim]) ++ + Mat[..., np.newaxis, -1:, :ndim] ++ ).squeeze(-2) ++ else: ++ raise Exception("wrong shape") ++ ++def batch_nd_transform_points_pt( ++ points: torch.Tensor, Mat: torch.Tensor ++) -> torch.Tensor: ++ ndim = Mat.shape[-1] - 1 ++ Mat = torch.transpose(Mat, -1, -2) ++ if points.ndim == Mat.ndim - 1: ++ return (points[..., np.newaxis, :] @ Mat[..., :ndim, :ndim]).squeeze(-2) + Mat[ ++ ..., -1:, :ndim ++ ].squeeze(-2) ++ elif points.ndim == Mat.ndim: ++ return ( ++ (points[..., np.newaxis, :] @ Mat[..., np.newaxis, :ndim, :ndim]) ++ + Mat[..., np.newaxis, -1:, :ndim] ++ ).squeeze(-2) ++ elif points.ndim == Mat.ndim + 1: ++ return ( ++ ( ++ points[..., np.newaxis, :] ++ @ Mat[..., np.newaxis, np.newaxis, :ndim, :ndim] ++ ) ++ + Mat[..., np.newaxis, np.newaxis, -1:, :ndim] ++ ).squeeze(-2) ++ else: ++ raise Exception("wrong shape") ++ ++ ++def batch_nd_transform_angles_np(angles: np.ndarray, Mat: np.ndarray) -> np.ndarray: ++ cos_vals, sin_vals = Mat[..., 0, 0], Mat[..., 1, 0] ++ rot_angle = np.arctan2(sin_vals, cos_vals) ++ angles = angles + rot_angle ++ angles = angle_wrap(angles) ++ return angles ++ ++ ++def batch_nd_transform_angles_pt( ++ angles: torch.Tensor, Mat: torch.Tensor ++) -> torch.Tensor: ++ cos_vals, sin_vals = Mat[..., 0, 0], Mat[..., 1, 0] ++ rot_angle = torch.arctan2(sin_vals, cos_vals) ++ if rot_angle.ndim > angles.ndim: ++ raise ValueError("wrong shape") ++ while rot_angle.ndim < angles.ndim: ++ rot_angle = rot_angle.unsqueeze(-1) ++ angles = angles + rot_angle ++ angles = angle_wrap(angles) ++ return angles ++ ++ ++def batch_nd_transform_points_angles_np( ++ points_angles: np.ndarray, Mat: np.ndarray ++) -> np.ndarray: ++ assert points_angles.shape[-1] == 3 ++ points = batch_nd_transform_points_np(points_angles[..., :2], Mat) ++ angles = batch_nd_transform_angles_np(points_angles[..., 2:3], Mat) ++ points_angles = np.concatenate([points, angles], axis=-1) ++ return points_angles ++ ++ ++def batch_nd_transform_points_angles_pt( ++ points_angles: torch.Tensor, Mat: torch.Tensor ++) -> torch.Tensor: ++ assert points_angles.shape[-1] == 3 ++ points = batch_nd_transform_points_pt(points_angles[..., :2], Mat) ++ angles = batch_nd_transform_angles_pt(points_angles[..., 2:3], Mat) ++ points_angles = torch.concat([points, angles], axis=-1) ++ return points_angles ++ ++ ++def batch_nd_transform_xyvvaahh_pt(traj_xyvvaahh: torch.Tensor, tf: torch.Tensor) -> torch.Tensor: ++ """ ++ traj_xyvvaahh: [..., state_dim] where state_dim = [x, y, vx, vy, ax, ay, sinh, cosh] ++ This is the state representation used in AgentBatch and SceneBatch. ++ """ ++ rot_only_tf = tf.clone() ++ rot_only_tf[..., :2, -1] = 0. ++ ++ xy, vv, aa, hh = torch.split(traj_xyvvaahh, (2, 2, 2, 2), dim=-1) ++ xy = batch_nd_transform_points_pt(xy, tf) ++ vv = batch_nd_transform_points_pt(vv, rot_only_tf) ++ aa = batch_nd_transform_points_pt(aa, rot_only_tf) ++ # hh: sinh, cosh instead of cosh, sinh, so we use flip ++ hh = batch_nd_transform_points_pt(hh.flip(-1), rot_only_tf).flip(-1) ++ ++ return torch.concat((xy, vv, aa, hh), dim=-1) ++ ++ ++# -------- end of redundant transforms ++ + + def agent_aware_diff(values: np.ndarray, agent_ids: np.ndarray) -> np.ndarray: + values_diff: np.ndarray = np.diff( +@@ -325,3 +453,120 @@ def quaternion_to_yaw(q: np.ndarray): + 2 * (q[..., 0] * q[..., 3] - q[..., 1] * q[..., 2]), + 1 - 2 * (q[..., 2] ** 2 + q[..., 3] ** 2), + ) ++ ++ ++def batch_select( ++ x: torch.Tensor, ++ index: torch.Tensor, ++ batch_dims: int ++) -> torch.Tensor: ++ # Indexing into tensor, treating the first `batch_dims` dimensions as batch. ++ # Kind of: output[..., k] = x[..., index[...]] ++ ++ assert index.ndim >= batch_dims ++ assert index.ndim <= x.ndim ++ assert x.shape[:batch_dims] == index.shape[:batch_dims] ++ ++ batch_shape = x.shape[:batch_dims] ++ x_flat = x.reshape(-1, *x.shape[batch_dims:]) ++ index_flat = index.reshape(-1, *index.shape[batch_dims:]) ++ x_flat = x_flat[torch.arange(x_flat.shape[0]), index_flat] ++ x = x_flat.reshape(*batch_shape, *x_flat.shape[1:]) ++ ++ return x ++ ++ ++def roll_with_tensor(mat: torch.Tensor, shifts: torch.LongTensor, dim: int): ++ if dim < 0: ++ dim = mat.ndim + dim ++ arange1 = torch.arange(mat.shape[dim], device=shifts.device) ++ expanded_shape = [1] * dim + [-1] + [1] * (mat.ndim-dim-1) ++ arange1 = arange1.view(expanded_shape).expand(mat.shape) ++ if shifts.ndim == 1: ++ shifts = shifts.view([1] * (dim-1) + [-1]) ++ # TODO(pkarkus) assert that shift dimenesions either match mat or 1 ++ shifts = shifts.view(list(shifts.shape) + [1] * (mat.ndim-dim)) ++ ++ arange2 = (arange1 - shifts) % mat.shape[dim] ++ # print(arange2) ++ return torch.gather(mat, dim, arange2) ++ ++def round_2pi(x): ++ return (x + np.pi) % (2 * np.pi) - np.pi ++ ++def batch_proj(x, line): ++ # x:[batch,3], line:[batch,N,3] ++ line_length = line.shape[-2] ++ batch_dim = x.ndim - 1 ++ if isinstance(x, torch.Tensor): ++ delta = line[..., 0:2] - torch.unsqueeze(x[..., 0:2], dim=-2).repeat( ++ *([1] * batch_dim), line_length, 1 ++ ) ++ dis = torch.linalg.norm(delta, axis=-1) ++ idx0 = torch.argmin(dis, dim=-1) ++ idx = idx0.view(*line.shape[:-2], 1, 1).repeat( ++ *([1] * (batch_dim + 1)), line.shape[-1] ++ ) ++ line_min = torch.squeeze(torch.gather(line, -2, idx), dim=-2) ++ dx = x[..., None, 0] - line[..., 0] ++ dy = x[..., None, 1] - line[..., 1] ++ delta_y = -dx * torch.sin(line_min[..., None, 2]) + dy * torch.cos( ++ line_min[..., None, 2] ++ ) ++ delta_x = dx * torch.cos(line_min[..., None, 2]) + dy * torch.sin( ++ line_min[..., None, 2] ++ ) ++ ++ delta_psi = round_2pi(x[..., 2] - line_min[..., 2]) ++ ++ return ( ++ delta_x, ++ delta_y, ++ torch.unsqueeze(delta_psi, dim=-1), ++ ) ++ elif isinstance(x, np.ndarray): ++ delta = line[..., 0:2] - np.repeat( ++ x[..., np.newaxis, 0:2], line_length, axis=-2 ++ ) ++ dis = np.linalg.norm(delta, axis=-1) ++ idx0 = np.argmin(dis, axis=-1) ++ idx = idx0.reshape(*line.shape[:-2], 1, 1).repeat(line.shape[-1], axis=-1) ++ line_min = np.squeeze(np.take_along_axis(line, idx, axis=-2), axis=-2) ++ dx = x[..., None, 0] - line[..., 0] ++ dy = x[..., None, 1] - line[..., 1] ++ delta_y = -dx * np.sin(line_min[..., None, 2]) + dy * np.cos( ++ line_min[..., None, 2] ++ ) ++ delta_x = dx * np.cos(line_min[..., None, 2]) + dy * np.sin( ++ line_min[..., None, 2] ++ ) ++ ++ delta_psi = round_2pi(x[..., 2] - line_min[..., 2]) ++ return ( ++ delta_x, ++ delta_y, ++ np.expand_dims(delta_psi, axis=-1), ++ ) ++ ++def get_close_lanes(radius,ego_xyh,vec_map,num_pts): ++ # obtain close lanes, their distance to the ego ++ close_lanes = [] ++ while len(close_lanes)==0: ++ close_lanes=vec_map.get_lanes_within(ego_xyh,radius) ++ radius+=20 ++ dis = list() ++ lane_pts = np.stack([lane.center.interpolate(num_pts).points[:,[0,1,3]] for lane in close_lanes],0) ++ dx,dy,dh = batch_proj(ego_xyh[None].repeat(lane_pts.shape[0],0),lane_pts) ++ ++ idx = np.abs(dx).argmin(axis=1) ++ # hausdorff distance to the lane (longitudinal) ++ x_dis = np.take_along_axis(np.abs(dx),idx[:,None],axis=1).squeeze(1) ++ x_dis[(dx.min(1)<0) & (dx.max(1)>0)] = 0 ++ ++ y_dis = np.take_along_axis(np.abs(dy),idx[:,None],axis=1).squeeze(1) ++ ++ # distance metric to the lane (combining x,y) ++ dis = x_dis+y_dis ++ ++ ++ return close_lanes,dis +diff --git a/src/trajdata/utils/batch_utils.py b/src/trajdata/utils/batch_utils.py +index cf63009..555e790 100644 +--- a/src/trajdata/utils/batch_utils.py ++++ b/src/trajdata/utils/batch_utils.py +@@ -1,5 +1,9 @@ + from collections import defaultdict +-from typing import Any, Dict, Iterator, List, Optional, Tuple ++ ++import torch ++ ++from pathlib import Path ++from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + + import numpy as np + from torch.utils.data import Sampler +@@ -10,10 +14,15 @@ from trajdata.data_structures import ( + AgentBatchElement, + AgentDataIndex, + AgentType, ++ SceneBatch, + SceneBatchElement, + SceneTimeAgent, + ) +-from trajdata.data_structures.collation import agent_collate_fn ++from trajdata.data_structures.collation import agent_collate_fn, batch_rotate_raster_maps_for_agents_in_scene ++from trajdata.maps import RasterizedMapPatch ++from trajdata.utils.map_utils import load_map_patch ++from trajdata.utils.arr_utils import batch_nd_transform_xyvvaahh_pt, batch_select, PadDirection ++from trajdata.caching.df_cache import DataFrameCache + + + class SceneTimeBatcher(Sampler): +@@ -173,3 +182,107 @@ def convert_to_agent_batch( + ) + + return agent_collate_fn(batch_elems, return_dict=False, pad_format=pad_format) ++ ++ ++def get_agents_map_patch( ++ cache_path: Path, ++ map_name: str, ++ patch_params: Dict[str, int], ++ agent_world_states_xyh: Union[np.ndarray, torch.Tensor], ++ allow_nan: float = False, ++) -> List[RasterizedMapPatch]: ++ ++ if isinstance(agent_world_states_xyh, torch.Tensor): ++ agent_world_states_xyh = agent_world_states_xyh.cpu().numpy() ++ assert agent_world_states_xyh.ndim == 2 ++ assert agent_world_states_xyh.shape[-1] == 3 ++ ++ desired_patch_size: int = patch_params["map_size_px"] ++ resolution: float = patch_params["px_per_m"] ++ offset_xy: Tuple[float, float] = patch_params.get("offset_frac_xy", (0.0, 0.0)) ++ return_rgb: bool = patch_params.get("return_rgb", True) ++ no_map_fill_val: float = patch_params.get("no_map_fill_value", 0.0) ++ ++ env_name, location_name = map_name.split(':') # assumes map_name format nusc_mini:boston-seaport ++ ++ map_patches = list() ++ ++ ( ++ maps_path, ++ _, ++ _, ++ raster_map_path, ++ raster_metadata_path, ++ ) = DataFrameCache.get_map_paths( ++ cache_path, env_name, location_name, resolution ++ ) ++ ++ for i in range(agent_world_states_xyh.shape[0]): ++ patch_data, raster_from_world_tf, has_data = load_map_patch( ++ raster_map_path, ++ raster_metadata_path, ++ agent_world_states_xyh[i, 0], ++ agent_world_states_xyh[i, 1], ++ desired_patch_size, ++ resolution, ++ offset_xy, ++ agent_world_states_xyh[i, 2], ++ return_rgb, ++ rot_pad_factor=np.sqrt(2), ++ no_map_val=no_map_fill_val, ++ ) ++ map_patches.append( ++ RasterizedMapPatch( ++ data=patch_data, ++ rot_angle=agent_world_states_xyh[i, 2], ++ crop_size=desired_patch_size, ++ resolution=resolution, ++ raster_from_world_tf=raster_from_world_tf, ++ has_data=has_data, ++ ) ++ ) ++ ++ return map_patches ++ ++ ++def get_raster_maps_for_scene_batch(batch: SceneBatch, cache_path: Path, raster_map_params: Dict): ++ ++ # Get current states ++ agent_states = batch.agent_hist.as_format('x,y,xd,yd,xdd,ydd,s,c') ++ if batch.history_pad_dir == PadDirection.AFTER: ++ agent_states = batch_select(agent_states, index=batch.agent_hist_len-1, batch_dims=2) # b, N, t, 8 ++ else: ++ agent_states = agent_states[:, :, -1] ++ ++ agent_world_states_xyvvaahh = batch_nd_transform_xyvvaahh_pt( ++ agent_states.type_as(batch.centered_world_from_agent_tf), ++ batch.centered_world_from_agent_tf ++ ) ++ ++ agent_world_states_xyh = torch.concat(( ++ agent_world_states_xyvvaahh[..., :2], ++ torch.atan2(agent_world_states_xyvvaahh[..., 6:7], agent_world_states_xyvvaahh[..., 7:8])), dim=-1) ++ ++ maps: List[torch.Tensor] = [] ++ maps_resolution: List[torch.Tensor] = [] ++ raster_from_world_tf: List[torch.Tensor] = [] ++ ++ # Collect map patches for all elements and agents into a flat list ++ num_agents: List[int] = [] ++ map_patches: List[RasterizedMapPatch] = [] ++ ++ for b_i in range(agent_world_states_xyh.shape[0]): ++ num_agents.append(batch.num_agents[b_i]) ++ map_patches += get_agents_map_patch( ++ cache_path, batch.map_names[b_i], raster_map_params, agent_world_states_xyh[b_i, :batch.num_agents[b_i]]) ++ ++ # Batch transform map patches and pad ++ ( ++ maps, ++ maps_resolution, ++ raster_from_world_tf ++ ) = batch_rotate_raster_maps_for_agents_in_scene( ++ map_patches, num_agents, agent_world_states_xyh.shape[1], pad_value=np.nan, ++ ) ++ ++ return maps, maps_resolution, raster_from_world_tf +diff --git a/src/trajdata/utils/comm_utils.py b/src/trajdata/utils/comm_utils.py +new file mode 100644 +index 0000000..594ccb0 +--- /dev/null ++++ b/src/trajdata/utils/comm_utils.py +@@ -0,0 +1,24 @@ ++import numpy as np ++import socket ++ ++from contextlib import closing ++from typing import Callable, Optional ++ ++ ++def find_open_port(): ++ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: ++ s.bind(("", 0)) ++ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) ++ return s.getsockname()[1] ++ ++ ++def find_open_port_in_range(start_port, end_port): ++ for port in range(start_port, end_port+1): ++ try: ++ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: ++ s.bind(('localhost', port)) ++ s.listen(1) ++ return port ++ except OSError: ++ continue ++ return None +diff --git a/src/trajdata/utils/env_utils.py b/src/trajdata/utils/env_utils.py +index 4726537..a429f68 100644 +--- a/src/trajdata/utils/env_utils.py ++++ b/src/trajdata/utils/env_utils.py +@@ -33,6 +33,7 @@ except ModuleNotFoundError: + # This can happen if the user did not install trajdata + # with the "trajdata[nuplan]" option. + pass ++from trajdata.dataset_specific.drivesim import DrivesimDataset + + try: + from trajdata.dataset_specific.waymo import WaymoDataset +@@ -60,7 +61,10 @@ def get_raw_dataset(dataset_name: str, data_dir: str) -> RawDataset: + ) + + if "nuplan" in dataset_name: +- return NuplanDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) ++ return NuplanDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) ++ ++ if "drivesim" in dataset_name: ++ return DrivesimDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) + + if "waymo" in dataset_name: + return WaymoDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) +diff --git a/src/trajdata/utils/map_utils.py b/src/trajdata/utils/map_utils.py +index 01844b7..45c08fd 100644 +--- a/src/trajdata/utils/map_utils.py ++++ b/src/trajdata/utils/map_utils.py +@@ -3,19 +3,300 @@ from __future__ import annotations + from typing import TYPE_CHECKING + + if TYPE_CHECKING: +- from trajdata.maps import map_kdtree, vec_map ++ from trajdata.maps import map_kdtree, vec_map, RasterizedMapMetadata, RasterizedMapPatch + + from pathlib import Path +-from typing import Dict, Final, Optional ++from typing import Dict, Final, Optional, Tuple, List, Union + + import dill ++import kornia + import numpy as np ++import math ++import torch ++import zarr + from scipy.stats import circmean + + import trajdata.proto.vectorized_map_pb2 as map_proto + from trajdata.utils import arr_utils + +-MM_PER_M: Final[float] = 1000 ++NUM_DECIMALS: Final[int] = 5 ++COMPRESSION_SCALE: Final[float] = 10**NUM_DECIMALS ++import enum ++ ++class LaneSegRelation(enum.IntEnum): ++ """ ++ Categorical token describing the relationship between an agent and a Lane ++ """ ++ NOTCONNECTED = 0 ++ NEXT = 1 ++ PREV = 2 ++ LEFT = 3 ++ RIGHT = 4 ++ ++def pad_map_patch( ++ patch: np.ndarray, ++ # top, bot, left, right ++ patch_sides: Tuple[int, int, int, int], ++ patch_size: int, ++ map_dims: Tuple[int, int, int], ++) -> np.ndarray: ++ # TODO(pkarkus) remove equivalent function from df_cache ++ ++ if patch.shape[-2:] == (patch_size, patch_size): ++ return patch ++ ++ top, bot, left, right = patch_sides ++ channels, height, width = map_dims ++ ++ # If we're off the map, just return zeros in the ++ # desired size of the patch. ++ if bot <= 0 or top >= height or right <= 0 or left >= width: ++ return np.zeros((channels, patch_size, patch_size)) ++ ++ pad_top, pad_bot, pad_left, pad_right = 0, 0, 0, 0 ++ if top < 0: ++ pad_top = 0 - top ++ if bot >= height: ++ pad_bot = bot - height ++ if left < 0: ++ pad_left = 0 - left ++ if right >= width: ++ pad_right = right - width ++ ++ return np.pad(patch, [(0, 0), (pad_top, pad_bot), (pad_left, pad_right)]) ++ ++ ++def load_map_patch( ++ raster_map_path: Path, ++ raster_metadata_path: Path, ++ world_x: float, ++ world_y: float, ++ desired_patch_size: int, ++ resolution: float, ++ offset_xy: Tuple[float, float], ++ agent_heading: float, ++ return_rgb: bool, ++ rot_pad_factor: float = 1.0, ++ no_map_val: float = 0.0, ++ allow_missing_map: bool = False, ++) -> Tuple[np.ndarray, np.ndarray, bool]: ++ ++ # TODO(pkarkus) remove equivalent function from df_cache ++ ++ if not raster_metadata_path.exists(): ++ if not allow_missing_map: ++ raise ValueError(f"Missing map at {raster_metadata_path}") ++ # This dataset (or location) does not have any maps, ++ # so we return an empty map. ++ patch_size: int = math.ceil((rot_pad_factor * desired_patch_size) / 2) * 2 ++ ++ return ( ++ np.full( ++ (1 if not return_rgb else 3, patch_size, patch_size), ++ fill_value=no_map_val, ++ ), ++ np.eye(3), ++ False, ++ ) ++ ++ with open(raster_metadata_path, "rb") as f: ++ map_info: RasterizedMapMetadata = dill.load(f) ++ ++ raster_from_world_tf: np.ndarray = map_info.map_from_world ++ map_coords: np.ndarray = map_info.map_from_world @ np.array( ++ [world_x, world_y, 1.0] ++ ) ++ map_x, map_y = map_coords[0].item(), map_coords[1].item() ++ ++ raster_from_world_tf = ( ++ np.array( ++ [ ++ [1.0, 0.0, -map_x], ++ [0.0, 1.0, -map_y], ++ [0.0, 0.0, 1.0], ++ ] ++ ) ++ @ raster_from_world_tf ++ ) ++ ++ # This first size is how much of the map we ++ # need to extract to match the requested metric size (meters x meters) of ++ # the patch. ++ data_patch_size: int = math.ceil( ++ desired_patch_size * map_info.resolution / resolution ++ ) ++ ++ # Incorporating offsets. ++ if offset_xy != (0.0, 0.0): ++ # x is negative here because I am moving the map ++ # center so that the agent ends up where the user wishes ++ # (the agent is pinned from the end user's perspective). ++ map_offset: Tuple[float, float] = ( ++ -offset_xy[0] * data_patch_size // 2, ++ offset_xy[1] * data_patch_size // 2, ++ ) ++ ++ rotated_offset: np.ndarray = ( ++ arr_utils.rotation_matrix(agent_heading) @ map_offset ++ ) ++ ++ off_x = rotated_offset[0] ++ off_y = rotated_offset[1] ++ ++ map_x += off_x ++ map_y += off_y ++ ++ raster_from_world_tf = ( ++ np.array( ++ [ ++ [1.0, 0.0, -off_x], ++ [0.0, 1.0, -off_y], ++ [0.0, 0.0, 1.0], ++ ] ++ ) ++ @ raster_from_world_tf ++ ) ++ ++ # This is the size of the patch taking into account expansion to allow for ++ # rotation to match the agent's heading. We also ensure the final size is ++ # divisible by two so that the // 2 below does not chop any information off. ++ data_with_rot_pad_size: int = math.ceil((rot_pad_factor * data_patch_size) / 2) * 2 ++ ++ disk_data = zarr.open_array(raster_map_path, mode="r") ++ ++ map_x = round(map_x) ++ map_y = round(map_y) ++ ++ # Half of the patch's side length. ++ half_extent: int = data_with_rot_pad_size // 2 ++ ++ top: int = map_y - half_extent ++ bot: int = map_y + half_extent ++ left: int = map_x - half_extent ++ right: int = map_x + half_extent ++ ++ data_patch: np.ndarray = pad_map_patch( ++ disk_data[ ++ ..., ++ max(top, 0) : min(bot, disk_data.shape[1]), ++ max(left, 0) : min(right, disk_data.shape[2]), ++ ], ++ (top, bot, left, right), ++ data_with_rot_pad_size, ++ disk_data.shape, ++ ) ++ ++ if return_rgb: ++ rgb_groups = map_info.layer_rgb_groups ++ data_patch = np.stack( ++ [ ++ np.amax(data_patch[rgb_groups[0]], axis=0), ++ np.amax(data_patch[rgb_groups[1]], axis=0), ++ np.amax(data_patch[rgb_groups[2]], axis=0), ++ ], ++ ) ++ ++ if desired_patch_size != data_patch_size: ++ scale_factor: float = desired_patch_size / data_patch_size ++ data_patch = ( ++ kornia.geometry.rescale( ++ torch.from_numpy(data_patch).unsqueeze(0), ++ scale_factor, ++ # Default align_corners value, just putting it to remove warnings ++ align_corners=False, ++ antialias=True, ++ ) ++ .squeeze(0) ++ .numpy() ++ ) ++ ++ raster_from_world_tf = ( ++ np.array( ++ [ ++ [1 / scale_factor, 0.0, 0.0], ++ [0.0, 1 / scale_factor, 0.0], ++ [0.0, 0.0, 1.0], ++ ] ++ ) ++ @ raster_from_world_tf ++ ) ++ ++ return data_patch, raster_from_world_tf, True ++ ++ ++def batch_transform_raster_maps( ++ map_patches: List[RasterizedMapPatch], ++): ++ ++ patch_size: int = map_patches[0].crop_size ++ assert all( ++ x.crop_size == patch_size for x in map_patches ++ ) ++ ++ agents_rasters_from_world_tfs: List[np.ndarray] = [ ++ x.raster_from_world_tf for x in map_patches ++ ] ++ agents_patches: List[np.ndarray] = [x.data for x in map_patches] ++ agents_rot_angles_list: List[float] = [ ++ x.rot_angle for x in map_patches] ++ agents_res_list: List[float] = [x.resolution for x in map_patches] ++ ++ patch_data: torch.Tensor = torch.as_tensor(np.stack(agents_patches), dtype=torch.float) ++ agents_rot_angles: torch.Tensor = torch.as_tensor( ++ np.stack(agents_rot_angles_list), dtype=torch.float ++ ) ++ agents_rasters_from_world_tf: torch.Tensor = torch.as_tensor( ++ np.stack(agents_rasters_from_world_tfs), dtype=torch.float ++ ) ++ agents_resolution: torch.Tensor = torch.as_tensor( ++ np.stack(agents_res_list), dtype=torch.float ++ ) ++ ++ patch_size_y, patch_size_x = patch_data.shape[-2:] ++ center_y: int = patch_size_y // 2 ++ center_x: int = patch_size_x // 2 ++ half_extent: int = patch_size // 2 ++ ++ if torch.count_nonzero(agents_rot_angles) == 0: ++ agents_rasters_from_world_tf = torch.bmm( ++ torch.tensor( ++ [ ++ [ ++ [1.0, 0.0, half_extent], ++ [0.0, 1.0, half_extent], ++ [0.0, 0.0, 1.0], ++ ] ++ ], ++ dtype=agents_rasters_from_world_tf.dtype, ++ device=agents_rasters_from_world_tf.device, ++ ).expand((agents_rasters_from_world_tf.shape[0], -1, -1)), ++ agents_rasters_from_world_tf, ++ ) ++ ++ rot_crop_patches = patch_data ++ else: ++ agents_rasters_from_world_tf = torch.bmm( ++ arr_utils.transform_matrices( ++ -agents_rot_angles, ++ torch.tensor([[half_extent, half_extent]]).expand( ++ (agents_rot_angles.shape[0], -1) ++ ), ++ ), ++ agents_rasters_from_world_tf, ++ ) ++ ++ # Batch rotating patches by rot_angles. ++ rot_patches: torch.Tensor = kornia.geometry.transform.rotate( ++ patch_data, torch.rad2deg(agents_rot_angles)) ++ ++ # Center cropping via slicing. ++ rot_crop_patches = rot_patches[ ++ ..., ++ center_y - half_extent : center_y + half_extent, ++ center_x - half_extent : center_x + half_extent, ++ ] ++ ++ return rot_crop_patches, agents_resolution, agents_rasters_from_world_tf + + + def decompress_values(data: np.ndarray) -> np.ndarray: +@@ -23,11 +304,11 @@ def decompress_values(data: np.ndarray) -> np.ndarray: + # The delta for the first point is just its coordinates tuple, i.e. it is a "delta" from + # the origin. For subsequent points, this field stores the difference between the point's + # coordinates and the previous point's coordinates. This is for representation efficiency. +- return np.cumsum(data, axis=0, dtype=float) / MM_PER_M ++ return np.cumsum(data, axis=0, dtype=float) / COMPRESSION_SCALE + + + def compress_values(data: np.ndarray) -> np.ndarray: +- return (np.diff(data, axis=0, prepend=0.0) * MM_PER_M).astype(np.int32) ++ return (np.diff(data, axis=0, prepend=0.0) * COMPRESSION_SCALE).astype(np.int64) + + + def get_polyline_headings(points: np.ndarray) -> np.ndarray: +diff --git a/src/trajdata/utils/py_utils.py b/src/trajdata/utils/py_utils.py +new file mode 100644 +index 0000000..c302aab +--- /dev/null ++++ b/src/trajdata/utils/py_utils.py +@@ -0,0 +1,13 @@ ++import hashlib ++import json ++from typing import Dict, List, Set, Tuple, Union ++ ++ ++def hash_dict(o: Union[Dict, List, Tuple, Set]) -> str: ++ """ ++ Makes a hash from a dictionary, list, tuple or set to any level, that contains ++ only other hashable types (including any lists, tuples, sets, and ++ dictionaries). ++ """ ++ string_rep: str = json.dumps(o) ++ return hashlib.sha1(str.encode(string_rep)).hexdigest() +diff --git a/src/trajdata/utils/scene_utils.py b/src/trajdata/utils/scene_utils.py +index 804c498..1c29bc5 100644 +--- a/src/trajdata/utils/scene_utils.py ++++ b/src/trajdata/utils/scene_utils.py +@@ -60,12 +60,13 @@ def interpolate_scene_dt(scene: Scene, desired_dt: float) -> None: + + def subsample_scene_dt(scene: Scene, desired_dt: float) -> None: + dt_ratio: float = desired_dt / scene.dt +- if not dt_ratio.is_integer(): ++ ++ if not is_integer_robust(dt_ratio): + raise ValueError( + f"Cannot subsample scene: {desired_dt} is not integer divisible by {scene.dt} for {str(scene)}" + ) + +- dt_factor: int = int(dt_ratio) ++ dt_factor: int = int(round(dt_ratio)) + + # E.g., the scene is currently at dt = 0.1s (10 Hz), + # but we want desired_dt = 0.5s (2 Hz). +@@ -86,3 +87,6 @@ def subsample_scene_dt(scene: Scene, desired_dt: float) -> None: + scene.dt = desired_dt + # Note we do not touch scene_info.env_metadata.dt, this will serve as our + # source of the "original" data dt information. ++ ++def is_integer_robust(x): ++ return abs(x-round(x))<1e-6 +\ No newline at end of file +diff --git a/src/trajdata/utils/vis_utils.py b/src/trajdata/utils/vis_utils.py +index aab5acb..54fc036 100644 +--- a/src/trajdata/utils/vis_utils.py ++++ b/src/trajdata/utils/vis_utils.py +@@ -1,12 +1,15 @@ + from collections import defaultdict +-from typing import List, Optional, Tuple ++from typing import List, Optional, Tuple, Dict ++ + + import geopandas as gpd + import numpy as np + import pandas as pd + import seaborn as sns + from bokeh.models import ColumnDataSource, GlyphRenderer +-from bokeh.plotting import figure ++from bokeh.plotting import figure, curdoc ++import bokeh ++from bokeh.io import export_png + from shapely.geometry import LineString, Polygon + + from trajdata.data_structures.agent import AgentType +@@ -20,7 +23,12 @@ from trajdata.maps.vec_map_elements import ( + RoadArea, + RoadLane, + ) +-from trajdata.utils.arr_utils import transform_coords_2d_np ++from trajdata.utils.arr_utils import ( ++ transform_coords_2d_np, ++ batch_nd_transform_points_pt, ++ batch_nd_transform_points_np, ++) ++from PIL import Image + + + def apply_default_settings(fig: figure) -> None: +@@ -481,7 +489,7 @@ def draw_map_elems( + vec_map: VectorMap, + map_from_world_tf: np.ndarray, + bbox: Optional[Tuple[float, float, float, float]] = None, +- **kwargs ++ **kwargs, + ) -> Tuple[GlyphRenderer, GlyphRenderer, GlyphRenderer, GlyphRenderer, GlyphRenderer]: + """_summary_ + +diff --git a/tests/Drivesim_scene_generation.py b/tests/Drivesim_scene_generation.py +new file mode 100644 +index 0000000..fadd074 +--- /dev/null ++++ b/tests/Drivesim_scene_generation.py +@@ -0,0 +1,207 @@ ++from collections import namedtuple ++from typing import Any, List, Optional ++from collections import defaultdict ++import numpy as np ++import pandas as pd ++ ++from trajdata.data_structures.agent import AgentMetadata ++from trajdata.data_structures.environment import EnvMetadata ++from trajdata.data_structures.scene_metadata import Scene ++from trajdata import AgentType ++from trajdata.data_structures.agent import FixedExtent ++from trajdata.caching.scene_cache import SceneCache ++from trajdata.caching.env_cache import EnvCache ++from trajdata.caching.df_cache import DataFrameCache ++from trajdata.dataset_specific.scene_records import DrivesimSceneRecord ++from pathlib import Path ++import dill ++from trajdata import MapAPI, VectorMap ++import tbsim.utils.lane_utils as LaneUtils ++from bokeh.plotting import figure, show, save ++import bokeh ++ ++def generate_agentmeta(initial_state,agent_names,agent_extents,T,dt,hist_types): ++ total_agent_data = list() ++ total_agent_metadata = list() ++ for x0,name,extent,htype in zip(initial_state,agent_names,agent_extents,hist_types): ++ agent_meta = AgentMetadata(name=name, ++ agent_type= AgentType.VEHICLE, ++ first_timestep = 0, ++ last_timestep = T-1, ++ extent=extent) ++ x,y,v,yaw = x0 ++ if htype=="constvel": ++ vx = v*np.cos(yaw) ++ vy = v*np.sin(yaw) ++ ax = np.zeros_like(vx) ++ ay = np.zeros_like(vy) ++ yaw = yaw*np.ones(T) ++ scene_ts = np.arange(0,T) ++ x = x+vx*(scene_ts-T+1)*dt ++ y = y+vy*(scene_ts-T+1)*dt ++ elif htype in ["brake","accelerate"]: ++ acce = -2.0 if htype=="brake" else 2.0 ++ seq = np.concatenate([np.ones(T-int(T/2))*int(T/2),np.arange(int(T/2)-1,-1,-1)]) ++ vseq = np.clip(v-acce*seq*dt,0.0,10.0) ++ vx = vseq*np.cos(yaw) ++ vy = vseq*np.sin(yaw) ++ ax = vx[1:]-vx[:-1] ++ ax = np.concatenate([ax,ax[-1:]]) ++ ay = vy[1:]-vy[:-1] ++ ay = np.concatenate([ay,ay[-1:]]) ++ yaw = yaw*np.ones(T) ++ scene_ts = np.arange(0,T) ++ x = x+(vx.cumsum()-vx.sum())*dt ++ y = y+(vy.cumsum()-vy.sum())*dt ++ ++ track_id = [name]*T ++ z = np.zeros(T) ++ ++ ++ pd_frame = pd.DataFrame({"agent_id":track_id,"scene_ts":scene_ts,"x":x,"y":y,"z":z,"heading":yaw,"vx":vx,"vy":vy,"ax":ax,"ay":ay}) ++ total_agent_data.append(pd_frame) ++ total_agent_metadata.append(agent_meta) ++ total_agent_data = pd.concat(total_agent_data).set_index(["agent_id", "scene_ts"]) ++ return total_agent_data,total_agent_metadata ++ ++ ++ ++def generate_drivesim_scene(): ++ ++ repeat = 10 ++ ++ dt = 0.1 ++ data_dir = "" ++ T = 20 ++ cache_path = Path("/home/yuxiaoc/.unified_data_cache") ++ env_metadata = EnvMetadata("drivesim", ++ data_dir, ++ dt, ++ parts=[("train",),("main",)], ++ scene_split_map=defaultdict(lambda: "train"), ++ map_locations=("main",)) ++ env_cache = EnvCache(cache_path) ++ scene_records = list() ++ ++ agent_initial_state = [(np.array([-565,-1001,3.0,0]),"accelerate"), ++ (np.array([-573,-1001,3.0,0]),"accelerate"), ++ (np.array([-573,-1005,3.0,0.0]),"accelerate"), ++ (np.array([-530.0,-976,0.0,-0.95*np.pi/2]),"constvel"), ++ (np.array([-526.4,-976,0.0,-0.95*np.pi/2]),"brake"), ++ (np.array([-507.8,-1027,0.0,np.pi/2+0.08]),"brake"), ++ (np.array([-507.8,-1021,0.0,np.pi/2+0.08]),"brake"), ++ (np.array([-504,-1021,0.0,np.pi/2+0.08]),"brake"), ++ (np.array([-500.5,-1021,0.0,np.pi/2+0.08]),"brake"), ++ (np.array([-480.5,-992,0.0,np.pi]),"brake"), ++ (np.array([-473,-992,0.0,np.pi]),"brake"), ++ (np.array([-489,-1049,8.0,np.pi/2-0.1]),"constvel"), ++ (np.array([-488.4,-983,5.0,np.pi*0.75]),"accelerate"), ++ (np.array([-524.2,-964.6,2.0,-0.95*np.pi/2]),"brake"), ++ (np.array([-523.3,-976,0.0,-0.95*np.pi/2]),"brake"), ++ (np.array([-519.5,-975,0.0,-0.92*np.pi/2]),"brake"), ++ (np.array([-563.4,-1011,8.0,-0.18*np.pi/2]),"constvel"), ++ ] ++ ++ noise_spec = [1,1,1,2,2,3,3,3] ++ for r in range(repeat): ++ group1_xn =np.random.randn()*5 ++ group1_yn = np.random.randn()*0.5 ++ group1_vn = np.random.randn()*0.5 ++ group1_psin = np.random.randn()*0.01 ++ group2_xn = np.random.randn()*0.2 ++ group2_yn = np.random.randn()*0.2 ++ group2_vn = 0 ++ group2_psin = np.random.randn()*0.01 ++ group3_xn = np.random.randn()*0.2 ++ group3_yn = np.random.randn()*0.2 ++ group3_vn = 0 ++ group3_psin = np.random.randn()*0.01 ++ group1_n = np.array([group1_xn,group1_yn,group1_vn,group1_psin]) ++ group2_n = np.array([group2_xn,group2_yn,group2_vn,group2_psin]) ++ group3_n = np.array([group3_xn,group3_yn,group3_vn,group3_psin]) ++ init_state=np.stack([x0 for x0,_ in agent_initial_state]) ++ hist_types = [htype for _,htype in agent_initial_state] ++ for i in range(len(init_state)): ++ if i None: ++ super().__init__(methodName) ++ ++ data_source = "nusc_mini" ++ history_sec = 2.0 ++ prediction_sec = 6.0 ++ ++ attention_radius = defaultdict( ++ lambda: 20.0 ++ ) # Default range is 20m unless otherwise specified. ++ attention_radius[(AgentType.PEDESTRIAN, AgentType.PEDESTRIAN)] = 10.0 ++ attention_radius[(AgentType.PEDESTRIAN, AgentType.VEHICLE)] = 20.0 ++ attention_radius[(AgentType.VEHICLE, AgentType.PEDESTRIAN)] = 20.0 ++ attention_radius[(AgentType.VEHICLE, AgentType.VEHICLE)] = 30.0 ++ ++ self._map_params = {"px_per_m": 2, "map_size_px": 100, "offset_frac_xy": (-0.75, 0.0)} ++ ++ self._scene_dataset = UnifiedDataset( ++ centric="scene", ++ desired_data=[data_source], ++ history_sec=(history_sec, history_sec), ++ future_sec=(prediction_sec, prediction_sec), ++ agent_interaction_distances=attention_radius, ++ incl_robot_future=False, ++ incl_raster_map=True, ++ raster_map_params=self._map_params, ++ only_predict=[AgentType.VEHICLE, AgentType.PEDESTRIAN], ++ no_types=[AgentType.UNKNOWN], ++ num_workers=0, ++ standardize_data=True, ++ data_dirs={ ++ "nusc_mini": "~/datasets/nuScenes", ++ }, ++ ) ++ ++ self._scene_dataloader = DataLoader( ++ self._scene_dataset, ++ batch_size=4, ++ shuffle=False, ++ collate_fn=self._scene_dataset.get_collate_fn(), ++ num_workers=0, ++ ) ++ ++ def _assert_allclose_with_nans(self, tensor1, tensor2, atol=1e-8): ++ """ ++ asserts that the two tensors have nans in the same locations, and the non-nan ++ elements all are close. ++ """ ++ # Check nans are in the same place ++ self.assertFalse( ++ torch.any( # True if there's any mismatch ++ torch.logical_xor( # True where either tensor1 or tensor 2 has nans, but not both (mismatch) ++ torch.isnan(tensor1), # True where tensor1 has nans ++ torch.isnan(tensor2), # True where tensor2 has nans ++ ) ++ ), ++ msg="Nans occur in different places.", ++ ) ++ valid_mask = torch.logical_not(torch.isnan(tensor1)) ++ self.assertTrue( ++ torch.allclose(tensor1[valid_mask], tensor2[valid_mask], atol=atol), ++ msg="Non-nan values don't match.", ++ ) ++ ++ def test_map_transform_scenebatch(self): ++ scene_batch: SceneBatch ++ for i, scene_batch in enumerate(self._scene_dataloader): ++ ++ # Make the tf double for more accurate transform. ++ scene_batch.centered_world_from_agent_tf = scene_batch.centered_world_from_agent_tf.double() ++ ++ maps, maps_resolution, raster_from_world_tf = get_raster_maps_for_scene_batch( ++ scene_batch, self._scene_dataset.cache_path, "nusc_mini", self._map_params) ++ ++ self._assert_allclose_with_nans(scene_batch.rasters_from_world_tf, raster_from_world_tf, atol=1e-2) ++ self._assert_allclose_with_nans(scene_batch.maps_resolution, maps_resolution) ++ self._assert_allclose_with_nans(scene_batch.maps, maps, atol=1e-4) ++ ++ if i > 50: ++ break ++ ++if __name__ == "__main__": ++ unittest.main(catchbreak=False) diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 6bd7408..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,6 +0,0 @@ -[build-system] -requires = [ - "setuptools>=58", - "wheel" -] -build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 0baea75..f55ac42 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,36 @@ -torch==1.12.1 -numpy -scipy -dill -pathos==0.2.9 +matplotlib==3.3.4 +seaborn==0.11.1 +# numpy +# scipy +# scikit-learn==0.24.1 +# torch==1.10.2 +pyquaternion==0.9.9 +pytest==6.2.2 orjson==3.5.1 ncls==0.0.57 +dill==0.3.5.1 +pathos==0.2.9 +tqdm>=4.53.0 +notebook==6.2.0 +nuscenes-devkit==1.1.5 +pykalman==0.9.5 +# You will likely need to source deactivate and activate +# again to get the correct tensorboard version used. +gym>=0.18.3 +stable-baselines3==1.1.0a11 # pip install git+https://github.com/DLR-RM/stable-baselines3 +requests==2.28.2 +pympler==1.0.1 +moviepy==1.0.3 wandb imageio==2.9.0 -tqdm -matplotlib<=3.7 +ipdb +l5kit==1.5.0 +networkx + +community +pytorch-lightning==2.0.5 +# You will likely need to uninstall opencv-python and install an earlier version instead because of a compatibility issue +# https://github.com/opencv/opencv-python/issues/591 +# pip uninstall opencv-python -y +# pip install "opencv-python-headless<4.3" +hydra-core==1.3.2 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 9bfc3f9..0000000 --- a/setup.cfg +++ /dev/null @@ -1,43 +0,0 @@ -[metadata] -name = diffstack -version = 0.0.2 -author = Peter Karkus -author_email = pkarkus@nvidia.com -description = DiffStack: a differentiable prediction, planning, control stack for autonomous driving. -long_description = file: README.md -long_description_content_type = text/markdown -license = NSCL -url = https://github.com/NVlabs/diffstack -classifiers = - Development Status :: 3 - Alpha - Intended Audience :: Developers - Programming Language :: Python :: 3.9 - -[options] -package_dir = - = ./ -packages = find: -python_requires = >=3.9 -install_requires = - numpy>=1.19 - tqdm>=4.62 - matplotlib>=3.5 - dill>=0.3.4 - pandas>=1.4.1 - pyarrow>=7.0.0 - torch>=1.13.1 - zarr>=2.11.0 - kornia>=0.6.4 - pathos>=0.2.9 - seaborn>=0.12 - protobuf>=3.19.4 # for trajdata map api - orjson>=3.5.1 - ncls>=0.0.57 - wandb - mpc @ git+https://github.com/locuslab/mpc.pytorch.git - trajdata[nusc] @ git+https://github.com/NVlabs/trajdata.git@d714b82dadec80c62e6413ae9a0feb42a517be57 - -[options.packages.find] -where = ./ - - diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..46bd280 --- /dev/null +++ b/setup.py @@ -0,0 +1,29 @@ +from setuptools import setup, find_packages + +# read the contents of your README file +from os import path + +this_directory = path.abspath(path.dirname(__file__)) + + +# remove images from README + +setup( + name="diffstack", + packages=[ + package for package in find_packages() if package.startswith("diffstack") + ], + install_requires=[ + "wandb", + "pytorch-lightning", + ], + eager_resources=["*"], + include_package_data=True, + python_requires=">=3", + description="diffstack", + author="NVIDIA AV Research", + author_email="", + version="0.0.1", + long_description="", + long_description_content_type="text/markdown", +)